xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILOpLowering.cpp (revision 1342eb5a832fa10e689a29faab3acb6054e4778c)
1 //===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "DXILOpLowering.h"
10 #include "DXILConstants.h"
11 #include "DXILOpBuilder.h"
12 #include "DXILShaderFlags.h"
13 #include "DirectX.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/Analysis/DXILMetadataAnalysis.h"
16 #include "llvm/Analysis/DXILResource.h"
17 #include "llvm/CodeGen/Passes.h"
18 #include "llvm/IR/DiagnosticInfo.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Intrinsics.h"
23 #include "llvm/IR/IntrinsicsDirectX.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/PassManager.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/FormatVariadic.h"
30 
31 #define DEBUG_TYPE "dxil-op-lower"
32 
33 using namespace llvm;
34 using namespace llvm::dxil;
35 
36 namespace {
37 class OpLowerer {
38   Module &M;
39   DXILOpBuilder OpBuilder;
40   DXILResourceMap &DRM;
41   DXILResourceTypeMap &DRTM;
42   const ModuleMetadataInfo &MMDI;
43   SmallVector<CallInst *> CleanupCasts;
44 
45 public:
46   OpLowerer(Module &M, DXILResourceMap &DRM, DXILResourceTypeMap &DRTM,
47             const ModuleMetadataInfo &MMDI)
48       : M(M), OpBuilder(M), DRM(DRM), DRTM(DRTM), MMDI(MMDI) {}
49 
50   /// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If
51   /// there is an error replacing a call, we emit a diagnostic and return true.
52   [[nodiscard]] bool
53   replaceFunction(Function &F,
54                   llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
55     for (User *U : make_early_inc_range(F.users())) {
56       CallInst *CI = dyn_cast<CallInst>(U);
57       if (!CI)
58         continue;
59 
60       if (Error E = ReplaceCall(CI)) {
61         std::string Message(toString(std::move(E)));
62         M.getContext().diagnose(DiagnosticInfoUnsupported(
63             *CI->getFunction(), Message, CI->getDebugLoc()));
64 
65         return true;
66       }
67     }
68     if (F.user_empty())
69       F.eraseFromParent();
70     return false;
71   }
72 
73   struct IntrinArgSelect {
74     enum class Type {
75 #define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name,
76 #include "DXILOperation.inc"
77     };
78     Type Type;
79     int Value;
80   };
81 
82   /// Replaces uses of a struct with uses of an equivalent named struct.
83   ///
84   /// DXIL operations that return structs give them well known names, so we need
85   /// to update uses when we switch from an LLVM intrinsic to an op.
86   Error replaceNamedStructUses(CallInst *Intrin, CallInst *DXILOp) {
87     auto *IntrinTy = cast<StructType>(Intrin->getType());
88     auto *DXILOpTy = cast<StructType>(DXILOp->getType());
89     if (!IntrinTy->isLayoutIdentical(DXILOpTy))
90       return make_error<StringError>(
91           "Type mismatch between intrinsic and DXIL op",
92           inconvertibleErrorCode());
93 
94     for (Use &U : make_early_inc_range(Intrin->uses()))
95       if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser()))
96         EVI->setOperand(0, DXILOp);
97       else if (auto *IVI = dyn_cast<InsertValueInst>(U.getUser()))
98         IVI->setOperand(0, DXILOp);
99       else
100         return make_error<StringError>("DXIL ops that return structs may only "
101                                        "be used by insert- and extractvalue",
102                                        inconvertibleErrorCode());
103     return Error::success();
104   }
105 
106   [[nodiscard]] bool
107   replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
108                         ArrayRef<IntrinArgSelect> ArgSelects) {
109     return replaceFunction(F, [&](CallInst *CI) -> Error {
110       OpBuilder.getIRB().SetInsertPoint(CI);
111       SmallVector<Value *> Args;
112       if (ArgSelects.size()) {
113         for (const IntrinArgSelect &A : ArgSelects) {
114           switch (A.Type) {
115           case IntrinArgSelect::Type::Index:
116             Args.push_back(CI->getArgOperand(A.Value));
117             break;
118           case IntrinArgSelect::Type::I8:
119             Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value));
120             break;
121           case IntrinArgSelect::Type::I32:
122             Args.push_back(OpBuilder.getIRB().getInt32(A.Value));
123             break;
124           }
125         }
126       } else {
127         Args.append(CI->arg_begin(), CI->arg_end());
128       }
129 
130       Expected<CallInst *> OpCall =
131           OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType());
132       if (Error E = OpCall.takeError())
133         return E;
134 
135       if (isa<StructType>(CI->getType())) {
136         if (Error E = replaceNamedStructUses(CI, *OpCall))
137           return E;
138       } else
139         CI->replaceAllUsesWith(*OpCall);
140 
141       CI->eraseFromParent();
142       return Error::success();
143     });
144   }
145 
146   /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which
147   /// is intended to be removed by the end of lowering. This is used to allow
148   /// lowering of ops which need to change their return or argument types in a
149   /// piecemeal way - we can add the casts in to avoid updating all of the uses
150   /// or defs, and by the end all of the casts will be redundant.
151   Value *createTmpHandleCast(Value *V, Type *Ty) {
152     CallInst *Cast = OpBuilder.getIRB().CreateIntrinsic(
153         Intrinsic::dx_resource_casthandle, {Ty, V->getType()}, {V});
154     CleanupCasts.push_back(Cast);
155     return Cast;
156   }
157 
158   void cleanupHandleCasts() {
159     SmallVector<CallInst *> ToRemove;
160     SmallVector<Function *> CastFns;
161 
162     for (CallInst *Cast : CleanupCasts) {
163       // These casts were only put in to ease the move from `target("dx")` types
164       // to `dx.types.Handle in a piecemeal way. At this point, all of the
165       // non-cast uses should now be `dx.types.Handle`, and remaining casts
166       // should all form pairs to and from the now unused `target("dx")` type.
167       CastFns.push_back(Cast->getCalledFunction());
168 
169       // If the cast is not to `dx.types.Handle`, it should be the first part of
170       // the pair. Keep track so we can remove it once it has no more uses.
171       if (Cast->getType() != OpBuilder.getHandleType()) {
172         ToRemove.push_back(Cast);
173         continue;
174       }
175       // Otherwise, we're the second handle in a pair. Forward the arguments and
176       // remove the (second) cast.
177       CallInst *Def = cast<CallInst>(Cast->getOperand(0));
178       assert(Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle &&
179              "Unbalanced pair of temporary handle casts");
180       Cast->replaceAllUsesWith(Def->getOperand(0));
181       Cast->eraseFromParent();
182     }
183     for (CallInst *Cast : ToRemove) {
184       assert(Cast->user_empty() && "Temporary handle cast still has users");
185       Cast->eraseFromParent();
186     }
187 
188     // Deduplicate the cast functions so that we only erase each one once.
189     llvm::sort(CastFns);
190     CastFns.erase(llvm::unique(CastFns), CastFns.end());
191     for (Function *F : CastFns)
192       F->eraseFromParent();
193 
194     CleanupCasts.clear();
195   }
196 
197   // Remove the resource global associated with the handleFromBinding call
198   // instruction and their uses as they aren't needed anymore.
199   // TODO: We should verify that all the globals get removed.
200   // It's expected we'll need a custom pass in the future that will eliminate
201   // the need for this here.
202   void removeResourceGlobals(CallInst *CI) {
203     for (User *User : make_early_inc_range(CI->users())) {
204       if (StoreInst *Store = dyn_cast<StoreInst>(User)) {
205         Value *V = Store->getOperand(1);
206         Store->eraseFromParent();
207         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
208           if (GV->use_empty()) {
209             GV->removeDeadConstantUsers();
210             GV->eraseFromParent();
211           }
212       }
213     }
214   }
215 
216   void replaceHandleFromBindingCall(CallInst *CI, Value *Replacement) {
217     assert(CI->getCalledFunction()->getIntrinsicID() ==
218            Intrinsic::dx_resource_handlefrombinding);
219 
220     removeResourceGlobals(CI);
221 
222     auto *NameGlobal = dyn_cast<llvm::GlobalVariable>(CI->getArgOperand(5));
223 
224     CI->replaceAllUsesWith(Replacement);
225     CI->eraseFromParent();
226 
227     if (NameGlobal && NameGlobal->use_empty())
228       NameGlobal->removeFromParent();
229   }
230 
231   [[nodiscard]] bool lowerToCreateHandle(Function &F) {
232     IRBuilder<> &IRB = OpBuilder.getIRB();
233     Type *Int8Ty = IRB.getInt8Ty();
234     Type *Int32Ty = IRB.getInt32Ty();
235 
236     return replaceFunction(F, [&](CallInst *CI) -> Error {
237       IRB.SetInsertPoint(CI);
238 
239       auto *It = DRM.find(CI);
240       assert(It != DRM.end() && "Resource not in map?");
241       dxil::ResourceInfo &RI = *It;
242 
243       const auto &Binding = RI.getBinding();
244       dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass();
245 
246       Value *IndexOp = CI->getArgOperand(3);
247       if (Binding.LowerBound != 0)
248         IndexOp = IRB.CreateAdd(IndexOp,
249                                 ConstantInt::get(Int32Ty, Binding.LowerBound));
250 
251       std::array<Value *, 4> Args{
252           ConstantInt::get(Int8Ty, llvm::to_underlying(RC)),
253           ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp,
254           CI->getArgOperand(4)};
255       Expected<CallInst *> OpCall =
256           OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName());
257       if (Error E = OpCall.takeError())
258         return E;
259 
260       Value *Cast = createTmpHandleCast(*OpCall, CI->getType());
261       replaceHandleFromBindingCall(CI, Cast);
262       return Error::success();
263     });
264   }
265 
266   [[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) {
267     IRBuilder<> &IRB = OpBuilder.getIRB();
268     Type *Int32Ty = IRB.getInt32Ty();
269 
270     return replaceFunction(F, [&](CallInst *CI) -> Error {
271       IRB.SetInsertPoint(CI);
272 
273       auto *It = DRM.find(CI);
274       assert(It != DRM.end() && "Resource not in map?");
275       dxil::ResourceInfo &RI = *It;
276 
277       const auto &Binding = RI.getBinding();
278       dxil::ResourceTypeInfo &RTI = DRTM[RI.getHandleTy()];
279       dxil::ResourceClass RC = RTI.getResourceClass();
280 
281       Value *IndexOp = CI->getArgOperand(3);
282       if (Binding.LowerBound != 0)
283         IndexOp = IRB.CreateAdd(IndexOp,
284                                 ConstantInt::get(Int32Ty, Binding.LowerBound));
285 
286       std::pair<uint32_t, uint32_t> Props =
287           RI.getAnnotateProps(*F.getParent(), RTI);
288 
289       // For `CreateHandleFromBinding` we need the upper bound rather than the
290       // size, so we need to be careful about the difference for "unbounded".
291       uint32_t Unbounded = std::numeric_limits<uint32_t>::max();
292       uint32_t UpperBound = Binding.Size == Unbounded
293                                 ? Unbounded
294                                 : Binding.LowerBound + Binding.Size - 1;
295       Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound,
296                                                Binding.Space, RC);
297       std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)};
298       Expected<CallInst *> OpBind = OpBuilder.tryCreateOp(
299           OpCode::CreateHandleFromBinding, BindArgs, CI->getName());
300       if (Error E = OpBind.takeError())
301         return E;
302 
303       std::array<Value *, 2> AnnotateArgs{
304           *OpBind, OpBuilder.getResProps(Props.first, Props.second)};
305       Expected<CallInst *> OpAnnotate = OpBuilder.tryCreateOp(
306           OpCode::AnnotateHandle, AnnotateArgs,
307           CI->hasName() ? CI->getName() + "_annot" : Twine());
308       if (Error E = OpAnnotate.takeError())
309         return E;
310 
311       Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType());
312       replaceHandleFromBindingCall(CI, Cast);
313       return Error::success();
314     });
315   }
316 
317   /// Lower `dx.resource.handlefrombinding` intrinsics depending on the shader
318   /// model and taking into account binding information from
319   /// DXILResourceAnalysis.
320   bool lowerHandleFromBinding(Function &F) {
321     if (MMDI.DXILVersion < VersionTuple(1, 6))
322       return lowerToCreateHandle(F);
323     return lowerToBindAndAnnotateHandle(F);
324   }
325 
326   /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
327   /// Since we expect to be post-scalarization, make an effort to avoid vectors.
328   Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
329     IRBuilder<> &IRB = OpBuilder.getIRB();
330 
331     Instruction *OldResult = Intrin;
332     Type *OldTy = Intrin->getType();
333 
334     if (HasCheckBit) {
335       auto *ST = cast<StructType>(OldTy);
336 
337       Value *CheckOp = nullptr;
338       Type *Int32Ty = IRB.getInt32Ty();
339       for (Use &U : make_early_inc_range(OldResult->uses())) {
340         if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
341           ArrayRef<unsigned> Indices = EVI->getIndices();
342           assert(Indices.size() == 1);
343           // We're only interested in uses of the check bit for now.
344           if (Indices[0] != 1)
345             continue;
346           if (!CheckOp) {
347             Value *NewEVI = IRB.CreateExtractValue(Op, 4);
348             Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
349                 OpCode::CheckAccessFullyMapped, {NewEVI},
350                 OldResult->hasName() ? OldResult->getName() + "_check"
351                                      : Twine(),
352                 Int32Ty);
353             if (Error E = OpCall.takeError())
354               return E;
355             CheckOp = *OpCall;
356           }
357           EVI->replaceAllUsesWith(CheckOp);
358           EVI->eraseFromParent();
359         }
360       }
361 
362       if (OldResult->use_empty()) {
363         // Only the check bit was used, so we're done here.
364         OldResult->eraseFromParent();
365         return Error::success();
366       }
367 
368       assert(OldResult->hasOneUse() &&
369              isa<ExtractValueInst>(*OldResult->user_begin()) &&
370              "Expected only use to be extract of first element");
371       OldResult = cast<Instruction>(*OldResult->user_begin());
372       OldTy = ST->getElementType(0);
373     }
374 
375     // For scalars, we just extract the first element.
376     if (!isa<FixedVectorType>(OldTy)) {
377       Value *EVI = IRB.CreateExtractValue(Op, 0);
378       OldResult->replaceAllUsesWith(EVI);
379       OldResult->eraseFromParent();
380       if (OldResult != Intrin) {
381         assert(Intrin->use_empty() && "Intrinsic still has uses?");
382         Intrin->eraseFromParent();
383       }
384       return Error::success();
385     }
386 
387     std::array<Value *, 4> Extracts = {};
388     SmallVector<ExtractElementInst *> DynamicAccesses;
389 
390     // The users of the operation should all be scalarized, so we attempt to
391     // replace the extractelements with extractvalues directly.
392     for (Use &U : make_early_inc_range(OldResult->uses())) {
393       if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
394         if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
395           size_t IndexVal = IndexOp->getZExtValue();
396           assert(IndexVal < 4 && "Index into buffer load out of range");
397           if (!Extracts[IndexVal])
398             Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal);
399           EEI->replaceAllUsesWith(Extracts[IndexVal]);
400           EEI->eraseFromParent();
401         } else {
402           DynamicAccesses.push_back(EEI);
403         }
404       }
405     }
406 
407     const auto *VecTy = cast<FixedVectorType>(OldTy);
408     const unsigned N = VecTy->getNumElements();
409 
410     // If there's a dynamic access we need to round trip through stack memory so
411     // that we don't leave vectors around.
412     if (!DynamicAccesses.empty()) {
413       Type *Int32Ty = IRB.getInt32Ty();
414       Constant *Zero = ConstantInt::get(Int32Ty, 0);
415 
416       Type *ElTy = VecTy->getElementType();
417       Type *ArrayTy = ArrayType::get(ElTy, N);
418       Value *Alloca = IRB.CreateAlloca(ArrayTy);
419 
420       for (int I = 0, E = N; I != E; ++I) {
421         if (!Extracts[I])
422           Extracts[I] = IRB.CreateExtractValue(Op, I);
423         Value *GEP = IRB.CreateInBoundsGEP(
424             ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)});
425         IRB.CreateStore(Extracts[I], GEP);
426       }
427 
428       for (ExtractElementInst *EEI : DynamicAccesses) {
429         Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca,
430                                            {Zero, EEI->getIndexOperand()});
431         Value *Load = IRB.CreateLoad(ElTy, GEP);
432         EEI->replaceAllUsesWith(Load);
433         EEI->eraseFromParent();
434       }
435     }
436 
437     // If we still have uses, then we're not fully scalarized and need to
438     // recreate the vector. This should only happen for things like exported
439     // functions from libraries.
440     if (!OldResult->use_empty()) {
441       for (int I = 0, E = N; I != E; ++I)
442         if (!Extracts[I])
443           Extracts[I] = IRB.CreateExtractValue(Op, I);
444 
445       Value *Vec = PoisonValue::get(OldTy);
446       for (int I = 0, E = N; I != E; ++I)
447         Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
448       OldResult->replaceAllUsesWith(Vec);
449     }
450 
451     OldResult->eraseFromParent();
452     if (OldResult != Intrin) {
453       assert(Intrin->use_empty() && "Intrinsic still has uses?");
454       Intrin->eraseFromParent();
455     }
456 
457     return Error::success();
458   }
459 
460   [[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) {
461     IRBuilder<> &IRB = OpBuilder.getIRB();
462     Type *Int32Ty = IRB.getInt32Ty();
463 
464     return replaceFunction(F, [&](CallInst *CI) -> Error {
465       IRB.SetInsertPoint(CI);
466 
467       Value *Handle =
468           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
469       Value *Index0 = CI->getArgOperand(1);
470       Value *Index1 = UndefValue::get(Int32Ty);
471 
472       Type *OldTy = CI->getType();
473       if (HasCheckBit)
474         OldTy = cast<StructType>(OldTy)->getElementType(0);
475       Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType());
476 
477       std::array<Value *, 3> Args{Handle, Index0, Index1};
478       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
479           OpCode::BufferLoad, Args, CI->getName(), NewRetTy);
480       if (Error E = OpCall.takeError())
481         return E;
482       if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
483         return E;
484 
485       return Error::success();
486     });
487   }
488 
489   [[nodiscard]] bool lowerRawBufferLoad(Function &F) {
490     const DataLayout &DL = F.getDataLayout();
491     IRBuilder<> &IRB = OpBuilder.getIRB();
492     Type *Int8Ty = IRB.getInt8Ty();
493     Type *Int32Ty = IRB.getInt32Ty();
494 
495     return replaceFunction(F, [&](CallInst *CI) -> Error {
496       IRB.SetInsertPoint(CI);
497 
498       Type *OldTy = cast<StructType>(CI->getType())->getElementType(0);
499       Type *ScalarTy = OldTy->getScalarType();
500       Type *NewRetTy = OpBuilder.getResRetType(ScalarTy);
501 
502       Value *Handle =
503           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
504       Value *Index0 = CI->getArgOperand(1);
505       Value *Index1 = CI->getArgOperand(2);
506       uint64_t NumElements =
507           DL.getTypeSizeInBits(OldTy) / DL.getTypeSizeInBits(ScalarTy);
508       Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
509       Value *Align =
510           ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value());
511 
512       Expected<CallInst *> OpCall =
513           MMDI.DXILVersion >= VersionTuple(1, 2)
514               ? OpBuilder.tryCreateOp(OpCode::RawBufferLoad,
515                                       {Handle, Index0, Index1, Mask, Align},
516                                       CI->getName(), NewRetTy)
517               : OpBuilder.tryCreateOp(OpCode::BufferLoad,
518                                       {Handle, Index0, Index1}, CI->getName(),
519                                       NewRetTy);
520       if (Error E = OpCall.takeError())
521         return E;
522       if (Error E = replaceResRetUses(CI, *OpCall, /*HasCheckBit=*/true))
523         return E;
524 
525       return Error::success();
526     });
527   }
528 
529   [[nodiscard]] bool lowerCBufferLoad(Function &F) {
530     IRBuilder<> &IRB = OpBuilder.getIRB();
531 
532     return replaceFunction(F, [&](CallInst *CI) -> Error {
533       IRB.SetInsertPoint(CI);
534 
535       Type *OldTy = cast<StructType>(CI->getType())->getElementType(0);
536       Type *ScalarTy = OldTy->getScalarType();
537       Type *NewRetTy = OpBuilder.getCBufRetType(ScalarTy);
538 
539       Value *Handle =
540           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
541       Value *Index = CI->getArgOperand(1);
542 
543       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
544           OpCode::CBufferLoadLegacy, {Handle, Index}, CI->getName(), NewRetTy);
545       if (Error E = OpCall.takeError())
546         return E;
547       if (Error E = replaceNamedStructUses(CI, *OpCall))
548         return E;
549 
550       CI->eraseFromParent();
551       return Error::success();
552     });
553   }
554 
555   [[nodiscard]] bool lowerUpdateCounter(Function &F) {
556     IRBuilder<> &IRB = OpBuilder.getIRB();
557     Type *Int32Ty = IRB.getInt32Ty();
558 
559     return replaceFunction(F, [&](CallInst *CI) -> Error {
560       IRB.SetInsertPoint(CI);
561       Value *Handle =
562           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
563       Value *Op1 = CI->getArgOperand(1);
564 
565       std::array<Value *, 2> Args{Handle, Op1};
566 
567       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
568           OpCode::UpdateCounter, Args, CI->getName(), Int32Ty);
569 
570       if (Error E = OpCall.takeError())
571         return E;
572 
573       CI->replaceAllUsesWith(*OpCall);
574       CI->eraseFromParent();
575       return Error::success();
576     });
577   }
578 
579   [[nodiscard]] bool lowerGetPointer(Function &F) {
580     // These should have already been handled in DXILResourceAccess, so we can
581     // just clean up the dead prototype.
582     assert(F.user_empty() && "getpointer operations should have been removed");
583     F.eraseFromParent();
584     return false;
585   }
586 
587   [[nodiscard]] bool lowerBufferStore(Function &F, bool IsRaw) {
588     const DataLayout &DL = F.getDataLayout();
589     IRBuilder<> &IRB = OpBuilder.getIRB();
590     Type *Int8Ty = IRB.getInt8Ty();
591     Type *Int32Ty = IRB.getInt32Ty();
592 
593     return replaceFunction(F, [&](CallInst *CI) -> Error {
594       IRB.SetInsertPoint(CI);
595 
596       Value *Handle =
597           createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
598       Value *Index0 = CI->getArgOperand(1);
599       Value *Index1 = IsRaw ? CI->getArgOperand(2) : UndefValue::get(Int32Ty);
600 
601       Value *Data = CI->getArgOperand(IsRaw ? 3 : 2);
602       Type *DataTy = Data->getType();
603       Type *ScalarTy = DataTy->getScalarType();
604 
605       uint64_t NumElements =
606           DL.getTypeSizeInBits(DataTy) / DL.getTypeSizeInBits(ScalarTy);
607       Value *Mask =
608           ConstantInt::get(Int8Ty, IsRaw ? ~(~0U << NumElements) : 15U);
609 
610       // TODO: check that we only have vector or scalar...
611       if (NumElements > 4)
612         return make_error<StringError>(
613             "Buffer store data must have at most 4 elements",
614             inconvertibleErrorCode());
615 
616       std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
617       if (DataTy == ScalarTy)
618         DataElements[0] = Data;
619       else {
620         // Since we're post-scalarizer, if we see a vector here it's likely
621         // constructed solely for the argument of the store. Just use the scalar
622         // values from before they're inserted into the temporary.
623         auto *IEI = dyn_cast<InsertElementInst>(Data);
624         while (IEI) {
625           auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
626           if (!IndexOp)
627             break;
628           size_t IndexVal = IndexOp->getZExtValue();
629           assert(IndexVal < 4 && "Too many elements for buffer store");
630           DataElements[IndexVal] = IEI->getOperand(1);
631           IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
632         }
633       }
634 
635       // If for some reason we weren't able to forward the arguments from the
636       // scalarizer artifact, then we may need to actually extract elements from
637       // the vector.
638       for (int I = 0, E = NumElements; I < E; ++I)
639         if (DataElements[I] == nullptr)
640           DataElements[I] =
641               IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
642 
643       // For any elements beyond the length of the vector, we should fill it up
644       // with undef - however, for typed buffers we repeat the first element to
645       // match DXC.
646       for (int I = NumElements, E = 4; I < E; ++I)
647         if (DataElements[I] == nullptr)
648           DataElements[I] = IsRaw ? UndefValue::get(ScalarTy) : DataElements[0];
649 
650       dxil::OpCode Op = OpCode::BufferStore;
651       SmallVector<Value *, 9> Args{
652           Handle,          Index0,          Index1,          DataElements[0],
653           DataElements[1], DataElements[2], DataElements[3], Mask};
654       if (IsRaw && MMDI.DXILVersion >= VersionTuple(1, 2)) {
655         Op = OpCode::RawBufferStore;
656         // RawBufferStore requires the alignment
657         Args.push_back(
658             ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value()));
659       }
660       Expected<CallInst *> OpCall =
661           OpBuilder.tryCreateOp(Op, Args, CI->getName());
662       if (Error E = OpCall.takeError())
663         return E;
664 
665       CI->eraseFromParent();
666       // Clean up any leftover `insertelement`s
667       auto *IEI = dyn_cast<InsertElementInst>(Data);
668       while (IEI && IEI->use_empty()) {
669         InsertElementInst *Tmp = IEI;
670         IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
671         Tmp->eraseFromParent();
672       }
673 
674       return Error::success();
675     });
676   }
677 
678   [[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
679     IRBuilder<> &IRB = OpBuilder.getIRB();
680     Type *Int32Ty = IRB.getInt32Ty();
681 
682     return replaceFunction(F, [&](CallInst *CI) -> Error {
683       IRB.SetInsertPoint(CI);
684       SmallVector<Value *> Args;
685       Args.append(CI->arg_begin(), CI->arg_end());
686 
687       Type *RetTy = Int32Ty;
688       Type *FRT = F.getReturnType();
689       if (const auto *VT = dyn_cast<VectorType>(FRT))
690         RetTy = VectorType::get(RetTy, VT);
691 
692       Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
693           dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
694       if (Error E = OpCall.takeError())
695         return E;
696 
697       // If the result type is 32 bits we can do a direct replacement.
698       if (FRT->isIntOrIntVectorTy(32)) {
699         CI->replaceAllUsesWith(*OpCall);
700         CI->eraseFromParent();
701         return Error::success();
702       }
703 
704       unsigned CastOp;
705       unsigned CastOp2;
706       if (FRT->isIntOrIntVectorTy(16)) {
707         CastOp = Instruction::ZExt;
708         CastOp2 = Instruction::SExt;
709       } else { // must be 64 bits
710         assert(FRT->isIntOrIntVectorTy(64) &&
711                "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
712                 is supported.");
713         CastOp = Instruction::Trunc;
714         CastOp2 = Instruction::Trunc;
715       }
716 
717       // It is correct to replace the ctpop with the dxil op and
718       // remove all casts to i32
719       bool NeedsCast = false;
720       for (User *User : make_early_inc_range(CI->users())) {
721         Instruction *I = dyn_cast<Instruction>(User);
722         if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
723             I->getType() == RetTy) {
724           I->replaceAllUsesWith(*OpCall);
725           I->eraseFromParent();
726         } else
727           NeedsCast = true;
728       }
729 
730       // It is correct to replace a ctpop with the dxil op and
731       // a cast from i32 to the return type of the ctpop
732       // the cast is emitted here if there is a non-cast to i32
733       // instr which uses the ctpop
734       if (NeedsCast) {
735         Value *Cast =
736             IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
737         CI->replaceAllUsesWith(Cast);
738       }
739 
740       CI->eraseFromParent();
741       return Error::success();
742     });
743   }
744 
745   [[nodiscard]] bool lowerLifetimeIntrinsic(Function &F) {
746     IRBuilder<> &IRB = OpBuilder.getIRB();
747     return replaceFunction(F, [&](CallInst *CI) -> Error {
748       IRB.SetInsertPoint(CI);
749       Value *Ptr = CI->getArgOperand(1);
750       assert(Ptr->getType()->isPointerTy() &&
751              "Expected operand of lifetime intrinsic to be a pointer");
752 
753       auto ZeroOrUndef = [&](Type *Ty) {
754         return MMDI.ValidatorVersion < VersionTuple(1, 6)
755                    ? Constant::getNullValue(Ty)
756                    : UndefValue::get(Ty);
757       };
758 
759       Value *Val = nullptr;
760       if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) {
761         if (GV->hasInitializer() || GV->isExternallyInitialized())
762           return Error::success();
763         Val = ZeroOrUndef(GV->getValueType());
764       } else if (auto *AI = dyn_cast<AllocaInst>(Ptr))
765         Val = ZeroOrUndef(AI->getAllocatedType());
766 
767       assert(Val && "Expected operand of lifetime intrinsic to be a global "
768                     "variable or alloca instruction");
769       IRB.CreateStore(Val, Ptr, false);
770 
771       CI->eraseFromParent();
772       return Error::success();
773     });
774   }
775 
776   [[nodiscard]] bool lowerIsFPClass(Function &F) {
777     IRBuilder<> &IRB = OpBuilder.getIRB();
778     Type *RetTy = IRB.getInt1Ty();
779 
780     return replaceFunction(F, [&](CallInst *CI) -> Error {
781       IRB.SetInsertPoint(CI);
782       SmallVector<Value *> Args;
783       Value *Fl = CI->getArgOperand(0);
784       Args.push_back(Fl);
785 
786       dxil::OpCode OpCode;
787       Value *T = CI->getArgOperand(1);
788       auto *TCI = dyn_cast<ConstantInt>(T);
789       switch (TCI->getZExtValue()) {
790       case FPClassTest::fcInf:
791         OpCode = dxil::OpCode::IsInf;
792         break;
793       case FPClassTest::fcNan:
794         OpCode = dxil::OpCode::IsNaN;
795         break;
796       case FPClassTest::fcNormal:
797         OpCode = dxil::OpCode::IsNormal;
798         break;
799       case FPClassTest::fcFinite:
800         OpCode = dxil::OpCode::IsFinite;
801         break;
802       default:
803         SmallString<128> Msg =
804             formatv("Unsupported FPClassTest {0} for DXIL Op Lowering",
805                     TCI->getZExtValue());
806         return make_error<StringError>(Msg, inconvertibleErrorCode());
807       }
808 
809       Expected<CallInst *> OpCall =
810           OpBuilder.tryCreateOp(OpCode, Args, CI->getName(), RetTy);
811       if (Error E = OpCall.takeError())
812         return E;
813 
814       CI->replaceAllUsesWith(*OpCall);
815       CI->eraseFromParent();
816       return Error::success();
817     });
818   }
819 
820   bool lowerIntrinsics() {
821     bool Updated = false;
822     bool HasErrors = false;
823 
824     for (Function &F : make_early_inc_range(M.functions())) {
825       if (!F.isDeclaration())
826         continue;
827       Intrinsic::ID ID = F.getIntrinsicID();
828       switch (ID) {
829       // NOTE: Skip dx_resource_casthandle here. They are
830       // resolved after this loop in cleanupHandleCasts.
831       case Intrinsic::dx_resource_casthandle:
832       // NOTE: llvm.dbg.value is supported as is in DXIL.
833       case Intrinsic::dbg_value:
834       case Intrinsic::not_intrinsic:
835         if (F.use_empty())
836           F.eraseFromParent();
837         continue;
838       default:
839         if (F.use_empty())
840           F.eraseFromParent();
841         else {
842           SmallString<128> Msg = formatv(
843               "Unsupported intrinsic {0} for DXIL lowering", F.getName());
844           M.getContext().emitError(Msg);
845           HasErrors |= true;
846         }
847         break;
848 
849 #define DXIL_OP_INTRINSIC(OpCode, Intrin, ...)                                 \
850   case Intrin:                                                                 \
851     HasErrors |= replaceFunctionWithOp(                                        \
852         F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__});                    \
853     break;
854 #include "DXILOperation.inc"
855       case Intrinsic::dx_resource_handlefrombinding:
856         HasErrors |= lowerHandleFromBinding(F);
857         break;
858       case Intrinsic::dx_resource_getpointer:
859         HasErrors |= lowerGetPointer(F);
860         break;
861       case Intrinsic::dx_resource_load_typedbuffer:
862         HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true);
863         break;
864       case Intrinsic::dx_resource_store_typedbuffer:
865         HasErrors |= lowerBufferStore(F, /*IsRaw=*/false);
866         break;
867       case Intrinsic::dx_resource_load_rawbuffer:
868         HasErrors |= lowerRawBufferLoad(F);
869         break;
870       case Intrinsic::dx_resource_store_rawbuffer:
871         HasErrors |= lowerBufferStore(F, /*IsRaw=*/true);
872         break;
873       case Intrinsic::dx_resource_load_cbufferrow_2:
874       case Intrinsic::dx_resource_load_cbufferrow_4:
875       case Intrinsic::dx_resource_load_cbufferrow_8:
876         HasErrors |= lowerCBufferLoad(F);
877         break;
878       case Intrinsic::dx_resource_updatecounter:
879         HasErrors |= lowerUpdateCounter(F);
880         break;
881       case Intrinsic::ctpop:
882         HasErrors |= lowerCtpopToCountBits(F);
883         break;
884       case Intrinsic::lifetime_start:
885       case Intrinsic::lifetime_end:
886         if (F.use_empty())
887           F.eraseFromParent();
888         else {
889           if (MMDI.DXILVersion < VersionTuple(1, 6))
890             HasErrors |= lowerLifetimeIntrinsic(F);
891           else
892             continue;
893         }
894         break;
895       case Intrinsic::is_fpclass:
896         HasErrors |= lowerIsFPClass(F);
897         break;
898       }
899       Updated = true;
900     }
901     if (Updated && !HasErrors)
902       cleanupHandleCasts();
903 
904     return Updated;
905   }
906 };
907 } // namespace
908 
909 PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) {
910   DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
911   DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M);
912   const ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
913 
914   const bool MadeChanges = OpLowerer(M, DRM, DRTM, MMDI).lowerIntrinsics();
915   if (!MadeChanges)
916     return PreservedAnalyses::all();
917   PreservedAnalyses PA;
918   PA.preserve<DXILResourceAnalysis>();
919   PA.preserve<DXILMetadataAnalysis>();
920   PA.preserve<ShaderFlagsAnalysis>();
921   return PA;
922 }
923 
924 namespace {
925 class DXILOpLoweringLegacy : public ModulePass {
926 public:
927   bool runOnModule(Module &M) override {
928     DXILResourceMap &DRM =
929         getAnalysis<DXILResourceWrapperPass>().getResourceMap();
930     DXILResourceTypeMap &DRTM =
931         getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
932     const ModuleMetadataInfo MMDI =
933         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
934 
935     return OpLowerer(M, DRM, DRTM, MMDI).lowerIntrinsics();
936   }
937   StringRef getPassName() const override { return "DXIL Op Lowering"; }
938   DXILOpLoweringLegacy() : ModulePass(ID) {}
939 
940   static char ID; // Pass identification.
941   void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
942     AU.addRequired<DXILResourceTypeWrapperPass>();
943     AU.addRequired<DXILResourceWrapperPass>();
944     AU.addRequired<DXILMetadataAnalysisWrapperPass>();
945     AU.addPreserved<DXILResourceWrapperPass>();
946     AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
947     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
948   }
949 };
950 char DXILOpLoweringLegacy::ID = 0;
951 } // end anonymous namespace
952 
953 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
954                       false, false)
955 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
956 INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
957 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
958                     false)
959 
960 ModulePass *llvm::createDXILOpLoweringLegacyPass() {
961   return new DXILOpLoweringLegacy();
962 }
963