18
18
#include " src/__support/common.h"
19
19
#include " src/__support/uint128.h"
20
20
21
+ #include " hdr/fenv_macros.h"
22
+
21
23
namespace LIBC_NAMESPACE {
22
24
namespace fputil {
23
25
@@ -64,40 +66,50 @@ LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
64
66
65
67
// Correctly rounded IEEE 754 SQRT for all rounding modes.
66
68
// Shift-and-add algorithm.
67
- template <typename T>
68
- LIBC_INLINE cpp::enable_if_t <cpp::is_floating_point_v<T>, T> sqrt (T x) {
69
-
70
- if constexpr (internal::SpecialLongDouble<T>::VALUE) {
69
+ template <typename OutType, typename InType>
70
+ LIBC_INLINE cpp::enable_if_t <cpp::is_floating_point_v<OutType> &&
71
+ cpp::is_floating_point_v<InType> &&
72
+ sizeof (OutType) <= sizeof (InType),
73
+ OutType>
74
+ sqrt (InType x) {
75
+ if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
76
+ internal::SpecialLongDouble<InType>::VALUE) {
71
77
// Special 80-bit long double.
72
78
return x86::sqrt (x);
73
79
} else {
74
80
// IEEE floating points formats.
75
- using FPBits_t = typename fputil::FPBits<T>;
76
- using StorageType = typename FPBits_t::StorageType;
77
- constexpr StorageType ONE = StorageType (1 ) << FPBits_t::FRACTION_LEN;
78
- constexpr auto FLT_NAN = FPBits_t::quiet_nan ().get_val ();
79
-
80
- FPBits_t bits (x);
81
-
82
- if (bits == FPBits_t::inf (Sign::POS) || bits.is_zero () || bits.is_nan ()) {
81
+ using OutFPBits = typename fputil::FPBits<OutType>;
82
+ using OutStorageType = typename OutFPBits::StorageType;
83
+ using InFPBits = typename fputil::FPBits<InType>;
84
+ using InStorageType = typename InFPBits::StorageType;
85
+ constexpr InStorageType ONE = InStorageType (1 ) << InFPBits::FRACTION_LEN;
86
+ constexpr auto FLT_NAN = OutFPBits::quiet_nan ().get_val ();
87
+ constexpr int EXTRA_FRACTION_LEN =
88
+ InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
89
+ constexpr InStorageType EXTRA_FRACTION_MASK =
90
+ (InStorageType (1 ) << EXTRA_FRACTION_LEN) - 1 ;
91
+
92
+ InFPBits bits (x);
93
+
94
+ if (bits == InFPBits::inf (Sign::POS) || bits.is_zero () || bits.is_nan ()) {
83
95
// sqrt(+Inf) = +Inf
84
96
// sqrt(+0) = +0
85
97
// sqrt(-0) = -0
86
98
// sqrt(NaN) = NaN
87
99
// sqrt(-NaN) = -NaN
88
- return x ;
100
+ return static_cast <OutType>(x) ;
89
101
} else if (bits.is_neg ()) {
90
102
// sqrt(-Inf) = NaN
91
103
// sqrt(-x) = NaN
92
104
return FLT_NAN;
93
105
} else {
94
106
int x_exp = bits.get_exponent ();
95
- StorageType x_mant = bits.get_mantissa ();
107
+ InStorageType x_mant = bits.get_mantissa ();
96
108
97
109
// Step 1a: Normalize denormal input and append hidden bit to the mantissa
98
110
if (bits.is_subnormal ()) {
99
111
++x_exp; // let x_exp be the correct exponent of ONE bit.
100
- internal::normalize<T >(x_exp, x_mant);
112
+ internal::normalize<InType >(x_exp, x_mant);
101
113
} else {
102
114
x_mant |= ONE;
103
115
}
@@ -120,47 +132,105 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
120
132
// So the nth digit y_n of the mantissa of sqrt(x) can be found by:
121
133
// y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
122
134
// 0 otherwise.
123
- StorageType y = ONE;
124
- StorageType r = x_mant - ONE;
135
+ InStorageType y = ONE;
136
+ InStorageType r = x_mant - ONE;
125
137
126
- for (StorageType current_bit = ONE >> 1 ; current_bit; current_bit >>= 1 ) {
138
+ for (InStorageType current_bit = ONE >> 1 ; current_bit;
139
+ current_bit >>= 1 ) {
127
140
r <<= 1 ;
128
- StorageType tmp = (y << 1 ) + current_bit; // 2*y(n - 1) + 2^(-n-1)
141
+ InStorageType tmp = (y << 1 ) + current_bit; // 2*y(n - 1) + 2^(-n-1)
129
142
if (r >= tmp) {
130
143
r -= tmp;
131
144
y += current_bit;
132
145
}
133
146
}
134
147
135
148
// We compute one more iteration in order to round correctly.
136
- bool lsb = static_cast <bool >(y & 1 ); // Least significant bit
137
- bool rb = false ; // Round bit
149
+ bool lsb = (y & (InStorageType (1 ) << EXTRA_FRACTION_LEN)) !=
150
+ 0 ; // Least significant bit
151
+ bool rb = false ; // Round bit
138
152
r <<= 2 ;
139
- StorageType tmp = (y << 2 ) + 1 ;
153
+ InStorageType tmp = (y << 2 ) + 1 ;
140
154
if (r >= tmp) {
141
155
r -= tmp;
142
156
rb = true ;
143
157
}
144
158
159
+ bool sticky = false ;
160
+
161
+ if constexpr (EXTRA_FRACTION_LEN > 0 ) {
162
+ sticky = rb || (y & EXTRA_FRACTION_MASK) != 0 ;
163
+ rb = (y & (InStorageType (1 ) << (EXTRA_FRACTION_LEN - 1 ))) != 0 ;
164
+ }
165
+
145
166
// Remove hidden bit and append the exponent field.
146
- x_exp = ((x_exp >> 1 ) + FPBits_t::EXP_BIAS);
167
+ x_exp = ((x_exp >> 1 ) + OutFPBits::EXP_BIAS);
168
+
169
+ OutStorageType y_out = static_cast <OutStorageType>(
170
+ ((y - ONE) >> EXTRA_FRACTION_LEN) |
171
+ (static_cast <OutStorageType>(x_exp) << OutFPBits::FRACTION_LEN));
172
+
173
+ if constexpr (EXTRA_FRACTION_LEN > 0 ) {
174
+ if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) {
175
+ switch (quick_get_round ()) {
176
+ case FE_TONEAREST:
177
+ case FE_UPWARD:
178
+ return OutFPBits::inf ().get_val ();
179
+ default :
180
+ return OutFPBits::max_normal ().get_val ();
181
+ }
182
+ }
183
+
184
+ if (x_exp <
185
+ -OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
186
+ switch (quick_get_round ()) {
187
+ case FE_UPWARD:
188
+ return OutFPBits::min_subnormal ().get_val ();
189
+ default :
190
+ return OutType (0.0 );
191
+ }
192
+ }
147
193
148
- y = (y - ONE) |
149
- (static_cast <StorageType>(x_exp) << FPBits_t::FRACTION_LEN);
194
+ if (x_exp <= 0 ) {
195
+ int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1 ;
196
+ InStorageType underflow_extra_fraction_mask =
197
+ (InStorageType (1 ) << underflow_extra_fraction_len) - 1 ;
198
+
199
+ rb = (y & (InStorageType (1 ) << (underflow_extra_fraction_len - 1 ))) !=
200
+ 0 ;
201
+ OutStorageType subnormal_mant =
202
+ static_cast <OutStorageType>(y >> underflow_extra_fraction_len);
203
+ lsb = (subnormal_mant & 1 ) != 0 ;
204
+ sticky = sticky || (y & underflow_extra_fraction_mask) != 0 ;
205
+
206
+ switch (quick_get_round ()) {
207
+ case FE_TONEAREST:
208
+ if (rb && (lsb || sticky))
209
+ ++subnormal_mant;
210
+ break ;
211
+ case FE_UPWARD:
212
+ if (rb || sticky)
213
+ ++subnormal_mant;
214
+ break ;
215
+ }
216
+
217
+ return cpp::bit_cast<OutType>(subnormal_mant);
218
+ }
219
+ }
150
220
151
221
switch (quick_get_round ()) {
152
222
case FE_TONEAREST:
153
223
// Round to nearest, ties to even
154
224
if (rb && (lsb || (r != 0 )))
155
- ++y ;
225
+ ++y_out ;
156
226
break ;
157
227
case FE_UPWARD:
158
- if (rb || (r != 0 ))
159
- ++y ;
228
+ if (rb || (r != 0 ) || sticky )
229
+ ++y_out ;
160
230
break ;
161
231
}
162
232
163
- return cpp::bit_cast<T>(y );
233
+ return cpp::bit_cast<OutType>(y_out );
164
234
}
165
235
}
166
236
}
0 commit comments