xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILCBufferAccess.cpp (revision e3f4a63af63bea70bc86b6c790b14aa5ee99fcd0)
1 //===- DXILCBufferAccess.cpp - Translate CBuffer Loads --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DXILCBufferAccess.h"
10 #include "DirectX.h"
11 #include "llvm/Frontend/HLSL/CBuffer.h"
12 #include "llvm/Frontend/HLSL/HLSLResource.h"
13 #include "llvm/IR/IRBuilder.h"
14 #include "llvm/IR/IntrinsicInst.h"
15 #include "llvm/IR/IntrinsicsDirectX.h"
16 #include "llvm/InitializePasses.h"
17 #include "llvm/Pass.h"
18 #include "llvm/Support/FormatVariadic.h"
19 #include "llvm/Transforms/Utils/Local.h"
20 
21 #define DEBUG_TYPE "dxil-cbuffer-access"
22 using namespace llvm;
23 
24 namespace {
25 /// Helper for building a `load.cbufferrow` intrinsic given a simple type.
26 struct CBufferRowIntrin {
27   Intrinsic::ID IID;
28   Type *RetTy;
29   unsigned int EltSize;
30   unsigned int NumElts;
31 
32   CBufferRowIntrin(const DataLayout &DL, Type *Ty) {
33     assert(Ty == Ty->getScalarType() && "Expected scalar type");
34 
35     switch (DL.getTypeSizeInBits(Ty)) {
36     case 16:
37       IID = Intrinsic::dx_resource_load_cbufferrow_8;
38       RetTy = StructType::get(Ty, Ty, Ty, Ty, Ty, Ty, Ty, Ty);
39       EltSize = 2;
40       NumElts = 8;
41       break;
42     case 32:
43       IID = Intrinsic::dx_resource_load_cbufferrow_4;
44       RetTy = StructType::get(Ty, Ty, Ty, Ty);
45       EltSize = 4;
46       NumElts = 4;
47       break;
48     case 64:
49       IID = Intrinsic::dx_resource_load_cbufferrow_2;
50       RetTy = StructType::get(Ty, Ty);
51       EltSize = 8;
52       NumElts = 2;
53       break;
54     default:
55       llvm_unreachable("Only 16, 32, and 64 bit types supported");
56     }
57   }
58 };
59 
60 // Helper for creating CBuffer handles and loading data from them
61 struct CBufferResource {
62   GlobalVariable *GVHandle;
63   GlobalVariable *Member;
64   size_t MemberOffset;
65 
66   LoadInst *Handle;
67 
68   CBufferResource(GlobalVariable *GVHandle, GlobalVariable *Member,
69                   size_t MemberOffset)
70       : GVHandle(GVHandle), Member(Member), MemberOffset(MemberOffset) {}
71 
72   const DataLayout &getDataLayout() { return GVHandle->getDataLayout(); }
73   Type *getValueType() { return Member->getValueType(); }
74   iterator_range<ConstantDataSequential::user_iterator> users() {
75     return Member->users();
76   }
77 
78   /// Get the byte offset of a Pointer-typed Value * `Val` relative to Member.
79   /// `Val` can either be Member itself, or a GEP of a constant offset from
80   /// Member
81   size_t getOffsetForCBufferGEP(Value *Val) {
82     assert(isa<PointerType>(Val->getType()) &&
83            "Expected a pointer-typed value");
84 
85     if (Val == Member)
86       return 0;
87 
88     if (auto *GEP = dyn_cast<GEPOperator>(Val)) {
89       // Since we should always have a constant offset, we should only ever have
90       // a single GEP of indirection from the Global.
91       assert(GEP->getPointerOperand() == Member &&
92              "Indirect access to resource handle");
93 
94       const DataLayout &DL = getDataLayout();
95       APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
96       bool Success = GEP->accumulateConstantOffset(DL, ConstantOffset);
97       (void)Success;
98       assert(Success && "Offsets into cbuffer globals must be constant");
99 
100       if (auto *ATy = dyn_cast<ArrayType>(Member->getValueType()))
101         ConstantOffset =
102             hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
103 
104       return ConstantOffset.getZExtValue();
105     }
106 
107     llvm_unreachable("Expected Val to be a GlobalVariable or GEP");
108   }
109 
110   /// Create a handle for this cbuffer resource using the IRBuilder `Builder`
111   /// and sets the handle as the current one to use for subsequent calls to
112   /// `loadValue`
113   void createAndSetCurrentHandle(IRBuilder<> &Builder) {
114     Handle = Builder.CreateLoad(GVHandle->getValueType(), GVHandle,
115                                 GVHandle->getName());
116   }
117 
118   /// Load a value of type `Ty` at offset `Offset` using the handle from the
119   /// last call to `createAndSetCurrentHandle`
120   Value *loadValue(IRBuilder<> &Builder, Type *Ty, size_t Offset,
121                    const Twine &Name = "") {
122     assert(Handle &&
123            "Expected a handle for this cbuffer global resource to be created "
124            "before loading a value from it");
125     const DataLayout &DL = getDataLayout();
126 
127     size_t TargetOffset = MemberOffset + Offset;
128     CBufferRowIntrin Intrin(DL, Ty->getScalarType());
129     // The cbuffer consists of some number of 16-byte rows.
130     unsigned int CurrentRow = TargetOffset / hlsl::CBufferRowSizeInBytes;
131     unsigned int CurrentIndex =
132         (TargetOffset % hlsl::CBufferRowSizeInBytes) / Intrin.EltSize;
133 
134     auto *CBufLoad = Builder.CreateIntrinsic(
135         Intrin.RetTy, Intrin.IID,
136         {Handle, ConstantInt::get(Builder.getInt32Ty(), CurrentRow)}, nullptr,
137         Name + ".load");
138     auto *Elt = Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
139                                            Name + ".extract");
140 
141     Value *Result = nullptr;
142     unsigned int Remaining =
143         ((DL.getTypeSizeInBits(Ty) / 8) / Intrin.EltSize) - 1;
144 
145     if (Remaining == 0) {
146       // We only have a single element, so we're done.
147       Result = Elt;
148 
149       // However, if we loaded a <1 x T>, then we need to adjust the type here.
150       if (auto *VT = dyn_cast<FixedVectorType>(Ty)) {
151         assert(VT->getNumElements() == 1 &&
152                "Can't have multiple elements here");
153         Result = Builder.CreateInsertElement(PoisonValue::get(VT), Result,
154                                              Builder.getInt32(0), Name);
155       }
156       return Result;
157     }
158 
159     // Walk each element and extract it, wrapping to new rows as needed.
160     SmallVector<Value *> Extracts{Elt};
161     while (Remaining--) {
162       CurrentIndex %= Intrin.NumElts;
163 
164       if (CurrentIndex == 0)
165         CBufLoad = Builder.CreateIntrinsic(
166             Intrin.RetTy, Intrin.IID,
167             {Handle, ConstantInt::get(Builder.getInt32Ty(), ++CurrentRow)},
168             nullptr, Name + ".load");
169 
170       Extracts.push_back(Builder.CreateExtractValue(CBufLoad, {CurrentIndex++},
171                                                     Name + ".extract"));
172     }
173 
174     // Finally, we build up the original loaded value.
175     Result = PoisonValue::get(Ty);
176     for (int I = 0, E = Extracts.size(); I < E; ++I)
177       Result =
178           Builder.CreateInsertElement(Result, Extracts[I], Builder.getInt32(I),
179                                       Name + formatv(".upto{}", I));
180     return Result;
181   }
182 };
183 
184 } // namespace
185 
186 /// Replace load via cbuffer global with a load from the cbuffer handle itself.
187 static void replaceLoad(LoadInst *LI, CBufferResource &CBR,
188                         SmallVectorImpl<WeakTrackingVH> &DeadInsts) {
189   size_t Offset = CBR.getOffsetForCBufferGEP(LI->getPointerOperand());
190   IRBuilder<> Builder(LI);
191   CBR.createAndSetCurrentHandle(Builder);
192   Value *Result = CBR.loadValue(Builder, LI->getType(), Offset, LI->getName());
193   LI->replaceAllUsesWith(Result);
194   DeadInsts.push_back(LI);
195 }
196 
197 /// This function recursively copies N array elements from the cbuffer resource
198 /// CBR to the MemCpy Destination. Recursion is used to unravel multidimensional
199 /// arrays into a sequence of scalar/vector extracts and stores.
200 static void copyArrayElemsForMemCpy(IRBuilder<> &Builder, MemCpyInst *MCI,
201                                     CBufferResource &CBR, ArrayType *ArrTy,
202                                     size_t ArrOffset, size_t N,
203                                     const Twine &Name = "") {
204   const DataLayout &DL = MCI->getDataLayout();
205   Type *ElemTy = ArrTy->getElementType();
206   size_t ElemTySize = DL.getTypeAllocSize(ElemTy);
207   for (unsigned I = 0; I < N; ++I) {
208     size_t Offset = ArrOffset + I * ElemTySize;
209 
210     // Recursively copy nested arrays
211     if (ArrayType *ElemArrTy = dyn_cast<ArrayType>(ElemTy)) {
212       copyArrayElemsForMemCpy(Builder, MCI, CBR, ElemArrTy, Offset,
213                               ElemArrTy->getNumElements(), Name);
214       continue;
215     }
216 
217     // Load CBuffer value and store it in Dest
218     APInt CBufArrayOffset(
219         DL.getIndexTypeSizeInBits(MCI->getSource()->getType()), Offset);
220     CBufArrayOffset =
221         hlsl::translateCBufArrayOffset(DL, CBufArrayOffset, ArrTy);
222     Value *CBufferVal =
223         CBR.loadValue(Builder, ElemTy, CBufArrayOffset.getZExtValue(), Name);
224     Value *GEP =
225         Builder.CreateInBoundsGEP(Builder.getInt8Ty(), MCI->getDest(),
226                                   {Builder.getInt32(Offset)}, Name + ".dest");
227     Builder.CreateStore(CBufferVal, GEP, MCI->isVolatile());
228   }
229 }
230 
231 /// Replace memcpy from a cbuffer global with a memcpy from the cbuffer handle
232 /// itself. Assumes the cbuffer global is an array, and the length of bytes to
233 /// copy is divisible by array element allocation size.
234 /// The memcpy source must also be a direct cbuffer global reference, not a GEP.
235 static void replaceMemCpy(MemCpyInst *MCI, CBufferResource &CBR) {
236 
237   ArrayType *ArrTy = dyn_cast<ArrayType>(CBR.getValueType());
238   assert(ArrTy && "MemCpy lowering is only supported for array types");
239 
240   // This assumption vastly simplifies the implementation
241   if (MCI->getSource() != CBR.Member)
242     reportFatalUsageError(
243         "Expected MemCpy source to be a cbuffer global variable");
244 
245   ConstantInt *Length = dyn_cast<ConstantInt>(MCI->getLength());
246   uint64_t ByteLength = Length->getZExtValue();
247 
248   // If length to copy is zero, no memcpy is needed
249   if (ByteLength == 0) {
250     MCI->eraseFromParent();
251     return;
252   }
253 
254   const DataLayout &DL = CBR.getDataLayout();
255 
256   Type *ElemTy = ArrTy->getElementType();
257   size_t ElemSize = DL.getTypeAllocSize(ElemTy);
258   assert(ByteLength % ElemSize == 0 &&
259          "Length of bytes to MemCpy must be divisible by allocation size of "
260          "source/destination array elements");
261   size_t ElemsToCpy = ByteLength / ElemSize;
262 
263   IRBuilder<> Builder(MCI);
264   CBR.createAndSetCurrentHandle(Builder);
265 
266   copyArrayElemsForMemCpy(Builder, MCI, CBR, ArrTy, 0, ElemsToCpy,
267                           "memcpy." + MCI->getDest()->getName() + "." +
268                               MCI->getSource()->getName());
269 
270   MCI->eraseFromParent();
271 }
272 
273 static void replaceAccessesWithHandle(CBufferResource &CBR) {
274   SmallVector<WeakTrackingVH> DeadInsts;
275 
276   SmallVector<User *> ToProcess{CBR.users()};
277   while (!ToProcess.empty()) {
278     User *Cur = ToProcess.pop_back_val();
279 
280     // If we have a load instruction, replace the access.
281     if (auto *LI = dyn_cast<LoadInst>(Cur)) {
282       replaceLoad(LI, CBR, DeadInsts);
283       continue;
284     }
285 
286     // If we have a memcpy instruction, replace it with multiple accesses and
287     // subsequent stores to the destination
288     if (auto *MCI = dyn_cast<MemCpyInst>(Cur)) {
289       replaceMemCpy(MCI, CBR);
290       continue;
291     }
292 
293     // Otherwise, walk users looking for a load...
294     if (isa<GetElementPtrInst>(Cur) || isa<GEPOperator>(Cur)) {
295       ToProcess.append(Cur->user_begin(), Cur->user_end());
296       continue;
297     }
298 
299     llvm_unreachable("Unexpected user of Global");
300   }
301   RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);
302 }
303 
304 static bool replaceCBufferAccesses(Module &M) {
305   std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
306   if (!CBufMD)
307     return false;
308 
309   for (const hlsl::CBufferMapping &Mapping : *CBufMD)
310     for (const hlsl::CBufferMember &Member : Mapping.Members) {
311       CBufferResource CBR(Mapping.Handle, Member.GV, Member.Offset);
312       replaceAccessesWithHandle(CBR);
313       Member.GV->removeFromParent();
314     }
315 
316   CBufMD->eraseFromModule();
317   return true;
318 }
319 
320 PreservedAnalyses DXILCBufferAccess::run(Module &M, ModuleAnalysisManager &AM) {
321   PreservedAnalyses PA;
322   bool Changed = replaceCBufferAccesses(M);
323 
324   if (!Changed)
325     return PreservedAnalyses::all();
326   return PA;
327 }
328 
329 namespace {
330 class DXILCBufferAccessLegacy : public ModulePass {
331 public:
332   bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); }
333   StringRef getPassName() const override { return "DXIL CBuffer Access"; }
334   DXILCBufferAccessLegacy() : ModulePass(ID) {}
335 
336   static char ID; // Pass identification.
337 };
338 char DXILCBufferAccessLegacy::ID = 0;
339 } // end anonymous namespace
340 
341 INITIALIZE_PASS(DXILCBufferAccessLegacy, DEBUG_TYPE, "DXIL CBuffer Access",
342                 false, false)
343 
344 ModulePass *llvm::createDXILCBufferAccessLegacyPass() {
345   return new DXILCBufferAccessLegacy();
346 }
347