1 //===------- CGHLSLBuiltins.cpp - Emit LLVM Code for HLSL builtins --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This contains code to emit HLSL Builtin calls as LLVM code. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "CGBuiltin.h" 14 #include "CGHLSLRuntime.h" 15 #include "CodeGenFunction.h" 16 17 using namespace clang; 18 using namespace CodeGen; 19 using namespace llvm; 20 21 static Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) { 22 assert((E->getArg(0)->getType()->hasUnsignedIntegerRepresentation() && 23 E->getArg(1)->getType()->hasUnsignedIntegerRepresentation()) && 24 "asdouble operands types mismatch"); 25 Value *OpLowBits = CGF.EmitScalarExpr(E->getArg(0)); 26 Value *OpHighBits = CGF.EmitScalarExpr(E->getArg(1)); 27 28 llvm::Type *ResultType = CGF.DoubleTy; 29 int N = 1; 30 if (auto *VTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) { 31 N = VTy->getNumElements(); 32 ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N); 33 } 34 35 if (CGF.CGM.getTarget().getTriple().isDXIL()) 36 return CGF.Builder.CreateIntrinsic( 37 /*ReturnType=*/ResultType, Intrinsic::dx_asdouble, 38 {OpLowBits, OpHighBits}, nullptr, "hlsl.asdouble"); 39 40 if (!E->getArg(0)->getType()->isVectorType()) { 41 OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits); 42 OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits); 43 } 44 45 llvm::SmallVector<int> Mask; 46 for (int i = 0; i < N; i++) { 47 Mask.push_back(i); 48 Mask.push_back(i + N); 49 } 50 51 Value *BitVec = CGF.Builder.CreateShuffleVector(OpLowBits, OpHighBits, Mask); 52 53 return CGF.Builder.CreateBitCast(BitVec, ResultType); 54 } 55 56 static Value *handleHlslClip(const CallExpr *E, CodeGenFunction *CGF) { 57 Value *Op0 = CGF->EmitScalarExpr(E->getArg(0)); 58 59 Constant *FZeroConst = ConstantFP::getZero(CGF->FloatTy); 60 Value *CMP; 61 Value *LastInstr; 62 63 if (const auto *VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) { 64 FZeroConst = ConstantVector::getSplat( 65 ElementCount::getFixed(VecTy->getNumElements()), FZeroConst); 66 auto *FCompInst = CGF->Builder.CreateFCmpOLT(Op0, FZeroConst); 67 CMP = CGF->Builder.CreateIntrinsic( 68 CGF->Builder.getInt1Ty(), CGF->CGM.getHLSLRuntime().getAnyIntrinsic(), 69 {FCompInst}); 70 } else { 71 CMP = CGF->Builder.CreateFCmpOLT(Op0, FZeroConst); 72 } 73 74 if (CGF->CGM.getTarget().getTriple().isDXIL()) { 75 LastInstr = CGF->Builder.CreateIntrinsic(Intrinsic::dx_discard, {CMP}); 76 } else if (CGF->CGM.getTarget().getTriple().isSPIRV()) { 77 BasicBlock *LT0 = CGF->createBasicBlock("lt0", CGF->CurFn); 78 BasicBlock *End = CGF->createBasicBlock("end", CGF->CurFn); 79 80 CGF->Builder.CreateCondBr(CMP, LT0, End); 81 82 CGF->Builder.SetInsertPoint(LT0); 83 84 CGF->Builder.CreateIntrinsic(Intrinsic::spv_discard, {}); 85 86 LastInstr = CGF->Builder.CreateBr(End); 87 CGF->Builder.SetInsertPoint(End); 88 } else { 89 llvm_unreachable("Backend Codegen not supported."); 90 } 91 92 return LastInstr; 93 } 94 95 static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) { 96 Value *Op0 = CGF->EmitScalarExpr(E->getArg(0)); 97 const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1)); 98 const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2)); 99 100 CallArgList Args; 101 LValue Op1TmpLValue = 102 CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType()); 103 LValue Op2TmpLValue = 104 CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType()); 105 106 if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee()) 107 Args.reverseWritebacks(); 108 109 Value *LowBits = nullptr; 110 Value *HighBits = nullptr; 111 112 if (CGF->CGM.getTarget().getTriple().isDXIL()) { 113 llvm::Type *RetElementTy = CGF->Int32Ty; 114 if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) 115 RetElementTy = llvm::VectorType::get( 116 CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements())); 117 auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy); 118 119 CallInst *CI = CGF->Builder.CreateIntrinsic( 120 RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble"); 121 122 LowBits = CGF->Builder.CreateExtractValue(CI, 0); 123 HighBits = CGF->Builder.CreateExtractValue(CI, 1); 124 } else { 125 // For Non DXIL targets we generate the instructions. 126 127 if (!Op0->getType()->isVectorTy()) { 128 FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2); 129 Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy); 130 131 LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0); 132 HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1); 133 } else { 134 int NumElements = 1; 135 if (const auto *VecTy = 136 E->getArg(0)->getType()->getAs<clang::VectorType>()) 137 NumElements = VecTy->getNumElements(); 138 139 FixedVectorType *Uint32VecTy = 140 FixedVectorType::get(CGF->Int32Ty, NumElements * 2); 141 Value *Uint32Vec = CGF->Builder.CreateBitCast(Op0, Uint32VecTy); 142 if (NumElements == 1) { 143 LowBits = CGF->Builder.CreateExtractElement(Uint32Vec, (uint64_t)0); 144 HighBits = CGF->Builder.CreateExtractElement(Uint32Vec, 1); 145 } else { 146 SmallVector<int> EvenMask, OddMask; 147 for (int I = 0, E = NumElements; I != E; ++I) { 148 EvenMask.push_back(I * 2); 149 OddMask.push_back(I * 2 + 1); 150 } 151 LowBits = CGF->Builder.CreateShuffleVector(Uint32Vec, EvenMask); 152 HighBits = CGF->Builder.CreateShuffleVector(Uint32Vec, OddMask); 153 } 154 } 155 } 156 CGF->Builder.CreateStore(LowBits, Op1TmpLValue.getAddress()); 157 auto *LastInst = 158 CGF->Builder.CreateStore(HighBits, Op2TmpLValue.getAddress()); 159 CGF->EmitWritebacks(Args); 160 return LastInst; 161 } 162 163 // Return dot product intrinsic that corresponds to the QT scalar type 164 static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) { 165 if (QT->isFloatingType()) 166 return RT.getFDotIntrinsic(); 167 if (QT->isSignedIntegerType()) 168 return RT.getSDotIntrinsic(); 169 assert(QT->isUnsignedIntegerType()); 170 return RT.getUDotIntrinsic(); 171 } 172 173 static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) { 174 if (QT->hasSignedIntegerRepresentation()) { 175 return RT.getFirstBitSHighIntrinsic(); 176 } 177 178 assert(QT->hasUnsignedIntegerRepresentation()); 179 return RT.getFirstBitUHighIntrinsic(); 180 } 181 182 // Return wave active sum that corresponds to the QT scalar type 183 static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch, 184 CGHLSLRuntime &RT, QualType QT) { 185 switch (Arch) { 186 case llvm::Triple::spirv: 187 return Intrinsic::spv_wave_reduce_sum; 188 case llvm::Triple::dxil: { 189 if (QT->isUnsignedIntegerType()) 190 return Intrinsic::dx_wave_reduce_usum; 191 return Intrinsic::dx_wave_reduce_sum; 192 } 193 default: 194 llvm_unreachable("Intrinsic WaveActiveSum" 195 " not supported by target architecture"); 196 } 197 } 198 199 // Return wave active sum that corresponds to the QT scalar type 200 static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch, 201 CGHLSLRuntime &RT, QualType QT) { 202 switch (Arch) { 203 case llvm::Triple::spirv: 204 if (QT->isUnsignedIntegerType()) 205 return Intrinsic::spv_wave_reduce_umax; 206 return Intrinsic::spv_wave_reduce_max; 207 case llvm::Triple::dxil: { 208 if (QT->isUnsignedIntegerType()) 209 return Intrinsic::dx_wave_reduce_umax; 210 return Intrinsic::dx_wave_reduce_max; 211 } 212 default: 213 llvm_unreachable("Intrinsic WaveActiveMax" 214 " not supported by target architecture"); 215 } 216 } 217 218 // Returns the mangled name for a builtin function that the SPIR-V backend 219 // will expand into a spec Constant. 220 static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType, 221 ASTContext &Context) { 222 // The parameter types for our conceptual intrinsic function. 223 QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType}; 224 225 // Create a temporary FunctionDecl for the builtin fuction. It won't be 226 // added to the AST. 227 FunctionProtoType::ExtProtoInfo EPI; 228 QualType FnType = 229 Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI); 230 DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant"); 231 FunctionDecl *FnDeclForMangling = FunctionDecl::Create( 232 Context, Context.getTranslationUnitDecl(), SourceLocation(), 233 SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern); 234 235 // Attach the created parameter declarations to the function declaration. 236 SmallVector<ParmVarDecl *, 2> ParamDecls; 237 for (QualType ParamType : ClangParamTypes) { 238 ParmVarDecl *PD = ParmVarDecl::Create( 239 Context, FnDeclForMangling, SourceLocation(), SourceLocation(), 240 /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None, 241 /*DefaultArg*/ nullptr); 242 ParamDecls.push_back(PD); 243 } 244 FnDeclForMangling->setParams(ParamDecls); 245 246 // Get the mangled name. 247 std::string Name; 248 llvm::raw_string_ostream MangledNameStream(Name); 249 std::unique_ptr<MangleContext> Mangler(Context.createMangleContext()); 250 Mangler->mangleName(FnDeclForMangling, MangledNameStream); 251 MangledNameStream.flush(); 252 253 return Name; 254 } 255 256 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, 257 const CallExpr *E, 258 ReturnValueSlot ReturnValue) { 259 if (!getLangOpts().HLSL) 260 return nullptr; 261 262 switch (BuiltinID) { 263 case Builtin::BI__builtin_hlsl_adduint64: { 264 Value *OpA = EmitScalarExpr(E->getArg(0)); 265 Value *OpB = EmitScalarExpr(E->getArg(1)); 266 QualType Arg0Ty = E->getArg(0)->getType(); 267 uint64_t NumElements = Arg0Ty->castAs<VectorType>()->getNumElements(); 268 assert(Arg0Ty == E->getArg(1)->getType() && 269 "AddUint64 operand types must match"); 270 assert(Arg0Ty->hasIntegerRepresentation() && 271 "AddUint64 operands must have an integer representation"); 272 assert((NumElements == 2 || NumElements == 4) && 273 "AddUint64 operands must have 2 or 4 elements"); 274 275 llvm::Value *LowA; 276 llvm::Value *HighA; 277 llvm::Value *LowB; 278 llvm::Value *HighB; 279 280 // Obtain low and high words of inputs A and B 281 if (NumElements == 2) { 282 LowA = Builder.CreateExtractElement(OpA, (uint64_t)0, "LowA"); 283 HighA = Builder.CreateExtractElement(OpA, (uint64_t)1, "HighA"); 284 LowB = Builder.CreateExtractElement(OpB, (uint64_t)0, "LowB"); 285 HighB = Builder.CreateExtractElement(OpB, (uint64_t)1, "HighB"); 286 } else { 287 LowA = Builder.CreateShuffleVector(OpA, {0, 2}, "LowA"); 288 HighA = Builder.CreateShuffleVector(OpA, {1, 3}, "HighA"); 289 LowB = Builder.CreateShuffleVector(OpB, {0, 2}, "LowB"); 290 HighB = Builder.CreateShuffleVector(OpB, {1, 3}, "HighB"); 291 } 292 293 // Use an uadd_with_overflow to compute the sum of low words and obtain a 294 // carry value 295 llvm::Value *Carry; 296 llvm::Value *LowSum = EmitOverflowIntrinsic( 297 *this, Intrinsic::uadd_with_overflow, LowA, LowB, Carry); 298 llvm::Value *ZExtCarry = 299 Builder.CreateZExt(Carry, HighA->getType(), "CarryZExt"); 300 301 // Sum the high words and the carry 302 llvm::Value *HighSum = Builder.CreateAdd(HighA, HighB, "HighSum"); 303 llvm::Value *HighSumPlusCarry = 304 Builder.CreateAdd(HighSum, ZExtCarry, "HighSumPlusCarry"); 305 306 if (NumElements == 4) { 307 return Builder.CreateShuffleVector(LowSum, HighSumPlusCarry, {0, 2, 1, 3}, 308 "hlsl.AddUint64"); 309 } 310 311 llvm::Value *Result = PoisonValue::get(OpA->getType()); 312 Result = Builder.CreateInsertElement(Result, LowSum, (uint64_t)0, 313 "hlsl.AddUint64.upto0"); 314 Result = Builder.CreateInsertElement(Result, HighSumPlusCarry, (uint64_t)1, 315 "hlsl.AddUint64"); 316 return Result; 317 } 318 case Builtin::BI__builtin_hlsl_resource_getpointer: { 319 Value *HandleOp = EmitScalarExpr(E->getArg(0)); 320 Value *IndexOp = EmitScalarExpr(E->getArg(1)); 321 322 llvm::Type *RetTy = ConvertType(E->getType()); 323 return Builder.CreateIntrinsic( 324 RetTy, CGM.getHLSLRuntime().getCreateResourceGetPointerIntrinsic(), 325 ArrayRef<Value *>{HandleOp, IndexOp}); 326 } 327 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: { 328 llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType()); 329 return llvm::PoisonValue::get(HandleTy); 330 } 331 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: { 332 llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType()); 333 Value *RegisterOp = EmitScalarExpr(E->getArg(1)); 334 Value *SpaceOp = EmitScalarExpr(E->getArg(2)); 335 Value *RangeOp = EmitScalarExpr(E->getArg(3)); 336 Value *IndexOp = EmitScalarExpr(E->getArg(4)); 337 Value *Name = EmitScalarExpr(E->getArg(5)); 338 // FIXME: NonUniformResourceIndex bit is not yet implemented 339 // (llvm/llvm-project#135452) 340 Value *NonUniform = 341 llvm::ConstantInt::get(llvm::Type::getInt1Ty(getLLVMContext()), false); 342 343 llvm::Intrinsic::ID IntrinsicID = 344 CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic(); 345 SmallVector<Value *> Args{SpaceOp, RegisterOp, RangeOp, 346 IndexOp, NonUniform, Name}; 347 return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args); 348 } 349 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: { 350 llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType()); 351 Value *SpaceOp = EmitScalarExpr(E->getArg(1)); 352 Value *RangeOp = EmitScalarExpr(E->getArg(2)); 353 Value *IndexOp = EmitScalarExpr(E->getArg(3)); 354 Value *OrderID = EmitScalarExpr(E->getArg(4)); 355 Value *Name = EmitScalarExpr(E->getArg(5)); 356 // FIXME: NonUniformResourceIndex bit is not yet implemented 357 // (llvm/llvm-project#135452) 358 Value *NonUniform = 359 llvm::ConstantInt::get(llvm::Type::getInt1Ty(getLLVMContext()), false); 360 361 llvm::Intrinsic::ID IntrinsicID = 362 CGM.getHLSLRuntime().getCreateHandleFromImplicitBindingIntrinsic(); 363 SmallVector<Value *> Args{OrderID, SpaceOp, RangeOp, 364 IndexOp, NonUniform, Name}; 365 return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args); 366 } 367 case Builtin::BI__builtin_hlsl_all: { 368 Value *Op0 = EmitScalarExpr(E->getArg(0)); 369 return Builder.CreateIntrinsic( 370 /*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()), 371 CGM.getHLSLRuntime().getAllIntrinsic(), ArrayRef<Value *>{Op0}, nullptr, 372 "hlsl.all"); 373 } 374 case Builtin::BI__builtin_hlsl_and: { 375 Value *Op0 = EmitScalarExpr(E->getArg(0)); 376 Value *Op1 = EmitScalarExpr(E->getArg(1)); 377 return Builder.CreateAnd(Op0, Op1, "hlsl.and"); 378 } 379 case Builtin::BI__builtin_hlsl_or: { 380 Value *Op0 = EmitScalarExpr(E->getArg(0)); 381 Value *Op1 = EmitScalarExpr(E->getArg(1)); 382 return Builder.CreateOr(Op0, Op1, "hlsl.or"); 383 } 384 case Builtin::BI__builtin_hlsl_any: { 385 Value *Op0 = EmitScalarExpr(E->getArg(0)); 386 return Builder.CreateIntrinsic( 387 /*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()), 388 CGM.getHLSLRuntime().getAnyIntrinsic(), ArrayRef<Value *>{Op0}, nullptr, 389 "hlsl.any"); 390 } 391 case Builtin::BI__builtin_hlsl_asdouble: 392 return handleAsDoubleBuiltin(*this, E); 393 case Builtin::BI__builtin_hlsl_elementwise_clamp: { 394 Value *OpX = EmitScalarExpr(E->getArg(0)); 395 Value *OpMin = EmitScalarExpr(E->getArg(1)); 396 Value *OpMax = EmitScalarExpr(E->getArg(2)); 397 398 QualType Ty = E->getArg(0)->getType(); 399 if (auto *VecTy = Ty->getAs<VectorType>()) 400 Ty = VecTy->getElementType(); 401 402 Intrinsic::ID Intr; 403 if (Ty->isFloatingType()) { 404 Intr = CGM.getHLSLRuntime().getNClampIntrinsic(); 405 } else if (Ty->isUnsignedIntegerType()) { 406 Intr = CGM.getHLSLRuntime().getUClampIntrinsic(); 407 } else { 408 assert(Ty->isSignedIntegerType()); 409 Intr = CGM.getHLSLRuntime().getSClampIntrinsic(); 410 } 411 return Builder.CreateIntrinsic( 412 /*ReturnType=*/OpX->getType(), Intr, 413 ArrayRef<Value *>{OpX, OpMin, OpMax}, nullptr, "hlsl.clamp"); 414 } 415 case Builtin::BI__builtin_hlsl_crossf16: 416 case Builtin::BI__builtin_hlsl_crossf32: { 417 Value *Op0 = EmitScalarExpr(E->getArg(0)); 418 Value *Op1 = EmitScalarExpr(E->getArg(1)); 419 assert(E->getArg(0)->getType()->hasFloatingRepresentation() && 420 E->getArg(1)->getType()->hasFloatingRepresentation() && 421 "cross operands must have a float representation"); 422 // make sure each vector has exactly 3 elements 423 assert( 424 E->getArg(0)->getType()->castAs<VectorType>()->getNumElements() == 3 && 425 E->getArg(1)->getType()->castAs<VectorType>()->getNumElements() == 3 && 426 "input vectors must have 3 elements each"); 427 return Builder.CreateIntrinsic( 428 /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getCrossIntrinsic(), 429 ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.cross"); 430 } 431 case Builtin::BI__builtin_hlsl_dot: { 432 Value *Op0 = EmitScalarExpr(E->getArg(0)); 433 Value *Op1 = EmitScalarExpr(E->getArg(1)); 434 llvm::Type *T0 = Op0->getType(); 435 llvm::Type *T1 = Op1->getType(); 436 437 // If the arguments are scalars, just emit a multiply 438 if (!T0->isVectorTy() && !T1->isVectorTy()) { 439 if (T0->isFloatingPointTy()) 440 return Builder.CreateFMul(Op0, Op1, "hlsl.dot"); 441 442 if (T0->isIntegerTy()) 443 return Builder.CreateMul(Op0, Op1, "hlsl.dot"); 444 445 llvm_unreachable( 446 "Scalar dot product is only supported on ints and floats."); 447 } 448 // For vectors, validate types and emit the appropriate intrinsic 449 assert(CGM.getContext().hasSameUnqualifiedType(E->getArg(0)->getType(), 450 E->getArg(1)->getType()) && 451 "Dot product operands must have the same type."); 452 453 auto *VecTy0 = E->getArg(0)->getType()->castAs<VectorType>(); 454 assert(VecTy0 && "Dot product argument must be a vector."); 455 456 return Builder.CreateIntrinsic( 457 /*ReturnType=*/T0->getScalarType(), 458 getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()), 459 ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot"); 460 } 461 case Builtin::BI__builtin_hlsl_dot4add_i8packed: { 462 Value *X = EmitScalarExpr(E->getArg(0)); 463 Value *Y = EmitScalarExpr(E->getArg(1)); 464 Value *Acc = EmitScalarExpr(E->getArg(2)); 465 466 Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddI8PackedIntrinsic(); 467 // Note that the argument order disagrees between the builtin and the 468 // intrinsic here. 469 return Builder.CreateIntrinsic( 470 /*ReturnType=*/Acc->getType(), ID, ArrayRef<Value *>{Acc, X, Y}, 471 nullptr, "hlsl.dot4add.i8packed"); 472 } 473 case Builtin::BI__builtin_hlsl_dot4add_u8packed: { 474 Value *X = EmitScalarExpr(E->getArg(0)); 475 Value *Y = EmitScalarExpr(E->getArg(1)); 476 Value *Acc = EmitScalarExpr(E->getArg(2)); 477 478 Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddU8PackedIntrinsic(); 479 // Note that the argument order disagrees between the builtin and the 480 // intrinsic here. 481 return Builder.CreateIntrinsic( 482 /*ReturnType=*/Acc->getType(), ID, ArrayRef<Value *>{Acc, X, Y}, 483 nullptr, "hlsl.dot4add.u8packed"); 484 } 485 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: { 486 Value *X = EmitScalarExpr(E->getArg(0)); 487 488 return Builder.CreateIntrinsic( 489 /*ReturnType=*/ConvertType(E->getType()), 490 getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()), 491 ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh"); 492 } 493 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: { 494 Value *X = EmitScalarExpr(E->getArg(0)); 495 496 return Builder.CreateIntrinsic( 497 /*ReturnType=*/ConvertType(E->getType()), 498 CGM.getHLSLRuntime().getFirstBitLowIntrinsic(), ArrayRef<Value *>{X}, 499 nullptr, "hlsl.firstbitlow"); 500 } 501 case Builtin::BI__builtin_hlsl_lerp: { 502 Value *X = EmitScalarExpr(E->getArg(0)); 503 Value *Y = EmitScalarExpr(E->getArg(1)); 504 Value *S = EmitScalarExpr(E->getArg(2)); 505 if (!E->getArg(0)->getType()->hasFloatingRepresentation()) 506 llvm_unreachable("lerp operand must have a float representation"); 507 return Builder.CreateIntrinsic( 508 /*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(), 509 ArrayRef<Value *>{X, Y, S}, nullptr, "hlsl.lerp"); 510 } 511 case Builtin::BI__builtin_hlsl_normalize: { 512 Value *X = EmitScalarExpr(E->getArg(0)); 513 514 assert(E->getArg(0)->getType()->hasFloatingRepresentation() && 515 "normalize operand must have a float representation"); 516 517 return Builder.CreateIntrinsic( 518 /*ReturnType=*/X->getType(), 519 CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X}, 520 nullptr, "hlsl.normalize"); 521 } 522 case Builtin::BI__builtin_hlsl_elementwise_degrees: { 523 Value *X = EmitScalarExpr(E->getArg(0)); 524 525 assert(E->getArg(0)->getType()->hasFloatingRepresentation() && 526 "degree operand must have a float representation"); 527 528 return Builder.CreateIntrinsic( 529 /*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getDegreesIntrinsic(), 530 ArrayRef<Value *>{X}, nullptr, "hlsl.degrees"); 531 } 532 case Builtin::BI__builtin_hlsl_elementwise_frac: { 533 Value *Op0 = EmitScalarExpr(E->getArg(0)); 534 if (!E->getArg(0)->getType()->hasFloatingRepresentation()) 535 llvm_unreachable("frac operand must have a float representation"); 536 return Builder.CreateIntrinsic( 537 /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getFracIntrinsic(), 538 ArrayRef<Value *>{Op0}, nullptr, "hlsl.frac"); 539 } 540 case Builtin::BI__builtin_hlsl_elementwise_isinf: { 541 Value *Op0 = EmitScalarExpr(E->getArg(0)); 542 llvm::Type *Xty = Op0->getType(); 543 llvm::Type *retType = llvm::Type::getInt1Ty(this->getLLVMContext()); 544 if (Xty->isVectorTy()) { 545 auto *XVecTy = E->getArg(0)->getType()->castAs<VectorType>(); 546 retType = llvm::VectorType::get( 547 retType, ElementCount::getFixed(XVecTy->getNumElements())); 548 } 549 if (!E->getArg(0)->getType()->hasFloatingRepresentation()) 550 llvm_unreachable("isinf operand must have a float representation"); 551 return Builder.CreateIntrinsic(retType, Intrinsic::dx_isinf, 552 ArrayRef<Value *>{Op0}, nullptr, "dx.isinf"); 553 } 554 case Builtin::BI__builtin_hlsl_mad: { 555 Value *M = EmitScalarExpr(E->getArg(0)); 556 Value *A = EmitScalarExpr(E->getArg(1)); 557 Value *B = EmitScalarExpr(E->getArg(2)); 558 if (E->getArg(0)->getType()->hasFloatingRepresentation()) 559 return Builder.CreateIntrinsic( 560 /*ReturnType*/ M->getType(), Intrinsic::fmuladd, 561 ArrayRef<Value *>{M, A, B}, nullptr, "hlsl.fmad"); 562 563 if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) { 564 if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil) 565 return Builder.CreateIntrinsic( 566 /*ReturnType*/ M->getType(), Intrinsic::dx_imad, 567 ArrayRef<Value *>{M, A, B}, nullptr, "dx.imad"); 568 569 Value *Mul = Builder.CreateNSWMul(M, A); 570 return Builder.CreateNSWAdd(Mul, B); 571 } 572 assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation()); 573 if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil) 574 return Builder.CreateIntrinsic( 575 /*ReturnType=*/M->getType(), Intrinsic::dx_umad, 576 ArrayRef<Value *>{M, A, B}, nullptr, "dx.umad"); 577 578 Value *Mul = Builder.CreateNUWMul(M, A); 579 return Builder.CreateNUWAdd(Mul, B); 580 } 581 case Builtin::BI__builtin_hlsl_elementwise_rcp: { 582 Value *Op0 = EmitScalarExpr(E->getArg(0)); 583 if (!E->getArg(0)->getType()->hasFloatingRepresentation()) 584 llvm_unreachable("rcp operand must have a float representation"); 585 llvm::Type *Ty = Op0->getType(); 586 llvm::Type *EltTy = Ty->getScalarType(); 587 Constant *One = Ty->isVectorTy() 588 ? ConstantVector::getSplat( 589 ElementCount::getFixed( 590 cast<FixedVectorType>(Ty)->getNumElements()), 591 ConstantFP::get(EltTy, 1.0)) 592 : ConstantFP::get(EltTy, 1.0); 593 return Builder.CreateFDiv(One, Op0, "hlsl.rcp"); 594 } 595 case Builtin::BI__builtin_hlsl_elementwise_rsqrt: { 596 Value *Op0 = EmitScalarExpr(E->getArg(0)); 597 if (!E->getArg(0)->getType()->hasFloatingRepresentation()) 598 llvm_unreachable("rsqrt operand must have a float representation"); 599 return Builder.CreateIntrinsic( 600 /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(), 601 ArrayRef<Value *>{Op0}, nullptr, "hlsl.rsqrt"); 602 } 603 case Builtin::BI__builtin_hlsl_elementwise_saturate: { 604 Value *Op0 = EmitScalarExpr(E->getArg(0)); 605 assert(E->getArg(0)->getType()->hasFloatingRepresentation() && 606 "saturate operand must have a float representation"); 607 return Builder.CreateIntrinsic( 608 /*ReturnType=*/Op0->getType(), 609 CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0}, 610 nullptr, "hlsl.saturate"); 611 } 612 case Builtin::BI__builtin_hlsl_select: { 613 Value *OpCond = EmitScalarExpr(E->getArg(0)); 614 RValue RValTrue = EmitAnyExpr(E->getArg(1)); 615 Value *OpTrue = 616 RValTrue.isScalar() 617 ? RValTrue.getScalarVal() 618 : RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this); 619 RValue RValFalse = EmitAnyExpr(E->getArg(2)); 620 Value *OpFalse = 621 RValFalse.isScalar() 622 ? RValFalse.getScalarVal() 623 : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this); 624 if (auto *VTy = E->getType()->getAs<VectorType>()) { 625 if (!OpTrue->getType()->isVectorTy()) 626 OpTrue = 627 Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat"); 628 if (!OpFalse->getType()->isVectorTy()) 629 OpFalse = 630 Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat"); 631 } 632 633 Value *SelectVal = 634 Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select"); 635 if (!RValTrue.isScalar()) 636 Builder.CreateStore(SelectVal, ReturnValue.getAddress(), 637 ReturnValue.isVolatile()); 638 639 return SelectVal; 640 } 641 case Builtin::BI__builtin_hlsl_step: { 642 Value *Op0 = EmitScalarExpr(E->getArg(0)); 643 Value *Op1 = EmitScalarExpr(E->getArg(1)); 644 assert(E->getArg(0)->getType()->hasFloatingRepresentation() && 645 E->getArg(1)->getType()->hasFloatingRepresentation() && 646 "step operands must have a float representation"); 647 return Builder.CreateIntrinsic( 648 /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(), 649 ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step"); 650 } 651 case Builtin::BI__builtin_hlsl_wave_active_all_true: { 652 Value *Op = EmitScalarExpr(E->getArg(0)); 653 assert(Op->getType()->isIntegerTy(1) && 654 "Intrinsic WaveActiveAllTrue operand must be a bool"); 655 656 Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllTrueIntrinsic(); 657 return EmitRuntimeCall( 658 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op}); 659 } 660 case Builtin::BI__builtin_hlsl_wave_active_any_true: { 661 Value *Op = EmitScalarExpr(E->getArg(0)); 662 assert(Op->getType()->isIntegerTy(1) && 663 "Intrinsic WaveActiveAnyTrue operand must be a bool"); 664 665 Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAnyTrueIntrinsic(); 666 return EmitRuntimeCall( 667 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op}); 668 } 669 case Builtin::BI__builtin_hlsl_wave_active_count_bits: { 670 Value *OpExpr = EmitScalarExpr(E->getArg(0)); 671 Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic(); 672 return EmitRuntimeCall( 673 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), 674 ArrayRef{OpExpr}); 675 } 676 case Builtin::BI__builtin_hlsl_wave_active_sum: { 677 // Due to the use of variadic arguments, explicitly retreive argument 678 Value *OpExpr = EmitScalarExpr(E->getArg(0)); 679 Intrinsic::ID IID = getWaveActiveSumIntrinsic( 680 getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), 681 E->getArg(0)->getType()); 682 683 return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( 684 &CGM.getModule(), IID, {OpExpr->getType()}), 685 ArrayRef{OpExpr}, "hlsl.wave.active.sum"); 686 } 687 case Builtin::BI__builtin_hlsl_wave_active_max: { 688 // Due to the use of variadic arguments, explicitly retreive argument 689 Value *OpExpr = EmitScalarExpr(E->getArg(0)); 690 Intrinsic::ID IID = getWaveActiveMaxIntrinsic( 691 getTarget().getTriple().getArch(), CGM.getHLSLRuntime(), 692 E->getArg(0)->getType()); 693 694 return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( 695 &CGM.getModule(), IID, {OpExpr->getType()}), 696 ArrayRef{OpExpr}, "hlsl.wave.active.max"); 697 } 698 case Builtin::BI__builtin_hlsl_wave_get_lane_index: { 699 // We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in 700 // defined in SPIRVBuiltins.td. So instead we manually get the matching name 701 // for the DirectX intrinsic and the demangled builtin name 702 switch (CGM.getTarget().getTriple().getArch()) { 703 case llvm::Triple::dxil: 704 return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration( 705 &CGM.getModule(), Intrinsic::dx_wave_getlaneindex)); 706 case llvm::Triple::spirv: 707 return EmitRuntimeCall(CGM.CreateRuntimeFunction( 708 llvm::FunctionType::get(IntTy, {}, false), 709 "__hlsl_wave_get_lane_index", {}, false, true)); 710 default: 711 llvm_unreachable( 712 "Intrinsic WaveGetLaneIndex not supported by target architecture"); 713 } 714 } 715 case Builtin::BI__builtin_hlsl_wave_is_first_lane: { 716 Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic(); 717 return EmitRuntimeCall( 718 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID)); 719 } 720 case Builtin::BI__builtin_hlsl_wave_get_lane_count: { 721 Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveGetLaneCountIntrinsic(); 722 return EmitRuntimeCall( 723 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID)); 724 } 725 case Builtin::BI__builtin_hlsl_wave_read_lane_at: { 726 // Due to the use of variadic arguments we must explicitly retreive them and 727 // create our function type. 728 Value *OpExpr = EmitScalarExpr(E->getArg(0)); 729 Value *OpIndex = EmitScalarExpr(E->getArg(1)); 730 return EmitRuntimeCall( 731 Intrinsic::getOrInsertDeclaration( 732 &CGM.getModule(), CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(), 733 {OpExpr->getType()}), 734 ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane"); 735 } 736 case Builtin::BI__builtin_hlsl_elementwise_sign: { 737 auto *Arg0 = E->getArg(0); 738 Value *Op0 = EmitScalarExpr(Arg0); 739 llvm::Type *Xty = Op0->getType(); 740 llvm::Type *retType = llvm::Type::getInt32Ty(this->getLLVMContext()); 741 if (Xty->isVectorTy()) { 742 auto *XVecTy = Arg0->getType()->castAs<VectorType>(); 743 retType = llvm::VectorType::get( 744 retType, ElementCount::getFixed(XVecTy->getNumElements())); 745 } 746 assert((Arg0->getType()->hasFloatingRepresentation() || 747 Arg0->getType()->hasIntegerRepresentation()) && 748 "sign operand must have a float or int representation"); 749 750 if (Arg0->getType()->hasUnsignedIntegerRepresentation()) { 751 Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::get(Xty, 0)); 752 return Builder.CreateSelect(Cmp, ConstantInt::get(retType, 0), 753 ConstantInt::get(retType, 1), "hlsl.sign"); 754 } 755 756 return Builder.CreateIntrinsic( 757 retType, CGM.getHLSLRuntime().getSignIntrinsic(), 758 ArrayRef<Value *>{Op0}, nullptr, "hlsl.sign"); 759 } 760 case Builtin::BI__builtin_hlsl_elementwise_radians: { 761 Value *Op0 = EmitScalarExpr(E->getArg(0)); 762 assert(E->getArg(0)->getType()->hasFloatingRepresentation() && 763 "radians operand must have a float representation"); 764 return Builder.CreateIntrinsic( 765 /*ReturnType=*/Op0->getType(), 766 CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0}, 767 nullptr, "hlsl.radians"); 768 } 769 case Builtin::BI__builtin_hlsl_buffer_update_counter: { 770 Value *ResHandle = EmitScalarExpr(E->getArg(0)); 771 Value *Offset = EmitScalarExpr(E->getArg(1)); 772 Value *OffsetI8 = Builder.CreateIntCast(Offset, Int8Ty, true); 773 return Builder.CreateIntrinsic( 774 /*ReturnType=*/Offset->getType(), 775 CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(), 776 ArrayRef<Value *>{ResHandle, OffsetI8}, nullptr); 777 } 778 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: { 779 780 assert((E->getArg(0)->getType()->hasFloatingRepresentation() && 781 E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() && 782 E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) && 783 "asuint operands types mismatch"); 784 return handleHlslSplitdouble(E, this); 785 } 786 case Builtin::BI__builtin_hlsl_elementwise_clip: 787 assert(E->getArg(0)->getType()->hasFloatingRepresentation() && 788 "clip operands types mismatch"); 789 return handleHlslClip(E, this); 790 case Builtin::BI__builtin_hlsl_group_memory_barrier_with_group_sync: { 791 Intrinsic::ID ID = 792 CGM.getHLSLRuntime().getGroupMemoryBarrierWithGroupSyncIntrinsic(); 793 return EmitRuntimeCall( 794 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID)); 795 } 796 case Builtin::BI__builtin_get_spirv_spec_constant_bool: 797 case Builtin::BI__builtin_get_spirv_spec_constant_short: 798 case Builtin::BI__builtin_get_spirv_spec_constant_ushort: 799 case Builtin::BI__builtin_get_spirv_spec_constant_int: 800 case Builtin::BI__builtin_get_spirv_spec_constant_uint: 801 case Builtin::BI__builtin_get_spirv_spec_constant_longlong: 802 case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong: 803 case Builtin::BI__builtin_get_spirv_spec_constant_half: 804 case Builtin::BI__builtin_get_spirv_spec_constant_float: 805 case Builtin::BI__builtin_get_spirv_spec_constant_double: { 806 llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType()); 807 llvm::Value *SpecId = EmitScalarExpr(E->getArg(0)); 808 llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1)); 809 llvm::Value *Args[] = {SpecId, DefaultVal}; 810 return Builder.CreateCall(SpecConstantFn, Args); 811 } 812 } 813 return nullptr; 814 } 815 816 llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction( 817 const clang::QualType &SpecConstantType) { 818 819 // Find or create the declaration for the function. 820 llvm::Module *M = &CGM.getModule(); 821 std::string MangledName = 822 getSpecConstantFunctionName(SpecConstantType, getContext()); 823 llvm::Function *SpecConstantFn = M->getFunction(MangledName); 824 825 if (!SpecConstantFn) { 826 llvm::Type *IntType = ConvertType(getContext().IntTy); 827 llvm::Type *RetTy = ConvertType(SpecConstantType); 828 llvm::Type *ArgTypes[] = {IntType, RetTy}; 829 llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false); 830 SpecConstantFn = llvm::Function::Create( 831 FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M); 832 } 833 return SpecConstantFn; 834 } 835