Skip to content

Commit a239343

Browse files
authored
[libc][math][c23] Add f16sqrtf C23 math function (#95251)
Part of #95250.
1 parent 389142e commit a239343

29 files changed

+352
-123
lines changed

libc/config/linux/aarch64/entrypoints.txt

+1
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
503503
libc.src.math.canonicalizef16
504504
libc.src.math.ceilf16
505505
libc.src.math.copysignf16
506+
libc.src.math.f16sqrtf
506507
libc.src.math.fabsf16
507508
libc.src.math.fdimf16
508509
libc.src.math.floorf16

libc/config/linux/x86_64/entrypoints.txt

+1
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
535535
libc.src.math.canonicalizef16
536536
libc.src.math.ceilf16
537537
libc.src.math.copysignf16
538+
libc.src.math.f16sqrtf
538539
libc.src.math.fabsf16
539540
libc.src.math.fdimf16
540541
libc.src.math.floorf16

libc/docs/math/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ Higher Math Functions
280280
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
281281
| fma | |check| | |check| | | | | 7.12.13.1 | F.10.10.1 |
282282
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
283+
| f16sqrt | |check| | | | N/A | | 7.12.14.6 | F.10.11 |
284+
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
283285
| fsqrt | N/A | | | N/A | | 7.12.14.6 | F.10.11 |
284286
+-----------+------------------+-----------------+------------------------+----------------------+------------------------+------------------------+----------------------------+
285287
| hypot | |check| | |check| | | | | 7.12.7.4 | F.10.4.4 |

libc/spec/stdc.td

+2
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,8 @@ def StdC : StandardSpec<"stdc"> {
714714
GuardedFunctionSpec<"totalorderf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,
715715

716716
GuardedFunctionSpec<"totalordermagf16", RetValSpec<IntType>, [ArgSpec<Float16Ptr>, ArgSpec<Float16Ptr>], "LIBC_TYPES_HAS_FLOAT16">,
717+
718+
GuardedFunctionSpec<"f16sqrtf", RetValSpec<Float16Type>, [ArgSpec<FloatType>], "LIBC_TYPES_HAS_FLOAT16">,
717719
]
718720
>;
719721

libc/src/__support/FPUtil/generic/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_header_library(
44
sqrt.h
55
sqrt_80_bit_long_double.h
66
DEPENDS
7+
libc.hdr.fenv_macros
78
libc.src.__support.common
89
libc.src.__support.CPP.bit
910
libc.src.__support.CPP.type_traits

libc/src/__support/FPUtil/generic/sqrt.h

+99-29
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "src/__support/common.h"
1919
#include "src/__support/uint128.h"
2020

21+
#include "hdr/fenv_macros.h"
22+
2123
namespace LIBC_NAMESPACE {
2224
namespace fputil {
2325

@@ -64,40 +66,50 @@ LIBC_INLINE void normalize<long double>(int &exponent, UInt128 &mantissa) {
6466

6567
// Correctly rounded IEEE 754 SQRT for all rounding modes.
6668
// Shift-and-add algorithm.
67-
template <typename T>
68-
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
69-
70-
if constexpr (internal::SpecialLongDouble<T>::VALUE) {
69+
template <typename OutType, typename InType>
70+
LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<OutType> &&
71+
cpp::is_floating_point_v<InType> &&
72+
sizeof(OutType) <= sizeof(InType),
73+
OutType>
74+
sqrt(InType x) {
75+
if constexpr (internal::SpecialLongDouble<OutType>::VALUE &&
76+
internal::SpecialLongDouble<InType>::VALUE) {
7177
// Special 80-bit long double.
7278
return x86::sqrt(x);
7379
} else {
7480
// IEEE floating points formats.
75-
using FPBits_t = typename fputil::FPBits<T>;
76-
using StorageType = typename FPBits_t::StorageType;
77-
constexpr StorageType ONE = StorageType(1) << FPBits_t::FRACTION_LEN;
78-
constexpr auto FLT_NAN = FPBits_t::quiet_nan().get_val();
79-
80-
FPBits_t bits(x);
81-
82-
if (bits == FPBits_t::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
81+
using OutFPBits = typename fputil::FPBits<OutType>;
82+
using OutStorageType = typename OutFPBits::StorageType;
83+
using InFPBits = typename fputil::FPBits<InType>;
84+
using InStorageType = typename InFPBits::StorageType;
85+
constexpr InStorageType ONE = InStorageType(1) << InFPBits::FRACTION_LEN;
86+
constexpr auto FLT_NAN = OutFPBits::quiet_nan().get_val();
87+
constexpr int EXTRA_FRACTION_LEN =
88+
InFPBits::FRACTION_LEN - OutFPBits::FRACTION_LEN;
89+
constexpr InStorageType EXTRA_FRACTION_MASK =
90+
(InStorageType(1) << EXTRA_FRACTION_LEN) - 1;
91+
92+
InFPBits bits(x);
93+
94+
if (bits == InFPBits::inf(Sign::POS) || bits.is_zero() || bits.is_nan()) {
8395
// sqrt(+Inf) = +Inf
8496
// sqrt(+0) = +0
8597
// sqrt(-0) = -0
8698
// sqrt(NaN) = NaN
8799
// sqrt(-NaN) = -NaN
88-
return x;
100+
return static_cast<OutType>(x);
89101
} else if (bits.is_neg()) {
90102
// sqrt(-Inf) = NaN
91103
// sqrt(-x) = NaN
92104
return FLT_NAN;
93105
} else {
94106
int x_exp = bits.get_exponent();
95-
StorageType x_mant = bits.get_mantissa();
107+
InStorageType x_mant = bits.get_mantissa();
96108

97109
// Step 1a: Normalize denormal input and append hidden bit to the mantissa
98110
if (bits.is_subnormal()) {
99111
++x_exp; // let x_exp be the correct exponent of ONE bit.
100-
internal::normalize<T>(x_exp, x_mant);
112+
internal::normalize<InType>(x_exp, x_mant);
101113
} else {
102114
x_mant |= ONE;
103115
}
@@ -120,47 +132,105 @@ LIBC_INLINE cpp::enable_if_t<cpp::is_floating_point_v<T>, T> sqrt(T x) {
120132
// So the nth digit y_n of the mantissa of sqrt(x) can be found by:
121133
// y_n = 1 if 2*r(n-1) >= 2*y(n - 1) + 2^(-n-1)
122134
// 0 otherwise.
123-
StorageType y = ONE;
124-
StorageType r = x_mant - ONE;
135+
InStorageType y = ONE;
136+
InStorageType r = x_mant - ONE;
125137

126-
for (StorageType current_bit = ONE >> 1; current_bit; current_bit >>= 1) {
138+
for (InStorageType current_bit = ONE >> 1; current_bit;
139+
current_bit >>= 1) {
127140
r <<= 1;
128-
StorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
141+
InStorageType tmp = (y << 1) + current_bit; // 2*y(n - 1) + 2^(-n-1)
129142
if (r >= tmp) {
130143
r -= tmp;
131144
y += current_bit;
132145
}
133146
}
134147

135148
// We compute one more iteration in order to round correctly.
136-
bool lsb = static_cast<bool>(y & 1); // Least significant bit
137-
bool rb = false; // Round bit
149+
bool lsb = (y & (InStorageType(1) << EXTRA_FRACTION_LEN)) !=
150+
0; // Least significant bit
151+
bool rb = false; // Round bit
138152
r <<= 2;
139-
StorageType tmp = (y << 2) + 1;
153+
InStorageType tmp = (y << 2) + 1;
140154
if (r >= tmp) {
141155
r -= tmp;
142156
rb = true;
143157
}
144158

159+
bool sticky = false;
160+
161+
if constexpr (EXTRA_FRACTION_LEN > 0) {
162+
sticky = rb || (y & EXTRA_FRACTION_MASK) != 0;
163+
rb = (y & (InStorageType(1) << (EXTRA_FRACTION_LEN - 1))) != 0;
164+
}
165+
145166
// Remove hidden bit and append the exponent field.
146-
x_exp = ((x_exp >> 1) + FPBits_t::EXP_BIAS);
167+
x_exp = ((x_exp >> 1) + OutFPBits::EXP_BIAS);
168+
169+
OutStorageType y_out = static_cast<OutStorageType>(
170+
((y - ONE) >> EXTRA_FRACTION_LEN) |
171+
(static_cast<OutStorageType>(x_exp) << OutFPBits::FRACTION_LEN));
172+
173+
if constexpr (EXTRA_FRACTION_LEN > 0) {
174+
if (x_exp >= OutFPBits::MAX_BIASED_EXPONENT) {
175+
switch (quick_get_round()) {
176+
case FE_TONEAREST:
177+
case FE_UPWARD:
178+
return OutFPBits::inf().get_val();
179+
default:
180+
return OutFPBits::max_normal().get_val();
181+
}
182+
}
183+
184+
if (x_exp <
185+
-OutFPBits::EXP_BIAS - OutFPBits::SIG_LEN + EXTRA_FRACTION_LEN) {
186+
switch (quick_get_round()) {
187+
case FE_UPWARD:
188+
return OutFPBits::min_subnormal().get_val();
189+
default:
190+
return OutType(0.0);
191+
}
192+
}
147193

148-
y = (y - ONE) |
149-
(static_cast<StorageType>(x_exp) << FPBits_t::FRACTION_LEN);
194+
if (x_exp <= 0) {
195+
int underflow_extra_fraction_len = EXTRA_FRACTION_LEN - x_exp + 1;
196+
InStorageType underflow_extra_fraction_mask =
197+
(InStorageType(1) << underflow_extra_fraction_len) - 1;
198+
199+
rb = (y & (InStorageType(1) << (underflow_extra_fraction_len - 1))) !=
200+
0;
201+
OutStorageType subnormal_mant =
202+
static_cast<OutStorageType>(y >> underflow_extra_fraction_len);
203+
lsb = (subnormal_mant & 1) != 0;
204+
sticky = sticky || (y & underflow_extra_fraction_mask) != 0;
205+
206+
switch (quick_get_round()) {
207+
case FE_TONEAREST:
208+
if (rb && (lsb || sticky))
209+
++subnormal_mant;
210+
break;
211+
case FE_UPWARD:
212+
if (rb || sticky)
213+
++subnormal_mant;
214+
break;
215+
}
216+
217+
return cpp::bit_cast<OutType>(subnormal_mant);
218+
}
219+
}
150220

151221
switch (quick_get_round()) {
152222
case FE_TONEAREST:
153223
// Round to nearest, ties to even
154224
if (rb && (lsb || (r != 0)))
155-
++y;
225+
++y_out;
156226
break;
157227
case FE_UPWARD:
158-
if (rb || (r != 0))
159-
++y;
228+
if (rb || (r != 0) || sticky)
229+
++y_out;
160230
break;
161231
}
162232

163-
return cpp::bit_cast<T>(y);
233+
return cpp::bit_cast<OutType>(y_out);
164234
}
165235
}
166236
}

libc/src/math/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ add_math_entrypoint_object(exp10f)
9999
add_math_entrypoint_object(expm1)
100100
add_math_entrypoint_object(expm1f)
101101

102+
add_math_entrypoint_object(f16sqrtf)
103+
102104
add_math_entrypoint_object(fabs)
103105
add_math_entrypoint_object(fabsf)
104106
add_math_entrypoint_object(fabsl)

libc/src/math/f16sqrtf.h

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===-- Implementation header for f16sqrtf ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef LLVM_LIBC_SRC_MATH_F16SQRTF_H
10+
#define LLVM_LIBC_SRC_MATH_F16SQRTF_H
11+
12+
#include "src/__support/macros/properties/types.h"
13+
14+
namespace LIBC_NAMESPACE {
15+
16+
float16 f16sqrtf(float x);
17+
18+
} // namespace LIBC_NAMESPACE
19+
20+
#endif // LLVM_LIBC_SRC_MATH_F16SQRTF_H

libc/src/math/generic/CMakeLists.txt

+13
Original file line numberDiff line numberDiff line change
@@ -3601,3 +3601,16 @@ add_entrypoint_object(
36013601
COMPILE_OPTIONS
36023602
-O3
36033603
)
3604+
3605+
add_entrypoint_object(
3606+
f16sqrtf
3607+
SRCS
3608+
f16sqrtf.cpp
3609+
HDRS
3610+
../f16sqrtf.h
3611+
DEPENDS
3612+
libc.src.__support.macros.properties.types
3613+
libc.src.__support.FPUtil.sqrt
3614+
COMPILE_OPTIONS
3615+
-O3
3616+
)

libc/src/math/generic/acosf.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ LLVM_LIBC_FUNCTION(float, acosf, (float x)) {
113113
xbits.set_sign(Sign::POS);
114114
double xd = static_cast<double>(xbits.get_val());
115115
double u = fputil::multiply_add(-0.5, xd, 0.5);
116-
double cv = 2 * fputil::sqrt(u);
116+
double cv = 2 * fputil::sqrt<double>(u);
117117

118118
double r3 = asin_eval(u);
119119
double r = fputil::multiply_add(cv * u, r3, cv);

libc/src/math/generic/acoshf.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ LLVM_LIBC_FUNCTION(float, acoshf, (float x)) {
6666

6767
double x_d = static_cast<double>(x);
6868
// acosh(x) = log(x + sqrt(x^2 - 1))
69-
return static_cast<float>(
70-
log_eval(x_d + fputil::sqrt(fputil::multiply_add(x_d, x_d, -1.0))));
69+
return static_cast<float>(log_eval(
70+
x_d + fputil::sqrt<double>(fputil::multiply_add(x_d, x_d, -1.0))));
7171
}
7272

7373
} // namespace LIBC_NAMESPACE

libc/src/math/generic/asinf.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ LLVM_LIBC_FUNCTION(float, asinf, (float x)) {
144144
double sign = SIGN[x_sign];
145145
double xd = static_cast<double>(xbits.get_val());
146146
double u = fputil::multiply_add(-0.5, xd, 0.5);
147-
double c1 = sign * (-2 * fputil::sqrt(u));
147+
double c1 = sign * (-2 * fputil::sqrt<double>(u));
148148
double c2 = fputil::multiply_add(sign, M_MATH_PI_2, c1);
149149
double c3 = c1 * u;
150150

libc/src/math/generic/asinhf.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,9 @@ LLVM_LIBC_FUNCTION(float, asinhf, (float x)) {
9797

9898
// asinh(x) = log(x + sqrt(x^2 + 1))
9999
return static_cast<float>(
100-
x_sign *
101-
log_eval(fputil::multiply_add(
102-
x_d, x_sign, fputil::sqrt(fputil::multiply_add(x_d, x_d, 1.0)))));
100+
x_sign * log_eval(fputil::multiply_add(
101+
x_d, x_sign,
102+
fputil::sqrt<double>(fputil::multiply_add(x_d, x_d, 1.0)))));
103103
}
104104

105105
} // namespace LIBC_NAMESPACE

libc/src/math/generic/f16sqrtf.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===-- Implementation of f16sqrtf function -------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "src/math/f16sqrtf.h"
10+
#include "src/__support/FPUtil/sqrt.h"
11+
#include "src/__support/common.h"
12+
13+
namespace LIBC_NAMESPACE {
14+
15+
LLVM_LIBC_FUNCTION(float16, f16sqrtf, (float x)) {
16+
return fputil::sqrt<float16>(x);
17+
}
18+
19+
} // namespace LIBC_NAMESPACE

libc/src/math/generic/hypotf.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ LLVM_LIBC_FUNCTION(float, hypotf, (float x, float y)) {
4242
double err = (x_sq >= y_sq) ? (sum_sq - x_sq) - y_sq : (sum_sq - y_sq) - x_sq;
4343

4444
// Take sqrt in double precision.
45-
DoubleBits result(fputil::sqrt(sum_sq));
45+
DoubleBits result(fputil::sqrt<double>(sum_sq));
4646

4747
if (!DoubleBits(sum_sq).is_inf_or_nan()) {
4848
// Correct rounding.

libc/src/math/generic/powf.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ LLVM_LIBC_FUNCTION(float, powf, (float x, float y)) {
562562
switch (y_u) {
563563
case 0x3f00'0000: // y = 0.5f
564564
// pow(x, 1/2) = sqrt(x)
565-
return fputil::sqrt(x);
565+
return fputil::sqrt<float>(x);
566566
case 0x3f80'0000: // y = 1.0f
567567
return x;
568568
case 0x4000'0000: // y = 2.0f

libc/src/math/generic/sqrt.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212

1313
namespace LIBC_NAMESPACE {
1414

15-
LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt(x); }
15+
LLVM_LIBC_FUNCTION(double, sqrt, (double x)) { return fputil::sqrt<double>(x); }
1616

1717
} // namespace LIBC_NAMESPACE

libc/src/math/generic/sqrtf.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@
1212

1313
namespace LIBC_NAMESPACE {
1414

15-
LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt(x); }
15+
LLVM_LIBC_FUNCTION(float, sqrtf, (float x)) { return fputil::sqrt<float>(x); }
1616

1717
} // namespace LIBC_NAMESPACE

libc/src/math/generic/sqrtf128.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
namespace LIBC_NAMESPACE {
1414

15-
LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) { return fputil::sqrt(x); }
15+
LLVM_LIBC_FUNCTION(float128, sqrtf128, (float128 x)) {
16+
return fputil::sqrt<float128>(x);
17+
}
1618

1719
} // namespace LIBC_NAMESPACE

libc/src/math/generic/sqrtl.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace LIBC_NAMESPACE {
1414

1515
LLVM_LIBC_FUNCTION(long double, sqrtl, (long double x)) {
16-
return fputil::sqrt(x);
16+
return fputil::sqrt<long double>(x);
1717
}
1818

1919
} // namespace LIBC_NAMESPACE

0 commit comments

Comments
 (0)