xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp (revision e6bfd18d21b225af6a0ed67ceeaf1293b7b9eba5)
1 //===- Float2Int.cpp - Demote floating point ops to work on integers ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the Float2Int pass, which aims to demote floating
10 // point operations to work on integers, where that is losslessly possible.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Transforms/Scalar/Float2Int.h"
15 #include "llvm/ADT/APInt.h"
16 #include "llvm/ADT/APSInt.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/GlobalsModRef.h"
19 #include "llvm/IR/Constants.h"
20 #include "llvm/IR/Dominators.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Pass.h"
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/Debug.h"
27 #include "llvm/Support/raw_ostream.h"
28 #include "llvm/Transforms/Scalar.h"
29 #include <deque>
30 
31 #define DEBUG_TYPE "float2int"
32 
33 using namespace llvm;
34 
35 // The algorithm is simple. Start at instructions that convert from the
36 // float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use
37 // graph, using an equivalence datastructure to unify graphs that interfere.
38 //
39 // Mappable instructions are those with an integer corrollary that, given
40 // integer domain inputs, produce an integer output; fadd, for example.
41 //
42 // If a non-mappable instruction is seen, this entire def-use graph is marked
43 // as non-transformable. If we see an instruction that converts from the
44 // integer domain to FP domain (uitofp,sitofp), we terminate our walk.
45 
46 /// The largest integer type worth dealing with.
47 static cl::opt<unsigned>
48 MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,
49              cl::desc("Max integer bitwidth to consider in float2int"
50                       "(default=64)"));
51 
52 namespace {
53   struct Float2IntLegacyPass : public FunctionPass {
54     static char ID; // Pass identification, replacement for typeid
55     Float2IntLegacyPass() : FunctionPass(ID) {
56       initializeFloat2IntLegacyPassPass(*PassRegistry::getPassRegistry());
57     }
58 
59     bool runOnFunction(Function &F) override {
60       if (skipFunction(F))
61         return false;
62 
63       const DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
64       return Impl.runImpl(F, DT);
65     }
66 
67     void getAnalysisUsage(AnalysisUsage &AU) const override {
68       AU.setPreservesCFG();
69       AU.addRequired<DominatorTreeWrapperPass>();
70       AU.addPreserved<GlobalsAAWrapperPass>();
71     }
72 
73   private:
74     Float2IntPass Impl;
75   };
76 }
77 
78 char Float2IntLegacyPass::ID = 0;
79 INITIALIZE_PASS(Float2IntLegacyPass, "float2int", "Float to int", false, false)
80 
81 // Given a FCmp predicate, return a matching ICmp predicate if one
82 // exists, otherwise return BAD_ICMP_PREDICATE.
83 static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) {
84   switch (P) {
85   case CmpInst::FCMP_OEQ:
86   case CmpInst::FCMP_UEQ:
87     return CmpInst::ICMP_EQ;
88   case CmpInst::FCMP_OGT:
89   case CmpInst::FCMP_UGT:
90     return CmpInst::ICMP_SGT;
91   case CmpInst::FCMP_OGE:
92   case CmpInst::FCMP_UGE:
93     return CmpInst::ICMP_SGE;
94   case CmpInst::FCMP_OLT:
95   case CmpInst::FCMP_ULT:
96     return CmpInst::ICMP_SLT;
97   case CmpInst::FCMP_OLE:
98   case CmpInst::FCMP_ULE:
99     return CmpInst::ICMP_SLE;
100   case CmpInst::FCMP_ONE:
101   case CmpInst::FCMP_UNE:
102     return CmpInst::ICMP_NE;
103   default:
104     return CmpInst::BAD_ICMP_PREDICATE;
105   }
106 }
107 
108 // Given a floating point binary operator, return the matching
109 // integer version.
110 static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
111   switch (Opcode) {
112   default: llvm_unreachable("Unhandled opcode!");
113   case Instruction::FAdd: return Instruction::Add;
114   case Instruction::FSub: return Instruction::Sub;
115   case Instruction::FMul: return Instruction::Mul;
116   }
117 }
118 
119 // Find the roots - instructions that convert from the FP domain to
120 // integer domain.
121 void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) {
122   for (BasicBlock &BB : F) {
123     // Unreachable code can take on strange forms that we are not prepared to
124     // handle. For example, an instruction may have itself as an operand.
125     if (!DT.isReachableFromEntry(&BB))
126       continue;
127 
128     for (Instruction &I : BB) {
129       if (isa<VectorType>(I.getType()))
130         continue;
131       switch (I.getOpcode()) {
132       default: break;
133       case Instruction::FPToUI:
134       case Instruction::FPToSI:
135         Roots.insert(&I);
136         break;
137       case Instruction::FCmp:
138         if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) !=
139             CmpInst::BAD_ICMP_PREDICATE)
140           Roots.insert(&I);
141         break;
142       }
143     }
144   }
145 }
146 
147 // Helper - mark I as having been traversed, having range R.
148 void Float2IntPass::seen(Instruction *I, ConstantRange R) {
149   LLVM_DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n");
150   auto IT = SeenInsts.find(I);
151   if (IT != SeenInsts.end())
152     IT->second = std::move(R);
153   else
154     SeenInsts.insert(std::make_pair(I, std::move(R)));
155 }
156 
157 // Helper - get a range representing a poison value.
158 ConstantRange Float2IntPass::badRange() {
159   return ConstantRange::getFull(MaxIntegerBW + 1);
160 }
161 ConstantRange Float2IntPass::unknownRange() {
162   return ConstantRange::getEmpty(MaxIntegerBW + 1);
163 }
164 ConstantRange Float2IntPass::validateRange(ConstantRange R) {
165   if (R.getBitWidth() > MaxIntegerBW + 1)
166     return badRange();
167   return R;
168 }
169 
170 // The most obvious way to structure the search is a depth-first, eager
171 // search from each root. However, that require direct recursion and so
172 // can only handle small instruction sequences. Instead, we split the search
173 // up into two phases:
174 //   - walkBackwards:  A breadth-first walk of the use-def graph starting from
175 //                     the roots. Populate "SeenInsts" with interesting
176 //                     instructions and poison values if they're obvious and
177 //                     cheap to compute. Calculate the equivalance set structure
178 //                     while we're here too.
179 //   - walkForwards:  Iterate over SeenInsts in reverse order, so we visit
180 //                     defs before their uses. Calculate the real range info.
181 
182 // Breadth-first walk of the use-def graph; determine the set of nodes
183 // we care about and eagerly determine if some of them are poisonous.
184 void Float2IntPass::walkBackwards() {
185   std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());
186   while (!Worklist.empty()) {
187     Instruction *I = Worklist.back();
188     Worklist.pop_back();
189 
190     if (SeenInsts.find(I) != SeenInsts.end())
191       // Seen already.
192       continue;
193 
194     switch (I->getOpcode()) {
195       // FIXME: Handle select and phi nodes.
196     default:
197       // Path terminated uncleanly.
198       seen(I, badRange());
199       break;
200 
201     case Instruction::UIToFP:
202     case Instruction::SIToFP: {
203       // Path terminated cleanly - use the type of the integer input to seed
204       // the analysis.
205       unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
206       auto Input = ConstantRange::getFull(BW);
207       auto CastOp = (Instruction::CastOps)I->getOpcode();
208       seen(I, validateRange(Input.castOp(CastOp, MaxIntegerBW+1)));
209       continue;
210     }
211 
212     case Instruction::FNeg:
213     case Instruction::FAdd:
214     case Instruction::FSub:
215     case Instruction::FMul:
216     case Instruction::FPToUI:
217     case Instruction::FPToSI:
218     case Instruction::FCmp:
219       seen(I, unknownRange());
220       break;
221     }
222 
223     for (Value *O : I->operands()) {
224       if (Instruction *OI = dyn_cast<Instruction>(O)) {
225         // Unify def-use chains if they interfere.
226         ECs.unionSets(I, OI);
227         if (SeenInsts.find(I)->second != badRange())
228           Worklist.push_back(OI);
229       } else if (!isa<ConstantFP>(O)) {
230         // Not an instruction or ConstantFP? we can't do anything.
231         seen(I, badRange());
232       }
233     }
234   }
235 }
236 
237 // Calculate result range from operand ranges.
238 // Return None if the range cannot be calculated yet.
239 Optional<ConstantRange> Float2IntPass::calcRange(Instruction *I) {
240   SmallVector<ConstantRange, 4> OpRanges;
241   for (Value *O : I->operands()) {
242     if (Instruction *OI = dyn_cast<Instruction>(O)) {
243       auto OpIt = SeenInsts.find(OI);
244       assert(OpIt != SeenInsts.end() && "def not seen before use!");
245       if (OpIt->second == unknownRange())
246         return None; // Wait until operand range has been calculated.
247       OpRanges.push_back(OpIt->second);
248     } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) {
249       // Work out if the floating point number can be losslessly represented
250       // as an integer.
251       // APFloat::convertToInteger(&Exact) purports to do what we want, but
252       // the exactness can be too precise. For example, negative zero can
253       // never be exactly converted to an integer.
254       //
255       // Instead, we ask APFloat to round itself to an integral value - this
256       // preserves sign-of-zero - then compare the result with the original.
257       //
258       const APFloat &F = CF->getValueAPF();
259 
260       // First, weed out obviously incorrect values. Non-finite numbers
261       // can't be represented and neither can negative zero, unless
262       // we're in fast math mode.
263       if (!F.isFinite() ||
264           (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) &&
265            !I->hasNoSignedZeros()))
266         return badRange();
267 
268       APFloat NewF = F;
269       auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven);
270       if (Res != APFloat::opOK || NewF != F)
271         return badRange();
272 
273       // OK, it's representable. Now get it.
274       APSInt Int(MaxIntegerBW+1, false);
275       bool Exact;
276       CF->getValueAPF().convertToInteger(Int,
277                                          APFloat::rmNearestTiesToEven,
278                                          &Exact);
279       OpRanges.push_back(ConstantRange(Int));
280     } else {
281       llvm_unreachable("Should have already marked this as badRange!");
282     }
283   }
284 
285   switch (I->getOpcode()) {
286   // FIXME: Handle select and phi nodes.
287   default:
288   case Instruction::UIToFP:
289   case Instruction::SIToFP:
290     llvm_unreachable("Should have been handled in walkForwards!");
291 
292   case Instruction::FNeg: {
293     assert(OpRanges.size() == 1 && "FNeg is a unary operator!");
294     unsigned Size = OpRanges[0].getBitWidth();
295     auto Zero = ConstantRange(APInt::getZero(Size));
296     return Zero.sub(OpRanges[0]);
297   }
298 
299   case Instruction::FAdd:
300   case Instruction::FSub:
301   case Instruction::FMul: {
302     assert(OpRanges.size() == 2 && "its a binary operator!");
303     auto BinOp = (Instruction::BinaryOps) I->getOpcode();
304     return OpRanges[0].binaryOp(BinOp, OpRanges[1]);
305   }
306 
307   //
308   // Root-only instructions - we'll only see these if they're the
309   //                          first node in a walk.
310   //
311   case Instruction::FPToUI:
312   case Instruction::FPToSI: {
313     assert(OpRanges.size() == 1 && "FPTo[US]I is a unary operator!");
314     // Note: We're ignoring the casts output size here as that's what the
315     // caller expects.
316     auto CastOp = (Instruction::CastOps)I->getOpcode();
317     return OpRanges[0].castOp(CastOp, MaxIntegerBW+1);
318   }
319 
320   case Instruction::FCmp:
321     assert(OpRanges.size() == 2 && "FCmp is a binary operator!");
322     return OpRanges[0].unionWith(OpRanges[1]);
323   }
324 }
325 
326 // Walk forwards down the list of seen instructions, so we visit defs before
327 // uses.
328 void Float2IntPass::walkForwards() {
329   std::deque<Instruction *> Worklist;
330   for (const auto &Pair : SeenInsts)
331     if (Pair.second == unknownRange())
332       Worklist.push_back(Pair.first);
333 
334   while (!Worklist.empty()) {
335     Instruction *I = Worklist.back();
336     Worklist.pop_back();
337 
338     if (Optional<ConstantRange> Range = calcRange(I))
339       seen(I, *Range);
340     else
341       Worklist.push_front(I); // Reprocess later.
342   }
343 }
344 
345 // If there is a valid transform to be done, do it.
346 bool Float2IntPass::validateAndTransform() {
347   bool MadeChange = false;
348 
349   // Iterate over every disjoint partition of the def-use graph.
350   for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) {
351     ConstantRange R(MaxIntegerBW + 1, false);
352     bool Fail = false;
353     Type *ConvertedToTy = nullptr;
354 
355     // For every member of the partition, union all the ranges together.
356     for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
357          MI != ME; ++MI) {
358       Instruction *I = *MI;
359       auto SeenI = SeenInsts.find(I);
360       if (SeenI == SeenInsts.end())
361         continue;
362 
363       R = R.unionWith(SeenI->second);
364       // We need to ensure I has no users that have not been seen.
365       // If it does, transformation would be illegal.
366       //
367       // Don't count the roots, as they terminate the graphs.
368       if (!Roots.contains(I)) {
369         // Set the type of the conversion while we're here.
370         if (!ConvertedToTy)
371           ConvertedToTy = I->getType();
372         for (User *U : I->users()) {
373           Instruction *UI = dyn_cast<Instruction>(U);
374           if (!UI || SeenInsts.find(UI) == SeenInsts.end()) {
375             LLVM_DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n");
376             Fail = true;
377             break;
378           }
379         }
380       }
381       if (Fail)
382         break;
383     }
384 
385     // If the set was empty, or we failed, or the range is poisonous,
386     // bail out.
387     if (ECs.member_begin(It) == ECs.member_end() || Fail ||
388         R.isFullSet() || R.isSignWrappedSet())
389       continue;
390     assert(ConvertedToTy && "Must have set the convertedtoty by this point!");
391 
392     // The number of bits required is the maximum of the upper and
393     // lower limits, plus one so it can be signed.
394     unsigned MinBW = std::max(R.getLower().getMinSignedBits(),
395                               R.getUpper().getMinSignedBits()) + 1;
396     LLVM_DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n");
397 
398     // If we've run off the realms of the exactly representable integers,
399     // the floating point result will differ from an integer approximation.
400 
401     // Do we need more bits than are in the mantissa of the type we converted
402     // to? semanticsPrecision returns the number of mantissa bits plus one
403     // for the sign bit.
404     unsigned MaxRepresentableBits
405       = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1;
406     if (MinBW > MaxRepresentableBits) {
407       LLVM_DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n");
408       continue;
409     }
410     if (MinBW > 64) {
411       LLVM_DEBUG(
412           dbgs() << "F2I: Value requires more than 64 bits to represent!\n");
413       continue;
414     }
415 
416     // OK, R is known to be representable. Now pick a type for it.
417     // FIXME: Pick the smallest legal type that will fit.
418     Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx);
419 
420     for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
421          MI != ME; ++MI)
422       convert(*MI, Ty);
423     MadeChange = true;
424   }
425 
426   return MadeChange;
427 }
428 
429 Value *Float2IntPass::convert(Instruction *I, Type *ToTy) {
430   if (ConvertedInsts.find(I) != ConvertedInsts.end())
431     // Already converted this instruction.
432     return ConvertedInsts[I];
433 
434   SmallVector<Value*,4> NewOperands;
435   for (Value *V : I->operands()) {
436     // Don't recurse if we're an instruction that terminates the path.
437     if (I->getOpcode() == Instruction::UIToFP ||
438         I->getOpcode() == Instruction::SIToFP) {
439       NewOperands.push_back(V);
440     } else if (Instruction *VI = dyn_cast<Instruction>(V)) {
441       NewOperands.push_back(convert(VI, ToTy));
442     } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) {
443       APSInt Val(ToTy->getPrimitiveSizeInBits(), /*isUnsigned=*/false);
444       bool Exact;
445       CF->getValueAPF().convertToInteger(Val,
446                                          APFloat::rmNearestTiesToEven,
447                                          &Exact);
448       NewOperands.push_back(ConstantInt::get(ToTy, Val));
449     } else {
450       llvm_unreachable("Unhandled operand type?");
451     }
452   }
453 
454   // Now create a new instruction.
455   IRBuilder<> IRB(I);
456   Value *NewV = nullptr;
457   switch (I->getOpcode()) {
458   default: llvm_unreachable("Unhandled instruction!");
459 
460   case Instruction::FPToUI:
461     NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType());
462     break;
463 
464   case Instruction::FPToSI:
465     NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType());
466     break;
467 
468   case Instruction::FCmp: {
469     CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate());
470     assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!");
471     NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName());
472     break;
473   }
474 
475   case Instruction::UIToFP:
476     NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy);
477     break;
478 
479   case Instruction::SIToFP:
480     NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy);
481     break;
482 
483   case Instruction::FNeg:
484     NewV = IRB.CreateNeg(NewOperands[0], I->getName());
485     break;
486 
487   case Instruction::FAdd:
488   case Instruction::FSub:
489   case Instruction::FMul:
490     NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()),
491                            NewOperands[0], NewOperands[1],
492                            I->getName());
493     break;
494   }
495 
496   // If we're a root instruction, RAUW.
497   if (Roots.count(I))
498     I->replaceAllUsesWith(NewV);
499 
500   ConvertedInsts[I] = NewV;
501   return NewV;
502 }
503 
504 // Perform dead code elimination on the instructions we just modified.
505 void Float2IntPass::cleanup() {
506   for (auto &I : reverse(ConvertedInsts))
507     I.first->eraseFromParent();
508 }
509 
510 bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {
511   LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
512   // Clear out all state.
513   ECs = EquivalenceClasses<Instruction*>();
514   SeenInsts.clear();
515   ConvertedInsts.clear();
516   Roots.clear();
517 
518   Ctx = &F.getParent()->getContext();
519 
520   findRoots(F, DT);
521 
522   walkBackwards();
523   walkForwards();
524 
525   bool Modified = validateAndTransform();
526   if (Modified)
527     cleanup();
528   return Modified;
529 }
530 
531 namespace llvm {
532 FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); }
533 
534 PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) {
535   const DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
536   if (!runImpl(F, DT))
537     return PreservedAnalyses::all();
538 
539   PreservedAnalyses PA;
540   PA.preserveSet<CFGAnalyses>();
541   return PA;
542 }
543 } // End namespace llvm
544