1 //===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "DXILCBufferAccess.h" 10 #include "DirectX.h" 11 #include "llvm/Frontend/HLSL/CBuffer.h" 12 #include "llvm/Frontend/HLSL/HLSLResource.h" 13 #include "llvm/IR/IRBuilder.h" 14 #include "llvm/IR/IntrinsicInst.h" 15 #include "llvm/IR/IntrinsicsDirectX.h" 16 #include "llvm/InitializePasses.h" 17 #include "llvm/Pass.h" 18 #include "llvm/Support/FormatVariadic.h" 19 #include "llvm/Transforms/Utils/Local.h" 20 21 #define DEBUG_TYPE "dxil-cbuffer-access" 22 using namespace llvm; 23 24 namespace { 25 /// Helper for building a `load.cbufferrow` intrinsic given a simple type. 26 struct CBufferRowIntrin { 27 Intrinsic::ID IID; 28 Type *RetTy; 29 unsigned int EltSize; 30 unsigned int NumElts; 31 32 CBufferRowIntrin(const DataLayout &DL, Type *Ty) { 33 assert(Ty == Ty->getScalarType() && "Expected scalar type"); 34 35 switch (DL.getTypeSizeInBits(Ty)) { 36 case 16: 37 IID = Intrinsic::dx_resource_load_cbufferrow_8; 38 RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty); 39 EltSize = 2; 40 NumElts = 8; 41 break; 42 case 32: 43 IID = Intrinsic::dx_resource_load_cbufferrow_4; 44 RetTy = StructType::get(Ty, Ty, Ty, Ty); 45 EltSize = 4; 46 NumElts = 4; 47 break; 48 case 64: 49 IID = Intrinsic::dx_resource_load_cbufferrow_2; 50 RetTy = StructType::get(Ty, Ty); 51 EltSize = 8; 52 NumElts = 2; 53 break; 54 default: 55 llvm_unreachable("Only 16, 32, and 64 bit types supported"); 56 } 57 } 58 }; 59 60 // Helper for creating CBuffer handles and loading data from them 61 struct CBufferResource { 62 GlobalVariable *GVHandle; 63 GlobalVariable *Member; 64 size_t MemberOffset; 65 66 LoadInst *Handle; 67 68 CBufferResource(GlobalVariable *GVHandle, GlobalVariable *Member, 69 size_t MemberOffset) 70 : GVHandle(GVHandle), Member(Member), MemberOffset(MemberOffset) {} 71 72 const DataLayout &getDataLayout() { return GVHandle->getDataLayout(); } 73 Type *getValueType() { return Member->getValueType(); } 74 iterator_range<ConstantDataSequential::user_iterator> users() { 75 return Member->users(); 76 } 77 78 /// Get the byte offset of a Pointer-typed Value * `Val` relative to Member. 79 /// `Val` can either be Member itself, or a GEP of a constant offset from 80 /// Member 81 size_t getOffsetForCBufferGEP(Value *Val) { 82 assert(isa<PointerType>(Val->getType()) && 83 "Expected a pointer-typed value"); 84 85 if (Val == Member) 86 return 0; 87 88 if (auto *GEP = dyn_cast<GEPOperator>(Val)) { 89 // Since we should always have a constant offset, we should only ever have 90 // a single GEP of indirection from the Global. 91 assert(GEP->getPointerOperand() == Member && 92 "Indirect access to resource handle"); 93 94 const DataLayout &DL = getDataLayout(); 95 APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0); 96 bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset); 97 (void)Success; 98 assert(Success && "Offsets into cbuffer globals must be constant"); 99 100 if (auto *ATy = dyn_cast<ArrayType>(Member->getValueType())) 101 ConstantOffset = 102 hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy); 103 104 return ConstantOffset.getZExtValue(); 105 } 106 107 llvm_unreachable("Expected Val to be a GlobalVariable or GEP"); 108 } 109 110 /// Create a handle for this cbuffer resource using the IRBuilder `Builder` 111 /// and sets the handle as the current one to use for subsequent calls to 112 /// `loadValue` 113 void createAndSetCurrentHandle(IRBuilder<> &Builder) { 114 Handle = Builder.CreateLoad(GVHandle->getValueType(), GVHandle, 115 GVHandle->getName()); 116 } 117 118 /// Load a value of type `Ty` at offset `Offset` using the handle from the 119 /// last call to `createAndSetCurrentHandle` 120 Value *loadValue(IRBuilder<> &Builder, Type *Ty, size_t Offset, 121 const Twine &Name = "") { 122 assert(Handle && 123 "Expected a handle for this cbuffer global resource to be created " 124 "before loading a value from it"); 125 const DataLayout &DL = getDataLayout(); 126 127 size_t TargetOffset = MemberOffset + Offset; 128 CBufferRowIntrin Intrin(DL, Ty->getScalarType()); 129 // The cbuffer consists of some number of 16-byte rows. 130 unsigned int CurrentRow = TargetOffset / hlsl::CBufferRowSizeInBytes; 131 unsigned int CurrentIndex = 132 (TargetOffset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize; 133 134 auto *CBufLoad = Builder.CreateIntrinsic( 135 Intrin.RetTy, Intrin.IID, 136 {Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr, 137 Name + ".load"); 138 auto *Elt = Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, 139 Name + ".extract"); 140 141 Value *Result = nullptr; 142 unsigned int Remaining = 143 ((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1; 144 145 if (Remaining == 0) { 146 // We only have a single element, so we're done. 147 Result = Elt; 148 149 // However, if we loaded a <1 x T>, then we need to adjust the type here. 150 if (auto *VT = dyn_cast<FixedVectorType>(Ty)) { 151 assert(VT->getNumElements() == 1 && 152 "Can't have multiple elements here"); 153 Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result, 154 Builder.getInt32(0), Name); 155 } 156 return Result; 157 } 158 159 // Walk each element and extract it, wrapping to new rows as needed. 160 SmallVector<Value *> Extracts{Elt}; 161 while (Remaining--) { 162 CurrentIndex %= Intrin.NumElts; 163 164 if (CurrentIndex == 0) 165 CBufLoad = Builder.CreateIntrinsic( 166 Intrin.RetTy, Intrin.IID, 167 {Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)}, 168 nullptr, Name + ".load"); 169 170 Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++}, 171 Name + ".extract")); 172 } 173 174 // Finally, we build up the original loaded value. 175 Result = PoisonValue::get(Ty); 176 for (int I = 0, E = Extracts.size(); I < E; ++I) 177 Result = 178 Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I), 179 Name + formatv(".upto{}", I)); 180 return Result; 181 } 182 }; 183 184 } // namespace 185 186 /// Replace load via cbuffer global with a load from the cbuffer handle itself. 187 static void replaceLoad(LoadInst *LI, CBufferResource &CBR, 188 SmallVectorImpl<WeakTrackingVH> &DeadInsts) { 189 size_t Offset = CBR.getOffsetForCBufferGEP(LI->getPointerOperand()); 190 IRBuilder<> Builder(LI); 191 CBR.createAndSetCurrentHandle(Builder); 192 Value *Result = CBR.loadValue(Builder, LI->getType(), Offset, LI->getName()); 193 LI->replaceAllUsesWith(Result); 194 DeadInsts.push_back(LI); 195 } 196 197 /// This function recursively copies N array elements from the cbuffer resource 198 /// CBR to the MemCpy Destination. Recursion is used to unravel multidimensional 199 /// arrays into a sequence of scalar/vector extracts and stores. 200 static void copyArrayElemsForMemCpy(IRBuilder<> &Builder, MemCpyInst *MCI, 201 CBufferResource &CBR, ArrayType *ArrTy, 202 size_t ArrOffset, size_t N, 203 const Twine &Name = "") { 204 const DataLayout &DL = MCI->getDataLayout(); 205 Type *ElemTy = ArrTy->getElementType(); 206 size_t ElemTySize = DL.getTypeAllocSize(ElemTy); 207 for (unsigned I = 0; I < N; ++I) { 208 size_t Offset = ArrOffset + I * ElemTySize; 209 210 // Recursively copy nested arrays 211 if (ArrayType *ElemArrTy = dyn_cast<ArrayType>(ElemTy)) { 212 copyArrayElemsForMemCpy(Builder, MCI, CBR, ElemArrTy, Offset, 213 ElemArrTy->getNumElements(), Name); 214 continue; 215 } 216 217 // Load CBuffer value and store it in Dest 218 APInt CBufArrayOffset( 219 DL.getIndexTypeSizeInBits(MCI->getSource()->getType()), Offset); 220 CBufArrayOffset = 221 hlsl::translateCBufArrayOffset(DL, CBufArrayOffset, ArrTy); 222 Value *CBufferVal = 223 CBR.loadValue(Builder, ElemTy, CBufArrayOffset.getZExtValue(), Name); 224 Value *GEP = 225 Builder.CreateInBoundsGEP(Builder.getInt8Ty(), MCI->getDest(), 226 {Builder.getInt32(Offset)}, Name + ".dest"); 227 Builder.CreateStore(CBufferVal, GEP, MCI->isVolatile()); 228 } 229 } 230 231 /// Replace memcpy from a cbuffer global with a memcpy from the cbuffer handle 232 /// itself. Assumes the cbuffer global is an array, and the length of bytes to 233 /// copy is divisible by array element allocation size. 234 /// The memcpy source must also be a direct cbuffer global reference, not a GEP. 235 static void replaceMemCpy(MemCpyInst *MCI, CBufferResource &CBR) { 236 237 ArrayType *ArrTy = dyn_cast<ArrayType>(CBR.getValueType()); 238 assert(ArrTy && "MemCpy lowering is only supported for array types"); 239 240 // This assumption vastly simplifies the implementation 241 if (MCI->getSource() != CBR.Member) 242 reportFatalUsageError( 243 "Expected MemCpy source to be a cbuffer global variable"); 244 245 ConstantInt *Length = dyn_cast<ConstantInt>(MCI->getLength()); 246 uint64_t ByteLength = Length->getZExtValue(); 247 248 // If length to copy is zero, no memcpy is needed 249 if (ByteLength == 0) { 250 MCI->eraseFromParent(); 251 return; 252 } 253 254 const DataLayout &DL = CBR.getDataLayout(); 255 256 Type *ElemTy = ArrTy->getElementType(); 257 size_t ElemSize = DL.getTypeAllocSize(ElemTy); 258 assert(ByteLength % ElemSize == 0 && 259 "Length of bytes to MemCpy must be divisible by allocation size of " 260 "source/destination array elements"); 261 size_t ElemsToCpy = ByteLength / ElemSize; 262 263 IRBuilder<> Builder(MCI); 264 CBR.createAndSetCurrentHandle(Builder); 265 266 copyArrayElemsForMemCpy(Builder, MCI, CBR, ArrTy, 0, ElemsToCpy, 267 "memcpy." + MCI->getDest()->getName() + "." + 268 MCI->getSource()->getName()); 269 270 MCI->eraseFromParent(); 271 } 272 273 static void replaceAccessesWithHandle(CBufferResource &CBR) { 274 SmallVector<WeakTrackingVH> DeadInsts; 275 276 SmallVector<User *> ToProcess{CBR.users()}; 277 while (!ToProcess.empty()) { 278 User *Cur = ToProcess.pop_back_val(); 279 280 // If we have a load instruction, replace the access. 281 if (auto *LI = dyn_cast<LoadInst>(Cur)) { 282 replaceLoad(LI, CBR, DeadInsts); 283 continue; 284 } 285 286 // If we have a memcpy instruction, replace it with multiple accesses and 287 // subsequent stores to the destination 288 if (auto *MCI = dyn_cast<MemCpyInst>(Cur)) { 289 replaceMemCpy(MCI, CBR); 290 continue; 291 } 292 293 // Otherwise, walk users looking for a load... 294 if (isa<GetElementPtrInst>(Cur) || isa<GEPOperator>(Cur)) { 295 ToProcess.append(Cur->user_begin(), Cur->user_end()); 296 continue; 297 } 298 299 llvm_unreachable("Unexpected user of Global"); 300 } 301 RecursivelyDeleteTriviallyDeadInstructions(DeadInsts); 302 } 303 304 static bool replaceCBufferAccesses(Module &M) { 305 std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M); 306 if (!CBufMD) 307 return false; 308 309 for (const hlsl::CBufferMapping &Mapping : *CBufMD) 310 for (const hlsl::CBufferMember &Member : Mapping.Members) { 311 CBufferResource CBR(Mapping.Handle, Member.GV, Member.Offset); 312 replaceAccessesWithHandle(CBR); 313 Member.GV->removeFromParent(); 314 } 315 316 CBufMD->eraseFromModule(); 317 return true; 318 } 319 320 PreservedAnalyses DXILCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) { 321 PreservedAnalyses PA; 322 bool Changed = replaceCBufferAccesses(M); 323 324 if (!Changed) 325 return PreservedAnalyses::all(); 326 return PA; 327 } 328 329 namespace { 330 class DXILCBufferAccessLegacy : public ModulePass { 331 public: 332 bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); } 333 StringRef getPassName() const override { return "DXIL CBuffer Access"; } 334 DXILCBufferAccessLegacy() : ModulePass(ID) {} 335 336 static char ID; // Pass identification. 337 }; 338 char DXILCBufferAccessLegacy::ID = 0; 339 } // end anonymous namespace 340 341 INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access", 342 false, false) 343 344 ModulePass *llvm::createDXILCBufferAccessLegacyPass() { 345 return new DXILCBufferAccessLegacy(); 346 } 347