xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains DXIL intrinsic expansions for those that don't have
10 //  opcodes in DirectX Intermediate Language (DXIL).
11 //===----------------------------------------------------------------------===//
12 
13 #include "DXILIntrinsicExpansion.h"
14 #include "DirectX.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/CodeGen/Passes.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/InstrTypes.h"
20 #include "llvm/IR/Instruction.h"
21 #include "llvm/IR/Instructions.h"
22 #include "llvm/IR/Intrinsics.h"
23 #include "llvm/IR/IntrinsicsDirectX.h"
24 #include "llvm/IR/Module.h"
25 #include "llvm/IR/PassManager.h"
26 #include "llvm/IR/Type.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Support/Casting.h"
29 #include "llvm/Support/ErrorHandling.h"
30 #include "llvm/Support/MathExtras.h"
31 
32 #define DEBUG_TYPE "dxil-intrinsic-expansion"
33 
34 using namespace llvm;
35 
36 class DXILIntrinsicExpansionLegacy : public ModulePass {
37 
38 public:
39   bool runOnModule(Module &M) override;
DXILIntrinsicExpansionLegacy()40   DXILIntrinsicExpansionLegacy() : ModulePass(ID) {}
41 
42   static char ID; // Pass identification.
43 };
44 
resourceAccessNeeds64BitExpansion(Module * M,Type * OverloadTy,bool IsRaw)45 static bool resourceAccessNeeds64BitExpansion(Module *M, Type *OverloadTy,
46                                               bool IsRaw) {
47   if (IsRaw && M->getTargetTriple().getDXILVersion() > VersionTuple(1, 2))
48     return false;
49 
50   Type *ScalarTy = OverloadTy->getScalarType();
51   return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64);
52 }
53 
isIntrinsicExpansion(Function & F)54 static bool isIntrinsicExpansion(Function &F) {
55   switch (F.getIntrinsicID()) {
56   case Intrinsic::abs:
57   case Intrinsic::atan2:
58   case Intrinsic::exp:
59   case Intrinsic::is_fpclass:
60   case Intrinsic::log:
61   case Intrinsic::log10:
62   case Intrinsic::pow:
63   case Intrinsic::powi:
64   case Intrinsic::dx_all:
65   case Intrinsic::dx_any:
66   case Intrinsic::dx_cross:
67   case Intrinsic::dx_uclamp:
68   case Intrinsic::dx_sclamp:
69   case Intrinsic::dx_nclamp:
70   case Intrinsic::dx_degrees:
71   case Intrinsic::dx_lerp:
72   case Intrinsic::dx_normalize:
73   case Intrinsic::dx_fdot:
74   case Intrinsic::dx_sdot:
75   case Intrinsic::dx_udot:
76   case Intrinsic::dx_sign:
77   case Intrinsic::dx_step:
78   case Intrinsic::dx_radians:
79   case Intrinsic::usub_sat:
80   case Intrinsic::vector_reduce_add:
81   case Intrinsic::vector_reduce_fadd:
82     return true;
83   case Intrinsic::dx_resource_load_rawbuffer:
84     return resourceAccessNeeds64BitExpansion(
85         F.getParent(), F.getReturnType()->getStructElementType(0),
86         /*IsRaw*/ true);
87   case Intrinsic::dx_resource_load_typedbuffer:
88     return resourceAccessNeeds64BitExpansion(
89         F.getParent(), F.getReturnType()->getStructElementType(0),
90         /*IsRaw*/ false);
91   case Intrinsic::dx_resource_store_rawbuffer:
92     return resourceAccessNeeds64BitExpansion(
93         F.getParent(), F.getFunctionType()->getParamType(3), /*IsRaw*/ true);
94   case Intrinsic::dx_resource_store_typedbuffer:
95     return resourceAccessNeeds64BitExpansion(
96         F.getParent(), F.getFunctionType()->getParamType(2), /*IsRaw*/ false);
97   }
98   return false;
99 }
100 
expandUsubSat(CallInst * Orig)101 static Value *expandUsubSat(CallInst *Orig) {
102   Value *A = Orig->getArgOperand(0);
103   Value *B = Orig->getArgOperand(1);
104   Type *Ty = A->getType();
105 
106   IRBuilder<> Builder(Orig);
107 
108   Value *Cmp = Builder.CreateICmpULT(A, B, "usub.cmp");
109   Value *Sub = Builder.CreateSub(A, B, "usub.sub");
110   Value *Zero = ConstantInt::get(Ty, 0);
111   return Builder.CreateSelect(Cmp, Zero, Sub, "usub.sat");
112 }
113 
expandVecReduceAdd(CallInst * Orig,Intrinsic::ID IntrinsicId)114 static Value *expandVecReduceAdd(CallInst *Orig, Intrinsic::ID IntrinsicId) {
115   assert(IntrinsicId == Intrinsic::vector_reduce_add ||
116          IntrinsicId == Intrinsic::vector_reduce_fadd);
117 
118   IRBuilder<> Builder(Orig);
119   bool IsFAdd = (IntrinsicId == Intrinsic::vector_reduce_fadd);
120 
121   Value *X = Orig->getOperand(IsFAdd ? 1 : 0);
122   Type *Ty = X->getType();
123   auto *XVec = dyn_cast<FixedVectorType>(Ty);
124   unsigned XVecSize = XVec->getNumElements();
125   Value *Sum = Builder.CreateExtractElement(X, static_cast<uint64_t>(0));
126 
127   // Handle the initial start value for floating-point addition.
128   if (IsFAdd) {
129     Constant *StartValue = dyn_cast<Constant>(Orig->getOperand(0));
130     if (StartValue && !StartValue->isZeroValue())
131       Sum = Builder.CreateFAdd(Sum, StartValue);
132   }
133 
134   // Accumulate the remaining vector elements.
135   for (unsigned I = 1; I < XVecSize; I++) {
136     Value *Elt = Builder.CreateExtractElement(X, I);
137     if (IsFAdd)
138       Sum = Builder.CreateFAdd(Sum, Elt);
139     else
140       Sum = Builder.CreateAdd(Sum, Elt);
141   }
142 
143   return Sum;
144 }
145 
expandAbs(CallInst * Orig)146 static Value *expandAbs(CallInst *Orig) {
147   Value *X = Orig->getOperand(0);
148   IRBuilder<> Builder(Orig);
149   Type *Ty = X->getType();
150   Type *EltTy = Ty->getScalarType();
151   Constant *Zero = Ty->isVectorTy()
152                        ? ConstantVector::getSplat(
153                              ElementCount::getFixed(
154                                  cast<FixedVectorType>(Ty)->getNumElements()),
155                              ConstantInt::get(EltTy, 0))
156                        : ConstantInt::get(EltTy, 0);
157   auto *V = Builder.CreateSub(Zero, X);
158   return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr,
159                                  "dx.max");
160 }
161 
expandCrossIntrinsic(CallInst * Orig)162 static Value *expandCrossIntrinsic(CallInst *Orig) {
163 
164   VectorType *VT = cast<VectorType>(Orig->getType());
165   if (cast<FixedVectorType>(VT)->getNumElements() != 3)
166     reportFatalUsageError("return vector must have exactly 3 elements");
167 
168   Value *op0 = Orig->getOperand(0);
169   Value *op1 = Orig->getOperand(1);
170   IRBuilder<> Builder(Orig);
171 
172   Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0, "x0");
173   Value *op0_y = Builder.CreateExtractElement(op0, 1, "x1");
174   Value *op0_z = Builder.CreateExtractElement(op0, 2, "x2");
175 
176   Value *op1_x = Builder.CreateExtractElement(op1, (uint64_t)0, "y0");
177   Value *op1_y = Builder.CreateExtractElement(op1, 1, "y1");
178   Value *op1_z = Builder.CreateExtractElement(op1, 2, "y2");
179 
180   auto MulSub = [&](Value *x0, Value *y0, Value *x1, Value *y1) -> Value * {
181     Value *xy = Builder.CreateFMul(x0, y1);
182     Value *yx = Builder.CreateFMul(y0, x1);
183     return Builder.CreateFSub(xy, yx, Orig->getName());
184   };
185 
186   Value *yz_zy = MulSub(op0_y, op0_z, op1_y, op1_z);
187   Value *zx_xz = MulSub(op0_z, op0_x, op1_z, op1_x);
188   Value *xy_yx = MulSub(op0_x, op0_y, op1_x, op1_y);
189 
190   Value *cross = PoisonValue::get(VT);
191   cross = Builder.CreateInsertElement(cross, yz_zy, (uint64_t)0);
192   cross = Builder.CreateInsertElement(cross, zx_xz, 1);
193   cross = Builder.CreateInsertElement(cross, xy_yx, 2);
194   return cross;
195 }
196 
197 // Create appropriate DXIL float dot intrinsic for the given A and B operands
198 // The appropriate opcode will be determined by the size of the operands
199 // The dot product is placed in the position indicated by Orig
expandFloatDotIntrinsic(CallInst * Orig,Value * A,Value * B)200 static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) {
201   Type *ATy = A->getType();
202   [[maybe_unused]] Type *BTy = B->getType();
203   assert(ATy->isVectorTy() && BTy->isVectorTy());
204 
205   IRBuilder<> Builder(Orig);
206 
207   auto *AVec = dyn_cast<FixedVectorType>(ATy);
208 
209   assert(ATy->getScalarType()->isFloatingPointTy());
210 
211   Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
212   int NumElts = AVec->getNumElements();
213   switch (NumElts) {
214   case 2:
215     DotIntrinsic = Intrinsic::dx_dot2;
216     break;
217   case 3:
218     DotIntrinsic = Intrinsic::dx_dot3;
219     break;
220   case 4:
221     DotIntrinsic = Intrinsic::dx_dot4;
222     break;
223   default:
224     reportFatalUsageError(
225         "Invalid dot product input vector: length is outside 2-4");
226     return nullptr;
227   }
228 
229   SmallVector<Value *> Args;
230   for (int I = 0; I < NumElts; ++I)
231     Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I)));
232   for (int I = 0; I < NumElts; ++I)
233     Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I)));
234   return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args,
235                                  nullptr, "dot");
236 }
237 
238 // Create the appropriate DXIL float dot intrinsic for the operands of Orig
239 // The appropriate opcode will be determined by the size of the operands
240 // The dot product is placed in the position indicated by Orig
expandFloatDotIntrinsic(CallInst * Orig)241 static Value *expandFloatDotIntrinsic(CallInst *Orig) {
242   return expandFloatDotIntrinsic(Orig, Orig->getOperand(0),
243                                  Orig->getOperand(1));
244 }
245 
246 // Expand integer dot product to multiply and add ops
expandIntegerDotIntrinsic(CallInst * Orig,Intrinsic::ID DotIntrinsic)247 static Value *expandIntegerDotIntrinsic(CallInst *Orig,
248                                         Intrinsic::ID DotIntrinsic) {
249   assert(DotIntrinsic == Intrinsic::dx_sdot ||
250          DotIntrinsic == Intrinsic::dx_udot);
251   Value *A = Orig->getOperand(0);
252   Value *B = Orig->getOperand(1);
253   Type *ATy = A->getType();
254   [[maybe_unused]] Type *BTy = B->getType();
255   assert(ATy->isVectorTy() && BTy->isVectorTy());
256 
257   IRBuilder<> Builder(Orig);
258 
259   auto *AVec = dyn_cast<FixedVectorType>(ATy);
260 
261   assert(ATy->getScalarType()->isIntegerTy());
262 
263   Value *Result;
264   Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
265                                    ? Intrinsic::dx_imad
266                                    : Intrinsic::dx_umad;
267   Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
268   Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);
269   Result = Builder.CreateMul(Elt0, Elt1);
270   for (unsigned I = 1; I < AVec->getNumElements(); I++) {
271     Elt0 = Builder.CreateExtractElement(A, I);
272     Elt1 = Builder.CreateExtractElement(B, I);
273     Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,
274                                      ArrayRef<Value *>{Elt0, Elt1, Result},
275                                      nullptr, "dx.mad");
276   }
277   return Result;
278 }
279 
expandExpIntrinsic(CallInst * Orig)280 static Value *expandExpIntrinsic(CallInst *Orig) {
281   Value *X = Orig->getOperand(0);
282   IRBuilder<> Builder(Orig);
283   Type *Ty = X->getType();
284   Type *EltTy = Ty->getScalarType();
285   Constant *Log2eConst =
286       Ty->isVectorTy() ? ConstantVector::getSplat(
287                              ElementCount::getFixed(
288                                  cast<FixedVectorType>(Ty)->getNumElements()),
289                              ConstantFP::get(EltTy, numbers::log2ef))
290                        : ConstantFP::get(EltTy, numbers::log2ef);
291   Value *NewX = Builder.CreateFMul(Log2eConst, X);
292   auto *Exp2Call =
293       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");
294   Exp2Call->setTailCall(Orig->isTailCall());
295   Exp2Call->setAttributes(Orig->getAttributes());
296   return Exp2Call;
297 }
298 
expandIsFPClass(CallInst * Orig)299 static Value *expandIsFPClass(CallInst *Orig) {
300   Value *T = Orig->getArgOperand(1);
301   auto *TCI = dyn_cast<ConstantInt>(T);
302 
303   // These FPClassTest cases have DXIL opcodes, so they will be handled in
304   // DXIL Op Lowering instead.
305   switch (TCI->getZExtValue()) {
306   case FPClassTest::fcInf:
307   case FPClassTest::fcNan:
308   case FPClassTest::fcNormal:
309   case FPClassTest::fcFinite:
310     return nullptr;
311   }
312 
313   IRBuilder<> Builder(Orig);
314 
315   Value *F = Orig->getArgOperand(0);
316   Type *FTy = F->getType();
317   unsigned FNumElem = 0; // 0 => F is not a vector
318 
319   unsigned BitWidth; // Bit width of F or the ElemTy of F
320   Type *BitCastTy;   // An IntNTy of the same bitwidth as F or ElemTy of F
321 
322   if (auto *FVecTy = dyn_cast<FixedVectorType>(FTy)) {
323     Type *ElemTy = FVecTy->getElementType();
324     FNumElem = FVecTy->getNumElements();
325     BitWidth = ElemTy->getPrimitiveSizeInBits();
326     BitCastTy = FixedVectorType::get(Builder.getIntNTy(BitWidth), FNumElem);
327   } else {
328     BitWidth = FTy->getPrimitiveSizeInBits();
329     BitCastTy = Builder.getIntNTy(BitWidth);
330   }
331 
332   Value *FBitCast = Builder.CreateBitCast(F, BitCastTy);
333   switch (TCI->getZExtValue()) {
334   case FPClassTest::fcNegZero: {
335     Value *NegZero =
336         ConstantInt::get(Builder.getIntNTy(BitWidth), 1 << (BitWidth - 1));
337     Value *RetVal;
338     if (FNumElem) {
339       Value *NegZeroSplat = Builder.CreateVectorSplat(FNumElem, NegZero);
340       RetVal =
341           Builder.CreateICmpEQ(FBitCast, NegZeroSplat, "is.fpclass.negzero");
342     } else
343       RetVal = Builder.CreateICmpEQ(FBitCast, NegZero, "is.fpclass.negzero");
344     return RetVal;
345   }
346   default:
347     reportFatalUsageError("Unsupported FPClassTest");
348   }
349 }
350 
expandAnyOrAllIntrinsic(CallInst * Orig,Intrinsic::ID IntrinsicId)351 static Value *expandAnyOrAllIntrinsic(CallInst *Orig,
352                                       Intrinsic::ID IntrinsicId) {
353   Value *X = Orig->getOperand(0);
354   IRBuilder<> Builder(Orig);
355   Type *Ty = X->getType();
356   Type *EltTy = Ty->getScalarType();
357 
358   auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result,
359                             Value *Elt) {
360     if (IntrinsicId == Intrinsic::dx_any)
361       return Builder.CreateOr(Result, Elt);
362     assert(IntrinsicId == Intrinsic::dx_all);
363     return Builder.CreateAnd(Result, Elt);
364   };
365 
366   Value *Result = nullptr;
367   if (!Ty->isVectorTy()) {
368     Result = EltTy->isFloatingPointTy()
369                  ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))
370                  : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));
371   } else {
372     auto *XVec = dyn_cast<FixedVectorType>(Ty);
373     Value *Cond =
374         EltTy->isFloatingPointTy()
375             ? Builder.CreateFCmpUNE(
376                   X, ConstantVector::getSplat(
377                          ElementCount::getFixed(XVec->getNumElements()),
378                          ConstantFP::get(EltTy, 0)))
379             : Builder.CreateICmpNE(
380                   X, ConstantVector::getSplat(
381                          ElementCount::getFixed(XVec->getNumElements()),
382                          ConstantInt::get(EltTy, 0)));
383     Result = Builder.CreateExtractElement(Cond, (uint64_t)0);
384     for (unsigned I = 1; I < XVec->getNumElements(); I++) {
385       Value *Elt = Builder.CreateExtractElement(Cond, I);
386       Result = ApplyOp(IntrinsicId, Result, Elt);
387     }
388   }
389   return Result;
390 }
391 
expandLerpIntrinsic(CallInst * Orig)392 static Value *expandLerpIntrinsic(CallInst *Orig) {
393   Value *X = Orig->getOperand(0);
394   Value *Y = Orig->getOperand(1);
395   Value *S = Orig->getOperand(2);
396   IRBuilder<> Builder(Orig);
397   auto *V = Builder.CreateFSub(Y, X);
398   V = Builder.CreateFMul(S, V);
399   return Builder.CreateFAdd(X, V, "dx.lerp");
400 }
401 
expandLogIntrinsic(CallInst * Orig,float LogConstVal=numbers::ln2f)402 static Value *expandLogIntrinsic(CallInst *Orig,
403                                  float LogConstVal = numbers::ln2f) {
404   Value *X = Orig->getOperand(0);
405   IRBuilder<> Builder(Orig);
406   Type *Ty = X->getType();
407   Type *EltTy = Ty->getScalarType();
408   Constant *Ln2Const =
409       Ty->isVectorTy() ? ConstantVector::getSplat(
410                              ElementCount::getFixed(
411                                  cast<FixedVectorType>(Ty)->getNumElements()),
412                              ConstantFP::get(EltTy, LogConstVal))
413                        : ConstantFP::get(EltTy, LogConstVal);
414   auto *Log2Call =
415       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
416   Log2Call->setTailCall(Orig->isTailCall());
417   Log2Call->setAttributes(Orig->getAttributes());
418   return Builder.CreateFMul(Ln2Const, Log2Call);
419 }
expandLog10Intrinsic(CallInst * Orig)420 static Value *expandLog10Intrinsic(CallInst *Orig) {
421   return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);
422 }
423 
424 // Use dot product of vector operand with itself to calculate the length.
425 // Divide the vector by that length to normalize it.
expandNormalizeIntrinsic(CallInst * Orig)426 static Value *expandNormalizeIntrinsic(CallInst *Orig) {
427   Value *X = Orig->getOperand(0);
428   Type *Ty = Orig->getType();
429   Type *EltTy = Ty->getScalarType();
430   IRBuilder<> Builder(Orig);
431 
432   auto *XVec = dyn_cast<FixedVectorType>(Ty);
433   if (!XVec) {
434     if (auto *constantFP = dyn_cast<ConstantFP>(X)) {
435       const APFloat &fpVal = constantFP->getValueAPF();
436       if (fpVal.isZero())
437         reportFatalUsageError("Invalid input scalar: length is zero");
438     }
439     return Builder.CreateFDiv(X, X);
440   }
441 
442   Value *DotProduct = expandFloatDotIntrinsic(Orig, X, X);
443 
444   // verify that the length is non-zero
445   // (if the dot product is non-zero, then the length is non-zero)
446   if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) {
447     const APFloat &fpVal = constantFP->getValueAPF();
448     if (fpVal.isZero())
449       reportFatalUsageError("Invalid input vector: length is zero");
450   }
451 
452   Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt,
453                                                 ArrayRef<Value *>{DotProduct},
454                                                 nullptr, "dx.rsqrt");
455 
456   Value *MultiplicandVec =
457       Builder.CreateVectorSplat(XVec->getNumElements(), Multiplicand);
458   return Builder.CreateFMul(X, MultiplicandVec);
459 }
460 
expandAtan2Intrinsic(CallInst * Orig)461 static Value *expandAtan2Intrinsic(CallInst *Orig) {
462   Value *Y = Orig->getOperand(0);
463   Value *X = Orig->getOperand(1);
464   Type *Ty = X->getType();
465   IRBuilder<> Builder(Orig);
466   Builder.setFastMathFlags(Orig->getFastMathFlags());
467 
468   Value *Tan = Builder.CreateFDiv(Y, X);
469 
470   CallInst *Atan =
471       Builder.CreateIntrinsic(Ty, Intrinsic::atan, {Tan}, nullptr, "Elt.Atan");
472   Atan->setTailCall(Orig->isTailCall());
473   Atan->setAttributes(Orig->getAttributes());
474 
475   // Modify atan result based on https://en.wikipedia.org/wiki/Atan2.
476   Constant *Pi = ConstantFP::get(Ty, llvm::numbers::pi);
477   Constant *HalfPi = ConstantFP::get(Ty, llvm::numbers::pi / 2);
478   Constant *NegHalfPi = ConstantFP::get(Ty, -llvm::numbers::pi / 2);
479   Constant *Zero = ConstantFP::get(Ty, 0);
480   Value *AtanAddPi = Builder.CreateFAdd(Atan, Pi);
481   Value *AtanSubPi = Builder.CreateFSub(Atan, Pi);
482 
483   // x > 0 -> atan.
484   Value *Result = Atan;
485   Value *XLt0 = Builder.CreateFCmpOLT(X, Zero);
486   Value *XEq0 = Builder.CreateFCmpOEQ(X, Zero);
487   Value *YGe0 = Builder.CreateFCmpOGE(Y, Zero);
488   Value *YLt0 = Builder.CreateFCmpOLT(Y, Zero);
489 
490   // x < 0, y >= 0 -> atan + pi.
491   Value *XLt0AndYGe0 = Builder.CreateAnd(XLt0, YGe0);
492   Result = Builder.CreateSelect(XLt0AndYGe0, AtanAddPi, Result);
493 
494   // x < 0, y < 0 -> atan - pi.
495   Value *XLt0AndYLt0 = Builder.CreateAnd(XLt0, YLt0);
496   Result = Builder.CreateSelect(XLt0AndYLt0, AtanSubPi, Result);
497 
498   // x == 0, y < 0 -> -pi/2
499   Value *XEq0AndYLt0 = Builder.CreateAnd(XEq0, YLt0);
500   Result = Builder.CreateSelect(XEq0AndYLt0, NegHalfPi, Result);
501 
502   // x == 0, y > 0 -> pi/2
503   Value *XEq0AndYGe0 = Builder.CreateAnd(XEq0, YGe0);
504   Result = Builder.CreateSelect(XEq0AndYGe0, HalfPi, Result);
505 
506   return Result;
507 }
508 
expandPowIntrinsic(CallInst * Orig,Intrinsic::ID IntrinsicId)509 static Value *expandPowIntrinsic(CallInst *Orig, Intrinsic::ID IntrinsicId) {
510 
511   Value *X = Orig->getOperand(0);
512   Value *Y = Orig->getOperand(1);
513   Type *Ty = X->getType();
514   IRBuilder<> Builder(Orig);
515 
516   if (IntrinsicId == Intrinsic::powi)
517     Y = Builder.CreateSIToFP(Y, Ty);
518 
519   auto *Log2Call =
520       Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");
521   auto *Mul = Builder.CreateFMul(Log2Call, Y);
522   auto *Exp2Call =
523       Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");
524   Exp2Call->setTailCall(Orig->isTailCall());
525   Exp2Call->setAttributes(Orig->getAttributes());
526   return Exp2Call;
527 }
528 
expandStepIntrinsic(CallInst * Orig)529 static Value *expandStepIntrinsic(CallInst *Orig) {
530 
531   Value *X = Orig->getOperand(0);
532   Value *Y = Orig->getOperand(1);
533   Type *Ty = X->getType();
534   IRBuilder<> Builder(Orig);
535 
536   Constant *One = ConstantFP::get(Ty->getScalarType(), 1.0);
537   Constant *Zero = ConstantFP::get(Ty->getScalarType(), 0.0);
538   Value *Cond = Builder.CreateFCmpOLT(Y, X);
539 
540   if (Ty != Ty->getScalarType()) {
541     auto *XVec = dyn_cast<FixedVectorType>(Ty);
542     One = ConstantVector::getSplat(
543         ElementCount::getFixed(XVec->getNumElements()), One);
544     Zero = ConstantVector::getSplat(
545         ElementCount::getFixed(XVec->getNumElements()), Zero);
546   }
547 
548   return Builder.CreateSelect(Cond, Zero, One);
549 }
550 
expandRadiansIntrinsic(CallInst * Orig)551 static Value *expandRadiansIntrinsic(CallInst *Orig) {
552   Value *X = Orig->getOperand(0);
553   Type *Ty = X->getType();
554   IRBuilder<> Builder(Orig);
555   Value *PiOver180 = ConstantFP::get(Ty, llvm::numbers::pi / 180.0);
556   return Builder.CreateFMul(X, PiOver180);
557 }
558 
expandBufferLoadIntrinsic(CallInst * Orig,bool IsRaw)559 static bool expandBufferLoadIntrinsic(CallInst *Orig, bool IsRaw) {
560   IRBuilder<> Builder(Orig);
561 
562   Type *BufferTy = Orig->getType()->getStructElementType(0);
563   Type *ScalarTy = BufferTy->getScalarType();
564   bool IsDouble = ScalarTy->isDoubleTy();
565   assert(IsDouble || ScalarTy->isIntegerTy(64) &&
566                          "Only expand double or int64 scalars or vectors");
567   bool IsVector = false;
568   unsigned ExtractNum = 2;
569   if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
570     ExtractNum = 2 * VT->getNumElements();
571     IsVector = true;
572     assert(IsRaw || ExtractNum == 4 && "TypedBufferLoad vector must be size 2");
573   }
574 
575   SmallVector<Value *, 2> Loads;
576   Value *Result = PoisonValue::get(BufferTy);
577   unsigned Base = 0;
578   // If we need to extract more than 4 i32; we need to break it up into
579   // more than one load. LoadNum tells us how many i32s we are loading in
580   // each load
581   while (ExtractNum > 0) {
582     unsigned LoadNum = std::min(ExtractNum, 4u);
583     Type *Ty = VectorType::get(Builder.getInt32Ty(), LoadNum, false);
584 
585     Type *LoadType = StructType::get(Ty, Builder.getInt1Ty());
586     Intrinsic::ID LoadIntrinsic = Intrinsic::dx_resource_load_typedbuffer;
587     SmallVector<Value *, 3> Args = {Orig->getOperand(0), Orig->getOperand(1)};
588     if (IsRaw) {
589       LoadIntrinsic = Intrinsic::dx_resource_load_rawbuffer;
590       Value *Tmp = Builder.getInt32(4 * Base * 2);
591       Args.push_back(Builder.CreateAdd(Orig->getOperand(2), Tmp));
592     }
593 
594     CallInst *Load = Builder.CreateIntrinsic(LoadType, LoadIntrinsic, Args);
595     Loads.push_back(Load);
596 
597     // extract the buffer load's result
598     Value *Extract = Builder.CreateExtractValue(Load, {0});
599 
600     SmallVector<Value *> ExtractElements;
601     for (unsigned I = 0; I < LoadNum; ++I)
602       ExtractElements.push_back(
603           Builder.CreateExtractElement(Extract, Builder.getInt32(I)));
604 
605     // combine into double(s) or int64(s)
606     for (unsigned I = 0; I < LoadNum; I += 2) {
607       Value *Combined = nullptr;
608       if (IsDouble)
609         // For doubles, use dx_asdouble intrinsic
610         Combined = Builder.CreateIntrinsic(
611             Builder.getDoubleTy(), Intrinsic::dx_asdouble,
612             {ExtractElements[I], ExtractElements[I + 1]});
613       else {
614         // For int64, manually combine two int32s
615         // First, zero-extend both values to i64
616         Value *Lo =
617             Builder.CreateZExt(ExtractElements[I], Builder.getInt64Ty());
618         Value *Hi =
619             Builder.CreateZExt(ExtractElements[I + 1], Builder.getInt64Ty());
620         // Shift the high bits left by 32 bits
621         Value *ShiftedHi = Builder.CreateShl(Hi, Builder.getInt64(32));
622         // OR the high and low bits together
623         Combined = Builder.CreateOr(Lo, ShiftedHi);
624       }
625 
626       if (IsVector)
627         Result = Builder.CreateInsertElement(Result, Combined,
628                                              Builder.getInt32((I / 2) + Base));
629       else
630         Result = Combined;
631     }
632 
633     ExtractNum -= LoadNum;
634     Base += LoadNum / 2;
635   }
636 
637   Value *CheckBit = nullptr;
638   for (User *U : make_early_inc_range(Orig->users())) {
639     // If it's not a ExtractValueInst, we don't know how to
640     // handle it
641     auto *EVI = dyn_cast<ExtractValueInst>(U);
642     if (!EVI)
643       llvm_unreachable("Unexpected user of typedbufferload");
644 
645     ArrayRef<unsigned> Indices = EVI->getIndices();
646     assert(Indices.size() == 1);
647 
648     if (Indices[0] == 0) {
649       // Use of the value(s)
650       EVI->replaceAllUsesWith(Result);
651     } else {
652       // Use of the check bit
653       assert(Indices[0] == 1 && "Unexpected type for typedbufferload");
654       // Note: This does not always match the historical behaviour of DXC.
655       // See https://github.com/microsoft/DirectXShaderCompiler/issues/7622
656       if (!CheckBit) {
657         SmallVector<Value *, 2> CheckBits;
658         for (Value *L : Loads)
659           CheckBits.push_back(Builder.CreateExtractValue(L, {1}));
660         CheckBit = Builder.CreateAnd(CheckBits);
661       }
662       EVI->replaceAllUsesWith(CheckBit);
663     }
664     EVI->eraseFromParent();
665   }
666   Orig->eraseFromParent();
667   return true;
668 }
669 
expandBufferStoreIntrinsic(CallInst * Orig,bool IsRaw)670 static bool expandBufferStoreIntrinsic(CallInst *Orig, bool IsRaw) {
671   IRBuilder<> Builder(Orig);
672 
673   unsigned ValIndex = IsRaw ? 3 : 2;
674   Type *BufferTy = Orig->getFunctionType()->getParamType(ValIndex);
675   Type *ScalarTy = BufferTy->getScalarType();
676   bool IsDouble = ScalarTy->isDoubleTy();
677   assert((IsDouble || ScalarTy->isIntegerTy(64)) &&
678          "Only expand double or int64 scalars or vectors");
679 
680   // Determine if we're dealing with a vector or scalar
681   bool IsVector = false;
682   unsigned ExtractNum = 2;
683   unsigned VecLen = 0;
684   if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) {
685     VecLen = VT->getNumElements();
686     assert(IsRaw || VecLen == 2 && "TypedBufferStore vector must be size 2");
687     ExtractNum = VecLen * 2;
688     IsVector = true;
689   }
690 
691   // Create the appropriate vector type for the result
692   Type *Int32Ty = Builder.getInt32Ty();
693   Type *ResultTy = VectorType::get(Int32Ty, ExtractNum, false);
694   Value *Val = PoisonValue::get(ResultTy);
695 
696   Type *SplitElementTy = Int32Ty;
697   if (IsVector)
698     SplitElementTy = VectorType::get(SplitElementTy, VecLen, false);
699 
700   Value *LowBits = nullptr;
701   Value *HighBits = nullptr;
702   // Split the 64-bit values into 32-bit components
703   if (IsDouble) {
704     auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy);
705     Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble,
706                                            {Orig->getOperand(ValIndex)});
707     LowBits = Builder.CreateExtractValue(Split, 0);
708     HighBits = Builder.CreateExtractValue(Split, 1);
709   } else {
710     // Handle int64 type(s)
711     Value *InputVal = Orig->getOperand(ValIndex);
712     Constant *ShiftAmt = Builder.getInt64(32);
713     if (IsVector)
714       ShiftAmt =
715           ConstantVector::getSplat(ElementCount::getFixed(VecLen), ShiftAmt);
716 
717     // Split into low and high 32-bit parts
718     LowBits = Builder.CreateTrunc(InputVal, SplitElementTy);
719     Value *ShiftedVal = Builder.CreateLShr(InputVal, ShiftAmt);
720     HighBits = Builder.CreateTrunc(ShiftedVal, SplitElementTy);
721   }
722 
723   if (IsVector) {
724     SmallVector<int, 8> Mask;
725     for (unsigned I = 0; I < VecLen; ++I) {
726       Mask.push_back(I);
727       Mask.push_back(I + VecLen);
728     }
729     Val = Builder.CreateShuffleVector(LowBits, HighBits, Mask);
730   } else {
731     Val = Builder.CreateInsertElement(Val, LowBits, Builder.getInt32(0));
732     Val = Builder.CreateInsertElement(Val, HighBits, Builder.getInt32(1));
733   }
734 
735   // If we need to extract more than 4 i32; we need to break it up into
736   // more than one store. StoreNum tells us how many i32s we are storing in
737   // each store
738   unsigned Base = 0;
739   while (ExtractNum > 0) {
740     unsigned StoreNum = std::min(ExtractNum, 4u);
741 
742     Intrinsic::ID StoreIntrinsic = Intrinsic::dx_resource_store_typedbuffer;
743     SmallVector<Value *, 4> Args = {Orig->getOperand(0), Orig->getOperand(1)};
744     if (IsRaw) {
745       StoreIntrinsic = Intrinsic::dx_resource_store_rawbuffer;
746       Value *Tmp = Builder.getInt32(4 * Base);
747       Args.push_back(Builder.CreateAdd(Orig->getOperand(2), Tmp));
748     }
749 
750     SmallVector<int, 4> Mask;
751     for (unsigned I = 0; I < StoreNum; ++I) {
752       Mask.push_back(Base + I);
753     }
754 
755     Value *SubVal = Val;
756     if (VecLen > 2)
757       SubVal = Builder.CreateShuffleVector(Val, Mask);
758 
759     Args.push_back(SubVal);
760     // Create the final intrinsic call
761     Builder.CreateIntrinsic(Builder.getVoidTy(), StoreIntrinsic, Args);
762 
763     ExtractNum -= StoreNum;
764     Base += StoreNum;
765   }
766   Orig->eraseFromParent();
767   return true;
768 }
769 
getMaxForClamp(Intrinsic::ID ClampIntrinsic)770 static Intrinsic::ID getMaxForClamp(Intrinsic::ID ClampIntrinsic) {
771   if (ClampIntrinsic == Intrinsic::dx_uclamp)
772     return Intrinsic::umax;
773   if (ClampIntrinsic == Intrinsic::dx_sclamp)
774     return Intrinsic::smax;
775   assert(ClampIntrinsic == Intrinsic::dx_nclamp);
776   return Intrinsic::maxnum;
777 }
778 
getMinForClamp(Intrinsic::ID ClampIntrinsic)779 static Intrinsic::ID getMinForClamp(Intrinsic::ID ClampIntrinsic) {
780   if (ClampIntrinsic == Intrinsic::dx_uclamp)
781     return Intrinsic::umin;
782   if (ClampIntrinsic == Intrinsic::dx_sclamp)
783     return Intrinsic::smin;
784   assert(ClampIntrinsic == Intrinsic::dx_nclamp);
785   return Intrinsic::minnum;
786 }
787 
expandClampIntrinsic(CallInst * Orig,Intrinsic::ID ClampIntrinsic)788 static Value *expandClampIntrinsic(CallInst *Orig,
789                                    Intrinsic::ID ClampIntrinsic) {
790   Value *X = Orig->getOperand(0);
791   Value *Min = Orig->getOperand(1);
792   Value *Max = Orig->getOperand(2);
793   Type *Ty = X->getType();
794   IRBuilder<> Builder(Orig);
795   auto *MaxCall = Builder.CreateIntrinsic(Ty, getMaxForClamp(ClampIntrinsic),
796                                           {X, Min}, nullptr, "dx.max");
797   return Builder.CreateIntrinsic(Ty, getMinForClamp(ClampIntrinsic),
798                                  {MaxCall, Max}, nullptr, "dx.min");
799 }
800 
expandDegreesIntrinsic(CallInst * Orig)801 static Value *expandDegreesIntrinsic(CallInst *Orig) {
802   Value *X = Orig->getOperand(0);
803   Type *Ty = X->getType();
804   IRBuilder<> Builder(Orig);
805   Value *DegreesRatio = ConstantFP::get(Ty, 180.0 * llvm::numbers::inv_pi);
806   return Builder.CreateFMul(X, DegreesRatio);
807 }
808 
expandSignIntrinsic(CallInst * Orig)809 static Value *expandSignIntrinsic(CallInst *Orig) {
810   Value *X = Orig->getOperand(0);
811   Type *Ty = X->getType();
812   Type *ScalarTy = Ty->getScalarType();
813   Type *RetTy = Orig->getType();
814   Constant *Zero = Constant::getNullValue(Ty);
815 
816   IRBuilder<> Builder(Orig);
817 
818   Value *GT;
819   Value *LT;
820   if (ScalarTy->isFloatingPointTy()) {
821     GT = Builder.CreateFCmpOLT(Zero, X);
822     LT = Builder.CreateFCmpOLT(X, Zero);
823   } else {
824     assert(ScalarTy->isIntegerTy());
825     GT = Builder.CreateICmpSLT(Zero, X);
826     LT = Builder.CreateICmpSLT(X, Zero);
827   }
828 
829   Value *ZextGT = Builder.CreateZExt(GT, RetTy);
830   Value *ZextLT = Builder.CreateZExt(LT, RetTy);
831 
832   return Builder.CreateSub(ZextGT, ZextLT);
833 }
834 
expandIntrinsic(Function & F,CallInst * Orig)835 static bool expandIntrinsic(Function &F, CallInst *Orig) {
836   Value *Result = nullptr;
837   Intrinsic::ID IntrinsicId = F.getIntrinsicID();
838   switch (IntrinsicId) {
839   case Intrinsic::abs:
840     Result = expandAbs(Orig);
841     break;
842   case Intrinsic::atan2:
843     Result = expandAtan2Intrinsic(Orig);
844     break;
845   case Intrinsic::exp:
846     Result = expandExpIntrinsic(Orig);
847     break;
848   case Intrinsic::is_fpclass:
849     Result = expandIsFPClass(Orig);
850     break;
851   case Intrinsic::log:
852     Result = expandLogIntrinsic(Orig);
853     break;
854   case Intrinsic::log10:
855     Result = expandLog10Intrinsic(Orig);
856     break;
857   case Intrinsic::pow:
858   case Intrinsic::powi:
859     Result = expandPowIntrinsic(Orig, IntrinsicId);
860     break;
861   case Intrinsic::dx_all:
862   case Intrinsic::dx_any:
863     Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId);
864     break;
865   case Intrinsic::dx_cross:
866     Result = expandCrossIntrinsic(Orig);
867     break;
868   case Intrinsic::dx_uclamp:
869   case Intrinsic::dx_sclamp:
870   case Intrinsic::dx_nclamp:
871     Result = expandClampIntrinsic(Orig, IntrinsicId);
872     break;
873   case Intrinsic::dx_degrees:
874     Result = expandDegreesIntrinsic(Orig);
875     break;
876   case Intrinsic::dx_lerp:
877     Result = expandLerpIntrinsic(Orig);
878     break;
879   case Intrinsic::dx_normalize:
880     Result = expandNormalizeIntrinsic(Orig);
881     break;
882   case Intrinsic::dx_fdot:
883     Result = expandFloatDotIntrinsic(Orig);
884     break;
885   case Intrinsic::dx_sdot:
886   case Intrinsic::dx_udot:
887     Result = expandIntegerDotIntrinsic(Orig, IntrinsicId);
888     break;
889   case Intrinsic::dx_sign:
890     Result = expandSignIntrinsic(Orig);
891     break;
892   case Intrinsic::dx_step:
893     Result = expandStepIntrinsic(Orig);
894     break;
895   case Intrinsic::dx_radians:
896     Result = expandRadiansIntrinsic(Orig);
897     break;
898   case Intrinsic::dx_resource_load_rawbuffer:
899     if (expandBufferLoadIntrinsic(Orig, /*IsRaw*/ true))
900       return true;
901     break;
902   case Intrinsic::dx_resource_store_rawbuffer:
903     if (expandBufferStoreIntrinsic(Orig, /*IsRaw*/ true))
904       return true;
905     break;
906   case Intrinsic::dx_resource_load_typedbuffer:
907     if (expandBufferLoadIntrinsic(Orig, /*IsRaw*/ false))
908       return true;
909     break;
910   case Intrinsic::dx_resource_store_typedbuffer:
911     if (expandBufferStoreIntrinsic(Orig, /*IsRaw*/ false))
912       return true;
913     break;
914   case Intrinsic::usub_sat:
915     Result = expandUsubSat(Orig);
916     break;
917   case Intrinsic::vector_reduce_add:
918   case Intrinsic::vector_reduce_fadd:
919     Result = expandVecReduceAdd(Orig, IntrinsicId);
920     break;
921   }
922   if (Result) {
923     Orig->replaceAllUsesWith(Result);
924     Orig->eraseFromParent();
925     return true;
926   }
927   return false;
928 }
929 
expansionIntrinsics(Module & M)930 static bool expansionIntrinsics(Module &M) {
931   for (auto &F : make_early_inc_range(M.functions())) {
932     if (!isIntrinsicExpansion(F))
933       continue;
934     bool IntrinsicExpanded = false;
935     for (User *U : make_early_inc_range(F.users())) {
936       auto *IntrinsicCall = dyn_cast<CallInst>(U);
937       if (!IntrinsicCall)
938         continue;
939       IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);
940     }
941     if (F.user_empty() && IntrinsicExpanded)
942       F.eraseFromParent();
943   }
944   return true;
945 }
946 
run(Module & M,ModuleAnalysisManager &)947 PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,
948                                               ModuleAnalysisManager &) {
949   if (expansionIntrinsics(M))
950     return PreservedAnalyses::none();
951   return PreservedAnalyses::all();
952 }
953 
runOnModule(Module & M)954 bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {
955   return expansionIntrinsics(M);
956 }
957 
958 char DXILIntrinsicExpansionLegacy::ID = 0;
959 
960 INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
961                       "DXIL Intrinsic Expansion", false, false)
962 INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,
963                     "DXIL Intrinsic Expansion", false, false)
964 
createDXILIntrinsicExpansionLegacyPass()965 ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {
966   return new DXILIntrinsicExpansionLegacy();
967 }
968