xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===- ConstantHoisting.cpp - Prepare code for expensive constants --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass identifies expensive constants to hoist and coalesces them to
10 // better prepare it for SelectionDAG-based code generation. This works around
11 // the limitations of the basic-block-at-a-time approach.
12 //
13 // First it scans all instructions for integer constants and calculates its
14 // cost. If the constant can be folded into the instruction (the cost is
15 // TCC_Free) or the cost is just a simple operation (TCC_BASIC), then we don't
16 // consider it expensive and leave it alone. This is the default behavior and
17 // the default implementation of getIntImmCostInst will always return TCC_Free.
18 //
19 // If the cost is more than TCC_BASIC, then the integer constant can't be folded
20 // into the instruction and it might be beneficial to hoist the constant.
21 // Similar constants are coalesced to reduce register pressure and
22 // materialization code.
23 //
24 // When a constant is hoisted, it is also hidden behind a bitcast to force it to
25 // be live-out of the basic block. Otherwise the constant would be just
26 // duplicated and each basic block would have its own copy in the SelectionDAG.
27 // The SelectionDAG recognizes such constants as opaque and doesn't perform
28 // certain transformations on them, which would create a new expensive constant.
29 //
30 // This optimization is only applied to integer constants in instructions and
31 // simple (this means not nested) constant cast expressions. For example:
32 // %0 = load i64* inttoptr (i64 big_constant to i64*)
33 //===----------------------------------------------------------------------===//
34 
35 #include "llvm/Transforms/Scalar/ConstantHoisting.h"
36 #include "llvm/ADT/APInt.h"
37 #include "llvm/ADT/DenseMap.h"
38 #include "llvm/ADT/SmallPtrSet.h"
39 #include "llvm/ADT/SmallVector.h"
40 #include "llvm/ADT/Statistic.h"
41 #include "llvm/Analysis/BlockFrequencyInfo.h"
42 #include "llvm/Analysis/ProfileSummaryInfo.h"
43 #include "llvm/Analysis/TargetTransformInfo.h"
44 #include "llvm/IR/BasicBlock.h"
45 #include "llvm/IR/Constants.h"
46 #include "llvm/IR/DataLayout.h"
47 #include "llvm/IR/DebugInfoMetadata.h"
48 #include "llvm/IR/Dominators.h"
49 #include "llvm/IR/Function.h"
50 #include "llvm/IR/InstrTypes.h"
51 #include "llvm/IR/Instruction.h"
52 #include "llvm/IR/Instructions.h"
53 #include "llvm/IR/IntrinsicInst.h"
54 #include "llvm/IR/Operator.h"
55 #include "llvm/IR/Value.h"
56 #include "llvm/InitializePasses.h"
57 #include "llvm/Pass.h"
58 #include "llvm/Support/BlockFrequency.h"
59 #include "llvm/Support/Casting.h"
60 #include "llvm/Support/CommandLine.h"
61 #include "llvm/Support/Debug.h"
62 #include "llvm/Support/raw_ostream.h"
63 #include "llvm/Transforms/Scalar.h"
64 #include "llvm/Transforms/Utils/Local.h"
65 #include "llvm/Transforms/Utils/SizeOpts.h"
66 #include <algorithm>
67 #include <cassert>
68 #include <cstdint>
69 #include <iterator>
70 #include <tuple>
71 #include <utility>
72 
73 using namespace llvm;
74 using namespace consthoist;
75 
76 #define DEBUG_TYPE "consthoist"
77 
78 STATISTIC(NumConstantsHoisted, "Number of constants hoisted");
79 STATISTIC(NumConstantsRebased, "Number of constants rebased");
80 
81 static cl::opt<bool> ConstHoistWithBlockFrequency(
82     "consthoist-with-block-frequency", cl::init(true), cl::Hidden,
83     cl::desc("Enable the use of the block frequency analysis to reduce the "
84              "chance to execute const materialization more frequently than "
85              "without hoisting."));
86 
87 static cl::opt<bool> ConstHoistGEP(
88     "consthoist-gep", cl::init(false), cl::Hidden,
89     cl::desc("Try hoisting constant gep expressions"));
90 
91 static cl::opt<unsigned>
92 MinNumOfDependentToRebase("consthoist-min-num-to-rebase",
93     cl::desc("Do not rebase if number of dependent constants of a Base is less "
94              "than this number."),
95     cl::init(0), cl::Hidden);
96 
97 namespace {
98 
99 /// The constant hoisting pass.
100 class ConstantHoistingLegacyPass : public FunctionPass {
101 public:
102   static char ID; // Pass identification, replacement for typeid
103 
104   ConstantHoistingLegacyPass() : FunctionPass(ID) {
105     initializeConstantHoistingLegacyPassPass(*PassRegistry::getPassRegistry());
106   }
107 
108   bool runOnFunction(Function &Fn) override;
109 
110   StringRef getPassName() const override { return "Constant Hoisting"; }
111 
112   void getAnalysisUsage(AnalysisUsage &AU) const override {
113     AU.setPreservesCFG();
114     if (ConstHoistWithBlockFrequency)
115       AU.addRequired<BlockFrequencyInfoWrapperPass>();
116     AU.addRequired<DominatorTreeWrapperPass>();
117     AU.addRequired<ProfileSummaryInfoWrapperPass>();
118     AU.addRequired<TargetTransformInfoWrapperPass>();
119   }
120 
121 private:
122   ConstantHoistingPass Impl;
123 };
124 
125 } // end anonymous namespace
126 
127 char ConstantHoistingLegacyPass::ID = 0;
128 
129 INITIALIZE_PASS_BEGIN(ConstantHoistingLegacyPass, "consthoist",
130                       "Constant Hoisting", false, false)
131 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
132 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
133 INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass)
134 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
135 INITIALIZE_PASS_END(ConstantHoistingLegacyPass, "consthoist",
136                     "Constant Hoisting", false, false)
137 
138 FunctionPass *llvm::createConstantHoistingPass() {
139   return new ConstantHoistingLegacyPass();
140 }
141 
142 /// Perform the constant hoisting optimization for the given function.
143 bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) {
144   if (skipFunction(Fn))
145     return false;
146 
147   LLVM_DEBUG(dbgs() << "********** Begin Constant Hoisting **********\n");
148   LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
149 
150   bool MadeChange =
151       Impl.runImpl(Fn, getAnalysis<TargetTransformInfoWrapperPass>().getTTI(Fn),
152                    getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
153                    ConstHoistWithBlockFrequency
154                        ? &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI()
155                        : nullptr,
156                    Fn.getEntryBlock(),
157                    &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI());
158 
159   LLVM_DEBUG(dbgs() << "********** End Constant Hoisting **********\n");
160 
161   return MadeChange;
162 }
163 
164 void ConstantHoistingPass::collectMatInsertPts(
165     const RebasedConstantListType &RebasedConstants,
166     SmallVectorImpl<BasicBlock::iterator> &MatInsertPts) const {
167   for (const RebasedConstantInfo &RCI : RebasedConstants)
168     for (const ConstantUser &U : RCI.Uses)
169       MatInsertPts.emplace_back(findMatInsertPt(U.Inst, U.OpndIdx));
170 }
171 
172 /// Find the constant materialization insertion point.
173 BasicBlock::iterator ConstantHoistingPass::findMatInsertPt(Instruction *Inst,
174                                                            unsigned Idx) const {
175   // If the operand is a cast instruction, then we have to materialize the
176   // constant before the cast instruction.
177   if (Idx != ~0U) {
178     Value *Opnd = Inst->getOperand(Idx);
179     if (auto CastInst = dyn_cast<Instruction>(Opnd))
180       if (CastInst->isCast())
181         return CastInst->getIterator();
182   }
183 
184   // The simple and common case. This also includes constant expressions.
185   if (!isa<PHINode>(Inst) && !Inst->isEHPad())
186     return Inst->getIterator();
187 
188   // We can't insert directly before a phi node or an eh pad. Insert before
189   // the terminator of the incoming or dominating block.
190   assert(Entry != Inst->getParent() && "PHI or landing pad in entry block!");
191   BasicBlock *InsertionBlock = nullptr;
192   if (Idx != ~0U && isa<PHINode>(Inst)) {
193     InsertionBlock = cast<PHINode>(Inst)->getIncomingBlock(Idx);
194     if (!InsertionBlock->isEHPad()) {
195       return InsertionBlock->getTerminator()->getIterator();
196     }
197   } else {
198     InsertionBlock = Inst->getParent();
199   }
200 
201   // This must be an EH pad. Iterate over immediate dominators until we find a
202   // non-EH pad. We need to skip over catchswitch blocks, which are both EH pads
203   // and terminators.
204   auto *IDom = DT->getNode(InsertionBlock)->getIDom();
205   while (IDom->getBlock()->isEHPad()) {
206     assert(Entry != IDom->getBlock() && "eh pad in entry block");
207     IDom = IDom->getIDom();
208   }
209 
210   return IDom->getBlock()->getTerminator()->getIterator();
211 }
212 
213 /// Given \p BBs as input, find another set of BBs which collectively
214 /// dominates \p BBs and have the minimal sum of frequencies. Return the BB
215 /// set found in \p BBs.
216 static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI,
217                                  BasicBlock *Entry,
218                                  SetVector<BasicBlock *> &BBs) {
219   assert(!BBs.count(Entry) && "Assume Entry is not in BBs");
220   // Nodes on the current path to the root.
221   SmallPtrSet<BasicBlock *, 8> Path;
222   // Candidates includes any block 'BB' in set 'BBs' that is not strictly
223   // dominated by any other blocks in set 'BBs', and all nodes in the path
224   // in the dominator tree from Entry to 'BB'.
225   SmallPtrSet<BasicBlock *, 16> Candidates;
226   for (auto *BB : BBs) {
227     // Ignore unreachable basic blocks.
228     if (!DT.isReachableFromEntry(BB))
229       continue;
230     Path.clear();
231     // Walk up the dominator tree until Entry or another BB in BBs
232     // is reached. Insert the nodes on the way to the Path.
233     BasicBlock *Node = BB;
234     // The "Path" is a candidate path to be added into Candidates set.
235     bool isCandidate = false;
236     do {
237       Path.insert(Node);
238       if (Node == Entry || Candidates.count(Node)) {
239         isCandidate = true;
240         break;
241       }
242       assert(DT.getNode(Node)->getIDom() &&
243              "Entry doens't dominate current Node");
244       Node = DT.getNode(Node)->getIDom()->getBlock();
245     } while (!BBs.count(Node));
246 
247     // If isCandidate is false, Node is another Block in BBs dominating
248     // current 'BB'. Drop the nodes on the Path.
249     if (!isCandidate)
250       continue;
251 
252     // Add nodes on the Path into Candidates.
253     Candidates.insert(Path.begin(), Path.end());
254   }
255 
256   // Sort the nodes in Candidates in top-down order and save the nodes
257   // in Orders.
258   unsigned Idx = 0;
259   SmallVector<BasicBlock *, 16> Orders;
260   Orders.push_back(Entry);
261   while (Idx != Orders.size()) {
262     BasicBlock *Node = Orders[Idx++];
263     for (auto *ChildDomNode : DT.getNode(Node)->children()) {
264       if (Candidates.count(ChildDomNode->getBlock()))
265         Orders.push_back(ChildDomNode->getBlock());
266     }
267   }
268 
269   // Visit Orders in bottom-up order.
270   using InsertPtsCostPair =
271       std::pair<SetVector<BasicBlock *>, BlockFrequency>;
272 
273   // InsertPtsMap is a map from a BB to the best insertion points for the
274   // subtree of BB (subtree not including the BB itself).
275   DenseMap<BasicBlock *, InsertPtsCostPair> InsertPtsMap;
276   InsertPtsMap.reserve(Orders.size() + 1);
277   for (BasicBlock *Node : llvm::reverse(Orders)) {
278     bool NodeInBBs = BBs.count(Node);
279     auto &InsertPts = InsertPtsMap[Node].first;
280     BlockFrequency &InsertPtsFreq = InsertPtsMap[Node].second;
281 
282     // Return the optimal insert points in BBs.
283     if (Node == Entry) {
284       BBs.clear();
285       if (InsertPtsFreq > BFI.getBlockFreq(Node) ||
286           (InsertPtsFreq == BFI.getBlockFreq(Node) && InsertPts.size() > 1))
287         BBs.insert(Entry);
288       else
289         BBs.insert(InsertPts.begin(), InsertPts.end());
290       break;
291     }
292 
293     BasicBlock *Parent = DT.getNode(Node)->getIDom()->getBlock();
294     // Initially, ParentInsertPts is empty and ParentPtsFreq is 0. Every child
295     // will update its parent's ParentInsertPts and ParentPtsFreq.
296     auto &ParentInsertPts = InsertPtsMap[Parent].first;
297     BlockFrequency &ParentPtsFreq = InsertPtsMap[Parent].second;
298     // Choose to insert in Node or in subtree of Node.
299     // Don't hoist to EHPad because we may not find a proper place to insert
300     // in EHPad.
301     // If the total frequency of InsertPts is the same as the frequency of the
302     // target Node, and InsertPts contains more than one nodes, choose hoisting
303     // to reduce code size.
304     if (NodeInBBs ||
305         (!Node->isEHPad() &&
306          (InsertPtsFreq > BFI.getBlockFreq(Node) ||
307           (InsertPtsFreq == BFI.getBlockFreq(Node) && InsertPts.size() > 1)))) {
308       ParentInsertPts.insert(Node);
309       ParentPtsFreq += BFI.getBlockFreq(Node);
310     } else {
311       ParentInsertPts.insert(InsertPts.begin(), InsertPts.end());
312       ParentPtsFreq += InsertPtsFreq;
313     }
314   }
315 }
316 
317 /// Find an insertion point that dominates all uses.
318 SetVector<BasicBlock::iterator>
319 ConstantHoistingPass::findConstantInsertionPoint(
320     const ConstantInfo &ConstInfo,
321     const ArrayRef<BasicBlock::iterator> MatInsertPts) const {
322   assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry.");
323   // Collect all basic blocks.
324   SetVector<BasicBlock *> BBs;
325   SetVector<BasicBlock::iterator> InsertPts;
326 
327   for (BasicBlock::iterator MatInsertPt : MatInsertPts)
328     BBs.insert(MatInsertPt->getParent());
329 
330   if (BBs.count(Entry)) {
331     InsertPts.insert(Entry->begin());
332     return InsertPts;
333   }
334 
335   if (BFI) {
336     findBestInsertionSet(*DT, *BFI, Entry, BBs);
337     for (BasicBlock *BB : BBs)
338       InsertPts.insert(BB->getFirstInsertionPt());
339     return InsertPts;
340   }
341 
342   while (BBs.size() >= 2) {
343     BasicBlock *BB, *BB1, *BB2;
344     BB1 = BBs.pop_back_val();
345     BB2 = BBs.pop_back_val();
346     BB = DT->findNearestCommonDominator(BB1, BB2);
347     if (BB == Entry) {
348       InsertPts.insert(Entry->begin());
349       return InsertPts;
350     }
351     BBs.insert(BB);
352   }
353   assert((BBs.size() == 1) && "Expected only one element.");
354   Instruction &FirstInst = (*BBs.begin())->front();
355   InsertPts.insert(findMatInsertPt(&FirstInst));
356   return InsertPts;
357 }
358 
359 /// Record constant integer ConstInt for instruction Inst at operand
360 /// index Idx.
361 ///
362 /// The operand at index Idx is not necessarily the constant integer itself. It
363 /// could also be a cast instruction or a constant expression that uses the
364 /// constant integer.
365 void ConstantHoistingPass::collectConstantCandidates(
366     ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx,
367     ConstantInt *ConstInt) {
368   if (ConstInt->getType()->isVectorTy())
369     return;
370 
371   InstructionCost Cost;
372   // Ask the target about the cost of materializing the constant for the given
373   // instruction and operand index.
374   if (auto IntrInst = dyn_cast<IntrinsicInst>(Inst))
375     Cost = TTI->getIntImmCostIntrin(IntrInst->getIntrinsicID(), Idx,
376                                     ConstInt->getValue(), ConstInt->getType(),
377                                     TargetTransformInfo::TCK_SizeAndLatency);
378   else
379     Cost = TTI->getIntImmCostInst(
380         Inst->getOpcode(), Idx, ConstInt->getValue(), ConstInt->getType(),
381         TargetTransformInfo::TCK_SizeAndLatency, Inst);
382 
383   // Ignore cheap integer constants.
384   if (Cost > TargetTransformInfo::TCC_Basic) {
385     ConstCandMapType::iterator Itr;
386     bool Inserted;
387     ConstPtrUnionType Cand = ConstInt;
388     std::tie(Itr, Inserted) = ConstCandMap.insert(std::make_pair(Cand, 0));
389     if (Inserted) {
390       ConstIntCandVec.push_back(ConstantCandidate(ConstInt));
391       Itr->second = ConstIntCandVec.size() - 1;
392     }
393     ConstIntCandVec[Itr->second].addUser(Inst, Idx, *Cost.getValue());
394     LLVM_DEBUG(if (isa<ConstantInt>(Inst->getOperand(Idx))) dbgs()
395                    << "Collect constant " << *ConstInt << " from " << *Inst
396                    << " with cost " << Cost << '\n';
397                else dbgs() << "Collect constant " << *ConstInt
398                            << " indirectly from " << *Inst << " via "
399                            << *Inst->getOperand(Idx) << " with cost " << Cost
400                            << '\n';);
401   }
402 }
403 
404 /// Record constant GEP expression for instruction Inst at operand index Idx.
405 void ConstantHoistingPass::collectConstantCandidates(
406     ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx,
407     ConstantExpr *ConstExpr) {
408   // TODO: Handle vector GEPs
409   if (ConstExpr->getType()->isVectorTy())
410     return;
411 
412   GlobalVariable *BaseGV = dyn_cast<GlobalVariable>(ConstExpr->getOperand(0));
413   if (!BaseGV)
414     return;
415 
416   // Get offset from the base GV.
417   PointerType *GVPtrTy = cast<PointerType>(BaseGV->getType());
418   IntegerType *OffsetTy = DL->getIndexType(*Ctx, GVPtrTy->getAddressSpace());
419   APInt Offset(DL->getTypeSizeInBits(OffsetTy), /*val*/ 0, /*isSigned*/ true);
420   auto *GEPO = cast<GEPOperator>(ConstExpr);
421 
422   // TODO: If we have a mix of inbounds and non-inbounds GEPs, then basing a
423   // non-inbounds GEP on an inbounds GEP is potentially incorrect. Restrict to
424   // inbounds GEP for now -- alternatively, we could drop inbounds from the
425   // constant expression,
426   if (!GEPO->isInBounds())
427     return;
428 
429   if (!GEPO->accumulateConstantOffset(*DL, Offset))
430     return;
431 
432   if (!Offset.isIntN(32))
433     return;
434 
435   // A constant GEP expression that has a GlobalVariable as base pointer is
436   // usually lowered to a load from constant pool. Such operation is unlikely
437   // to be cheaper than compute it by <Base + Offset>, which can be lowered to
438   // an ADD instruction or folded into Load/Store instruction.
439   InstructionCost Cost =
440       TTI->getIntImmCostInst(Instruction::Add, 1, Offset, OffsetTy,
441                              TargetTransformInfo::TCK_SizeAndLatency, Inst);
442   ConstCandVecType &ExprCandVec = ConstGEPCandMap[BaseGV];
443   ConstCandMapType::iterator Itr;
444   bool Inserted;
445   ConstPtrUnionType Cand = ConstExpr;
446   std::tie(Itr, Inserted) = ConstCandMap.insert(std::make_pair(Cand, 0));
447   if (Inserted) {
448     ExprCandVec.push_back(ConstantCandidate(
449         ConstantInt::get(Type::getInt32Ty(*Ctx), Offset.getLimitedValue()),
450         ConstExpr));
451     Itr->second = ExprCandVec.size() - 1;
452   }
453   ExprCandVec[Itr->second].addUser(Inst, Idx, *Cost.getValue());
454 }
455 
456 /// Check the operand for instruction Inst at index Idx.
457 void ConstantHoistingPass::collectConstantCandidates(
458     ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx) {
459   Value *Opnd = Inst->getOperand(Idx);
460 
461   // Visit constant integers.
462   if (auto ConstInt = dyn_cast<ConstantInt>(Opnd)) {
463     collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt);
464     return;
465   }
466 
467   // Visit cast instructions that have constant integers.
468   if (auto CastInst = dyn_cast<Instruction>(Opnd)) {
469     // Only visit cast instructions, which have been skipped. All other
470     // instructions should have already been visited.
471     if (!CastInst->isCast())
472       return;
473 
474     if (auto *ConstInt = dyn_cast<ConstantInt>(CastInst->getOperand(0))) {
475       // Pretend the constant is directly used by the instruction and ignore
476       // the cast instruction.
477       collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt);
478       return;
479     }
480   }
481 
482   // Visit constant expressions that have constant integers.
483   if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) {
484     // Handle constant gep expressions.
485     if (ConstHoistGEP && isa<GEPOperator>(ConstExpr))
486       collectConstantCandidates(ConstCandMap, Inst, Idx, ConstExpr);
487 
488     // Only visit constant cast expressions.
489     if (!ConstExpr->isCast())
490       return;
491 
492     if (auto ConstInt = dyn_cast<ConstantInt>(ConstExpr->getOperand(0))) {
493       // Pretend the constant is directly used by the instruction and ignore
494       // the constant expression.
495       collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt);
496       return;
497     }
498   }
499 }
500 
501 /// Scan the instruction for expensive integer constants and record them
502 /// in the constant candidate vector.
503 void ConstantHoistingPass::collectConstantCandidates(
504     ConstCandMapType &ConstCandMap, Instruction *Inst) {
505   // Skip all cast instructions. They are visited indirectly later on.
506   if (Inst->isCast())
507     return;
508 
509   // Scan all operands.
510   for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
511     // The cost of materializing the constants (defined in
512     // `TargetTransformInfo::getIntImmCostInst`) for instructions which only
513     // take constant variables is lower than `TargetTransformInfo::TCC_Basic`.
514     // So it's safe for us to collect constant candidates from all
515     // IntrinsicInsts.
516     if (canReplaceOperandWithVariable(Inst, Idx)) {
517       collectConstantCandidates(ConstCandMap, Inst, Idx);
518     }
519   } // end of for all operands
520 }
521 
522 /// Collect all integer constants in the function that cannot be folded
523 /// into an instruction itself.
524 void ConstantHoistingPass::collectConstantCandidates(Function &Fn) {
525   ConstCandMapType ConstCandMap;
526   for (BasicBlock &BB : Fn) {
527     // Ignore unreachable basic blocks.
528     if (!DT->isReachableFromEntry(&BB))
529       continue;
530     for (Instruction &Inst : BB)
531       if (!TTI->preferToKeepConstantsAttached(Inst, Fn))
532         collectConstantCandidates(ConstCandMap, &Inst);
533   }
534 }
535 
536 // This helper function is necessary to deal with values that have different
537 // bit widths (APInt Operator- does not like that). If the value cannot be
538 // represented in uint64 we return an "empty" APInt. This is then interpreted
539 // as the value is not in range.
540 static std::optional<APInt> calculateOffsetDiff(const APInt &V1,
541                                                 const APInt &V2) {
542   std::optional<APInt> Res;
543   unsigned BW = V1.getBitWidth() > V2.getBitWidth() ?
544                 V1.getBitWidth() : V2.getBitWidth();
545   uint64_t LimVal1 = V1.getLimitedValue();
546   uint64_t LimVal2 = V2.getLimitedValue();
547 
548   if (LimVal1 == ~0ULL || LimVal2 == ~0ULL)
549     return Res;
550 
551   uint64_t Diff = LimVal1 - LimVal2;
552   return APInt(BW, Diff, true);
553 }
554 
555 // From a list of constants, one needs to picked as the base and the other
556 // constants will be transformed into an offset from that base constant. The
557 // question is which we can pick best? For example, consider these constants
558 // and their number of uses:
559 //
560 //  Constants| 2 | 4 | 12 | 42 |
561 //  NumUses  | 3 | 2 |  8 |  7 |
562 //
563 // Selecting constant 12 because it has the most uses will generate negative
564 // offsets for constants 2 and 4 (i.e. -10 and -8 respectively). If negative
565 // offsets lead to less optimal code generation, then there might be better
566 // solutions. Suppose immediates in the range of 0..35 are most optimally
567 // supported by the architecture, then selecting constant 2 is most optimal
568 // because this will generate offsets: 0, 2, 10, 40. Offsets 0, 2 and 10 are in
569 // range 0..35, and thus 3 + 2 + 8 = 13 uses are in range. Selecting 12 would
570 // have only 8 uses in range, so choosing 2 as a base is more optimal. Thus, in
571 // selecting the base constant the range of the offsets is a very important
572 // factor too that we take into account here. This algorithm calculates a total
573 // costs for selecting a constant as the base and substract the costs if
574 // immediates are out of range. It has quadratic complexity, so we call this
575 // function only when we're optimising for size and there are less than 100
576 // constants, we fall back to the straightforward algorithm otherwise
577 // which does not do all the offset calculations.
578 unsigned
579 ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S,
580                                            ConstCandVecType::iterator E,
581                                            ConstCandVecType::iterator &MaxCostItr) {
582   unsigned NumUses = 0;
583 
584   if (!OptForSize || std::distance(S,E) > 100) {
585     for (auto ConstCand = S; ConstCand != E; ++ConstCand) {
586       NumUses += ConstCand->Uses.size();
587       if (ConstCand->CumulativeCost > MaxCostItr->CumulativeCost)
588         MaxCostItr = ConstCand;
589     }
590     return NumUses;
591   }
592 
593   LLVM_DEBUG(dbgs() << "== Maximize constants in range ==\n");
594   InstructionCost MaxCost = -1;
595   for (auto ConstCand = S; ConstCand != E; ++ConstCand) {
596     auto Value = ConstCand->ConstInt->getValue();
597     Type *Ty = ConstCand->ConstInt->getType();
598     InstructionCost Cost = 0;
599     NumUses += ConstCand->Uses.size();
600     LLVM_DEBUG(dbgs() << "= Constant: " << ConstCand->ConstInt->getValue()
601                       << "\n");
602 
603     for (auto User : ConstCand->Uses) {
604       unsigned Opcode = User.Inst->getOpcode();
605       unsigned OpndIdx = User.OpndIdx;
606       Cost += TTI->getIntImmCostInst(Opcode, OpndIdx, Value, Ty,
607                                      TargetTransformInfo::TCK_SizeAndLatency);
608       LLVM_DEBUG(dbgs() << "Cost: " << Cost << "\n");
609 
610       for (auto C2 = S; C2 != E; ++C2) {
611         std::optional<APInt> Diff = calculateOffsetDiff(
612             C2->ConstInt->getValue(), ConstCand->ConstInt->getValue());
613         if (Diff) {
614           const InstructionCost ImmCosts =
615               TTI->getIntImmCodeSizeCost(Opcode, OpndIdx, *Diff, Ty);
616           Cost -= ImmCosts;
617           LLVM_DEBUG(dbgs() << "Offset " << *Diff << " "
618                             << "has penalty: " << ImmCosts << "\n"
619                             << "Adjusted cost: " << Cost << "\n");
620         }
621       }
622     }
623     LLVM_DEBUG(dbgs() << "Cumulative cost: " << Cost << "\n");
624     if (Cost > MaxCost) {
625       MaxCost = Cost;
626       MaxCostItr = ConstCand;
627       LLVM_DEBUG(dbgs() << "New candidate: " << MaxCostItr->ConstInt->getValue()
628                         << "\n");
629     }
630   }
631   return NumUses;
632 }
633 
634 /// Find the base constant within the given range and rebase all other
635 /// constants with respect to the base constant.
636 void ConstantHoistingPass::findAndMakeBaseConstant(
637     ConstCandVecType::iterator S, ConstCandVecType::iterator E,
638     SmallVectorImpl<consthoist::ConstantInfo> &ConstInfoVec) {
639   auto MaxCostItr = S;
640   unsigned NumUses = maximizeConstantsInRange(S, E, MaxCostItr);
641 
642   // Don't hoist constants that have only one use.
643   if (NumUses <= 1)
644     return;
645 
646   ConstantInt *ConstInt = MaxCostItr->ConstInt;
647   ConstantExpr *ConstExpr = MaxCostItr->ConstExpr;
648   ConstantInfo ConstInfo;
649   ConstInfo.BaseInt = ConstInt;
650   ConstInfo.BaseExpr = ConstExpr;
651   Type *Ty = ConstInt->getType();
652 
653   // Rebase the constants with respect to the base constant.
654   for (auto ConstCand = S; ConstCand != E; ++ConstCand) {
655     APInt Diff = ConstCand->ConstInt->getValue() - ConstInt->getValue();
656     Constant *Offset = Diff == 0 ? nullptr : ConstantInt::get(Ty, Diff);
657     Type *ConstTy =
658         ConstCand->ConstExpr ? ConstCand->ConstExpr->getType() : nullptr;
659     ConstInfo.RebasedConstants.push_back(
660       RebasedConstantInfo(std::move(ConstCand->Uses), Offset, ConstTy));
661   }
662   ConstInfoVec.push_back(std::move(ConstInfo));
663 }
664 
665 /// Finds and combines constant candidates that can be easily
666 /// rematerialized with an add from a common base constant.
667 void ConstantHoistingPass::findBaseConstants(GlobalVariable *BaseGV) {
668   // If BaseGV is nullptr, find base among candidate constant integers;
669   // Otherwise find base among constant GEPs that share the same BaseGV.
670   ConstCandVecType &ConstCandVec = BaseGV ?
671       ConstGEPCandMap[BaseGV] : ConstIntCandVec;
672   ConstInfoVecType &ConstInfoVec = BaseGV ?
673       ConstGEPInfoMap[BaseGV] : ConstIntInfoVec;
674 
675   // Sort the constants by value and type. This invalidates the mapping!
676   llvm::stable_sort(ConstCandVec, [](const ConstantCandidate &LHS,
677                                      const ConstantCandidate &RHS) {
678     if (LHS.ConstInt->getType() != RHS.ConstInt->getType())
679       return LHS.ConstInt->getBitWidth() < RHS.ConstInt->getBitWidth();
680     return LHS.ConstInt->getValue().ult(RHS.ConstInt->getValue());
681   });
682 
683   // Simple linear scan through the sorted constant candidate vector for viable
684   // merge candidates.
685   auto MinValItr = ConstCandVec.begin();
686   for (auto CC = std::next(ConstCandVec.begin()), E = ConstCandVec.end();
687        CC != E; ++CC) {
688     if (MinValItr->ConstInt->getType() == CC->ConstInt->getType()) {
689       Type *MemUseValTy = nullptr;
690       for (auto &U : CC->Uses) {
691         auto *UI = U.Inst;
692         if (LoadInst *LI = dyn_cast<LoadInst>(UI)) {
693           MemUseValTy = LI->getType();
694           break;
695         } else if (StoreInst *SI = dyn_cast<StoreInst>(UI)) {
696           // Make sure the constant is used as pointer operand of the StoreInst.
697           if (SI->getPointerOperand() == SI->getOperand(U.OpndIdx)) {
698             MemUseValTy = SI->getValueOperand()->getType();
699             break;
700           }
701         }
702       }
703 
704       // Check if the constant is in range of an add with immediate.
705       APInt Diff = CC->ConstInt->getValue() - MinValItr->ConstInt->getValue();
706       if ((Diff.getBitWidth() <= 64) &&
707           TTI->isLegalAddImmediate(Diff.getSExtValue()) &&
708           // Check if Diff can be used as offset in addressing mode of the user
709           // memory instruction.
710           (!MemUseValTy || TTI->isLegalAddressingMode(MemUseValTy,
711            /*BaseGV*/nullptr, /*BaseOffset*/Diff.getSExtValue(),
712            /*HasBaseReg*/true, /*Scale*/0)))
713         continue;
714     }
715     // We either have now a different constant type or the constant is not in
716     // range of an add with immediate anymore.
717     findAndMakeBaseConstant(MinValItr, CC, ConstInfoVec);
718     // Start a new base constant search.
719     MinValItr = CC;
720   }
721   // Finalize the last base constant search.
722   findAndMakeBaseConstant(MinValItr, ConstCandVec.end(), ConstInfoVec);
723 }
724 
725 /// Updates the operand at Idx in instruction Inst with the result of
726 ///        instruction Mat. If the instruction is a PHI node then special
727 ///        handling for duplicate values from the same incoming basic block is
728 ///        required.
729 /// \return The update will always succeed, but the return value indicated if
730 ///         Mat was used for the update or not.
731 static bool updateOperand(Instruction *Inst, unsigned Idx, Instruction *Mat) {
732   if (auto PHI = dyn_cast<PHINode>(Inst)) {
733     // Check if any previous operand of the PHI node has the same incoming basic
734     // block. This is a very odd case that happens when the incoming basic block
735     // has a switch statement. In this case use the same value as the previous
736     // operand(s), otherwise we will fail verification due to different values.
737     // The values are actually the same, but the variable names are different
738     // and the verifier doesn't like that.
739     BasicBlock *IncomingBB = PHI->getIncomingBlock(Idx);
740     for (unsigned i = 0; i < Idx; ++i) {
741       if (PHI->getIncomingBlock(i) == IncomingBB) {
742         Value *IncomingVal = PHI->getIncomingValue(i);
743         Inst->setOperand(Idx, IncomingVal);
744         return false;
745       }
746     }
747   }
748 
749   Inst->setOperand(Idx, Mat);
750   return true;
751 }
752 
753 /// Emit materialization code for all rebased constants and update their
754 /// users.
755 void ConstantHoistingPass::emitBaseConstants(Instruction *Base,
756                                              UserAdjustment *Adj) {
757   Instruction *Mat = Base;
758 
759   // The same offset can be dereferenced to different types in nested struct.
760   if (!Adj->Offset && Adj->Ty && Adj->Ty != Base->getType())
761     Adj->Offset = ConstantInt::get(Type::getInt32Ty(*Ctx), 0);
762 
763   if (Adj->Offset) {
764     if (Adj->Ty) {
765       // Constant being rebased is a ConstantExpr.
766       Mat = GetElementPtrInst::Create(Type::getInt8Ty(*Ctx), Base, Adj->Offset,
767                                       "mat_gep", Adj->MatInsertPt);
768       // Hide it behind a bitcast.
769       Mat = new BitCastInst(Mat, Adj->Ty, "mat_bitcast",
770                             Adj->MatInsertPt->getIterator());
771     } else
772       // Constant being rebased is a ConstantInt.
773       Mat =
774           BinaryOperator::Create(Instruction::Add, Base, Adj->Offset,
775                                  "const_mat", Adj->MatInsertPt->getIterator());
776 
777     LLVM_DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0)
778                       << " + " << *Adj->Offset << ") in BB "
779                       << Mat->getParent()->getName() << '\n'
780                       << *Mat << '\n');
781     Mat->setDebugLoc(Adj->User.Inst->getDebugLoc());
782   }
783   Value *Opnd = Adj->User.Inst->getOperand(Adj->User.OpndIdx);
784 
785   // Visit constant integer.
786   if (isa<ConstantInt>(Opnd)) {
787     LLVM_DEBUG(dbgs() << "Update: " << *Adj->User.Inst << '\n');
788     if (!updateOperand(Adj->User.Inst, Adj->User.OpndIdx, Mat) && Adj->Offset)
789       Mat->eraseFromParent();
790     LLVM_DEBUG(dbgs() << "To    : " << *Adj->User.Inst << '\n');
791     return;
792   }
793 
794   // Visit cast instruction.
795   if (auto CastInst = dyn_cast<Instruction>(Opnd)) {
796     assert(CastInst->isCast() && "Expected an cast instruction!");
797     // Check if we already have visited this cast instruction before to avoid
798     // unnecessary cloning.
799     Instruction *&ClonedCastInst = ClonedCastMap[CastInst];
800     if (!ClonedCastInst) {
801       ClonedCastInst = CastInst->clone();
802       ClonedCastInst->setOperand(0, Mat);
803       ClonedCastInst->insertAfter(CastInst);
804       // Use the same debug location as the original cast instruction.
805       ClonedCastInst->setDebugLoc(CastInst->getDebugLoc());
806       LLVM_DEBUG(dbgs() << "Clone instruction: " << *CastInst << '\n'
807                         << "To               : " << *ClonedCastInst << '\n');
808     }
809 
810     LLVM_DEBUG(dbgs() << "Update: " << *Adj->User.Inst << '\n');
811     updateOperand(Adj->User.Inst, Adj->User.OpndIdx, ClonedCastInst);
812     LLVM_DEBUG(dbgs() << "To    : " << *Adj->User.Inst << '\n');
813     return;
814   }
815 
816   // Visit constant expression.
817   if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) {
818     if (isa<GEPOperator>(ConstExpr)) {
819       // Operand is a ConstantGEP, replace it.
820       updateOperand(Adj->User.Inst, Adj->User.OpndIdx, Mat);
821       return;
822     }
823 
824     // Aside from constant GEPs, only constant cast expressions are collected.
825     assert(ConstExpr->isCast() && "ConstExpr should be a cast");
826     Instruction *ConstExprInst = ConstExpr->getAsInstruction();
827     ConstExprInst->insertBefore(Adj->MatInsertPt);
828     ConstExprInst->setOperand(0, Mat);
829 
830     // Use the same debug location as the instruction we are about to update.
831     ConstExprInst->setDebugLoc(Adj->User.Inst->getDebugLoc());
832 
833     LLVM_DEBUG(dbgs() << "Create instruction: " << *ConstExprInst << '\n'
834                       << "From              : " << *ConstExpr << '\n');
835     LLVM_DEBUG(dbgs() << "Update: " << *Adj->User.Inst << '\n');
836     if (!updateOperand(Adj->User.Inst, Adj->User.OpndIdx, ConstExprInst)) {
837       ConstExprInst->eraseFromParent();
838       if (Adj->Offset)
839         Mat->eraseFromParent();
840     }
841     LLVM_DEBUG(dbgs() << "To    : " << *Adj->User.Inst << '\n');
842     return;
843   }
844 }
845 
846 /// Hoist and hide the base constant behind a bitcast and emit
847 /// materialization code for derived constants.
848 bool ConstantHoistingPass::emitBaseConstants(GlobalVariable *BaseGV) {
849   bool MadeChange = false;
850   SmallVectorImpl<consthoist::ConstantInfo> &ConstInfoVec =
851       BaseGV ? ConstGEPInfoMap[BaseGV] : ConstIntInfoVec;
852   for (const consthoist::ConstantInfo &ConstInfo : ConstInfoVec) {
853     SmallVector<BasicBlock::iterator, 4> MatInsertPts;
854     collectMatInsertPts(ConstInfo.RebasedConstants, MatInsertPts);
855     SetVector<BasicBlock::iterator> IPSet =
856         findConstantInsertionPoint(ConstInfo, MatInsertPts);
857     // We can have an empty set if the function contains unreachable blocks.
858     if (IPSet.empty())
859       continue;
860 
861     unsigned UsesNum = 0;
862     unsigned ReBasesNum = 0;
863     unsigned NotRebasedNum = 0;
864     for (const BasicBlock::iterator &IP : IPSet) {
865       // First, collect constants depending on this IP of the base.
866       UsesNum = 0;
867       SmallVector<UserAdjustment, 4> ToBeRebased;
868       unsigned MatCtr = 0;
869       for (auto const &RCI : ConstInfo.RebasedConstants) {
870         UsesNum += RCI.Uses.size();
871         for (auto const &U : RCI.Uses) {
872           const BasicBlock::iterator &MatInsertPt = MatInsertPts[MatCtr++];
873           BasicBlock *OrigMatInsertBB = MatInsertPt->getParent();
874           // If Base constant is to be inserted in multiple places,
875           // generate rebase for U using the Base dominating U.
876           if (IPSet.size() == 1 ||
877               DT->dominates(IP->getParent(), OrigMatInsertBB))
878             ToBeRebased.emplace_back(RCI.Offset, RCI.Ty, MatInsertPt, U);
879         }
880       }
881 
882       // If only few constants depend on this IP of base, skip rebasing,
883       // assuming the base and the rebased have the same materialization cost.
884       if (ToBeRebased.size() < MinNumOfDependentToRebase) {
885         NotRebasedNum += ToBeRebased.size();
886         continue;
887       }
888 
889       // Emit an instance of the base at this IP.
890       Instruction *Base = nullptr;
891       // Hoist and hide the base constant behind a bitcast.
892       if (ConstInfo.BaseExpr) {
893         assert(BaseGV && "A base constant expression must have an base GV");
894         Type *Ty = ConstInfo.BaseExpr->getType();
895         Base = new BitCastInst(ConstInfo.BaseExpr, Ty, "const", IP);
896       } else {
897         IntegerType *Ty = ConstInfo.BaseInt->getIntegerType();
898         Base = new BitCastInst(ConstInfo.BaseInt, Ty, "const", IP);
899       }
900 
901       Base->setDebugLoc(IP->getDebugLoc());
902 
903       LLVM_DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseInt
904                         << ") to BB " << IP->getParent()->getName() << '\n'
905                         << *Base << '\n');
906 
907       // Emit materialization code for rebased constants depending on this IP.
908       for (UserAdjustment &R : ToBeRebased) {
909         emitBaseConstants(Base, &R);
910         ReBasesNum++;
911         // Use the same debug location as the last user of the constant.
912         Base->setDebugLoc(DILocation::getMergedLocation(
913             Base->getDebugLoc(), R.User.Inst->getDebugLoc()));
914       }
915       assert(!Base->use_empty() && "The use list is empty!?");
916       assert(isa<Instruction>(Base->user_back()) &&
917              "All uses should be instructions.");
918     }
919     (void)UsesNum;
920     (void)ReBasesNum;
921     (void)NotRebasedNum;
922     // Expect all uses are rebased after rebase is done.
923     assert(UsesNum == (ReBasesNum + NotRebasedNum) &&
924            "Not all uses are rebased");
925 
926     NumConstantsHoisted++;
927 
928     // Base constant is also included in ConstInfo.RebasedConstants, so
929     // deduct 1 from ConstInfo.RebasedConstants.size().
930     NumConstantsRebased += ConstInfo.RebasedConstants.size() - 1;
931 
932     MadeChange = true;
933   }
934   return MadeChange;
935 }
936 
937 /// Check all cast instructions we made a copy of and remove them if they
938 /// have no more users.
939 void ConstantHoistingPass::deleteDeadCastInst() const {
940   for (auto const &I : ClonedCastMap)
941     if (I.first->use_empty())
942       I.first->eraseFromParent();
943 }
944 
945 /// Optimize expensive integer constants in the given function.
946 bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI,
947                                    DominatorTree &DT, BlockFrequencyInfo *BFI,
948                                    BasicBlock &Entry, ProfileSummaryInfo *PSI) {
949   this->TTI = &TTI;
950   this->DT = &DT;
951   this->BFI = BFI;
952   this->DL = &Fn.getDataLayout();
953   this->Ctx = &Fn.getContext();
954   this->Entry = &Entry;
955   this->PSI = PSI;
956   this->OptForSize = Entry.getParent()->hasOptSize() ||
957                      llvm::shouldOptimizeForSize(Entry.getParent(), PSI, BFI,
958                                                  PGSOQueryType::IRPass);
959 
960   // Collect all constant candidates.
961   collectConstantCandidates(Fn);
962 
963   // Combine constants that can be easily materialized with an add from a common
964   // base constant.
965   if (!ConstIntCandVec.empty())
966     findBaseConstants(nullptr);
967   for (const auto &MapEntry : ConstGEPCandMap)
968     if (!MapEntry.second.empty())
969       findBaseConstants(MapEntry.first);
970 
971   // Finally hoist the base constant and emit materialization code for dependent
972   // constants.
973   bool MadeChange = false;
974   if (!ConstIntInfoVec.empty())
975     MadeChange = emitBaseConstants(nullptr);
976   for (const auto &MapEntry : ConstGEPInfoMap)
977     if (!MapEntry.second.empty())
978       MadeChange |= emitBaseConstants(MapEntry.first);
979 
980 
981   // Cleanup dead instructions.
982   deleteDeadCastInst();
983 
984   cleanup();
985 
986   return MadeChange;
987 }
988 
989 PreservedAnalyses ConstantHoistingPass::run(Function &F,
990                                             FunctionAnalysisManager &AM) {
991   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
992   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
993   auto BFI = ConstHoistWithBlockFrequency
994                  ? &AM.getResult<BlockFrequencyAnalysis>(F)
995                  : nullptr;
996   auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
997   auto *PSI = MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent());
998   if (!runImpl(F, TTI, DT, BFI, F.getEntryBlock(), PSI))
999     return PreservedAnalyses::all();
1000 
1001   PreservedAnalyses PA;
1002   PA.preserveSet<CFGAnalyses>();
1003   return PA;
1004 }
1005