1 //===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "DXILOpLowering.h"
10 #include "DXILConstants.h"
11 #include "DXILOpBuilder.h"
12 #include "DXILShaderFlags.h"
13 #include "DirectX.h"
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/Analysis/DXILMetadataAnalysis.h"
16 #include "llvm/Analysis/DXILResource.h"
17 #include "llvm/CodeGen/Passes.h"
18 #include "llvm/IR/DiagnosticInfo.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Intrinsics.h"
23 #include "llvm/IR/IntrinsicsDirectX.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/PassManager.h"
26 #include "llvm/InitializePasses.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/ErrorHandling.h"
29 #include "llvm/Support/FormatVariadic.h"
30
31 #define DEBUG_TYPE "dxil-op-lower"
32
33 using namespace llvm;
34 using namespace llvm::dxil;
35
36 namespace {
37 class OpLowerer {
38 Module &M;
39 DXILOpBuilder OpBuilder;
40 DXILResourceMap &DRM;
41 DXILResourceTypeMap &DRTM;
42 const ModuleMetadataInfo &MMDI;
43 SmallVector<CallInst *> CleanupCasts;
44
45 public:
OpLowerer(Module & M,DXILResourceMap & DRM,DXILResourceTypeMap & DRTM,const ModuleMetadataInfo & MMDI)46 OpLowerer(Module &M, DXILResourceMap &DRM, DXILResourceTypeMap &DRTM,
47 const ModuleMetadataInfo &MMDI)
48 : M(M), OpBuilder(M), DRM(DRM), DRTM(DRTM), MMDI(MMDI) {}
49
50 /// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If
51 /// there is an error replacing a call, we emit a diagnostic and return true.
52 [[nodiscard]] bool
replaceFunction(Function & F,llvm::function_ref<Error (CallInst * CI)> ReplaceCall)53 replaceFunction(Function &F,
54 llvm::function_ref<Error(CallInst *CI)> ReplaceCall) {
55 for (User *U : make_early_inc_range(F.users())) {
56 CallInst *CI = dyn_cast<CallInst>(U);
57 if (!CI)
58 continue;
59
60 if (Error E = ReplaceCall(CI)) {
61 std::string Message(toString(std::move(E)));
62 M.getContext().diagnose(DiagnosticInfoUnsupported(
63 *CI->getFunction(), Message, CI->getDebugLoc()));
64
65 return true;
66 }
67 }
68 if (F.user_empty())
69 F.eraseFromParent();
70 return false;
71 }
72
73 struct IntrinArgSelect {
74 enum class Type {
75 #define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name,
76 #include "DXILOperation.inc"
77 };
78 Type Type;
79 int Value;
80 };
81
82 /// Replaces uses of a struct with uses of an equivalent named struct.
83 ///
84 /// DXIL operations that return structs give them well known names, so we need
85 /// to update uses when we switch from an LLVM intrinsic to an op.
replaceNamedStructUses(CallInst * Intrin,CallInst * DXILOp)86 Error replaceNamedStructUses(CallInst *Intrin, CallInst *DXILOp) {
87 auto *IntrinTy = cast<StructType>(Intrin->getType());
88 auto *DXILOpTy = cast<StructType>(DXILOp->getType());
89 if (!IntrinTy->isLayoutIdentical(DXILOpTy))
90 return make_error<StringError>(
91 "Type mismatch between intrinsic and DXIL op",
92 inconvertibleErrorCode());
93
94 for (Use &U : make_early_inc_range(Intrin->uses()))
95 if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser()))
96 EVI->setOperand(0, DXILOp);
97 else if (auto *IVI = dyn_cast<InsertValueInst>(U.getUser()))
98 IVI->setOperand(0, DXILOp);
99 else
100 return make_error<StringError>("DXIL ops that return structs may only "
101 "be used by insert- and extractvalue",
102 inconvertibleErrorCode());
103 return Error::success();
104 }
105
106 [[nodiscard]] bool
replaceFunctionWithOp(Function & F,dxil::OpCode DXILOp,ArrayRef<IntrinArgSelect> ArgSelects)107 replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
108 ArrayRef<IntrinArgSelect> ArgSelects) {
109 return replaceFunction(F, [&](CallInst *CI) -> Error {
110 OpBuilder.getIRB().SetInsertPoint(CI);
111 SmallVector<Value *> Args;
112 if (ArgSelects.size()) {
113 for (const IntrinArgSelect &A : ArgSelects) {
114 switch (A.Type) {
115 case IntrinArgSelect::Type::Index:
116 Args.push_back(CI->getArgOperand(A.Value));
117 break;
118 case IntrinArgSelect::Type::I8:
119 Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value));
120 break;
121 case IntrinArgSelect::Type::I32:
122 Args.push_back(OpBuilder.getIRB().getInt32(A.Value));
123 break;
124 }
125 }
126 } else {
127 Args.append(CI->arg_begin(), CI->arg_end());
128 }
129
130 Expected<CallInst *> OpCall =
131 OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType());
132 if (Error E = OpCall.takeError())
133 return E;
134
135 if (isa<StructType>(CI->getType())) {
136 if (Error E = replaceNamedStructUses(CI, *OpCall))
137 return E;
138 } else
139 CI->replaceAllUsesWith(*OpCall);
140
141 CI->eraseFromParent();
142 return Error::success();
143 });
144 }
145
146 /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which
147 /// is intended to be removed by the end of lowering. This is used to allow
148 /// lowering of ops which need to change their return or argument types in a
149 /// piecemeal way - we can add the casts in to avoid updating all of the uses
150 /// or defs, and by the end all of the casts will be redundant.
createTmpHandleCast(Value * V,Type * Ty)151 Value *createTmpHandleCast(Value *V, Type *Ty) {
152 CallInst *Cast = OpBuilder.getIRB().CreateIntrinsic(
153 Intrinsic::dx_resource_casthandle, {Ty, V->getType()}, {V});
154 CleanupCasts.push_back(Cast);
155 return Cast;
156 }
157
cleanupHandleCasts()158 void cleanupHandleCasts() {
159 SmallVector<CallInst *> ToRemove;
160 SmallVector<Function *> CastFns;
161
162 for (CallInst *Cast : CleanupCasts) {
163 // These casts were only put in to ease the move from `target("dx")` types
164 // to `dx.types.Handle in a piecemeal way. At this point, all of the
165 // non-cast uses should now be `dx.types.Handle`, and remaining casts
166 // should all form pairs to and from the now unused `target("dx")` type.
167 CastFns.push_back(Cast->getCalledFunction());
168
169 // If the cast is not to `dx.types.Handle`, it should be the first part of
170 // the pair. Keep track so we can remove it once it has no more uses.
171 if (Cast->getType() != OpBuilder.getHandleType()) {
172 ToRemove.push_back(Cast);
173 continue;
174 }
175 // Otherwise, we're the second handle in a pair. Forward the arguments and
176 // remove the (second) cast.
177 CallInst *Def = cast<CallInst>(Cast->getOperand(0));
178 assert(Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle &&
179 "Unbalanced pair of temporary handle casts");
180 Cast->replaceAllUsesWith(Def->getOperand(0));
181 Cast->eraseFromParent();
182 }
183 for (CallInst *Cast : ToRemove) {
184 assert(Cast->user_empty() && "Temporary handle cast still has users");
185 Cast->eraseFromParent();
186 }
187
188 // Deduplicate the cast functions so that we only erase each one once.
189 llvm::sort(CastFns);
190 CastFns.erase(llvm::unique(CastFns), CastFns.end());
191 for (Function *F : CastFns)
192 F->eraseFromParent();
193
194 CleanupCasts.clear();
195 }
196
197 // Remove the resource global associated with the handleFromBinding call
198 // instruction and their uses as they aren't needed anymore.
199 // TODO: We should verify that all the globals get removed.
200 // It's expected we'll need a custom pass in the future that will eliminate
201 // the need for this here.
removeResourceGlobals(CallInst * CI)202 void removeResourceGlobals(CallInst *CI) {
203 for (User *User : make_early_inc_range(CI->users())) {
204 if (StoreInst *Store = dyn_cast<StoreInst>(User)) {
205 Value *V = Store->getOperand(1);
206 Store->eraseFromParent();
207 if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
208 if (GV->use_empty()) {
209 GV->removeDeadConstantUsers();
210 GV->eraseFromParent();
211 }
212 }
213 }
214 }
215
replaceHandleFromBindingCall(CallInst * CI,Value * Replacement)216 void replaceHandleFromBindingCall(CallInst *CI, Value *Replacement) {
217 assert(CI->getCalledFunction()->getIntrinsicID() ==
218 Intrinsic::dx_resource_handlefrombinding);
219
220 removeResourceGlobals(CI);
221
222 auto *NameGlobal = dyn_cast<llvm::GlobalVariable>(CI->getArgOperand(5));
223
224 CI->replaceAllUsesWith(Replacement);
225 CI->eraseFromParent();
226
227 if (NameGlobal && NameGlobal->use_empty())
228 NameGlobal->removeFromParent();
229 }
230
lowerToCreateHandle(Function & F)231 [[nodiscard]] bool lowerToCreateHandle(Function &F) {
232 IRBuilder<> &IRB = OpBuilder.getIRB();
233 Type *Int8Ty = IRB.getInt8Ty();
234 Type *Int32Ty = IRB.getInt32Ty();
235
236 return replaceFunction(F, [&](CallInst *CI) -> Error {
237 IRB.SetInsertPoint(CI);
238
239 auto *It = DRM.find(CI);
240 assert(It != DRM.end() && "Resource not in map?");
241 dxil::ResourceInfo &RI = *It;
242
243 const auto &Binding = RI.getBinding();
244 dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass();
245
246 Value *IndexOp = CI->getArgOperand(3);
247 if (Binding.LowerBound != 0)
248 IndexOp = IRB.CreateAdd(IndexOp,
249 ConstantInt::get(Int32Ty, Binding.LowerBound));
250
251 std::array<Value *, 4> Args{
252 ConstantInt::get(Int8Ty, llvm::to_underlying(RC)),
253 ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp,
254 CI->getArgOperand(4)};
255 Expected<CallInst *> OpCall =
256 OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName());
257 if (Error E = OpCall.takeError())
258 return E;
259
260 Value *Cast = createTmpHandleCast(*OpCall, CI->getType());
261 replaceHandleFromBindingCall(CI, Cast);
262 return Error::success();
263 });
264 }
265
lowerToBindAndAnnotateHandle(Function & F)266 [[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) {
267 IRBuilder<> &IRB = OpBuilder.getIRB();
268 Type *Int32Ty = IRB.getInt32Ty();
269
270 return replaceFunction(F, [&](CallInst *CI) -> Error {
271 IRB.SetInsertPoint(CI);
272
273 auto *It = DRM.find(CI);
274 assert(It != DRM.end() && "Resource not in map?");
275 dxil::ResourceInfo &RI = *It;
276
277 const auto &Binding = RI.getBinding();
278 dxil::ResourceTypeInfo &RTI = DRTM[RI.getHandleTy()];
279 dxil::ResourceClass RC = RTI.getResourceClass();
280
281 Value *IndexOp = CI->getArgOperand(3);
282 if (Binding.LowerBound != 0)
283 IndexOp = IRB.CreateAdd(IndexOp,
284 ConstantInt::get(Int32Ty, Binding.LowerBound));
285
286 std::pair<uint32_t, uint32_t> Props =
287 RI.getAnnotateProps(*F.getParent(), RTI);
288
289 // For `CreateHandleFromBinding` we need the upper bound rather than the
290 // size, so we need to be careful about the difference for "unbounded".
291 uint32_t Unbounded = std::numeric_limits<uint32_t>::max();
292 uint32_t UpperBound = Binding.Size == Unbounded
293 ? Unbounded
294 : Binding.LowerBound + Binding.Size - 1;
295 Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound,
296 Binding.Space, RC);
297 std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)};
298 Expected<CallInst *> OpBind = OpBuilder.tryCreateOp(
299 OpCode::CreateHandleFromBinding, BindArgs, CI->getName());
300 if (Error E = OpBind.takeError())
301 return E;
302
303 std::array<Value *, 2> AnnotateArgs{
304 *OpBind, OpBuilder.getResProps(Props.first, Props.second)};
305 Expected<CallInst *> OpAnnotate = OpBuilder.tryCreateOp(
306 OpCode::AnnotateHandle, AnnotateArgs,
307 CI->hasName() ? CI->getName() + "_annot" : Twine());
308 if (Error E = OpAnnotate.takeError())
309 return E;
310
311 Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType());
312 replaceHandleFromBindingCall(CI, Cast);
313 return Error::success();
314 });
315 }
316
317 /// Lower `dx.resource.handlefrombinding` intrinsics depending on the shader
318 /// model and taking into account binding information from
319 /// DXILResourceAnalysis.
lowerHandleFromBinding(Function & F)320 bool lowerHandleFromBinding(Function &F) {
321 if (MMDI.DXILVersion < VersionTuple(1, 6))
322 return lowerToCreateHandle(F);
323 return lowerToBindAndAnnotateHandle(F);
324 }
325
326 /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op.
327 /// Since we expect to be post-scalarization, make an effort to avoid vectors.
replaceResRetUses(CallInst * Intrin,CallInst * Op,bool HasCheckBit)328 Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) {
329 IRBuilder<> &IRB = OpBuilder.getIRB();
330
331 Instruction *OldResult = Intrin;
332 Type *OldTy = Intrin->getType();
333
334 if (HasCheckBit) {
335 auto *ST = cast<StructType>(OldTy);
336
337 Value *CheckOp = nullptr;
338 Type *Int32Ty = IRB.getInt32Ty();
339 for (Use &U : make_early_inc_range(OldResult->uses())) {
340 if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) {
341 ArrayRef<unsigned> Indices = EVI->getIndices();
342 assert(Indices.size() == 1);
343 // We're only interested in uses of the check bit for now.
344 if (Indices[0] != 1)
345 continue;
346 if (!CheckOp) {
347 Value *NewEVI = IRB.CreateExtractValue(Op, 4);
348 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
349 OpCode::CheckAccessFullyMapped, {NewEVI},
350 OldResult->hasName() ? OldResult->getName() + "_check"
351 : Twine(),
352 Int32Ty);
353 if (Error E = OpCall.takeError())
354 return E;
355 CheckOp = *OpCall;
356 }
357 EVI->replaceAllUsesWith(CheckOp);
358 EVI->eraseFromParent();
359 }
360 }
361
362 if (OldResult->use_empty()) {
363 // Only the check bit was used, so we're done here.
364 OldResult->eraseFromParent();
365 return Error::success();
366 }
367
368 assert(OldResult->hasOneUse() &&
369 isa<ExtractValueInst>(*OldResult->user_begin()) &&
370 "Expected only use to be extract of first element");
371 OldResult = cast<Instruction>(*OldResult->user_begin());
372 OldTy = ST->getElementType(0);
373 }
374
375 // For scalars, we just extract the first element.
376 if (!isa<FixedVectorType>(OldTy)) {
377 Value *EVI = IRB.CreateExtractValue(Op, 0);
378 OldResult->replaceAllUsesWith(EVI);
379 OldResult->eraseFromParent();
380 if (OldResult != Intrin) {
381 assert(Intrin->use_empty() && "Intrinsic still has uses?");
382 Intrin->eraseFromParent();
383 }
384 return Error::success();
385 }
386
387 std::array<Value *, 4> Extracts = {};
388 SmallVector<ExtractElementInst *> DynamicAccesses;
389
390 // The users of the operation should all be scalarized, so we attempt to
391 // replace the extractelements with extractvalues directly.
392 for (Use &U : make_early_inc_range(OldResult->uses())) {
393 if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) {
394 if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) {
395 size_t IndexVal = IndexOp->getZExtValue();
396 assert(IndexVal < 4 && "Index into buffer load out of range");
397 if (!Extracts[IndexVal])
398 Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal);
399 EEI->replaceAllUsesWith(Extracts[IndexVal]);
400 EEI->eraseFromParent();
401 } else {
402 DynamicAccesses.push_back(EEI);
403 }
404 }
405 }
406
407 const auto *VecTy = cast<FixedVectorType>(OldTy);
408 const unsigned N = VecTy->getNumElements();
409
410 // If there's a dynamic access we need to round trip through stack memory so
411 // that we don't leave vectors around.
412 if (!DynamicAccesses.empty()) {
413 Type *Int32Ty = IRB.getInt32Ty();
414 Constant *Zero = ConstantInt::get(Int32Ty, 0);
415
416 Type *ElTy = VecTy->getElementType();
417 Type *ArrayTy = ArrayType::get(ElTy, N);
418 Value *Alloca = IRB.CreateAlloca(ArrayTy);
419
420 for (int I = 0, E = N; I != E; ++I) {
421 if (!Extracts[I])
422 Extracts[I] = IRB.CreateExtractValue(Op, I);
423 Value *GEP = IRB.CreateInBoundsGEP(
424 ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)});
425 IRB.CreateStore(Extracts[I], GEP);
426 }
427
428 for (ExtractElementInst *EEI : DynamicAccesses) {
429 Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca,
430 {Zero, EEI->getIndexOperand()});
431 Value *Load = IRB.CreateLoad(ElTy, GEP);
432 EEI->replaceAllUsesWith(Load);
433 EEI->eraseFromParent();
434 }
435 }
436
437 // If we still have uses, then we're not fully scalarized and need to
438 // recreate the vector. This should only happen for things like exported
439 // functions from libraries.
440 if (!OldResult->use_empty()) {
441 for (int I = 0, E = N; I != E; ++I)
442 if (!Extracts[I])
443 Extracts[I] = IRB.CreateExtractValue(Op, I);
444
445 Value *Vec = PoisonValue::get(OldTy);
446 for (int I = 0, E = N; I != E; ++I)
447 Vec = IRB.CreateInsertElement(Vec, Extracts[I], I);
448 OldResult->replaceAllUsesWith(Vec);
449 }
450
451 OldResult->eraseFromParent();
452 if (OldResult != Intrin) {
453 assert(Intrin->use_empty() && "Intrinsic still has uses?");
454 Intrin->eraseFromParent();
455 }
456
457 return Error::success();
458 }
459
lowerTypedBufferLoad(Function & F,bool HasCheckBit)460 [[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) {
461 IRBuilder<> &IRB = OpBuilder.getIRB();
462 Type *Int32Ty = IRB.getInt32Ty();
463
464 return replaceFunction(F, [&](CallInst *CI) -> Error {
465 IRB.SetInsertPoint(CI);
466
467 Value *Handle =
468 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
469 Value *Index0 = CI->getArgOperand(1);
470 Value *Index1 = UndefValue::get(Int32Ty);
471
472 Type *OldTy = CI->getType();
473 if (HasCheckBit)
474 OldTy = cast<StructType>(OldTy)->getElementType(0);
475 Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType());
476
477 std::array<Value *, 3> Args{Handle, Index0, Index1};
478 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
479 OpCode::BufferLoad, Args, CI->getName(), NewRetTy);
480 if (Error E = OpCall.takeError())
481 return E;
482 if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit))
483 return E;
484
485 return Error::success();
486 });
487 }
488
lowerRawBufferLoad(Function & F)489 [[nodiscard]] bool lowerRawBufferLoad(Function &F) {
490 const DataLayout &DL = F.getDataLayout();
491 IRBuilder<> &IRB = OpBuilder.getIRB();
492 Type *Int8Ty = IRB.getInt8Ty();
493 Type *Int32Ty = IRB.getInt32Ty();
494
495 return replaceFunction(F, [&](CallInst *CI) -> Error {
496 IRB.SetInsertPoint(CI);
497
498 Type *OldTy = cast<StructType>(CI->getType())->getElementType(0);
499 Type *ScalarTy = OldTy->getScalarType();
500 Type *NewRetTy = OpBuilder.getResRetType(ScalarTy);
501
502 Value *Handle =
503 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
504 Value *Index0 = CI->getArgOperand(1);
505 Value *Index1 = CI->getArgOperand(2);
506 uint64_t NumElements =
507 DL.getTypeSizeInBits(OldTy) / DL.getTypeSizeInBits(ScalarTy);
508 Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements));
509 Value *Align =
510 ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value());
511
512 Expected<CallInst *> OpCall =
513 MMDI.DXILVersion >= VersionTuple(1, 2)
514 ? OpBuilder.tryCreateOp(OpCode::RawBufferLoad,
515 {Handle, Index0, Index1, Mask, Align},
516 CI->getName(), NewRetTy)
517 : OpBuilder.tryCreateOp(OpCode::BufferLoad,
518 {Handle, Index0, Index1}, CI->getName(),
519 NewRetTy);
520 if (Error E = OpCall.takeError())
521 return E;
522 if (Error E = replaceResRetUses(CI, *OpCall, /*HasCheckBit=*/true))
523 return E;
524
525 return Error::success();
526 });
527 }
528
lowerCBufferLoad(Function & F)529 [[nodiscard]] bool lowerCBufferLoad(Function &F) {
530 IRBuilder<> &IRB = OpBuilder.getIRB();
531
532 return replaceFunction(F, [&](CallInst *CI) -> Error {
533 IRB.SetInsertPoint(CI);
534
535 Type *OldTy = cast<StructType>(CI->getType())->getElementType(0);
536 Type *ScalarTy = OldTy->getScalarType();
537 Type *NewRetTy = OpBuilder.getCBufRetType(ScalarTy);
538
539 Value *Handle =
540 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
541 Value *Index = CI->getArgOperand(1);
542
543 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
544 OpCode::CBufferLoadLegacy, {Handle, Index}, CI->getName(), NewRetTy);
545 if (Error E = OpCall.takeError())
546 return E;
547 if (Error E = replaceNamedStructUses(CI, *OpCall))
548 return E;
549
550 CI->eraseFromParent();
551 return Error::success();
552 });
553 }
554
lowerUpdateCounter(Function & F)555 [[nodiscard]] bool lowerUpdateCounter(Function &F) {
556 IRBuilder<> &IRB = OpBuilder.getIRB();
557 Type *Int32Ty = IRB.getInt32Ty();
558
559 return replaceFunction(F, [&](CallInst *CI) -> Error {
560 IRB.SetInsertPoint(CI);
561 Value *Handle =
562 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
563 Value *Op1 = CI->getArgOperand(1);
564
565 std::array<Value *, 2> Args{Handle, Op1};
566
567 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
568 OpCode::UpdateCounter, Args, CI->getName(), Int32Ty);
569
570 if (Error E = OpCall.takeError())
571 return E;
572
573 CI->replaceAllUsesWith(*OpCall);
574 CI->eraseFromParent();
575 return Error::success();
576 });
577 }
578
lowerGetPointer(Function & F)579 [[nodiscard]] bool lowerGetPointer(Function &F) {
580 // These should have already been handled in DXILResourceAccess, so we can
581 // just clean up the dead prototype.
582 assert(F.user_empty() && "getpointer operations should have been removed");
583 F.eraseFromParent();
584 return false;
585 }
586
lowerBufferStore(Function & F,bool IsRaw)587 [[nodiscard]] bool lowerBufferStore(Function &F, bool IsRaw) {
588 const DataLayout &DL = F.getDataLayout();
589 IRBuilder<> &IRB = OpBuilder.getIRB();
590 Type *Int8Ty = IRB.getInt8Ty();
591 Type *Int32Ty = IRB.getInt32Ty();
592
593 return replaceFunction(F, [&](CallInst *CI) -> Error {
594 IRB.SetInsertPoint(CI);
595
596 Value *Handle =
597 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType());
598 Value *Index0 = CI->getArgOperand(1);
599 Value *Index1 = IsRaw ? CI->getArgOperand(2) : UndefValue::get(Int32Ty);
600
601 Value *Data = CI->getArgOperand(IsRaw ? 3 : 2);
602 Type *DataTy = Data->getType();
603 Type *ScalarTy = DataTy->getScalarType();
604
605 uint64_t NumElements =
606 DL.getTypeSizeInBits(DataTy) / DL.getTypeSizeInBits(ScalarTy);
607 Value *Mask =
608 ConstantInt::get(Int8Ty, IsRaw ? ~(~0U << NumElements) : 15U);
609
610 // TODO: check that we only have vector or scalar...
611 if (NumElements > 4)
612 return make_error<StringError>(
613 "Buffer store data must have at most 4 elements",
614 inconvertibleErrorCode());
615
616 std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr};
617 if (DataTy == ScalarTy)
618 DataElements[0] = Data;
619 else {
620 // Since we're post-scalarizer, if we see a vector here it's likely
621 // constructed solely for the argument of the store. Just use the scalar
622 // values from before they're inserted into the temporary.
623 auto *IEI = dyn_cast<InsertElementInst>(Data);
624 while (IEI) {
625 auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2));
626 if (!IndexOp)
627 break;
628 size_t IndexVal = IndexOp->getZExtValue();
629 assert(IndexVal < 4 && "Too many elements for buffer store");
630 DataElements[IndexVal] = IEI->getOperand(1);
631 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
632 }
633 }
634
635 // If for some reason we weren't able to forward the arguments from the
636 // scalarizer artifact, then we may need to actually extract elements from
637 // the vector.
638 for (int I = 0, E = NumElements; I < E; ++I)
639 if (DataElements[I] == nullptr)
640 DataElements[I] =
641 IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I));
642
643 // For any elements beyond the length of the vector, we should fill it up
644 // with undef - however, for typed buffers we repeat the first element to
645 // match DXC.
646 for (int I = NumElements, E = 4; I < E; ++I)
647 if (DataElements[I] == nullptr)
648 DataElements[I] = IsRaw ? UndefValue::get(ScalarTy) : DataElements[0];
649
650 dxil::OpCode Op = OpCode::BufferStore;
651 SmallVector<Value *, 9> Args{
652 Handle, Index0, Index1, DataElements[0],
653 DataElements[1], DataElements[2], DataElements[3], Mask};
654 if (IsRaw && MMDI.DXILVersion >= VersionTuple(1, 2)) {
655 Op = OpCode::RawBufferStore;
656 // RawBufferStore requires the alignment
657 Args.push_back(
658 ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value()));
659 }
660 Expected<CallInst *> OpCall =
661 OpBuilder.tryCreateOp(Op, Args, CI->getName());
662 if (Error E = OpCall.takeError())
663 return E;
664
665 CI->eraseFromParent();
666 // Clean up any leftover `insertelement`s
667 auto *IEI = dyn_cast<InsertElementInst>(Data);
668 while (IEI && IEI->use_empty()) {
669 InsertElementInst *Tmp = IEI;
670 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0));
671 Tmp->eraseFromParent();
672 }
673
674 return Error::success();
675 });
676 }
677
lowerCtpopToCountBits(Function & F)678 [[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
679 IRBuilder<> &IRB = OpBuilder.getIRB();
680 Type *Int32Ty = IRB.getInt32Ty();
681
682 return replaceFunction(F, [&](CallInst *CI) -> Error {
683 IRB.SetInsertPoint(CI);
684 SmallVector<Value *> Args;
685 Args.append(CI->arg_begin(), CI->arg_end());
686
687 Type *RetTy = Int32Ty;
688 Type *FRT = F.getReturnType();
689 if (const auto *VT = dyn_cast<VectorType>(FRT))
690 RetTy = VectorType::get(RetTy, VT);
691
692 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
693 dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
694 if (Error E = OpCall.takeError())
695 return E;
696
697 // If the result type is 32 bits we can do a direct replacement.
698 if (FRT->isIntOrIntVectorTy(32)) {
699 CI->replaceAllUsesWith(*OpCall);
700 CI->eraseFromParent();
701 return Error::success();
702 }
703
704 unsigned CastOp;
705 unsigned CastOp2;
706 if (FRT->isIntOrIntVectorTy(16)) {
707 CastOp = Instruction::ZExt;
708 CastOp2 = Instruction::SExt;
709 } else { // must be 64 bits
710 assert(FRT->isIntOrIntVectorTy(64) &&
711 "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
712 is supported.");
713 CastOp = Instruction::Trunc;
714 CastOp2 = Instruction::Trunc;
715 }
716
717 // It is correct to replace the ctpop with the dxil op and
718 // remove all casts to i32
719 bool NeedsCast = false;
720 for (User *User : make_early_inc_range(CI->users())) {
721 Instruction *I = dyn_cast<Instruction>(User);
722 if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
723 I->getType() == RetTy) {
724 I->replaceAllUsesWith(*OpCall);
725 I->eraseFromParent();
726 } else
727 NeedsCast = true;
728 }
729
730 // It is correct to replace a ctpop with the dxil op and
731 // a cast from i32 to the return type of the ctpop
732 // the cast is emitted here if there is a non-cast to i32
733 // instr which uses the ctpop
734 if (NeedsCast) {
735 Value *Cast =
736 IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
737 CI->replaceAllUsesWith(Cast);
738 }
739
740 CI->eraseFromParent();
741 return Error::success();
742 });
743 }
744
lowerLifetimeIntrinsic(Function & F)745 [[nodiscard]] bool lowerLifetimeIntrinsic(Function &F) {
746 IRBuilder<> &IRB = OpBuilder.getIRB();
747 return replaceFunction(F, [&](CallInst *CI) -> Error {
748 IRB.SetInsertPoint(CI);
749 Value *Ptr = CI->getArgOperand(1);
750 assert(Ptr->getType()->isPointerTy() &&
751 "Expected operand of lifetime intrinsic to be a pointer");
752
753 auto ZeroOrUndef = [&](Type *Ty) {
754 return MMDI.ValidatorVersion < VersionTuple(1, 6)
755 ? Constant::getNullValue(Ty)
756 : UndefValue::get(Ty);
757 };
758
759 Value *Val = nullptr;
760 if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) {
761 if (GV->hasInitializer() || GV->isExternallyInitialized())
762 return Error::success();
763 Val = ZeroOrUndef(GV->getValueType());
764 } else if (auto *AI = dyn_cast<AllocaInst>(Ptr))
765 Val = ZeroOrUndef(AI->getAllocatedType());
766
767 assert(Val && "Expected operand of lifetime intrinsic to be a global "
768 "variable or alloca instruction");
769 IRB.CreateStore(Val, Ptr, false);
770
771 CI->eraseFromParent();
772 return Error::success();
773 });
774 }
775
lowerIsFPClass(Function & F)776 [[nodiscard]] bool lowerIsFPClass(Function &F) {
777 IRBuilder<> &IRB = OpBuilder.getIRB();
778 Type *RetTy = IRB.getInt1Ty();
779
780 return replaceFunction(F, [&](CallInst *CI) -> Error {
781 IRB.SetInsertPoint(CI);
782 SmallVector<Value *> Args;
783 Value *Fl = CI->getArgOperand(0);
784 Args.push_back(Fl);
785
786 dxil::OpCode OpCode;
787 Value *T = CI->getArgOperand(1);
788 auto *TCI = dyn_cast<ConstantInt>(T);
789 switch (TCI->getZExtValue()) {
790 case FPClassTest::fcInf:
791 OpCode = dxil::OpCode::IsInf;
792 break;
793 case FPClassTest::fcNan:
794 OpCode = dxil::OpCode::IsNaN;
795 break;
796 case FPClassTest::fcNormal:
797 OpCode = dxil::OpCode::IsNormal;
798 break;
799 case FPClassTest::fcFinite:
800 OpCode = dxil::OpCode::IsFinite;
801 break;
802 default:
803 SmallString<128> Msg =
804 formatv("Unsupported FPClassTest {0} for DXIL Op Lowering",
805 TCI->getZExtValue());
806 return make_error<StringError>(Msg, inconvertibleErrorCode());
807 }
808
809 Expected<CallInst *> OpCall =
810 OpBuilder.tryCreateOp(OpCode, Args, CI->getName(), RetTy);
811 if (Error E = OpCall.takeError())
812 return E;
813
814 CI->replaceAllUsesWith(*OpCall);
815 CI->eraseFromParent();
816 return Error::success();
817 });
818 }
819
lowerIntrinsics()820 bool lowerIntrinsics() {
821 bool Updated = false;
822 bool HasErrors = false;
823
824 for (Function &F : make_early_inc_range(M.functions())) {
825 if (!F.isDeclaration())
826 continue;
827 Intrinsic::ID ID = F.getIntrinsicID();
828 switch (ID) {
829 // NOTE: Skip dx_resource_casthandle here. They are
830 // resolved after this loop in cleanupHandleCasts.
831 case Intrinsic::dx_resource_casthandle:
832 // NOTE: llvm.dbg.value is supported as is in DXIL.
833 case Intrinsic::dbg_value:
834 case Intrinsic::not_intrinsic:
835 if (F.use_empty())
836 F.eraseFromParent();
837 continue;
838 default:
839 if (F.use_empty())
840 F.eraseFromParent();
841 else {
842 SmallString<128> Msg = formatv(
843 "Unsupported intrinsic {0} for DXIL lowering", F.getName());
844 M.getContext().emitError(Msg);
845 HasErrors |= true;
846 }
847 break;
848
849 #define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \
850 case Intrin: \
851 HasErrors |= replaceFunctionWithOp( \
852 F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__}); \
853 break;
854 #include "DXILOperation.inc"
855 case Intrinsic::dx_resource_handlefrombinding:
856 HasErrors |= lowerHandleFromBinding(F);
857 break;
858 case Intrinsic::dx_resource_getpointer:
859 HasErrors |= lowerGetPointer(F);
860 break;
861 case Intrinsic::dx_resource_load_typedbuffer:
862 HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true);
863 break;
864 case Intrinsic::dx_resource_store_typedbuffer:
865 HasErrors |= lowerBufferStore(F, /*IsRaw=*/false);
866 break;
867 case Intrinsic::dx_resource_load_rawbuffer:
868 HasErrors |= lowerRawBufferLoad(F);
869 break;
870 case Intrinsic::dx_resource_store_rawbuffer:
871 HasErrors |= lowerBufferStore(F, /*IsRaw=*/true);
872 break;
873 case Intrinsic::dx_resource_load_cbufferrow_2:
874 case Intrinsic::dx_resource_load_cbufferrow_4:
875 case Intrinsic::dx_resource_load_cbufferrow_8:
876 HasErrors |= lowerCBufferLoad(F);
877 break;
878 case Intrinsic::dx_resource_updatecounter:
879 HasErrors |= lowerUpdateCounter(F);
880 break;
881 case Intrinsic::ctpop:
882 HasErrors |= lowerCtpopToCountBits(F);
883 break;
884 case Intrinsic::lifetime_start:
885 case Intrinsic::lifetime_end:
886 if (F.use_empty())
887 F.eraseFromParent();
888 else {
889 if (MMDI.DXILVersion < VersionTuple(1, 6))
890 HasErrors |= lowerLifetimeIntrinsic(F);
891 else
892 continue;
893 }
894 break;
895 case Intrinsic::is_fpclass:
896 HasErrors |= lowerIsFPClass(F);
897 break;
898 }
899 Updated = true;
900 }
901 if (Updated && !HasErrors)
902 cleanupHandleCasts();
903
904 return Updated;
905 }
906 };
907 } // namespace
908
run(Module & M,ModuleAnalysisManager & MAM)909 PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) {
910 DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M);
911 DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M);
912 const ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M);
913
914 const bool MadeChanges = OpLowerer(M, DRM, DRTM, MMDI).lowerIntrinsics();
915 if (!MadeChanges)
916 return PreservedAnalyses::all();
917 PreservedAnalyses PA;
918 PA.preserve<DXILResourceAnalysis>();
919 PA.preserve<DXILMetadataAnalysis>();
920 PA.preserve<ShaderFlagsAnalysis>();
921 return PA;
922 }
923
924 namespace {
925 class DXILOpLoweringLegacy : public ModulePass {
926 public:
runOnModule(Module & M)927 bool runOnModule(Module &M) override {
928 DXILResourceMap &DRM =
929 getAnalysis<DXILResourceWrapperPass>().getResourceMap();
930 DXILResourceTypeMap &DRTM =
931 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
932 const ModuleMetadataInfo MMDI =
933 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
934
935 return OpLowerer(M, DRM, DRTM, MMDI).lowerIntrinsics();
936 }
getPassName() const937 StringRef getPassName() const override { return "DXIL Op Lowering"; }
DXILOpLoweringLegacy()938 DXILOpLoweringLegacy() : ModulePass(ID) {}
939
940 static char ID; // Pass identification.
getAnalysisUsage(llvm::AnalysisUsage & AU) const941 void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
942 AU.addRequired<DXILResourceTypeWrapperPass>();
943 AU.addRequired<DXILResourceWrapperPass>();
944 AU.addRequired<DXILMetadataAnalysisWrapperPass>();
945 AU.addPreserved<DXILResourceWrapperPass>();
946 AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
947 AU.addPreserved<ShaderFlagsAnalysisWrapper>();
948 }
949 };
950 char DXILOpLoweringLegacy::ID = 0;
951 } // end anonymous namespace
952
953 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
954 false, false)
INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)955 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
956 INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
957 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
958 false)
959
960 ModulePass *llvm::createDXILOpLoweringLegacyPass() {
961 return new DXILOpLoweringLegacy();
962 }
963