xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AMDGPU/AMDGPULateCodeGenPrepare.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction
11 /// selection.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/UniformityAnalysis.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/InstVisitor.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/KnownBits.h"
26 #include "llvm/Transforms/Utils/Local.h"
27 
28 #define DEBUG_TYPE "amdgpu-late-codegenprepare"
29 
30 using namespace llvm;
31 
32 // Scalar load widening needs running after load-store-vectorizer as that pass
33 // doesn't handle overlapping cases. In addition, this pass enhances the
34 // widening to handle cases where scalar sub-dword loads are naturally aligned
35 // only but not dword aligned.
36 static cl::opt<bool>
37     WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads",
38                cl::desc("Widen sub-dword constant address space loads in "
39                         "AMDGPULateCodeGenPrepare"),
40                cl::ReallyHidden, cl::init(true));
41 
42 namespace {
43 
44 class AMDGPULateCodeGenPrepare
45     : public FunctionPass,
46       public InstVisitor<AMDGPULateCodeGenPrepare, bool> {
47   Module *Mod = nullptr;
48   const DataLayout *DL = nullptr;
49 
50   AssumptionCache *AC = nullptr;
51   UniformityInfo *UA = nullptr;
52 
53   SmallVector<WeakTrackingVH, 8> DeadInsts;
54 
55 public:
56   static char ID;
57 
AMDGPULateCodeGenPrepare()58   AMDGPULateCodeGenPrepare() : FunctionPass(ID) {}
59 
getPassName() const60   StringRef getPassName() const override {
61     return "AMDGPU IR late optimizations";
62   }
63 
getAnalysisUsage(AnalysisUsage & AU) const64   void getAnalysisUsage(AnalysisUsage &AU) const override {
65     AU.addRequired<TargetPassConfig>();
66     AU.addRequired<AssumptionCacheTracker>();
67     AU.addRequired<UniformityInfoWrapperPass>();
68     AU.setPreservesAll();
69   }
70 
71   bool doInitialization(Module &M) override;
72   bool runOnFunction(Function &F) override;
73 
visitInstruction(Instruction &)74   bool visitInstruction(Instruction &) { return false; }
75 
76   // Check if the specified value is at least DWORD aligned.
isDWORDAligned(const Value * V) const77   bool isDWORDAligned(const Value *V) const {
78     KnownBits Known = computeKnownBits(V, *DL, 0, AC);
79     return Known.countMinTrailingZeros() >= 2;
80   }
81 
82   bool canWidenScalarExtLoad(LoadInst &LI) const;
83   bool visitLoadInst(LoadInst &LI);
84 };
85 
86 using ValueToValueMap = DenseMap<const Value *, Value *>;
87 
88 class LiveRegOptimizer {
89 private:
90   Module *Mod = nullptr;
91   const DataLayout *DL = nullptr;
92   const GCNSubtarget *ST;
93   /// The scalar type to convert to
94   Type *ConvertToScalar;
95   /// The set of visited Instructions
96   SmallPtrSet<Instruction *, 4> Visited;
97   /// Map of Value -> Converted Value
98   ValueToValueMap ValMap;
99   /// Map of containing conversions from Optimal Type -> Original Type per BB.
100   DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101 
102 public:
103   /// Calculate the and \p return  the type to convert to given a problematic \p
104   /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105   Type *calculateConvertType(Type *OriginalType);
106   /// Convert the virtual register defined by \p V to the compatible vector of
107   /// legal type
108   Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt);
109   /// Convert the virtual register defined by \p V back to the original type \p
110   /// ConvertType, stripping away the MSBs in cases where there was an imperfect
111   /// fit (e.g. v2i32 -> v7i8)
112   Value *convertFromOptType(Type *ConvertType, Instruction *V,
113                             BasicBlock::iterator &InstPt,
114                             BasicBlock *InsertBlock);
115   /// Check for problematic PHI nodes or cross-bb values based on the value
116   /// defined by \p I, and coerce to legal types if necessary. For problematic
117   /// PHI node, we coerce all incoming values in a single invocation.
118   bool optimizeLiveType(Instruction *I,
119                         SmallVectorImpl<WeakTrackingVH> &DeadInsts);
120 
121   // Whether or not the type should be replaced to avoid inefficient
122   // legalization code
shouldReplace(Type * ITy)123   bool shouldReplace(Type *ITy) {
124     FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
125     if (!VTy)
126       return false;
127 
128     auto TLI = ST->getTargetLowering();
129 
130     Type *EltTy = VTy->getElementType();
131     // If the element size is not less than the convert to scalar size, then we
132     // can't do any bit packing
133     if (!EltTy->isIntegerTy() ||
134         EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits())
135       return false;
136 
137     // Only coerce illegal types
138     TargetLoweringBase::LegalizeKind LK =
139         TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false));
140     return LK.first != TargetLoweringBase::TypeLegal;
141   }
142 
LiveRegOptimizer(Module * Mod,const GCNSubtarget * ST)143   LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
144     DL = &Mod->getDataLayout();
145     ConvertToScalar = Type::getInt32Ty(Mod->getContext());
146   }
147 };
148 
149 } // end anonymous namespace
150 
doInitialization(Module & M)151 bool AMDGPULateCodeGenPrepare::doInitialization(Module &M) {
152   Mod = &M;
153   DL = &Mod->getDataLayout();
154   return false;
155 }
156 
runOnFunction(Function & F)157 bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
158   if (skipFunction(F))
159     return false;
160 
161   const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
162   const TargetMachine &TM = TPC.getTM<TargetMachine>();
163   const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
164 
165   AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
166   UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
167 
168   // "Optimize" the virtual regs that cross basic block boundaries. When
169   // building the SelectionDAG, vectors of illegal types that cross basic blocks
170   // will be scalarized and widened, with each scalar living in its
171   // own register. To work around this, this optimization converts the
172   // vectors to equivalent vectors of legal type (which are converted back
173   // before uses in subsequent blocks), to pack the bits into fewer physical
174   // registers (used in CopyToReg/CopyFromReg pairs).
175   LiveRegOptimizer LRO(Mod, &ST);
176 
177   bool Changed = false;
178 
179   bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads();
180 
181   for (auto &BB : reverse(F))
182     for (Instruction &I : make_early_inc_range(reverse(BB))) {
183       Changed |= !HasScalarSubwordLoads && visit(I);
184       Changed |= LRO.optimizeLiveType(&I, DeadInsts);
185     }
186 
187   RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts);
188   return Changed;
189 }
190 
calculateConvertType(Type * OriginalType)191 Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
192   assert(OriginalType->getScalarSizeInBits() <=
193          ConvertToScalar->getScalarSizeInBits());
194 
195   FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
196 
197   TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
198   TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
199   unsigned ConvertEltCount =
200       (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
201 
202   if (OriginalSize <= ConvertScalarSize)
203     return IntegerType::get(Mod->getContext(), ConvertScalarSize);
204 
205   return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
206                          ConvertEltCount, false);
207 }
208 
convertToOptType(Instruction * V,BasicBlock::iterator & InsertPt)209 Value *LiveRegOptimizer::convertToOptType(Instruction *V,
210                                           BasicBlock::iterator &InsertPt) {
211   FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
212   Type *NewTy = calculateConvertType(V->getType());
213 
214   TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
215   TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
216 
217   IRBuilder<> Builder(V->getParent(), InsertPt);
218   // If there is a bitsize match, we can fit the old vector into a new vector of
219   // desired type.
220   if (OriginalSize == NewSize)
221     return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc");
222 
223   // If there is a bitsize mismatch, we must use a wider vector.
224   assert(NewSize > OriginalSize);
225   uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
226 
227   SmallVector<int, 8> ShuffleMask;
228   uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue();
229   for (unsigned I = 0; I < OriginalElementCount; I++)
230     ShuffleMask.push_back(I);
231 
232   for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
233     ShuffleMask.push_back(OriginalElementCount);
234 
235   Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask);
236   return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc");
237 }
238 
convertFromOptType(Type * ConvertType,Instruction * V,BasicBlock::iterator & InsertPt,BasicBlock * InsertBB)239 Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
240                                             BasicBlock::iterator &InsertPt,
241                                             BasicBlock *InsertBB) {
242   FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
243 
244   TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
245   TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
246 
247   IRBuilder<> Builder(InsertBB, InsertPt);
248   // If there is a bitsize match, we simply convert back to the original type.
249   if (OriginalSize == NewSize)
250     return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc");
251 
252   // If there is a bitsize mismatch, then we must have used a wider value to
253   // hold the bits.
254   assert(OriginalSize > NewSize);
255   // For wide scalars, we can just truncate the value.
256   if (!V->getType()->isVectorTy()) {
257     Instruction *Trunc = cast<Instruction>(
258         Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
259     return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
260   }
261 
262   // For wider vectors, we must strip the MSBs to convert back to the original
263   // type.
264   VectorType *ExpandedVT = VectorType::get(
265       Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
266       (OriginalSize / NewVTy->getScalarSizeInBits()), false);
267   Instruction *Converted =
268       cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
269 
270   unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
271   SmallVector<int, 8> ShuffleMask(NarrowElementCount);
272   std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
273 
274   return Builder.CreateShuffleVector(Converted, ShuffleMask);
275 }
276 
optimizeLiveType(Instruction * I,SmallVectorImpl<WeakTrackingVH> & DeadInsts)277 bool LiveRegOptimizer::optimizeLiveType(
278     Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
279   SmallVector<Instruction *, 4> Worklist;
280   SmallPtrSet<PHINode *, 4> PhiNodes;
281   SmallPtrSet<Instruction *, 4> Defs;
282   SmallPtrSet<Instruction *, 4> Uses;
283 
284   Worklist.push_back(cast<Instruction>(I));
285   while (!Worklist.empty()) {
286     Instruction *II = Worklist.pop_back_val();
287 
288     if (!Visited.insert(II).second)
289       continue;
290 
291     if (!shouldReplace(II->getType()))
292       continue;
293 
294     if (PHINode *Phi = dyn_cast<PHINode>(II)) {
295       PhiNodes.insert(Phi);
296       // Collect all the incoming values of problematic PHI nodes.
297       for (Value *V : Phi->incoming_values()) {
298         // Repeat the collection process for newly found PHI nodes.
299         if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
300           if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
301             Worklist.push_back(OpPhi);
302           continue;
303         }
304 
305         Instruction *IncInst = dyn_cast<Instruction>(V);
306         // Other incoming value types (e.g. vector literals) are unhandled
307         if (!IncInst && !isa<ConstantAggregateZero>(V))
308           return false;
309 
310         // Collect all other incoming values for coercion.
311         if (IncInst)
312           Defs.insert(IncInst);
313       }
314     }
315 
316     // Collect all relevant uses.
317     for (User *V : II->users()) {
318       // Repeat the collection process for problematic PHI nodes.
319       if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
320         if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
321           Worklist.push_back(OpPhi);
322         continue;
323       }
324 
325       Instruction *UseInst = cast<Instruction>(V);
326       // Collect all uses of PHINodes and any use the crosses BB boundaries.
327       if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) {
328         Uses.insert(UseInst);
329         if (!Defs.count(II) && !isa<PHINode>(II)) {
330           Defs.insert(II);
331         }
332       }
333     }
334   }
335 
336   // Coerce and track the defs.
337   for (Instruction *D : Defs) {
338     if (!ValMap.contains(D)) {
339       BasicBlock::iterator InsertPt = std::next(D->getIterator());
340       Value *ConvertVal = convertToOptType(D, InsertPt);
341       assert(ConvertVal);
342       ValMap[D] = ConvertVal;
343     }
344   }
345 
346   // Construct new-typed PHI nodes.
347   for (PHINode *Phi : PhiNodes) {
348     ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()),
349                                   Phi->getNumIncomingValues(),
350                                   Phi->getName() + ".tc", Phi->getIterator());
351   }
352 
353   // Connect all the PHI nodes with their new incoming values.
354   for (PHINode *Phi : PhiNodes) {
355     PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
356     bool MissingIncVal = false;
357     for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
358       Value *IncVal = Phi->getIncomingValue(I);
359       if (isa<ConstantAggregateZero>(IncVal)) {
360         Type *NewType = calculateConvertType(Phi->getType());
361         NewPhi->addIncoming(ConstantInt::get(NewType, 0, false),
362                             Phi->getIncomingBlock(I));
363       } else if (ValMap.contains(IncVal) && ValMap[IncVal])
364         NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I));
365       else
366         MissingIncVal = true;
367     }
368     if (MissingIncVal) {
369       Value *DeadVal = ValMap[Phi];
370       // The coercion chain of the PHI is broken. Delete the Phi
371       // from the ValMap and any connected / user Phis.
372       SmallVector<Value *, 4> PHIWorklist;
373       SmallPtrSet<Value *, 4> VisitedPhis;
374       PHIWorklist.push_back(DeadVal);
375       while (!PHIWorklist.empty()) {
376         Value *NextDeadValue = PHIWorklist.pop_back_val();
377         VisitedPhis.insert(NextDeadValue);
378         auto OriginalPhi =
379             std::find_if(PhiNodes.begin(), PhiNodes.end(),
380                          [this, &NextDeadValue](PHINode *CandPhi) {
381                            return ValMap[CandPhi] == NextDeadValue;
382                          });
383         // This PHI may have already been removed from maps when
384         // unwinding a previous Phi
385         if (OriginalPhi != PhiNodes.end())
386           ValMap.erase(*OriginalPhi);
387 
388         DeadInsts.emplace_back(cast<Instruction>(NextDeadValue));
389 
390         for (User *U : NextDeadValue->users()) {
391           if (!VisitedPhis.contains(cast<PHINode>(U)))
392             PHIWorklist.push_back(U);
393         }
394       }
395     } else {
396       DeadInsts.emplace_back(cast<Instruction>(Phi));
397     }
398   }
399   // Coerce back to the original type and replace the uses.
400   for (Instruction *U : Uses) {
401     // Replace all converted operands for a use.
402     for (auto [OpIdx, Op] : enumerate(U->operands())) {
403       if (ValMap.contains(Op) && ValMap[Op]) {
404         Value *NewVal = nullptr;
405         if (BBUseValMap.contains(U->getParent()) &&
406             BBUseValMap[U->getParent()].contains(ValMap[Op]))
407           NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
408         else {
409           BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
410           // We may pick up ops that were previously converted for users in
411           // other blocks. If there is an originally typed definition of the Op
412           // already in this block, simply reuse it.
413           if (isa<Instruction>(Op) && !isa<PHINode>(Op) &&
414               U->getParent() == cast<Instruction>(Op)->getParent()) {
415             NewVal = Op;
416           } else {
417             NewVal =
418                 convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]),
419                                    InsertPt, U->getParent());
420             BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
421           }
422         }
423         assert(NewVal);
424         U->setOperand(OpIdx, NewVal);
425       }
426     }
427   }
428 
429   return true;
430 }
431 
canWidenScalarExtLoad(LoadInst & LI) const432 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
433   unsigned AS = LI.getPointerAddressSpace();
434   // Skip non-constant address space.
435   if (AS != AMDGPUAS::CONSTANT_ADDRESS &&
436       AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT)
437     return false;
438   // Skip non-simple loads.
439   if (!LI.isSimple())
440     return false;
441   Type *Ty = LI.getType();
442   // Skip aggregate types.
443   if (Ty->isAggregateType())
444     return false;
445   unsigned TySize = DL->getTypeStoreSize(Ty);
446   // Only handle sub-DWORD loads.
447   if (TySize >= 4)
448     return false;
449   // That load must be at least naturally aligned.
450   if (LI.getAlign() < DL->getABITypeAlign(Ty))
451     return false;
452   // It should be uniform, i.e. a scalar load.
453   return UA->isUniform(&LI);
454 }
455 
visitLoadInst(LoadInst & LI)456 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
457   if (!WidenLoads)
458     return false;
459 
460   // Skip if that load is already aligned on DWORD at least as it's handled in
461   // SDAG.
462   if (LI.getAlign() >= 4)
463     return false;
464 
465   if (!canWidenScalarExtLoad(LI))
466     return false;
467 
468   int64_t Offset = 0;
469   auto *Base =
470       GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, *DL);
471   // If that base is not DWORD aligned, it's not safe to perform the following
472   // transforms.
473   if (!isDWORDAligned(Base))
474     return false;
475 
476   int64_t Adjust = Offset & 0x3;
477   if (Adjust == 0) {
478     // With a zero adjust, the original alignment could be promoted with a
479     // better one.
480     LI.setAlignment(Align(4));
481     return true;
482   }
483 
484   IRBuilder<> IRB(&LI);
485   IRB.SetCurrentDebugLocation(LI.getDebugLoc());
486 
487   unsigned LdBits = DL->getTypeStoreSizeInBits(LI.getType());
488   auto IntNTy = Type::getIntNTy(LI.getContext(), LdBits);
489 
490   auto *NewPtr = IRB.CreateConstGEP1_64(
491       IRB.getInt8Ty(),
492       IRB.CreateAddrSpaceCast(Base, LI.getPointerOperand()->getType()),
493       Offset - Adjust);
494 
495   LoadInst *NewLd = IRB.CreateAlignedLoad(IRB.getInt32Ty(), NewPtr, Align(4));
496   NewLd->copyMetadata(LI);
497   NewLd->setMetadata(LLVMContext::MD_range, nullptr);
498 
499   unsigned ShAmt = Adjust * 8;
500   auto *NewVal = IRB.CreateBitCast(
501       IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt), IntNTy), LI.getType());
502   LI.replaceAllUsesWith(NewVal);
503   DeadInsts.emplace_back(&LI);
504 
505   return true;
506 }
507 
508 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepare, DEBUG_TYPE,
509                       "AMDGPU IR late optimizations", false, false)
510 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
511 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
512 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
513 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepare, DEBUG_TYPE,
514                     "AMDGPU IR late optimizations", false, false)
515 
516 char AMDGPULateCodeGenPrepare::ID = 0;
517 
createAMDGPULateCodeGenPreparePass()518 FunctionPass *llvm::createAMDGPULateCodeGenPreparePass() {
519   return new AMDGPULateCodeGenPrepare();
520 }
521