1 //===------- CGHLSLBuiltins.cpp - Emit LLVM Code for HLSL builtins --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This contains code to emit HLSL Builtin calls as LLVM code.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "CGBuiltin.h"
14 #include "CGHLSLRuntime.h"
15 #include "CodeGenFunction.h"
16
17 using namespace clang;
18 using namespace CodeGen;
19 using namespace llvm;
20
handleAsDoubleBuiltin(CodeGenFunction & CGF,const CallExpr * E)21 static Value *handleAsDoubleBuiltin(CodeGenFunction &CGF, const CallExpr *E) {
22 assert((E->getArg(0)->getType()->hasUnsignedIntegerRepresentation() &&
23 E->getArg(1)->getType()->hasUnsignedIntegerRepresentation()) &&
24 "asdouble operands types mismatch");
25 Value *OpLowBits = CGF.EmitScalarExpr(E->getArg(0));
26 Value *OpHighBits = CGF.EmitScalarExpr(E->getArg(1));
27
28 llvm::Type *ResultType = CGF.DoubleTy;
29 int N = 1;
30 if (auto *VTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {
31 N = VTy->getNumElements();
32 ResultType = llvm::FixedVectorType::get(CGF.DoubleTy, N);
33 }
34
35 if (CGF.CGM.getTarget().getTriple().isDXIL())
36 return CGF.Builder.CreateIntrinsic(
37 /*ReturnType=*/ResultType, Intrinsic::dx_asdouble,
38 {OpLowBits, OpHighBits}, nullptr, "hlsl.asdouble");
39
40 if (!E->getArg(0)->getType()->isVectorType()) {
41 OpLowBits = CGF.Builder.CreateVectorSplat(1, OpLowBits);
42 OpHighBits = CGF.Builder.CreateVectorSplat(1, OpHighBits);
43 }
44
45 llvm::SmallVector<int> Mask;
46 for (int i = 0; i < N; i++) {
47 Mask.push_back(i);
48 Mask.push_back(i + N);
49 }
50
51 Value *BitVec = CGF.Builder.CreateShuffleVector(OpLowBits, OpHighBits, Mask);
52
53 return CGF.Builder.CreateBitCast(BitVec, ResultType);
54 }
55
handleHlslClip(const CallExpr * E,CodeGenFunction * CGF)56 static Value *handleHlslClip(const CallExpr *E, CodeGenFunction *CGF) {
57 Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
58
59 Constant *FZeroConst = ConstantFP::getZero(CGF->FloatTy);
60 Value *CMP;
61 Value *LastInstr;
62
63 if (const auto *VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>()) {
64 FZeroConst = ConstantVector::getSplat(
65 ElementCount::getFixed(VecTy->getNumElements()), FZeroConst);
66 auto *FCompInst = CGF->Builder.CreateFCmpOLT(Op0, FZeroConst);
67 CMP = CGF->Builder.CreateIntrinsic(
68 CGF->Builder.getInt1Ty(), CGF->CGM.getHLSLRuntime().getAnyIntrinsic(),
69 {FCompInst});
70 } else {
71 CMP = CGF->Builder.CreateFCmpOLT(Op0, FZeroConst);
72 }
73
74 if (CGF->CGM.getTarget().getTriple().isDXIL()) {
75 LastInstr = CGF->Builder.CreateIntrinsic(Intrinsic::dx_discard, {CMP});
76 } else if (CGF->CGM.getTarget().getTriple().isSPIRV()) {
77 BasicBlock *LT0 = CGF->createBasicBlock("lt0", CGF->CurFn);
78 BasicBlock *End = CGF->createBasicBlock("end", CGF->CurFn);
79
80 CGF->Builder.CreateCondBr(CMP, LT0, End);
81
82 CGF->Builder.SetInsertPoint(LT0);
83
84 CGF->Builder.CreateIntrinsic(Intrinsic::spv_discard, {});
85
86 LastInstr = CGF->Builder.CreateBr(End);
87 CGF->Builder.SetInsertPoint(End);
88 } else {
89 llvm_unreachable("Backend Codegen not supported.");
90 }
91
92 return LastInstr;
93 }
94
handleHlslSplitdouble(const CallExpr * E,CodeGenFunction * CGF)95 static Value *handleHlslSplitdouble(const CallExpr *E, CodeGenFunction *CGF) {
96 Value *Op0 = CGF->EmitScalarExpr(E->getArg(0));
97 const auto *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1));
98 const auto *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2));
99
100 CallArgList Args;
101 LValue Op1TmpLValue =
102 CGF->EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType());
103 LValue Op2TmpLValue =
104 CGF->EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType());
105
106 if (CGF->getTarget().getCXXABI().areArgsDestroyedLeftToRightInCallee())
107 Args.reverseWritebacks();
108
109 Value *LowBits = nullptr;
110 Value *HighBits = nullptr;
111
112 if (CGF->CGM.getTarget().getTriple().isDXIL()) {
113 llvm::Type *RetElementTy = CGF->Int32Ty;
114 if (auto *Op0VecTy = E->getArg(0)->getType()->getAs<clang::VectorType>())
115 RetElementTy = llvm::VectorType::get(
116 CGF->Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements()));
117 auto *RetTy = llvm::StructType::get(RetElementTy, RetElementTy);
118
119 CallInst *CI = CGF->Builder.CreateIntrinsic(
120 RetTy, Intrinsic::dx_splitdouble, {Op0}, nullptr, "hlsl.splitdouble");
121
122 LowBits = CGF->Builder.CreateExtractValue(CI, 0);
123 HighBits = CGF->Builder.CreateExtractValue(CI, 1);
124 } else {
125 // For Non DXIL targets we generate the instructions.
126
127 if (!Op0->getType()->isVectorTy()) {
128 FixedVectorType *DestTy = FixedVectorType::get(CGF->Int32Ty, 2);
129 Value *Bitcast = CGF->Builder.CreateBitCast(Op0, DestTy);
130
131 LowBits = CGF->Builder.CreateExtractElement(Bitcast, (uint64_t)0);
132 HighBits = CGF->Builder.CreateExtractElement(Bitcast, 1);
133 } else {
134 int NumElements = 1;
135 if (const auto *VecTy =
136 E->getArg(0)->getType()->getAs<clang::VectorType>())
137 NumElements = VecTy->getNumElements();
138
139 FixedVectorType *Uint32VecTy =
140 FixedVectorType::get(CGF->Int32Ty, NumElements * 2);
141 Value *Uint32Vec = CGF->Builder.CreateBitCast(Op0, Uint32VecTy);
142 if (NumElements == 1) {
143 LowBits = CGF->Builder.CreateExtractElement(Uint32Vec, (uint64_t)0);
144 HighBits = CGF->Builder.CreateExtractElement(Uint32Vec, 1);
145 } else {
146 SmallVector<int> EvenMask, OddMask;
147 for (int I = 0, E = NumElements; I != E; ++I) {
148 EvenMask.push_back(I * 2);
149 OddMask.push_back(I * 2 + 1);
150 }
151 LowBits = CGF->Builder.CreateShuffleVector(Uint32Vec, EvenMask);
152 HighBits = CGF->Builder.CreateShuffleVector(Uint32Vec, OddMask);
153 }
154 }
155 }
156 CGF->Builder.CreateStore(LowBits, Op1TmpLValue.getAddress());
157 auto *LastInst =
158 CGF->Builder.CreateStore(HighBits, Op2TmpLValue.getAddress());
159 CGF->EmitWritebacks(Args);
160 return LastInst;
161 }
162
163 // Return dot product intrinsic that corresponds to the QT scalar type
getDotProductIntrinsic(CGHLSLRuntime & RT,QualType QT)164 static Intrinsic::ID getDotProductIntrinsic(CGHLSLRuntime &RT, QualType QT) {
165 if (QT->isFloatingType())
166 return RT.getFDotIntrinsic();
167 if (QT->isSignedIntegerType())
168 return RT.getSDotIntrinsic();
169 assert(QT->isUnsignedIntegerType());
170 return RT.getUDotIntrinsic();
171 }
172
getFirstBitHighIntrinsic(CGHLSLRuntime & RT,QualType QT)173 static Intrinsic::ID getFirstBitHighIntrinsic(CGHLSLRuntime &RT, QualType QT) {
174 if (QT->hasSignedIntegerRepresentation()) {
175 return RT.getFirstBitSHighIntrinsic();
176 }
177
178 assert(QT->hasUnsignedIntegerRepresentation());
179 return RT.getFirstBitUHighIntrinsic();
180 }
181
182 // Return wave active sum that corresponds to the QT scalar type
getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,CGHLSLRuntime & RT,QualType QT)183 static Intrinsic::ID getWaveActiveSumIntrinsic(llvm::Triple::ArchType Arch,
184 CGHLSLRuntime &RT, QualType QT) {
185 switch (Arch) {
186 case llvm::Triple::spirv:
187 return Intrinsic::spv_wave_reduce_sum;
188 case llvm::Triple::dxil: {
189 if (QT->isUnsignedIntegerType())
190 return Intrinsic::dx_wave_reduce_usum;
191 return Intrinsic::dx_wave_reduce_sum;
192 }
193 default:
194 llvm_unreachable("Intrinsic WaveActiveSum"
195 " not supported by target architecture");
196 }
197 }
198
199 // Return wave active sum that corresponds to the QT scalar type
getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,CGHLSLRuntime & RT,QualType QT)200 static Intrinsic::ID getWaveActiveMaxIntrinsic(llvm::Triple::ArchType Arch,
201 CGHLSLRuntime &RT, QualType QT) {
202 switch (Arch) {
203 case llvm::Triple::spirv:
204 if (QT->isUnsignedIntegerType())
205 return Intrinsic::spv_wave_reduce_umax;
206 return Intrinsic::spv_wave_reduce_max;
207 case llvm::Triple::dxil: {
208 if (QT->isUnsignedIntegerType())
209 return Intrinsic::dx_wave_reduce_umax;
210 return Intrinsic::dx_wave_reduce_max;
211 }
212 default:
213 llvm_unreachable("Intrinsic WaveActiveMax"
214 " not supported by target architecture");
215 }
216 }
217
218 // Returns the mangled name for a builtin function that the SPIR-V backend
219 // will expand into a spec Constant.
getSpecConstantFunctionName(clang::QualType SpecConstantType,ASTContext & Context)220 static std::string getSpecConstantFunctionName(clang::QualType SpecConstantType,
221 ASTContext &Context) {
222 // The parameter types for our conceptual intrinsic function.
223 QualType ClangParamTypes[] = {Context.IntTy, SpecConstantType};
224
225 // Create a temporary FunctionDecl for the builtin fuction. It won't be
226 // added to the AST.
227 FunctionProtoType::ExtProtoInfo EPI;
228 QualType FnType =
229 Context.getFunctionType(SpecConstantType, ClangParamTypes, EPI);
230 DeclarationName FuncName = &Context.Idents.get("__spirv_SpecConstant");
231 FunctionDecl *FnDeclForMangling = FunctionDecl::Create(
232 Context, Context.getTranslationUnitDecl(), SourceLocation(),
233 SourceLocation(), FuncName, FnType, /*TSI=*/nullptr, SC_Extern);
234
235 // Attach the created parameter declarations to the function declaration.
236 SmallVector<ParmVarDecl *, 2> ParamDecls;
237 for (QualType ParamType : ClangParamTypes) {
238 ParmVarDecl *PD = ParmVarDecl::Create(
239 Context, FnDeclForMangling, SourceLocation(), SourceLocation(),
240 /*IdentifierInfo*/ nullptr, ParamType, /*TSI*/ nullptr, SC_None,
241 /*DefaultArg*/ nullptr);
242 ParamDecls.push_back(PD);
243 }
244 FnDeclForMangling->setParams(ParamDecls);
245
246 // Get the mangled name.
247 std::string Name;
248 llvm::raw_string_ostream MangledNameStream(Name);
249 std::unique_ptr<MangleContext> Mangler(Context.createMangleContext());
250 Mangler->mangleName(FnDeclForMangling, MangledNameStream);
251 MangledNameStream.flush();
252
253 return Name;
254 }
255
EmitHLSLBuiltinExpr(unsigned BuiltinID,const CallExpr * E,ReturnValueSlot ReturnValue)256 Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
257 const CallExpr *E,
258 ReturnValueSlot ReturnValue) {
259 if (!getLangOpts().HLSL)
260 return nullptr;
261
262 switch (BuiltinID) {
263 case Builtin::BI__builtin_hlsl_adduint64: {
264 Value *OpA = EmitScalarExpr(E->getArg(0));
265 Value *OpB = EmitScalarExpr(E->getArg(1));
266 QualType Arg0Ty = E->getArg(0)->getType();
267 uint64_t NumElements = Arg0Ty->castAs<VectorType>()->getNumElements();
268 assert(Arg0Ty == E->getArg(1)->getType() &&
269 "AddUint64 operand types must match");
270 assert(Arg0Ty->hasIntegerRepresentation() &&
271 "AddUint64 operands must have an integer representation");
272 assert((NumElements == 2 || NumElements == 4) &&
273 "AddUint64 operands must have 2 or 4 elements");
274
275 llvm::Value *LowA;
276 llvm::Value *HighA;
277 llvm::Value *LowB;
278 llvm::Value *HighB;
279
280 // Obtain low and high words of inputs A and B
281 if (NumElements == 2) {
282 LowA = Builder.CreateExtractElement(OpA, (uint64_t)0, "LowA");
283 HighA = Builder.CreateExtractElement(OpA, (uint64_t)1, "HighA");
284 LowB = Builder.CreateExtractElement(OpB, (uint64_t)0, "LowB");
285 HighB = Builder.CreateExtractElement(OpB, (uint64_t)1, "HighB");
286 } else {
287 LowA = Builder.CreateShuffleVector(OpA, {0, 2}, "LowA");
288 HighA = Builder.CreateShuffleVector(OpA, {1, 3}, "HighA");
289 LowB = Builder.CreateShuffleVector(OpB, {0, 2}, "LowB");
290 HighB = Builder.CreateShuffleVector(OpB, {1, 3}, "HighB");
291 }
292
293 // Use an uadd_with_overflow to compute the sum of low words and obtain a
294 // carry value
295 llvm::Value *Carry;
296 llvm::Value *LowSum = EmitOverflowIntrinsic(
297 *this, Intrinsic::uadd_with_overflow, LowA, LowB, Carry);
298 llvm::Value *ZExtCarry =
299 Builder.CreateZExt(Carry, HighA->getType(), "CarryZExt");
300
301 // Sum the high words and the carry
302 llvm::Value *HighSum = Builder.CreateAdd(HighA, HighB, "HighSum");
303 llvm::Value *HighSumPlusCarry =
304 Builder.CreateAdd(HighSum, ZExtCarry, "HighSumPlusCarry");
305
306 if (NumElements == 4) {
307 return Builder.CreateShuffleVector(LowSum, HighSumPlusCarry, {0, 2, 1, 3},
308 "hlsl.AddUint64");
309 }
310
311 llvm::Value *Result = PoisonValue::get(OpA->getType());
312 Result = Builder.CreateInsertElement(Result, LowSum, (uint64_t)0,
313 "hlsl.AddUint64.upto0");
314 Result = Builder.CreateInsertElement(Result, HighSumPlusCarry, (uint64_t)1,
315 "hlsl.AddUint64");
316 return Result;
317 }
318 case Builtin::BI__builtin_hlsl_resource_getpointer: {
319 Value *HandleOp = EmitScalarExpr(E->getArg(0));
320 Value *IndexOp = EmitScalarExpr(E->getArg(1));
321
322 llvm::Type *RetTy = ConvertType(E->getType());
323 return Builder.CreateIntrinsic(
324 RetTy, CGM.getHLSLRuntime().getCreateResourceGetPointerIntrinsic(),
325 ArrayRef<Value *>{HandleOp, IndexOp});
326 }
327 case Builtin::BI__builtin_hlsl_resource_uninitializedhandle: {
328 llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType());
329 return llvm::PoisonValue::get(HandleTy);
330 }
331 case Builtin::BI__builtin_hlsl_resource_handlefrombinding: {
332 llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType());
333 Value *RegisterOp = EmitScalarExpr(E->getArg(1));
334 Value *SpaceOp = EmitScalarExpr(E->getArg(2));
335 Value *RangeOp = EmitScalarExpr(E->getArg(3));
336 Value *IndexOp = EmitScalarExpr(E->getArg(4));
337 Value *Name = EmitScalarExpr(E->getArg(5));
338 // FIXME: NonUniformResourceIndex bit is not yet implemented
339 // (llvm/llvm-project#135452)
340 Value *NonUniform =
341 llvm::ConstantInt::get(llvm::Type::getInt1Ty(getLLVMContext()), false);
342
343 llvm::Intrinsic::ID IntrinsicID =
344 CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic();
345 SmallVector<Value *> Args{SpaceOp, RegisterOp, RangeOp,
346 IndexOp, NonUniform, Name};
347 return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args);
348 }
349 case Builtin::BI__builtin_hlsl_resource_handlefromimplicitbinding: {
350 llvm::Type *HandleTy = CGM.getTypes().ConvertType(E->getType());
351 Value *SpaceOp = EmitScalarExpr(E->getArg(1));
352 Value *RangeOp = EmitScalarExpr(E->getArg(2));
353 Value *IndexOp = EmitScalarExpr(E->getArg(3));
354 Value *OrderID = EmitScalarExpr(E->getArg(4));
355 Value *Name = EmitScalarExpr(E->getArg(5));
356 // FIXME: NonUniformResourceIndex bit is not yet implemented
357 // (llvm/llvm-project#135452)
358 Value *NonUniform =
359 llvm::ConstantInt::get(llvm::Type::getInt1Ty(getLLVMContext()), false);
360
361 llvm::Intrinsic::ID IntrinsicID =
362 CGM.getHLSLRuntime().getCreateHandleFromImplicitBindingIntrinsic();
363 SmallVector<Value *> Args{OrderID, SpaceOp, RangeOp,
364 IndexOp, NonUniform, Name};
365 return Builder.CreateIntrinsic(HandleTy, IntrinsicID, Args);
366 }
367 case Builtin::BI__builtin_hlsl_all: {
368 Value *Op0 = EmitScalarExpr(E->getArg(0));
369 return Builder.CreateIntrinsic(
370 /*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()),
371 CGM.getHLSLRuntime().getAllIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,
372 "hlsl.all");
373 }
374 case Builtin::BI__builtin_hlsl_and: {
375 Value *Op0 = EmitScalarExpr(E->getArg(0));
376 Value *Op1 = EmitScalarExpr(E->getArg(1));
377 return Builder.CreateAnd(Op0, Op1, "hlsl.and");
378 }
379 case Builtin::BI__builtin_hlsl_or: {
380 Value *Op0 = EmitScalarExpr(E->getArg(0));
381 Value *Op1 = EmitScalarExpr(E->getArg(1));
382 return Builder.CreateOr(Op0, Op1, "hlsl.or");
383 }
384 case Builtin::BI__builtin_hlsl_any: {
385 Value *Op0 = EmitScalarExpr(E->getArg(0));
386 return Builder.CreateIntrinsic(
387 /*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()),
388 CGM.getHLSLRuntime().getAnyIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,
389 "hlsl.any");
390 }
391 case Builtin::BI__builtin_hlsl_asdouble:
392 return handleAsDoubleBuiltin(*this, E);
393 case Builtin::BI__builtin_hlsl_elementwise_clamp: {
394 Value *OpX = EmitScalarExpr(E->getArg(0));
395 Value *OpMin = EmitScalarExpr(E->getArg(1));
396 Value *OpMax = EmitScalarExpr(E->getArg(2));
397
398 QualType Ty = E->getArg(0)->getType();
399 if (auto *VecTy = Ty->getAs<VectorType>())
400 Ty = VecTy->getElementType();
401
402 Intrinsic::ID Intr;
403 if (Ty->isFloatingType()) {
404 Intr = CGM.getHLSLRuntime().getNClampIntrinsic();
405 } else if (Ty->isUnsignedIntegerType()) {
406 Intr = CGM.getHLSLRuntime().getUClampIntrinsic();
407 } else {
408 assert(Ty->isSignedIntegerType());
409 Intr = CGM.getHLSLRuntime().getSClampIntrinsic();
410 }
411 return Builder.CreateIntrinsic(
412 /*ReturnType=*/OpX->getType(), Intr,
413 ArrayRef<Value *>{OpX, OpMin, OpMax}, nullptr, "hlsl.clamp");
414 }
415 case Builtin::BI__builtin_hlsl_crossf16:
416 case Builtin::BI__builtin_hlsl_crossf32: {
417 Value *Op0 = EmitScalarExpr(E->getArg(0));
418 Value *Op1 = EmitScalarExpr(E->getArg(1));
419 assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
420 E->getArg(1)->getType()->hasFloatingRepresentation() &&
421 "cross operands must have a float representation");
422 // make sure each vector has exactly 3 elements
423 assert(
424 E->getArg(0)->getType()->castAs<VectorType>()->getNumElements() == 3 &&
425 E->getArg(1)->getType()->castAs<VectorType>()->getNumElements() == 3 &&
426 "input vectors must have 3 elements each");
427 return Builder.CreateIntrinsic(
428 /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getCrossIntrinsic(),
429 ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.cross");
430 }
431 case Builtin::BI__builtin_hlsl_dot: {
432 Value *Op0 = EmitScalarExpr(E->getArg(0));
433 Value *Op1 = EmitScalarExpr(E->getArg(1));
434 llvm::Type *T0 = Op0->getType();
435 llvm::Type *T1 = Op1->getType();
436
437 // If the arguments are scalars, just emit a multiply
438 if (!T0->isVectorTy() && !T1->isVectorTy()) {
439 if (T0->isFloatingPointTy())
440 return Builder.CreateFMul(Op0, Op1, "hlsl.dot");
441
442 if (T0->isIntegerTy())
443 return Builder.CreateMul(Op0, Op1, "hlsl.dot");
444
445 llvm_unreachable(
446 "Scalar dot product is only supported on ints and floats.");
447 }
448 // For vectors, validate types and emit the appropriate intrinsic
449 assert(CGM.getContext().hasSameUnqualifiedType(E->getArg(0)->getType(),
450 E->getArg(1)->getType()) &&
451 "Dot product operands must have the same type.");
452
453 auto *VecTy0 = E->getArg(0)->getType()->castAs<VectorType>();
454 assert(VecTy0 && "Dot product argument must be a vector.");
455
456 return Builder.CreateIntrinsic(
457 /*ReturnType=*/T0->getScalarType(),
458 getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()),
459 ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot");
460 }
461 case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
462 Value *X = EmitScalarExpr(E->getArg(0));
463 Value *Y = EmitScalarExpr(E->getArg(1));
464 Value *Acc = EmitScalarExpr(E->getArg(2));
465
466 Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddI8PackedIntrinsic();
467 // Note that the argument order disagrees between the builtin and the
468 // intrinsic here.
469 return Builder.CreateIntrinsic(
470 /*ReturnType=*/Acc->getType(), ID, ArrayRef<Value *>{Acc, X, Y},
471 nullptr, "hlsl.dot4add.i8packed");
472 }
473 case Builtin::BI__builtin_hlsl_dot4add_u8packed: {
474 Value *X = EmitScalarExpr(E->getArg(0));
475 Value *Y = EmitScalarExpr(E->getArg(1));
476 Value *Acc = EmitScalarExpr(E->getArg(2));
477
478 Intrinsic::ID ID = CGM.getHLSLRuntime().getDot4AddU8PackedIntrinsic();
479 // Note that the argument order disagrees between the builtin and the
480 // intrinsic here.
481 return Builder.CreateIntrinsic(
482 /*ReturnType=*/Acc->getType(), ID, ArrayRef<Value *>{Acc, X, Y},
483 nullptr, "hlsl.dot4add.u8packed");
484 }
485 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: {
486 Value *X = EmitScalarExpr(E->getArg(0));
487
488 return Builder.CreateIntrinsic(
489 /*ReturnType=*/ConvertType(E->getType()),
490 getFirstBitHighIntrinsic(CGM.getHLSLRuntime(), E->getArg(0)->getType()),
491 ArrayRef<Value *>{X}, nullptr, "hlsl.firstbithigh");
492 }
493 case Builtin::BI__builtin_hlsl_elementwise_firstbitlow: {
494 Value *X = EmitScalarExpr(E->getArg(0));
495
496 return Builder.CreateIntrinsic(
497 /*ReturnType=*/ConvertType(E->getType()),
498 CGM.getHLSLRuntime().getFirstBitLowIntrinsic(), ArrayRef<Value *>{X},
499 nullptr, "hlsl.firstbitlow");
500 }
501 case Builtin::BI__builtin_hlsl_lerp: {
502 Value *X = EmitScalarExpr(E->getArg(0));
503 Value *Y = EmitScalarExpr(E->getArg(1));
504 Value *S = EmitScalarExpr(E->getArg(2));
505 if (!E->getArg(0)->getType()->hasFloatingRepresentation())
506 llvm_unreachable("lerp operand must have a float representation");
507 return Builder.CreateIntrinsic(
508 /*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getLerpIntrinsic(),
509 ArrayRef<Value *>{X, Y, S}, nullptr, "hlsl.lerp");
510 }
511 case Builtin::BI__builtin_hlsl_normalize: {
512 Value *X = EmitScalarExpr(E->getArg(0));
513
514 assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
515 "normalize operand must have a float representation");
516
517 return Builder.CreateIntrinsic(
518 /*ReturnType=*/X->getType(),
519 CGM.getHLSLRuntime().getNormalizeIntrinsic(), ArrayRef<Value *>{X},
520 nullptr, "hlsl.normalize");
521 }
522 case Builtin::BI__builtin_hlsl_elementwise_degrees: {
523 Value *X = EmitScalarExpr(E->getArg(0));
524
525 assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
526 "degree operand must have a float representation");
527
528 return Builder.CreateIntrinsic(
529 /*ReturnType=*/X->getType(), CGM.getHLSLRuntime().getDegreesIntrinsic(),
530 ArrayRef<Value *>{X}, nullptr, "hlsl.degrees");
531 }
532 case Builtin::BI__builtin_hlsl_elementwise_frac: {
533 Value *Op0 = EmitScalarExpr(E->getArg(0));
534 if (!E->getArg(0)->getType()->hasFloatingRepresentation())
535 llvm_unreachable("frac operand must have a float representation");
536 return Builder.CreateIntrinsic(
537 /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getFracIntrinsic(),
538 ArrayRef<Value *>{Op0}, nullptr, "hlsl.frac");
539 }
540 case Builtin::BI__builtin_hlsl_elementwise_isinf: {
541 Value *Op0 = EmitScalarExpr(E->getArg(0));
542 llvm::Type *Xty = Op0->getType();
543 llvm::Type *retType = llvm::Type::getInt1Ty(this->getLLVMContext());
544 if (Xty->isVectorTy()) {
545 auto *XVecTy = E->getArg(0)->getType()->castAs<VectorType>();
546 retType = llvm::VectorType::get(
547 retType, ElementCount::getFixed(XVecTy->getNumElements()));
548 }
549 if (!E->getArg(0)->getType()->hasFloatingRepresentation())
550 llvm_unreachable("isinf operand must have a float representation");
551 return Builder.CreateIntrinsic(retType, Intrinsic::dx_isinf,
552 ArrayRef<Value *>{Op0}, nullptr, "dx.isinf");
553 }
554 case Builtin::BI__builtin_hlsl_mad: {
555 Value *M = EmitScalarExpr(E->getArg(0));
556 Value *A = EmitScalarExpr(E->getArg(1));
557 Value *B = EmitScalarExpr(E->getArg(2));
558 if (E->getArg(0)->getType()->hasFloatingRepresentation())
559 return Builder.CreateIntrinsic(
560 /*ReturnType*/ M->getType(), Intrinsic::fmuladd,
561 ArrayRef<Value *>{M, A, B}, nullptr, "hlsl.fmad");
562
563 if (E->getArg(0)->getType()->hasSignedIntegerRepresentation()) {
564 if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil)
565 return Builder.CreateIntrinsic(
566 /*ReturnType*/ M->getType(), Intrinsic::dx_imad,
567 ArrayRef<Value *>{M, A, B}, nullptr, "dx.imad");
568
569 Value *Mul = Builder.CreateNSWMul(M, A);
570 return Builder.CreateNSWAdd(Mul, B);
571 }
572 assert(E->getArg(0)->getType()->hasUnsignedIntegerRepresentation());
573 if (CGM.getTarget().getTriple().getArch() == llvm::Triple::dxil)
574 return Builder.CreateIntrinsic(
575 /*ReturnType=*/M->getType(), Intrinsic::dx_umad,
576 ArrayRef<Value *>{M, A, B}, nullptr, "dx.umad");
577
578 Value *Mul = Builder.CreateNUWMul(M, A);
579 return Builder.CreateNUWAdd(Mul, B);
580 }
581 case Builtin::BI__builtin_hlsl_elementwise_rcp: {
582 Value *Op0 = EmitScalarExpr(E->getArg(0));
583 if (!E->getArg(0)->getType()->hasFloatingRepresentation())
584 llvm_unreachable("rcp operand must have a float representation");
585 llvm::Type *Ty = Op0->getType();
586 llvm::Type *EltTy = Ty->getScalarType();
587 Constant *One = Ty->isVectorTy()
588 ? ConstantVector::getSplat(
589 ElementCount::getFixed(
590 cast<FixedVectorType>(Ty)->getNumElements()),
591 ConstantFP::get(EltTy, 1.0))
592 : ConstantFP::get(EltTy, 1.0);
593 return Builder.CreateFDiv(One, Op0, "hlsl.rcp");
594 }
595 case Builtin::BI__builtin_hlsl_elementwise_rsqrt: {
596 Value *Op0 = EmitScalarExpr(E->getArg(0));
597 if (!E->getArg(0)->getType()->hasFloatingRepresentation())
598 llvm_unreachable("rsqrt operand must have a float representation");
599 return Builder.CreateIntrinsic(
600 /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getRsqrtIntrinsic(),
601 ArrayRef<Value *>{Op0}, nullptr, "hlsl.rsqrt");
602 }
603 case Builtin::BI__builtin_hlsl_elementwise_saturate: {
604 Value *Op0 = EmitScalarExpr(E->getArg(0));
605 assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
606 "saturate operand must have a float representation");
607 return Builder.CreateIntrinsic(
608 /*ReturnType=*/Op0->getType(),
609 CGM.getHLSLRuntime().getSaturateIntrinsic(), ArrayRef<Value *>{Op0},
610 nullptr, "hlsl.saturate");
611 }
612 case Builtin::BI__builtin_hlsl_select: {
613 Value *OpCond = EmitScalarExpr(E->getArg(0));
614 RValue RValTrue = EmitAnyExpr(E->getArg(1));
615 Value *OpTrue =
616 RValTrue.isScalar()
617 ? RValTrue.getScalarVal()
618 : RValTrue.getAggregatePointer(E->getArg(1)->getType(), *this);
619 RValue RValFalse = EmitAnyExpr(E->getArg(2));
620 Value *OpFalse =
621 RValFalse.isScalar()
622 ? RValFalse.getScalarVal()
623 : RValFalse.getAggregatePointer(E->getArg(2)->getType(), *this);
624 if (auto *VTy = E->getType()->getAs<VectorType>()) {
625 if (!OpTrue->getType()->isVectorTy())
626 OpTrue =
627 Builder.CreateVectorSplat(VTy->getNumElements(), OpTrue, "splat");
628 if (!OpFalse->getType()->isVectorTy())
629 OpFalse =
630 Builder.CreateVectorSplat(VTy->getNumElements(), OpFalse, "splat");
631 }
632
633 Value *SelectVal =
634 Builder.CreateSelect(OpCond, OpTrue, OpFalse, "hlsl.select");
635 if (!RValTrue.isScalar())
636 Builder.CreateStore(SelectVal, ReturnValue.getAddress(),
637 ReturnValue.isVolatile());
638
639 return SelectVal;
640 }
641 case Builtin::BI__builtin_hlsl_step: {
642 Value *Op0 = EmitScalarExpr(E->getArg(0));
643 Value *Op1 = EmitScalarExpr(E->getArg(1));
644 assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
645 E->getArg(1)->getType()->hasFloatingRepresentation() &&
646 "step operands must have a float representation");
647 return Builder.CreateIntrinsic(
648 /*ReturnType=*/Op0->getType(), CGM.getHLSLRuntime().getStepIntrinsic(),
649 ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.step");
650 }
651 case Builtin::BI__builtin_hlsl_wave_active_all_true: {
652 Value *Op = EmitScalarExpr(E->getArg(0));
653 assert(Op->getType()->isIntegerTy(1) &&
654 "Intrinsic WaveActiveAllTrue operand must be a bool");
655
656 Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAllTrueIntrinsic();
657 return EmitRuntimeCall(
658 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
659 }
660 case Builtin::BI__builtin_hlsl_wave_active_any_true: {
661 Value *Op = EmitScalarExpr(E->getArg(0));
662 assert(Op->getType()->isIntegerTy(1) &&
663 "Intrinsic WaveActiveAnyTrue operand must be a bool");
664
665 Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveAnyTrueIntrinsic();
666 return EmitRuntimeCall(
667 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
668 }
669 case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
670 Value *OpExpr = EmitScalarExpr(E->getArg(0));
671 Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
672 return EmitRuntimeCall(
673 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID),
674 ArrayRef{OpExpr});
675 }
676 case Builtin::BI__builtin_hlsl_wave_active_sum: {
677 // Due to the use of variadic arguments, explicitly retreive argument
678 Value *OpExpr = EmitScalarExpr(E->getArg(0));
679 Intrinsic::ID IID = getWaveActiveSumIntrinsic(
680 getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
681 E->getArg(0)->getType());
682
683 return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
684 &CGM.getModule(), IID, {OpExpr->getType()}),
685 ArrayRef{OpExpr}, "hlsl.wave.active.sum");
686 }
687 case Builtin::BI__builtin_hlsl_wave_active_max: {
688 // Due to the use of variadic arguments, explicitly retreive argument
689 Value *OpExpr = EmitScalarExpr(E->getArg(0));
690 Intrinsic::ID IID = getWaveActiveMaxIntrinsic(
691 getTarget().getTriple().getArch(), CGM.getHLSLRuntime(),
692 E->getArg(0)->getType());
693
694 return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
695 &CGM.getModule(), IID, {OpExpr->getType()}),
696 ArrayRef{OpExpr}, "hlsl.wave.active.max");
697 }
698 case Builtin::BI__builtin_hlsl_wave_get_lane_index: {
699 // We don't define a SPIR-V intrinsic, instead it is a SPIR-V built-in
700 // defined in SPIRVBuiltins.td. So instead we manually get the matching name
701 // for the DirectX intrinsic and the demangled builtin name
702 switch (CGM.getTarget().getTriple().getArch()) {
703 case llvm::Triple::dxil:
704 return EmitRuntimeCall(Intrinsic::getOrInsertDeclaration(
705 &CGM.getModule(), Intrinsic::dx_wave_getlaneindex));
706 case llvm::Triple::spirv:
707 return EmitRuntimeCall(CGM.CreateRuntimeFunction(
708 llvm::FunctionType::get(IntTy, {}, false),
709 "__hlsl_wave_get_lane_index", {}, false, true));
710 default:
711 llvm_unreachable(
712 "Intrinsic WaveGetLaneIndex not supported by target architecture");
713 }
714 }
715 case Builtin::BI__builtin_hlsl_wave_is_first_lane: {
716 Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveIsFirstLaneIntrinsic();
717 return EmitRuntimeCall(
718 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
719 }
720 case Builtin::BI__builtin_hlsl_wave_get_lane_count: {
721 Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveGetLaneCountIntrinsic();
722 return EmitRuntimeCall(
723 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
724 }
725 case Builtin::BI__builtin_hlsl_wave_read_lane_at: {
726 // Due to the use of variadic arguments we must explicitly retreive them and
727 // create our function type.
728 Value *OpExpr = EmitScalarExpr(E->getArg(0));
729 Value *OpIndex = EmitScalarExpr(E->getArg(1));
730 return EmitRuntimeCall(
731 Intrinsic::getOrInsertDeclaration(
732 &CGM.getModule(), CGM.getHLSLRuntime().getWaveReadLaneAtIntrinsic(),
733 {OpExpr->getType()}),
734 ArrayRef{OpExpr, OpIndex}, "hlsl.wave.readlane");
735 }
736 case Builtin::BI__builtin_hlsl_elementwise_sign: {
737 auto *Arg0 = E->getArg(0);
738 Value *Op0 = EmitScalarExpr(Arg0);
739 llvm::Type *Xty = Op0->getType();
740 llvm::Type *retType = llvm::Type::getInt32Ty(this->getLLVMContext());
741 if (Xty->isVectorTy()) {
742 auto *XVecTy = Arg0->getType()->castAs<VectorType>();
743 retType = llvm::VectorType::get(
744 retType, ElementCount::getFixed(XVecTy->getNumElements()));
745 }
746 assert((Arg0->getType()->hasFloatingRepresentation() ||
747 Arg0->getType()->hasIntegerRepresentation()) &&
748 "sign operand must have a float or int representation");
749
750 if (Arg0->getType()->hasUnsignedIntegerRepresentation()) {
751 Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::get(Xty, 0));
752 return Builder.CreateSelect(Cmp, ConstantInt::get(retType, 0),
753 ConstantInt::get(retType, 1), "hlsl.sign");
754 }
755
756 return Builder.CreateIntrinsic(
757 retType, CGM.getHLSLRuntime().getSignIntrinsic(),
758 ArrayRef<Value *>{Op0}, nullptr, "hlsl.sign");
759 }
760 case Builtin::BI__builtin_hlsl_elementwise_radians: {
761 Value *Op0 = EmitScalarExpr(E->getArg(0));
762 assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
763 "radians operand must have a float representation");
764 return Builder.CreateIntrinsic(
765 /*ReturnType=*/Op0->getType(),
766 CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
767 nullptr, "hlsl.radians");
768 }
769 case Builtin::BI__builtin_hlsl_buffer_update_counter: {
770 Value *ResHandle = EmitScalarExpr(E->getArg(0));
771 Value *Offset = EmitScalarExpr(E->getArg(1));
772 Value *OffsetI8 = Builder.CreateIntCast(Offset, Int8Ty, true);
773 return Builder.CreateIntrinsic(
774 /*ReturnType=*/Offset->getType(),
775 CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(),
776 ArrayRef<Value *>{ResHandle, OffsetI8}, nullptr);
777 }
778 case Builtin::BI__builtin_hlsl_elementwise_splitdouble: {
779
780 assert((E->getArg(0)->getType()->hasFloatingRepresentation() &&
781 E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() &&
782 E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) &&
783 "asuint operands types mismatch");
784 return handleHlslSplitdouble(E, this);
785 }
786 case Builtin::BI__builtin_hlsl_elementwise_clip:
787 assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
788 "clip operands types mismatch");
789 return handleHlslClip(E, this);
790 case Builtin::BI__builtin_hlsl_group_memory_barrier_with_group_sync: {
791 Intrinsic::ID ID =
792 CGM.getHLSLRuntime().getGroupMemoryBarrierWithGroupSyncIntrinsic();
793 return EmitRuntimeCall(
794 Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID));
795 }
796 case Builtin::BI__builtin_get_spirv_spec_constant_bool:
797 case Builtin::BI__builtin_get_spirv_spec_constant_short:
798 case Builtin::BI__builtin_get_spirv_spec_constant_ushort:
799 case Builtin::BI__builtin_get_spirv_spec_constant_int:
800 case Builtin::BI__builtin_get_spirv_spec_constant_uint:
801 case Builtin::BI__builtin_get_spirv_spec_constant_longlong:
802 case Builtin::BI__builtin_get_spirv_spec_constant_ulonglong:
803 case Builtin::BI__builtin_get_spirv_spec_constant_half:
804 case Builtin::BI__builtin_get_spirv_spec_constant_float:
805 case Builtin::BI__builtin_get_spirv_spec_constant_double: {
806 llvm::Function *SpecConstantFn = getSpecConstantFunction(E->getType());
807 llvm::Value *SpecId = EmitScalarExpr(E->getArg(0));
808 llvm::Value *DefaultVal = EmitScalarExpr(E->getArg(1));
809 llvm::Value *Args[] = {SpecId, DefaultVal};
810 return Builder.CreateCall(SpecConstantFn, Args);
811 }
812 }
813 return nullptr;
814 }
815
getSpecConstantFunction(const clang::QualType & SpecConstantType)816 llvm::Function *clang::CodeGen::CodeGenFunction::getSpecConstantFunction(
817 const clang::QualType &SpecConstantType) {
818
819 // Find or create the declaration for the function.
820 llvm::Module *M = &CGM.getModule();
821 std::string MangledName =
822 getSpecConstantFunctionName(SpecConstantType, getContext());
823 llvm::Function *SpecConstantFn = M->getFunction(MangledName);
824
825 if (!SpecConstantFn) {
826 llvm::Type *IntType = ConvertType(getContext().IntTy);
827 llvm::Type *RetTy = ConvertType(SpecConstantType);
828 llvm::Type *ArgTypes[] = {IntType, RetTy};
829 llvm::FunctionType *FnTy = llvm::FunctionType::get(RetTy, ArgTypes, false);
830 SpecConstantFn = llvm::Function::Create(
831 FnTy, llvm::GlobalValue::ExternalLinkage, MangledName, M);
832 }
833 return SpecConstantFn;
834 }
835