1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains passes and utilities to lower llvm intrinsic call 10 /// to DXILOp function call. 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILConstants.h" 14 #include "DXILIntrinsicExpansion.h" 15 #include "DXILOpBuilder.h" 16 #include "DirectX.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/CodeGen/Passes.h" 19 #include "llvm/IR/IRBuilder.h" 20 #include "llvm/IR/Instruction.h" 21 #include "llvm/IR/Intrinsics.h" 22 #include "llvm/IR/IntrinsicsDirectX.h" 23 #include "llvm/IR/Module.h" 24 #include "llvm/IR/PassManager.h" 25 #include "llvm/Pass.h" 26 #include "llvm/Support/ErrorHandling.h" 27 28 #define DEBUG_TYPE "dxil-op-lower" 29 30 using namespace llvm; 31 using namespace llvm::dxil; 32 33 static bool isVectorArgExpansion(Function &F) { 34 switch (F.getIntrinsicID()) { 35 case Intrinsic::dx_dot2: 36 case Intrinsic::dx_dot3: 37 case Intrinsic::dx_dot4: 38 return true; 39 } 40 return false; 41 } 42 43 static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) { 44 SmallVector<Value *> ExtractedElements; 45 auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); 46 for (unsigned I = 0; I < VecArg->getNumElements(); ++I) { 47 Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I); 48 Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index); 49 ExtractedElements.push_back(ExtractedElement); 50 } 51 return ExtractedElements; 52 } 53 54 static SmallVector<Value *> argVectorFlatten(CallInst *Orig, 55 IRBuilder<> &Builder) { 56 // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. 57 unsigned NumOperands = Orig->getNumOperands() - 1; 58 assert(NumOperands > 0); 59 Value *Arg0 = Orig->getOperand(0); 60 [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType()); 61 assert(VecArg0); 62 SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder); 63 for (unsigned I = 1; I < NumOperands; ++I) { 64 Value *Arg = Orig->getOperand(I); 65 [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); 66 assert(VecArg); 67 assert(VecArg0->getElementType() == VecArg->getElementType()); 68 assert(VecArg0->getNumElements() == VecArg->getNumElements()); 69 auto NextOperandList = populateOperands(Arg, Builder); 70 NewOperands.append(NextOperandList.begin(), NextOperandList.end()); 71 } 72 return NewOperands; 73 } 74 75 static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { 76 IRBuilder<> B(M.getContext()); 77 DXILOpBuilder DXILB(M, B); 78 Type *OverloadTy = DXILB.getOverloadTy(DXILOp, F.getFunctionType()); 79 for (User *U : make_early_inc_range(F.users())) { 80 CallInst *CI = dyn_cast<CallInst>(U); 81 if (!CI) 82 continue; 83 84 SmallVector<Value *> Args; 85 Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp)); 86 Args.emplace_back(DXILOpArg); 87 B.SetInsertPoint(CI); 88 if (isVectorArgExpansion(F)) { 89 SmallVector<Value *> NewArgs = argVectorFlatten(CI, B); 90 Args.append(NewArgs.begin(), NewArgs.end()); 91 } else 92 Args.append(CI->arg_begin(), CI->arg_end()); 93 94 CallInst *DXILCI = 95 DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args); 96 97 CI->replaceAllUsesWith(DXILCI); 98 CI->eraseFromParent(); 99 } 100 if (F.user_empty()) 101 F.eraseFromParent(); 102 } 103 104 static bool lowerIntrinsics(Module &M) { 105 bool Updated = false; 106 107 #define DXIL_OP_INTRINSIC_MAP 108 #include "DXILOperation.inc" 109 #undef DXIL_OP_INTRINSIC_MAP 110 111 for (Function &F : make_early_inc_range(M.functions())) { 112 if (!F.isDeclaration()) 113 continue; 114 Intrinsic::ID ID = F.getIntrinsicID(); 115 if (ID == Intrinsic::not_intrinsic) 116 continue; 117 auto LowerIt = LowerMap.find(ID); 118 if (LowerIt == LowerMap.end()) 119 continue; 120 lowerIntrinsic(LowerIt->second, F, M); 121 Updated = true; 122 } 123 return Updated; 124 } 125 126 namespace { 127 /// A pass that transforms external global definitions into declarations. 128 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> { 129 public: 130 PreservedAnalyses run(Module &M, ModuleAnalysisManager &) { 131 if (lowerIntrinsics(M)) 132 return PreservedAnalyses::none(); 133 return PreservedAnalyses::all(); 134 } 135 }; 136 } // namespace 137 138 namespace { 139 class DXILOpLoweringLegacy : public ModulePass { 140 public: 141 bool runOnModule(Module &M) override { return lowerIntrinsics(M); } 142 StringRef getPassName() const override { return "DXIL Op Lowering"; } 143 DXILOpLoweringLegacy() : ModulePass(ID) {} 144 145 static char ID; // Pass identification. 146 void getAnalysisUsage(llvm::AnalysisUsage &AU) const override { 147 // Specify the passes that your pass depends on 148 AU.addRequired<DXILIntrinsicExpansionLegacy>(); 149 } 150 }; 151 char DXILOpLoweringLegacy::ID = 0; 152 } // end anonymous namespace 153 154 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", 155 false, false) 156 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false, 157 false) 158 159 ModulePass *llvm::createDXILOpLoweringLegacyPass() { 160 return new DXILOpLoweringLegacy(); 161 } 162