1 //===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains DXIL intrinsic expansions for those that don't have
10 // opcodes in DirectX Intermediate Language (DXIL).
11 //===----------------------------------------------------------------------===//
12
13 #include "DXILIntrinsicExpansion.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/CodeGen/Passes.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/InstrTypes.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Intrinsics.h"
23 #include "llvm/IR/IntrinsicsDirectX.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/PassManager.h"
26 #include "llvm/IR/Type.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/MathExtras.h"
31
32 #define DEBUG_TYPE "dxil-intrinsic-expansion"
33
34 using namespace llvm;
35
36 class DXILIntrinsicExpansionLegacy : public ModulePass {
37
38 public:
39 bool runOnModule(Module &M) override;
DXILIntrinsicExpansionLegacy()40 DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
41
42 static char ID; // Pass identification.
43 };
44
resourceAccessNeeds64BitExpansion(Module * M,Type * OverloadTy,bool IsRaw)45 static bool resourceAccessNeeds64BitExpansion(Module *M, Type *OverloadTy,
46 bool IsRaw) {
47 if (IsRaw && M->getTargetTriple().getDXILVersion() > VersionTuple(1, 2))
48 return false;
49
50 Type *ScalarTy = OverloadTy->getScalarType();
51 return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
52 }
53
isIntrinsicExpansion(Function & F)54 static bool isIntrinsicExpansion(Function &F) {
55 switch (F.getIntrinsicID()) {
56 case Intrinsic::abs:
57 case Intrinsic::atan2:
58 case Intrinsic::exp:
59 case Intrinsic::is_fpclass:
60 case Intrinsic::log:
61 case Intrinsic::log10:
62 case Intrinsic::pow:
63 case Intrinsic::powi:
64 case Intrinsic::dx_all:
65 case Intrinsic::dx_any:
66 case Intrinsic::dx_cross:
67 case Intrinsic::dx_uclamp:
68 case Intrinsic::dx_sclamp:
69 case Intrinsic::dx_nclamp:
70 case Intrinsic::dx_degrees:
71 case Intrinsic::dx_lerp:
72 case Intrinsic::dx_normalize:
73 case Intrinsic::dx_fdot:
74 case Intrinsic::dx_sdot:
75 case Intrinsic::dx_udot:
76 case Intrinsic::dx_sign:
77 case Intrinsic::dx_step:
78 case Intrinsic::dx_radians:
79 case Intrinsic::usub_sat:
80 case Intrinsic::vector_reduce_add:
81 case Intrinsic::vector_reduce_fadd:
82 return true;
83 case Intrinsic::dx_resource_load_rawbuffer:
84 return resourceAccessNeeds64BitExpansion(
85 F.getParent(), F.getReturnType()->getStructElementType(0),
86 /*IsRaw*/ true);
87 case Intrinsic::dx_resource_load_typedbuffer:
88 return resourceAccessNeeds64BitExpansion(
89 F.getParent(), F.getReturnType()->getStructElementType(0),
90 /*IsRaw*/ false);
91 case Intrinsic::dx_resource_store_rawbuffer:
92 return resourceAccessNeeds64BitExpansion(
93 F.getParent(), F.getFunctionType()->getParamType(3), /*IsRaw*/ true);
94 case Intrinsic::dx_resource_store_typedbuffer:
95 return resourceAccessNeeds64BitExpansion(
96 F.getParent(), F.getFunctionType()->getParamType(2), /*IsRaw*/ false);
97 }
98 return false;
99 }
100
expandUsubSat(CallInst * Orig)101 static Value *expandUsubSat(CallInst *Orig) {
102 Value *A = Orig->getArgOperand(0);
103 Value *B = Orig->getArgOperand(1);
104 Type *Ty = A->getType();
105
106 IRBuilder<> Builder(Orig);
107
108 Value *Cmp = Builder.CreateICmpULT(A, B, "usub.cmp");
109 Value *Sub = Builder.CreateSub(A, B, "usub.sub");
110 Value *Zero = ConstantInt::get(Ty, 0);
111 return Builder.CreateSelect(Cmp, Zero, Sub, "usub.sat");
112 }
113
expandVecReduceAdd(CallInst * Orig,Intrinsic::ID IntrinsicId)114 static Value *expandVecReduceAdd(CallInst *Orig, Intrinsic::ID IntrinsicId) {
115 assert(IntrinsicId == Intrinsic::vector_reduce_add ||
116 IntrinsicId == Intrinsic::vector_reduce_fadd);
117
118 IRBuilder<> Builder(Orig);
119 bool IsFAdd = (IntrinsicId == Intrinsic::vector_reduce_fadd);
120
121 Value *X = Orig->getOperand(IsFAdd ? 1 : 0);
122 Type *Ty = X->getType();
123 auto *XVec = dyn_cast<FixedVectorType>(Ty);
124 unsigned XVecSize = XVec->getNumElements();
125 Value *Sum = Builder.CreateExtractElement(X, static_cast<uint64_t>(0));
126
127 // Handle the initial start value for floating-point addition.
128 if (IsFAdd) {
129 Constant *StartValue = dyn_cast<Constant>(Orig->getOperand(0));
130 if (StartValue && !StartValue->isZeroValue())
131 Sum = Builder.CreateFAdd(Sum, StartValue);
132 }
133
134 // Accumulate the remaining vector elements.
135 for (unsigned I = 1; I < XVecSize; I++) {
136 Value *Elt = Builder.CreateExtractElement(X, I);
137 if (IsFAdd)
138 Sum = Builder.CreateFAdd(Sum, Elt);
139 else
140 Sum = Builder.CreateAdd(Sum, Elt);
141 }
142
143 return Sum;
144 }
145
expandAbs(CallInst * Orig)146 static Value *expandAbs(CallInst *Orig) {
147 Value *X = Orig->getOperand(0);
148 IRBuilder<> Builder(Orig);
149 Type *Ty = X->getType();
150 Type *EltTy = Ty->getScalarType();
151 Constant *Zero = Ty->isVectorTy()
152 ? ConstantVector::getSplat(
153 ElementCount::getFixed(
154 cast<FixedVectorType>(Ty)->getNumElements()),
155 ConstantInt::get(EltTy, 0))
156 : ConstantInt::get(EltTy, 0);
157 auto *V = Builder.CreateSub(Zero, X);
158 return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr,
159 "dx.max");
160 }
161
expandCrossIntrinsic(CallInst * Orig)162 static Value *expandCrossIntrinsic(CallInst *Orig) {
163
164 VectorType *VT = cast<VectorType>(Orig->getType());
165 if (cast<FixedVectorType>(VT)->getNumElements() != 3)
166 reportFatalUsageError("return vector must have exactly 3 elements");
167
168 Value *op0 = Orig->getOperand(0);
169 Value *op1 = Orig->getOperand(1);
170 IRBuilder<> Builder(Orig);
171
172 Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0, "x0");
173 Value *op0_y = Builder.CreateExtractElement(op0, 1, "x1");
174 Value *op0_z = Builder.CreateExtractElement(op0, 2, "x2");
175
176 Value *op1_x = Builder.CreateExtractElement(op1, (uint64_t)0, "y0");
177 Value *op1_y = Builder.CreateExtractElement(op1, 1, "y1");
178 Value *op1_z = Builder.CreateExtractElement(op1, 2, "y2");
179
180 auto MulSub = [&](Value *x0, Value *y0, Value *x1, Value *y1) -> Value * {
181 Value *xy = Builder.CreateFMul(x0, y1);
182 Value *yx = Builder.CreateFMul(y0, x1);
183 return Builder.CreateFSub(xy, yx, Orig->getName());
184 };
185
186 Value *yz_zy = MulSub(op0_y, op0_z, op1_y, op1_z);
187 Value *zx_xz = MulSub(op0_z, op0_x, op1_z, op1_x);
188 Value *xy_yx = MulSub(op0_x, op0_y, op1_x, op1_y);
189
190 Value *cross = PoisonValue::get(VT);
191 cross = Builder.CreateInsertElement(cross, yz_zy, (uint64_t)0);
192 cross = Builder.CreateInsertElement(cross, zx_xz, 1);
193 cross = Builder.CreateInsertElement(cross, xy_yx, 2);
194 return cross;
195 }
196
197 // Create appropriate DXIL float dot intrinsic for the given A and B operands
198 // The appropriate opcode will be determined by the size of the operands
199 // The dot product is placed in the position indicated by Orig
expandFloatDotIntrinsic(CallInst * Orig,Value * A,Value * B)200 static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
201 Type *ATy = A->getType();
202 [[maybe_unused]] Type *BTy = B->getType();
203 assert(ATy->isVectorTy() && BTy->isVectorTy());
204
205 IRBuilder<> Builder(Orig);
206
207 auto *AVec = dyn_cast<FixedVectorType>(ATy);
208
209 assert(ATy->getScalarType()->isFloatingPointTy());
210
211 Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
212 int NumElts = AVec->getNumElements();
213 switch (NumElts) {
214 case 2:
215 DotIntrinsic = Intrinsic::dx_dot2;
216 break;
217 case 3:
218 DotIntrinsic = Intrinsic::dx_dot3;
219 break;
220 case 4:
221 DotIntrinsic = Intrinsic::dx_dot4;
222 break;
223 default:
224 reportFatalUsageError(
225 "Invalid dot product input vector: length is outside 2-4");
226 return nullptr;
227 }
228
229 SmallVector<Value *> Args;
230 for (int I = 0; I < NumElts; ++I)
231 Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I)));
232 for (int I = 0; I < NumElts; ++I)
233 Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I)));
234 return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args,
235 nullptr, "dot");
236 }
237
238 // Create the appropriate DXIL float dot intrinsic for the operands of Orig
239 // The appropriate opcode will be determined by the size of the operands
240 // The dot product is placed in the position indicated by Orig
expandFloatDotIntrinsic(CallInst * Orig)241 static Value *expandFloatDotIntrinsic(CallInst *Orig) {
242 return expandFloatDotIntrinsic(Orig, Orig->getOperand(0),
243 Orig->getOperand(1));
244 }
245
246 // Expand integer dot product to multiply and add ops
expandIntegerDotIntrinsic(CallInst * Orig,Intrinsic::ID DotIntrinsic)247 static Value *expandIntegerDotIntrinsic(CallInst *Orig,
248 Intrinsic::ID DotIntrinsic) {
249 assert(DotIntrinsic == Intrinsic::dx_sdot ||
250 DotIntrinsic == Intrinsic::dx_udot);
251 Value *A = Orig->getOperand(0);
252 Value *B = Orig->getOperand(1);
253 Type *ATy = A->getType();
254 [[maybe_unused]] Type *BTy = B->getType();
255 assert(ATy->isVectorTy() && BTy->isVectorTy());
256
257 IRBuilder<> Builder(Orig);
258
259 auto *AVec = dyn_cast<FixedVectorType>(ATy);
260
261 assert(ATy->getScalarType()->isIntegerTy());
262
263 Value *Result;
264 Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
265 ? Intrinsic::dx_imad
266 : Intrinsic::dx_umad;
267 Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
268 Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
269 Result = Builder.CreateMul(Elt0, Elt1);
270 for (unsigned I = 1; I < AVec->getNumElements(); I++) {
271 Elt0 = Builder.CreateExtractElement(A, I);
272 Elt1 = Builder.CreateExtractElement(B, I);
273 Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
274 ArrayRef<Value *>{Elt0, Elt1, Result},
275 nullptr, "dx.mad");
276 }
277 return Result;
278 }
279
expandExpIntrinsic(CallInst * Orig)280 static Value *expandExpIntrinsic(CallInst *Orig) {
281 Value *X = Orig->getOperand(0);
282 IRBuilder<> Builder(Orig);
283 Type *Ty = X->getType();
284 Type *EltTy = Ty->getScalarType();
285 Constant *Log2eConst =
286 Ty->isVectorTy() ? ConstantVector::getSplat(
287 ElementCount::getFixed(
288 cast<FixedVectorType>(Ty)->getNumElements()),
289 ConstantFP::get(EltTy, numbers::log2ef))
290 : ConstantFP::get(EltTy, numbers::log2ef);
291 Value *NewX = Builder.CreateFMul(Log2eConst, X);
292 auto *Exp2Call =
293 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
294 Exp2Call->setTailCall(Orig->isTailCall());
295 Exp2Call->setAttributes(Orig->getAttributes());
296 return Exp2Call;
297 }
298
expandIsFPClass(CallInst * Orig)299 static Value *expandIsFPClass(CallInst *Orig) {
300 Value *T = Orig->getArgOperand(1);
301 auto *TCI = dyn_cast<ConstantInt>(T);
302
303 // These FPClassTest cases have DXIL opcodes, so they will be handled in
304 // DXIL Op Lowering instead.
305 switch (TCI->getZExtValue()) {
306 case FPClassTest::fcInf:
307 case FPClassTest::fcNan:
308 case FPClassTest::fcNormal:
309 case FPClassTest::fcFinite:
310 return nullptr;
311 }
312
313 IRBuilder<> Builder(Orig);
314
315 Value *F = Orig->getArgOperand(0);
316 Type *FTy = F->getType();
317 unsigned FNumElem = 0; // 0 => F is not a vector
318
319 unsigned BitWidth; // Bit width of F or the ElemTy of F
320 Type *BitCastTy; // An IntNTy of the same bitwidth as F or ElemTy of F
321
322 if (auto *FVecTy = dyn_cast<FixedVectorType>(FTy)) {
323 Type *ElemTy = FVecTy->getElementType();
324 FNumElem = FVecTy->getNumElements();
325 BitWidth = ElemTy->getPrimitiveSizeInBits();
326 BitCastTy = FixedVectorType::get(Builder.getIntNTy(BitWidth), FNumElem);
327 } else {
328 BitWidth = FTy->getPrimitiveSizeInBits();
329 BitCastTy = Builder.getIntNTy(BitWidth);
330 }
331
332 Value *FBitCast = Builder.CreateBitCast(F, BitCastTy);
333 switch (TCI->getZExtValue()) {
334 case FPClassTest::fcNegZero: {
335 Value *NegZero =
336 ConstantInt::get(Builder.getIntNTy(BitWidth), 1 << (BitWidth - 1));
337 Value *RetVal;
338 if (FNumElem) {
339 Value *NegZeroSplat = Builder.CreateVectorSplat(FNumElem, NegZero);
340 RetVal =
341 Builder.CreateICmpEQ(FBitCast, NegZeroSplat, "is.fpclass.negzero");
342 } else
343 RetVal = Builder.CreateICmpEQ(FBitCast, NegZero, "is.fpclass.negzero");
344 return RetVal;
345 }
346 default:
347 reportFatalUsageError("Unsupported FPClassTest");
348 }
349 }
350
expandAnyOrAllIntrinsic(CallInst * Orig,Intrinsic::ID IntrinsicId)351 static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
352 Intrinsic::ID IntrinsicId) {
353 Value *X = Orig->getOperand(0);
354 IRBuilder<> Builder(Orig);
355 Type *Ty = X->getType();
356 Type *EltTy = Ty->getScalarType();
357
358 auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
359 Value *Elt) {
360 if (IntrinsicId == Intrinsic::dx_any)
361 return Builder.CreateOr(Result, Elt);
362 assert(IntrinsicId == Intrinsic::dx_all);
363 return Builder.CreateAnd(Result, Elt);
364 };
365
366 Value *Result = nullptr;
367 if (!Ty->isVectorTy()) {
368 Result = EltTy->isFloatingPointTy()
369 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
370 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
371 } else {
372 auto *XVec = dyn_cast<FixedVectorType>(Ty);
373 Value *Cond =
374 EltTy->isFloatingPointTy()
375 ? Builder.CreateFCmpUNE(
376 X, ConstantVector::getSplat(
377 ElementCount::getFixed(XVec->getNumElements()),
378 ConstantFP::get(EltTy, 0)))
379 : Builder.CreateICmpNE(
380 X, ConstantVector::getSplat(
381 ElementCount::getFixed(XVec->getNumElements()),
382 ConstantInt::get(EltTy, 0)));
383 Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
384 for (unsigned I = 1; I < XVec->getNumElements(); I++) {
385 Value *Elt = Builder.CreateExtractElement(Cond, I);
386 Result = ApplyOp(IntrinsicId, Result, Elt);
387 }
388 }
389 return Result;
390 }
391
expandLerpIntrinsic(CallInst * Orig)392 static Value *expandLerpIntrinsic(CallInst *Orig) {
393 Value *X = Orig->getOperand(0);
394 Value *Y = Orig->getOperand(1);
395 Value *S = Orig->getOperand(2);
396 IRBuilder<> Builder(Orig);
397 auto *V = Builder.CreateFSub(Y, X);
398 V = Builder.CreateFMul(S, V);
399 return Builder.CreateFAdd(X, V, "dx.lerp");
400 }
401
expandLogIntrinsic(CallInst * Orig,float LogConstVal=numbers::ln2f)402 static Value *expandLogIntrinsic(CallInst *Orig,
403 float LogConstVal = numbers::ln2f) {
404 Value *X = Orig->getOperand(0);
405 IRBuilder<> Builder(Orig);
406 Type *Ty = X->getType();
407 Type *EltTy = Ty->getScalarType();
408 Constant *Ln2Const =
409 Ty->isVectorTy() ? ConstantVector::getSplat(
410 ElementCount::getFixed(
411 cast<FixedVectorType>(Ty)->getNumElements()),
412 ConstantFP::get(EltTy, LogConstVal))
413 : ConstantFP::get(EltTy, LogConstVal);
414 auto *Log2Call =
415 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
416 Log2Call->setTailCall(Orig->isTailCall());
417 Log2Call->setAttributes(Orig->getAttributes());
418 return Builder.CreateFMul(Ln2Const, Log2Call);
419 }
expandLog10Intrinsic(CallInst * Orig)420 static Value *expandLog10Intrinsic(CallInst *Orig) {
421 return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
422 }
423
424 // Use dot product of vector operand with itself to calculate the length.
425 // Divide the vector by that length to normalize it.
expandNormalizeIntrinsic(CallInst * Orig)426 static Value *expandNormalizeIntrinsic(CallInst *Orig) {
427 Value *X = Orig->getOperand(0);
428 Type *Ty = Orig->getType();
429 Type *EltTy = Ty->getScalarType();
430 IRBuilder<> Builder(Orig);
431
432 auto *XVec = dyn_cast<FixedVectorType>(Ty);
433 if (!XVec) {
434 if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
435 const APFloat &fpVal = constantFP->getValueAPF();
436 if (fpVal.isZero())
437 reportFatalUsageError("Invalid input scalar: length is zero");
438 }
439 return Builder.CreateFDiv(X, X);
440 }
441
442 Value *DotProduct = expandFloatDotIntrinsic(Orig, X, X);
443
444 // verify that the length is non-zero
445 // (if the dot product is non-zero, then the length is non-zero)
446 if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
447 const APFloat &fpVal = constantFP->getValueAPF();
448 if (fpVal.isZero())
449 reportFatalUsageError("Invalid input vector: length is zero");
450 }
451
452 Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
453 ArrayRef<Value *>{DotProduct},
454 nullptr, "dx.rsqrt");
455
456 Value *MultiplicandVec =
457 Builder.CreateVectorSplat(XVec->getNumElements(), Multiplicand);
458 return Builder.CreateFMul(X, MultiplicandVec);
459 }
460
expandAtan2Intrinsic(CallInst * Orig)461 static Value *expandAtan2Intrinsic(CallInst *Orig) {
462 Value *Y = Orig->getOperand(0);
463 Value *X = Orig->getOperand(1);
464 Type *Ty = X->getType();
465 IRBuilder<> Builder(Orig);
466 Builder.setFastMathFlags(Orig->getFastMathFlags());
467
468 Value *Tan = Builder.CreateFDiv(Y, X);
469
470 CallInst *Atan =
471 Builder.CreateIntrinsic(Ty, Intrinsic::atan, {Tan}, nullptr, "Elt.Atan");
472 Atan->setTailCall(Orig->isTailCall());
473 Atan->setAttributes(Orig->getAttributes());
474
475 // Modify atan result based on https://en.wikipedia.org/wiki/Atan2.
476 Constant *Pi = ConstantFP::get(Ty, llvm::numbers::pi);
477 Constant *HalfPi = ConstantFP::get(Ty, llvm::numbers::pi / 2);
478 Constant *NegHalfPi = ConstantFP::get(Ty, -llvm::numbers::pi / 2);
479 Constant *Zero = ConstantFP::get(Ty, 0);
480 Value *AtanAddPi = Builder.CreateFAdd(Atan, Pi);
481 Value *AtanSubPi = Builder.CreateFSub(Atan, Pi);
482
483 // x > 0 -> atan.
484 Value *Result = Atan;
485 Value *XLt0 = Builder.CreateFCmpOLT(X, Zero);
486 Value *XEq0 = Builder.CreateFCmpOEQ(X, Zero);
487 Value *YGe0 = Builder.CreateFCmpOGE(Y, Zero);
488 Value *YLt0 = Builder.CreateFCmpOLT(Y, Zero);
489
490 // x < 0, y >= 0 -> atan + pi.
491 Value *XLt0AndYGe0 = Builder.CreateAnd(XLt0, YGe0);
492 Result = Builder.CreateSelect(XLt0AndYGe0, AtanAddPi, Result);
493
494 // x < 0, y < 0 -> atan - pi.
495 Value *XLt0AndYLt0 = Builder.CreateAnd(XLt0, YLt0);
496 Result = Builder.CreateSelect(XLt0AndYLt0, AtanSubPi, Result);
497
498 // x == 0, y < 0 -> -pi/2
499 Value *XEq0AndYLt0 = Builder.CreateAnd(XEq0, YLt0);
500 Result = Builder.CreateSelect(XEq0AndYLt0, NegHalfPi, Result);
501
502 // x == 0, y > 0 -> pi/2
503 Value *XEq0AndYGe0 = Builder.CreateAnd(XEq0, YGe0);
504 Result = Builder.CreateSelect(XEq0AndYGe0, HalfPi, Result);
505
506 return Result;
507 }
508
expandPowIntrinsic(CallInst * Orig,Intrinsic::ID IntrinsicId)509 static Value *expandPowIntrinsic(CallInst *Orig, Intrinsic::ID IntrinsicId) {
510
511 Value *X = Orig->getOperand(0);
512 Value *Y = Orig->getOperand(1);
513 Type *Ty = X->getType();
514 IRBuilder<> Builder(Orig);
515
516 if (IntrinsicId == Intrinsic::powi)
517 Y = Builder.CreateSIToFP(Y, Ty);
518
519 auto *Log2Call =
520 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
521 auto *Mul = Builder.CreateFMul(Log2Call, Y);
522 auto *Exp2Call =
523 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
524 Exp2Call->setTailCall(Orig->isTailCall());
525 Exp2Call->setAttributes(Orig->getAttributes());
526 return Exp2Call;
527 }
528
expandStepIntrinsic(CallInst * Orig)529 static Value *expandStepIntrinsic(CallInst *Orig) {
530
531 Value *X = Orig->getOperand(0);
532 Value *Y = Orig->getOperand(1);
533 Type *Ty = X->getType();
534 IRBuilder<> Builder(Orig);
535
536 Constant *One = ConstantFP::get(Ty->getScalarType(), 1.0);
537 Constant *Zero = ConstantFP::get(Ty->getScalarType(), 0.0);
538 Value *Cond = Builder.CreateFCmpOLT(Y, X);
539
540 if (Ty != Ty->getScalarType()) {
541 auto *XVec = dyn_cast<FixedVectorType>(Ty);
542 One = ConstantVector::getSplat(
543 ElementCount::getFixed(XVec->getNumElements()), One);
544 Zero = ConstantVector::getSplat(
545 ElementCount::getFixed(XVec->getNumElements()), Zero);
546 }
547
548 return Builder.CreateSelect(Cond, Zero, One);
549 }
550
expandRadiansIntrinsic(CallInst * Orig)551 static Value *expandRadiansIntrinsic(CallInst *Orig) {
552 Value *X = Orig->getOperand(0);
553 Type *Ty = X->getType();
554 IRBuilder<> Builder(Orig);
555 Value *PiOver180 = ConstantFP::get(Ty, llvm::numbers::pi / 180.0);
556 return Builder.CreateFMul(X, PiOver180);
557 }
558
expandBufferLoadIntrinsic(CallInst * Orig,bool IsRaw)559 static bool expandBufferLoadIntrinsic(CallInst *Orig, bool IsRaw) {
560 IRBuilder<> Builder(Orig);
561
562 Type *BufferTy = Orig->getType()->getStructElementType(0);
563 Type *ScalarTy = BufferTy->getScalarType();
564 bool IsDouble = ScalarTy->isDoubleTy();
565 assert(IsDouble || ScalarTy->isIntegerTy(64) &&
566 "Only expand double or int64 scalars or vectors");
567 bool IsVector = false;
568 unsigned ExtractNum = 2;
569 if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
570 ExtractNum = 2 * VT->getNumElements();
571 IsVector = true;
572 assert(IsRaw || ExtractNum == 4 && "TypedBufferLoad vector must be size 2");
573 }
574
575 SmallVector<Value *, 2> Loads;
576 Value *Result = PoisonValue::get(BufferTy);
577 unsigned Base = 0;
578 // If we need to extract more than 4 i32; we need to break it up into
579 // more than one load. LoadNum tells us how many i32s we are loading in
580 // each load
581 while (ExtractNum > 0) {
582 unsigned LoadNum = std::min(ExtractNum, 4u);
583 Type *Ty = VectorType::get(Builder.getInt32Ty(), LoadNum, false);
584
585 Type *LoadType = StructType::get(Ty, Builder.getInt1Ty());
586 Intrinsic::ID LoadIntrinsic = Intrinsic::dx_resource_load_typedbuffer;
587 SmallVector<Value *, 3> Args = {Orig->getOperand(0), Orig->getOperand(1)};
588 if (IsRaw) {
589 LoadIntrinsic = Intrinsic::dx_resource_load_rawbuffer;
590 Value *Tmp = Builder.getInt32(4 * Base * 2);
591 Args.push_back(Builder.CreateAdd(Orig->getOperand(2), Tmp));
592 }
593
594 CallInst *Load = Builder.CreateIntrinsic(LoadType, LoadIntrinsic, Args);
595 Loads.push_back(Load);
596
597 // extract the buffer load's result
598 Value *Extract = Builder.CreateExtractValue(Load, {0});
599
600 SmallVector<Value *> ExtractElements;
601 for (unsigned I = 0; I < LoadNum; ++I)
602 ExtractElements.push_back(
603 Builder.CreateExtractElement(Extract, Builder.getInt32(I)));
604
605 // combine into double(s) or int64(s)
606 for (unsigned I = 0; I < LoadNum; I += 2) {
607 Value *Combined = nullptr;
608 if (IsDouble)
609 // For doubles, use dx_asdouble intrinsic
610 Combined = Builder.CreateIntrinsic(
611 Builder.getDoubleTy(), Intrinsic::dx_asdouble,
612 {ExtractElements[I], ExtractElements[I + 1]});
613 else {
614 // For int64, manually combine two int32s
615 // First, zero-extend both values to i64
616 Value *Lo =
617 Builder.CreateZExt(ExtractElements[I], Builder.getInt64Ty());
618 Value *Hi =
619 Builder.CreateZExt(ExtractElements[I + 1], Builder.getInt64Ty());
620 // Shift the high bits left by 32 bits
621 Value *ShiftedHi = Builder.CreateShl(Hi, Builder.getInt64(32));
622 // OR the high and low bits together
623 Combined = Builder.CreateOr(Lo, ShiftedHi);
624 }
625
626 if (IsVector)
627 Result = Builder.CreateInsertElement(Result, Combined,
628 Builder.getInt32((I / 2) + Base));
629 else
630 Result = Combined;
631 }
632
633 ExtractNum -= LoadNum;
634 Base += LoadNum / 2;
635 }
636
637 Value *CheckBit = nullptr;
638 for (User *U : make_early_inc_range(Orig->users())) {
639 // If it's not a ExtractValueInst, we don't know how to
640 // handle it
641 auto *EVI = dyn_cast<ExtractValueInst>(U);
642 if (!EVI)
643 llvm_unreachable("Unexpected user of typedbufferload");
644
645 ArrayRef<unsigned> Indices = EVI->getIndices();
646 assert(Indices.size() == 1);
647
648 if (Indices[0] == 0) {
649 // Use of the value(s)
650 EVI->replaceAllUsesWith(Result);
651 } else {
652 // Use of the check bit
653 assert(Indices[0] == 1 && "Unexpected type for typedbufferload");
654 // Note: This does not always match the historical behaviour of DXC.
655 // See https://github.com/microsoft/DirectXShaderCompiler/issues/7622
656 if (!CheckBit) {
657 SmallVector<Value *, 2> CheckBits;
658 for (Value *L : Loads)
659 CheckBits.push_back(Builder.CreateExtractValue(L, {1}));
660 CheckBit = Builder.CreateAnd(CheckBits);
661 }
662 EVI->replaceAllUsesWith(CheckBit);
663 }
664 EVI->eraseFromParent();
665 }
666 Orig->eraseFromParent();
667 return true;
668 }
669
expandBufferStoreIntrinsic(CallInst * Orig,bool IsRaw)670 static bool expandBufferStoreIntrinsic(CallInst *Orig, bool IsRaw) {
671 IRBuilder<> Builder(Orig);
672
673 unsigned ValIndex = IsRaw ? 3 : 2;
674 Type *BufferTy = Orig->getFunctionType()->getParamType(ValIndex);
675 Type *ScalarTy = BufferTy->getScalarType();
676 bool IsDouble = ScalarTy->isDoubleTy();
677 assert((IsDouble || ScalarTy->isIntegerTy(64)) &&
678 "Only expand double or int64 scalars or vectors");
679
680 // Determine if we're dealing with a vector or scalar
681 bool IsVector = false;
682 unsigned ExtractNum = 2;
683 unsigned VecLen = 0;
684 if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
685 VecLen = VT->getNumElements();
686 assert(IsRaw || VecLen == 2 && "TypedBufferStore vector must be size 2");
687 ExtractNum = VecLen * 2;
688 IsVector = true;
689 }
690
691 // Create the appropriate vector type for the result
692 Type *Int32Ty = Builder.getInt32Ty();
693 Type *ResultTy = VectorType::get(Int32Ty, ExtractNum, false);
694 Value *Val = PoisonValue::get(ResultTy);
695
696 Type *SplitElementTy = Int32Ty;
697 if (IsVector)
698 SplitElementTy = VectorType::get(SplitElementTy, VecLen, false);
699
700 Value *LowBits = nullptr;
701 Value *HighBits = nullptr;
702 // Split the 64-bit values into 32-bit components
703 if (IsDouble) {
704 auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy);
705 Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble,
706 {Orig->getOperand(ValIndex)});
707 LowBits = Builder.CreateExtractValue(Split, 0);
708 HighBits = Builder.CreateExtractValue(Split, 1);
709 } else {
710 // Handle int64 type(s)
711 Value *InputVal = Orig->getOperand(ValIndex);
712 Constant *ShiftAmt = Builder.getInt64(32);
713 if (IsVector)
714 ShiftAmt =
715 ConstantVector::getSplat(ElementCount::getFixed(VecLen), ShiftAmt);
716
717 // Split into low and high 32-bit parts
718 LowBits = Builder.CreateTrunc(InputVal, SplitElementTy);
719 Value *ShiftedVal = Builder.CreateLShr(InputVal, ShiftAmt);
720 HighBits = Builder.CreateTrunc(ShiftedVal, SplitElementTy);
721 }
722
723 if (IsVector) {
724 SmallVector<int, 8> Mask;
725 for (unsigned I = 0; I < VecLen; ++I) {
726 Mask.push_back(I);
727 Mask.push_back(I + VecLen);
728 }
729 Val = Builder.CreateShuffleVector(LowBits, HighBits, Mask);
730 } else {
731 Val = Builder.CreateInsertElement(Val, LowBits, Builder.getInt32(0));
732 Val = Builder.CreateInsertElement(Val, HighBits, Builder.getInt32(1));
733 }
734
735 // If we need to extract more than 4 i32; we need to break it up into
736 // more than one store. StoreNum tells us how many i32s we are storing in
737 // each store
738 unsigned Base = 0;
739 while (ExtractNum > 0) {
740 unsigned StoreNum = std::min(ExtractNum, 4u);
741
742 Intrinsic::ID StoreIntrinsic = Intrinsic::dx_resource_store_typedbuffer;
743 SmallVector<Value *, 4> Args = {Orig->getOperand(0), Orig->getOperand(1)};
744 if (IsRaw) {
745 StoreIntrinsic = Intrinsic::dx_resource_store_rawbuffer;
746 Value *Tmp = Builder.getInt32(4 * Base);
747 Args.push_back(Builder.CreateAdd(Orig->getOperand(2), Tmp));
748 }
749
750 SmallVector<int, 4> Mask;
751 for (unsigned I = 0; I < StoreNum; ++I) {
752 Mask.push_back(Base + I);
753 }
754
755 Value *SubVal = Val;
756 if (VecLen > 2)
757 SubVal = Builder.CreateShuffleVector(Val, Mask);
758
759 Args.push_back(SubVal);
760 // Create the final intrinsic call
761 Builder.CreateIntrinsic(Builder.getVoidTy(), StoreIntrinsic, Args);
762
763 ExtractNum -= StoreNum;
764 Base += StoreNum;
765 }
766 Orig->eraseFromParent();
767 return true;
768 }
769
getMaxForClamp(Intrinsic::ID ClampIntrinsic)770 static Intrinsic::ID getMaxForClamp(Intrinsic::ID ClampIntrinsic) {
771 if (ClampIntrinsic == Intrinsic::dx_uclamp)
772 return Intrinsic::umax;
773 if (ClampIntrinsic == Intrinsic::dx_sclamp)
774 return Intrinsic::smax;
775 assert(ClampIntrinsic == Intrinsic::dx_nclamp);
776 return Intrinsic::maxnum;
777 }
778
getMinForClamp(Intrinsic::ID ClampIntrinsic)779 static Intrinsic::ID getMinForClamp(Intrinsic::ID ClampIntrinsic) {
780 if (ClampIntrinsic == Intrinsic::dx_uclamp)
781 return Intrinsic::umin;
782 if (ClampIntrinsic == Intrinsic::dx_sclamp)
783 return Intrinsic::smin;
784 assert(ClampIntrinsic == Intrinsic::dx_nclamp);
785 return Intrinsic::minnum;
786 }
787
expandClampIntrinsic(CallInst * Orig,Intrinsic::ID ClampIntrinsic)788 static Value *expandClampIntrinsic(CallInst *Orig,
789 Intrinsic::ID ClampIntrinsic) {
790 Value *X = Orig->getOperand(0);
791 Value *Min = Orig->getOperand(1);
792 Value *Max = Orig->getOperand(2);
793 Type *Ty = X->getType();
794 IRBuilder<> Builder(Orig);
795 auto *MaxCall = Builder.CreateIntrinsic(Ty, getMaxForClamp(ClampIntrinsic),
796 {X, Min}, nullptr, "dx.max");
797 return Builder.CreateIntrinsic(Ty, getMinForClamp(ClampIntrinsic),
798 {MaxCall, Max}, nullptr, "dx.min");
799 }
800
expandDegreesIntrinsic(CallInst * Orig)801 static Value *expandDegreesIntrinsic(CallInst *Orig) {
802 Value *X = Orig->getOperand(0);
803 Type *Ty = X->getType();
804 IRBuilder<> Builder(Orig);
805 Value *DegreesRatio = ConstantFP::get(Ty, 180.0 * llvm::numbers::inv_pi);
806 return Builder.CreateFMul(X, DegreesRatio);
807 }
808
expandSignIntrinsic(CallInst * Orig)809 static Value *expandSignIntrinsic(CallInst *Orig) {
810 Value *X = Orig->getOperand(0);
811 Type *Ty = X->getType();
812 Type *ScalarTy = Ty->getScalarType();
813 Type *RetTy = Orig->getType();
814 Constant *Zero = Constant::getNullValue(Ty);
815
816 IRBuilder<> Builder(Orig);
817
818 Value *GT;
819 Value *LT;
820 if (ScalarTy->isFloatingPointTy()) {
821 GT = Builder.CreateFCmpOLT(Zero, X);
822 LT = Builder.CreateFCmpOLT(X, Zero);
823 } else {
824 assert(ScalarTy->isIntegerTy());
825 GT = Builder.CreateICmpSLT(Zero, X);
826 LT = Builder.CreateICmpSLT(X, Zero);
827 }
828
829 Value *ZextGT = Builder.CreateZExt(GT, RetTy);
830 Value *ZextLT = Builder.CreateZExt(LT, RetTy);
831
832 return Builder.CreateSub(ZextGT, ZextLT);
833 }
834
expandIntrinsic(Function & F,CallInst * Orig)835 static bool expandIntrinsic(Function &F, CallInst *Orig) {
836 Value *Result = nullptr;
837 Intrinsic::ID IntrinsicId = F.getIntrinsicID();
838 switch (IntrinsicId) {
839 case Intrinsic::abs:
840 Result = expandAbs(Orig);
841 break;
842 case Intrinsic::atan2:
843 Result = expandAtan2Intrinsic(Orig);
844 break;
845 case Intrinsic::exp:
846 Result = expandExpIntrinsic(Orig);
847 break;
848 case Intrinsic::is_fpclass:
849 Result = expandIsFPClass(Orig);
850 break;
851 case Intrinsic::log:
852 Result = expandLogIntrinsic(Orig);
853 break;
854 case Intrinsic::log10:
855 Result = expandLog10Intrinsic(Orig);
856 break;
857 case Intrinsic::pow:
858 case Intrinsic::powi:
859 Result = expandPowIntrinsic(Orig, IntrinsicId);
860 break;
861 case Intrinsic::dx_all:
862 case Intrinsic::dx_any:
863 Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
864 break;
865 case Intrinsic::dx_cross:
866 Result = expandCrossIntrinsic(Orig);
867 break;
868 case Intrinsic::dx_uclamp:
869 case Intrinsic::dx_sclamp:
870 case Intrinsic::dx_nclamp:
871 Result = expandClampIntrinsic(Orig, IntrinsicId);
872 break;
873 case Intrinsic::dx_degrees:
874 Result = expandDegreesIntrinsic(Orig);
875 break;
876 case Intrinsic::dx_lerp:
877 Result = expandLerpIntrinsic(Orig);
878 break;
879 case Intrinsic::dx_normalize:
880 Result = expandNormalizeIntrinsic(Orig);
881 break;
882 case Intrinsic::dx_fdot:
883 Result = expandFloatDotIntrinsic(Orig);
884 break;
885 case Intrinsic::dx_sdot:
886 case Intrinsic::dx_udot:
887 Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
888 break;
889 case Intrinsic::dx_sign:
890 Result = expandSignIntrinsic(Orig);
891 break;
892 case Intrinsic::dx_step:
893 Result = expandStepIntrinsic(Orig);
894 break;
895 case Intrinsic::dx_radians:
896 Result = expandRadiansIntrinsic(Orig);
897 break;
898 case Intrinsic::dx_resource_load_rawbuffer:
899 if (expandBufferLoadIntrinsic(Orig, /*IsRaw*/ true))
900 return true;
901 break;
902 case Intrinsic::dx_resource_store_rawbuffer:
903 if (expandBufferStoreIntrinsic(Orig, /*IsRaw*/ true))
904 return true;
905 break;
906 case Intrinsic::dx_resource_load_typedbuffer:
907 if (expandBufferLoadIntrinsic(Orig, /*IsRaw*/ false))
908 return true;
909 break;
910 case Intrinsic::dx_resource_store_typedbuffer:
911 if (expandBufferStoreIntrinsic(Orig, /*IsRaw*/ false))
912 return true;
913 break;
914 case Intrinsic::usub_sat:
915 Result = expandUsubSat(Orig);
916 break;
917 case Intrinsic::vector_reduce_add:
918 case Intrinsic::vector_reduce_fadd:
919 Result = expandVecReduceAdd(Orig, IntrinsicId);
920 break;
921 }
922 if (Result) {
923 Orig->replaceAllUsesWith(Result);
924 Orig->eraseFromParent();
925 return true;
926 }
927 return false;
928 }
929
expansionIntrinsics(Module & M)930 static bool expansionIntrinsics(Module &M) {
931 for (auto &F : make_early_inc_range(M.functions())) {
932 if (!isIntrinsicExpansion(F))
933 continue;
934 bool IntrinsicExpanded = false;
935 for (User *U : make_early_inc_range(F.users())) {
936 auto *IntrinsicCall = dyn_cast<CallInst>(U);
937 if (!IntrinsicCall)
938 continue;
939 IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
940 }
941 if (F.user_empty() && IntrinsicExpanded)
942 F.eraseFromParent();
943 }
944 return true;
945 }
946
run(Module & M,ModuleAnalysisManager &)947 PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
948 ModuleAnalysisManager &) {
949 if (expansionIntrinsics(M))
950 return PreservedAnalyses::none();
951 return PreservedAnalyses::all();
952 }
953
runOnModule(Module & M)954 bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
955 return expansionIntrinsics(M);
956 }
957
958 char DXILIntrinsicExpansionLegacy::ID = 0;
959
960 INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
961 "DXIL Intrinsic Expansion", false, false)
962 INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
963 "DXIL Intrinsic Expansion", false, false)
964
createDXILIntrinsicExpansionLegacyPass()965 ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
966 return new DXILIntrinsicExpansionLegacy();
967 }
968