1 //===-- AMDGPUCodeGenPrepare.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// This pass does misc. AMDGPU optimizations on IR *just* before instruction 11 /// selection. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "AMDGPU.h" 16 #include "AMDGPUTargetMachine.h" 17 #include "llvm/Analysis/AssumptionCache.h" 18 #include "llvm/Analysis/UniformityAnalysis.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/InstVisitor.h" 23 #include "llvm/InitializePasses.h" 24 #include "llvm/Support/CommandLine.h" 25 #include "llvm/Support/KnownBits.h" 26 #include "llvm/Transforms/Utils/Local.h" 27 28 #define DEBUG_TYPE "amdgpu-late-codegenprepare" 29 30 using namespace llvm; 31 32 // Scalar load widening needs running after load-store-vectorizer as that pass 33 // doesn't handle overlapping cases. In addition, this pass enhances the 34 // widening to handle cases where scalar sub-dword loads are naturally aligned 35 // only but not dword aligned. 36 static cl::opt<bool> 37 WidenLoads("amdgpu-late-codegenprepare-widen-constant-loads", 38 cl::desc("Widen sub-dword constant address space loads in " 39 "AMDGPULateCodeGenPrepare"), 40 cl::ReallyHidden, cl::init(true)); 41 42 namespace { 43 44 class AMDGPULateCodeGenPrepare 45 : public FunctionPass, 46 public InstVisitor<AMDGPULateCodeGenPrepare, bool> { 47 Module *Mod = nullptr; 48 const DataLayout *DL = nullptr; 49 50 AssumptionCache *AC = nullptr; 51 UniformityInfo *UA = nullptr; 52 53 SmallVector<WeakTrackingVH, 8> DeadInsts; 54 55 public: 56 static char ID; 57 58 AMDGPULateCodeGenPrepare() : FunctionPass(ID) {} 59 60 StringRef getPassName() const override { 61 return "AMDGPU IR late optimizations"; 62 } 63 64 void getAnalysisUsage(AnalysisUsage &AU) const override { 65 AU.addRequired<TargetPassConfig>(); 66 AU.addRequired<AssumptionCacheTracker>(); 67 AU.addRequired<UniformityInfoWrapperPass>(); 68 AU.setPreservesAll(); 69 } 70 71 bool doInitialization(Module &M) override; 72 bool runOnFunction(Function &F) override; 73 74 bool visitInstruction(Instruction &) { return false; } 75 76 // Check if the specified value is at least DWORD aligned. 77 bool isDWORDAligned(const Value *V) const { 78 KnownBits Known = computeKnownBits(V, *DL, 0, AC); 79 return Known.countMinTrailingZeros() >= 2; 80 } 81 82 bool canWidenScalarExtLoad(LoadInst &LI) const; 83 bool visitLoadInst(LoadInst &LI); 84 }; 85 86 using ValueToValueMap = DenseMap<const Value *, Value *>; 87 88 class LiveRegOptimizer { 89 private: 90 Module *Mod = nullptr; 91 const DataLayout *DL = nullptr; 92 const GCNSubtarget *ST; 93 /// The scalar type to convert to 94 Type *ConvertToScalar; 95 /// The set of visited Instructions 96 SmallPtrSet<Instruction *, 4> Visited; 97 /// Map of Value -> Converted Value 98 ValueToValueMap ValMap; 99 /// Map of containing conversions from Optimal Type -> Original Type per BB. 100 DenseMap<BasicBlock *, ValueToValueMap> BBUseValMap; 101 102 public: 103 /// Calculate the and \p return the type to convert to given a problematic \p 104 /// OriginalType. In some instances, we may widen the type (e.g. v2i8 -> i32). 105 Type *calculateConvertType(Type *OriginalType); 106 /// Convert the virtual register defined by \p V to the compatible vector of 107 /// legal type 108 Value *convertToOptType(Instruction *V, BasicBlock::iterator &InstPt); 109 /// Convert the virtual register defined by \p V back to the original type \p 110 /// ConvertType, stripping away the MSBs in cases where there was an imperfect 111 /// fit (e.g. v2i32 -> v7i8) 112 Value *convertFromOptType(Type *ConvertType, Instruction *V, 113 BasicBlock::iterator &InstPt, 114 BasicBlock *InsertBlock); 115 /// Check for problematic PHI nodes or cross-bb values based on the value 116 /// defined by \p I, and coerce to legal types if necessary. For problematic 117 /// PHI node, we coerce all incoming values in a single invocation. 118 bool optimizeLiveType(Instruction *I, 119 SmallVectorImpl<WeakTrackingVH> &DeadInsts); 120 121 // Whether or not the type should be replaced to avoid inefficient 122 // legalization code 123 bool shouldReplace(Type *ITy) { 124 FixedVectorType *VTy = dyn_cast<FixedVectorType>(ITy); 125 if (!VTy) 126 return false; 127 128 auto TLI = ST->getTargetLowering(); 129 130 Type *EltTy = VTy->getElementType(); 131 // If the element size is not less than the convert to scalar size, then we 132 // can't do any bit packing 133 if (!EltTy->isIntegerTy() || 134 EltTy->getScalarSizeInBits() > ConvertToScalar->getScalarSizeInBits()) 135 return false; 136 137 // Only coerce illegal types 138 TargetLoweringBase::LegalizeKind LK = 139 TLI->getTypeConversion(EltTy->getContext(), EVT::getEVT(EltTy, false)); 140 return LK.first != TargetLoweringBase::TypeLegal; 141 } 142 143 LiveRegOptimizer(Module *Mod, const GCNSubtarget *ST) : Mod(Mod), ST(ST) { 144 DL = &Mod->getDataLayout(); 145 ConvertToScalar = Type::getInt32Ty(Mod->getContext()); 146 } 147 }; 148 149 } // end anonymous namespace 150 151 bool AMDGPULateCodeGenPrepare::doInitialization(Module &M) { 152 Mod = &M; 153 DL = &Mod->getDataLayout(); 154 return false; 155 } 156 157 bool AMDGPULateCodeGenPrepare::runOnFunction(Function &F) { 158 if (skipFunction(F)) 159 return false; 160 161 const TargetPassConfig &TPC = getAnalysis<TargetPassConfig>(); 162 const TargetMachine &TM = TPC.getTM<TargetMachine>(); 163 const GCNSubtarget &ST = TM.getSubtarget<GCNSubtarget>(F); 164 165 AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); 166 UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo(); 167 168 // "Optimize" the virtual regs that cross basic block boundaries. When 169 // building the SelectionDAG, vectors of illegal types that cross basic blocks 170 // will be scalarized and widened, with each scalar living in its 171 // own register. To work around this, this optimization converts the 172 // vectors to equivalent vectors of legal type (which are converted back 173 // before uses in subsequent blocks), to pack the bits into fewer physical 174 // registers (used in CopyToReg/CopyFromReg pairs). 175 LiveRegOptimizer LRO(Mod, &ST); 176 177 bool Changed = false; 178 179 bool HasScalarSubwordLoads = ST.hasScalarSubwordLoads(); 180 181 for (auto &BB : reverse(F)) 182 for (Instruction &I : make_early_inc_range(reverse(BB))) { 183 Changed |= !HasScalarSubwordLoads && visit(I); 184 Changed |= LRO.optimizeLiveType(&I, DeadInsts); 185 } 186 187 RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts); 188 return Changed; 189 } 190 191 Type *LiveRegOptimizer::calculateConvertType(Type *OriginalType) { 192 assert(OriginalType->getScalarSizeInBits() <= 193 ConvertToScalar->getScalarSizeInBits()); 194 195 FixedVectorType *VTy = cast<FixedVectorType>(OriginalType); 196 197 TypeSize OriginalSize = DL->getTypeSizeInBits(VTy); 198 TypeSize ConvertScalarSize = DL->getTypeSizeInBits(ConvertToScalar); 199 unsigned ConvertEltCount = 200 (OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize; 201 202 if (OriginalSize <= ConvertScalarSize) 203 return IntegerType::get(Mod->getContext(), ConvertScalarSize); 204 205 return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize), 206 ConvertEltCount, false); 207 } 208 209 Value *LiveRegOptimizer::convertToOptType(Instruction *V, 210 BasicBlock::iterator &InsertPt) { 211 FixedVectorType *VTy = cast<FixedVectorType>(V->getType()); 212 Type *NewTy = calculateConvertType(V->getType()); 213 214 TypeSize OriginalSize = DL->getTypeSizeInBits(VTy); 215 TypeSize NewSize = DL->getTypeSizeInBits(NewTy); 216 217 IRBuilder<> Builder(V->getParent(), InsertPt); 218 // If there is a bitsize match, we can fit the old vector into a new vector of 219 // desired type. 220 if (OriginalSize == NewSize) 221 return Builder.CreateBitCast(V, NewTy, V->getName() + ".bc"); 222 223 // If there is a bitsize mismatch, we must use a wider vector. 224 assert(NewSize > OriginalSize); 225 uint64_t ExpandedVecElementCount = NewSize / VTy->getScalarSizeInBits(); 226 227 SmallVector<int, 8> ShuffleMask; 228 uint64_t OriginalElementCount = VTy->getElementCount().getFixedValue(); 229 for (unsigned I = 0; I < OriginalElementCount; I++) 230 ShuffleMask.push_back(I); 231 232 for (uint64_t I = OriginalElementCount; I < ExpandedVecElementCount; I++) 233 ShuffleMask.push_back(OriginalElementCount); 234 235 Value *ExpandedVec = Builder.CreateShuffleVector(V, ShuffleMask); 236 return Builder.CreateBitCast(ExpandedVec, NewTy, V->getName() + ".bc"); 237 } 238 239 Value *LiveRegOptimizer::convertFromOptType(Type *ConvertType, Instruction *V, 240 BasicBlock::iterator &InsertPt, 241 BasicBlock *InsertBB) { 242 FixedVectorType *NewVTy = cast<FixedVectorType>(ConvertType); 243 244 TypeSize OriginalSize = DL->getTypeSizeInBits(V->getType()); 245 TypeSize NewSize = DL->getTypeSizeInBits(NewVTy); 246 247 IRBuilder<> Builder(InsertBB, InsertPt); 248 // If there is a bitsize match, we simply convert back to the original type. 249 if (OriginalSize == NewSize) 250 return Builder.CreateBitCast(V, NewVTy, V->getName() + ".bc"); 251 252 // If there is a bitsize mismatch, then we must have used a wider value to 253 // hold the bits. 254 assert(OriginalSize > NewSize); 255 // For wide scalars, we can just truncate the value. 256 if (!V->getType()->isVectorTy()) { 257 Instruction *Trunc = cast<Instruction>( 258 Builder.CreateTrunc(V, IntegerType::get(Mod->getContext(), NewSize))); 259 return cast<Instruction>(Builder.CreateBitCast(Trunc, NewVTy)); 260 } 261 262 // For wider vectors, we must strip the MSBs to convert back to the original 263 // type. 264 VectorType *ExpandedVT = VectorType::get( 265 Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()), 266 (OriginalSize / NewVTy->getScalarSizeInBits()), false); 267 Instruction *Converted = 268 cast<Instruction>(Builder.CreateBitCast(V, ExpandedVT)); 269 270 unsigned NarrowElementCount = NewVTy->getElementCount().getFixedValue(); 271 SmallVector<int, 8> ShuffleMask(NarrowElementCount); 272 std::iota(ShuffleMask.begin(), ShuffleMask.end(), 0); 273 274 return Builder.CreateShuffleVector(Converted, ShuffleMask); 275 } 276 277 bool LiveRegOptimizer::optimizeLiveType( 278 Instruction *I, SmallVectorImpl<WeakTrackingVH> &DeadInsts) { 279 SmallVector<Instruction *, 4> Worklist; 280 SmallPtrSet<PHINode *, 4> PhiNodes; 281 SmallPtrSet<Instruction *, 4> Defs; 282 SmallPtrSet<Instruction *, 4> Uses; 283 284 Worklist.push_back(cast<Instruction>(I)); 285 while (!Worklist.empty()) { 286 Instruction *II = Worklist.pop_back_val(); 287 288 if (!Visited.insert(II).second) 289 continue; 290 291 if (!shouldReplace(II->getType())) 292 continue; 293 294 if (PHINode *Phi = dyn_cast<PHINode>(II)) { 295 PhiNodes.insert(Phi); 296 // Collect all the incoming values of problematic PHI nodes. 297 for (Value *V : Phi->incoming_values()) { 298 // Repeat the collection process for newly found PHI nodes. 299 if (PHINode *OpPhi = dyn_cast<PHINode>(V)) { 300 if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi)) 301 Worklist.push_back(OpPhi); 302 continue; 303 } 304 305 Instruction *IncInst = dyn_cast<Instruction>(V); 306 // Other incoming value types (e.g. vector literals) are unhandled 307 if (!IncInst && !isa<ConstantAggregateZero>(V)) 308 return false; 309 310 // Collect all other incoming values for coercion. 311 if (IncInst) 312 Defs.insert(IncInst); 313 } 314 } 315 316 // Collect all relevant uses. 317 for (User *V : II->users()) { 318 // Repeat the collection process for problematic PHI nodes. 319 if (PHINode *OpPhi = dyn_cast<PHINode>(V)) { 320 if (!PhiNodes.count(OpPhi) && !Visited.count(OpPhi)) 321 Worklist.push_back(OpPhi); 322 continue; 323 } 324 325 Instruction *UseInst = cast<Instruction>(V); 326 // Collect all uses of PHINodes and any use the crosses BB boundaries. 327 if (UseInst->getParent() != II->getParent() || isa<PHINode>(II)) { 328 Uses.insert(UseInst); 329 if (!Defs.count(II) && !isa<PHINode>(II)) { 330 Defs.insert(II); 331 } 332 } 333 } 334 } 335 336 // Coerce and track the defs. 337 for (Instruction *D : Defs) { 338 if (!ValMap.contains(D)) { 339 BasicBlock::iterator InsertPt = std::next(D->getIterator()); 340 Value *ConvertVal = convertToOptType(D, InsertPt); 341 assert(ConvertVal); 342 ValMap[D] = ConvertVal; 343 } 344 } 345 346 // Construct new-typed PHI nodes. 347 for (PHINode *Phi : PhiNodes) { 348 ValMap[Phi] = PHINode::Create(calculateConvertType(Phi->getType()), 349 Phi->getNumIncomingValues(), 350 Phi->getName() + ".tc", Phi->getIterator()); 351 } 352 353 // Connect all the PHI nodes with their new incoming values. 354 for (PHINode *Phi : PhiNodes) { 355 PHINode *NewPhi = cast<PHINode>(ValMap[Phi]); 356 bool MissingIncVal = false; 357 for (int I = 0, E = Phi->getNumIncomingValues(); I < E; I++) { 358 Value *IncVal = Phi->getIncomingValue(I); 359 if (isa<ConstantAggregateZero>(IncVal)) { 360 Type *NewType = calculateConvertType(Phi->getType()); 361 NewPhi->addIncoming(ConstantInt::get(NewType, 0, false), 362 Phi->getIncomingBlock(I)); 363 } else if (ValMap.contains(IncVal) && ValMap[IncVal]) 364 NewPhi->addIncoming(ValMap[IncVal], Phi->getIncomingBlock(I)); 365 else 366 MissingIncVal = true; 367 } 368 if (MissingIncVal) { 369 Value *DeadVal = ValMap[Phi]; 370 // The coercion chain of the PHI is broken. Delete the Phi 371 // from the ValMap and any connected / user Phis. 372 SmallVector<Value *, 4> PHIWorklist; 373 SmallPtrSet<Value *, 4> VisitedPhis; 374 PHIWorklist.push_back(DeadVal); 375 while (!PHIWorklist.empty()) { 376 Value *NextDeadValue = PHIWorklist.pop_back_val(); 377 VisitedPhis.insert(NextDeadValue); 378 auto OriginalPhi = 379 std::find_if(PhiNodes.begin(), PhiNodes.end(), 380 [this, &NextDeadValue](PHINode *CandPhi) { 381 return ValMap[CandPhi] == NextDeadValue; 382 }); 383 // This PHI may have already been removed from maps when 384 // unwinding a previous Phi 385 if (OriginalPhi != PhiNodes.end()) 386 ValMap.erase(*OriginalPhi); 387 388 DeadInsts.emplace_back(cast<Instruction>(NextDeadValue)); 389 390 for (User *U : NextDeadValue->users()) { 391 if (!VisitedPhis.contains(cast<PHINode>(U))) 392 PHIWorklist.push_back(U); 393 } 394 } 395 } else { 396 DeadInsts.emplace_back(cast<Instruction>(Phi)); 397 } 398 } 399 // Coerce back to the original type and replace the uses. 400 for (Instruction *U : Uses) { 401 // Replace all converted operands for a use. 402 for (auto [OpIdx, Op] : enumerate(U->operands())) { 403 if (ValMap.contains(Op) && ValMap[Op]) { 404 Value *NewVal = nullptr; 405 if (BBUseValMap.contains(U->getParent()) && 406 BBUseValMap[U->getParent()].contains(ValMap[Op])) 407 NewVal = BBUseValMap[U->getParent()][ValMap[Op]]; 408 else { 409 BasicBlock::iterator InsertPt = U->getParent()->getFirstNonPHIIt(); 410 // We may pick up ops that were previously converted for users in 411 // other blocks. If there is an originally typed definition of the Op 412 // already in this block, simply reuse it. 413 if (isa<Instruction>(Op) && !isa<PHINode>(Op) && 414 U->getParent() == cast<Instruction>(Op)->getParent()) { 415 NewVal = Op; 416 } else { 417 NewVal = 418 convertFromOptType(Op->getType(), cast<Instruction>(ValMap[Op]), 419 InsertPt, U->getParent()); 420 BBUseValMap[U->getParent()][ValMap[Op]] = NewVal; 421 } 422 } 423 assert(NewVal); 424 U->setOperand(OpIdx, NewVal); 425 } 426 } 427 } 428 429 return true; 430 } 431 432 bool AMDGPULateCodeGenPrepare::canWidenScalarExtLoad(LoadInst &LI) const { 433 unsigned AS = LI.getPointerAddressSpace(); 434 // Skip non-constant address space. 435 if (AS != AMDGPUAS::CONSTANT_ADDRESS && 436 AS != AMDGPUAS::CONSTANT_ADDRESS_32BIT) 437 return false; 438 // Skip non-simple loads. 439 if (!LI.isSimple()) 440 return false; 441 Type *Ty = LI.getType(); 442 // Skip aggregate types. 443 if (Ty->isAggregateType()) 444 return false; 445 unsigned TySize = DL->getTypeStoreSize(Ty); 446 // Only handle sub-DWORD loads. 447 if (TySize >= 4) 448 return false; 449 // That load must be at least naturally aligned. 450 if (LI.getAlign() < DL->getABITypeAlign(Ty)) 451 return false; 452 // It should be uniform, i.e. a scalar load. 453 return UA->isUniform(&LI); 454 } 455 456 bool AMDGPULateCodeGenPrepare::visitLoadInst(LoadInst &LI) { 457 if (!WidenLoads) 458 return false; 459 460 // Skip if that load is already aligned on DWORD at least as it's handled in 461 // SDAG. 462 if (LI.getAlign() >= 4) 463 return false; 464 465 if (!canWidenScalarExtLoad(LI)) 466 return false; 467 468 int64_t Offset = 0; 469 auto *Base = 470 GetPointerBaseWithConstantOffset(LI.getPointerOperand(), Offset, *DL); 471 // If that base is not DWORD aligned, it's not safe to perform the following 472 // transforms. 473 if (!isDWORDAligned(Base)) 474 return false; 475 476 int64_t Adjust = Offset & 0x3; 477 if (Adjust == 0) { 478 // With a zero adjust, the original alignment could be promoted with a 479 // better one. 480 LI.setAlignment(Align(4)); 481 return true; 482 } 483 484 IRBuilder<> IRB(&LI); 485 IRB.SetCurrentDebugLocation(LI.getDebugLoc()); 486 487 unsigned LdBits = DL->getTypeStoreSizeInBits(LI.getType()); 488 auto IntNTy = Type::getIntNTy(LI.getContext(), LdBits); 489 490 auto *NewPtr = IRB.CreateConstGEP1_64( 491 IRB.getInt8Ty(), 492 IRB.CreateAddrSpaceCast(Base, LI.getPointerOperand()->getType()), 493 Offset - Adjust); 494 495 LoadInst *NewLd = IRB.CreateAlignedLoad(IRB.getInt32Ty(), NewPtr, Align(4)); 496 NewLd->copyMetadata(LI); 497 NewLd->setMetadata(LLVMContext::MD_range, nullptr); 498 499 unsigned ShAmt = Adjust * 8; 500 auto *NewVal = IRB.CreateBitCast( 501 IRB.CreateTrunc(IRB.CreateLShr(NewLd, ShAmt), IntNTy), LI.getType()); 502 LI.replaceAllUsesWith(NewVal); 503 DeadInsts.emplace_back(&LI); 504 505 return true; 506 } 507 508 INITIALIZE_PASS_BEGIN(AMDGPULateCodeGenPrepare, DEBUG_TYPE, 509 "AMDGPU IR late optimizations", false, false) 510 INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) 511 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 512 INITIALIZE_PASS_DEPENDENCY(UniformityInfoWrapperPass) 513 INITIALIZE_PASS_END(AMDGPULateCodeGenPrepare, DEBUG_TYPE, 514 "AMDGPU IR late optimizations", false, false) 515 516 char AMDGPULateCodeGenPrepare::ID = 0; 517 518 FunctionPass *llvm::createAMDGPULateCodeGenPreparePass() { 519 return new AMDGPULateCodeGenPrepare(); 520 } 521