xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CGHLSLBuiltins.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===------- CGHLSLBuiltins.cpp - Emit LLVM Code for HLSL builtins --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This contains code to emit HLSL Builtin calls as LLVM code.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "CGBuiltin.h"
14 #include "CGHLSLRuntime.h"
15 #include "CodeGenFunction.h"
16 
17 using namespace clang;
18 using namespace CodeGen;
19 using namespace llvm;
20 
handleAsDoubleBuiltin(CodeGenFunction & CGF,const CallExpr * E)21 static Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) {
22   assert((E->getArg(0)->getType()->hasUnsignedIntegerRepresentation() &&
23           E->getArg(1)->getType()->hasUnsignedIntegerRepresentation()) &&
24          "asdouble operands types mismatch");
25   Value *OpLowBits = CGF.EmitScalarExpr(E->getArg(0));
26   Value *OpHighBits = CGF.EmitScalarExpr(E->getArg(1));
27 
28   llvm::Type *ResultType = CGF.DoubleTy;
29   int N = 1;
30   if (auto *VTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {
31     N = VTy->getNumElements();
32     ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N);
33   }
34 
35   if (CGF.CGM.getTarget().getTriple().isDXIL())
36     return CGF.Builder.CreateIntrinsic(
37         /*ReturnType=*/ResultType, Intrinsic::dx_asdouble,
38         {OpLowBits, OpHighBits}, nullptr, "hlsl.asdouble");
39 
40   if (!E->getArg(0)->getType()->isVectorType()) {
41     OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits);
42     OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits);
43   }
44 
45   llvm::SmallVector<int> Mask;
46   for (int i = 0; i < N; i++) {
47     Mask.push_back(i);
48     Mask.push_back(i + N);
49   }
50 
51   Value *BitVec = CGF.Builder.CreateShuffleVector(OpLowBits, OpHighBits, Mask);
52 
53   return CGF.Builder.CreateBitCast(BitVec, ResultType);
54 }
55 
handleHlslClip(const CallExpr * E,CodeGenFunction * CGF)56 static Value *handleHlslClip(const CallExpr *E, CodeGenFunction *CGF) {
57   Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
58 
59   Constant *FZeroConst = ConstantFP::getZero(CGF->FloatTy);
60   Value *CMP;
61   Value *LastInstr;
62 
63   if (const auto *VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {
64     FZeroConst = ConstantVector::getSplat(
65         ElementCount::getFixed(VecTy->getNumElements()), FZeroConst);
66     auto *FCompInst = CGF->Builder.CreateFCmpOLT(Op0, FZeroConst);
67     CMP = CGF->Builder.CreateIntrinsic(
68         CGF->Builder.getInt1Ty(), CGF->CGM.getHLSLRuntime().getAnyIntrinsic(),
69         {FCompInst});
70   } else {
71     CMP = CGF->Builder.CreateFCmpOLT(Op0, FZeroConst);
72   }
73 
74   if (CGF->CGM.getTarget().getTriple().isDXIL()) {
75     LastInstr = CGF->Builder.CreateIntrinsic(Intrinsic::dx_discard, {CMP});
76   } else if (CGF->CGM.getTarget().getTriple().isSPIRV()) {
77     BasicBlock *LT0 = CGF->createBasicBlock("lt0", CGF->CurFn);
78     BasicBlock *End = CGF->createBasicBlock("end", CGF->CurFn);
79 
80     CGF->Builder.CreateCondBr(CMP, LT0, End);
81 
82     CGF->Builder.SetInsertPoint(LT0);
83 
84     CGF->Builder.CreateIntrinsic(Intrinsic::spv_discard, {});
85 
86     LastInstr = CGF->Builder.CreateBr(End);
87     CGF->Builder.SetInsertPoint(End);
88   } else {
89     llvm_unreachable("Backend Codegen not supported.");
90   }
91 
92   return LastInstr;
93 }
94 
handleHlslSplitdouble(const CallExpr * E,CodeGenFunction * CGF)95 static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
96   Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
97   const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
98   const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
99 
100   CallArgList Args;
101   LValue Op1TmpLValue =
102       CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
103   LValue Op2TmpLValue =
104       CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
105 
106   if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
107     Args.reverseWritebacks();
108 
109   Value *LowBits = nullptr;
110   Value *HighBits = nullptr;
111 
112   if (CGF->CGM.getTarget().getTriple().isDXIL()) {
113     llvm::Type *RetElementTy = CGF->Int32Ty;
114     if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())
115       RetElementTy = llvm::VectorType::get(
116           CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
117     auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);
118 
119     CallInst *CI = CGF->Builder.CreateIntrinsic(
120         RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");
121 
122     LowBits = CGF->Builder.CreateExtractValue(CI, 0);
123     HighBits = CGF->Builder.CreateExtractValue(CI, 1);
124   } else {
125     // For Non DXIL targets we generate the instructions.
126 
127     if (!Op0->getType()->isVectorTy()) {
128       FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
129       Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);
130 
131       LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
132       HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
133     } else {
134       int NumElements = 1;
135       if (const auto *VecTy =
136               E->getArg(0)->getType()->getAs<clang::VectorType>())
137         NumElements = VecTy->getNumElements();
138 
139       FixedVectorType *Uint32VecTy =
140           FixedVectorType::get(CGF->Int32Ty, NumElements * 2);
141       Value *Uint32Vec = CGF->Builder.CreateBitCast(Op0, Uint32VecTy);
142       if (NumElements == 1) {
143         LowBits = CGF->Builder.CreateExtractElement(Uint32Vec, (uint64_t)0);
144         HighBits = CGF->Builder.CreateExtractElement(Uint32Vec, 1);
145       } else {
146         SmallVector<int> EvenMask, OddMask;
147         for (int I = 0, E = NumElements; I != E; ++I) {
148           EvenMask.push_back(I * 2);
149           OddMask.push_back(I * 2 + 1);
150         }
151         LowBits = CGF->Builder.CreateShuffleVector(Uint32Vec, EvenMask);
152         HighBits = CGF->Builder.CreateShuffleVector(Uint32Vec, OddMask);
153       }
154     }
155   }
156   CGF->Builder.CreateStore(LowBits, Op1TmpLValue.getAddress());
157   auto *LastInst =
158       CGF->Builder.CreateStore(HighBits, Op2TmpLValue.getAddress());
159   CGF->EmitWritebacks(Args);
160   return LastInst;
161 }
162 
163 // Return dot product intrinsic that corresponds to the QT scalar type
getDotProductIntrinsic(CGHLSLRuntime & RT,QualType QT)164 static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
165   if (QT->isFloatingType())
166     return RT.getFDotIntrinsic();
167   if (QT->isSignedIntegerType())
168     return RT.getSDotIntrinsic();
169   assert(QT->isUnsignedIntegerType());
170   return RT.getUDotIntrinsic();
171 }
172 
getFirstBitHighIntrinsic(CGHLSLRuntime & RT,QualType QT)173 static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
174   if (QT->hasSignedIntegerRepresentation()) {
175     return RT.getFirstBitSHighIntrinsic();
176   }
177 
178   assert(QT->hasUnsignedIntegerRepresentation());
179   return RT.getFirstBitUHighIntrinsic();
180 }
181 
182 // Return wave active sum that corresponds to the QT scalar type
getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,CGHLSLRuntime & RT,QualType QT)183 static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
184                                                CGHLSLRuntime &RT, QualType QT) {
185   switch (Arch) {
186   case llvm::Triple::spirv:
187     return Intrinsic::spv_wave_reduce_sum;
188   case llvm::Triple::dxil: {
189     if (QT->isUnsignedIntegerType())
190       return Intrinsic::dx_wave_reduce_usum;
191     return Intrinsic::dx_wave_reduce_sum;
192   }
193   default:
194     llvm_unreachable("Intrinsic WaveActiveSum"
195                      " not supported by target architecture");
196   }
197 }
198 
199 // Return wave active sum that corresponds to the QT scalar type
getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,CGHLSLRuntime & RT,QualType QT)200 static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
201                                                CGHLSLRuntime &RT, QualType QT) {
202   switch (Arch) {
203   case llvm::Triple::spirv:
204     if (QT->isUnsignedIntegerType())
205       return Intrinsic::spv_wave_reduce_umax;
206     return Intrinsic::spv_wave_reduce_max;
207   case llvm::Triple::dxil: {
208     if (QT->isUnsignedIntegerType())
209       return Intrinsic::dx_wave_reduce_umax;
210     return Intrinsic::dx_wave_reduce_max;
211   }
212   default:
213     llvm_unreachable("Intrinsic WaveActiveMax"
214                      " not supported by target architecture");
215   }
216 }
217 
218 // Returns the mangled name for a builtin function that the SPIR-V backend
219 // will expand into a spec Constant.
getSpecConstantFunctionName(clang::QualType SpecConstantType,ASTContext & Context)220 static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
221                                                ASTContext &Context) {
222   // The parameter types for our conceptual intrinsic function.
223   QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
224 
225   // Create a temporary FunctionDecl for the builtin fuction. It won't be
226   // added to the AST.
227   FunctionProtoType::ExtProtoInfo EPI;
228   QualType FnType =
229       Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
230   DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
231   FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
232       Context, Context.getTranslationUnitDecl(), SourceLocation(),
233       SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
234 
235   // Attach the created parameter declarations to the function declaration.
236   SmallVector<ParmVarDecl *, 2> ParamDecls;
237   for (QualType ParamType : ClangParamTypes) {
238     ParmVarDecl *PD = ParmVarDecl::Create(
239         Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
240         /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
241         /*DefaultArg*/ nullptr);
242     ParamDecls.push_back(PD);
243   }
244   FnDeclForMangling->setParams(ParamDecls);
245 
246   // Get the mangled name.
247   std::string Name;
248   llvm::raw_string_ostream MangledNameStream(Name);
249   std::unique_ptr<MangleContext> Mangler(Context.createMangleContext());
250   Mangler->mangleName(FnDeclForMangling, MangledNameStream);
251   MangledNameStream.flush();
252 
253   return Name;
254 }
255 
EmitHLSLBuiltinExpr(unsigned BuiltinID,const CallExpr * E,ReturnValueSlot ReturnValue)256 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
257                                             const CallExpr *E,
258                                             ReturnValueSlot ReturnValue) {
259   if (!getLangOpts().HLSL)
260     return nullptr;
261 
262   switch (BuiltinID) {
263   case Builtin::BI__builtin_hlsl_adduint64: {
264     Value *OpA = EmitScalarExpr(E->getArg(0));
265     Value *OpB = EmitScalarExpr(E->getArg(1));
266     QualType Arg0Ty = E->getArg(0)->getType();
267     uint64_t NumElements = Arg0Ty->castAs<VectorType>()->getNumElements();
268     assert(Arg0Ty == E->getArg(1)->getType() &&
269            "AddUint64 operand types must match");
270     assert(Arg0Ty->hasIntegerRepresentation() &&
271            "AddUint64 operands must have an integer representation");
272     assert((NumElements == 2 || NumElements == 4) &&
273            "AddUint64 operands must have 2 or 4 elements");
274 
275     llvm::Value *LowA;
276     llvm::Value *HighA;
277     llvm::Value *LowB;
278     llvm::Value *HighB;
279 
280     // Obtain low and high words of inputs A and B
281     if (NumElements == 2) {
282       LowA = Builder.CreateExtractElement(OpA, (uint64_t)0, "LowA");
283       HighA = Builder.CreateExtractElement(OpA, (uint64_t)1, "HighA");
284       LowB = Builder.CreateExtractElement(OpB, (uint64_t)0, "LowB");
285       HighB = Builder.CreateExtractElement(OpB, (uint64_t)1, "HighB");
286     } else {
287       LowA = Builder.CreateShuffleVector(OpA, {0, 2}, "LowA");
288       HighA = Builder.CreateShuffleVector(OpA, {1, 3}, "HighA");
289       LowB = Builder.CreateShuffleVector(OpB, {0, 2}, "LowB");
290       HighB = Builder.CreateShuffleVector(OpB, {1, 3}, "HighB");
291     }
292 
293     // Use an uadd_with_overflow to compute the sum of low words and obtain a
294     // carry value
295     llvm::Value *Carry;
296     llvm::Value *LowSum = EmitOverflowIntrinsic(
297         *this, Intrinsic::uadd_with_overflow, LowA, LowB, Carry);
298     llvm::Value *ZExtCarry =
299         Builder.CreateZExt(Carry, HighA->getType(), "CarryZExt");
300 
301     // Sum the high words and the carry
302     llvm::Value *HighSum = Builder.CreateAdd(HighA, HighB, "HighSum");
303     llvm::Value *HighSumPlusCarry =
304         Builder.CreateAdd(HighSum, ZExtCarry, "HighSumPlusCarry");
305 
306     if (NumElements == 4) {
307       return Builder.CreateShuffleVector(LowSum, HighSumPlusCarry, {0, 2, 1, 3},
308                                          "hlsl.AddUint64");
309     }
310 
311     llvm::Value *Result = PoisonValue::get(OpA->getType());
312     Result = Builder.CreateInsertElement(Result, LowSum, (uint64_t)0,
313                                          "hlsl.AddUint64.upto0");
314     Result = Builder.CreateInsertElement(Result, HighSumPlusCarry, (uint64_t)1,
315                                          "hlsl.AddUint64");
316     return Result;
317   }
318   case Builtin::BI__builtin_hlsl_resource_getpointer: {
319     Value *HandleOp = EmitScalarExpr(E->getArg(0));
320     Value *IndexOp = EmitScalarExpr(E->getArg(1));
321 
322     llvm::Type *RetTy = ConvertType(E->getType());
323     return Builder.CreateIntrinsic(
324         RetTy, CGM.getHLSLRuntime().getCreateResourceGetPointerIntrinsic(),
325         ArrayRef<Value *>{HandleOp, IndexOp});
326   }
327   case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
328     llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType());
329     return llvm::PoisonValue::get(HandleTy);
330   }
331   case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
332     llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType());
333     Value *RegisterOp = EmitScalarExpr(E->getArg(1));
334     Value *SpaceOp = EmitScalarExpr(E->getArg(2));
335     Value *RangeOp = EmitScalarExpr(E->getArg(3));
336     Value *IndexOp = EmitScalarExpr(E->getArg(4));
337     Value *Name = EmitScalarExpr(E->getArg(5));
338     // FIXME: NonUniformResourceIndex bit is not yet implemented
339     // (llvm/llvm-project#135452)
340     Value *NonUniform =
341         llvm::ConstantInt::get(llvm::Type::getInt1Ty(getLLVMContext()), false);
342 
343     llvm::Intrinsic::ID IntrinsicID =
344         CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic();
345     SmallVector<Value *> Args{SpaceOp, RegisterOp, RangeOp,
346                               IndexOp, NonUniform, Name};
347     return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args);
348   }
349   case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
350     llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType());
351     Value *SpaceOp = EmitScalarExpr(E->getArg(1));
352     Value *RangeOp = EmitScalarExpr(E->getArg(2));
353     Value *IndexOp = EmitScalarExpr(E->getArg(3));
354     Value *OrderID = EmitScalarExpr(E->getArg(4));
355     Value *Name = EmitScalarExpr(E->getArg(5));
356     // FIXME: NonUniformResourceIndex bit is not yet implemented
357     // (llvm/llvm-project#135452)
358     Value *NonUniform =
359         llvm::ConstantInt::get(llvm::Type::getInt1Ty(getLLVMContext()), false);
360 
361     llvm::Intrinsic::ID IntrinsicID =
362         CGM.getHLSLRuntime().getCreateHandleFromImplicitBindingIntrinsic();
363     SmallVector<Value *> Args{OrderID, SpaceOp,    RangeOp,
364                               IndexOp, NonUniform, Name};
365     return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args);
366   }
367   case Builtin::BI__builtin_hlsl_all: {
368     Value *Op0 = EmitScalarExpr(E->getArg(0));
369     return Builder.CreateIntrinsic(
370         /*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()),
371         CGM.getHLSLRuntime().getAllIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,
372         "hlsl.all");
373   }
374   case Builtin::BI__builtin_hlsl_and: {
375     Value *Op0 = EmitScalarExpr(E->getArg(0));
376     Value *Op1 = EmitScalarExpr(E->getArg(1));
377     return Builder.CreateAnd(Op0, Op1, "hlsl.and");
378   }
379   case Builtin::BI__builtin_hlsl_or: {
380     Value *Op0 = EmitScalarExpr(E->getArg(0));
381     Value *Op1 = EmitScalarExpr(E->getArg(1));
382     return Builder.CreateOr(Op0, Op1, "hlsl.or");
383   }
384   case Builtin::BI__builtin_hlsl_any: {
385     Value *Op0 = EmitScalarExpr(E->getArg(0));
386     return Builder.CreateIntrinsic(
387         /*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()),
388         CGM.getHLSLRuntime().getAnyIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,
389         "hlsl.any");
390   }
391   case Builtin::BI__builtin_hlsl_asdouble:
392     return handleAsDoubleBuiltin(*this, E);
393   case Builtin::BI__builtin_hlsl_elementwise_clamp: {
394     Value *OpX = EmitScalarExpr(E->getArg(0));
395     Value *OpMin = EmitScalarExpr(E->getArg(1));
396     Value *OpMax = EmitScalarExpr(E->getArg(2));
397 
398     QualType Ty = E->getArg(0)->getType();
399     if (auto *VecTy = Ty->getAs<VectorType>())
400       Ty = VecTy->getElementType();
401 
402     Intrinsic::ID Intr;
403     if (Ty->isFloatingType()) {
404       Intr = CGM.getHLSLRuntime().getNClampIntrinsic();
405     } else if (Ty->isUnsignedIntegerType()) {
406       Intr = CGM.getHLSLRuntime().getUClampIntrinsic();
407     } else {
408       assert(Ty->isSignedIntegerType());
409       Intr = CGM.getHLSLRuntime().getSClampIntrinsic();
410     }
411     return Builder.CreateIntrinsic(
412         /*ReturnType=*/OpX->getType(), Intr,
413         ArrayRef<Value *>{OpX, OpMin, OpMax}, nullptr, "hlsl.clamp");
414   }
415   case Builtin::BI__builtin_hlsl_crossf16:
416   case Builtin::BI__builtin_hlsl_crossf32: {
417     Value *Op0 = EmitScalarExpr(E->getArg(0));
418     Value *Op1 = EmitScalarExpr(E->getArg(1));
419     assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
420            E->getArg(1)->getType()->hasFloatingRepresentation() &&
421            "cross operands must have a float representation");
422     // make sure each vector has exactly 3 elements
423     assert(
424         E->getArg(0)->getType()->castAs<VectorType>()->getNumElements() == 3 &&
425         E->getArg(1)->getType()->castAs<VectorType>()->getNumElements() == 3 &&
426         "input vectors must have 3 elements each");
427     return Builder.CreateIntrinsic(
428         /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getCrossIntrinsic(),
429         ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.cross");
430   }
431   case Builtin::BI__builtin_hlsl_dot: {
432     Value *Op0 = EmitScalarExpr(E->getArg(0));
433     Value *Op1 = EmitScalarExpr(E->getArg(1));
434     llvm::Type *T0 = Op0->getType();
435     llvm::Type *T1 = Op1->getType();
436 
437     // If the arguments are scalars, just emit a multiply
438     if (!T0->isVectorTy() && !T1->isVectorTy()) {
439       if (T0->isFloatingPointTy())
440         return Builder.CreateFMul(Op0, Op1, "hlsl.dot");
441 
442       if (T0->isIntegerTy())
443         return Builder.CreateMul(Op0, Op1, "hlsl.dot");
444 
445       llvm_unreachable(
446           "Scalar dot product is only supported on ints and floats.");
447     }
448     // For vectors, validate types and emit the appropriate intrinsic
449     assert(CGM.getContext().hasSameUnqualifiedType(E->getArg(0)->getType(),
450                                                    E->getArg(1)->getType()) &&
451            "Dot product operands must have the same type.");
452 
453     auto *VecTy0 = E->getArg(0)->getType()->castAs<VectorType>();
454     assert(VecTy0 && "Dot product argument must be a vector.");
455 
456     return Builder.CreateIntrinsic(
457         /*ReturnType=*/T0->getScalarType(),
458         getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
459         ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
460   }
461   case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
462     Value *X = EmitScalarExpr(E->getArg(0));
463     Value *Y = EmitScalarExpr(E->getArg(1));
464     Value *Acc = EmitScalarExpr(E->getArg(2));
465 
466     Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddI8PackedIntrinsic();
467     // Note that the argument order disagrees between the builtin and the
468     // intrinsic here.
469     return Builder.CreateIntrinsic(
470         /*ReturnType=*/Acc->getType(), ID, ArrayRef<Value *>{Acc, X, Y},
471         nullptr, "hlsl.dot4add.i8packed");
472   }
473   case Builtin::BI__builtin_hlsl_dot4add_u8packed: {
474     Value *X = EmitScalarExpr(E->getArg(0));
475     Value *Y = EmitScalarExpr(E->getArg(1));
476     Value *Acc = EmitScalarExpr(E->getArg(2));
477 
478     Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddU8PackedIntrinsic();
479     // Note that the argument order disagrees between the builtin and the
480     // intrinsic here.
481     return Builder.CreateIntrinsic(
482         /*ReturnType=*/Acc->getType(), ID, ArrayRef<Value *>{Acc, X, Y},
483         nullptr, "hlsl.dot4add.u8packed");
484   }
485   case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
486     Value *X = EmitScalarExpr(E->getArg(0));
487 
488     return Builder.CreateIntrinsic(
489         /*ReturnType=*/ConvertType(E->getType()),
490         getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
491         ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
492   }
493   case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
494     Value *X = EmitScalarExpr(E->getArg(0));
495 
496     return Builder.CreateIntrinsic(
497         /*ReturnType=*/ConvertType(E->getType()),
498         CGM.getHLSLRuntime().getFirstBitLowIntrinsic(), ArrayRef<Value *>{X},
499         nullptr, "hlsl.firstbitlow");
500   }
501   case Builtin::BI__builtin_hlsl_lerp: {
502     Value *X = EmitScalarExpr(E->getArg(0));
503     Value *Y = EmitScalarExpr(E->getArg(1));
504     Value *S = EmitScalarExpr(E->getArg(2));
505     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
506       llvm_unreachable("lerp operand must have a float representation");
507     return Builder.CreateIntrinsic(
508         /*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(),
509         ArrayRef<Value *>{X, Y, S}, nullptr, "hlsl.lerp");
510   }
511   case Builtin::BI__builtin_hlsl_normalize: {
512     Value *X = EmitScalarExpr(E->getArg(0));
513 
514     assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
515            "normalize operand must have a float representation");
516 
517     return Builder.CreateIntrinsic(
518         /*ReturnType=*/X->getType(),
519         CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
520         nullptr, "hlsl.normalize");
521   }
522   case Builtin::BI__builtin_hlsl_elementwise_degrees: {
523     Value *X = EmitScalarExpr(E->getArg(0));
524 
525     assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
526            "degree operand must have a float representation");
527 
528     return Builder.CreateIntrinsic(
529         /*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getDegreesIntrinsic(),
530         ArrayRef<Value *>{X}, nullptr, "hlsl.degrees");
531   }
532   case Builtin::BI__builtin_hlsl_elementwise_frac: {
533     Value *Op0 = EmitScalarExpr(E->getArg(0));
534     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
535       llvm_unreachable("frac operand must have a float representation");
536     return Builder.CreateIntrinsic(
537         /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getFracIntrinsic(),
538         ArrayRef<Value *>{Op0}, nullptr, "hlsl.frac");
539   }
540   case Builtin::BI__builtin_hlsl_elementwise_isinf: {
541     Value *Op0 = EmitScalarExpr(E->getArg(0));
542     llvm::Type *Xty = Op0->getType();
543     llvm::Type *retType = llvm::Type::getInt1Ty(this->getLLVMContext());
544     if (Xty->isVectorTy()) {
545       auto *XVecTy = E->getArg(0)->getType()->castAs<VectorType>();
546       retType = llvm::VectorType::get(
547           retType, ElementCount::getFixed(XVecTy->getNumElements()));
548     }
549     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
550       llvm_unreachable("isinf operand must have a float representation");
551     return Builder.CreateIntrinsic(retType, Intrinsic::dx_isinf,
552                                    ArrayRef<Value *>{Op0}, nullptr, "dx.isinf");
553   }
554   case Builtin::BI__builtin_hlsl_mad: {
555     Value *M = EmitScalarExpr(E->getArg(0));
556     Value *A = EmitScalarExpr(E->getArg(1));
557     Value *B = EmitScalarExpr(E->getArg(2));
558     if (E->getArg(0)->getType()->hasFloatingRepresentation())
559       return Builder.CreateIntrinsic(
560           /*ReturnType*/ M->getType(), Intrinsic::fmuladd,
561           ArrayRef<Value *>{M, A, B}, nullptr, "hlsl.fmad");
562 
563     if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) {
564       if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil)
565         return Builder.CreateIntrinsic(
566             /*ReturnType*/ M->getType(), Intrinsic::dx_imad,
567             ArrayRef<Value *>{M, A, B}, nullptr, "dx.imad");
568 
569       Value *Mul = Builder.CreateNSWMul(M, A);
570       return Builder.CreateNSWAdd(Mul, B);
571     }
572     assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation());
573     if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil)
574       return Builder.CreateIntrinsic(
575           /*ReturnType=*/M->getType(), Intrinsic::dx_umad,
576           ArrayRef<Value *>{M, A, B}, nullptr, "dx.umad");
577 
578     Value *Mul = Builder.CreateNUWMul(M, A);
579     return Builder.CreateNUWAdd(Mul, B);
580   }
581   case Builtin::BI__builtin_hlsl_elementwise_rcp: {
582     Value *Op0 = EmitScalarExpr(E->getArg(0));
583     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
584       llvm_unreachable("rcp operand must have a float representation");
585     llvm::Type *Ty = Op0->getType();
586     llvm::Type *EltTy = Ty->getScalarType();
587     Constant *One = Ty->isVectorTy()
588                         ? ConstantVector::getSplat(
589                               ElementCount::getFixed(
590                                   cast<FixedVectorType>(Ty)->getNumElements()),
591                               ConstantFP::get(EltTy, 1.0))
592                         : ConstantFP::get(EltTy, 1.0);
593     return Builder.CreateFDiv(One, Op0, "hlsl.rcp");
594   }
595   case Builtin::BI__builtin_hlsl_elementwise_rsqrt: {
596     Value *Op0 = EmitScalarExpr(E->getArg(0));
597     if (!E->getArg(0)->getType()->hasFloatingRepresentation())
598       llvm_unreachable("rsqrt operand must have a float representation");
599     return Builder.CreateIntrinsic(
600         /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(),
601         ArrayRef<Value *>{Op0}, nullptr, "hlsl.rsqrt");
602   }
603   case Builtin::BI__builtin_hlsl_elementwise_saturate: {
604     Value *Op0 = EmitScalarExpr(E->getArg(0));
605     assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
606            "saturate operand must have a float representation");
607     return Builder.CreateIntrinsic(
608         /*ReturnType=*/Op0->getType(),
609         CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
610         nullptr, "hlsl.saturate");
611   }
612   case Builtin::BI__builtin_hlsl_select: {
613     Value *OpCond = EmitScalarExpr(E->getArg(0));
614     RValue RValTrue = EmitAnyExpr(E->getArg(1));
615     Value *OpTrue =
616         RValTrue.isScalar()
617             ? RValTrue.getScalarVal()
618             : RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this);
619     RValue RValFalse = EmitAnyExpr(E->getArg(2));
620     Value *OpFalse =
621         RValFalse.isScalar()
622             ? RValFalse.getScalarVal()
623             : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
624     if (auto *VTy = E->getType()->getAs<VectorType>()) {
625       if (!OpTrue->getType()->isVectorTy())
626         OpTrue =
627             Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");
628       if (!OpFalse->getType()->isVectorTy())
629         OpFalse =
630             Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");
631     }
632 
633     Value *SelectVal =
634         Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
635     if (!RValTrue.isScalar())
636       Builder.CreateStore(SelectVal, ReturnValue.getAddress(),
637                           ReturnValue.isVolatile());
638 
639     return SelectVal;
640   }
641   case Builtin::BI__builtin_hlsl_step: {
642     Value *Op0 = EmitScalarExpr(E->getArg(0));
643     Value *Op1 = EmitScalarExpr(E->getArg(1));
644     assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
645            E->getArg(1)->getType()->hasFloatingRepresentation() &&
646            "step operands must have a float representation");
647     return Builder.CreateIntrinsic(
648         /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
649         ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
650   }
651   case Builtin::BI__builtin_hlsl_wave_active_all_true: {
652     Value *Op = EmitScalarExpr(E->getArg(0));
653     assert(Op->getType()->isIntegerTy(1) &&
654            "Intrinsic WaveActiveAllTrue operand must be a bool");
655 
656     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllTrueIntrinsic();
657     return EmitRuntimeCall(
658         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
659   }
660   case Builtin::BI__builtin_hlsl_wave_active_any_true: {
661     Value *Op = EmitScalarExpr(E->getArg(0));
662     assert(Op->getType()->isIntegerTy(1) &&
663            "Intrinsic WaveActiveAnyTrue operand must be a bool");
664 
665     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAnyTrueIntrinsic();
666     return EmitRuntimeCall(
667         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
668   }
669   case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
670     Value *OpExpr = EmitScalarExpr(E->getArg(0));
671     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
672     return EmitRuntimeCall(
673         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
674         ArrayRef{OpExpr});
675   }
676   case Builtin::BI__builtin_hlsl_wave_active_sum: {
677     // Due to the use of variadic arguments, explicitly retreive argument
678     Value *OpExpr = EmitScalarExpr(E->getArg(0));
679     Intrinsic::ID IID = getWaveActiveSumIntrinsic(
680         getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
681         E->getArg(0)->getType());
682 
683     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
684                                &CGM.getModule(), IID, {OpExpr->getType()}),
685                            ArrayRef{OpExpr}, "hlsl.wave.active.sum");
686   }
687   case Builtin::BI__builtin_hlsl_wave_active_max: {
688     // Due to the use of variadic arguments, explicitly retreive argument
689     Value *OpExpr = EmitScalarExpr(E->getArg(0));
690     Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
691         getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
692         E->getArg(0)->getType());
693 
694     return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
695                                &CGM.getModule(), IID, {OpExpr->getType()}),
696                            ArrayRef{OpExpr}, "hlsl.wave.active.max");
697   }
698   case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
699     // We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
700     // defined in SPIRVBuiltins.td. So instead we manually get the matching name
701     // for the DirectX intrinsic and the demangled builtin name
702     switch (CGM.getTarget().getTriple().getArch()) {
703     case llvm::Triple::dxil:
704       return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
705           &CGM.getModule(), Intrinsic::dx_wave_getlaneindex));
706     case llvm::Triple::spirv:
707       return EmitRuntimeCall(CGM.CreateRuntimeFunction(
708           llvm::FunctionType::get(IntTy, {}, false),
709           "__hlsl_wave_get_lane_index", {}, false, true));
710     default:
711       llvm_unreachable(
712           "Intrinsic WaveGetLaneIndex not supported by target architecture");
713     }
714   }
715   case Builtin::BI__builtin_hlsl_wave_is_first_lane: {
716     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
717     return EmitRuntimeCall(
718         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
719   }
720   case Builtin::BI__builtin_hlsl_wave_get_lane_count: {
721     Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveGetLaneCountIntrinsic();
722     return EmitRuntimeCall(
723         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
724   }
725   case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
726     // Due to the use of variadic arguments we must explicitly retreive them and
727     // create our function type.
728     Value *OpExpr = EmitScalarExpr(E->getArg(0));
729     Value *OpIndex = EmitScalarExpr(E->getArg(1));
730     return EmitRuntimeCall(
731         Intrinsic::getOrInsertDeclaration(
732             &CGM.getModule(), CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
733             {OpExpr->getType()}),
734         ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
735   }
736   case Builtin::BI__builtin_hlsl_elementwise_sign: {
737     auto *Arg0 = E->getArg(0);
738     Value *Op0 = EmitScalarExpr(Arg0);
739     llvm::Type *Xty = Op0->getType();
740     llvm::Type *retType = llvm::Type::getInt32Ty(this->getLLVMContext());
741     if (Xty->isVectorTy()) {
742       auto *XVecTy = Arg0->getType()->castAs<VectorType>();
743       retType = llvm::VectorType::get(
744           retType, ElementCount::getFixed(XVecTy->getNumElements()));
745     }
746     assert((Arg0->getType()->hasFloatingRepresentation() ||
747             Arg0->getType()->hasIntegerRepresentation()) &&
748            "sign operand must have a float or int representation");
749 
750     if (Arg0->getType()->hasUnsignedIntegerRepresentation()) {
751       Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::get(Xty, 0));
752       return Builder.CreateSelect(Cmp, ConstantInt::get(retType, 0),
753                                   ConstantInt::get(retType, 1), "hlsl.sign");
754     }
755 
756     return Builder.CreateIntrinsic(
757         retType, CGM.getHLSLRuntime().getSignIntrinsic(),
758         ArrayRef<Value *>{Op0}, nullptr, "hlsl.sign");
759   }
760   case Builtin::BI__builtin_hlsl_elementwise_radians: {
761     Value *Op0 = EmitScalarExpr(E->getArg(0));
762     assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
763            "radians operand must have a float representation");
764     return Builder.CreateIntrinsic(
765         /*ReturnType=*/Op0->getType(),
766         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
767         nullptr, "hlsl.radians");
768   }
769   case Builtin::BI__builtin_hlsl_buffer_update_counter: {
770     Value *ResHandle = EmitScalarExpr(E->getArg(0));
771     Value *Offset = EmitScalarExpr(E->getArg(1));
772     Value *OffsetI8 = Builder.CreateIntCast(Offset, Int8Ty, true);
773     return Builder.CreateIntrinsic(
774         /*ReturnType=*/Offset->getType(),
775         CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(),
776         ArrayRef<Value *>{ResHandle, OffsetI8}, nullptr);
777   }
778   case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
779 
780     assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
781             E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
782             E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
783            "asuint operands types mismatch");
784     return handleHlslSplitdouble(E, this);
785   }
786   case Builtin::BI__builtin_hlsl_elementwise_clip:
787     assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
788            "clip operands types mismatch");
789     return handleHlslClip(E, this);
790   case Builtin::BI__builtin_hlsl_group_memory_barrier_with_group_sync: {
791     Intrinsic::ID ID =
792         CGM.getHLSLRuntime().getGroupMemoryBarrierWithGroupSyncIntrinsic();
793     return EmitRuntimeCall(
794         Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
795   }
796   case Builtin::BI__builtin_get_spirv_spec_constant_bool:
797   case Builtin::BI__builtin_get_spirv_spec_constant_short:
798   case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
799   case Builtin::BI__builtin_get_spirv_spec_constant_int:
800   case Builtin::BI__builtin_get_spirv_spec_constant_uint:
801   case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
802   case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
803   case Builtin::BI__builtin_get_spirv_spec_constant_half:
804   case Builtin::BI__builtin_get_spirv_spec_constant_float:
805   case Builtin::BI__builtin_get_spirv_spec_constant_double: {
806     llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
807     llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
808     llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
809     llvm::Value *Args[] = {SpecId, DefaultVal};
810     return Builder.CreateCall(SpecConstantFn, Args);
811   }
812   }
813   return nullptr;
814 }
815 
getSpecConstantFunction(const clang::QualType & SpecConstantType)816 llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
817     const clang::QualType &SpecConstantType) {
818 
819   // Find or create the declaration for the function.
820   llvm::Module *M = &CGM.getModule();
821   std::string MangledName =
822       getSpecConstantFunctionName(SpecConstantType, getContext());
823   llvm::Function *SpecConstantFn = M->getFunction(MangledName);
824 
825   if (!SpecConstantFn) {
826     llvm::Type *IntType = ConvertType(getContext().IntTy);
827     llvm::Type *RetTy = ConvertType(SpecConstantType);
828     llvm::Type *ArgTypes[] = {IntType, RetTy};
829     llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
830     SpecConstantFn = llvm::Function::Create(
831         FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
832   }
833   return SpecConstantFn;
834 }
835