1 //===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains DXIL intrinsic expansions for those that don't have
10 // opcodes in DirectX Intermediate Language (DXIL).
11 //===----------------------------------------------------------------------===//
12
13 #include "DXILIntrinsicExpansion.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/CodeGen/Passes.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/Instruction.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/Intrinsics.h"
22 #include "llvm/IR/IntrinsicsDirectX.h"
23 #include "llvm/IR/Module.h"
24 #include "llvm/IR/PassManager.h"
25 #include "llvm/IR/Type.h"
26 #include "llvm/Pass.h"
27 #include "llvm/Support/ErrorHandling.h"
28 #include "llvm/Support/MathExtras.h"
29
30 #define DEBUG_TYPE "dxil-intrinsic-expansion"
31
32 using namespace llvm;
33
isIntrinsicExpansion(Function & F)34 static bool isIntrinsicExpansion(Function &F) {
35 switch (F.getIntrinsicID()) {
36 case Intrinsic::abs:
37 case Intrinsic::exp:
38 case Intrinsic::log:
39 case Intrinsic::log10:
40 case Intrinsic::pow:
41 case Intrinsic::dx_any:
42 case Intrinsic::dx_clamp:
43 case Intrinsic::dx_uclamp:
44 case Intrinsic::dx_lerp:
45 case Intrinsic::dx_sdot:
46 case Intrinsic::dx_udot:
47 return true;
48 }
49 return false;
50 }
51
expandAbs(CallInst * Orig)52 static bool expandAbs(CallInst *Orig) {
53 Value *X = Orig->getOperand(0);
54 IRBuilder<> Builder(Orig->getParent());
55 Builder.SetInsertPoint(Orig);
56 Type *Ty = X->getType();
57 Type *EltTy = Ty->getScalarType();
58 Constant *Zero = Ty->isVectorTy()
59 ? ConstantVector::getSplat(
60 ElementCount::getFixed(
61 cast<FixedVectorType>(Ty)->getNumElements()),
62 ConstantInt::get(EltTy, 0))
63 : ConstantInt::get(EltTy, 0);
64 auto *V = Builder.CreateSub(Zero, X);
65 auto *MaxCall =
66 Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");
67 Orig->replaceAllUsesWith(MaxCall);
68 Orig->eraseFromParent();
69 return true;
70 }
71
expandIntegerDot(CallInst * Orig,Intrinsic::ID DotIntrinsic)72 static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
73 assert(DotIntrinsic == Intrinsic::dx_sdot ||
74 DotIntrinsic == Intrinsic::dx_udot);
75 Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
76 ? Intrinsic::dx_imad
77 : Intrinsic::dx_umad;
78 Value *A = Orig->getOperand(0);
79 Value *B = Orig->getOperand(1);
80 [[maybe_unused]] Type *ATy = A->getType();
81 [[maybe_unused]] Type *BTy = B->getType();
82 assert(ATy->isVectorTy() && BTy->isVectorTy());
83
84 IRBuilder<> Builder(Orig->getParent());
85 Builder.SetInsertPoint(Orig);
86
87 auto *AVec = dyn_cast<FixedVectorType>(A->getType());
88 Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
89 Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
90 Value *Result = Builder.CreateMul(Elt0, Elt1);
91 for (unsigned I = 1; I < AVec->getNumElements(); I++) {
92 Elt0 = Builder.CreateExtractElement(A, I);
93 Elt1 = Builder.CreateExtractElement(B, I);
94 Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
95 ArrayRef<Value *>{Elt0, Elt1, Result},
96 nullptr, "dx.mad");
97 }
98 Orig->replaceAllUsesWith(Result);
99 Orig->eraseFromParent();
100 return true;
101 }
102
expandExpIntrinsic(CallInst * Orig)103 static bool expandExpIntrinsic(CallInst *Orig) {
104 Value *X = Orig->getOperand(0);
105 IRBuilder<> Builder(Orig->getParent());
106 Builder.SetInsertPoint(Orig);
107 Type *Ty = X->getType();
108 Type *EltTy = Ty->getScalarType();
109 Constant *Log2eConst =
110 Ty->isVectorTy() ? ConstantVector::getSplat(
111 ElementCount::getFixed(
112 cast<FixedVectorType>(Ty)->getNumElements()),
113 ConstantFP::get(EltTy, numbers::log2ef))
114 : ConstantFP::get(EltTy, numbers::log2ef);
115 Value *NewX = Builder.CreateFMul(Log2eConst, X);
116 auto *Exp2Call =
117 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
118 Exp2Call->setTailCall(Orig->isTailCall());
119 Exp2Call->setAttributes(Orig->getAttributes());
120 Orig->replaceAllUsesWith(Exp2Call);
121 Orig->eraseFromParent();
122 return true;
123 }
124
expandAnyIntrinsic(CallInst * Orig)125 static bool expandAnyIntrinsic(CallInst *Orig) {
126 Value *X = Orig->getOperand(0);
127 IRBuilder<> Builder(Orig->getParent());
128 Builder.SetInsertPoint(Orig);
129 Type *Ty = X->getType();
130 Type *EltTy = Ty->getScalarType();
131
132 if (!Ty->isVectorTy()) {
133 Value *Cond = EltTy->isFloatingPointTy()
134 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
135 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
136 Orig->replaceAllUsesWith(Cond);
137 } else {
138 auto *XVec = dyn_cast<FixedVectorType>(Ty);
139 Value *Cond =
140 EltTy->isFloatingPointTy()
141 ? Builder.CreateFCmpUNE(
142 X, ConstantVector::getSplat(
143 ElementCount::getFixed(XVec->getNumElements()),
144 ConstantFP::get(EltTy, 0)))
145 : Builder.CreateICmpNE(
146 X, ConstantVector::getSplat(
147 ElementCount::getFixed(XVec->getNumElements()),
148 ConstantInt::get(EltTy, 0)));
149 Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
150 for (unsigned I = 1; I < XVec->getNumElements(); I++) {
151 Value *Elt = Builder.CreateExtractElement(Cond, I);
152 Result = Builder.CreateOr(Result, Elt);
153 }
154 Orig->replaceAllUsesWith(Result);
155 }
156 Orig->eraseFromParent();
157 return true;
158 }
159
expandLerpIntrinsic(CallInst * Orig)160 static bool expandLerpIntrinsic(CallInst *Orig) {
161 Value *X = Orig->getOperand(0);
162 Value *Y = Orig->getOperand(1);
163 Value *S = Orig->getOperand(2);
164 IRBuilder<> Builder(Orig->getParent());
165 Builder.SetInsertPoint(Orig);
166 auto *V = Builder.CreateFSub(Y, X);
167 V = Builder.CreateFMul(S, V);
168 auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");
169 Orig->replaceAllUsesWith(Result);
170 Orig->eraseFromParent();
171 return true;
172 }
173
expandLogIntrinsic(CallInst * Orig,float LogConstVal=numbers::ln2f)174 static bool expandLogIntrinsic(CallInst *Orig,
175 float LogConstVal = numbers::ln2f) {
176 Value *X = Orig->getOperand(0);
177 IRBuilder<> Builder(Orig->getParent());
178 Builder.SetInsertPoint(Orig);
179 Type *Ty = X->getType();
180 Type *EltTy = Ty->getScalarType();
181 Constant *Ln2Const =
182 Ty->isVectorTy() ? ConstantVector::getSplat(
183 ElementCount::getFixed(
184 cast<FixedVectorType>(Ty)->getNumElements()),
185 ConstantFP::get(EltTy, LogConstVal))
186 : ConstantFP::get(EltTy, LogConstVal);
187 auto *Log2Call =
188 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
189 Log2Call->setTailCall(Orig->isTailCall());
190 Log2Call->setAttributes(Orig->getAttributes());
191 auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);
192 Orig->replaceAllUsesWith(Result);
193 Orig->eraseFromParent();
194 return true;
195 }
expandLog10Intrinsic(CallInst * Orig)196 static bool expandLog10Intrinsic(CallInst *Orig) {
197 return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
198 }
199
expandPowIntrinsic(CallInst * Orig)200 static bool expandPowIntrinsic(CallInst *Orig) {
201
202 Value *X = Orig->getOperand(0);
203 Value *Y = Orig->getOperand(1);
204 Type *Ty = X->getType();
205 IRBuilder<> Builder(Orig->getParent());
206 Builder.SetInsertPoint(Orig);
207
208 auto *Log2Call =
209 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
210 auto *Mul = Builder.CreateFMul(Log2Call, Y);
211 auto *Exp2Call =
212 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
213 Exp2Call->setTailCall(Orig->isTailCall());
214 Exp2Call->setAttributes(Orig->getAttributes());
215 Orig->replaceAllUsesWith(Exp2Call);
216 Orig->eraseFromParent();
217 return true;
218 }
219
getMaxForClamp(Type * ElemTy,Intrinsic::ID ClampIntrinsic)220 static Intrinsic::ID getMaxForClamp(Type *ElemTy,
221 Intrinsic::ID ClampIntrinsic) {
222 if (ClampIntrinsic == Intrinsic::dx_uclamp)
223 return Intrinsic::umax;
224 assert(ClampIntrinsic == Intrinsic::dx_clamp);
225 if (ElemTy->isVectorTy())
226 ElemTy = ElemTy->getScalarType();
227 if (ElemTy->isIntegerTy())
228 return Intrinsic::smax;
229 assert(ElemTy->isFloatingPointTy());
230 return Intrinsic::maxnum;
231 }
232
getMinForClamp(Type * ElemTy,Intrinsic::ID ClampIntrinsic)233 static Intrinsic::ID getMinForClamp(Type *ElemTy,
234 Intrinsic::ID ClampIntrinsic) {
235 if (ClampIntrinsic == Intrinsic::dx_uclamp)
236 return Intrinsic::umin;
237 assert(ClampIntrinsic == Intrinsic::dx_clamp);
238 if (ElemTy->isVectorTy())
239 ElemTy = ElemTy->getScalarType();
240 if (ElemTy->isIntegerTy())
241 return Intrinsic::smin;
242 assert(ElemTy->isFloatingPointTy());
243 return Intrinsic::minnum;
244 }
245
expandClampIntrinsic(CallInst * Orig,Intrinsic::ID ClampIntrinsic)246 static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {
247 Value *X = Orig->getOperand(0);
248 Value *Min = Orig->getOperand(1);
249 Value *Max = Orig->getOperand(2);
250 Type *Ty = X->getType();
251 IRBuilder<> Builder(Orig->getParent());
252 Builder.SetInsertPoint(Orig);
253 auto *MaxCall = Builder.CreateIntrinsic(
254 Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");
255 auto *MinCall =
256 Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),
257 {MaxCall, Max}, nullptr, "dx.min");
258
259 Orig->replaceAllUsesWith(MinCall);
260 Orig->eraseFromParent();
261 return true;
262 }
263
expandIntrinsic(Function & F,CallInst * Orig)264 static bool expandIntrinsic(Function &F, CallInst *Orig) {
265 switch (F.getIntrinsicID()) {
266 case Intrinsic::abs:
267 return expandAbs(Orig);
268 case Intrinsic::exp:
269 return expandExpIntrinsic(Orig);
270 case Intrinsic::log:
271 return expandLogIntrinsic(Orig);
272 case Intrinsic::log10:
273 return expandLog10Intrinsic(Orig);
274 case Intrinsic::pow:
275 return expandPowIntrinsic(Orig);
276 case Intrinsic::dx_any:
277 return expandAnyIntrinsic(Orig);
278 case Intrinsic::dx_uclamp:
279 case Intrinsic::dx_clamp:
280 return expandClampIntrinsic(Orig, F.getIntrinsicID());
281 case Intrinsic::dx_lerp:
282 return expandLerpIntrinsic(Orig);
283 case Intrinsic::dx_sdot:
284 case Intrinsic::dx_udot:
285 return expandIntegerDot(Orig, F.getIntrinsicID());
286 }
287 return false;
288 }
289
expansionIntrinsics(Module & M)290 static bool expansionIntrinsics(Module &M) {
291 for (auto &F : make_early_inc_range(M.functions())) {
292 if (!isIntrinsicExpansion(F))
293 continue;
294 bool IntrinsicExpanded = false;
295 for (User *U : make_early_inc_range(F.users())) {
296 auto *IntrinsicCall = dyn_cast<CallInst>(U);
297 if (!IntrinsicCall)
298 continue;
299 IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
300 }
301 if (F.user_empty() && IntrinsicExpanded)
302 F.eraseFromParent();
303 }
304 return true;
305 }
306
run(Module & M,ModuleAnalysisManager &)307 PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
308 ModuleAnalysisManager &) {
309 if (expansionIntrinsics(M))
310 return PreservedAnalyses::none();
311 return PreservedAnalyses::all();
312 }
313
runOnModule(Module & M)314 bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
315 return expansionIntrinsics(M);
316 }
317
318 char DXILIntrinsicExpansionLegacy::ID = 0;
319
320 INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
321 "DXIL Intrinsic Expansion", false, false)
322 INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
323 "DXIL Intrinsic Expansion", false, false)
324
createDXILIntrinsicExpansionLegacyPass()325 ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
326 return new DXILIntrinsicExpansionLegacy();
327 }
328