xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (revision 5036d9652a5701d00e9e40ea942c278e9f77d33d)
1 //===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains DXIL intrinsic expansions for those that don't have
10 //  opcodes in DirectX Intermediate Language (DXIL).
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILIntrinsicExpansion.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/CodeGen/Passes.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/Instruction.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/Intrinsics.h"
22 #include "llvm/IR/IntrinsicsDirectX.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/IR/PassManager.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/Pass.h"
27 #include "llvm/Support/ErrorHandling.h"
28 #include "llvm/Support/MathExtras.h"
29 
30 #define DEBUG_TYPE "dxil-intrinsic-expansion"
31 
32 using namespace llvm;
33 
34 static bool isIntrinsicExpansion(Function &F) {
35   switch (F.getIntrinsicID()) {
36   case Intrinsic::abs:
37   case Intrinsic::exp:
38   case Intrinsic::log:
39   case Intrinsic::log10:
40   case Intrinsic::pow:
41   case Intrinsic::dx_any:
42   case Intrinsic::dx_clamp:
43   case Intrinsic::dx_uclamp:
44   case Intrinsic::dx_lerp:
45   case Intrinsic::dx_sdot:
46   case Intrinsic::dx_udot:
47     return true;
48   }
49   return false;
50 }
51 
52 static bool expandAbs(CallInst *Orig) {
53   Value *X = Orig->getOperand(0);
54   IRBuilder<> Builder(Orig->getParent());
55   Builder.SetInsertPoint(Orig);
56   Type *Ty = X->getType();
57   Type *EltTy = Ty->getScalarType();
58   Constant *Zero = Ty->isVectorTy()
59                        ? ConstantVector::getSplat(
60                              ElementCount::getFixed(
61                                  cast<FixedVectorType>(Ty)->getNumElements()),
62                              ConstantInt::get(EltTy, 0))
63                        : ConstantInt::get(EltTy, 0);
64   auto *V = Builder.CreateSub(Zero, X);
65   auto *MaxCall =
66       Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
67   Orig->replaceAllUsesWith(MaxCall);
68   Orig->eraseFromParent();
69   return true;
70 }
71 
72 static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
73   assert(DotIntrinsic == Intrinsic::dx_sdot ||
74          DotIntrinsic == Intrinsic::dx_udot);
75   Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
76                                    ? Intrinsic::dx_imad
77                                    : Intrinsic::dx_umad;
78   Value *A = Orig->getOperand(0);
79   Value *B = Orig->getOperand(1);
80   [[maybe_unused]] Type *ATy = A->getType();
81   [[maybe_unused]] Type *BTy = B->getType();
82   assert(ATy->isVectorTy() && BTy->isVectorTy());
83 
84   IRBuilder<> Builder(Orig->getParent());
85   Builder.SetInsertPoint(Orig);
86 
87   auto *AVec = dyn_cast<FixedVectorType>(A->getType());
88   Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
89   Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
90   Value *Result = Builder.CreateMul(Elt0, Elt1);
91   for (unsigned I = 1; I < AVec->getNumElements(); I++) {
92     Elt0 = Builder.CreateExtractElement(A, I);
93     Elt1 = Builder.CreateExtractElement(B, I);
94     Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
95                                      ArrayRef<Value *>{Elt0, Elt1, Result},
96                                      nullptr, "dx.mad");
97   }
98   Orig->replaceAllUsesWith(Result);
99   Orig->eraseFromParent();
100   return true;
101 }
102 
103 static bool expandExpIntrinsic(CallInst *Orig) {
104   Value *X = Orig->getOperand(0);
105   IRBuilder<> Builder(Orig->getParent());
106   Builder.SetInsertPoint(Orig);
107   Type *Ty = X->getType();
108   Type *EltTy = Ty->getScalarType();
109   Constant *Log2eConst =
110       Ty->isVectorTy() ? ConstantVector::getSplat(
111                              ElementCount::getFixed(
112                                  cast<FixedVectorType>(Ty)->getNumElements()),
113                              ConstantFP::get(EltTy, numbers::log2ef))
114                        : ConstantFP::get(EltTy, numbers::log2ef);
115   Value *NewX = Builder.CreateFMul(Log2eConst, X);
116   auto *Exp2Call =
117       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
118   Exp2Call->setTailCall(Orig->isTailCall());
119   Exp2Call->setAttributes(Orig->getAttributes());
120   Orig->replaceAllUsesWith(Exp2Call);
121   Orig->eraseFromParent();
122   return true;
123 }
124 
125 static bool expandAnyIntrinsic(CallInst *Orig) {
126   Value *X = Orig->getOperand(0);
127   IRBuilder<> Builder(Orig->getParent());
128   Builder.SetInsertPoint(Orig);
129   Type *Ty = X->getType();
130   Type *EltTy = Ty->getScalarType();
131 
132   if (!Ty->isVectorTy()) {
133     Value *Cond = EltTy->isFloatingPointTy()
134                       ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
135                       : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
136     Orig->replaceAllUsesWith(Cond);
137   } else {
138     auto *XVec = dyn_cast<FixedVectorType>(Ty);
139     Value *Cond =
140         EltTy->isFloatingPointTy()
141             ? Builder.CreateFCmpUNE(
142                   X, ConstantVector::getSplat(
143                          ElementCount::getFixed(XVec->getNumElements()),
144                          ConstantFP::get(EltTy, 0)))
145             : Builder.CreateICmpNE(
146                   X, ConstantVector::getSplat(
147                          ElementCount::getFixed(XVec->getNumElements()),
148                          ConstantInt::get(EltTy, 0)));
149     Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
150     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
151       Value *Elt = Builder.CreateExtractElement(Cond, I);
152       Result = Builder.CreateOr(Result, Elt);
153     }
154     Orig->replaceAllUsesWith(Result);
155   }
156   Orig->eraseFromParent();
157   return true;
158 }
159 
160 static bool expandLerpIntrinsic(CallInst *Orig) {
161   Value *X = Orig->getOperand(0);
162   Value *Y = Orig->getOperand(1);
163   Value *S = Orig->getOperand(2);
164   IRBuilder<> Builder(Orig->getParent());
165   Builder.SetInsertPoint(Orig);
166   auto *V = Builder.CreateFSub(Y, X);
167   V = Builder.CreateFMul(S, V);
168   auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
169   Orig->replaceAllUsesWith(Result);
170   Orig->eraseFromParent();
171   return true;
172 }
173 
174 static bool expandLogIntrinsic(CallInst *Orig,
175                                float LogConstVal = numbers::ln2f) {
176   Value *X = Orig->getOperand(0);
177   IRBuilder<> Builder(Orig->getParent());
178   Builder.SetInsertPoint(Orig);
179   Type *Ty = X->getType();
180   Type *EltTy = Ty->getScalarType();
181   Constant *Ln2Const =
182       Ty->isVectorTy() ? ConstantVector::getSplat(
183                              ElementCount::getFixed(
184                                  cast<FixedVectorType>(Ty)->getNumElements()),
185                              ConstantFP::get(EltTy, LogConstVal))
186                        : ConstantFP::get(EltTy, LogConstVal);
187   auto *Log2Call =
188       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
189   Log2Call->setTailCall(Orig->isTailCall());
190   Log2Call->setAttributes(Orig->getAttributes());
191   auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);
192   Orig->replaceAllUsesWith(Result);
193   Orig->eraseFromParent();
194   return true;
195 }
196 static bool expandLog10Intrinsic(CallInst *Orig) {
197   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
198 }
199 
200 static bool expandPowIntrinsic(CallInst *Orig) {
201 
202   Value *X = Orig->getOperand(0);
203   Value *Y = Orig->getOperand(1);
204   Type *Ty = X->getType();
205   IRBuilder<> Builder(Orig->getParent());
206   Builder.SetInsertPoint(Orig);
207 
208   auto *Log2Call =
209       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
210   auto *Mul = Builder.CreateFMul(Log2Call, Y);
211   auto *Exp2Call =
212       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
213   Exp2Call->setTailCall(Orig->isTailCall());
214   Exp2Call->setAttributes(Orig->getAttributes());
215   Orig->replaceAllUsesWith(Exp2Call);
216   Orig->eraseFromParent();
217   return true;
218 }
219 
220 static Intrinsic::ID getMaxForClamp(Type *ElemTy,
221                                     Intrinsic::ID ClampIntrinsic) {
222   if (ClampIntrinsic == Intrinsic::dx_uclamp)
223     return Intrinsic::umax;
224   assert(ClampIntrinsic == Intrinsic::dx_clamp);
225   if (ElemTy->isVectorTy())
226     ElemTy = ElemTy->getScalarType();
227   if (ElemTy->isIntegerTy())
228     return Intrinsic::smax;
229   assert(ElemTy->isFloatingPointTy());
230   return Intrinsic::maxnum;
231 }
232 
233 static Intrinsic::ID getMinForClamp(Type *ElemTy,
234                                     Intrinsic::ID ClampIntrinsic) {
235   if (ClampIntrinsic == Intrinsic::dx_uclamp)
236     return Intrinsic::umin;
237   assert(ClampIntrinsic == Intrinsic::dx_clamp);
238   if (ElemTy->isVectorTy())
239     ElemTy = ElemTy->getScalarType();
240   if (ElemTy->isIntegerTy())
241     return Intrinsic::smin;
242   assert(ElemTy->isFloatingPointTy());
243   return Intrinsic::minnum;
244 }
245 
246 static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
247   Value *X = Orig->getOperand(0);
248   Value *Min = Orig->getOperand(1);
249   Value *Max = Orig->getOperand(2);
250   Type *Ty = X->getType();
251   IRBuilder<> Builder(Orig->getParent());
252   Builder.SetInsertPoint(Orig);
253   auto *MaxCall = Builder.CreateIntrinsic(
254       Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
255   auto *MinCall =
256       Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
257                               {MaxCall, Max}, nullptr, "dx.min");
258 
259   Orig->replaceAllUsesWith(MinCall);
260   Orig->eraseFromParent();
261   return true;
262 }
263 
264 static bool expandIntrinsic(Function &F, CallInst *Orig) {
265   switch (F.getIntrinsicID()) {
266   case Intrinsic::abs:
267     return expandAbs(Orig);
268   case Intrinsic::exp:
269     return expandExpIntrinsic(Orig);
270   case Intrinsic::log:
271     return expandLogIntrinsic(Orig);
272   case Intrinsic::log10:
273     return expandLog10Intrinsic(Orig);
274   case Intrinsic::pow:
275     return expandPowIntrinsic(Orig);
276   case Intrinsic::dx_any:
277     return expandAnyIntrinsic(Orig);
278   case Intrinsic::dx_uclamp:
279   case Intrinsic::dx_clamp:
280     return expandClampIntrinsic(Orig, F.getIntrinsicID());
281   case Intrinsic::dx_lerp:
282     return expandLerpIntrinsic(Orig);
283   case Intrinsic::dx_sdot:
284   case Intrinsic::dx_udot:
285     return expandIntegerDot(Orig, F.getIntrinsicID());
286   }
287   return false;
288 }
289 
290 static bool expansionIntrinsics(Module &M) {
291   for (auto &F : make_early_inc_range(M.functions())) {
292     if (!isIntrinsicExpansion(F))
293       continue;
294     bool IntrinsicExpanded = false;
295     for (User *U : make_early_inc_range(F.users())) {
296       auto *IntrinsicCall = dyn_cast<CallInst>(U);
297       if (!IntrinsicCall)
298         continue;
299       IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
300     }
301     if (F.user_empty() && IntrinsicExpanded)
302       F.eraseFromParent();
303   }
304   return true;
305 }
306 
307 PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
308                                               ModuleAnalysisManager &) {
309   if (expansionIntrinsics(M))
310     return PreservedAnalyses::none();
311   return PreservedAnalyses::all();
312 }
313 
314 bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
315   return expansionIntrinsics(M);
316 }
317 
318 char DXILIntrinsicExpansionLegacy::ID = 0;
319 
320 INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
321                       "DXIL Intrinsic Expansion", false, false)
322 INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
323                     "DXIL Intrinsic Expansion", false, false)
324 
325 ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
326   return new DXILIntrinsicExpansionLegacy();
327 }
328