xref: /freebsd/contrib/llvm-project/llvm/lib/Support/APFixedPoint.cpp (revision 2b8331622f0b212cf3bb4fc4914a501e5321d506)
1 //===- APFixedPoint.cpp - Fixed point constant handling ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// Defines the implementation for the fixed point number interface.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/APFixedPoint.h"
15 #include "llvm/ADT/APFloat.h"
16 
17 namespace llvm {
18 
19 APFixedPoint APFixedPoint::convert(const FixedPointSemantics &DstSema,
20                                    bool *Overflow) const {
21   APSInt NewVal = Val;
22   unsigned DstWidth = DstSema.getWidth();
23   unsigned DstScale = DstSema.getScale();
24   bool Upscaling = DstScale > getScale();
25   if (Overflow)
26     *Overflow = false;
27 
28   if (Upscaling) {
29     NewVal = NewVal.extend(NewVal.getBitWidth() + DstScale - getScale());
30     NewVal <<= (DstScale - getScale());
31   } else {
32     NewVal >>= (getScale() - DstScale);
33   }
34 
35   auto Mask = APInt::getBitsSetFrom(
36       NewVal.getBitWidth(),
37       std::min(DstScale + DstSema.getIntegralBits(), NewVal.getBitWidth()));
38   APInt Masked(NewVal & Mask);
39 
40   // Change in the bits above the sign
41   if (!(Masked == Mask || Masked == 0)) {
42     // Found overflow in the bits above the sign
43     if (DstSema.isSaturated())
44       NewVal = NewVal.isNegative() ? Mask : ~Mask;
45     else if (Overflow)
46       *Overflow = true;
47   }
48 
49   // If the dst semantics are unsigned, but our value is signed and negative, we
50   // clamp to zero.
51   if (!DstSema.isSigned() && NewVal.isSigned() && NewVal.isNegative()) {
52     // Found negative overflow for unsigned result
53     if (DstSema.isSaturated())
54       NewVal = 0;
55     else if (Overflow)
56       *Overflow = true;
57   }
58 
59   NewVal = NewVal.extOrTrunc(DstWidth);
60   NewVal.setIsSigned(DstSema.isSigned());
61   return APFixedPoint(NewVal, DstSema);
62 }
63 
64 int APFixedPoint::compare(const APFixedPoint &Other) const {
65   APSInt ThisVal = getValue();
66   APSInt OtherVal = Other.getValue();
67   bool ThisSigned = Val.isSigned();
68   bool OtherSigned = OtherVal.isSigned();
69   unsigned OtherScale = Other.getScale();
70   unsigned OtherWidth = OtherVal.getBitWidth();
71 
72   unsigned CommonWidth = std::max(Val.getBitWidth(), OtherWidth);
73 
74   // Prevent overflow in the event the widths are the same but the scales differ
75   CommonWidth += getScale() >= OtherScale ? getScale() - OtherScale
76                                           : OtherScale - getScale();
77 
78   ThisVal = ThisVal.extOrTrunc(CommonWidth);
79   OtherVal = OtherVal.extOrTrunc(CommonWidth);
80 
81   unsigned CommonScale = std::max(getScale(), OtherScale);
82   ThisVal = ThisVal.shl(CommonScale - getScale());
83   OtherVal = OtherVal.shl(CommonScale - OtherScale);
84 
85   if (ThisSigned && OtherSigned) {
86     if (ThisVal.sgt(OtherVal))
87       return 1;
88     else if (ThisVal.slt(OtherVal))
89       return -1;
90   } else if (!ThisSigned && !OtherSigned) {
91     if (ThisVal.ugt(OtherVal))
92       return 1;
93     else if (ThisVal.ult(OtherVal))
94       return -1;
95   } else if (ThisSigned && !OtherSigned) {
96     if (ThisVal.isSignBitSet())
97       return -1;
98     else if (ThisVal.ugt(OtherVal))
99       return 1;
100     else if (ThisVal.ult(OtherVal))
101       return -1;
102   } else {
103     // !ThisSigned && OtherSigned
104     if (OtherVal.isSignBitSet())
105       return 1;
106     else if (ThisVal.ugt(OtherVal))
107       return 1;
108     else if (ThisVal.ult(OtherVal))
109       return -1;
110   }
111 
112   return 0;
113 }
114 
115 APFixedPoint APFixedPoint::getMax(const FixedPointSemantics &Sema) {
116   bool IsUnsigned = !Sema.isSigned();
117   auto Val = APSInt::getMaxValue(Sema.getWidth(), IsUnsigned);
118   if (IsUnsigned && Sema.hasUnsignedPadding())
119     Val = Val.lshr(1);
120   return APFixedPoint(Val, Sema);
121 }
122 
123 APFixedPoint APFixedPoint::getMin(const FixedPointSemantics &Sema) {
124   auto Val = APSInt::getMinValue(Sema.getWidth(), !Sema.isSigned());
125   return APFixedPoint(Val, Sema);
126 }
127 
128 bool FixedPointSemantics::fitsInFloatSemantics(
129     const fltSemantics &FloatSema) const {
130   // A fixed point semantic fits in a floating point semantic if the maximum
131   // and minimum values as integers of the fixed point semantic can fit in the
132   // floating point semantic.
133 
134   // If these values do not fit, then a floating point rescaling of the true
135   // maximum/minimum value will not fit either, so the floating point semantic
136   // cannot be used to perform such a rescaling.
137 
138   APSInt MaxInt = APFixedPoint::getMax(*this).getValue();
139   APFloat F(FloatSema);
140   APFloat::opStatus Status = F.convertFromAPInt(MaxInt, MaxInt.isSigned(),
141                                                 APFloat::rmNearestTiesToAway);
142   if ((Status & APFloat::opOverflow) || !isSigned())
143     return !(Status & APFloat::opOverflow);
144 
145   APSInt MinInt = APFixedPoint::getMin(*this).getValue();
146   Status = F.convertFromAPInt(MinInt, MinInt.isSigned(),
147                               APFloat::rmNearestTiesToAway);
148   return !(Status & APFloat::opOverflow);
149 }
150 
151 FixedPointSemantics FixedPointSemantics::getCommonSemantics(
152     const FixedPointSemantics &Other) const {
153   unsigned CommonScale = std::max(getScale(), Other.getScale());
154   unsigned CommonWidth =
155       std::max(getIntegralBits(), Other.getIntegralBits()) + CommonScale;
156 
157   bool ResultIsSigned = isSigned() || Other.isSigned();
158   bool ResultIsSaturated = isSaturated() || Other.isSaturated();
159   bool ResultHasUnsignedPadding = false;
160   if (!ResultIsSigned) {
161     // Both are unsigned.
162     ResultHasUnsignedPadding = hasUnsignedPadding() &&
163                                Other.hasUnsignedPadding() && !ResultIsSaturated;
164   }
165 
166   // If the result is signed, add an extra bit for the sign. Otherwise, if it is
167   // unsigned and has unsigned padding, we only need to add the extra padding
168   // bit back if we are not saturating.
169   if (ResultIsSigned || ResultHasUnsignedPadding)
170     CommonWidth++;
171 
172   return FixedPointSemantics(CommonWidth, CommonScale, ResultIsSigned,
173                              ResultIsSaturated, ResultHasUnsignedPadding);
174 }
175 
176 APFixedPoint APFixedPoint::add(const APFixedPoint &Other,
177                                bool *Overflow) const {
178   auto CommonFXSema = Sema.getCommonSemantics(Other.getSemantics());
179   APFixedPoint ConvertedThis = convert(CommonFXSema);
180   APFixedPoint ConvertedOther = Other.convert(CommonFXSema);
181   APSInt ThisVal = ConvertedThis.getValue();
182   APSInt OtherVal = ConvertedOther.getValue();
183   bool Overflowed = false;
184 
185   APSInt Result;
186   if (CommonFXSema.isSaturated()) {
187     Result = CommonFXSema.isSigned() ? ThisVal.sadd_sat(OtherVal)
188                                      : ThisVal.uadd_sat(OtherVal);
189   } else {
190     Result = ThisVal.isSigned() ? ThisVal.sadd_ov(OtherVal, Overflowed)
191                                 : ThisVal.uadd_ov(OtherVal, Overflowed);
192   }
193 
194   if (Overflow)
195     *Overflow = Overflowed;
196 
197   return APFixedPoint(Result, CommonFXSema);
198 }
199 
200 APFixedPoint APFixedPoint::sub(const APFixedPoint &Other,
201                                bool *Overflow) const {
202   auto CommonFXSema = Sema.getCommonSemantics(Other.getSemantics());
203   APFixedPoint ConvertedThis = convert(CommonFXSema);
204   APFixedPoint ConvertedOther = Other.convert(CommonFXSema);
205   APSInt ThisVal = ConvertedThis.getValue();
206   APSInt OtherVal = ConvertedOther.getValue();
207   bool Overflowed = false;
208 
209   APSInt Result;
210   if (CommonFXSema.isSaturated()) {
211     Result = CommonFXSema.isSigned() ? ThisVal.ssub_sat(OtherVal)
212                                      : ThisVal.usub_sat(OtherVal);
213   } else {
214     Result = ThisVal.isSigned() ? ThisVal.ssub_ov(OtherVal, Overflowed)
215                                 : ThisVal.usub_ov(OtherVal, Overflowed);
216   }
217 
218   if (Overflow)
219     *Overflow = Overflowed;
220 
221   return APFixedPoint(Result, CommonFXSema);
222 }
223 
224 APFixedPoint APFixedPoint::mul(const APFixedPoint &Other,
225                                bool *Overflow) const {
226   auto CommonFXSema = Sema.getCommonSemantics(Other.getSemantics());
227   APFixedPoint ConvertedThis = convert(CommonFXSema);
228   APFixedPoint ConvertedOther = Other.convert(CommonFXSema);
229   APSInt ThisVal = ConvertedThis.getValue();
230   APSInt OtherVal = ConvertedOther.getValue();
231   bool Overflowed = false;
232 
233   // Widen the LHS and RHS so we can perform a full multiplication.
234   unsigned Wide = CommonFXSema.getWidth() * 2;
235   if (CommonFXSema.isSigned()) {
236     ThisVal = ThisVal.sext(Wide);
237     OtherVal = OtherVal.sext(Wide);
238   } else {
239     ThisVal = ThisVal.zext(Wide);
240     OtherVal = OtherVal.zext(Wide);
241   }
242 
243   // Perform the full multiplication and downscale to get the same scale.
244   //
245   // Note that the right shifts here perform an implicit downwards rounding.
246   // This rounding could discard bits that would technically place the result
247   // outside the representable range. We interpret the spec as allowing us to
248   // perform the rounding step first, avoiding the overflow case that would
249   // arise.
250   APSInt Result;
251   if (CommonFXSema.isSigned())
252     Result = ThisVal.smul_ov(OtherVal, Overflowed)
253                     .ashr(CommonFXSema.getScale());
254   else
255     Result = ThisVal.umul_ov(OtherVal, Overflowed)
256                     .lshr(CommonFXSema.getScale());
257   assert(!Overflowed && "Full multiplication cannot overflow!");
258   Result.setIsSigned(CommonFXSema.isSigned());
259 
260   // If our result lies outside of the representative range of the common
261   // semantic, we either have overflow or saturation.
262   APSInt Max = APFixedPoint::getMax(CommonFXSema).getValue()
263                                                  .extOrTrunc(Wide);
264   APSInt Min = APFixedPoint::getMin(CommonFXSema).getValue()
265                                                  .extOrTrunc(Wide);
266   if (CommonFXSema.isSaturated()) {
267     if (Result < Min)
268       Result = Min;
269     else if (Result > Max)
270       Result = Max;
271   } else
272     Overflowed = Result < Min || Result > Max;
273 
274   if (Overflow)
275     *Overflow = Overflowed;
276 
277   return APFixedPoint(Result.sextOrTrunc(CommonFXSema.getWidth()),
278                       CommonFXSema);
279 }
280 
281 APFixedPoint APFixedPoint::div(const APFixedPoint &Other,
282                                bool *Overflow) const {
283   auto CommonFXSema = Sema.getCommonSemantics(Other.getSemantics());
284   APFixedPoint ConvertedThis = convert(CommonFXSema);
285   APFixedPoint ConvertedOther = Other.convert(CommonFXSema);
286   APSInt ThisVal = ConvertedThis.getValue();
287   APSInt OtherVal = ConvertedOther.getValue();
288   bool Overflowed = false;
289 
290   // Widen the LHS and RHS so we can perform a full division.
291   unsigned Wide = CommonFXSema.getWidth() * 2;
292   if (CommonFXSema.isSigned()) {
293     ThisVal = ThisVal.sext(Wide);
294     OtherVal = OtherVal.sext(Wide);
295   } else {
296     ThisVal = ThisVal.zext(Wide);
297     OtherVal = OtherVal.zext(Wide);
298   }
299 
300   // Upscale to compensate for the loss of precision from division, and
301   // perform the full division.
302   ThisVal = ThisVal.shl(CommonFXSema.getScale());
303   APSInt Result;
304   if (CommonFXSema.isSigned()) {
305     APInt Rem;
306     APInt::sdivrem(ThisVal, OtherVal, Result, Rem);
307     // If the quotient is negative and the remainder is nonzero, round
308     // towards negative infinity by subtracting epsilon from the result.
309     if (ThisVal.isNegative() != OtherVal.isNegative() && !Rem.isZero())
310       Result = Result - 1;
311   } else
312     Result = ThisVal.udiv(OtherVal);
313   Result.setIsSigned(CommonFXSema.isSigned());
314 
315   // If our result lies outside of the representative range of the common
316   // semantic, we either have overflow or saturation.
317   APSInt Max = APFixedPoint::getMax(CommonFXSema).getValue()
318                                                  .extOrTrunc(Wide);
319   APSInt Min = APFixedPoint::getMin(CommonFXSema).getValue()
320                                                  .extOrTrunc(Wide);
321   if (CommonFXSema.isSaturated()) {
322     if (Result < Min)
323       Result = Min;
324     else if (Result > Max)
325       Result = Max;
326   } else
327     Overflowed = Result < Min || Result > Max;
328 
329   if (Overflow)
330     *Overflow = Overflowed;
331 
332   return APFixedPoint(Result.sextOrTrunc(CommonFXSema.getWidth()),
333                       CommonFXSema);
334 }
335 
336 APFixedPoint APFixedPoint::shl(unsigned Amt, bool *Overflow) const {
337   APSInt ThisVal = Val;
338   bool Overflowed = false;
339 
340   // Widen the LHS.
341   unsigned Wide = Sema.getWidth() * 2;
342   if (Sema.isSigned())
343     ThisVal = ThisVal.sext(Wide);
344   else
345     ThisVal = ThisVal.zext(Wide);
346 
347   // Clamp the shift amount at the original width, and perform the shift.
348   Amt = std::min(Amt, ThisVal.getBitWidth());
349   APSInt Result = ThisVal << Amt;
350   Result.setIsSigned(Sema.isSigned());
351 
352   // If our result lies outside of the representative range of the
353   // semantic, we either have overflow or saturation.
354   APSInt Max = APFixedPoint::getMax(Sema).getValue().extOrTrunc(Wide);
355   APSInt Min = APFixedPoint::getMin(Sema).getValue().extOrTrunc(Wide);
356   if (Sema.isSaturated()) {
357     if (Result < Min)
358       Result = Min;
359     else if (Result > Max)
360       Result = Max;
361   } else
362     Overflowed = Result < Min || Result > Max;
363 
364   if (Overflow)
365     *Overflow = Overflowed;
366 
367   return APFixedPoint(Result.sextOrTrunc(Sema.getWidth()), Sema);
368 }
369 
370 void APFixedPoint::toString(SmallVectorImpl<char> &Str) const {
371   APSInt Val = getValue();
372   unsigned Scale = getScale();
373 
374   if (Val.isSigned() && Val.isNegative() && Val != -Val) {
375     Val = -Val;
376     Str.push_back('-');
377   }
378 
379   APSInt IntPart = Val >> Scale;
380 
381   // Add 4 digits to hold the value after multiplying 10 (the radix)
382   unsigned Width = Val.getBitWidth() + 4;
383   APInt FractPart = Val.zextOrTrunc(Scale).zext(Width);
384   APInt FractPartMask = APInt::getAllOnes(Scale).zext(Width);
385   APInt RadixInt = APInt(Width, 10);
386 
387   IntPart.toString(Str, /*Radix=*/10);
388   Str.push_back('.');
389   do {
390     (FractPart * RadixInt)
391         .lshr(Scale)
392         .toString(Str, /*Radix=*/10, Val.isSigned());
393     FractPart = (FractPart * RadixInt) & FractPartMask;
394   } while (FractPart != 0);
395 }
396 
397 APFixedPoint APFixedPoint::negate(bool *Overflow) const {
398   if (!isSaturated()) {
399     if (Overflow)
400       *Overflow =
401           (!isSigned() && Val != 0) || (isSigned() && Val.isMinSignedValue());
402     return APFixedPoint(-Val, Sema);
403   }
404 
405   // We never overflow for saturation
406   if (Overflow)
407     *Overflow = false;
408 
409   if (isSigned())
410     return Val.isMinSignedValue() ? getMax(Sema) : APFixedPoint(-Val, Sema);
411   else
412     return APFixedPoint(Sema);
413 }
414 
415 APSInt APFixedPoint::convertToInt(unsigned DstWidth, bool DstSign,
416                                   bool *Overflow) const {
417   APSInt Result = getIntPart();
418   unsigned SrcWidth = getWidth();
419 
420   APSInt DstMin = APSInt::getMinValue(DstWidth, !DstSign);
421   APSInt DstMax = APSInt::getMaxValue(DstWidth, !DstSign);
422 
423   if (SrcWidth < DstWidth) {
424     Result = Result.extend(DstWidth);
425   } else if (SrcWidth > DstWidth) {
426     DstMin = DstMin.extend(SrcWidth);
427     DstMax = DstMax.extend(SrcWidth);
428   }
429 
430   if (Overflow) {
431     if (Result.isSigned() && !DstSign) {
432       *Overflow = Result.isNegative() || Result.ugt(DstMax);
433     } else if (Result.isUnsigned() && DstSign) {
434       *Overflow = Result.ugt(DstMax);
435     } else {
436       *Overflow = Result < DstMin || Result > DstMax;
437     }
438   }
439 
440   Result.setIsSigned(DstSign);
441   return Result.extOrTrunc(DstWidth);
442 }
443 
444 const fltSemantics *APFixedPoint::promoteFloatSemantics(const fltSemantics *S) {
445   if (S == &APFloat::BFloat())
446     return &APFloat::IEEEdouble();
447   else if (S == &APFloat::IEEEhalf())
448     return &APFloat::IEEEsingle();
449   else if (S == &APFloat::IEEEsingle())
450     return &APFloat::IEEEdouble();
451   else if (S == &APFloat::IEEEdouble())
452     return &APFloat::IEEEquad();
453   llvm_unreachable("Could not promote float type!");
454 }
455 
456 APFloat APFixedPoint::convertToFloat(const fltSemantics &FloatSema) const {
457   // For some operations, rounding mode has an effect on the result, while
458   // other operations are lossless and should never result in rounding.
459   // To signify which these operations are, we define two rounding modes here.
460   APFloat::roundingMode RM = APFloat::rmNearestTiesToEven;
461   APFloat::roundingMode LosslessRM = APFloat::rmTowardZero;
462 
463   // Make sure that we are operating in a type that works with this fixed-point
464   // semantic.
465   const fltSemantics *OpSema = &FloatSema;
466   while (!Sema.fitsInFloatSemantics(*OpSema))
467     OpSema = promoteFloatSemantics(OpSema);
468 
469   // Convert the fixed point value bits as an integer. If the floating point
470   // value does not have the required precision, we will round according to the
471   // given mode.
472   APFloat Flt(*OpSema);
473   APFloat::opStatus S = Flt.convertFromAPInt(Val, Sema.isSigned(), RM);
474 
475   // If we cared about checking for precision loss, we could look at this
476   // status.
477   (void)S;
478 
479   // Scale down the integer value in the float to match the correct scaling
480   // factor.
481   APFloat ScaleFactor(std::pow(2, -(int)Sema.getScale()));
482   bool Ignored;
483   ScaleFactor.convert(*OpSema, LosslessRM, &Ignored);
484   Flt.multiply(ScaleFactor, LosslessRM);
485 
486   if (OpSema != &FloatSema)
487     Flt.convert(FloatSema, RM, &Ignored);
488 
489   return Flt;
490 }
491 
492 APFixedPoint APFixedPoint::getFromIntValue(const APSInt &Value,
493                                            const FixedPointSemantics &DstFXSema,
494                                            bool *Overflow) {
495   FixedPointSemantics IntFXSema = FixedPointSemantics::GetIntegerSemantics(
496       Value.getBitWidth(), Value.isSigned());
497   return APFixedPoint(Value, IntFXSema).convert(DstFXSema, Overflow);
498 }
499 
500 APFixedPoint
501 APFixedPoint::getFromFloatValue(const APFloat &Value,
502                                 const FixedPointSemantics &DstFXSema,
503                                 bool *Overflow) {
504   // For some operations, rounding mode has an effect on the result, while
505   // other operations are lossless and should never result in rounding.
506   // To signify which these operations are, we define two rounding modes here,
507   // even though they are the same mode.
508   APFloat::roundingMode RM = APFloat::rmTowardZero;
509   APFloat::roundingMode LosslessRM = APFloat::rmTowardZero;
510 
511   const fltSemantics &FloatSema = Value.getSemantics();
512 
513   if (Value.isNaN()) {
514     // Handle NaN immediately.
515     if (Overflow)
516       *Overflow = true;
517     return APFixedPoint(DstFXSema);
518   }
519 
520   // Make sure that we are operating in a type that works with this fixed-point
521   // semantic.
522   const fltSemantics *OpSema = &FloatSema;
523   while (!DstFXSema.fitsInFloatSemantics(*OpSema))
524     OpSema = promoteFloatSemantics(OpSema);
525 
526   APFloat Val = Value;
527 
528   bool Ignored;
529   if (&FloatSema != OpSema)
530     Val.convert(*OpSema, LosslessRM, &Ignored);
531 
532   // Scale up the float so that the 'fractional' part of the mantissa ends up in
533   // the integer range instead. Rounding mode is irrelevant here.
534   // It is fine if this overflows to infinity even for saturating types,
535   // since we will use floating point comparisons to check for saturation.
536   APFloat ScaleFactor(std::pow(2, DstFXSema.getScale()));
537   ScaleFactor.convert(*OpSema, LosslessRM, &Ignored);
538   Val.multiply(ScaleFactor, LosslessRM);
539 
540   // Convert to the integral representation of the value. This rounding mode
541   // is significant.
542   APSInt Res(DstFXSema.getWidth(), !DstFXSema.isSigned());
543   Val.convertToInteger(Res, RM, &Ignored);
544 
545   // Round the integral value and scale back. This makes the
546   // overflow calculations below work properly. If we do not round here,
547   // we risk checking for overflow with a value that is outside the
548   // representable range of the fixed-point semantic even though no overflow
549   // would occur had we rounded first.
550   ScaleFactor = APFloat(std::pow(2, -(int)DstFXSema.getScale()));
551   ScaleFactor.convert(*OpSema, LosslessRM, &Ignored);
552   Val.roundToIntegral(RM);
553   Val.multiply(ScaleFactor, LosslessRM);
554 
555   // Check for overflow/saturation by checking if the floating point value
556   // is outside the range representable by the fixed-point value.
557   APFloat FloatMax = getMax(DstFXSema).convertToFloat(*OpSema);
558   APFloat FloatMin = getMin(DstFXSema).convertToFloat(*OpSema);
559   bool Overflowed = false;
560   if (DstFXSema.isSaturated()) {
561     if (Val > FloatMax)
562       Res = getMax(DstFXSema).getValue();
563     else if (Val < FloatMin)
564       Res = getMin(DstFXSema).getValue();
565   } else
566     Overflowed = Val > FloatMax || Val < FloatMin;
567 
568   if (Overflow)
569     *Overflow = Overflowed;
570 
571   return APFixedPoint(Res, DstFXSema);
572 }
573 
574 } // namespace llvm
575