1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction
11 /// selection.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "AMDGPU.h"
16 #include "AMDGPUTargetMachine.h"
17 #include "llvm/Analysis/AssumptionCache.h"
18 #include "llvm/Analysis/UniformityAnalysis.h"
19 #include "llvm/Analysis/ValueTracking.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/InstVisitor.h"
23 #include "llvm/InitializePasses.h"
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/KnownBits.h"
26 #include "llvm/Transforms/Utils/Local.h"
27
28 #define DEBUG_TYPE "amdgpu-late-codegenprepare"
29
30 using namespace llvm;
31
32 // Scalar load widening needs running after load-store-vectorizer as that pass
33 // doesn't handle overlapping cases. In addition, this pass enhances the
34 // widening to handle cases where scalar sub-dword loads are naturally aligned
35 // only but not dword aligned.
36 static cl::opt<bool>
37 WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads",
38 cl::desc("Widen sub-dword constant address space loads in "
39 "AMDGPULateCodeGenPrepare"),
40 cl::ReallyHidden, cl::init(true));
41
42 namespace {
43
44 class AMDGPULateCodeGenPrepare
45 : public FunctionPass,
46 public InstVisitor<AMDGPULateCodeGenPrepare, bool> {
47 Module *Mod = nullptr;
48 const DataLayout *DL = nullptr;
49
50 AssumptionCache *AC = nullptr;
51 UniformityInfo *UA = nullptr;
52
53 SmallVector<WeakTrackingVH, 8> DeadInsts;
54
55 public:
56 static char ID;
57
AMDGPULateCodeGenPrepare()58 AMDGPULateCodeGenPrepare() : FunctionPass(ID) {}
59
getPassName() const60 StringRef getPassName() const override {
61 return "AMDGPU IR late optimizations";
62 }
63
getAnalysisUsage(AnalysisUsage & AU) const64 void getAnalysisUsage(AnalysisUsage &AU) const override {
65 AU.addRequired<TargetPassConfig>();
66 AU.addRequired<AssumptionCacheTracker>();
67 AU.addRequired<UniformityInfoWrapperPass>();
68 AU.setPreservesAll();
69 }
70
71 bool doInitialization(Module &M) override;
72 bool runOnFunction(Function &F) override;
73
visitInstruction(Instruction &)74 bool visitInstruction(Instruction &) { return false; }
75
76 // Check if the specified value is at least DWORD aligned.
isDWORDAligned(const Value * V) const77 bool isDWORDAligned(const Value *V) const {
78 KnownBits Known = computeKnownBits(V, *DL, 0, AC);
79 return Known.countMinTrailingZeros() >= 2;
80 }
81
82 bool canWidenScalarExtLoad(LoadInst &LI) const;
83 bool visitLoadInst(LoadInst &LI);
84 };
85
86 using ValueToValueMap = DenseMap<const Value *, Value *>;
87
88 class LiveRegOptimizer {
89 private:
90 Module *Mod = nullptr;
91 const DataLayout *DL = nullptr;
92 const GCNSubtarget *ST;
93 /// The scalar type to convert to
94 Type *ConvertToScalar;
95 /// The set of visited Instructions
96 SmallPtrSet<Instruction *, 4> Visited;
97 /// Map of Value -> Converted Value
98 ValueToValueMap ValMap;
99 /// Map of containing conversions from Optimal Type -> Original Type per BB.
100 DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap;
101
102 public:
103 /// Calculate the and \p return the type to convert to given a problematic \p
104 /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32).
105 Type *calculateConvertType(Type *OriginalType);
106 /// Convert the virtual register defined by \p V to the compatible vector of
107 /// legal type
108 Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt);
109 /// Convert the virtual register defined by \p V back to the original type \p
110 /// ConvertType, stripping away the MSBs in cases where there was an imperfect
111 /// fit (e.g. v2i32 -> v7i8)
112 Value *convertFromOptType(Type *ConvertType, Instruction *V,
113 BasicBlock::iterator &InstPt,
114 BasicBlock *InsertBlock);
115 /// Check for problematic PHI nodes or cross-bb values based on the value
116 /// defined by \p I, and coerce to legal types if necessary. For problematic
117 /// PHI node, we coerce all incoming values in a single invocation.
118 bool optimizeLiveType(Instruction *I,
119 SmallVectorImpl<WeakTrackingVH> &DeadInsts);
120
121 // Whether or not the type should be replaced to avoid inefficient
122 // legalization code
shouldReplace(Type * ITy)123 bool shouldReplace(Type *ITy) {
124 FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy);
125 if (!VTy)
126 return false;
127
128 auto TLI = ST->getTargetLowering();
129
130 Type *EltTy = VTy->getElementType();
131 // If the element size is not less than the convert to scalar size, then we
132 // can't do any bit packing
133 if (!EltTy->isIntegerTy() ||
134 EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits())
135 return false;
136
137 // Only coerce illegal types
138 TargetLoweringBase::LegalizeKind LK =
139 TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false));
140 return LK.first != TargetLoweringBase::TypeLegal;
141 }
142
LiveRegOptimizer(Module * Mod,const GCNSubtarget * ST)143 LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) {
144 DL = &Mod->getDataLayout();
145 ConvertToScalar = Type::getInt32Ty(Mod->getContext());
146 }
147 };
148
149 } // end anonymous namespace
150
doInitialization(Module & M)151 bool AMDGPULateCodeGenPrepare::doInitialization(Module &M) {
152 Mod = &M;
153 DL = &Mod->getDataLayout();
154 return false;
155 }
156
runOnFunction(Function & F)157 bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) {
158 if (skipFunction(F))
159 return false;
160
161 const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
162 const TargetMachine &TM = TPC.getTM<TargetMachine>();
163 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F);
164
165 AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
166 UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
167
168 // "Optimize" the virtual regs that cross basic block boundaries. When
169 // building the SelectionDAG, vectors of illegal types that cross basic blocks
170 // will be scalarized and widened, with each scalar living in its
171 // own register. To work around this, this optimization converts the
172 // vectors to equivalent vectors of legal type (which are converted back
173 // before uses in subsequent blocks), to pack the bits into fewer physical
174 // registers (used in CopyToReg/CopyFromReg pairs).
175 LiveRegOptimizer LRO(Mod, &ST);
176
177 bool Changed = false;
178
179 bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads();
180
181 for (auto &BB : reverse(F))
182 for (Instruction &I : make_early_inc_range(reverse(BB))) {
183 Changed |= !HasScalarSubwordLoads && visit(I);
184 Changed |= LRO.optimizeLiveType(&I, DeadInsts);
185 }
186
187 RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts);
188 return Changed;
189 }
190
calculateConvertType(Type * OriginalType)191 Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) {
192 assert(OriginalType->getScalarSizeInBits() <=
193 ConvertToScalar->getScalarSizeInBits());
194
195 FixedVectorType *VTy = cast<FixedVectorType>(OriginalType);
196
197 TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
198 TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar);
199 unsigned ConvertEltCount =
200 (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
201
202 if (OriginalSize <= ConvertScalarSize)
203 return IntegerType::get(Mod->getContext(), ConvertScalarSize);
204
205 return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
206 ConvertEltCount, false);
207 }
208
convertToOptType(Instruction * V,BasicBlock::iterator & InsertPt)209 Value *LiveRegOptimizer::convertToOptType(Instruction *V,
210 BasicBlock::iterator &InsertPt) {
211 FixedVectorType *VTy = cast<FixedVectorType>(V->getType());
212 Type *NewTy = calculateConvertType(V->getType());
213
214 TypeSize OriginalSize = DL->getTypeSizeInBits(VTy);
215 TypeSize NewSize = DL->getTypeSizeInBits(NewTy);
216
217 IRBuilder<> Builder(V->getParent(), InsertPt);
218 // If there is a bitsize match, we can fit the old vector into a new vector of
219 // desired type.
220 if (OriginalSize == NewSize)
221 return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc");
222
223 // If there is a bitsize mismatch, we must use a wider vector.
224 assert(NewSize > OriginalSize);
225 uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits();
226
227 SmallVector<int, 8> ShuffleMask;
228 uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue();
229 for (unsigned I = 0; I < OriginalElementCount; I++)
230 ShuffleMask.push_back(I);
231
232 for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++)
233 ShuffleMask.push_back(OriginalElementCount);
234
235 Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask);
236 return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc");
237 }
238
convertFromOptType(Type * ConvertType,Instruction * V,BasicBlock::iterator & InsertPt,BasicBlock * InsertBB)239 Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V,
240 BasicBlock::iterator &InsertPt,
241 BasicBlock *InsertBB) {
242 FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType);
243
244 TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType());
245 TypeSize NewSize = DL->getTypeSizeInBits(NewVTy);
246
247 IRBuilder<> Builder(InsertBB, InsertPt);
248 // If there is a bitsize match, we simply convert back to the original type.
249 if (OriginalSize == NewSize)
250 return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc");
251
252 // If there is a bitsize mismatch, then we must have used a wider value to
253 // hold the bits.
254 assert(OriginalSize > NewSize);
255 // For wide scalars, we can just truncate the value.
256 if (!V->getType()->isVectorTy()) {
257 Instruction *Trunc = cast<Instruction>(
258 Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize)));
259 return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy));
260 }
261
262 // For wider vectors, we must strip the MSBs to convert back to the original
263 // type.
264 VectorType *ExpandedVT = VectorType::get(
265 Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
266 (OriginalSize / NewVTy->getScalarSizeInBits()), false);
267 Instruction *Converted =
268 cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT));
269
270 unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue();
271 SmallVector<int, 8> ShuffleMask(NarrowElementCount);
272 std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0);
273
274 return Builder.CreateShuffleVector(Converted, ShuffleMask);
275 }
276
optimizeLiveType(Instruction * I,SmallVectorImpl<WeakTrackingVH> & DeadInsts)277 bool LiveRegOptimizer::optimizeLiveType(
278 Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
279 SmallVector<Instruction *, 4> Worklist;
280 SmallPtrSet<PHINode *, 4> PhiNodes;
281 SmallPtrSet<Instruction *, 4> Defs;
282 SmallPtrSet<Instruction *, 4> Uses;
283
284 Worklist.push_back(cast<Instruction>(I));
285 while (!Worklist.empty()) {
286 Instruction *II = Worklist.pop_back_val();
287
288 if (!Visited.insert(II).second)
289 continue;
290
291 if (!shouldReplace(II->getType()))
292 continue;
293
294 if (PHINode *Phi = dyn_cast<PHINode>(II)) {
295 PhiNodes.insert(Phi);
296 // Collect all the incoming values of problematic PHI nodes.
297 for (Value *V : Phi->incoming_values()) {
298 // Repeat the collection process for newly found PHI nodes.
299 if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
300 if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
301 Worklist.push_back(OpPhi);
302 continue;
303 }
304
305 Instruction *IncInst = dyn_cast<Instruction>(V);
306 // Other incoming value types (e.g. vector literals) are unhandled
307 if (!IncInst && !isa<ConstantAggregateZero>(V))
308 return false;
309
310 // Collect all other incoming values for coercion.
311 if (IncInst)
312 Defs.insert(IncInst);
313 }
314 }
315
316 // Collect all relevant uses.
317 for (User *V : II->users()) {
318 // Repeat the collection process for problematic PHI nodes.
319 if (PHINode *OpPhi = dyn_cast<PHINode>(V)) {
320 if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi))
321 Worklist.push_back(OpPhi);
322 continue;
323 }
324
325 Instruction *UseInst = cast<Instruction>(V);
326 // Collect all uses of PHINodes and any use the crosses BB boundaries.
327 if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) {
328 Uses.insert(UseInst);
329 if (!Defs.count(II) && !isa<PHINode>(II)) {
330 Defs.insert(II);
331 }
332 }
333 }
334 }
335
336 // Coerce and track the defs.
337 for (Instruction *D : Defs) {
338 if (!ValMap.contains(D)) {
339 BasicBlock::iterator InsertPt = std::next(D->getIterator());
340 Value *ConvertVal = convertToOptType(D, InsertPt);
341 assert(ConvertVal);
342 ValMap[D] = ConvertVal;
343 }
344 }
345
346 // Construct new-typed PHI nodes.
347 for (PHINode *Phi : PhiNodes) {
348 ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()),
349 Phi->getNumIncomingValues(),
350 Phi->getName() + ".tc", Phi->getIterator());
351 }
352
353 // Connect all the PHI nodes with their new incoming values.
354 for (PHINode *Phi : PhiNodes) {
355 PHINode *NewPhi = cast<PHINode>(ValMap[Phi]);
356 bool MissingIncVal = false;
357 for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) {
358 Value *IncVal = Phi->getIncomingValue(I);
359 if (isa<ConstantAggregateZero>(IncVal)) {
360 Type *NewType = calculateConvertType(Phi->getType());
361 NewPhi->addIncoming(ConstantInt::get(NewType, 0, false),
362 Phi->getIncomingBlock(I));
363 } else if (ValMap.contains(IncVal) && ValMap[IncVal])
364 NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I));
365 else
366 MissingIncVal = true;
367 }
368 if (MissingIncVal) {
369 Value *DeadVal = ValMap[Phi];
370 // The coercion chain of the PHI is broken. Delete the Phi
371 // from the ValMap and any connected / user Phis.
372 SmallVector<Value *, 4> PHIWorklist;
373 SmallPtrSet<Value *, 4> VisitedPhis;
374 PHIWorklist.push_back(DeadVal);
375 while (!PHIWorklist.empty()) {
376 Value *NextDeadValue = PHIWorklist.pop_back_val();
377 VisitedPhis.insert(NextDeadValue);
378 auto OriginalPhi =
379 std::find_if(PhiNodes.begin(), PhiNodes.end(),
380 [this, &NextDeadValue](PHINode *CandPhi) {
381 return ValMap[CandPhi] == NextDeadValue;
382 });
383 // This PHI may have already been removed from maps when
384 // unwinding a previous Phi
385 if (OriginalPhi != PhiNodes.end())
386 ValMap.erase(*OriginalPhi);
387
388 DeadInsts.emplace_back(cast<Instruction>(NextDeadValue));
389
390 for (User *U : NextDeadValue->users()) {
391 if (!VisitedPhis.contains(cast<PHINode>(U)))
392 PHIWorklist.push_back(U);
393 }
394 }
395 } else {
396 DeadInsts.emplace_back(cast<Instruction>(Phi));
397 }
398 }
399 // Coerce back to the original type and replace the uses.
400 for (Instruction *U : Uses) {
401 // Replace all converted operands for a use.
402 for (auto [OpIdx, Op] : enumerate(U->operands())) {
403 if (ValMap.contains(Op) && ValMap[Op]) {
404 Value *NewVal = nullptr;
405 if (BBUseValMap.contains(U->getParent()) &&
406 BBUseValMap[U->getParent()].contains(ValMap[Op]))
407 NewVal = BBUseValMap[U->getParent()][ValMap[Op]];
408 else {
409 BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt();
410 // We may pick up ops that were previously converted for users in
411 // other blocks. If there is an originally typed definition of the Op
412 // already in this block, simply reuse it.
413 if (isa<Instruction>(Op) && !isa<PHINode>(Op) &&
414 U->getParent() == cast<Instruction>(Op)->getParent()) {
415 NewVal = Op;
416 } else {
417 NewVal =
418 convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]),
419 InsertPt, U->getParent());
420 BBUseValMap[U->getParent()][ValMap[Op]] = NewVal;
421 }
422 }
423 assert(NewVal);
424 U->setOperand(OpIdx, NewVal);
425 }
426 }
427 }
428
429 return true;
430 }
431
canWidenScalarExtLoad(LoadInst & LI) const432 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const {
433 unsigned AS = LI.getPointerAddressSpace();
434 // Skip non-constant address space.
435 if (AS != AMDGPUAS::CONSTANT_ADDRESS &&
436 AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT)
437 return false;
438 // Skip non-simple loads.
439 if (!LI.isSimple())
440 return false;
441 Type *Ty = LI.getType();
442 // Skip aggregate types.
443 if (Ty->isAggregateType())
444 return false;
445 unsigned TySize = DL->getTypeStoreSize(Ty);
446 // Only handle sub-DWORD loads.
447 if (TySize >= 4)
448 return false;
449 // That load must be at least naturally aligned.
450 if (LI.getAlign() < DL->getABITypeAlign(Ty))
451 return false;
452 // It should be uniform, i.e. a scalar load.
453 return UA->isUniform(&LI);
454 }
455
visitLoadInst(LoadInst & LI)456 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) {
457 if (!WidenLoads)
458 return false;
459
460 // Skip if that load is already aligned on DWORD at least as it's handled in
461 // SDAG.
462 if (LI.getAlign() >= 4)
463 return false;
464
465 if (!canWidenScalarExtLoad(LI))
466 return false;
467
468 int64_t Offset = 0;
469 auto *Base =
470 GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, *DL);
471 // If that base is not DWORD aligned, it's not safe to perform the following
472 // transforms.
473 if (!isDWORDAligned(Base))
474 return false;
475
476 int64_t Adjust = Offset & 0x3;
477 if (Adjust == 0) {
478 // With a zero adjust, the original alignment could be promoted with a
479 // better one.
480 LI.setAlignment(Align(4));
481 return true;
482 }
483
484 IRBuilder<> IRB(&LI);
485 IRB.SetCurrentDebugLocation(LI.getDebugLoc());
486
487 unsigned LdBits = DL->getTypeStoreSizeInBits(LI.getType());
488 auto IntNTy = Type::getIntNTy(LI.getContext(), LdBits);
489
490 auto *NewPtr = IRB.CreateConstGEP1_64(
491 IRB.getInt8Ty(),
492 IRB.CreateAddrSpaceCast(Base, LI.getPointerOperand()->getType()),
493 Offset - Adjust);
494
495 LoadInst *NewLd = IRB.CreateAlignedLoad(IRB.getInt32Ty(), NewPtr, Align(4));
496 NewLd->copyMetadata(LI);
497 NewLd->setMetadata(LLVMContext::MD_range, nullptr);
498
499 unsigned ShAmt = Adjust * 8;
500 auto *NewVal = IRB.CreateBitCast(
501 IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt), IntNTy), LI.getType());
502 LI.replaceAllUsesWith(NewVal);
503 DeadInsts.emplace_back(&LI);
504
505 return true;
506 }
507
508 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepare, DEBUG_TYPE,
509 "AMDGPU IR late optimizations", false, false)
510 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
511 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
512 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass)
513 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepare, DEBUG_TYPE,
514 "AMDGPU IR late optimizations", false, false)
515
516 char AMDGPULateCodeGenPrepare::ID = 0;
517
createAMDGPULateCodeGenPreparePass()518 FunctionPass *llvm::createAMDGPULateCodeGenPreparePass() {
519 return new AMDGPULateCodeGenPrepare();
520 }
521