xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILOpLowering.cpp (revision b2d2a78ad80ec68d4a17f5aef97d21686cb1e29b)
1 //===- DXILOpLower.cpp - Lowering LLVM intrinsic to DIXLOp function -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains passes and utilities to lower llvm intrinsic call
10 /// to DXILOp function call.
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILConstants.h"
14 #include "DXILIntrinsicExpansion.h"
15 #include "DXILOpBuilder.h"
16 #include "DirectX.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/IR/IRBuilder.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Intrinsics.h"
22 #include "llvm/IR/IntrinsicsDirectX.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/IR/PassManager.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/ErrorHandling.h"
27 
28 #define DEBUG_TYPE "dxil-op-lower"
29 
30 using namespace llvm;
31 using namespace llvm::dxil;
32 
33 static bool isVectorArgExpansion(Function &F) {
34   switch (F.getIntrinsicID()) {
35   case Intrinsic::dx_dot2:
36   case Intrinsic::dx_dot3:
37   case Intrinsic::dx_dot4:
38     return true;
39   }
40   return false;
41 }
42 
43 static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) {
44   SmallVector<Value *> ExtractedElements;
45   auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
46   for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
47     Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
48     Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
49     ExtractedElements.push_back(ExtractedElement);
50   }
51   return ExtractedElements;
52 }
53 
54 static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
55                                              IRBuilder<> &Builder) {
56   // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
57   unsigned NumOperands = Orig->getNumOperands() - 1;
58   assert(NumOperands > 0);
59   Value *Arg0 = Orig->getOperand(0);
60   [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
61   assert(VecArg0);
62   SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
63   for (unsigned I = 1; I < NumOperands; ++I) {
64     Value *Arg = Orig->getOperand(I);
65     [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
66     assert(VecArg);
67     assert(VecArg0->getElementType() == VecArg->getElementType());
68     assert(VecArg0->getNumElements() == VecArg->getNumElements());
69     auto NextOperandList = populateOperands(Arg, Builder);
70     NewOperands.append(NextOperandList.begin(), NextOperandList.end());
71   }
72   return NewOperands;
73 }
74 
75 static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) {
76   IRBuilder<> B(M.getContext());
77   DXILOpBuilder DXILB(M, B);
78   Type *OverloadTy = DXILB.getOverloadTy(DXILOp, F.getFunctionType());
79   for (User *U : make_early_inc_range(F.users())) {
80     CallInst *CI = dyn_cast<CallInst>(U);
81     if (!CI)
82       continue;
83 
84     SmallVector<Value *> Args;
85     Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp));
86     Args.emplace_back(DXILOpArg);
87     B.SetInsertPoint(CI);
88     if (isVectorArgExpansion(F)) {
89       SmallVector<Value *> NewArgs = argVectorFlatten(CI, B);
90       Args.append(NewArgs.begin(), NewArgs.end());
91     } else
92       Args.append(CI->arg_begin(), CI->arg_end());
93 
94     CallInst *DXILCI =
95         DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args);
96 
97     CI->replaceAllUsesWith(DXILCI);
98     CI->eraseFromParent();
99   }
100   if (F.user_empty())
101     F.eraseFromParent();
102 }
103 
104 static bool lowerIntrinsics(Module &M) {
105   bool Updated = false;
106 
107 #define DXIL_OP_INTRINSIC_MAP
108 #include "DXILOperation.inc"
109 #undef DXIL_OP_INTRINSIC_MAP
110 
111   for (Function &F : make_early_inc_range(M.functions())) {
112     if (!F.isDeclaration())
113       continue;
114     Intrinsic::ID ID = F.getIntrinsicID();
115     if (ID == Intrinsic::not_intrinsic)
116       continue;
117     auto LowerIt = LowerMap.find(ID);
118     if (LowerIt == LowerMap.end())
119       continue;
120     lowerIntrinsic(LowerIt->second, F, M);
121     Updated = true;
122   }
123   return Updated;
124 }
125 
126 namespace {
127 /// A pass that transforms external global definitions into declarations.
128 class DXILOpLowering : public PassInfoMixin<DXILOpLowering> {
129 public:
130   PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
131     if (lowerIntrinsics(M))
132       return PreservedAnalyses::none();
133     return PreservedAnalyses::all();
134   }
135 };
136 } // namespace
137 
138 namespace {
139 class DXILOpLoweringLegacy : public ModulePass {
140 public:
141   bool runOnModule(Module &M) override { return lowerIntrinsics(M); }
142   StringRef getPassName() const override { return "DXIL Op Lowering"; }
143   DXILOpLoweringLegacy() : ModulePass(ID) {}
144 
145   static char ID; // Pass identification.
146   void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
147     // Specify the passes that your pass depends on
148     AU.addRequired<DXILIntrinsicExpansionLegacy>();
149   }
150 };
151 char DXILOpLoweringLegacy::ID = 0;
152 } // end anonymous namespace
153 
154 INITIALIZE_PASS_BEGIN(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering",
155                       false, false)
156 INITIALIZE_PASS_END(DXILOpLoweringLegacy, DEBUG_TYPE, "DXIL Op Lowering", false,
157                     false)
158 
159 ModulePass *llvm::createDXILOpLoweringLegacyPass() {
160   return new DXILOpLoweringLegacy();
161 }
162