1 //===- DXILOpLowering.cpp - Lowering to DXIL operations -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "DXILOpLowering.h" 10 #include "DXILConstants.h" 11 #include "DXILOpBuilder.h" 12 #include "DXILShaderFlags.h" 13 #include "DirectX.h" 14 #include "llvm/ADT/SmallVector.h" 15 #include "llvm/Analysis/DXILMetadataAnalysis.h" 16 #include "llvm/Analysis/DXILResource.h" 17 #include "llvm/CodeGen/Passes.h" 18 #include "llvm/IR/DiagnosticInfo.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/Instruction.h" 21 #include "llvm/IR/Instructions.h" 22 #include "llvm/IR/Intrinsics.h" 23 #include "llvm/IR/IntrinsicsDirectX.h" 24 #include "llvm/IR/Module.h" 25 #include "llvm/IR/PassManager.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Pass.h" 28 #include "llvm/Support/ErrorHandling.h" 29 #include "llvm/Support/FormatVariadic.h" 30 31 #define DEBUG_TYPE "dxil-op-lower" 32 33 using namespace llvm; 34 using namespace llvm::dxil; 35 36 namespace { 37 class OpLowerer { 38 Module &M; 39 DXILOpBuilder OpBuilder; 40 DXILResourceMap &DRM; 41 DXILResourceTypeMap &DRTM; 42 const ModuleMetadataInfo &MMDI; 43 SmallVector<CallInst *> CleanupCasts; 44 45 public: 46 OpLowerer(Module &M, DXILResourceMap &DRM, DXILResourceTypeMap &DRTM, 47 const ModuleMetadataInfo &MMDI) 48 : M(M), OpBuilder(M), DRM(DRM), DRTM(DRTM), MMDI(MMDI) {} 49 50 /// Replace every call to \c F using \c ReplaceCall, and then erase \c F. If 51 /// there is an error replacing a call, we emit a diagnostic and return true. 52 [[nodiscard]] bool 53 replaceFunction(Function &F, 54 llvm::function_ref<Error(CallInst *CI)> ReplaceCall) { 55 for (User *U : make_early_inc_range(F.users())) { 56 CallInst *CI = dyn_cast<CallInst>(U); 57 if (!CI) 58 continue; 59 60 if (Error E = ReplaceCall(CI)) { 61 std::string Message(toString(std::move(E))); 62 M.getContext().diagnose(DiagnosticInfoUnsupported( 63 *CI->getFunction(), Message, CI->getDebugLoc())); 64 65 return true; 66 } 67 } 68 if (F.user_empty()) 69 F.eraseFromParent(); 70 return false; 71 } 72 73 struct IntrinArgSelect { 74 enum class Type { 75 #define DXIL_OP_INTRINSIC_ARG_SELECT_TYPE(name) name, 76 #include "DXILOperation.inc" 77 }; 78 Type Type; 79 int Value; 80 }; 81 82 /// Replaces uses of a struct with uses of an equivalent named struct. 83 /// 84 /// DXIL operations that return structs give them well known names, so we need 85 /// to update uses when we switch from an LLVM intrinsic to an op. 86 Error replaceNamedStructUses(CallInst *Intrin, CallInst *DXILOp) { 87 auto *IntrinTy = cast<StructType>(Intrin->getType()); 88 auto *DXILOpTy = cast<StructType>(DXILOp->getType()); 89 if (!IntrinTy->isLayoutIdentical(DXILOpTy)) 90 return make_error<StringError>( 91 "Type mismatch between intrinsic and DXIL op", 92 inconvertibleErrorCode()); 93 94 for (Use &U : make_early_inc_range(Intrin->uses())) 95 if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) 96 EVI->setOperand(0, DXILOp); 97 else if (auto *IVI = dyn_cast<InsertValueInst>(U.getUser())) 98 IVI->setOperand(0, DXILOp); 99 else 100 return make_error<StringError>("DXIL ops that return structs may only " 101 "be used by insert- and extractvalue", 102 inconvertibleErrorCode()); 103 return Error::success(); 104 } 105 106 [[nodiscard]] bool 107 replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp, 108 ArrayRef<IntrinArgSelect> ArgSelects) { 109 return replaceFunction(F, [&](CallInst *CI) -> Error { 110 OpBuilder.getIRB().SetInsertPoint(CI); 111 SmallVector<Value *> Args; 112 if (ArgSelects.size()) { 113 for (const IntrinArgSelect &A : ArgSelects) { 114 switch (A.Type) { 115 case IntrinArgSelect::Type::Index: 116 Args.push_back(CI->getArgOperand(A.Value)); 117 break; 118 case IntrinArgSelect::Type::I8: 119 Args.push_back(OpBuilder.getIRB().getInt8((uint8_t)A.Value)); 120 break; 121 case IntrinArgSelect::Type::I32: 122 Args.push_back(OpBuilder.getIRB().getInt32(A.Value)); 123 break; 124 } 125 } 126 } else { 127 Args.append(CI->arg_begin(), CI->arg_end()); 128 } 129 130 Expected<CallInst *> OpCall = 131 OpBuilder.tryCreateOp(DXILOp, Args, CI->getName(), F.getReturnType()); 132 if (Error E = OpCall.takeError()) 133 return E; 134 135 if (isa<StructType>(CI->getType())) { 136 if (Error E = replaceNamedStructUses(CI, *OpCall)) 137 return E; 138 } else 139 CI->replaceAllUsesWith(*OpCall); 140 141 CI->eraseFromParent(); 142 return Error::success(); 143 }); 144 } 145 146 /// Create a cast between a `target("dx")` type and `dx.types.Handle`, which 147 /// is intended to be removed by the end of lowering. This is used to allow 148 /// lowering of ops which need to change their return or argument types in a 149 /// piecemeal way - we can add the casts in to avoid updating all of the uses 150 /// or defs, and by the end all of the casts will be redundant. 151 Value *createTmpHandleCast(Value *V, Type *Ty) { 152 CallInst *Cast = OpBuilder.getIRB().CreateIntrinsic( 153 Intrinsic::dx_resource_casthandle, {Ty, V->getType()}, {V}); 154 CleanupCasts.push_back(Cast); 155 return Cast; 156 } 157 158 void cleanupHandleCasts() { 159 SmallVector<CallInst *> ToRemove; 160 SmallVector<Function *> CastFns; 161 162 for (CallInst *Cast : CleanupCasts) { 163 // These casts were only put in to ease the move from `target("dx")` types 164 // to `dx.types.Handle in a piecemeal way. At this point, all of the 165 // non-cast uses should now be `dx.types.Handle`, and remaining casts 166 // should all form pairs to and from the now unused `target("dx")` type. 167 CastFns.push_back(Cast->getCalledFunction()); 168 169 // If the cast is not to `dx.types.Handle`, it should be the first part of 170 // the pair. Keep track so we can remove it once it has no more uses. 171 if (Cast->getType() != OpBuilder.getHandleType()) { 172 ToRemove.push_back(Cast); 173 continue; 174 } 175 // Otherwise, we're the second handle in a pair. Forward the arguments and 176 // remove the (second) cast. 177 CallInst *Def = cast<CallInst>(Cast->getOperand(0)); 178 assert(Def->getIntrinsicID() == Intrinsic::dx_resource_casthandle && 179 "Unbalanced pair of temporary handle casts"); 180 Cast->replaceAllUsesWith(Def->getOperand(0)); 181 Cast->eraseFromParent(); 182 } 183 for (CallInst *Cast : ToRemove) { 184 assert(Cast->user_empty() && "Temporary handle cast still has users"); 185 Cast->eraseFromParent(); 186 } 187 188 // Deduplicate the cast functions so that we only erase each one once. 189 llvm::sort(CastFns); 190 CastFns.erase(llvm::unique(CastFns), CastFns.end()); 191 for (Function *F : CastFns) 192 F->eraseFromParent(); 193 194 CleanupCasts.clear(); 195 } 196 197 // Remove the resource global associated with the handleFromBinding call 198 // instruction and their uses as they aren't needed anymore. 199 // TODO: We should verify that all the globals get removed. 200 // It's expected we'll need a custom pass in the future that will eliminate 201 // the need for this here. 202 void removeResourceGlobals(CallInst *CI) { 203 for (User *User : make_early_inc_range(CI->users())) { 204 if (StoreInst *Store = dyn_cast<StoreInst>(User)) { 205 Value *V = Store->getOperand(1); 206 Store->eraseFromParent(); 207 if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) 208 if (GV->use_empty()) { 209 GV->removeDeadConstantUsers(); 210 GV->eraseFromParent(); 211 } 212 } 213 } 214 } 215 216 void replaceHandleFromBindingCall(CallInst *CI, Value *Replacement) { 217 assert(CI->getCalledFunction()->getIntrinsicID() == 218 Intrinsic::dx_resource_handlefrombinding); 219 220 removeResourceGlobals(CI); 221 222 auto *NameGlobal = dyn_cast<llvm::GlobalVariable>(CI->getArgOperand(5)); 223 224 CI->replaceAllUsesWith(Replacement); 225 CI->eraseFromParent(); 226 227 if (NameGlobal && NameGlobal->use_empty()) 228 NameGlobal->removeFromParent(); 229 } 230 231 [[nodiscard]] bool lowerToCreateHandle(Function &F) { 232 IRBuilder<> &IRB = OpBuilder.getIRB(); 233 Type *Int8Ty = IRB.getInt8Ty(); 234 Type *Int32Ty = IRB.getInt32Ty(); 235 236 return replaceFunction(F, [&](CallInst *CI) -> Error { 237 IRB.SetInsertPoint(CI); 238 239 auto *It = DRM.find(CI); 240 assert(It != DRM.end() && "Resource not in map?"); 241 dxil::ResourceInfo &RI = *It; 242 243 const auto &Binding = RI.getBinding(); 244 dxil::ResourceClass RC = DRTM[RI.getHandleTy()].getResourceClass(); 245 246 Value *IndexOp = CI->getArgOperand(3); 247 if (Binding.LowerBound != 0) 248 IndexOp = IRB.CreateAdd(IndexOp, 249 ConstantInt::get(Int32Ty, Binding.LowerBound)); 250 251 std::array<Value *, 4> Args{ 252 ConstantInt::get(Int8Ty, llvm::to_underlying(RC)), 253 ConstantInt::get(Int32Ty, Binding.RecordID), IndexOp, 254 CI->getArgOperand(4)}; 255 Expected<CallInst *> OpCall = 256 OpBuilder.tryCreateOp(OpCode::CreateHandle, Args, CI->getName()); 257 if (Error E = OpCall.takeError()) 258 return E; 259 260 Value *Cast = createTmpHandleCast(*OpCall, CI->getType()); 261 replaceHandleFromBindingCall(CI, Cast); 262 return Error::success(); 263 }); 264 } 265 266 [[nodiscard]] bool lowerToBindAndAnnotateHandle(Function &F) { 267 IRBuilder<> &IRB = OpBuilder.getIRB(); 268 Type *Int32Ty = IRB.getInt32Ty(); 269 270 return replaceFunction(F, [&](CallInst *CI) -> Error { 271 IRB.SetInsertPoint(CI); 272 273 auto *It = DRM.find(CI); 274 assert(It != DRM.end() && "Resource not in map?"); 275 dxil::ResourceInfo &RI = *It; 276 277 const auto &Binding = RI.getBinding(); 278 dxil::ResourceTypeInfo &RTI = DRTM[RI.getHandleTy()]; 279 dxil::ResourceClass RC = RTI.getResourceClass(); 280 281 Value *IndexOp = CI->getArgOperand(3); 282 if (Binding.LowerBound != 0) 283 IndexOp = IRB.CreateAdd(IndexOp, 284 ConstantInt::get(Int32Ty, Binding.LowerBound)); 285 286 std::pair<uint32_t, uint32_t> Props = 287 RI.getAnnotateProps(*F.getParent(), RTI); 288 289 // For `CreateHandleFromBinding` we need the upper bound rather than the 290 // size, so we need to be careful about the difference for "unbounded". 291 uint32_t Unbounded = std::numeric_limits<uint32_t>::max(); 292 uint32_t UpperBound = Binding.Size == Unbounded 293 ? Unbounded 294 : Binding.LowerBound + Binding.Size - 1; 295 Constant *ResBind = OpBuilder.getResBind(Binding.LowerBound, UpperBound, 296 Binding.Space, RC); 297 std::array<Value *, 3> BindArgs{ResBind, IndexOp, CI->getArgOperand(4)}; 298 Expected<CallInst *> OpBind = OpBuilder.tryCreateOp( 299 OpCode::CreateHandleFromBinding, BindArgs, CI->getName()); 300 if (Error E = OpBind.takeError()) 301 return E; 302 303 std::array<Value *, 2> AnnotateArgs{ 304 *OpBind, OpBuilder.getResProps(Props.first, Props.second)}; 305 Expected<CallInst *> OpAnnotate = OpBuilder.tryCreateOp( 306 OpCode::AnnotateHandle, AnnotateArgs, 307 CI->hasName() ? CI->getName() + "_annot" : Twine()); 308 if (Error E = OpAnnotate.takeError()) 309 return E; 310 311 Value *Cast = createTmpHandleCast(*OpAnnotate, CI->getType()); 312 replaceHandleFromBindingCall(CI, Cast); 313 return Error::success(); 314 }); 315 } 316 317 /// Lower `dx.resource.handlefrombinding` intrinsics depending on the shader 318 /// model and taking into account binding information from 319 /// DXILResourceAnalysis. 320 bool lowerHandleFromBinding(Function &F) { 321 if (MMDI.DXILVersion < VersionTuple(1, 6)) 322 return lowerToCreateHandle(F); 323 return lowerToBindAndAnnotateHandle(F); 324 } 325 326 /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op. 327 /// Since we expect to be post-scalarization, make an effort to avoid vectors. 328 Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) { 329 IRBuilder<> &IRB = OpBuilder.getIRB(); 330 331 Instruction *OldResult = Intrin; 332 Type *OldTy = Intrin->getType(); 333 334 if (HasCheckBit) { 335 auto *ST = cast<StructType>(OldTy); 336 337 Value *CheckOp = nullptr; 338 Type *Int32Ty = IRB.getInt32Ty(); 339 for (Use &U : make_early_inc_range(OldResult->uses())) { 340 if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) { 341 ArrayRef<unsigned> Indices = EVI->getIndices(); 342 assert(Indices.size() == 1); 343 // We're only interested in uses of the check bit for now. 344 if (Indices[0] != 1) 345 continue; 346 if (!CheckOp) { 347 Value *NewEVI = IRB.CreateExtractValue(Op, 4); 348 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 349 OpCode::CheckAccessFullyMapped, {NewEVI}, 350 OldResult->hasName() ? OldResult->getName() + "_check" 351 : Twine(), 352 Int32Ty); 353 if (Error E = OpCall.takeError()) 354 return E; 355 CheckOp = *OpCall; 356 } 357 EVI->replaceAllUsesWith(CheckOp); 358 EVI->eraseFromParent(); 359 } 360 } 361 362 if (OldResult->use_empty()) { 363 // Only the check bit was used, so we're done here. 364 OldResult->eraseFromParent(); 365 return Error::success(); 366 } 367 368 assert(OldResult->hasOneUse() && 369 isa<ExtractValueInst>(*OldResult->user_begin()) && 370 "Expected only use to be extract of first element"); 371 OldResult = cast<Instruction>(*OldResult->user_begin()); 372 OldTy = ST->getElementType(0); 373 } 374 375 // For scalars, we just extract the first element. 376 if (!isa<FixedVectorType>(OldTy)) { 377 Value *EVI = IRB.CreateExtractValue(Op, 0); 378 OldResult->replaceAllUsesWith(EVI); 379 OldResult->eraseFromParent(); 380 if (OldResult != Intrin) { 381 assert(Intrin->use_empty() && "Intrinsic still has uses?"); 382 Intrin->eraseFromParent(); 383 } 384 return Error::success(); 385 } 386 387 std::array<Value *, 4> Extracts = {}; 388 SmallVector<ExtractElementInst *> DynamicAccesses; 389 390 // The users of the operation should all be scalarized, so we attempt to 391 // replace the extractelements with extractvalues directly. 392 for (Use &U : make_early_inc_range(OldResult->uses())) { 393 if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) { 394 if (auto *IndexOp = dyn_cast<ConstantInt>(EEI->getIndexOperand())) { 395 size_t IndexVal = IndexOp->getZExtValue(); 396 assert(IndexVal < 4 && "Index into buffer load out of range"); 397 if (!Extracts[IndexVal]) 398 Extracts[IndexVal] = IRB.CreateExtractValue(Op, IndexVal); 399 EEI->replaceAllUsesWith(Extracts[IndexVal]); 400 EEI->eraseFromParent(); 401 } else { 402 DynamicAccesses.push_back(EEI); 403 } 404 } 405 } 406 407 const auto *VecTy = cast<FixedVectorType>(OldTy); 408 const unsigned N = VecTy->getNumElements(); 409 410 // If there's a dynamic access we need to round trip through stack memory so 411 // that we don't leave vectors around. 412 if (!DynamicAccesses.empty()) { 413 Type *Int32Ty = IRB.getInt32Ty(); 414 Constant *Zero = ConstantInt::get(Int32Ty, 0); 415 416 Type *ElTy = VecTy->getElementType(); 417 Type *ArrayTy = ArrayType::get(ElTy, N); 418 Value *Alloca = IRB.CreateAlloca(ArrayTy); 419 420 for (int I = 0, E = N; I != E; ++I) { 421 if (!Extracts[I]) 422 Extracts[I] = IRB.CreateExtractValue(Op, I); 423 Value *GEP = IRB.CreateInBoundsGEP( 424 ArrayTy, Alloca, {Zero, ConstantInt::get(Int32Ty, I)}); 425 IRB.CreateStore(Extracts[I], GEP); 426 } 427 428 for (ExtractElementInst *EEI : DynamicAccesses) { 429 Value *GEP = IRB.CreateInBoundsGEP(ArrayTy, Alloca, 430 {Zero, EEI->getIndexOperand()}); 431 Value *Load = IRB.CreateLoad(ElTy, GEP); 432 EEI->replaceAllUsesWith(Load); 433 EEI->eraseFromParent(); 434 } 435 } 436 437 // If we still have uses, then we're not fully scalarized and need to 438 // recreate the vector. This should only happen for things like exported 439 // functions from libraries. 440 if (!OldResult->use_empty()) { 441 for (int I = 0, E = N; I != E; ++I) 442 if (!Extracts[I]) 443 Extracts[I] = IRB.CreateExtractValue(Op, I); 444 445 Value *Vec = PoisonValue::get(OldTy); 446 for (int I = 0, E = N; I != E; ++I) 447 Vec = IRB.CreateInsertElement(Vec, Extracts[I], I); 448 OldResult->replaceAllUsesWith(Vec); 449 } 450 451 OldResult->eraseFromParent(); 452 if (OldResult != Intrin) { 453 assert(Intrin->use_empty() && "Intrinsic still has uses?"); 454 Intrin->eraseFromParent(); 455 } 456 457 return Error::success(); 458 } 459 460 [[nodiscard]] bool lowerTypedBufferLoad(Function &F, bool HasCheckBit) { 461 IRBuilder<> &IRB = OpBuilder.getIRB(); 462 Type *Int32Ty = IRB.getInt32Ty(); 463 464 return replaceFunction(F, [&](CallInst *CI) -> Error { 465 IRB.SetInsertPoint(CI); 466 467 Value *Handle = 468 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 469 Value *Index0 = CI->getArgOperand(1); 470 Value *Index1 = UndefValue::get(Int32Ty); 471 472 Type *OldTy = CI->getType(); 473 if (HasCheckBit) 474 OldTy = cast<StructType>(OldTy)->getElementType(0); 475 Type *NewRetTy = OpBuilder.getResRetType(OldTy->getScalarType()); 476 477 std::array<Value *, 3> Args{Handle, Index0, Index1}; 478 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 479 OpCode::BufferLoad, Args, CI->getName(), NewRetTy); 480 if (Error E = OpCall.takeError()) 481 return E; 482 if (Error E = replaceResRetUses(CI, *OpCall, HasCheckBit)) 483 return E; 484 485 return Error::success(); 486 }); 487 } 488 489 [[nodiscard]] bool lowerRawBufferLoad(Function &F) { 490 const DataLayout &DL = F.getDataLayout(); 491 IRBuilder<> &IRB = OpBuilder.getIRB(); 492 Type *Int8Ty = IRB.getInt8Ty(); 493 Type *Int32Ty = IRB.getInt32Ty(); 494 495 return replaceFunction(F, [&](CallInst *CI) -> Error { 496 IRB.SetInsertPoint(CI); 497 498 Type *OldTy = cast<StructType>(CI->getType())->getElementType(0); 499 Type *ScalarTy = OldTy->getScalarType(); 500 Type *NewRetTy = OpBuilder.getResRetType(ScalarTy); 501 502 Value *Handle = 503 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 504 Value *Index0 = CI->getArgOperand(1); 505 Value *Index1 = CI->getArgOperand(2); 506 uint64_t NumElements = 507 DL.getTypeSizeInBits(OldTy) / DL.getTypeSizeInBits(ScalarTy); 508 Value *Mask = ConstantInt::get(Int8Ty, ~(~0U << NumElements)); 509 Value *Align = 510 ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value()); 511 512 Expected<CallInst *> OpCall = 513 MMDI.DXILVersion >= VersionTuple(1, 2) 514 ? OpBuilder.tryCreateOp(OpCode::RawBufferLoad, 515 {Handle, Index0, Index1, Mask, Align}, 516 CI->getName(), NewRetTy) 517 : OpBuilder.tryCreateOp(OpCode::BufferLoad, 518 {Handle, Index0, Index1}, CI->getName(), 519 NewRetTy); 520 if (Error E = OpCall.takeError()) 521 return E; 522 if (Error E = replaceResRetUses(CI, *OpCall, /*HasCheckBit=*/true)) 523 return E; 524 525 return Error::success(); 526 }); 527 } 528 529 [[nodiscard]] bool lowerCBufferLoad(Function &F) { 530 IRBuilder<> &IRB = OpBuilder.getIRB(); 531 532 return replaceFunction(F, [&](CallInst *CI) -> Error { 533 IRB.SetInsertPoint(CI); 534 535 Type *OldTy = cast<StructType>(CI->getType())->getElementType(0); 536 Type *ScalarTy = OldTy->getScalarType(); 537 Type *NewRetTy = OpBuilder.getCBufRetType(ScalarTy); 538 539 Value *Handle = 540 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 541 Value *Index = CI->getArgOperand(1); 542 543 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 544 OpCode::CBufferLoadLegacy, {Handle, Index}, CI->getName(), NewRetTy); 545 if (Error E = OpCall.takeError()) 546 return E; 547 if (Error E = replaceNamedStructUses(CI, *OpCall)) 548 return E; 549 550 CI->eraseFromParent(); 551 return Error::success(); 552 }); 553 } 554 555 [[nodiscard]] bool lowerUpdateCounter(Function &F) { 556 IRBuilder<> &IRB = OpBuilder.getIRB(); 557 Type *Int32Ty = IRB.getInt32Ty(); 558 559 return replaceFunction(F, [&](CallInst *CI) -> Error { 560 IRB.SetInsertPoint(CI); 561 Value *Handle = 562 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 563 Value *Op1 = CI->getArgOperand(1); 564 565 std::array<Value *, 2> Args{Handle, Op1}; 566 567 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 568 OpCode::UpdateCounter, Args, CI->getName(), Int32Ty); 569 570 if (Error E = OpCall.takeError()) 571 return E; 572 573 CI->replaceAllUsesWith(*OpCall); 574 CI->eraseFromParent(); 575 return Error::success(); 576 }); 577 } 578 579 [[nodiscard]] bool lowerGetPointer(Function &F) { 580 // These should have already been handled in DXILResourceAccess, so we can 581 // just clean up the dead prototype. 582 assert(F.user_empty() && "getpointer operations should have been removed"); 583 F.eraseFromParent(); 584 return false; 585 } 586 587 [[nodiscard]] bool lowerBufferStore(Function &F, bool IsRaw) { 588 const DataLayout &DL = F.getDataLayout(); 589 IRBuilder<> &IRB = OpBuilder.getIRB(); 590 Type *Int8Ty = IRB.getInt8Ty(); 591 Type *Int32Ty = IRB.getInt32Ty(); 592 593 return replaceFunction(F, [&](CallInst *CI) -> Error { 594 IRB.SetInsertPoint(CI); 595 596 Value *Handle = 597 createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); 598 Value *Index0 = CI->getArgOperand(1); 599 Value *Index1 = IsRaw ? CI->getArgOperand(2) : UndefValue::get(Int32Ty); 600 601 Value *Data = CI->getArgOperand(IsRaw ? 3 : 2); 602 Type *DataTy = Data->getType(); 603 Type *ScalarTy = DataTy->getScalarType(); 604 605 uint64_t NumElements = 606 DL.getTypeSizeInBits(DataTy) / DL.getTypeSizeInBits(ScalarTy); 607 Value *Mask = 608 ConstantInt::get(Int8Ty, IsRaw ? ~(~0U << NumElements) : 15U); 609 610 // TODO: check that we only have vector or scalar... 611 if (NumElements > 4) 612 return make_error<StringError>( 613 "Buffer store data must have at most 4 elements", 614 inconvertibleErrorCode()); 615 616 std::array<Value *, 4> DataElements{nullptr, nullptr, nullptr, nullptr}; 617 if (DataTy == ScalarTy) 618 DataElements[0] = Data; 619 else { 620 // Since we're post-scalarizer, if we see a vector here it's likely 621 // constructed solely for the argument of the store. Just use the scalar 622 // values from before they're inserted into the temporary. 623 auto *IEI = dyn_cast<InsertElementInst>(Data); 624 while (IEI) { 625 auto *IndexOp = dyn_cast<ConstantInt>(IEI->getOperand(2)); 626 if (!IndexOp) 627 break; 628 size_t IndexVal = IndexOp->getZExtValue(); 629 assert(IndexVal < 4 && "Too many elements for buffer store"); 630 DataElements[IndexVal] = IEI->getOperand(1); 631 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0)); 632 } 633 } 634 635 // If for some reason we weren't able to forward the arguments from the 636 // scalarizer artifact, then we may need to actually extract elements from 637 // the vector. 638 for (int I = 0, E = NumElements; I < E; ++I) 639 if (DataElements[I] == nullptr) 640 DataElements[I] = 641 IRB.CreateExtractElement(Data, ConstantInt::get(Int32Ty, I)); 642 643 // For any elements beyond the length of the vector, we should fill it up 644 // with undef - however, for typed buffers we repeat the first element to 645 // match DXC. 646 for (int I = NumElements, E = 4; I < E; ++I) 647 if (DataElements[I] == nullptr) 648 DataElements[I] = IsRaw ? UndefValue::get(ScalarTy) : DataElements[0]; 649 650 dxil::OpCode Op = OpCode::BufferStore; 651 SmallVector<Value *, 9> Args{ 652 Handle, Index0, Index1, DataElements[0], 653 DataElements[1], DataElements[2], DataElements[3], Mask}; 654 if (IsRaw && MMDI.DXILVersion >= VersionTuple(1, 2)) { 655 Op = OpCode::RawBufferStore; 656 // RawBufferStore requires the alignment 657 Args.push_back( 658 ConstantInt::get(Int32Ty, DL.getPrefTypeAlign(ScalarTy).value())); 659 } 660 Expected<CallInst *> OpCall = 661 OpBuilder.tryCreateOp(Op, Args, CI->getName()); 662 if (Error E = OpCall.takeError()) 663 return E; 664 665 CI->eraseFromParent(); 666 // Clean up any leftover `insertelement`s 667 auto *IEI = dyn_cast<InsertElementInst>(Data); 668 while (IEI && IEI->use_empty()) { 669 InsertElementInst *Tmp = IEI; 670 IEI = dyn_cast<InsertElementInst>(IEI->getOperand(0)); 671 Tmp->eraseFromParent(); 672 } 673 674 return Error::success(); 675 }); 676 } 677 678 [[nodiscard]] bool lowerCtpopToCountBits(Function &F) { 679 IRBuilder<> &IRB = OpBuilder.getIRB(); 680 Type *Int32Ty = IRB.getInt32Ty(); 681 682 return replaceFunction(F, [&](CallInst *CI) -> Error { 683 IRB.SetInsertPoint(CI); 684 SmallVector<Value *> Args; 685 Args.append(CI->arg_begin(), CI->arg_end()); 686 687 Type *RetTy = Int32Ty; 688 Type *FRT = F.getReturnType(); 689 if (const auto *VT = dyn_cast<VectorType>(FRT)) 690 RetTy = VectorType::get(RetTy, VT); 691 692 Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( 693 dxil::OpCode::CountBits, Args, CI->getName(), RetTy); 694 if (Error E = OpCall.takeError()) 695 return E; 696 697 // If the result type is 32 bits we can do a direct replacement. 698 if (FRT->isIntOrIntVectorTy(32)) { 699 CI->replaceAllUsesWith(*OpCall); 700 CI->eraseFromParent(); 701 return Error::success(); 702 } 703 704 unsigned CastOp; 705 unsigned CastOp2; 706 if (FRT->isIntOrIntVectorTy(16)) { 707 CastOp = Instruction::ZExt; 708 CastOp2 = Instruction::SExt; 709 } else { // must be 64 bits 710 assert(FRT->isIntOrIntVectorTy(64) && 711 "Currently only lowering 16, 32, or 64 bit ctpop to CountBits \ 712 is supported."); 713 CastOp = Instruction::Trunc; 714 CastOp2 = Instruction::Trunc; 715 } 716 717 // It is correct to replace the ctpop with the dxil op and 718 // remove all casts to i32 719 bool NeedsCast = false; 720 for (User *User : make_early_inc_range(CI->users())) { 721 Instruction *I = dyn_cast<Instruction>(User); 722 if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) && 723 I->getType() == RetTy) { 724 I->replaceAllUsesWith(*OpCall); 725 I->eraseFromParent(); 726 } else 727 NeedsCast = true; 728 } 729 730 // It is correct to replace a ctpop with the dxil op and 731 // a cast from i32 to the return type of the ctpop 732 // the cast is emitted here if there is a non-cast to i32 733 // instr which uses the ctpop 734 if (NeedsCast) { 735 Value *Cast = 736 IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast"); 737 CI->replaceAllUsesWith(Cast); 738 } 739 740 CI->eraseFromParent(); 741 return Error::success(); 742 }); 743 } 744 745 [[nodiscard]] bool lowerLifetimeIntrinsic(Function &F) { 746 IRBuilder<> &IRB = OpBuilder.getIRB(); 747 return replaceFunction(F, [&](CallInst *CI) -> Error { 748 IRB.SetInsertPoint(CI); 749 Value *Ptr = CI->getArgOperand(1); 750 assert(Ptr->getType()->isPointerTy() && 751 "Expected operand of lifetime intrinsic to be a pointer"); 752 753 auto ZeroOrUndef = [&](Type *Ty) { 754 return MMDI.ValidatorVersion < VersionTuple(1, 6) 755 ? Constant::getNullValue(Ty) 756 : UndefValue::get(Ty); 757 }; 758 759 Value *Val = nullptr; 760 if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) { 761 if (GV->hasInitializer() || GV->isExternallyInitialized()) 762 return Error::success(); 763 Val = ZeroOrUndef(GV->getValueType()); 764 } else if (auto *AI = dyn_cast<AllocaInst>(Ptr)) 765 Val = ZeroOrUndef(AI->getAllocatedType()); 766 767 assert(Val && "Expected operand of lifetime intrinsic to be a global " 768 "variable or alloca instruction"); 769 IRB.CreateStore(Val, Ptr, false); 770 771 CI->eraseFromParent(); 772 return Error::success(); 773 }); 774 } 775 776 [[nodiscard]] bool lowerIsFPClass(Function &F) { 777 IRBuilder<> &IRB = OpBuilder.getIRB(); 778 Type *RetTy = IRB.getInt1Ty(); 779 780 return replaceFunction(F, [&](CallInst *CI) -> Error { 781 IRB.SetInsertPoint(CI); 782 SmallVector<Value *> Args; 783 Value *Fl = CI->getArgOperand(0); 784 Args.push_back(Fl); 785 786 dxil::OpCode OpCode; 787 Value *T = CI->getArgOperand(1); 788 auto *TCI = dyn_cast<ConstantInt>(T); 789 switch (TCI->getZExtValue()) { 790 case FPClassTest::fcInf: 791 OpCode = dxil::OpCode::IsInf; 792 break; 793 case FPClassTest::fcNan: 794 OpCode = dxil::OpCode::IsNaN; 795 break; 796 case FPClassTest::fcNormal: 797 OpCode = dxil::OpCode::IsNormal; 798 break; 799 case FPClassTest::fcFinite: 800 OpCode = dxil::OpCode::IsFinite; 801 break; 802 default: 803 SmallString<128> Msg = 804 formatv("Unsupported FPClassTest {0} for DXIL Op Lowering", 805 TCI->getZExtValue()); 806 return make_error<StringError>(Msg, inconvertibleErrorCode()); 807 } 808 809 Expected<CallInst *> OpCall = 810 OpBuilder.tryCreateOp(OpCode, Args, CI->getName(), RetTy); 811 if (Error E = OpCall.takeError()) 812 return E; 813 814 CI->replaceAllUsesWith(*OpCall); 815 CI->eraseFromParent(); 816 return Error::success(); 817 }); 818 } 819 820 bool lowerIntrinsics() { 821 bool Updated = false; 822 bool HasErrors = false; 823 824 for (Function &F : make_early_inc_range(M.functions())) { 825 if (!F.isDeclaration()) 826 continue; 827 Intrinsic::ID ID = F.getIntrinsicID(); 828 switch (ID) { 829 // NOTE: Skip dx_resource_casthandle here. They are 830 // resolved after this loop in cleanupHandleCasts. 831 case Intrinsic::dx_resource_casthandle: 832 // NOTE: llvm.dbg.value is supported as is in DXIL. 833 case Intrinsic::dbg_value: 834 case Intrinsic::not_intrinsic: 835 if (F.use_empty()) 836 F.eraseFromParent(); 837 continue; 838 default: 839 if (F.use_empty()) 840 F.eraseFromParent(); 841 else { 842 SmallString<128> Msg = formatv( 843 "Unsupported intrinsic {0} for DXIL lowering", F.getName()); 844 M.getContext().emitError(Msg); 845 HasErrors |= true; 846 } 847 break; 848 849 #define DXIL_OP_INTRINSIC(OpCode, Intrin, ...) \ 850 case Intrin: \ 851 HasErrors |= replaceFunctionWithOp( \ 852 F, OpCode, ArrayRef<IntrinArgSelect>{__VA_ARGS__}); \ 853 break; 854 #include "DXILOperation.inc" 855 case Intrinsic::dx_resource_handlefrombinding: 856 HasErrors |= lowerHandleFromBinding(F); 857 break; 858 case Intrinsic::dx_resource_getpointer: 859 HasErrors |= lowerGetPointer(F); 860 break; 861 case Intrinsic::dx_resource_load_typedbuffer: 862 HasErrors |= lowerTypedBufferLoad(F, /*HasCheckBit=*/true); 863 break; 864 case Intrinsic::dx_resource_store_typedbuffer: 865 HasErrors |= lowerBufferStore(F, /*IsRaw=*/false); 866 break; 867 case Intrinsic::dx_resource_load_rawbuffer: 868 HasErrors |= lowerRawBufferLoad(F); 869 break; 870 case Intrinsic::dx_resource_store_rawbuffer: 871 HasErrors |= lowerBufferStore(F, /*IsRaw=*/true); 872 break; 873 case Intrinsic::dx_resource_load_cbufferrow_2: 874 case Intrinsic::dx_resource_load_cbufferrow_4: 875 case Intrinsic::dx_resource_load_cbufferrow_8: 876 HasErrors |= lowerCBufferLoad(F); 877 break; 878 case Intrinsic::dx_resource_updatecounter: 879 HasErrors |= lowerUpdateCounter(F); 880 break; 881 case Intrinsic::ctpop: 882 HasErrors |= lowerCtpopToCountBits(F); 883 break; 884 case Intrinsic::lifetime_start: 885 case Intrinsic::lifetime_end: 886 if (F.use_empty()) 887 F.eraseFromParent(); 888 else { 889 if (MMDI.DXILVersion < VersionTuple(1, 6)) 890 HasErrors |= lowerLifetimeIntrinsic(F); 891 else 892 continue; 893 } 894 break; 895 case Intrinsic::is_fpclass: 896 HasErrors |= lowerIsFPClass(F); 897 break; 898 } 899 Updated = true; 900 } 901 if (Updated && !HasErrors) 902 cleanupHandleCasts(); 903 904 return Updated; 905 } 906 }; 907 } // namespace 908 909 PreservedAnalyses DXILOpLowering::run(Module &M, ModuleAnalysisManager &MAM) { 910 DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M); 911 DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M); 912 const ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M); 913 914 const bool MadeChanges = OpLowerer(M, DRM, DRTM, MMDI).lowerIntrinsics(); 915 if (!MadeChanges) 916 return PreservedAnalyses::all(); 917 PreservedAnalyses PA; 918 PA.preserve<DXILResourceAnalysis>(); 919 PA.preserve<DXILMetadataAnalysis>(); 920 PA.preserve<ShaderFlagsAnalysis>(); 921 return PA; 922 } 923 924 namespace { 925 class DXILOpLoweringLegacy : public ModulePass { 926 public: 927 bool runOnModule(Module &M) override { 928 DXILResourceMap &DRM = 929 getAnalysis<DXILResourceWrapperPass>().getResourceMap(); 930 DXILResourceTypeMap &DRTM = 931 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); 932 const ModuleMetadataInfo MMDI = 933 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); 934 935 return OpLowerer(M, DRM, DRTM, MMDI).lowerIntrinsics(); 936 } 937 StringRef getPassName() const override { return "DXIL Op Lowering"; } 938 DXILOpLoweringLegacy() : ModulePass(ID) {} 939 940 static char ID; // Pass identification. 941 void getAnalysisUsage(llvm::AnalysisUsage &AU) const override { 942 AU.addRequired<DXILResourceTypeWrapperPass>(); 943 AU.addRequired<DXILResourceWrapperPass>(); 944 AU.addRequired<DXILMetadataAnalysisWrapperPass>(); 945 AU.addPreserved<DXILResourceWrapperPass>(); 946 AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); 947 AU.addPreserved<ShaderFlagsAnalysisWrapper>(); 948 } 949 }; 950 char DXILOpLoweringLegacy::ID = 0; 951 } // end anonymous namespace 952 953 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", 954 false, false) 955 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) 956 INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) 957 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, 958 false) 959 960 ModulePass *llvm::createDXILOpLoweringLegacyPass() { 961 return new DXILOpLoweringLegacy(); 962 } 963