xref: /freebsd/contrib/llvm-project/libc/src/__support/FPUtil/dyadic_float.h (revision bb722a7d0f1642bff6487f943ad0427799a6e5bf)
1 //===-- A class to store high precision floating point numbers --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
10 #define LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
11 
12 #include "FEnvImpl.h"
13 #include "FPBits.h"
14 #include "hdr/errno_macros.h"
15 #include "hdr/fenv_macros.h"
16 #include "multiply_add.h"
17 #include "rounding_mode.h"
18 #include "src/__support/CPP/type_traits.h"
19 #include "src/__support/big_int.h"
20 #include "src/__support/macros/config.h"
21 #include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
22 #include "src/__support/macros/properties/types.h"
23 
24 #include <stddef.h>
25 
26 namespace LIBC_NAMESPACE_DECL {
27 namespace fputil {
28 
29 // Decide whether to round a UInt up, down or not at all at a given bit
30 // position, based on the current rounding mode. The assumption is that the
31 // caller is going to make the integer `value >> rshift`, and then might need
32 // to round it up by 1 depending on the value of the bits shifted off the
33 // bottom.
34 //
35 // `logical_sign` causes the behavior of FE_DOWNWARD and FE_UPWARD to
36 // be reversed, which is what you'd want if this is the mantissa of a
37 // negative floating-point number.
38 //
39 // Return value is +1 if the value should be rounded up; -1 if it should be
40 // rounded down; 0 if it's exact and needs no rounding.
41 template <size_t Bits>
42 LIBC_INLINE constexpr int
rounding_direction(const LIBC_NAMESPACE::UInt<Bits> & value,size_t rshift,Sign logical_sign)43 rounding_direction(const LIBC_NAMESPACE::UInt<Bits> &value, size_t rshift,
44                    Sign logical_sign) {
45   if (rshift == 0 || (rshift < Bits && (value << (Bits - rshift)) == 0) ||
46       (rshift >= Bits && value == 0))
47     return 0; // exact
48 
49   switch (quick_get_round()) {
50   case FE_TONEAREST:
51     if (rshift > 0 && rshift <= Bits && value.get_bit(rshift - 1)) {
52       // We round up, unless the value is an exact halfway case and
53       // the bit that will end up in the units place is 0, in which
54       // case tie-break-to-even says round down.
55       bool round_bit = rshift < Bits ? value.get_bit(rshift) : 0;
56       return round_bit != 0 || (value << (Bits - rshift + 1)) != 0 ? +1 : -1;
57     } else {
58       return -1;
59     }
60   case FE_TOWARDZERO:
61     return -1;
62   case FE_DOWNWARD:
63     return logical_sign.is_neg() &&
64                    (rshift < Bits && (value << (Bits - rshift)) != 0)
65                ? +1
66                : -1;
67   case FE_UPWARD:
68     return logical_sign.is_pos() &&
69                    (rshift < Bits && (value << (Bits - rshift)) != 0)
70                ? +1
71                : -1;
72   default:
73     __builtin_unreachable();
74   }
75 }
76 
77 // A generic class to perform computations of high precision floating points.
78 // We store the value in dyadic format, including 3 fields:
79 //   sign    : boolean value - false means positive, true means negative
80 //   exponent: the exponent value of the least significant bit of the mantissa.
81 //   mantissa: unsigned integer of length `Bits`.
82 // So the real value that is stored is:
83 //   real value = (-1)^sign * 2^exponent * (mantissa as unsigned integer)
84 // The stored data is normal if for non-zero mantissa, the leading bit is 1.
85 // The outputs of the constructors and most functions will be normalized.
86 // To simplify and improve the efficiency, many functions will assume that the
87 // inputs are normal.
88 template <size_t Bits> struct DyadicFloat {
89   using MantissaType = LIBC_NAMESPACE::UInt<Bits>;
90 
91   Sign sign = Sign::POS;
92   int exponent = 0;
93   MantissaType mantissa = MantissaType(0);
94 
95   LIBC_INLINE constexpr DyadicFloat() = default;
96 
97   template <typename T, cpp::enable_if_t<cpp::is_floating_point_v<T>, int> = 0>
DyadicFloatDyadicFloat98   LIBC_INLINE constexpr DyadicFloat(T x) {
99     static_assert(FPBits<T>::FRACTION_LEN < Bits);
100     FPBits<T> x_bits(x);
101     sign = x_bits.sign();
102     exponent = x_bits.get_explicit_exponent() - FPBits<T>::FRACTION_LEN;
103     mantissa = MantissaType(x_bits.get_explicit_mantissa());
104     normalize();
105   }
106 
DyadicFloatDyadicFloat107   LIBC_INLINE constexpr DyadicFloat(Sign s, int e, const MantissaType &m)
108       : sign(s), exponent(e), mantissa(m) {
109     normalize();
110   }
111 
112   // Normalizing the mantissa, bringing the leading 1 bit to the most
113   // significant bit.
normalizeDyadicFloat114   LIBC_INLINE constexpr DyadicFloat &normalize() {
115     if (!mantissa.is_zero()) {
116       int shift_length = cpp::countl_zero(mantissa);
117       exponent -= shift_length;
118       mantissa <<= static_cast<size_t>(shift_length);
119     }
120     return *this;
121   }
122 
123   // Used for aligning exponents.  Output might not be normalized.
shift_leftDyadicFloat124   LIBC_INLINE constexpr DyadicFloat &shift_left(unsigned shift_length) {
125     if (shift_length < Bits) {
126       exponent -= static_cast<int>(shift_length);
127       mantissa <<= shift_length;
128     } else {
129       exponent = 0;
130       mantissa = MantissaType(0);
131     }
132     return *this;
133   }
134 
135   // Used for aligning exponents.  Output might not be normalized.
shift_rightDyadicFloat136   LIBC_INLINE constexpr DyadicFloat &shift_right(unsigned shift_length) {
137     if (shift_length < Bits) {
138       exponent += static_cast<int>(shift_length);
139       mantissa >>= shift_length;
140     } else {
141       exponent = 0;
142       mantissa = MantissaType(0);
143     }
144     return *this;
145   }
146 
147   // Assume that it is already normalized.  Output the unbiased exponent.
get_unbiased_exponentDyadicFloat148   LIBC_INLINE constexpr int get_unbiased_exponent() const {
149     return exponent + (Bits - 1);
150   }
151 
152   // Produce a correctly rounded DyadicFloat from a too-large mantissa,
153   // by shifting it down and rounding if necessary.
154   template <size_t MantissaBits>
155   LIBC_INLINE constexpr static DyadicFloat<Bits>
roundDyadicFloat156   round(Sign result_sign, int result_exponent,
157         const LIBC_NAMESPACE::UInt<MantissaBits> &input_mantissa,
158         size_t rshift) {
159     MantissaType result_mantissa(input_mantissa >> rshift);
160     if (rounding_direction(input_mantissa, rshift, result_sign) > 0) {
161       ++result_mantissa;
162       if (result_mantissa == 0) {
163         // Rounding up made the mantissa integer wrap round to 0,
164         // carrying a bit off the top. So we've rounded up to the next
165         // exponent.
166         result_mantissa.set_bit(Bits - 1);
167         ++result_exponent;
168       }
169     }
170     return DyadicFloat(result_sign, result_exponent, result_mantissa);
171   }
172 
173   template <typename T, bool ShouldSignalExceptions>
174   LIBC_INLINE constexpr cpp::enable_if_t<
175       cpp::is_floating_point_v<T> && (FPBits<T>::FRACTION_LEN < Bits), T>
generic_asDyadicFloat176   generic_as() const {
177     using FPBits = FPBits<T>;
178     using StorageType = typename FPBits::StorageType;
179 
180     constexpr int EXTRA_FRACTION_LEN = Bits - 1 - FPBits::FRACTION_LEN;
181 
182     if (mantissa == 0)
183       return FPBits::zero(sign).get_val();
184 
185     int unbiased_exp = get_unbiased_exponent();
186 
187     if (unbiased_exp + FPBits::EXP_BIAS >= FPBits::MAX_BIASED_EXPONENT) {
188       if constexpr (ShouldSignalExceptions) {
189         set_errno_if_required(ERANGE);
190         raise_except_if_required(FE_OVERFLOW | FE_INEXACT);
191       }
192 
193       switch (quick_get_round()) {
194       case FE_TONEAREST:
195         return FPBits::inf(sign).get_val();
196       case FE_TOWARDZERO:
197         return FPBits::max_normal(sign).get_val();
198       case FE_DOWNWARD:
199         if (sign.is_pos())
200           return FPBits::max_normal(Sign::POS).get_val();
201         return FPBits::inf(Sign::NEG).get_val();
202       case FE_UPWARD:
203         if (sign.is_neg())
204           return FPBits::max_normal(Sign::NEG).get_val();
205         return FPBits::inf(Sign::POS).get_val();
206       default:
207         __builtin_unreachable();
208       }
209     }
210 
211     StorageType out_biased_exp = 0;
212     StorageType out_mantissa = 0;
213     bool round = false;
214     bool sticky = false;
215     bool underflow = false;
216 
217     if (unbiased_exp < -FPBits::EXP_BIAS - FPBits::FRACTION_LEN) {
218       sticky = true;
219       underflow = true;
220     } else if (unbiased_exp == -FPBits::EXP_BIAS - FPBits::FRACTION_LEN) {
221       round = true;
222       MantissaType sticky_mask = (MantissaType(1) << (Bits - 1)) - 1;
223       sticky = (mantissa & sticky_mask) != 0;
224     } else {
225       int extra_fraction_len = EXTRA_FRACTION_LEN;
226 
227       if (unbiased_exp < 1 - FPBits::EXP_BIAS) {
228         underflow = true;
229         extra_fraction_len += 1 - FPBits::EXP_BIAS - unbiased_exp;
230       } else {
231         out_biased_exp =
232             static_cast<StorageType>(unbiased_exp + FPBits::EXP_BIAS);
233       }
234 
235       MantissaType round_mask = MantissaType(1) << (extra_fraction_len - 1);
236       round = (mantissa & round_mask) != 0;
237       MantissaType sticky_mask = round_mask - 1;
238       sticky = (mantissa & sticky_mask) != 0;
239 
240       out_mantissa = static_cast<StorageType>(mantissa >> extra_fraction_len);
241     }
242 
243     bool lsb = (out_mantissa & 1) != 0;
244 
245     StorageType result =
246         FPBits::create_value(sign, out_biased_exp, out_mantissa).uintval();
247 
248     switch (quick_get_round()) {
249     case FE_TONEAREST:
250       if (round && (lsb || sticky))
251         ++result;
252       break;
253     case FE_DOWNWARD:
254       if (sign.is_neg() && (round || sticky))
255         ++result;
256       break;
257     case FE_UPWARD:
258       if (sign.is_pos() && (round || sticky))
259         ++result;
260       break;
261     default:
262       break;
263     }
264 
265     if (ShouldSignalExceptions && (round || sticky)) {
266       int excepts = FE_INEXACT;
267       if (FPBits(result).is_inf()) {
268         set_errno_if_required(ERANGE);
269         excepts |= FE_OVERFLOW;
270       } else if (underflow) {
271         set_errno_if_required(ERANGE);
272         excepts |= FE_UNDERFLOW;
273       }
274       raise_except_if_required(excepts);
275     }
276 
277     return FPBits(result).get_val();
278   }
279 
280   template <typename T, bool ShouldSignalExceptions,
281             typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
282                                             (FPBits<T>::FRACTION_LEN < Bits),
283                                         void>>
fast_asDyadicFloat284   LIBC_INLINE constexpr T fast_as() const {
285     if (LIBC_UNLIKELY(mantissa.is_zero()))
286       return FPBits<T>::zero(sign).get_val();
287 
288     // Assume that it is normalized, and output is also normal.
289     constexpr uint32_t PRECISION = FPBits<T>::FRACTION_LEN + 1;
290     using output_bits_t = typename FPBits<T>::StorageType;
291     constexpr output_bits_t IMPLICIT_MASK =
292         FPBits<T>::SIG_MASK - FPBits<T>::FRACTION_MASK;
293 
294     int exp_hi = exponent + static_cast<int>((Bits - 1) + FPBits<T>::EXP_BIAS);
295 
296     if (LIBC_UNLIKELY(exp_hi > 2 * FPBits<T>::EXP_BIAS)) {
297       // Results overflow.
298       T d_hi =
299           FPBits<T>::create_value(sign, 2 * FPBits<T>::EXP_BIAS, IMPLICIT_MASK)
300               .get_val();
301       // volatile prevents constant propagation that would result in infinity
302       // always being returned no matter the current rounding mode.
303       volatile T two = static_cast<T>(2.0);
304       T r = two * d_hi;
305 
306       // TODO: Whether rounding down the absolute value to max_normal should
307       // also raise FE_OVERFLOW and set ERANGE is debatable.
308       if (ShouldSignalExceptions && FPBits<T>(r).is_inf())
309         set_errno_if_required(ERANGE);
310 
311       return r;
312     }
313 
314     bool denorm = false;
315     uint32_t shift = Bits - PRECISION;
316     if (LIBC_UNLIKELY(exp_hi <= 0)) {
317       // Output is denormal.
318       denorm = true;
319       shift = (Bits - PRECISION) + static_cast<uint32_t>(1 - exp_hi);
320 
321       exp_hi = FPBits<T>::EXP_BIAS;
322     }
323 
324     int exp_lo = exp_hi - static_cast<int>(PRECISION) - 1;
325 
326     MantissaType m_hi =
327         shift >= MantissaType::BITS ? MantissaType(0) : mantissa >> shift;
328 
329     T d_hi = FPBits<T>::create_value(
330                  sign, static_cast<output_bits_t>(exp_hi),
331                  (static_cast<output_bits_t>(m_hi) & FPBits<T>::SIG_MASK) |
332                      IMPLICIT_MASK)
333                  .get_val();
334 
335     MantissaType round_mask =
336         shift - 1 >= MantissaType::BITS ? 0 : MantissaType(1) << (shift - 1);
337     MantissaType sticky_mask = round_mask - MantissaType(1);
338 
339     bool round_bit = !(mantissa & round_mask).is_zero();
340     bool sticky_bit = !(mantissa & sticky_mask).is_zero();
341     int round_and_sticky = int(round_bit) * 2 + int(sticky_bit);
342 
343     T d_lo;
344 
345     if (LIBC_UNLIKELY(exp_lo <= 0)) {
346       // d_lo is denormal, but the output is normal.
347       int scale_up_exponent = 1 - exp_lo;
348       T scale_up_factor =
349           FPBits<T>::create_value(Sign::POS,
350                                   static_cast<output_bits_t>(
351                                       FPBits<T>::EXP_BIAS + scale_up_exponent),
352                                   IMPLICIT_MASK)
353               .get_val();
354       T scale_down_factor =
355           FPBits<T>::create_value(Sign::POS,
356                                   static_cast<output_bits_t>(
357                                       FPBits<T>::EXP_BIAS - scale_up_exponent),
358                                   IMPLICIT_MASK)
359               .get_val();
360 
361       d_lo = FPBits<T>::create_value(
362                  sign, static_cast<output_bits_t>(exp_lo + scale_up_exponent),
363                  IMPLICIT_MASK)
364                  .get_val();
365 
366       return multiply_add(d_lo, T(round_and_sticky), d_hi * scale_up_factor) *
367              scale_down_factor;
368     }
369 
370     d_lo = FPBits<T>::create_value(sign, static_cast<output_bits_t>(exp_lo),
371                                    IMPLICIT_MASK)
372                .get_val();
373 
374     // Still correct without FMA instructions if `d_lo` is not underflow.
375     T r = multiply_add(d_lo, T(round_and_sticky), d_hi);
376 
377     if (LIBC_UNLIKELY(denorm)) {
378       // Exponent before rounding is in denormal range, simply clear the
379       // exponent field.
380       output_bits_t clear_exp = static_cast<output_bits_t>(
381           output_bits_t(exp_hi) << FPBits<T>::SIG_LEN);
382       output_bits_t r_bits = FPBits<T>(r).uintval() - clear_exp;
383 
384       if (!(r_bits & FPBits<T>::EXP_MASK)) {
385         // Output is denormal after rounding, clear the implicit bit for 80-bit
386         // long double.
387         r_bits -= IMPLICIT_MASK;
388 
389         // TODO: IEEE Std 754-2019 lets implementers choose whether to check for
390         // "tininess" before or after rounding for base-2 formats, as long as
391         // the same choice is made for all operations. Our choice to check after
392         // rounding might not be the same as the hardware's.
393         if (ShouldSignalExceptions && round_and_sticky) {
394           set_errno_if_required(ERANGE);
395           raise_except_if_required(FE_UNDERFLOW);
396         }
397       }
398 
399       return FPBits<T>(r_bits).get_val();
400     }
401 
402     return r;
403   }
404 
405   // Assume that it is already normalized.
406   // Output is rounded correctly with respect to the current rounding mode.
407   template <typename T, bool ShouldSignalExceptions,
408             typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
409                                             (FPBits<T>::FRACTION_LEN < Bits),
410                                         void>>
asDyadicFloat411   LIBC_INLINE constexpr T as() const {
412     if constexpr (cpp::is_same_v<T, bfloat16>
413 #if defined(LIBC_TYPES_HAS_FLOAT16) && !defined(__LIBC_USE_FLOAT16_CONVERSION)
414                   || cpp::is_same_v<T, float16>
415 #endif
416     )
417       return generic_as<T, ShouldSignalExceptions>();
418     else
419       return fast_as<T, ShouldSignalExceptions>();
420   }
421 
422   template <typename T,
423             typename = cpp::enable_if_t<cpp::is_floating_point_v<T> &&
424                                             (FPBits<T>::FRACTION_LEN < Bits),
425                                         void>>
TDyadicFloat426   LIBC_INLINE explicit constexpr operator T() const {
427     return as<T, /*ShouldSignalExceptions=*/false>();
428   }
429 
as_mantissa_typeDyadicFloat430   LIBC_INLINE constexpr MantissaType as_mantissa_type() const {
431     if (mantissa.is_zero())
432       return 0;
433 
434     MantissaType new_mant = mantissa;
435     if (exponent > 0) {
436       new_mant <<= exponent;
437     } else {
438       // Cast the exponent to size_t before negating it, rather than after,
439       // to avoid undefined behavior negating INT_MIN as an integer (although
440       // exponents coming in to this function _shouldn't_ be that large). The
441       // result should always end up as a positive size_t.
442       size_t shift = -static_cast<size_t>(exponent);
443       new_mant >>= shift;
444     }
445 
446     if (sign.is_neg()) {
447       new_mant = (~new_mant) + 1;
448     }
449 
450     return new_mant;
451   }
452 
453   LIBC_INLINE constexpr MantissaType
454   as_mantissa_type_rounded(int *round_dir_out = nullptr) const {
455     int round_dir = 0;
456     MantissaType new_mant;
457     if (mantissa.is_zero()) {
458       new_mant = 0;
459     } else {
460       new_mant = mantissa;
461       if (exponent > 0) {
462         new_mant <<= exponent;
463       } else if (exponent < 0) {
464         // Cast the exponent to size_t before negating it, rather than after,
465         // to avoid undefined behavior negating INT_MIN as an integer (although
466         // exponents coming in to this function _shouldn't_ be that large). The
467         // result should always end up as a positive size_t.
468         size_t shift = -static_cast<size_t>(exponent);
469         if (shift >= Bits)
470           new_mant = 0;
471         else
472           new_mant >>= shift;
473         round_dir = rounding_direction(mantissa, shift, sign);
474         if (round_dir > 0)
475           ++new_mant;
476       }
477 
478       if (sign.is_neg()) {
479         new_mant = (~new_mant) + 1;
480       }
481     }
482 
483     if (round_dir_out)
484       *round_dir_out = round_dir;
485 
486     return new_mant;
487   }
488 
489   LIBC_INLINE constexpr DyadicFloat operator-() const {
490     return DyadicFloat(sign.negate(), exponent, mantissa);
491   }
492 };
493 
494 // Quick add - Add 2 dyadic floats with rounding toward 0 and then normalize the
495 // output:
496 //   - Align the exponents so that:
497 //     new a.exponent = new b.exponent = max(a.exponent, b.exponent)
498 //   - Add or subtract the mantissas depending on the signs.
499 //   - Normalize the result.
500 // The absolute errors compared to the mathematical sum is bounded by:
501 //   | quick_add(a, b) - (a + b) | < MSB(a + b) * 2^(-Bits + 2),
502 // i.e., errors are up to 2 ULPs.
503 // Assume inputs are normalized (by constructors or other functions) so that we
504 // don't need to normalize the inputs again in this function.  If the inputs are
505 // not normalized, the results might lose precision significantly.
506 template <size_t Bits>
quick_add(DyadicFloat<Bits> a,DyadicFloat<Bits> b)507 LIBC_INLINE constexpr DyadicFloat<Bits> quick_add(DyadicFloat<Bits> a,
508                                                   DyadicFloat<Bits> b) {
509   if (LIBC_UNLIKELY(a.mantissa.is_zero()))
510     return b;
511   if (LIBC_UNLIKELY(b.mantissa.is_zero()))
512     return a;
513 
514   // Align exponents
515   if (a.exponent > b.exponent)
516     b.shift_right(static_cast<unsigned>(a.exponent - b.exponent));
517   else if (b.exponent > a.exponent)
518     a.shift_right(static_cast<unsigned>(b.exponent - a.exponent));
519 
520   DyadicFloat<Bits> result;
521 
522   if (a.sign == b.sign) {
523     // Addition
524     result.sign = a.sign;
525     result.exponent = a.exponent;
526     result.mantissa = a.mantissa;
527     if (result.mantissa.add_overflow(b.mantissa)) {
528       // Mantissa addition overflow.
529       result.shift_right(1);
530       result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] |=
531           (uint64_t(1) << 63);
532     }
533     // Result is already normalized.
534     return result;
535   }
536 
537   // Subtraction
538   if (a.mantissa >= b.mantissa) {
539     result.sign = a.sign;
540     result.exponent = a.exponent;
541     result.mantissa = a.mantissa - b.mantissa;
542   } else {
543     result.sign = b.sign;
544     result.exponent = b.exponent;
545     result.mantissa = b.mantissa - a.mantissa;
546   }
547 
548   return result.normalize();
549 }
550 
551 template <size_t Bits>
quick_sub(DyadicFloat<Bits> a,DyadicFloat<Bits> b)552 LIBC_INLINE constexpr DyadicFloat<Bits> quick_sub(DyadicFloat<Bits> a,
553                                                   DyadicFloat<Bits> b) {
554   return quick_add(a, -b);
555 }
556 
557 // Quick Mul - Slightly less accurate but efficient multiplication of 2 dyadic
558 // floats with rounding toward 0 and then normalize the output:
559 //   result.exponent = a.exponent + b.exponent + Bits,
560 //   result.mantissa = quick_mul_hi(a.mantissa + b.mantissa)
561 //                   ~ (full product a.mantissa * b.mantissa) >> Bits.
562 // The errors compared to the mathematical product is bounded by:
563 //   2 * errors of quick_mul_hi = 2 * (UInt<Bits>::WORD_COUNT - 1) in ULPs.
564 // Assume inputs are normalized (by constructors or other functions) so that we
565 // don't need to normalize the inputs again in this function.  If the inputs are
566 // not normalized, the results might lose precision significantly.
567 template <size_t Bits>
quick_mul(const DyadicFloat<Bits> & a,const DyadicFloat<Bits> & b)568 LIBC_INLINE constexpr DyadicFloat<Bits> quick_mul(const DyadicFloat<Bits> &a,
569                                                   const DyadicFloat<Bits> &b) {
570   DyadicFloat<Bits> result;
571   result.sign = (a.sign != b.sign) ? Sign::NEG : Sign::POS;
572   result.exponent = a.exponent + b.exponent + static_cast<int>(Bits);
573 
574   if (!(a.mantissa.is_zero() || b.mantissa.is_zero())) {
575     result.mantissa = a.mantissa.quick_mul_hi(b.mantissa);
576     // Check the leading bit directly, should be faster than using clz in
577     // normalize().
578     if (result.mantissa.val[DyadicFloat<Bits>::MantissaType::WORD_COUNT - 1] >>
579             63 ==
580         0)
581       result.shift_left(1);
582   } else {
583     result.mantissa = (typename DyadicFloat<Bits>::MantissaType)(0);
584   }
585   return result;
586 }
587 
588 // Correctly rounded multiplication of 2 dyadic floats, assuming the
589 // exponent remains within range.
590 template <size_t Bits>
591 LIBC_INLINE constexpr DyadicFloat<Bits>
rounded_mul(const DyadicFloat<Bits> & a,const DyadicFloat<Bits> & b)592 rounded_mul(const DyadicFloat<Bits> &a, const DyadicFloat<Bits> &b) {
593   using DblMant = LIBC_NAMESPACE::UInt<(2 * Bits)>;
594   Sign result_sign = (a.sign != b.sign) ? Sign::NEG : Sign::POS;
595   int result_exponent = a.exponent + b.exponent + static_cast<int>(Bits);
596   auto product = DblMant(a.mantissa) * DblMant(b.mantissa);
597   // As in quick_mul(), renormalize by 1 bit manually rather than countl_zero
598   if (product.get_bit(2 * Bits - 1) == 0) {
599     product <<= 1;
600     result_exponent -= 1;
601   }
602 
603   return DyadicFloat<Bits>::round(result_sign, result_exponent, product, Bits);
604 }
605 
606 // Approximate reciprocal - given a nonzero a, make a good approximation to 1/a.
607 // The method is Newton-Raphson iteration, based on quick_mul.
608 template <size_t Bits, typename = cpp::enable_if_t<(Bits >= 32)>>
609 LIBC_INLINE constexpr DyadicFloat<Bits>
approx_reciprocal(const DyadicFloat<Bits> & a)610 approx_reciprocal(const DyadicFloat<Bits> &a) {
611   // Given an approximation x to 1/a, a better one is x' = x(2-ax).
612   //
613   // You can derive this by using the Newton-Raphson formula with the function
614   // f(x) = 1/x - a. But another way to see that it works is to say: suppose
615   // that ax = 1-e for some small error e. Then ax' = ax(2-ax) = (1-e)(1+e) =
616   // 1-e^2. So the error in x' is the square of the error in x, i.e. the number
617   // of correct bits in x' is double the number in x.
618 
619   // An initial approximation to the reciprocal
620   DyadicFloat<Bits> x(Sign::POS, -32 - a.exponent - int(Bits),
621                       uint64_t(0xFFFFFFFFFFFFFFFF) /
622                           static_cast<uint64_t>(a.mantissa >> (Bits - 32)));
623 
624   // The constant 2, which we'll need in every iteration
625   DyadicFloat<Bits> two(Sign::POS, 1, 1);
626 
627   // We expect at least 31 correct bits from our 32-bit starting approximation
628   size_t ok_bits = 31;
629 
630   // The number of good bits doubles in each iteration, except that rounding
631   // errors introduce a little extra each time. Subtract a bit from our
632   // accuracy assessment to account for that.
633   while (ok_bits < Bits) {
634     x = quick_mul(x, quick_sub(two, quick_mul(a, x)));
635     ok_bits = 2 * ok_bits - 1;
636   }
637 
638   return x;
639 }
640 
641 // Correctly rounded division of 2 dyadic floats, assuming the
642 // exponent remains within range.
643 template <size_t Bits>
644 LIBC_INLINE constexpr DyadicFloat<Bits>
rounded_div(const DyadicFloat<Bits> & af,const DyadicFloat<Bits> & bf)645 rounded_div(const DyadicFloat<Bits> &af, const DyadicFloat<Bits> &bf) {
646   using DblMant = LIBC_NAMESPACE::UInt<(Bits * 2 + 64)>;
647 
648   // Make an approximation to the quotient as a * (1/b). Both the
649   // multiplication and the reciprocal are a bit sloppy, which doesn't
650   // matter, because we're going to correct for that below.
651   auto qf = fputil::quick_mul(af, fputil::approx_reciprocal(bf));
652 
653   // Switch to BigInt and stop using quick_add and quick_mul: now
654   // we're working in exact integers so as to get the true remainder.
655   DblMant a = af.mantissa, b = bf.mantissa, q = qf.mantissa;
656   q <<= 2; // leave room for a round bit, even if exponent decreases
657   a <<= af.exponent - bf.exponent - qf.exponent + 2;
658   DblMant qb = q * b;
659   if (qb < a) {
660     DblMant too_small = a - b;
661     while (qb <= too_small) {
662       qb += b;
663       ++q;
664     }
665   } else {
666     while (qb > a) {
667       qb -= b;
668       --q;
669     }
670   }
671 
672   DyadicFloat<(Bits * 2)> qbig(qf.sign, qf.exponent - 2, q);
673   return DyadicFloat<Bits>::round(qbig.sign, qbig.exponent + Bits,
674                                   qbig.mantissa, Bits);
675 }
676 
677 // Simple polynomial approximation.
678 template <size_t Bits>
679 LIBC_INLINE constexpr DyadicFloat<Bits>
multiply_add(const DyadicFloat<Bits> & a,const DyadicFloat<Bits> & b,const DyadicFloat<Bits> & c)680 multiply_add(const DyadicFloat<Bits> &a, const DyadicFloat<Bits> &b,
681              const DyadicFloat<Bits> &c) {
682   return quick_add(c, quick_mul(a, b));
683 }
684 
685 // Simple exponentiation implementation for printf. Only handles positive
686 // exponents, since division isn't implemented.
687 template <size_t Bits>
pow_n(const DyadicFloat<Bits> & a,uint32_t power)688 LIBC_INLINE constexpr DyadicFloat<Bits> pow_n(const DyadicFloat<Bits> &a,
689                                               uint32_t power) {
690   DyadicFloat<Bits> result = 1.0;
691   DyadicFloat<Bits> cur_power = a;
692 
693   while (power > 0) {
694     if ((power % 2) > 0) {
695       result = quick_mul(result, cur_power);
696     }
697     power = power >> 1;
698     cur_power = quick_mul(cur_power, cur_power);
699   }
700   return result;
701 }
702 
703 template <size_t Bits>
mul_pow_2(const DyadicFloat<Bits> & a,int32_t pow_2)704 LIBC_INLINE constexpr DyadicFloat<Bits> mul_pow_2(const DyadicFloat<Bits> &a,
705                                                   int32_t pow_2) {
706   DyadicFloat<Bits> result = a;
707   result.exponent += pow_2;
708   return result;
709 }
710 
711 } // namespace fputil
712 } // namespace LIBC_NAMESPACE_DECL
713 
714 #endif // LLVM_LIBC_SRC___SUPPORT_FPUTIL_DYADIC_FLOAT_H
715