xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/TypePromotion.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===----- TypePromotion.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This is an opcode based type promotion pass for small types that would
11 /// otherwise be promoted during legalisation. This works around the limitations
12 /// of selection dag for cyclic regions. The search begins from icmp
13 /// instructions operands where a tree, consisting of non-wrapping or safe
14 /// wrapping instructions, is built, checked and promoted if possible.
15 ///
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/CodeGen/TypePromotion.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Analysis/LoopInfo.h"
22 #include "llvm/Analysis/TargetTransformInfo.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/CodeGen/TargetLowering.h"
25 #include "llvm/CodeGen/TargetPassConfig.h"
26 #include "llvm/CodeGen/TargetSubtargetInfo.h"
27 #include "llvm/IR/Attributes.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Constants.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/InstrTypes.h"
32 #include "llvm/IR/Instruction.h"
33 #include "llvm/IR/Instructions.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/IR/Value.h"
36 #include "llvm/InitializePasses.h"
37 #include "llvm/Pass.h"
38 #include "llvm/Support/Casting.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Target/TargetMachine.h"
41 
42 #define DEBUG_TYPE "type-promotion"
43 #define PASS_NAME "Type Promotion"
44 
45 using namespace llvm;
46 
47 static cl::opt<bool> DisablePromotion("disable-type-promotion", cl::Hidden,
48                                       cl::init(false),
49                                       cl::desc("Disable type promotion pass"));
50 
51 // The goal of this pass is to enable more efficient code generation for
52 // operations on narrow types (i.e. types with < 32-bits) and this is a
53 // motivating IR code example:
54 //
55 //   define hidden i32 @cmp(i8 zeroext) {
56 //     %2 = add i8 %0, -49
57 //     %3 = icmp ult i8 %2, 3
58 //     ..
59 //   }
60 //
61 // The issue here is that i8 is type-legalized to i32 because i8 is not a
62 // legal type. Thus, arithmetic is done in integer-precision, but then the
63 // byte value is masked out as follows:
64 //
65 //   t19: i32 = add t4, Constant:i32<-49>
66 //     t24: i32 = and t19, Constant:i32<255>
67 //
68 // Consequently, we generate code like this:
69 //
70 //   subs  r0, #49
71 //   uxtb  r1, r0
72 //   cmp r1, #3
73 //
74 // This shows that masking out the byte value results in generation of
75 // the UXTB instruction. This is not optimal as r0 already contains the byte
76 // value we need, and so instead we can just generate:
77 //
78 //   sub.w r1, r0, #49
79 //   cmp r1, #3
80 //
81 // We achieve this by type promoting the IR to i32 like so for this example:
82 //
83 //   define i32 @cmp(i8 zeroext %c) {
84 //     %0 = zext i8 %c to i32
85 //     %c.off = add i32 %0, -49
86 //     %1 = icmp ult i32 %c.off, 3
87 //     ..
88 //   }
89 //
90 // For this to be valid and legal, we need to prove that the i32 add is
91 // producing the same value as the i8 addition, and that e.g. no overflow
92 // happens.
93 //
94 // A brief sketch of the algorithm and some terminology.
95 // We pattern match interesting IR patterns:
96 // - which have "sources": instructions producing narrow values (i8, i16), and
97 // - they have "sinks": instructions consuming these narrow values.
98 //
99 // We collect all instruction connecting sources and sinks in a worklist, so
100 // that we can mutate these instruction and perform type promotion when it is
101 // legal to do so.
102 
103 namespace {
104 class IRPromoter {
105   LLVMContext &Ctx;
106   unsigned PromotedWidth = 0;
107   SetVector<Value *> &Visited;
108   SetVector<Value *> &Sources;
109   SetVector<Instruction *> &Sinks;
110   SmallPtrSetImpl<Instruction *> &SafeWrap;
111   SmallPtrSetImpl<Instruction *> &InstsToRemove;
112   IntegerType *ExtTy = nullptr;
113   SmallPtrSet<Value *, 8> NewInsts;
114   DenseMap<Value *, SmallVector<Type *, 4>> TruncTysMap;
115   SmallPtrSet<Value *, 8> Promoted;
116 
117   void ReplaceAllUsersOfWith(Value *From, Value *To);
118   void ExtendSources();
119   void ConvertTruncs();
120   void PromoteTree();
121   void TruncateSinks();
122   void Cleanup();
123 
124 public:
IRPromoter(LLVMContext & C,unsigned Width,SetVector<Value * > & visited,SetVector<Value * > & sources,SetVector<Instruction * > & sinks,SmallPtrSetImpl<Instruction * > & wrap,SmallPtrSetImpl<Instruction * > & instsToRemove)125   IRPromoter(LLVMContext &C, unsigned Width, SetVector<Value *> &visited,
126              SetVector<Value *> &sources, SetVector<Instruction *> &sinks,
127              SmallPtrSetImpl<Instruction *> &wrap,
128              SmallPtrSetImpl<Instruction *> &instsToRemove)
129       : Ctx(C), PromotedWidth(Width), Visited(visited), Sources(sources),
130         Sinks(sinks), SafeWrap(wrap), InstsToRemove(instsToRemove) {
131     ExtTy = IntegerType::get(Ctx, PromotedWidth);
132   }
133 
134   void Mutate();
135 };
136 
137 class TypePromotionImpl {
138   unsigned TypeSize = 0;
139   const TargetLowering *TLI = nullptr;
140   LLVMContext *Ctx = nullptr;
141   unsigned RegisterBitWidth = 0;
142   SmallPtrSet<Value *, 16> AllVisited;
143   SmallPtrSet<Instruction *, 8> SafeToPromote;
144   SmallPtrSet<Instruction *, 4> SafeWrap;
145   SmallPtrSet<Instruction *, 4> InstsToRemove;
146 
147   // Does V have the same size result type as TypeSize.
148   bool EqualTypeSize(Value *V);
149   // Does V have the same size, or narrower, result type as TypeSize.
150   bool LessOrEqualTypeSize(Value *V);
151   // Does V have a result type that is wider than TypeSize.
152   bool GreaterThanTypeSize(Value *V);
153   // Does V have a result type that is narrower than TypeSize.
154   bool LessThanTypeSize(Value *V);
155   // Should V be a leaf in the promote tree?
156   bool isSource(Value *V);
157   // Should V be a root in the promotion tree?
158   bool isSink(Value *V);
159   // Should we change the result type of V? It will result in the users of V
160   // being visited.
161   bool shouldPromote(Value *V);
162   // Is I an add or a sub, which isn't marked as nuw, but where a wrapping
163   // result won't affect the computation?
164   bool isSafeWrap(Instruction *I);
165   // Can V have its integer type promoted, or can the type be ignored.
166   bool isSupportedType(Value *V);
167   // Is V an instruction with a supported opcode or another value that we can
168   // handle, such as constants and basic blocks.
169   bool isSupportedValue(Value *V);
170   // Is V an instruction thats result can trivially promoted, or has safe
171   // wrapping.
172   bool isLegalToPromote(Value *V);
173   bool TryToPromote(Value *V, unsigned PromotedWidth, const LoopInfo &LI);
174 
175 public:
176   bool run(Function &F, const TargetMachine *TM,
177            const TargetTransformInfo &TTI, const LoopInfo &LI);
178 };
179 
180 class TypePromotionLegacy : public FunctionPass {
181 public:
182   static char ID;
183 
TypePromotionLegacy()184   TypePromotionLegacy() : FunctionPass(ID) {}
185 
getAnalysisUsage(AnalysisUsage & AU) const186   void getAnalysisUsage(AnalysisUsage &AU) const override {
187     AU.addRequired<LoopInfoWrapperPass>();
188     AU.addRequired<TargetTransformInfoWrapperPass>();
189     AU.addRequired<TargetPassConfig>();
190     AU.setPreservesCFG();
191     AU.addPreserved<LoopInfoWrapperPass>();
192   }
193 
getPassName() const194   StringRef getPassName() const override { return PASS_NAME; }
195 
196   bool runOnFunction(Function &F) override;
197 };
198 
199 } // namespace
200 
GenerateSignBits(Instruction * I)201 static bool GenerateSignBits(Instruction *I) {
202   unsigned Opc = I->getOpcode();
203   return Opc == Instruction::AShr || Opc == Instruction::SDiv ||
204          Opc == Instruction::SRem || Opc == Instruction::SExt;
205 }
206 
EqualTypeSize(Value * V)207 bool TypePromotionImpl::EqualTypeSize(Value *V) {
208   return V->getType()->getScalarSizeInBits() == TypeSize;
209 }
210 
LessOrEqualTypeSize(Value * V)211 bool TypePromotionImpl::LessOrEqualTypeSize(Value *V) {
212   return V->getType()->getScalarSizeInBits() <= TypeSize;
213 }
214 
GreaterThanTypeSize(Value * V)215 bool TypePromotionImpl::GreaterThanTypeSize(Value *V) {
216   return V->getType()->getScalarSizeInBits() > TypeSize;
217 }
218 
LessThanTypeSize(Value * V)219 bool TypePromotionImpl::LessThanTypeSize(Value *V) {
220   return V->getType()->getScalarSizeInBits() < TypeSize;
221 }
222 
223 /// Return true if the given value is a source in the use-def chain, producing
224 /// a narrow 'TypeSize' value. These values will be zext to start the promotion
225 /// of the tree to i32. We guarantee that these won't populate the upper bits
226 /// of the register. ZExt on the loads will be free, and the same for call
227 /// return values because we only accept ones that guarantee a zeroext ret val.
228 /// Many arguments will have the zeroext attribute too, so those would be free
229 /// too.
isSource(Value * V)230 bool TypePromotionImpl::isSource(Value *V) {
231   if (!isa<IntegerType>(V->getType()))
232     return false;
233 
234   // TODO Allow zext to be sources.
235   if (isa<Argument>(V))
236     return true;
237   else if (isa<LoadInst>(V))
238     return true;
239   else if (auto *Call = dyn_cast<CallInst>(V))
240     return Call->hasRetAttr(Attribute::AttrKind::ZExt);
241   else if (auto *Trunc = dyn_cast<TruncInst>(V))
242     return EqualTypeSize(Trunc);
243   return false;
244 }
245 
246 /// Return true if V will require any promoted values to be truncated for the
247 /// the IR to remain valid. We can't mutate the value type of these
248 /// instructions.
isSink(Value * V)249 bool TypePromotionImpl::isSink(Value *V) {
250   // TODO The truncate also isn't actually necessary because we would already
251   // proved that the data value is kept within the range of the original data
252   // type. We currently remove any truncs inserted for handling zext sinks.
253 
254   // Sinks are:
255   // - points where the value in the register is being observed, such as an
256   //   icmp, switch or store.
257   // - points where value types have to match, such as calls and returns.
258   // - zext are included to ease the transformation and are generally removed
259   //   later on.
260   if (auto *Store = dyn_cast<StoreInst>(V))
261     return LessOrEqualTypeSize(Store->getValueOperand());
262   if (auto *Return = dyn_cast<ReturnInst>(V))
263     return LessOrEqualTypeSize(Return->getReturnValue());
264   if (auto *ZExt = dyn_cast<ZExtInst>(V))
265     return GreaterThanTypeSize(ZExt);
266   if (auto *Switch = dyn_cast<SwitchInst>(V))
267     return LessThanTypeSize(Switch->getCondition());
268   if (auto *ICmp = dyn_cast<ICmpInst>(V))
269     return ICmp->isSigned() || LessThanTypeSize(ICmp->getOperand(0));
270 
271   return isa<CallInst>(V);
272 }
273 
274 /// Return whether this instruction can safely wrap.
isSafeWrap(Instruction * I)275 bool TypePromotionImpl::isSafeWrap(Instruction *I) {
276   // We can support a potentially wrapping Add/Sub instruction (I) if:
277   // - It is only used by an unsigned icmp.
278   // - The icmp uses a constant.
279   // - The wrapping instruction (I) also uses a constant.
280   //
281   // This a common pattern emitted to check if a value is within a range.
282   //
283   // For example:
284   //
285   // %sub = sub i8 %a, C1
286   // %cmp = icmp ule i8 %sub, C2
287   //
288   // or
289   //
290   // %add = add i8 %a, C1
291   // %cmp = icmp ule i8 %add, C2.
292   //
293   // We will treat an add as though it were a subtract by -C1. To promote
294   // the Add/Sub we will zero extend the LHS and the subtracted amount. For Add,
295   // this means we need to negate the constant, zero extend to RegisterBitWidth,
296   // and negate in the larger type.
297   //
298   // This will produce a value in the range [-zext(C1), zext(X)-zext(C1)] where
299   // C1 is the subtracted amount. This is either a small unsigned number or a
300   // large unsigned number in the promoted type.
301   //
302   // Now we need to correct the compare constant C2. Values >= C1 in the
303   // original add result range have been remapped to large values in the
304   // promoted range. If the compare constant fell into this range we need to
305   // remap it as well. We can do this as -(zext(-C2)).
306   //
307   // For example:
308   //
309   // %sub = sub i8 %a, 2
310   // %cmp = icmp ule i8 %sub, 254
311   //
312   // becomes
313   //
314   // %zext = zext %a to i32
315   // %sub = sub i32 %zext, 2
316   // %cmp = icmp ule i32 %sub, 4294967294
317   //
318   // Another example:
319   //
320   // %sub = sub i8 %a, 1
321   // %cmp = icmp ule i8 %sub, 254
322   //
323   // becomes
324   //
325   // %zext = zext %a to i32
326   // %sub = sub i32 %zext, 1
327   // %cmp = icmp ule i32 %sub, 254
328 
329   unsigned Opc = I->getOpcode();
330   if (Opc != Instruction::Add && Opc != Instruction::Sub)
331     return false;
332 
333   if (!I->hasOneUse() || !isa<ICmpInst>(*I->user_begin()) ||
334       !isa<ConstantInt>(I->getOperand(1)))
335     return false;
336 
337   // Don't support an icmp that deals with sign bits.
338   auto *CI = cast<ICmpInst>(*I->user_begin());
339   if (CI->isSigned() || CI->isEquality())
340     return false;
341 
342   ConstantInt *ICmpConstant = nullptr;
343   if (auto *Const = dyn_cast<ConstantInt>(CI->getOperand(0)))
344     ICmpConstant = Const;
345   else if (auto *Const = dyn_cast<ConstantInt>(CI->getOperand(1)))
346     ICmpConstant = Const;
347   else
348     return false;
349 
350   const APInt &ICmpConst = ICmpConstant->getValue();
351   APInt OverflowConst = cast<ConstantInt>(I->getOperand(1))->getValue();
352   if (Opc == Instruction::Sub)
353     OverflowConst = -OverflowConst;
354 
355   // If the constant is positive, we will end up filling the promoted bits with
356   // all 1s. Make sure that results in a cheap add constant.
357   if (!OverflowConst.isNonPositive()) {
358     // We don't have the true promoted width, just use 64 so we can create an
359     // int64_t for the isLegalAddImmediate call.
360     if (OverflowConst.getBitWidth() >= 64)
361       return false;
362 
363     APInt NewConst = -((-OverflowConst).zext(64));
364     if (!TLI->isLegalAddImmediate(NewConst.getSExtValue()))
365       return false;
366   }
367 
368   SafeWrap.insert(I);
369 
370   if (OverflowConst == 0 || OverflowConst.ugt(ICmpConst)) {
371     LLVM_DEBUG(dbgs() << "IR Promotion: Allowing safe overflow for "
372                       << "const of " << *I << "\n");
373     return true;
374   }
375 
376   LLVM_DEBUG(dbgs() << "IR Promotion: Allowing safe overflow for "
377                     << "const of " << *I << " and " << *CI << "\n");
378   SafeWrap.insert(CI);
379   return true;
380 }
381 
shouldPromote(Value * V)382 bool TypePromotionImpl::shouldPromote(Value *V) {
383   if (!isa<IntegerType>(V->getType()) || isSink(V))
384     return false;
385 
386   if (isSource(V))
387     return true;
388 
389   auto *I = dyn_cast<Instruction>(V);
390   if (!I)
391     return false;
392 
393   if (isa<ICmpInst>(I))
394     return false;
395 
396   return true;
397 }
398 
399 /// Return whether we can safely mutate V's type to ExtTy without having to be
400 /// concerned with zero extending or truncation.
isPromotedResultSafe(Instruction * I)401 static bool isPromotedResultSafe(Instruction *I) {
402   if (GenerateSignBits(I))
403     return false;
404 
405   if (!isa<OverflowingBinaryOperator>(I))
406     return true;
407 
408   return I->hasNoUnsignedWrap();
409 }
410 
ReplaceAllUsersOfWith(Value * From,Value * To)411 void IRPromoter::ReplaceAllUsersOfWith(Value *From, Value *To) {
412   SmallVector<Instruction *, 4> Users;
413   Instruction *InstTo = dyn_cast<Instruction>(To);
414   bool ReplacedAll = true;
415 
416   LLVM_DEBUG(dbgs() << "IR Promotion: Replacing " << *From << " with " << *To
417                     << "\n");
418 
419   for (Use &U : From->uses()) {
420     auto *User = cast<Instruction>(U.getUser());
421     if (InstTo && User->isIdenticalTo(InstTo)) {
422       ReplacedAll = false;
423       continue;
424     }
425     Users.push_back(User);
426   }
427 
428   for (auto *U : Users)
429     U->replaceUsesOfWith(From, To);
430 
431   if (ReplacedAll)
432     if (auto *I = dyn_cast<Instruction>(From))
433       InstsToRemove.insert(I);
434 }
435 
ExtendSources()436 void IRPromoter::ExtendSources() {
437   IRBuilder<> Builder{Ctx};
438 
439   auto InsertZExt = [&](Value *V, Instruction *InsertPt) {
440     assert(V->getType() != ExtTy && "zext already extends to i32");
441     LLVM_DEBUG(dbgs() << "IR Promotion: Inserting ZExt for " << *V << "\n");
442     Builder.SetInsertPoint(InsertPt);
443     if (auto *I = dyn_cast<Instruction>(V))
444       Builder.SetCurrentDebugLocation(I->getDebugLoc());
445 
446     Value *ZExt = Builder.CreateZExt(V, ExtTy);
447     if (auto *I = dyn_cast<Instruction>(ZExt)) {
448       if (isa<Argument>(V))
449         I->moveBefore(InsertPt);
450       else
451         I->moveAfter(InsertPt);
452       NewInsts.insert(I);
453     }
454 
455     ReplaceAllUsersOfWith(V, ZExt);
456   };
457 
458   // Now, insert extending instructions between the sources and their users.
459   LLVM_DEBUG(dbgs() << "IR Promotion: Promoting sources:\n");
460   for (auto *V : Sources) {
461     LLVM_DEBUG(dbgs() << " - " << *V << "\n");
462     if (auto *I = dyn_cast<Instruction>(V))
463       InsertZExt(I, I);
464     else if (auto *Arg = dyn_cast<Argument>(V)) {
465       BasicBlock &BB = Arg->getParent()->front();
466       InsertZExt(Arg, &*BB.getFirstInsertionPt());
467     } else {
468       llvm_unreachable("unhandled source that needs extending");
469     }
470     Promoted.insert(V);
471   }
472 }
473 
PromoteTree()474 void IRPromoter::PromoteTree() {
475   LLVM_DEBUG(dbgs() << "IR Promotion: Mutating the tree..\n");
476 
477   // Mutate the types of the instructions within the tree. Here we handle
478   // constant operands.
479   for (auto *V : Visited) {
480     if (Sources.count(V))
481       continue;
482 
483     auto *I = cast<Instruction>(V);
484     if (Sinks.count(I))
485       continue;
486 
487     for (unsigned i = 0, e = I->getNumOperands(); i < e; ++i) {
488       Value *Op = I->getOperand(i);
489       if ((Op->getType() == ExtTy) || !isa<IntegerType>(Op->getType()))
490         continue;
491 
492       if (auto *Const = dyn_cast<ConstantInt>(Op)) {
493         // For subtract, we only need to zext the constant. We only put it in
494         // SafeWrap because SafeWrap.size() is used elsewhere.
495         // For Add and ICmp we need to find how far the constant is from the
496         // top of its original unsigned range and place it the same distance
497         // from the top of its new unsigned range. We can do this by negating
498         // the constant, zero extending it, then negating in the new type.
499         APInt NewConst;
500         if (SafeWrap.contains(I)) {
501           if (I->getOpcode() == Instruction::ICmp)
502             NewConst = -((-Const->getValue()).zext(PromotedWidth));
503           else if (I->getOpcode() == Instruction::Add && i == 1)
504             NewConst = -((-Const->getValue()).zext(PromotedWidth));
505           else
506             NewConst = Const->getValue().zext(PromotedWidth);
507         } else
508           NewConst = Const->getValue().zext(PromotedWidth);
509 
510         I->setOperand(i, ConstantInt::get(Const->getContext(), NewConst));
511       } else if (isa<UndefValue>(Op))
512         I->setOperand(i, ConstantInt::get(ExtTy, 0));
513     }
514 
515     // Mutate the result type, unless this is an icmp or switch.
516     if (!isa<ICmpInst>(I) && !isa<SwitchInst>(I)) {
517       I->mutateType(ExtTy);
518       Promoted.insert(I);
519     }
520   }
521 }
522 
TruncateSinks()523 void IRPromoter::TruncateSinks() {
524   LLVM_DEBUG(dbgs() << "IR Promotion: Fixing up the sinks:\n");
525 
526   IRBuilder<> Builder{Ctx};
527 
528   auto InsertTrunc = [&](Value *V, Type *TruncTy) -> Instruction * {
529     if (!isa<Instruction>(V) || !isa<IntegerType>(V->getType()))
530       return nullptr;
531 
532     if ((!Promoted.count(V) && !NewInsts.count(V)) || Sources.count(V))
533       return nullptr;
534 
535     LLVM_DEBUG(dbgs() << "IR Promotion: Creating " << *TruncTy << " Trunc for "
536                       << *V << "\n");
537     Builder.SetInsertPoint(cast<Instruction>(V));
538     auto *Trunc = dyn_cast<Instruction>(Builder.CreateTrunc(V, TruncTy));
539     if (Trunc)
540       NewInsts.insert(Trunc);
541     return Trunc;
542   };
543 
544   // Fix up any stores or returns that use the results of the promoted
545   // chain.
546   for (auto *I : Sinks) {
547     LLVM_DEBUG(dbgs() << "IR Promotion: For Sink: " << *I << "\n");
548 
549     // Handle calls separately as we need to iterate over arg operands.
550     if (auto *Call = dyn_cast<CallInst>(I)) {
551       for (unsigned i = 0; i < Call->arg_size(); ++i) {
552         Value *Arg = Call->getArgOperand(i);
553         Type *Ty = TruncTysMap[Call][i];
554         if (Instruction *Trunc = InsertTrunc(Arg, Ty)) {
555           Trunc->moveBefore(Call);
556           Call->setArgOperand(i, Trunc);
557         }
558       }
559       continue;
560     }
561 
562     // Special case switches because we need to truncate the condition.
563     if (auto *Switch = dyn_cast<SwitchInst>(I)) {
564       Type *Ty = TruncTysMap[Switch][0];
565       if (Instruction *Trunc = InsertTrunc(Switch->getCondition(), Ty)) {
566         Trunc->moveBefore(Switch);
567         Switch->setCondition(Trunc);
568       }
569       continue;
570     }
571 
572     // Don't insert a trunc for a zext which can still legally promote.
573     // Nor insert a trunc when the input value to that trunc has the same width
574     // as the zext we are inserting it for.  When this happens the input operand
575     // for the zext will be promoted to the same width as the zext's return type
576     // rendering that zext unnecessary.  This zext gets removed before the end
577     // of the pass.
578     if (auto ZExt = dyn_cast<ZExtInst>(I))
579       if (ZExt->getType()->getScalarSizeInBits() >= PromotedWidth)
580         continue;
581 
582     // Now handle the others.
583     for (unsigned i = 0; i < I->getNumOperands(); ++i) {
584       Type *Ty = TruncTysMap[I][i];
585       if (Instruction *Trunc = InsertTrunc(I->getOperand(i), Ty)) {
586         Trunc->moveBefore(I);
587         I->setOperand(i, Trunc);
588       }
589     }
590   }
591 }
592 
Cleanup()593 void IRPromoter::Cleanup() {
594   LLVM_DEBUG(dbgs() << "IR Promotion: Cleanup..\n");
595   // Some zexts will now have become redundant, along with their trunc
596   // operands, so remove them.
597   for (auto *V : Visited) {
598     if (!isa<ZExtInst>(V))
599       continue;
600 
601     auto ZExt = cast<ZExtInst>(V);
602     if (ZExt->getDestTy() != ExtTy)
603       continue;
604 
605     Value *Src = ZExt->getOperand(0);
606     if (ZExt->getSrcTy() == ZExt->getDestTy()) {
607       LLVM_DEBUG(dbgs() << "IR Promotion: Removing unnecessary cast: " << *ZExt
608                         << "\n");
609       ReplaceAllUsersOfWith(ZExt, Src);
610       continue;
611     }
612 
613     // We've inserted a trunc for a zext sink, but we already know that the
614     // input is in range, negating the need for the trunc.
615     if (NewInsts.count(Src) && isa<TruncInst>(Src)) {
616       auto *Trunc = cast<TruncInst>(Src);
617       assert(Trunc->getOperand(0)->getType() == ExtTy &&
618              "expected inserted trunc to be operating on i32");
619       ReplaceAllUsersOfWith(ZExt, Trunc->getOperand(0));
620     }
621   }
622 
623   for (auto *I : InstsToRemove) {
624     LLVM_DEBUG(dbgs() << "IR Promotion: Removing " << *I << "\n");
625     I->dropAllReferences();
626   }
627 }
628 
ConvertTruncs()629 void IRPromoter::ConvertTruncs() {
630   LLVM_DEBUG(dbgs() << "IR Promotion: Converting truncs..\n");
631   IRBuilder<> Builder{Ctx};
632 
633   for (auto *V : Visited) {
634     if (!isa<TruncInst>(V) || Sources.count(V))
635       continue;
636 
637     auto *Trunc = cast<TruncInst>(V);
638     Builder.SetInsertPoint(Trunc);
639     IntegerType *SrcTy = cast<IntegerType>(Trunc->getOperand(0)->getType());
640     IntegerType *DestTy = cast<IntegerType>(TruncTysMap[Trunc][0]);
641 
642     unsigned NumBits = DestTy->getScalarSizeInBits();
643     ConstantInt *Mask =
644         ConstantInt::get(SrcTy, APInt::getMaxValue(NumBits).getZExtValue());
645     Value *Masked = Builder.CreateAnd(Trunc->getOperand(0), Mask);
646     if (SrcTy->getBitWidth() > ExtTy->getBitWidth())
647       Masked = Builder.CreateTrunc(Masked, ExtTy);
648 
649     if (auto *I = dyn_cast<Instruction>(Masked))
650       NewInsts.insert(I);
651 
652     ReplaceAllUsersOfWith(Trunc, Masked);
653   }
654 }
655 
Mutate()656 void IRPromoter::Mutate() {
657   LLVM_DEBUG(dbgs() << "IR Promotion: Promoting use-def chains to "
658                     << PromotedWidth << "-bits\n");
659 
660   // Cache original types of the values that will likely need truncating
661   for (auto *I : Sinks) {
662     if (auto *Call = dyn_cast<CallInst>(I)) {
663       for (Value *Arg : Call->args())
664         TruncTysMap[Call].push_back(Arg->getType());
665     } else if (auto *Switch = dyn_cast<SwitchInst>(I))
666       TruncTysMap[I].push_back(Switch->getCondition()->getType());
667     else {
668       for (unsigned i = 0; i < I->getNumOperands(); ++i)
669         TruncTysMap[I].push_back(I->getOperand(i)->getType());
670     }
671   }
672   for (auto *V : Visited) {
673     if (!isa<TruncInst>(V) || Sources.count(V))
674       continue;
675     auto *Trunc = cast<TruncInst>(V);
676     TruncTysMap[Trunc].push_back(Trunc->getDestTy());
677   }
678 
679   // Insert zext instructions between sources and their users.
680   ExtendSources();
681 
682   // Promote visited instructions, mutating their types in place.
683   PromoteTree();
684 
685   // Convert any truncs, that aren't sources, into AND masks.
686   ConvertTruncs();
687 
688   // Insert trunc instructions for use by calls, stores etc...
689   TruncateSinks();
690 
691   // Finally, remove unecessary zexts and truncs, delete old instructions and
692   // clear the data structures.
693   Cleanup();
694 
695   LLVM_DEBUG(dbgs() << "IR Promotion: Mutation complete\n");
696 }
697 
698 /// We disallow booleans to make life easier when dealing with icmps but allow
699 /// any other integer that fits in a scalar register. Void types are accepted
700 /// so we can handle switches.
isSupportedType(Value * V)701 bool TypePromotionImpl::isSupportedType(Value *V) {
702   Type *Ty = V->getType();
703 
704   // Allow voids and pointers, these won't be promoted.
705   if (Ty->isVoidTy() || Ty->isPointerTy())
706     return true;
707 
708   if (!isa<IntegerType>(Ty) || cast<IntegerType>(Ty)->getBitWidth() == 1 ||
709       cast<IntegerType>(Ty)->getBitWidth() > RegisterBitWidth)
710     return false;
711 
712   return LessOrEqualTypeSize(V);
713 }
714 
715 /// We accept most instructions, as well as Arguments and ConstantInsts. We
716 /// Disallow casts other than zext and truncs and only allow calls if their
717 /// return value is zeroext. We don't allow opcodes that can introduce sign
718 /// bits.
isSupportedValue(Value * V)719 bool TypePromotionImpl::isSupportedValue(Value *V) {
720   if (auto *I = dyn_cast<Instruction>(V)) {
721     switch (I->getOpcode()) {
722     default:
723       return isa<BinaryOperator>(I) && isSupportedType(I) &&
724              !GenerateSignBits(I);
725     case Instruction::GetElementPtr:
726     case Instruction::Store:
727     case Instruction::Br:
728     case Instruction::Switch:
729       return true;
730     case Instruction::PHI:
731     case Instruction::Select:
732     case Instruction::Ret:
733     case Instruction::Load:
734     case Instruction::Trunc:
735       return isSupportedType(I);
736     case Instruction::BitCast:
737       return I->getOperand(0)->getType() == I->getType();
738     case Instruction::ZExt:
739       return isSupportedType(I->getOperand(0));
740     case Instruction::ICmp:
741       // Now that we allow small types than TypeSize, only allow icmp of
742       // TypeSize because they will require a trunc to be legalised.
743       // TODO: Allow icmp of smaller types, and calculate at the end
744       // whether the transform would be beneficial.
745       if (isa<PointerType>(I->getOperand(0)->getType()))
746         return true;
747       return EqualTypeSize(I->getOperand(0));
748     case Instruction::Call: {
749       // Special cases for calls as we need to check for zeroext
750       // TODO We should accept calls even if they don't have zeroext, as they
751       // can still be sinks.
752       auto *Call = cast<CallInst>(I);
753       return isSupportedType(Call) &&
754              Call->hasRetAttr(Attribute::AttrKind::ZExt);
755     }
756     }
757   } else if (isa<Constant>(V) && !isa<ConstantExpr>(V)) {
758     return isSupportedType(V);
759   } else if (isa<Argument>(V))
760     return isSupportedType(V);
761 
762   return isa<BasicBlock>(V);
763 }
764 
765 /// Check that the type of V would be promoted and that the original type is
766 /// smaller than the targeted promoted type. Check that we're not trying to
767 /// promote something larger than our base 'TypeSize' type.
isLegalToPromote(Value * V)768 bool TypePromotionImpl::isLegalToPromote(Value *V) {
769   auto *I = dyn_cast<Instruction>(V);
770   if (!I)
771     return true;
772 
773   if (SafeToPromote.count(I))
774     return true;
775 
776   if (isPromotedResultSafe(I) || isSafeWrap(I)) {
777     SafeToPromote.insert(I);
778     return true;
779   }
780   return false;
781 }
782 
TryToPromote(Value * V,unsigned PromotedWidth,const LoopInfo & LI)783 bool TypePromotionImpl::TryToPromote(Value *V, unsigned PromotedWidth,
784                                  const LoopInfo &LI) {
785   Type *OrigTy = V->getType();
786   TypeSize = OrigTy->getPrimitiveSizeInBits().getFixedValue();
787   SafeToPromote.clear();
788   SafeWrap.clear();
789 
790   if (!isSupportedValue(V) || !shouldPromote(V) || !isLegalToPromote(V))
791     return false;
792 
793   LLVM_DEBUG(dbgs() << "IR Promotion: TryToPromote: " << *V << ", from "
794                     << TypeSize << " bits to " << PromotedWidth << "\n");
795 
796   SetVector<Value *> WorkList;
797   SetVector<Value *> Sources;
798   SetVector<Instruction *> Sinks;
799   SetVector<Value *> CurrentVisited;
800   WorkList.insert(V);
801 
802   // Return true if V was added to the worklist as a supported instruction,
803   // if it was already visited, or if we don't need to explore it (e.g.
804   // pointer values and GEPs), and false otherwise.
805   auto AddLegalInst = [&](Value *V) {
806     if (CurrentVisited.count(V))
807       return true;
808 
809     // Ignore GEPs because they don't need promoting and the constant indices
810     // will prevent the transformation.
811     if (isa<GetElementPtrInst>(V))
812       return true;
813 
814     if (!isSupportedValue(V) || (shouldPromote(V) && !isLegalToPromote(V))) {
815       LLVM_DEBUG(dbgs() << "IR Promotion: Can't handle: " << *V << "\n");
816       return false;
817     }
818 
819     WorkList.insert(V);
820     return true;
821   };
822 
823   // Iterate through, and add to, a tree of operands and users in the use-def.
824   while (!WorkList.empty()) {
825     Value *V = WorkList.pop_back_val();
826     if (CurrentVisited.count(V))
827       continue;
828 
829     // Ignore non-instructions, other than arguments.
830     if (!isa<Instruction>(V) && !isSource(V))
831       continue;
832 
833     // If we've already visited this value from somewhere, bail now because
834     // the tree has already been explored.
835     // TODO: This could limit the transform, ie if we try to promote something
836     // from an i8 and fail first, before trying an i16.
837     if (AllVisited.count(V))
838       return false;
839 
840     CurrentVisited.insert(V);
841     AllVisited.insert(V);
842 
843     // Calls can be both sources and sinks.
844     if (isSink(V))
845       Sinks.insert(cast<Instruction>(V));
846 
847     if (isSource(V))
848       Sources.insert(V);
849 
850     if (!isSink(V) && !isSource(V)) {
851       if (auto *I = dyn_cast<Instruction>(V)) {
852         // Visit operands of any instruction visited.
853         for (auto &U : I->operands()) {
854           if (!AddLegalInst(U))
855             return false;
856         }
857       }
858     }
859 
860     // Don't visit users of a node which isn't going to be mutated unless its a
861     // source.
862     if (isSource(V) || shouldPromote(V)) {
863       for (Use &U : V->uses()) {
864         if (!AddLegalInst(U.getUser()))
865           return false;
866       }
867     }
868   }
869 
870   LLVM_DEBUG({
871     dbgs() << "IR Promotion: Visited nodes:\n";
872     for (auto *I : CurrentVisited)
873       I->dump();
874   });
875 
876   unsigned ToPromote = 0;
877   unsigned NonFreeArgs = 0;
878   unsigned NonLoopSources = 0, LoopSinks = 0;
879   SmallPtrSet<BasicBlock *, 4> Blocks;
880   for (auto *CV : CurrentVisited) {
881     if (auto *I = dyn_cast<Instruction>(CV))
882       Blocks.insert(I->getParent());
883 
884     if (Sources.count(CV)) {
885       if (auto *Arg = dyn_cast<Argument>(CV))
886         if (!Arg->hasZExtAttr() && !Arg->hasSExtAttr())
887           ++NonFreeArgs;
888       if (!isa<Instruction>(CV) ||
889           !LI.getLoopFor(cast<Instruction>(CV)->getParent()))
890         ++NonLoopSources;
891       continue;
892     }
893 
894     if (isa<PHINode>(CV))
895       continue;
896     if (LI.getLoopFor(cast<Instruction>(CV)->getParent()))
897       ++LoopSinks;
898     if (Sinks.count(cast<Instruction>(CV)))
899       continue;
900     ++ToPromote;
901   }
902 
903   // DAG optimizations should be able to handle these cases better, especially
904   // for function arguments.
905   if (!isa<PHINode>(V) && !(LoopSinks && NonLoopSources) &&
906       (ToPromote < 2 || (Blocks.size() == 1 && NonFreeArgs > SafeWrap.size())))
907     return false;
908 
909   IRPromoter Promoter(*Ctx, PromotedWidth, CurrentVisited, Sources, Sinks,
910                       SafeWrap, InstsToRemove);
911   Promoter.Mutate();
912   return true;
913 }
914 
run(Function & F,const TargetMachine * TM,const TargetTransformInfo & TTI,const LoopInfo & LI)915 bool TypePromotionImpl::run(Function &F, const TargetMachine *TM,
916                             const TargetTransformInfo &TTI,
917                             const LoopInfo &LI) {
918   if (DisablePromotion)
919     return false;
920 
921   LLVM_DEBUG(dbgs() << "IR Promotion: Running on " << F.getName() << "\n");
922 
923   AllVisited.clear();
924   SafeToPromote.clear();
925   SafeWrap.clear();
926   bool MadeChange = false;
927   const DataLayout &DL = F.getDataLayout();
928   const TargetSubtargetInfo *SubtargetInfo = TM->getSubtargetImpl(F);
929   TLI = SubtargetInfo->getTargetLowering();
930   RegisterBitWidth =
931       TTI.getRegisterBitWidth(TargetTransformInfo::RGK_Scalar).getFixedValue();
932   Ctx = &F.getContext();
933 
934   // Return the preferred integer width of the instruction, or zero if we
935   // shouldn't try.
936   auto GetPromoteWidth = [&](Instruction *I) -> uint32_t {
937     if (!isa<IntegerType>(I->getType()))
938       return 0;
939 
940     EVT SrcVT = TLI->getValueType(DL, I->getType());
941     if (SrcVT.isSimple() && TLI->isTypeLegal(SrcVT.getSimpleVT()))
942       return 0;
943 
944     if (TLI->getTypeAction(*Ctx, SrcVT) != TargetLowering::TypePromoteInteger)
945       return 0;
946 
947     EVT PromotedVT = TLI->getTypeToTransformTo(*Ctx, SrcVT);
948     if (TLI->isSExtCheaperThanZExt(SrcVT, PromotedVT))
949       return 0;
950     if (RegisterBitWidth < PromotedVT.getFixedSizeInBits()) {
951       LLVM_DEBUG(dbgs() << "IR Promotion: Couldn't find target register "
952                         << "for promoted type\n");
953       return 0;
954     }
955 
956     // TODO: Should we prefer to use RegisterBitWidth instead?
957     return PromotedVT.getFixedSizeInBits();
958   };
959 
960   auto BBIsInLoop = [&](BasicBlock *BB) -> bool {
961     for (auto *L : LI)
962       if (L->contains(BB))
963         return true;
964     return false;
965   };
966 
967   for (BasicBlock &BB : F) {
968     for (Instruction &I : BB) {
969       if (AllVisited.count(&I))
970         continue;
971 
972       if (isa<ZExtInst>(&I) && isa<PHINode>(I.getOperand(0)) &&
973           isa<IntegerType>(I.getType()) && BBIsInLoop(&BB)) {
974         LLVM_DEBUG(dbgs() << "IR Promotion: Searching from: "
975                           << *I.getOperand(0) << "\n");
976         EVT ZExtVT = TLI->getValueType(DL, I.getType());
977         Instruction *Phi = static_cast<Instruction *>(I.getOperand(0));
978         auto PromoteWidth = ZExtVT.getFixedSizeInBits();
979         if (RegisterBitWidth < PromoteWidth) {
980           LLVM_DEBUG(dbgs() << "IR Promotion: Couldn't find target "
981                             << "register for ZExt type\n");
982           continue;
983         }
984         MadeChange |= TryToPromote(Phi, PromoteWidth, LI);
985       } else if (auto *ICmp = dyn_cast<ICmpInst>(&I)) {
986         // Search up from icmps to try to promote their operands.
987         // Skip signed or pointer compares
988         if (ICmp->isSigned())
989           continue;
990 
991         LLVM_DEBUG(dbgs() << "IR Promotion: Searching from: " << *ICmp << "\n");
992 
993         for (auto &Op : ICmp->operands()) {
994           if (auto *OpI = dyn_cast<Instruction>(Op)) {
995             if (auto PromotedWidth = GetPromoteWidth(OpI)) {
996               MadeChange |= TryToPromote(OpI, PromotedWidth, LI);
997               break;
998             }
999           }
1000         }
1001       }
1002     }
1003     if (!InstsToRemove.empty()) {
1004       for (auto *I : InstsToRemove)
1005         I->eraseFromParent();
1006       InstsToRemove.clear();
1007     }
1008   }
1009 
1010   AllVisited.clear();
1011   SafeToPromote.clear();
1012   SafeWrap.clear();
1013 
1014   return MadeChange;
1015 }
1016 
1017 INITIALIZE_PASS_BEGIN(TypePromotionLegacy, DEBUG_TYPE, PASS_NAME, false, false)
1018 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
1019 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
1020 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
1021 INITIALIZE_PASS_END(TypePromotionLegacy, DEBUG_TYPE, PASS_NAME, false, false)
1022 
1023 char TypePromotionLegacy::ID = 0;
1024 
runOnFunction(Function & F)1025 bool TypePromotionLegacy::runOnFunction(Function &F) {
1026   if (skipFunction(F))
1027     return false;
1028 
1029   auto &TPC = getAnalysis<TargetPassConfig>();
1030   auto *TM = &TPC.getTM<TargetMachine>();
1031   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1032   auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1033 
1034   TypePromotionImpl TP;
1035   return TP.run(F, TM, TTI, LI);
1036 }
1037 
createTypePromotionLegacyPass()1038 FunctionPass *llvm::createTypePromotionLegacyPass() {
1039   return new TypePromotionLegacy();
1040 }
1041 
run(Function & F,FunctionAnalysisManager & AM)1042 PreservedAnalyses TypePromotionPass::run(Function &F,
1043                                          FunctionAnalysisManager &AM) {
1044   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1045   auto &LI = AM.getResult<LoopAnalysis>(F);
1046   TypePromotionImpl TP;
1047 
1048   bool Changed = TP.run(F, TM, TTI, LI);
1049   if (!Changed)
1050     return PreservedAnalyses::all();
1051 
1052   PreservedAnalyses PA;
1053   PA.preserveSet<CFGAnalyses>();
1054   PA.preserve<LoopAnalysis>();
1055   return PA;
1056 }
1057