1 //===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains DXIL intrinsic expansions for those that don't have 10 // opcodes in DirectX Intermediate Language (DXIL). 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILIntrinsicExpansion.h" 14 #include "DirectX.h" 15 #include "llvm/ADT/STLExtras.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/CodeGen/Passes.h" 18 #include "llvm/IR/IRBuilder.h" 19 #include "llvm/IR/Instruction.h" 20 #include "llvm/IR/Instructions.h" 21 #include "llvm/IR/Intrinsics.h" 22 #include "llvm/IR/IntrinsicsDirectX.h" 23 #include "llvm/IR/Module.h" 24 #include "llvm/IR/PassManager.h" 25 #include "llvm/IR/Type.h" 26 #include "llvm/Pass.h" 27 #include "llvm/Support/ErrorHandling.h" 28 #include "llvm/Support/MathExtras.h" 29 30 #define DEBUG_TYPE "dxil-intrinsic-expansion" 31 32 using namespace llvm; 33 34 static bool isIntrinsicExpansion(Function &F) { 35 switch (F.getIntrinsicID()) { 36 case Intrinsic::abs: 37 case Intrinsic::exp: 38 case Intrinsic::log: 39 case Intrinsic::log10: 40 case Intrinsic::pow: 41 case Intrinsic::dx_any: 42 case Intrinsic::dx_clamp: 43 case Intrinsic::dx_uclamp: 44 case Intrinsic::dx_lerp: 45 case Intrinsic::dx_sdot: 46 case Intrinsic::dx_udot: 47 return true; 48 } 49 return false; 50 } 51 52 static bool expandAbs(CallInst *Orig) { 53 Value *X = Orig->getOperand(0); 54 IRBuilder<> Builder(Orig->getParent()); 55 Builder.SetInsertPoint(Orig); 56 Type *Ty = X->getType(); 57 Type *EltTy = Ty->getScalarType(); 58 Constant *Zero = Ty->isVectorTy() 59 ? ConstantVector::getSplat( 60 ElementCount::getFixed( 61 cast<FixedVectorType>(Ty)->getNumElements()), 62 ConstantInt::get(EltTy, 0)) 63 : ConstantInt::get(EltTy, 0); 64 auto *V = Builder.CreateSub(Zero, X); 65 auto *MaxCall = 66 Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max"); 67 Orig->replaceAllUsesWith(MaxCall); 68 Orig->eraseFromParent(); 69 return true; 70 } 71 72 static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) { 73 assert(DotIntrinsic == Intrinsic::dx_sdot || 74 DotIntrinsic == Intrinsic::dx_udot); 75 Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot 76 ? Intrinsic::dx_imad 77 : Intrinsic::dx_umad; 78 Value *A = Orig->getOperand(0); 79 Value *B = Orig->getOperand(1); 80 [[maybe_unused]] Type *ATy = A->getType(); 81 [[maybe_unused]] Type *BTy = B->getType(); 82 assert(ATy->isVectorTy() && BTy->isVectorTy()); 83 84 IRBuilder<> Builder(Orig->getParent()); 85 Builder.SetInsertPoint(Orig); 86 87 auto *AVec = dyn_cast<FixedVectorType>(A->getType()); 88 Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0); 89 Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0); 90 Value *Result = Builder.CreateMul(Elt0, Elt1); 91 for (unsigned I = 1; I < AVec->getNumElements(); I++) { 92 Elt0 = Builder.CreateExtractElement(A, I); 93 Elt1 = Builder.CreateExtractElement(B, I); 94 Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic, 95 ArrayRef<Value *>{Elt0, Elt1, Result}, 96 nullptr, "dx.mad"); 97 } 98 Orig->replaceAllUsesWith(Result); 99 Orig->eraseFromParent(); 100 return true; 101 } 102 103 static bool expandExpIntrinsic(CallInst *Orig) { 104 Value *X = Orig->getOperand(0); 105 IRBuilder<> Builder(Orig->getParent()); 106 Builder.SetInsertPoint(Orig); 107 Type *Ty = X->getType(); 108 Type *EltTy = Ty->getScalarType(); 109 Constant *Log2eConst = 110 Ty->isVectorTy() ? ConstantVector::getSplat( 111 ElementCount::getFixed( 112 cast<FixedVectorType>(Ty)->getNumElements()), 113 ConstantFP::get(EltTy, numbers::log2ef)) 114 : ConstantFP::get(EltTy, numbers::log2ef); 115 Value *NewX = Builder.CreateFMul(Log2eConst, X); 116 auto *Exp2Call = 117 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2"); 118 Exp2Call->setTailCall(Orig->isTailCall()); 119 Exp2Call->setAttributes(Orig->getAttributes()); 120 Orig->replaceAllUsesWith(Exp2Call); 121 Orig->eraseFromParent(); 122 return true; 123 } 124 125 static bool expandAnyIntrinsic(CallInst *Orig) { 126 Value *X = Orig->getOperand(0); 127 IRBuilder<> Builder(Orig->getParent()); 128 Builder.SetInsertPoint(Orig); 129 Type *Ty = X->getType(); 130 Type *EltTy = Ty->getScalarType(); 131 132 if (!Ty->isVectorTy()) { 133 Value *Cond = EltTy->isFloatingPointTy() 134 ? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0)) 135 : Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0)); 136 Orig->replaceAllUsesWith(Cond); 137 } else { 138 auto *XVec = dyn_cast<FixedVectorType>(Ty); 139 Value *Cond = 140 EltTy->isFloatingPointTy() 141 ? Builder.CreateFCmpUNE( 142 X, ConstantVector::getSplat( 143 ElementCount::getFixed(XVec->getNumElements()), 144 ConstantFP::get(EltTy, 0))) 145 : Builder.CreateICmpNE( 146 X, ConstantVector::getSplat( 147 ElementCount::getFixed(XVec->getNumElements()), 148 ConstantInt::get(EltTy, 0))); 149 Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0); 150 for (unsigned I = 1; I < XVec->getNumElements(); I++) { 151 Value *Elt = Builder.CreateExtractElement(Cond, I); 152 Result = Builder.CreateOr(Result, Elt); 153 } 154 Orig->replaceAllUsesWith(Result); 155 } 156 Orig->eraseFromParent(); 157 return true; 158 } 159 160 static bool expandLerpIntrinsic(CallInst *Orig) { 161 Value *X = Orig->getOperand(0); 162 Value *Y = Orig->getOperand(1); 163 Value *S = Orig->getOperand(2); 164 IRBuilder<> Builder(Orig->getParent()); 165 Builder.SetInsertPoint(Orig); 166 auto *V = Builder.CreateFSub(Y, X); 167 V = Builder.CreateFMul(S, V); 168 auto *Result = Builder.CreateFAdd(X, V, "dx.lerp"); 169 Orig->replaceAllUsesWith(Result); 170 Orig->eraseFromParent(); 171 return true; 172 } 173 174 static bool expandLogIntrinsic(CallInst *Orig, 175 float LogConstVal = numbers::ln2f) { 176 Value *X = Orig->getOperand(0); 177 IRBuilder<> Builder(Orig->getParent()); 178 Builder.SetInsertPoint(Orig); 179 Type *Ty = X->getType(); 180 Type *EltTy = Ty->getScalarType(); 181 Constant *Ln2Const = 182 Ty->isVectorTy() ? ConstantVector::getSplat( 183 ElementCount::getFixed( 184 cast<FixedVectorType>(Ty)->getNumElements()), 185 ConstantFP::get(EltTy, LogConstVal)) 186 : ConstantFP::get(EltTy, LogConstVal); 187 auto *Log2Call = 188 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2"); 189 Log2Call->setTailCall(Orig->isTailCall()); 190 Log2Call->setAttributes(Orig->getAttributes()); 191 auto *Result = Builder.CreateFMul(Ln2Const, Log2Call); 192 Orig->replaceAllUsesWith(Result); 193 Orig->eraseFromParent(); 194 return true; 195 } 196 static bool expandLog10Intrinsic(CallInst *Orig) { 197 return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f); 198 } 199 200 static bool expandPowIntrinsic(CallInst *Orig) { 201 202 Value *X = Orig->getOperand(0); 203 Value *Y = Orig->getOperand(1); 204 Type *Ty = X->getType(); 205 IRBuilder<> Builder(Orig->getParent()); 206 Builder.SetInsertPoint(Orig); 207 208 auto *Log2Call = 209 Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2"); 210 auto *Mul = Builder.CreateFMul(Log2Call, Y); 211 auto *Exp2Call = 212 Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2"); 213 Exp2Call->setTailCall(Orig->isTailCall()); 214 Exp2Call->setAttributes(Orig->getAttributes()); 215 Orig->replaceAllUsesWith(Exp2Call); 216 Orig->eraseFromParent(); 217 return true; 218 } 219 220 static Intrinsic::ID getMaxForClamp(Type *ElemTy, 221 Intrinsic::ID ClampIntrinsic) { 222 if (ClampIntrinsic == Intrinsic::dx_uclamp) 223 return Intrinsic::umax; 224 assert(ClampIntrinsic == Intrinsic::dx_clamp); 225 if (ElemTy->isVectorTy()) 226 ElemTy = ElemTy->getScalarType(); 227 if (ElemTy->isIntegerTy()) 228 return Intrinsic::smax; 229 assert(ElemTy->isFloatingPointTy()); 230 return Intrinsic::maxnum; 231 } 232 233 static Intrinsic::ID getMinForClamp(Type *ElemTy, 234 Intrinsic::ID ClampIntrinsic) { 235 if (ClampIntrinsic == Intrinsic::dx_uclamp) 236 return Intrinsic::umin; 237 assert(ClampIntrinsic == Intrinsic::dx_clamp); 238 if (ElemTy->isVectorTy()) 239 ElemTy = ElemTy->getScalarType(); 240 if (ElemTy->isIntegerTy()) 241 return Intrinsic::smin; 242 assert(ElemTy->isFloatingPointTy()); 243 return Intrinsic::minnum; 244 } 245 246 static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) { 247 Value *X = Orig->getOperand(0); 248 Value *Min = Orig->getOperand(1); 249 Value *Max = Orig->getOperand(2); 250 Type *Ty = X->getType(); 251 IRBuilder<> Builder(Orig->getParent()); 252 Builder.SetInsertPoint(Orig); 253 auto *MaxCall = Builder.CreateIntrinsic( 254 Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max"); 255 auto *MinCall = 256 Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic), 257 {MaxCall, Max}, nullptr, "dx.min"); 258 259 Orig->replaceAllUsesWith(MinCall); 260 Orig->eraseFromParent(); 261 return true; 262 } 263 264 static bool expandIntrinsic(Function &F, CallInst *Orig) { 265 switch (F.getIntrinsicID()) { 266 case Intrinsic::abs: 267 return expandAbs(Orig); 268 case Intrinsic::exp: 269 return expandExpIntrinsic(Orig); 270 case Intrinsic::log: 271 return expandLogIntrinsic(Orig); 272 case Intrinsic::log10: 273 return expandLog10Intrinsic(Orig); 274 case Intrinsic::pow: 275 return expandPowIntrinsic(Orig); 276 case Intrinsic::dx_any: 277 return expandAnyIntrinsic(Orig); 278 case Intrinsic::dx_uclamp: 279 case Intrinsic::dx_clamp: 280 return expandClampIntrinsic(Orig, F.getIntrinsicID()); 281 case Intrinsic::dx_lerp: 282 return expandLerpIntrinsic(Orig); 283 case Intrinsic::dx_sdot: 284 case Intrinsic::dx_udot: 285 return expandIntegerDot(Orig, F.getIntrinsicID()); 286 } 287 return false; 288 } 289 290 static bool expansionIntrinsics(Module &M) { 291 for (auto &F : make_early_inc_range(M.functions())) { 292 if (!isIntrinsicExpansion(F)) 293 continue; 294 bool IntrinsicExpanded = false; 295 for (User *U : make_early_inc_range(F.users())) { 296 auto *IntrinsicCall = dyn_cast<CallInst>(U); 297 if (!IntrinsicCall) 298 continue; 299 IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall); 300 } 301 if (F.user_empty() && IntrinsicExpanded) 302 F.eraseFromParent(); 303 } 304 return true; 305 } 306 307 PreservedAnalyses DXILIntrinsicExpansion::run(Module &M, 308 ModuleAnalysisManager &) { 309 if (expansionIntrinsics(M)) 310 return PreservedAnalyses::none(); 311 return PreservedAnalyses::all(); 312 } 313 314 bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) { 315 return expansionIntrinsics(M); 316 } 317 318 char DXILIntrinsicExpansionLegacy::ID = 0; 319 320 INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE, 321 "DXIL Intrinsic Expansion", false, false) 322 INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE, 323 "DXIL Intrinsic Expansion", false, false) 324 325 ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() { 326 return new DXILIntrinsicExpansionLegacy(); 327 } 328