1 //===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains DXIL intrinsic expansions for those that don't have 10 // opcodes in DirectX Intermediate Language (DXIL). 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILIntrinsicExpansion.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/STLExtras.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/CodeGen/Passes.h" 18 #include "llvm/IR/IRBuilder.h" 19 #include "llvm/IR/InstrTypes.h" 20 #include "llvm/IR/Instruction.h" 21 #include "llvm/IR/Instructions.h" 22 #include "llvm/IR/Intrinsics.h" 23 #include "llvm/IR/IntrinsicsDirectX.h" 24 #include "llvm/IR/Module.h" 25 #include "llvm/IR/PassManager.h" 26 #include "llvm/IR/Type.h" 27 #include "llvm/Pass.h" 28 #include "llvm/Support/Casting.h" 29 #include "llvm/Support/ErrorHandling.h" 30 #include "llvm/Support/MathExtras.h" 31 32 #define DEBUG_TYPE "dxil-intrinsic-expansion" 33 34 using namespace llvm; 35 36 class DXILIntrinsicExpansionLegacy : public ModulePass { 37 38 public: 39 bool runOnModule(Module &M) override; 40 DXILIntrinsicExpansionLegacy() : ModulePass(ID) {} 41 42 static char ID; // Pass identification. 43 }; 44 45 static bool resourceAccessNeeds64BitExpansion(Module *M, Type *OverloadTy, 46 bool IsRaw) { 47 if (IsRaw && M->getTargetTriple().getDXILVersion() > VersionTuple(1, 2)) 48 return false; 49 50 Type *ScalarTy = OverloadTy->getScalarType(); 51 return ScalarTy->isDoubleTy() || ScalarTy->isIntegerTy(64); 52 } 53 54 static bool isIntrinsicExpansion(Function &F) { 55 switch (F.getIntrinsicID()) { 56 case Intrinsic::abs: 57 case Intrinsic::atan2: 58 case Intrinsic::exp: 59 case Intrinsic::is_fpclass: 60 case Intrinsic::log: 61 case Intrinsic::log10: 62 case Intrinsic::pow: 63 case Intrinsic::powi: 64 case Intrinsic::dx_all: 65 case Intrinsic::dx_any: 66 case Intrinsic::dx_cross: 67 case Intrinsic::dx_uclamp: 68 case Intrinsic::dx_sclamp: 69 case Intrinsic::dx_nclamp: 70 case Intrinsic::dx_degrees: 71 case Intrinsic::dx_lerp: 72 case Intrinsic::dx_normalize: 73 case Intrinsic::dx_fdot: 74 case Intrinsic::dx_sdot: 75 case Intrinsic::dx_udot: 76 case Intrinsic::dx_sign: 77 case Intrinsic::dx_step: 78 case Intrinsic::dx_radians: 79 case Intrinsic::usub_sat: 80 case Intrinsic::vector_reduce_add: 81 case Intrinsic::vector_reduce_fadd: 82 return true; 83 case Intrinsic::dx_resource_load_rawbuffer: 84 return resourceAccessNeeds64BitExpansion( 85 F.getParent(), F.getReturnType()->getStructElementType(0), 86 /*IsRaw*/ true); 87 case Intrinsic::dx_resource_load_typedbuffer: 88 return resourceAccessNeeds64BitExpansion( 89 F.getParent(), F.getReturnType()->getStructElementType(0), 90 /*IsRaw*/ false); 91 case Intrinsic::dx_resource_store_rawbuffer: 92 return resourceAccessNeeds64BitExpansion( 93 F.getParent(), F.getFunctionType()->getParamType(3), /*IsRaw*/ true); 94 case Intrinsic::dx_resource_store_typedbuffer: 95 return resourceAccessNeeds64BitExpansion( 96 F.getParent(), F.getFunctionType()->getParamType(2), /*IsRaw*/ false); 97 } 98 return false; 99 } 100 101 static Value *expandUsubSat(CallInst *Orig) { 102 Value *A = Orig->getArgOperand(0); 103 Value *B = Orig->getArgOperand(1); 104 Type *Ty = A->getType(); 105 106 IRBuilder<> Builder(Orig); 107 108 Value *Cmp = Builder.CreateICmpULT(A, B, "usub.cmp"); 109 Value *Sub = Builder.CreateSub(A, B, "usub.sub"); 110 Value *Zero = ConstantInt::get(Ty, 0); 111 return Builder.CreateSelect(Cmp, Zero, Sub, "usub.sat"); 112 } 113 114 static Value *expandVecReduceAdd(CallInst *Orig, Intrinsic::ID IntrinsicId) { 115 assert(IntrinsicId == Intrinsic::vector_reduce_add || 116 IntrinsicId == Intrinsic::vector_reduce_fadd); 117 118 IRBuilder<> Builder(Orig); 119 bool IsFAdd = (IntrinsicId == Intrinsic::vector_reduce_fadd); 120 121 Value *X = Orig->getOperand(IsFAdd ? 1 : 0); 122 Type *Ty = X->getType(); 123 auto *XVec = dyn_cast<FixedVectorType>(Ty); 124 unsigned XVecSize = XVec->getNumElements(); 125 Value *Sum = Builder.CreateExtractElement(X, static_cast<uint64_t>(0)); 126 127 // Handle the initial start value for floating-point addition. 128 if (IsFAdd) { 129 Constant *StartValue = dyn_cast<Constant>(Orig->getOperand(0)); 130 if (StartValue && !StartValue->isZeroValue()) 131 Sum = Builder.CreateFAdd(Sum, StartValue); 132 } 133 134 // Accumulate the remaining vector elements. 135 for (unsigned I = 1; I < XVecSize; I++) { 136 Value *Elt = Builder.CreateExtractElement(X, I); 137 if (IsFAdd) 138 Sum = Builder.CreateFAdd(Sum, Elt); 139 else 140 Sum = Builder.CreateAdd(Sum, Elt); 141 } 142 143 return Sum; 144 } 145 146 static Value *expandAbs(CallInst *Orig) { 147 Value *X = Orig->getOperand(0); 148 IRBuilder<> Builder(Orig); 149 Type *Ty = X->getType(); 150 Type *EltTy = Ty->getScalarType(); 151 Constant *Zero = Ty->isVectorTy() 152 ? ConstantVector::getSplat( 153 ElementCount::getFixed( 154 cast<FixedVectorType>(Ty)->getNumElements()), 155 ConstantInt::get(EltTy, 0)) 156 : ConstantInt::get(EltTy, 0); 157 auto *V = Builder.CreateSub(Zero, X); 158 return Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, 159 "dx.max"); 160 } 161 162 static Value *expandCrossIntrinsic(CallInst *Orig) { 163 164 VectorType *VT = cast<VectorType>(Orig->getType()); 165 if (cast<FixedVectorType>(VT)->getNumElements() != 3) 166 reportFatalUsageError("return vector must have exactly 3 elements"); 167 168 Value *op0 = Orig->getOperand(0); 169 Value *op1 = Orig->getOperand(1); 170 IRBuilder<> Builder(Orig); 171 172 Value *op0_x = Builder.CreateExtractElement(op0, (uint64_t)0, "x0"); 173 Value *op0_y = Builder.CreateExtractElement(op0, 1, "x1"); 174 Value *op0_z = Builder.CreateExtractElement(op0, 2, "x2"); 175 176 Value *op1_x = Builder.CreateExtractElement(op1, (uint64_t)0, "y0"); 177 Value *op1_y = Builder.CreateExtractElement(op1, 1, "y1"); 178 Value *op1_z = Builder.CreateExtractElement(op1, 2, "y2"); 179 180 auto MulSub = [&](Value *x0, Value *y0, Value *x1, Value *y1) -> Value * { 181 Value *xy = Builder.CreateFMul(x0, y1); 182 Value *yx = Builder.CreateFMul(y0, x1); 183 return Builder.CreateFSub(xy, yx, Orig->getName()); 184 }; 185 186 Value *yz_zy = MulSub(op0_y, op0_z, op1_y, op1_z); 187 Value *zx_xz = MulSub(op0_z, op0_x, op1_z, op1_x); 188 Value *xy_yx = MulSub(op0_x, op0_y, op1_x, op1_y); 189 190 Value *cross = PoisonValue::get(VT); 191 cross = Builder.CreateInsertElement(cross, yz_zy, (uint64_t)0); 192 cross = Builder.CreateInsertElement(cross, zx_xz, 1); 193 cross = Builder.CreateInsertElement(cross, xy_yx, 2); 194 return cross; 195 } 196 197 // Create appropriate DXIL float dot intrinsic for the given A and B operands 198 // The appropriate opcode will be determined by the size of the operands 199 // The dot product is placed in the position indicated by Orig 200 static Value *expandFloatDotIntrinsic(CallInst *Orig, Value *A, Value *B) { 201 Type *ATy = A->getType(); 202 [[maybe_unused]] Type *BTy = B->getType(); 203 assert(ATy->isVectorTy() && BTy->isVectorTy()); 204 205 IRBuilder<> Builder(Orig); 206 207 auto *AVec = dyn_cast<FixedVectorType>(ATy); 208 209 assert(ATy->getScalarType()->isFloatingPointTy()); 210 211 Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4; 212 int NumElts = AVec->getNumElements(); 213 switch (NumElts) { 214 case 2: 215 DotIntrinsic = Intrinsic::dx_dot2; 216 break; 217 case 3: 218 DotIntrinsic = Intrinsic::dx_dot3; 219 break; 220 case 4: 221 DotIntrinsic = Intrinsic::dx_dot4; 222 break; 223 default: 224 reportFatalUsageError( 225 "Invalid dot product input vector: length is outside 2-4"); 226 return nullptr; 227 } 228 229 SmallVector<Value *> Args; 230 for (int I = 0; I < NumElts; ++I) 231 Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I))); 232 for (int I = 0; I < NumElts; ++I) 233 Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I))); 234 return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args, 235 nullptr, "dot"); 236 } 237 238 // Create the appropriate DXIL float dot intrinsic for the operands of Orig 239 // The appropriate opcode will be determined by the size of the operands 240 // The dot product is placed in the position indicated by Orig 241 static Value *expandFloatDotIntrinsic(CallInst *Orig) { 242 return expandFloatDotIntrinsic(Orig, Orig->getOperand(0), 243 Orig->getOperand(1)); 244 } 245 246 // Expand integer dot product to multiply and add ops 247 static Value *expandIntegerDotIntrinsic(CallInst *Orig, 248 Intrinsic::ID DotIntrinsic) { 249 assert(DotIntrinsic == Intrinsic::dx_sdot || 250 DotIntrinsic == Intrinsic::dx_udot); 251 Value *A = Orig->getOperand(0); 252 Value *B = Orig->getOperand(1); 253 Type *ATy = A->getType(); 254 [[maybe_unused]] Type *BTy = B->getType(); 255 assert(ATy->isVectorTy() && BTy->isVectorTy()); 256 257 IRBuilder<> Builder(Orig); 258 259 auto *AVec = dyn_cast<FixedVectorType>(ATy); 260 261 assert(ATy->getScalarType()->isIntegerTy()); 262 263 Value *Result; 264 Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot 265 ? Intrinsic::dx_imad 266 : Intrinsic::dx_umad; 267 Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0); 268 Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0); 269 Result = Builder.CreateMul(Elt0, Elt1); 270 for (unsigned I = 1; I < AVec->getNumElements(); I++) { 271 Elt0 = Builder.CreateExtractElement(A, I); 272 Elt1 = Builder.CreateExtractElement(B, I); 273 Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic, 274 ArrayRef<Value *>{Elt0, Elt1, Result}, 275 nullptr, "dx.mad"); 276 } 277 return Result; 278 } 279 280 static Value *expandExpIntrinsic(CallInst *Orig) { 281 Value *X = Orig->getOperand(0); 282 IRBuilder<> Builder(Orig); 283 Type *Ty = X->getType(); 284 Type *EltTy = Ty->getScalarType(); 285 Constant *Log2eConst = 286 Ty->isVectorTy() ? ConstantVector::getSplat( 287 ElementCount::getFixed( 288 cast<FixedVectorType>(Ty)->getNumElements()), 289 ConstantFP::get(EltTy, numbers::log2ef)) 290 : ConstantFP::get(EltTy, numbers::log2ef); 291 Value *NewX = Builder.CreateFMul(Log2eConst, X); 292 auto *Exp2Call = 293 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2"); 294 Exp2Call->setTailCall(Orig->isTailCall()); 295 Exp2Call->setAttributes(Orig->getAttributes()); 296 return Exp2Call; 297 } 298 299 static Value *expandIsFPClass(CallInst *Orig) { 300 Value *T = Orig->getArgOperand(1); 301 auto *TCI = dyn_cast<ConstantInt>(T); 302 303 // These FPClassTest cases have DXIL opcodes, so they will be handled in 304 // DXIL Op Lowering instead. 305 switch (TCI->getZExtValue()) { 306 case FPClassTest::fcInf: 307 case FPClassTest::fcNan: 308 case FPClassTest::fcNormal: 309 case FPClassTest::fcFinite: 310 return nullptr; 311 } 312 313 IRBuilder<> Builder(Orig); 314 315 Value *F = Orig->getArgOperand(0); 316 Type *FTy = F->getType(); 317 unsigned FNumElem = 0; // 0 => F is not a vector 318 319 unsigned BitWidth; // Bit width of F or the ElemTy of F 320 Type *BitCastTy; // An IntNTy of the same bitwidth as F or ElemTy of F 321 322 if (auto *FVecTy = dyn_cast<FixedVectorType>(FTy)) { 323 Type *ElemTy = FVecTy->getElementType(); 324 FNumElem = FVecTy->getNumElements(); 325 BitWidth = ElemTy->getPrimitiveSizeInBits(); 326 BitCastTy = FixedVectorType::get(Builder.getIntNTy(BitWidth), FNumElem); 327 } else { 328 BitWidth = FTy->getPrimitiveSizeInBits(); 329 BitCastTy = Builder.getIntNTy(BitWidth); 330 } 331 332 Value *FBitCast = Builder.CreateBitCast(F, BitCastTy); 333 switch (TCI->getZExtValue()) { 334 case FPClassTest::fcNegZero: { 335 Value *NegZero = 336 ConstantInt::get(Builder.getIntNTy(BitWidth), 1 << (BitWidth - 1)); 337 Value *RetVal; 338 if (FNumElem) { 339 Value *NegZeroSplat = Builder.CreateVectorSplat(FNumElem, NegZero); 340 RetVal = 341 Builder.CreateICmpEQ(FBitCast, NegZeroSplat, "is.fpclass.negzero"); 342 } else 343 RetVal = Builder.CreateICmpEQ(FBitCast, NegZero, "is.fpclass.negzero"); 344 return RetVal; 345 } 346 default: 347 reportFatalUsageError("Unsupported FPClassTest"); 348 } 349 } 350 351 static Value *expandAnyOrAllIntrinsic(CallInst *Orig, 352 Intrinsic::ID IntrinsicId) { 353 Value *X = Orig->getOperand(0); 354 IRBuilder<> Builder(Orig); 355 Type *Ty = X->getType(); 356 Type *EltTy = Ty->getScalarType(); 357 358 auto ApplyOp = [&Builder](Intrinsic::ID IntrinsicId, Value *Result, 359 Value *Elt) { 360 if (IntrinsicId == Intrinsic::dx_any) 361 return Builder.CreateOr(Result, Elt); 362 assert(IntrinsicId == Intrinsic::dx_all); 363 return Builder.CreateAnd(Result, Elt); 364 }; 365 366 Value *Result = nullptr; 367 if (!Ty->isVectorTy()) { 368 Result = EltTy->isFloatingPointTy() 369 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0)) 370 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0)); 371 } else { 372 auto *XVec = dyn_cast<FixedVectorType>(Ty); 373 Value *Cond = 374 EltTy->isFloatingPointTy() 375 ? Builder.CreateFCmpUNE( 376 X, ConstantVector::getSplat( 377 ElementCount::getFixed(XVec->getNumElements()), 378 ConstantFP::get(EltTy, 0))) 379 : Builder.CreateICmpNE( 380 X, ConstantVector::getSplat( 381 ElementCount::getFixed(XVec->getNumElements()), 382 ConstantInt::get(EltTy, 0))); 383 Result = Builder.CreateExtractElement(Cond, (uint64_t)0); 384 for (unsigned I = 1; I < XVec->getNumElements(); I++) { 385 Value *Elt = Builder.CreateExtractElement(Cond, I); 386 Result = ApplyOp(IntrinsicId, Result, Elt); 387 } 388 } 389 return Result; 390 } 391 392 static Value *expandLerpIntrinsic(CallInst *Orig) { 393 Value *X = Orig->getOperand(0); 394 Value *Y = Orig->getOperand(1); 395 Value *S = Orig->getOperand(2); 396 IRBuilder<> Builder(Orig); 397 auto *V = Builder.CreateFSub(Y, X); 398 V = Builder.CreateFMul(S, V); 399 return Builder.CreateFAdd(X, V, "dx.lerp"); 400 } 401 402 static Value *expandLogIntrinsic(CallInst *Orig, 403 float LogConstVal = numbers::ln2f) { 404 Value *X = Orig->getOperand(0); 405 IRBuilder<> Builder(Orig); 406 Type *Ty = X->getType(); 407 Type *EltTy = Ty->getScalarType(); 408 Constant *Ln2Const = 409 Ty->isVectorTy() ? ConstantVector::getSplat( 410 ElementCount::getFixed( 411 cast<FixedVectorType>(Ty)->getNumElements()), 412 ConstantFP::get(EltTy, LogConstVal)) 413 : ConstantFP::get(EltTy, LogConstVal); 414 auto *Log2Call = 415 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2"); 416 Log2Call->setTailCall(Orig->isTailCall()); 417 Log2Call->setAttributes(Orig->getAttributes()); 418 return Builder.CreateFMul(Ln2Const, Log2Call); 419 } 420 static Value *expandLog10Intrinsic(CallInst *Orig) { 421 return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f); 422 } 423 424 // Use dot product of vector operand with itself to calculate the length. 425 // Divide the vector by that length to normalize it. 426 static Value *expandNormalizeIntrinsic(CallInst *Orig) { 427 Value *X = Orig->getOperand(0); 428 Type *Ty = Orig->getType(); 429 Type *EltTy = Ty->getScalarType(); 430 IRBuilder<> Builder(Orig); 431 432 auto *XVec = dyn_cast<FixedVectorType>(Ty); 433 if (!XVec) { 434 if (auto *constantFP = dyn_cast<ConstantFP>(X)) { 435 const APFloat &fpVal = constantFP->getValueAPF(); 436 if (fpVal.isZero()) 437 reportFatalUsageError("Invalid input scalar: length is zero"); 438 } 439 return Builder.CreateFDiv(X, X); 440 } 441 442 Value *DotProduct = expandFloatDotIntrinsic(Orig, X, X); 443 444 // verify that the length is non-zero 445 // (if the dot product is non-zero, then the length is non-zero) 446 if (auto *constantFP = dyn_cast<ConstantFP>(DotProduct)) { 447 const APFloat &fpVal = constantFP->getValueAPF(); 448 if (fpVal.isZero()) 449 reportFatalUsageError("Invalid input vector: length is zero"); 450 } 451 452 Value *Multiplicand = Builder.CreateIntrinsic(EltTy, Intrinsic::dx_rsqrt, 453 ArrayRef<Value *>{DotProduct}, 454 nullptr, "dx.rsqrt"); 455 456 Value *MultiplicandVec = 457 Builder.CreateVectorSplat(XVec->getNumElements(), Multiplicand); 458 return Builder.CreateFMul(X, MultiplicandVec); 459 } 460 461 static Value *expandAtan2Intrinsic(CallInst *Orig) { 462 Value *Y = Orig->getOperand(0); 463 Value *X = Orig->getOperand(1); 464 Type *Ty = X->getType(); 465 IRBuilder<> Builder(Orig); 466 Builder.setFastMathFlags(Orig->getFastMathFlags()); 467 468 Value *Tan = Builder.CreateFDiv(Y, X); 469 470 CallInst *Atan = 471 Builder.CreateIntrinsic(Ty, Intrinsic::atan, {Tan}, nullptr, "Elt.Atan"); 472 Atan->setTailCall(Orig->isTailCall()); 473 Atan->setAttributes(Orig->getAttributes()); 474 475 // Modify atan result based on https://en.wikipedia.org/wiki/Atan2. 476 Constant *Pi = ConstantFP::get(Ty, llvm::numbers::pi); 477 Constant *HalfPi = ConstantFP::get(Ty, llvm::numbers::pi / 2); 478 Constant *NegHalfPi = ConstantFP::get(Ty, -llvm::numbers::pi / 2); 479 Constant *Zero = ConstantFP::get(Ty, 0); 480 Value *AtanAddPi = Builder.CreateFAdd(Atan, Pi); 481 Value *AtanSubPi = Builder.CreateFSub(Atan, Pi); 482 483 // x > 0 -> atan. 484 Value *Result = Atan; 485 Value *XLt0 = Builder.CreateFCmpOLT(X, Zero); 486 Value *XEq0 = Builder.CreateFCmpOEQ(X, Zero); 487 Value *YGe0 = Builder.CreateFCmpOGE(Y, Zero); 488 Value *YLt0 = Builder.CreateFCmpOLT(Y, Zero); 489 490 // x < 0, y >= 0 -> atan + pi. 491 Value *XLt0AndYGe0 = Builder.CreateAnd(XLt0, YGe0); 492 Result = Builder.CreateSelect(XLt0AndYGe0, AtanAddPi, Result); 493 494 // x < 0, y < 0 -> atan - pi. 495 Value *XLt0AndYLt0 = Builder.CreateAnd(XLt0, YLt0); 496 Result = Builder.CreateSelect(XLt0AndYLt0, AtanSubPi, Result); 497 498 // x == 0, y < 0 -> -pi/2 499 Value *XEq0AndYLt0 = Builder.CreateAnd(XEq0, YLt0); 500 Result = Builder.CreateSelect(XEq0AndYLt0, NegHalfPi, Result); 501 502 // x == 0, y > 0 -> pi/2 503 Value *XEq0AndYGe0 = Builder.CreateAnd(XEq0, YGe0); 504 Result = Builder.CreateSelect(XEq0AndYGe0, HalfPi, Result); 505 506 return Result; 507 } 508 509 static Value *expandPowIntrinsic(CallInst *Orig, Intrinsic::ID IntrinsicId) { 510 511 Value *X = Orig->getOperand(0); 512 Value *Y = Orig->getOperand(1); 513 Type *Ty = X->getType(); 514 IRBuilder<> Builder(Orig); 515 516 if (IntrinsicId == Intrinsic::powi) 517 Y = Builder.CreateSIToFP(Y, Ty); 518 519 auto *Log2Call = 520 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2"); 521 auto *Mul = Builder.CreateFMul(Log2Call, Y); 522 auto *Exp2Call = 523 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2"); 524 Exp2Call->setTailCall(Orig->isTailCall()); 525 Exp2Call->setAttributes(Orig->getAttributes()); 526 return Exp2Call; 527 } 528 529 static Value *expandStepIntrinsic(CallInst *Orig) { 530 531 Value *X = Orig->getOperand(0); 532 Value *Y = Orig->getOperand(1); 533 Type *Ty = X->getType(); 534 IRBuilder<> Builder(Orig); 535 536 Constant *One = ConstantFP::get(Ty->getScalarType(), 1.0); 537 Constant *Zero = ConstantFP::get(Ty->getScalarType(), 0.0); 538 Value *Cond = Builder.CreateFCmpOLT(Y, X); 539 540 if (Ty != Ty->getScalarType()) { 541 auto *XVec = dyn_cast<FixedVectorType>(Ty); 542 One = ConstantVector::getSplat( 543 ElementCount::getFixed(XVec->getNumElements()), One); 544 Zero = ConstantVector::getSplat( 545 ElementCount::getFixed(XVec->getNumElements()), Zero); 546 } 547 548 return Builder.CreateSelect(Cond, Zero, One); 549 } 550 551 static Value *expandRadiansIntrinsic(CallInst *Orig) { 552 Value *X = Orig->getOperand(0); 553 Type *Ty = X->getType(); 554 IRBuilder<> Builder(Orig); 555 Value *PiOver180 = ConstantFP::get(Ty, llvm::numbers::pi / 180.0); 556 return Builder.CreateFMul(X, PiOver180); 557 } 558 559 static bool expandBufferLoadIntrinsic(CallInst *Orig, bool IsRaw) { 560 IRBuilder<> Builder(Orig); 561 562 Type *BufferTy = Orig->getType()->getStructElementType(0); 563 Type *ScalarTy = BufferTy->getScalarType(); 564 bool IsDouble = ScalarTy->isDoubleTy(); 565 assert(IsDouble || ScalarTy->isIntegerTy(64) && 566 "Only expand double or int64 scalars or vectors"); 567 bool IsVector = false; 568 unsigned ExtractNum = 2; 569 if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) { 570 ExtractNum = 2 * VT->getNumElements(); 571 IsVector = true; 572 assert(IsRaw || ExtractNum == 4 && "TypedBufferLoad vector must be size 2"); 573 } 574 575 SmallVector<Value *, 2> Loads; 576 Value *Result = PoisonValue::get(BufferTy); 577 unsigned Base = 0; 578 // If we need to extract more than 4 i32; we need to break it up into 579 // more than one load. LoadNum tells us how many i32s we are loading in 580 // each load 581 while (ExtractNum > 0) { 582 unsigned LoadNum = std::min(ExtractNum, 4u); 583 Type *Ty = VectorType::get(Builder.getInt32Ty(), LoadNum, false); 584 585 Type *LoadType = StructType::get(Ty, Builder.getInt1Ty()); 586 Intrinsic::ID LoadIntrinsic = Intrinsic::dx_resource_load_typedbuffer; 587 SmallVector<Value *, 3> Args = {Orig->getOperand(0), Orig->getOperand(1)}; 588 if (IsRaw) { 589 LoadIntrinsic = Intrinsic::dx_resource_load_rawbuffer; 590 Value *Tmp = Builder.getInt32(4 * Base * 2); 591 Args.push_back(Builder.CreateAdd(Orig->getOperand(2), Tmp)); 592 } 593 594 CallInst *Load = Builder.CreateIntrinsic(LoadType, LoadIntrinsic, Args); 595 Loads.push_back(Load); 596 597 // extract the buffer load's result 598 Value *Extract = Builder.CreateExtractValue(Load, {0}); 599 600 SmallVector<Value *> ExtractElements; 601 for (unsigned I = 0; I < LoadNum; ++I) 602 ExtractElements.push_back( 603 Builder.CreateExtractElement(Extract, Builder.getInt32(I))); 604 605 // combine into double(s) or int64(s) 606 for (unsigned I = 0; I < LoadNum; I += 2) { 607 Value *Combined = nullptr; 608 if (IsDouble) 609 // For doubles, use dx_asdouble intrinsic 610 Combined = Builder.CreateIntrinsic( 611 Builder.getDoubleTy(), Intrinsic::dx_asdouble, 612 {ExtractElements[I], ExtractElements[I + 1]}); 613 else { 614 // For int64, manually combine two int32s 615 // First, zero-extend both values to i64 616 Value *Lo = 617 Builder.CreateZExt(ExtractElements[I], Builder.getInt64Ty()); 618 Value *Hi = 619 Builder.CreateZExt(ExtractElements[I + 1], Builder.getInt64Ty()); 620 // Shift the high bits left by 32 bits 621 Value *ShiftedHi = Builder.CreateShl(Hi, Builder.getInt64(32)); 622 // OR the high and low bits together 623 Combined = Builder.CreateOr(Lo, ShiftedHi); 624 } 625 626 if (IsVector) 627 Result = Builder.CreateInsertElement(Result, Combined, 628 Builder.getInt32((I / 2) + Base)); 629 else 630 Result = Combined; 631 } 632 633 ExtractNum -= LoadNum; 634 Base += LoadNum / 2; 635 } 636 637 Value *CheckBit = nullptr; 638 for (User *U : make_early_inc_range(Orig->users())) { 639 // If it's not a ExtractValueInst, we don't know how to 640 // handle it 641 auto *EVI = dyn_cast<ExtractValueInst>(U); 642 if (!EVI) 643 llvm_unreachable("Unexpected user of typedbufferload"); 644 645 ArrayRef<unsigned> Indices = EVI->getIndices(); 646 assert(Indices.size() == 1); 647 648 if (Indices[0] == 0) { 649 // Use of the value(s) 650 EVI->replaceAllUsesWith(Result); 651 } else { 652 // Use of the check bit 653 assert(Indices[0] == 1 && "Unexpected type for typedbufferload"); 654 // Note: This does not always match the historical behaviour of DXC. 655 // See https://github.com/microsoft/DirectXShaderCompiler/issues/7622 656 if (!CheckBit) { 657 SmallVector<Value *, 2> CheckBits; 658 for (Value *L : Loads) 659 CheckBits.push_back(Builder.CreateExtractValue(L, {1})); 660 CheckBit = Builder.CreateAnd(CheckBits); 661 } 662 EVI->replaceAllUsesWith(CheckBit); 663 } 664 EVI->eraseFromParent(); 665 } 666 Orig->eraseFromParent(); 667 return true; 668 } 669 670 static bool expandBufferStoreIntrinsic(CallInst *Orig, bool IsRaw) { 671 IRBuilder<> Builder(Orig); 672 673 unsigned ValIndex = IsRaw ? 3 : 2; 674 Type *BufferTy = Orig->getFunctionType()->getParamType(ValIndex); 675 Type *ScalarTy = BufferTy->getScalarType(); 676 bool IsDouble = ScalarTy->isDoubleTy(); 677 assert((IsDouble || ScalarTy->isIntegerTy(64)) && 678 "Only expand double or int64 scalars or vectors"); 679 680 // Determine if we're dealing with a vector or scalar 681 bool IsVector = false; 682 unsigned ExtractNum = 2; 683 unsigned VecLen = 0; 684 if (auto *VT = dyn_cast<FixedVectorType>(BufferTy)) { 685 VecLen = VT->getNumElements(); 686 assert(IsRaw || VecLen == 2 && "TypedBufferStore vector must be size 2"); 687 ExtractNum = VecLen * 2; 688 IsVector = true; 689 } 690 691 // Create the appropriate vector type for the result 692 Type *Int32Ty = Builder.getInt32Ty(); 693 Type *ResultTy = VectorType::get(Int32Ty, ExtractNum, false); 694 Value *Val = PoisonValue::get(ResultTy); 695 696 Type *SplitElementTy = Int32Ty; 697 if (IsVector) 698 SplitElementTy = VectorType::get(SplitElementTy, VecLen, false); 699 700 Value *LowBits = nullptr; 701 Value *HighBits = nullptr; 702 // Split the 64-bit values into 32-bit components 703 if (IsDouble) { 704 auto *SplitTy = llvm::StructType::get(SplitElementTy, SplitElementTy); 705 Value *Split = Builder.CreateIntrinsic(SplitTy, Intrinsic::dx_splitdouble, 706 {Orig->getOperand(ValIndex)}); 707 LowBits = Builder.CreateExtractValue(Split, 0); 708 HighBits = Builder.CreateExtractValue(Split, 1); 709 } else { 710 // Handle int64 type(s) 711 Value *InputVal = Orig->getOperand(ValIndex); 712 Constant *ShiftAmt = Builder.getInt64(32); 713 if (IsVector) 714 ShiftAmt = 715 ConstantVector::getSplat(ElementCount::getFixed(VecLen), ShiftAmt); 716 717 // Split into low and high 32-bit parts 718 LowBits = Builder.CreateTrunc(InputVal, SplitElementTy); 719 Value *ShiftedVal = Builder.CreateLShr(InputVal, ShiftAmt); 720 HighBits = Builder.CreateTrunc(ShiftedVal, SplitElementTy); 721 } 722 723 if (IsVector) { 724 SmallVector<int, 8> Mask; 725 for (unsigned I = 0; I < VecLen; ++I) { 726 Mask.push_back(I); 727 Mask.push_back(I + VecLen); 728 } 729 Val = Builder.CreateShuffleVector(LowBits, HighBits, Mask); 730 } else { 731 Val = Builder.CreateInsertElement(Val, LowBits, Builder.getInt32(0)); 732 Val = Builder.CreateInsertElement(Val, HighBits, Builder.getInt32(1)); 733 } 734 735 // If we need to extract more than 4 i32; we need to break it up into 736 // more than one store. StoreNum tells us how many i32s we are storing in 737 // each store 738 unsigned Base = 0; 739 while (ExtractNum > 0) { 740 unsigned StoreNum = std::min(ExtractNum, 4u); 741 742 Intrinsic::ID StoreIntrinsic = Intrinsic::dx_resource_store_typedbuffer; 743 SmallVector<Value *, 4> Args = {Orig->getOperand(0), Orig->getOperand(1)}; 744 if (IsRaw) { 745 StoreIntrinsic = Intrinsic::dx_resource_store_rawbuffer; 746 Value *Tmp = Builder.getInt32(4 * Base); 747 Args.push_back(Builder.CreateAdd(Orig->getOperand(2), Tmp)); 748 } 749 750 SmallVector<int, 4> Mask; 751 for (unsigned I = 0; I < StoreNum; ++I) { 752 Mask.push_back(Base + I); 753 } 754 755 Value *SubVal = Val; 756 if (VecLen > 2) 757 SubVal = Builder.CreateShuffleVector(Val, Mask); 758 759 Args.push_back(SubVal); 760 // Create the final intrinsic call 761 Builder.CreateIntrinsic(Builder.getVoidTy(), StoreIntrinsic, Args); 762 763 ExtractNum -= StoreNum; 764 Base += StoreNum; 765 } 766 Orig->eraseFromParent(); 767 return true; 768 } 769 770 static Intrinsic::ID getMaxForClamp(Intrinsic::ID ClampIntrinsic) { 771 if (ClampIntrinsic == Intrinsic::dx_uclamp) 772 return Intrinsic::umax; 773 if (ClampIntrinsic == Intrinsic::dx_sclamp) 774 return Intrinsic::smax; 775 assert(ClampIntrinsic == Intrinsic::dx_nclamp); 776 return Intrinsic::maxnum; 777 } 778 779 static Intrinsic::ID getMinForClamp(Intrinsic::ID ClampIntrinsic) { 780 if (ClampIntrinsic == Intrinsic::dx_uclamp) 781 return Intrinsic::umin; 782 if (ClampIntrinsic == Intrinsic::dx_sclamp) 783 return Intrinsic::smin; 784 assert(ClampIntrinsic == Intrinsic::dx_nclamp); 785 return Intrinsic::minnum; 786 } 787 788 static Value *expandClampIntrinsic(CallInst *Orig, 789 Intrinsic::ID ClampIntrinsic) { 790 Value *X = Orig->getOperand(0); 791 Value *Min = Orig->getOperand(1); 792 Value *Max = Orig->getOperand(2); 793 Type *Ty = X->getType(); 794 IRBuilder<> Builder(Orig); 795 auto *MaxCall = Builder.CreateIntrinsic(Ty, getMaxForClamp(ClampIntrinsic), 796 {X, Min}, nullptr, "dx.max"); 797 return Builder.CreateIntrinsic(Ty, getMinForClamp(ClampIntrinsic), 798 {MaxCall, Max}, nullptr, "dx.min"); 799 } 800 801 static Value *expandDegreesIntrinsic(CallInst *Orig) { 802 Value *X = Orig->getOperand(0); 803 Type *Ty = X->getType(); 804 IRBuilder<> Builder(Orig); 805 Value *DegreesRatio = ConstantFP::get(Ty, 180.0 * llvm::numbers::inv_pi); 806 return Builder.CreateFMul(X, DegreesRatio); 807 } 808 809 static Value *expandSignIntrinsic(CallInst *Orig) { 810 Value *X = Orig->getOperand(0); 811 Type *Ty = X->getType(); 812 Type *ScalarTy = Ty->getScalarType(); 813 Type *RetTy = Orig->getType(); 814 Constant *Zero = Constant::getNullValue(Ty); 815 816 IRBuilder<> Builder(Orig); 817 818 Value *GT; 819 Value *LT; 820 if (ScalarTy->isFloatingPointTy()) { 821 GT = Builder.CreateFCmpOLT(Zero, X); 822 LT = Builder.CreateFCmpOLT(X, Zero); 823 } else { 824 assert(ScalarTy->isIntegerTy()); 825 GT = Builder.CreateICmpSLT(Zero, X); 826 LT = Builder.CreateICmpSLT(X, Zero); 827 } 828 829 Value *ZextGT = Builder.CreateZExt(GT, RetTy); 830 Value *ZextLT = Builder.CreateZExt(LT, RetTy); 831 832 return Builder.CreateSub(ZextGT, ZextLT); 833 } 834 835 static bool expandIntrinsic(Function &F, CallInst *Orig) { 836 Value *Result = nullptr; 837 Intrinsic::ID IntrinsicId = F.getIntrinsicID(); 838 switch (IntrinsicId) { 839 case Intrinsic::abs: 840 Result = expandAbs(Orig); 841 break; 842 case Intrinsic::atan2: 843 Result = expandAtan2Intrinsic(Orig); 844 break; 845 case Intrinsic::exp: 846 Result = expandExpIntrinsic(Orig); 847 break; 848 case Intrinsic::is_fpclass: 849 Result = expandIsFPClass(Orig); 850 break; 851 case Intrinsic::log: 852 Result = expandLogIntrinsic(Orig); 853 break; 854 case Intrinsic::log10: 855 Result = expandLog10Intrinsic(Orig); 856 break; 857 case Intrinsic::pow: 858 case Intrinsic::powi: 859 Result = expandPowIntrinsic(Orig, IntrinsicId); 860 break; 861 case Intrinsic::dx_all: 862 case Intrinsic::dx_any: 863 Result = expandAnyOrAllIntrinsic(Orig, IntrinsicId); 864 break; 865 case Intrinsic::dx_cross: 866 Result = expandCrossIntrinsic(Orig); 867 break; 868 case Intrinsic::dx_uclamp: 869 case Intrinsic::dx_sclamp: 870 case Intrinsic::dx_nclamp: 871 Result = expandClampIntrinsic(Orig, IntrinsicId); 872 break; 873 case Intrinsic::dx_degrees: 874 Result = expandDegreesIntrinsic(Orig); 875 break; 876 case Intrinsic::dx_lerp: 877 Result = expandLerpIntrinsic(Orig); 878 break; 879 case Intrinsic::dx_normalize: 880 Result = expandNormalizeIntrinsic(Orig); 881 break; 882 case Intrinsic::dx_fdot: 883 Result = expandFloatDotIntrinsic(Orig); 884 break; 885 case Intrinsic::dx_sdot: 886 case Intrinsic::dx_udot: 887 Result = expandIntegerDotIntrinsic(Orig, IntrinsicId); 888 break; 889 case Intrinsic::dx_sign: 890 Result = expandSignIntrinsic(Orig); 891 break; 892 case Intrinsic::dx_step: 893 Result = expandStepIntrinsic(Orig); 894 break; 895 case Intrinsic::dx_radians: 896 Result = expandRadiansIntrinsic(Orig); 897 break; 898 case Intrinsic::dx_resource_load_rawbuffer: 899 if (expandBufferLoadIntrinsic(Orig, /*IsRaw*/ true)) 900 return true; 901 break; 902 case Intrinsic::dx_resource_store_rawbuffer: 903 if (expandBufferStoreIntrinsic(Orig, /*IsRaw*/ true)) 904 return true; 905 break; 906 case Intrinsic::dx_resource_load_typedbuffer: 907 if (expandBufferLoadIntrinsic(Orig, /*IsRaw*/ false)) 908 return true; 909 break; 910 case Intrinsic::dx_resource_store_typedbuffer: 911 if (expandBufferStoreIntrinsic(Orig, /*IsRaw*/ false)) 912 return true; 913 break; 914 case Intrinsic::usub_sat: 915 Result = expandUsubSat(Orig); 916 break; 917 case Intrinsic::vector_reduce_add: 918 case Intrinsic::vector_reduce_fadd: 919 Result = expandVecReduceAdd(Orig, IntrinsicId); 920 break; 921 } 922 if (Result) { 923 Orig->replaceAllUsesWith(Result); 924 Orig->eraseFromParent(); 925 return true; 926 } 927 return false; 928 } 929 930 static bool expansionIntrinsics(Module &M) { 931 for (auto &F : make_early_inc_range(M.functions())) { 932 if (!isIntrinsicExpansion(F)) 933 continue; 934 bool IntrinsicExpanded = false; 935 for (User *U : make_early_inc_range(F.users())) { 936 auto *IntrinsicCall = dyn_cast<CallInst>(U); 937 if (!IntrinsicCall) 938 continue; 939 IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall); 940 } 941 if (F.user_empty() && IntrinsicExpanded) 942 F.eraseFromParent(); 943 } 944 return true; 945 } 946 947 PreservedAnalyses DXILIntrinsicExpansion::run(Module &M, 948 ModuleAnalysisManager &) { 949 if (expansionIntrinsics(M)) 950 return PreservedAnalyses::none(); 951 return PreservedAnalyses::all(); 952 } 953 954 bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) { 955 return expansionIntrinsics(M); 956 } 957 958 char DXILIntrinsicExpansionLegacy::ID = 0; 959 960 INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE, 961 "DXIL Intrinsic Expansion", false, false) 962 INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE, 963 "DXIL Intrinsic Expansion", false, false) 964 965 ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() { 966 return new DXILIntrinsicExpansionLegacy(); 967 } 968