1 //===- DXILFlattenArrays.cpp - Flattens DXIL Arrays-----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 /// 9 /// \file This file contains a pass to flatten arrays for the DirectX Backend. 10 /// 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILFlattenArrays.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/PostOrderIterator.h" 16 #include "llvm/ADT/STLExtras.h" 17 #include "llvm/IR/BasicBlock.h" 18 #include "llvm/IR/DerivedTypes.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/InstVisitor.h" 21 #include "llvm/IR/ReplaceConstant.h" 22 #include "llvm/Support/Casting.h" 23 #include "llvm/Support/MathExtras.h" 24 #include "llvm/Transforms/Utils/Local.h" 25 #include <cassert> 26 #include <cstddef> 27 #include <cstdint> 28 #include <utility> 29 30 #define DEBUG_TYPE "dxil-flatten-arrays" 31 32 using namespace llvm; 33 namespace { 34 35 class DXILFlattenArraysLegacy : public ModulePass { 36 37 public: 38 bool runOnModule(Module &M) override; 39 DXILFlattenArraysLegacy() : ModulePass(ID) {} 40 41 static char ID; // Pass identification. 42 }; 43 44 struct GEPInfo { 45 ArrayType *RootFlattenedArrayType; 46 Value *RootPointerOperand; 47 SmallMapVector<Value *, APInt, 4> VariableOffsets; 48 APInt ConstantOffset; 49 }; 50 51 class DXILFlattenArraysVisitor 52 : public InstVisitor<DXILFlattenArraysVisitor, bool> { 53 public: 54 DXILFlattenArraysVisitor( 55 SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) 56 : GlobalMap(GlobalMap) {} 57 bool visit(Function &F); 58 // InstVisitor methods. They return true if the instruction was scalarized, 59 // false if nothing changed. 60 bool visitGetElementPtrInst(GetElementPtrInst &GEPI); 61 bool visitAllocaInst(AllocaInst &AI); 62 bool visitInstruction(Instruction &I) { return false; } 63 bool visitSelectInst(SelectInst &SI) { return false; } 64 bool visitICmpInst(ICmpInst &ICI) { return false; } 65 bool visitFCmpInst(FCmpInst &FCI) { return false; } 66 bool visitUnaryOperator(UnaryOperator &UO) { return false; } 67 bool visitBinaryOperator(BinaryOperator &BO) { return false; } 68 bool visitCastInst(CastInst &CI) { return false; } 69 bool visitBitCastInst(BitCastInst &BCI) { return false; } 70 bool visitInsertElementInst(InsertElementInst &IEI) { return false; } 71 bool visitExtractElementInst(ExtractElementInst &EEI) { return false; } 72 bool visitShuffleVectorInst(ShuffleVectorInst &SVI) { return false; } 73 bool visitPHINode(PHINode &PHI) { return false; } 74 bool visitLoadInst(LoadInst &LI); 75 bool visitStoreInst(StoreInst &SI); 76 bool visitCallInst(CallInst &ICI) { return false; } 77 bool visitFreezeInst(FreezeInst &FI) { return false; } 78 static bool isMultiDimensionalArray(Type *T); 79 static std::pair<unsigned, Type *> getElementCountAndType(Type *ArrayTy); 80 81 private: 82 SmallVector<WeakTrackingVH> PotentiallyDeadInstrs; 83 SmallDenseMap<GEPOperator *, GEPInfo> GEPChainInfoMap; 84 SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap; 85 bool finish(); 86 ConstantInt *genConstFlattenIndices(ArrayRef<Value *> Indices, 87 ArrayRef<uint64_t> Dims, 88 IRBuilder<> &Builder); 89 Value *genInstructionFlattenIndices(ArrayRef<Value *> Indices, 90 ArrayRef<uint64_t> Dims, 91 IRBuilder<> &Builder); 92 }; 93 } // namespace 94 95 bool DXILFlattenArraysVisitor::finish() { 96 GEPChainInfoMap.clear(); 97 RecursivelyDeleteTriviallyDeadInstructionsPermissive(PotentiallyDeadInstrs); 98 return true; 99 } 100 101 bool DXILFlattenArraysVisitor::isMultiDimensionalArray(Type *T) { 102 if (ArrayType *ArrType = dyn_cast<ArrayType>(T)) 103 return isa<ArrayType>(ArrType->getElementType()); 104 return false; 105 } 106 107 std::pair<unsigned, Type *> 108 DXILFlattenArraysVisitor::getElementCountAndType(Type *ArrayTy) { 109 unsigned TotalElements = 1; 110 Type *CurrArrayTy = ArrayTy; 111 while (auto *InnerArrayTy = dyn_cast<ArrayType>(CurrArrayTy)) { 112 TotalElements *= InnerArrayTy->getNumElements(); 113 CurrArrayTy = InnerArrayTy->getElementType(); 114 } 115 return std::make_pair(TotalElements, CurrArrayTy); 116 } 117 118 ConstantInt *DXILFlattenArraysVisitor::genConstFlattenIndices( 119 ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) { 120 assert(Indices.size() == Dims.size() && 121 "Indicies and dimmensions should be the same"); 122 unsigned FlatIndex = 0; 123 unsigned Multiplier = 1; 124 125 for (int I = Indices.size() - 1; I >= 0; --I) { 126 unsigned DimSize = Dims[I]; 127 ConstantInt *CIndex = dyn_cast<ConstantInt>(Indices[I]); 128 assert(CIndex && "This function expects all indicies to be ConstantInt"); 129 FlatIndex += CIndex->getZExtValue() * Multiplier; 130 Multiplier *= DimSize; 131 } 132 return Builder.getInt32(FlatIndex); 133 } 134 135 Value *DXILFlattenArraysVisitor::genInstructionFlattenIndices( 136 ArrayRef<Value *> Indices, ArrayRef<uint64_t> Dims, IRBuilder<> &Builder) { 137 if (Indices.size() == 1) 138 return Indices[0]; 139 140 Value *FlatIndex = Builder.getInt32(0); 141 unsigned Multiplier = 1; 142 143 for (int I = Indices.size() - 1; I >= 0; --I) { 144 unsigned DimSize = Dims[I]; 145 Value *VMultiplier = Builder.getInt32(Multiplier); 146 Value *ScaledIndex = Builder.CreateMul(Indices[I], VMultiplier); 147 FlatIndex = Builder.CreateAdd(FlatIndex, ScaledIndex); 148 Multiplier *= DimSize; 149 } 150 return FlatIndex; 151 } 152 153 bool DXILFlattenArraysVisitor::visitLoadInst(LoadInst &LI) { 154 unsigned NumOperands = LI.getNumOperands(); 155 for (unsigned I = 0; I < NumOperands; ++I) { 156 Value *CurrOpperand = LI.getOperand(I); 157 ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); 158 if (CE && CE->getOpcode() == Instruction::GetElementPtr) { 159 GetElementPtrInst *OldGEP = 160 cast<GetElementPtrInst>(CE->getAsInstruction()); 161 OldGEP->insertBefore(LI.getIterator()); 162 163 IRBuilder<> Builder(&LI); 164 LoadInst *NewLoad = 165 Builder.CreateLoad(LI.getType(), OldGEP, LI.getName()); 166 NewLoad->setAlignment(LI.getAlign()); 167 LI.replaceAllUsesWith(NewLoad); 168 LI.eraseFromParent(); 169 visitGetElementPtrInst(*OldGEP); 170 return true; 171 } 172 } 173 return false; 174 } 175 176 bool DXILFlattenArraysVisitor::visitStoreInst(StoreInst &SI) { 177 unsigned NumOperands = SI.getNumOperands(); 178 for (unsigned I = 0; I < NumOperands; ++I) { 179 Value *CurrOpperand = SI.getOperand(I); 180 ConstantExpr *CE = dyn_cast<ConstantExpr>(CurrOpperand); 181 if (CE && CE->getOpcode() == Instruction::GetElementPtr) { 182 GetElementPtrInst *OldGEP = 183 cast<GetElementPtrInst>(CE->getAsInstruction()); 184 OldGEP->insertBefore(SI.getIterator()); 185 186 IRBuilder<> Builder(&SI); 187 StoreInst *NewStore = Builder.CreateStore(SI.getValueOperand(), OldGEP); 188 NewStore->setAlignment(SI.getAlign()); 189 SI.replaceAllUsesWith(NewStore); 190 SI.eraseFromParent(); 191 visitGetElementPtrInst(*OldGEP); 192 return true; 193 } 194 } 195 return false; 196 } 197 198 bool DXILFlattenArraysVisitor::visitAllocaInst(AllocaInst &AI) { 199 if (!isMultiDimensionalArray(AI.getAllocatedType())) 200 return false; 201 202 ArrayType *ArrType = cast<ArrayType>(AI.getAllocatedType()); 203 IRBuilder<> Builder(&AI); 204 auto [TotalElements, BaseType] = getElementCountAndType(ArrType); 205 206 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements); 207 AllocaInst *FlatAlloca = 208 Builder.CreateAlloca(FattenedArrayType, nullptr, AI.getName() + ".1dim"); 209 FlatAlloca->setAlignment(AI.getAlign()); 210 AI.replaceAllUsesWith(FlatAlloca); 211 AI.eraseFromParent(); 212 return true; 213 } 214 215 bool DXILFlattenArraysVisitor::visitGetElementPtrInst(GetElementPtrInst &GEP) { 216 // Do not visit GEPs more than once 217 if (GEPChainInfoMap.contains(cast<GEPOperator>(&GEP))) 218 return false; 219 220 Value *PtrOperand = GEP.getPointerOperand(); 221 // It shouldn't(?) be possible for the pointer operand of a GEP to be a PHI 222 // node unless HLSL has pointers. If this assumption is incorrect or HLSL gets 223 // pointer types, then the handling of this case can be implemented later. 224 assert(!isa<PHINode>(PtrOperand) && 225 "Pointer operand of GEP should not be a PHI Node"); 226 227 // Replace a GEP ConstantExpr pointer operand with a GEP instruction so that 228 // it can be visited 229 if (auto *PtrOpGEPCE = dyn_cast<ConstantExpr>(PtrOperand); 230 PtrOpGEPCE && PtrOpGEPCE->getOpcode() == Instruction::GetElementPtr) { 231 GetElementPtrInst *OldGEPI = 232 cast<GetElementPtrInst>(PtrOpGEPCE->getAsInstruction()); 233 OldGEPI->insertBefore(GEP.getIterator()); 234 235 IRBuilder<> Builder(&GEP); 236 SmallVector<Value *> Indices(GEP.indices()); 237 Value *NewGEP = 238 Builder.CreateGEP(GEP.getSourceElementType(), OldGEPI, Indices, 239 GEP.getName(), GEP.getNoWrapFlags()); 240 assert(isa<GetElementPtrInst>(NewGEP) && 241 "Expected newly-created GEP to be an instruction"); 242 GetElementPtrInst *NewGEPI = cast<GetElementPtrInst>(NewGEP); 243 244 GEP.replaceAllUsesWith(NewGEPI); 245 GEP.eraseFromParent(); 246 visitGetElementPtrInst(*OldGEPI); 247 visitGetElementPtrInst(*NewGEPI); 248 return true; 249 } 250 251 // Construct GEPInfo for this GEP 252 GEPInfo Info; 253 254 // Obtain the variable and constant byte offsets computed by this GEP 255 const DataLayout &DL = GEP.getDataLayout(); 256 unsigned BitWidth = DL.getIndexTypeSizeInBits(GEP.getType()); 257 Info.ConstantOffset = {BitWidth, 0}; 258 [[maybe_unused]] bool Success = GEP.collectOffset( 259 DL, BitWidth, Info.VariableOffsets, Info.ConstantOffset); 260 assert(Success && "Failed to collect offsets for GEP"); 261 262 // If there is a parent GEP, inherit the root array type and pointer, and 263 // merge the byte offsets. Otherwise, this GEP is itself the root of a GEP 264 // chain and we need to deterine the root array type 265 if (auto *PtrOpGEP = dyn_cast<GEPOperator>(PtrOperand)) { 266 assert(GEPChainInfoMap.contains(PtrOpGEP) && 267 "Expected parent GEP to be visited before this GEP"); 268 GEPInfo &PGEPInfo = GEPChainInfoMap[PtrOpGEP]; 269 Info.RootFlattenedArrayType = PGEPInfo.RootFlattenedArrayType; 270 Info.RootPointerOperand = PGEPInfo.RootPointerOperand; 271 for (auto &VariableOffset : PGEPInfo.VariableOffsets) 272 Info.VariableOffsets.insert(VariableOffset); 273 Info.ConstantOffset += PGEPInfo.ConstantOffset; 274 } else { 275 Info.RootPointerOperand = PtrOperand; 276 277 // We should try to determine the type of the root from the pointer rather 278 // than the GEP's source element type because this could be a scalar GEP 279 // into an array-typed pointer from an Alloca or Global Variable. 280 Type *RootTy = GEP.getSourceElementType(); 281 if (auto *GlobalVar = dyn_cast<GlobalVariable>(PtrOperand)) { 282 if (GlobalMap.contains(GlobalVar)) 283 GlobalVar = GlobalMap[GlobalVar]; 284 Info.RootPointerOperand = GlobalVar; 285 RootTy = GlobalVar->getValueType(); 286 } else if (auto *Alloca = dyn_cast<AllocaInst>(PtrOperand)) 287 RootTy = Alloca->getAllocatedType(); 288 assert(!isMultiDimensionalArray(RootTy) && 289 "Expected root array type to be flattened"); 290 291 // If the root type is not an array, we don't need to do any flattening 292 if (!isa<ArrayType>(RootTy)) 293 return false; 294 295 Info.RootFlattenedArrayType = cast<ArrayType>(RootTy); 296 } 297 298 // GEPs without users or GEPs with non-GEP users should be replaced such that 299 // the chain of GEPs they are a part of are collapsed to a single GEP into a 300 // flattened array. 301 bool ReplaceThisGEP = GEP.users().empty(); 302 for (Value *User : GEP.users()) 303 if (!isa<GetElementPtrInst>(User)) 304 ReplaceThisGEP = true; 305 306 if (ReplaceThisGEP) { 307 unsigned BytesPerElem = 308 DL.getTypeAllocSize(Info.RootFlattenedArrayType->getArrayElementType()); 309 assert(isPowerOf2_32(BytesPerElem) && 310 "Bytes per element should be a power of 2"); 311 312 // Compute the 32-bit index for this flattened GEP from the constant and 313 // variable byte offsets in the GEPInfo 314 IRBuilder<> Builder(&GEP); 315 Value *ZeroIndex = Builder.getInt32(0); 316 uint64_t ConstantOffset = 317 Info.ConstantOffset.udiv(BytesPerElem).getZExtValue(); 318 assert(ConstantOffset < UINT32_MAX && 319 "Constant byte offset for flat GEP index must fit within 32 bits"); 320 Value *FlattenedIndex = Builder.getInt32(ConstantOffset); 321 for (auto [VarIndex, Multiplier] : Info.VariableOffsets) { 322 assert(Multiplier.getActiveBits() <= 32 && 323 "The multiplier for a flat GEP index must fit within 32 bits"); 324 assert(VarIndex->getType()->isIntegerTy(32) && 325 "Expected i32-typed GEP indices"); 326 Value *VI; 327 if (Multiplier.getZExtValue() % BytesPerElem != 0) { 328 // This can happen, e.g., with i8 GEPs. To handle this we just divide 329 // by BytesPerElem using an instruction after multiplying VarIndex by 330 // Multiplier. 331 VI = Builder.CreateMul(VarIndex, 332 Builder.getInt32(Multiplier.getZExtValue())); 333 VI = Builder.CreateLShr(VI, Builder.getInt32(Log2_32(BytesPerElem))); 334 } else 335 VI = Builder.CreateMul( 336 VarIndex, 337 Builder.getInt32(Multiplier.getZExtValue() / BytesPerElem)); 338 FlattenedIndex = Builder.CreateAdd(FlattenedIndex, VI); 339 } 340 341 // Construct a new GEP for the flattened array to replace the current GEP 342 Value *NewGEP = Builder.CreateGEP( 343 Info.RootFlattenedArrayType, Info.RootPointerOperand, 344 {ZeroIndex, FlattenedIndex}, GEP.getName(), GEP.getNoWrapFlags()); 345 346 // Replace the current GEP with the new GEP. Store GEPInfo into the map 347 // for later use in case this GEP was not the end of the chain 348 GEPChainInfoMap.insert({cast<GEPOperator>(NewGEP), std::move(Info)}); 349 GEP.replaceAllUsesWith(NewGEP); 350 GEP.eraseFromParent(); 351 return true; 352 } 353 354 // This GEP is potentially dead at the end of the pass since it may not have 355 // any users anymore after GEP chains have been collapsed. We retain store 356 // GEPInfo for GEPs down the chain to use to compute their indices. 357 GEPChainInfoMap.insert({cast<GEPOperator>(&GEP), std::move(Info)}); 358 PotentiallyDeadInstrs.emplace_back(&GEP); 359 return false; 360 } 361 362 bool DXILFlattenArraysVisitor::visit(Function &F) { 363 bool MadeChange = false; 364 ReversePostOrderTraversal<Function *> RPOT(&F); 365 for (BasicBlock *BB : make_early_inc_range(RPOT)) { 366 for (Instruction &I : make_early_inc_range(*BB)) 367 MadeChange |= InstVisitor::visit(I); 368 } 369 finish(); 370 return MadeChange; 371 } 372 373 static void collectElements(Constant *Init, 374 SmallVectorImpl<Constant *> &Elements) { 375 // Base case: If Init is not an array, add it directly to the vector. 376 auto *ArrayTy = dyn_cast<ArrayType>(Init->getType()); 377 if (!ArrayTy) { 378 Elements.push_back(Init); 379 return; 380 } 381 unsigned ArrSize = ArrayTy->getNumElements(); 382 if (isa<ConstantAggregateZero>(Init)) { 383 for (unsigned I = 0; I < ArrSize; ++I) 384 Elements.push_back(Constant::getNullValue(ArrayTy->getElementType())); 385 return; 386 } 387 388 // Recursive case: Process each element in the array. 389 if (auto *ArrayConstant = dyn_cast<ConstantArray>(Init)) { 390 for (unsigned I = 0; I < ArrayConstant->getNumOperands(); ++I) { 391 collectElements(ArrayConstant->getOperand(I), Elements); 392 } 393 } else if (auto *DataArrayConstant = dyn_cast<ConstantDataArray>(Init)) { 394 for (unsigned I = 0; I < DataArrayConstant->getNumElements(); ++I) { 395 collectElements(DataArrayConstant->getElementAsConstant(I), Elements); 396 } 397 } else { 398 llvm_unreachable( 399 "Expected a ConstantArray or ConstantDataArray for array initializer!"); 400 } 401 } 402 403 static Constant *transformInitializer(Constant *Init, Type *OrigType, 404 ArrayType *FlattenedType, 405 LLVMContext &Ctx) { 406 // Handle ConstantAggregateZero (zero-initialized constants) 407 if (isa<ConstantAggregateZero>(Init)) 408 return ConstantAggregateZero::get(FlattenedType); 409 410 // Handle UndefValue (undefined constants) 411 if (isa<UndefValue>(Init)) 412 return UndefValue::get(FlattenedType); 413 414 if (!isa<ArrayType>(OrigType)) 415 return Init; 416 417 SmallVector<Constant *> FlattenedElements; 418 collectElements(Init, FlattenedElements); 419 assert(FlattenedType->getNumElements() == FlattenedElements.size() && 420 "The number of collected elements should match the FlattenedType"); 421 return ConstantArray::get(FlattenedType, FlattenedElements); 422 } 423 424 static void flattenGlobalArrays( 425 Module &M, SmallDenseMap<GlobalVariable *, GlobalVariable *> &GlobalMap) { 426 LLVMContext &Ctx = M.getContext(); 427 for (GlobalVariable &G : M.globals()) { 428 Type *OrigType = G.getValueType(); 429 if (!DXILFlattenArraysVisitor::isMultiDimensionalArray(OrigType)) 430 continue; 431 432 ArrayType *ArrType = cast<ArrayType>(OrigType); 433 auto [TotalElements, BaseType] = 434 DXILFlattenArraysVisitor::getElementCountAndType(ArrType); 435 ArrayType *FattenedArrayType = ArrayType::get(BaseType, TotalElements); 436 437 // Create a new global variable with the updated type 438 // Note: Initializer is set via transformInitializer 439 GlobalVariable *NewGlobal = 440 new GlobalVariable(M, FattenedArrayType, G.isConstant(), G.getLinkage(), 441 /*Initializer=*/nullptr, G.getName() + ".1dim", &G, 442 G.getThreadLocalMode(), G.getAddressSpace(), 443 G.isExternallyInitialized()); 444 445 // Copy relevant attributes 446 NewGlobal->setUnnamedAddr(G.getUnnamedAddr()); 447 if (G.getAlignment() > 0) { 448 NewGlobal->setAlignment(G.getAlign()); 449 } 450 451 if (G.hasInitializer()) { 452 Constant *Init = G.getInitializer(); 453 Constant *NewInit = 454 transformInitializer(Init, OrigType, FattenedArrayType, Ctx); 455 NewGlobal->setInitializer(NewInit); 456 } 457 GlobalMap[&G] = NewGlobal; 458 } 459 } 460 461 static bool flattenArrays(Module &M) { 462 bool MadeChange = false; 463 SmallDenseMap<GlobalVariable *, GlobalVariable *> GlobalMap; 464 flattenGlobalArrays(M, GlobalMap); 465 DXILFlattenArraysVisitor Impl(GlobalMap); 466 for (auto &F : make_early_inc_range(M.functions())) { 467 if (F.isDeclaration()) 468 continue; 469 MadeChange |= Impl.visit(F); 470 } 471 for (auto &[Old, New] : GlobalMap) { 472 Old->replaceAllUsesWith(New); 473 Old->eraseFromParent(); 474 MadeChange = true; 475 } 476 return MadeChange; 477 } 478 479 PreservedAnalyses DXILFlattenArrays::run(Module &M, ModuleAnalysisManager &) { 480 bool MadeChanges = flattenArrays(M); 481 if (!MadeChanges) 482 return PreservedAnalyses::all(); 483 PreservedAnalyses PA; 484 return PA; 485 } 486 487 bool DXILFlattenArraysLegacy::runOnModule(Module &M) { 488 return flattenArrays(M); 489 } 490 491 char DXILFlattenArraysLegacy::ID = 0; 492 493 INITIALIZE_PASS_BEGIN(DXILFlattenArraysLegacy, DEBUG_TYPE, 494 "DXIL Array Flattener", false, false) 495 INITIALIZE_PASS_END(DXILFlattenArraysLegacy, DEBUG_TYPE, "DXIL Array Flattener", 496 false, false) 497 498 ModulePass *llvm::createDXILFlattenArraysLegacyPass() { 499 return new DXILFlattenArraysLegacy(); 500 } 501