1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This provides an abstract class for HLSL code generation. Concrete
10 // subclasses of this implement code generation for specific HLSL
11 // runtime libraries.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "CGHLSLRuntime.h"
16 #include "CGDebugInfo.h"
17 #include "CodeGenFunction.h"
18 #include "CodeGenModule.h"
19 #include "TargetInfo.h"
20 #include "clang/AST/ASTContext.h"
21 #include "clang/AST/Decl.h"
22 #include "clang/AST/RecursiveASTVisitor.h"
23 #include "clang/AST/Type.h"
24 #include "clang/Basic/TargetOptions.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/DerivedTypes.h"
29 #include "llvm/IR/GlobalVariable.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/Metadata.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Type.h"
34 #include "llvm/IR/Value.h"
35 #include "llvm/Support/Alignment.h"
36 #include "llvm/Support/ErrorHandling.h"
37 #include "llvm/Support/FormatVariadic.h"
38
39 using namespace clang;
40 using namespace CodeGen;
41 using namespace clang::hlsl;
42 using namespace llvm;
43
44 using llvm::hlsl::CBufferRowSizeInBytes;
45
46 namespace {
47
addDxilValVersion(StringRef ValVersionStr,llvm::Module & M)48 void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
49 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
50 // Assume ValVersionStr is legal here.
51 VersionTuple Version;
52 if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
53 Version.getSubminor() || !Version.getMinor()) {
54 return;
55 }
56
57 uint64_t Major = Version.getMajor();
58 uint64_t Minor = *Version.getMinor();
59
60 auto &Ctx = M.getContext();
61 IRBuilder<> B(M.getContext());
62 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
63 ConstantAsMetadata::get(B.getInt32(Minor))});
64 StringRef DXILValKey = "dx.valver";
65 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
66 DXILValMD->addOperand(Val);
67 }
68
addRootSignature(llvm::dxbc::RootSignatureVersion RootSigVer,ArrayRef<llvm::hlsl::rootsig::RootElement> Elements,llvm::Function * Fn,llvm::Module & M)69 void addRootSignature(llvm::dxbc::RootSignatureVersion RootSigVer,
70 ArrayRef<llvm::hlsl::rootsig::RootElement> Elements,
71 llvm::Function *Fn, llvm::Module &M) {
72 auto &Ctx = M.getContext();
73
74 llvm::hlsl::rootsig::MetadataBuilder RSBuilder(Ctx, Elements);
75 MDNode *RootSignature = RSBuilder.BuildRootSignature();
76
77 ConstantAsMetadata *Version = ConstantAsMetadata::get(ConstantInt::get(
78 llvm::Type::getInt32Ty(Ctx), llvm::to_underlying(RootSigVer)));
79 MDNode *MDVals =
80 MDNode::get(Ctx, {ValueAsMetadata::get(Fn), RootSignature, Version});
81
82 StringRef RootSignatureValKey = "dx.rootsignatures";
83 auto *RootSignatureValMD = M.getOrInsertNamedMetadata(RootSignatureValKey);
84 RootSignatureValMD->addOperand(MDVals);
85 }
86
87 } // namespace
88
89 llvm::Type *
convertHLSLSpecificType(const Type * T,SmallVector<int32_t> * Packoffsets)90 CGHLSLRuntime::convertHLSLSpecificType(const Type *T,
91 SmallVector<int32_t> *Packoffsets) {
92 assert(T->isHLSLSpecificType() && "Not an HLSL specific type!");
93
94 // Check if the target has a specific translation for this type first.
95 if (llvm::Type *TargetTy =
96 CGM.getTargetCodeGenInfo().getHLSLType(CGM, T, Packoffsets))
97 return TargetTy;
98
99 llvm_unreachable("Generic handling of HLSL types is not supported.");
100 }
101
getArch()102 llvm::Triple::ArchType CGHLSLRuntime::getArch() {
103 return CGM.getTarget().getTriple().getArch();
104 }
105
106 // Returns true if the type is an HLSL resource class or an array of them
isResourceRecordTypeOrArrayOf(const clang::Type * Ty)107 static bool isResourceRecordTypeOrArrayOf(const clang::Type *Ty) {
108 while (const ConstantArrayType *CAT = dyn_cast<ConstantArrayType>(Ty))
109 Ty = CAT->getArrayElementTypeNoTypeQual();
110 return Ty->isHLSLResourceRecord();
111 }
112
113 // Emits constant global variables for buffer constants declarations
114 // and creates metadata linking the constant globals with the buffer global.
emitBufferGlobalsAndMetadata(const HLSLBufferDecl * BufDecl,llvm::GlobalVariable * BufGV)115 void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
116 llvm::GlobalVariable *BufGV) {
117 LLVMContext &Ctx = CGM.getLLVMContext();
118
119 // get the layout struct from constant buffer target type
120 llvm::Type *BufType = BufGV->getValueType();
121 llvm::Type *BufLayoutType =
122 cast<llvm::TargetExtType>(BufType)->getTypeParameter(0);
123 llvm::StructType *LayoutStruct = cast<llvm::StructType>(
124 cast<llvm::TargetExtType>(BufLayoutType)->getTypeParameter(0));
125
126 // Start metadata list associating the buffer global variable with its
127 // constatns
128 SmallVector<llvm::Metadata *> BufGlobals;
129 BufGlobals.push_back(ValueAsMetadata::get(BufGV));
130
131 const auto *ElemIt = LayoutStruct->element_begin();
132 for (Decl *D : BufDecl->buffer_decls()) {
133 if (isa<CXXRecordDecl, EmptyDecl>(D))
134 // Nothing to do for this declaration.
135 continue;
136 if (isa<FunctionDecl>(D)) {
137 // A function within an cbuffer is effectively a top-level function.
138 CGM.EmitTopLevelDecl(D);
139 continue;
140 }
141 VarDecl *VD = dyn_cast<VarDecl>(D);
142 if (!VD)
143 continue;
144
145 QualType VDTy = VD->getType();
146 if (VDTy.getAddressSpace() != LangAS::hlsl_constant) {
147 if (VD->getStorageClass() == SC_Static ||
148 VDTy.getAddressSpace() == LangAS::hlsl_groupshared ||
149 isResourceRecordTypeOrArrayOf(VDTy.getTypePtr())) {
150 // Emit static and groupshared variables and resource classes inside
151 // cbuffer as regular globals
152 CGM.EmitGlobal(VD);
153 } else {
154 // Anything else that is not in the hlsl_constant address space must be
155 // an empty struct or a zero-sized array and can be ignored
156 assert(BufDecl->getASTContext().getTypeSize(VDTy) == 0 &&
157 "constant buffer decl with non-zero sized type outside of "
158 "hlsl_constant address space");
159 }
160 continue;
161 }
162
163 assert(ElemIt != LayoutStruct->element_end() &&
164 "number of elements in layout struct does not match");
165 llvm::Type *LayoutType = *ElemIt++;
166
167 // FIXME: handle resources inside user defined structs
168 // (llvm/wg-hlsl#175)
169
170 // create global variable for the constant and to metadata list
171 GlobalVariable *ElemGV =
172 cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(VD, LayoutType));
173 BufGlobals.push_back(ValueAsMetadata::get(ElemGV));
174 }
175 assert(ElemIt == LayoutStruct->element_end() &&
176 "number of elements in layout struct does not match");
177
178 // add buffer metadata to the module
179 CGM.getModule()
180 .getOrInsertNamedMetadata("hlsl.cbs")
181 ->addOperand(MDNode::get(Ctx, BufGlobals));
182 }
183
184 // Creates resource handle type for the HLSL buffer declaration
185 static const clang::HLSLAttributedResourceType *
createBufferHandleType(const HLSLBufferDecl * BufDecl)186 createBufferHandleType(const HLSLBufferDecl *BufDecl) {
187 ASTContext &AST = BufDecl->getASTContext();
188 QualType QT = AST.getHLSLAttributedResourceType(
189 AST.HLSLResourceTy,
190 QualType(BufDecl->getLayoutStruct()->getTypeForDecl(), 0),
191 HLSLAttributedResourceType::Attributes(ResourceClass::CBuffer));
192 return cast<HLSLAttributedResourceType>(QT.getTypePtr());
193 }
194
195 // Iterates over all declarations in the HLSL buffer and based on the
196 // packoffset or register(c#) annotations it fills outs the Layout
197 // vector with the user-specified layout offsets.
198 // The buffer offsets can be specified 2 ways:
199 // 1. declarations in cbuffer {} block can have a packoffset annotation
200 // (translates to HLSLPackOffsetAttr)
201 // 2. default constant buffer declarations at global scope can have
202 // register(c#) annotations (translates to HLSLResourceBindingAttr with
203 // RegisterType::C)
204 // It is not guaranteed that all declarations in a buffer have an annotation.
205 // For those where it is not specified a -1 value is added to the Layout
206 // vector. In the final layout these declarations will be placed at the end
207 // of the HLSL buffer after all of the elements with specified offset.
fillPackoffsetLayout(const HLSLBufferDecl * BufDecl,SmallVector<int32_t> & Layout)208 static void fillPackoffsetLayout(const HLSLBufferDecl *BufDecl,
209 SmallVector<int32_t> &Layout) {
210 assert(Layout.empty() && "expected empty vector for layout");
211 assert(BufDecl->hasValidPackoffset());
212
213 for (Decl *D : BufDecl->buffer_decls()) {
214 if (isa<CXXRecordDecl, EmptyDecl>(D) || isa<FunctionDecl>(D)) {
215 continue;
216 }
217 VarDecl *VD = dyn_cast<VarDecl>(D);
218 if (!VD || VD->getType().getAddressSpace() != LangAS::hlsl_constant)
219 continue;
220
221 if (!VD->hasAttrs()) {
222 Layout.push_back(-1);
223 continue;
224 }
225
226 int32_t Offset = -1;
227 for (auto *Attr : VD->getAttrs()) {
228 if (auto *POA = dyn_cast<HLSLPackOffsetAttr>(Attr)) {
229 Offset = POA->getOffsetInBytes();
230 break;
231 }
232 auto *RBA = dyn_cast<HLSLResourceBindingAttr>(Attr);
233 if (RBA &&
234 RBA->getRegisterType() == HLSLResourceBindingAttr::RegisterType::C) {
235 Offset = RBA->getSlotNumber() * CBufferRowSizeInBytes;
236 break;
237 }
238 }
239 Layout.push_back(Offset);
240 }
241 }
242
243 // Codegen for HLSLBufferDecl
addBuffer(const HLSLBufferDecl * BufDecl)244 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *BufDecl) {
245
246 assert(BufDecl->isCBuffer() && "tbuffer codegen is not supported yet");
247
248 // create resource handle type for the buffer
249 const clang::HLSLAttributedResourceType *ResHandleTy =
250 createBufferHandleType(BufDecl);
251
252 // empty constant buffer is ignored
253 if (ResHandleTy->getContainedType()->getAsCXXRecordDecl()->isEmpty())
254 return;
255
256 // create global variable for the constant buffer
257 SmallVector<int32_t> Layout;
258 if (BufDecl->hasValidPackoffset())
259 fillPackoffsetLayout(BufDecl, Layout);
260
261 llvm::TargetExtType *TargetTy =
262 cast<llvm::TargetExtType>(convertHLSLSpecificType(
263 ResHandleTy, BufDecl->hasValidPackoffset() ? &Layout : nullptr));
264 llvm::GlobalVariable *BufGV = new GlobalVariable(
265 TargetTy, /*isConstant*/ false,
266 GlobalValue::LinkageTypes::ExternalLinkage, PoisonValue::get(TargetTy),
267 llvm::formatv("{0}{1}", BufDecl->getName(),
268 BufDecl->isCBuffer() ? ".cb" : ".tb"),
269 GlobalValue::NotThreadLocal);
270 CGM.getModule().insertGlobalVariable(BufGV);
271
272 // Add globals for constant buffer elements and create metadata nodes
273 emitBufferGlobalsAndMetadata(BufDecl, BufGV);
274
275 // Initialize cbuffer from binding (implicit or explicit)
276 HLSLResourceBindingAttr *RBA = BufDecl->getAttr<HLSLResourceBindingAttr>();
277 assert(RBA &&
278 "cbuffer/tbuffer should always have resource binding attribute");
279 initializeBufferFromBinding(BufDecl, BufGV, RBA);
280 }
281
282 llvm::TargetExtType *
getHLSLBufferLayoutType(const RecordType * StructType)283 CGHLSLRuntime::getHLSLBufferLayoutType(const RecordType *StructType) {
284 const auto Entry = LayoutTypes.find(StructType);
285 if (Entry != LayoutTypes.end())
286 return Entry->getSecond();
287 return nullptr;
288 }
289
addHLSLBufferLayoutType(const RecordType * StructType,llvm::TargetExtType * LayoutTy)290 void CGHLSLRuntime::addHLSLBufferLayoutType(const RecordType *StructType,
291 llvm::TargetExtType *LayoutTy) {
292 assert(getHLSLBufferLayoutType(StructType) == nullptr &&
293 "layout type for this struct already exist");
294 LayoutTypes[StructType] = LayoutTy;
295 }
296
finishCodeGen()297 void CGHLSLRuntime::finishCodeGen() {
298 auto &TargetOpts = CGM.getTarget().getTargetOpts();
299 auto &CodeGenOpts = CGM.getCodeGenOpts();
300 auto &LangOpts = CGM.getLangOpts();
301 llvm::Module &M = CGM.getModule();
302 Triple T(M.getTargetTriple());
303 if (T.getArch() == Triple::ArchType::dxil)
304 addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
305 if (CodeGenOpts.ResMayAlias)
306 M.setModuleFlag(llvm::Module::ModFlagBehavior::Error, "dx.resmayalias", 1);
307
308 // NativeHalfType corresponds to the -fnative-half-type clang option which is
309 // aliased by clang-dxc's -enable-16bit-types option. This option is used to
310 // set the UseNativeLowPrecision DXIL module flag in the DirectX backend
311 if (LangOpts.NativeHalfType)
312 M.setModuleFlag(llvm::Module::ModFlagBehavior::Error, "dx.nativelowprec",
313 1);
314
315 generateGlobalCtorDtorCalls();
316 }
317
setHLSLEntryAttributes(const FunctionDecl * FD,llvm::Function * Fn)318 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
319 const FunctionDecl *FD, llvm::Function *Fn) {
320 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
321 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
322 const StringRef ShaderAttrKindStr = "hlsl.shader";
323 Fn->addFnAttr(ShaderAttrKindStr,
324 llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
325 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
326 const StringRef NumThreadsKindStr = "hlsl.numthreads";
327 std::string NumThreadsStr =
328 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
329 NumThreadsAttr->getZ());
330 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
331 }
332 if (HLSLWaveSizeAttr *WaveSizeAttr = FD->getAttr<HLSLWaveSizeAttr>()) {
333 const StringRef WaveSizeKindStr = "hlsl.wavesize";
334 std::string WaveSizeStr =
335 formatv("{0},{1},{2}", WaveSizeAttr->getMin(), WaveSizeAttr->getMax(),
336 WaveSizeAttr->getPreferred());
337 Fn->addFnAttr(WaveSizeKindStr, WaveSizeStr);
338 }
339 // HLSL entry functions are materialized for module functions with
340 // HLSLShaderAttr attribute. SetLLVMFunctionAttributesForDefinition called
341 // later in the compiler-flow for such module functions is not aware of and
342 // hence not able to set attributes of the newly materialized entry functions.
343 // So, set attributes of entry function here, as appropriate.
344 if (CGM.getCodeGenOpts().OptimizationLevel == 0)
345 Fn->addFnAttr(llvm::Attribute::OptimizeNone);
346 Fn->addFnAttr(llvm::Attribute::NoInline);
347 }
348
buildVectorInput(IRBuilder<> & B,Function * F,llvm::Type * Ty)349 static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
350 if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
351 Value *Result = PoisonValue::get(Ty);
352 for (unsigned I = 0; I < VT->getNumElements(); ++I) {
353 Value *Elt = B.CreateCall(F, {B.getInt32(I)});
354 Result = B.CreateInsertElement(Result, Elt, I);
355 }
356 return Result;
357 }
358 return B.CreateCall(F, {B.getInt32(0)});
359 }
360
addSPIRVBuiltinDecoration(llvm::GlobalVariable * GV,unsigned BuiltIn)361 static void addSPIRVBuiltinDecoration(llvm::GlobalVariable *GV,
362 unsigned BuiltIn) {
363 LLVMContext &Ctx = GV->getContext();
364 IRBuilder<> B(GV->getContext());
365 MDNode *Operands = MDNode::get(
366 Ctx,
367 {ConstantAsMetadata::get(B.getInt32(/* Spirv::Decoration::BuiltIn */ 11)),
368 ConstantAsMetadata::get(B.getInt32(BuiltIn))});
369 MDNode *Decoration = MDNode::get(Ctx, {Operands});
370 GV->addMetadata("spirv.Decorations", *Decoration);
371 }
372
createSPIRVBuiltinLoad(IRBuilder<> & B,llvm::Module & M,llvm::Type * Ty,const Twine & Name,unsigned BuiltInID)373 static llvm::Value *createSPIRVBuiltinLoad(IRBuilder<> &B, llvm::Module &M,
374 llvm::Type *Ty, const Twine &Name,
375 unsigned BuiltInID) {
376 auto *GV = new llvm::GlobalVariable(
377 M, Ty, /* isConstant= */ true, llvm::GlobalValue::ExternalLinkage,
378 /* Initializer= */ nullptr, Name, /* insertBefore= */ nullptr,
379 llvm::GlobalVariable::GeneralDynamicTLSModel,
380 /* AddressSpace */ 7, /* isExternallyInitialized= */ true);
381 addSPIRVBuiltinDecoration(GV, BuiltInID);
382 GV->setVisibility(llvm::GlobalValue::HiddenVisibility);
383 return B.CreateLoad(Ty, GV);
384 }
385
emitInputSemantic(IRBuilder<> & B,const ParmVarDecl & D,llvm::Type * Ty)386 llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
387 const ParmVarDecl &D,
388 llvm::Type *Ty) {
389 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
390 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
391 llvm::Function *GroupIndex =
392 CGM.getIntrinsic(getFlattenedThreadIdInGroupIntrinsic());
393 return B.CreateCall(FunctionCallee(GroupIndex));
394 }
395 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
396 llvm::Intrinsic::ID IntrinID = getThreadIdIntrinsic();
397 llvm::Function *ThreadIDIntrinsic =
398 llvm::Intrinsic::isOverloaded(IntrinID)
399 ? CGM.getIntrinsic(IntrinID, {CGM.Int32Ty})
400 : CGM.getIntrinsic(IntrinID);
401 return buildVectorInput(B, ThreadIDIntrinsic, Ty);
402 }
403 if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) {
404 llvm::Intrinsic::ID IntrinID = getGroupThreadIdIntrinsic();
405 llvm::Function *GroupThreadIDIntrinsic =
406 llvm::Intrinsic::isOverloaded(IntrinID)
407 ? CGM.getIntrinsic(IntrinID, {CGM.Int32Ty})
408 : CGM.getIntrinsic(IntrinID);
409 return buildVectorInput(B, GroupThreadIDIntrinsic, Ty);
410 }
411 if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
412 llvm::Intrinsic::ID IntrinID = getGroupIdIntrinsic();
413 llvm::Function *GroupIDIntrinsic =
414 llvm::Intrinsic::isOverloaded(IntrinID)
415 ? CGM.getIntrinsic(IntrinID, {CGM.Int32Ty})
416 : CGM.getIntrinsic(IntrinID);
417 return buildVectorInput(B, GroupIDIntrinsic, Ty);
418 }
419 if (D.hasAttr<HLSLSV_PositionAttr>()) {
420 if (getArch() == llvm::Triple::spirv)
421 return createSPIRVBuiltinLoad(B, CGM.getModule(), Ty, "sv_position",
422 /* BuiltIn::Position */ 0);
423 llvm_unreachable("SV_Position semantic not implemented for this target.");
424 }
425 assert(false && "Unhandled parameter attribute");
426 return nullptr;
427 }
428
emitEntryFunction(const FunctionDecl * FD,llvm::Function * Fn)429 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
430 llvm::Function *Fn) {
431 llvm::Module &M = CGM.getModule();
432 llvm::LLVMContext &Ctx = M.getContext();
433 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
434 Function *EntryFn =
435 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
436
437 // Copy function attributes over, we have no argument or return attributes
438 // that can be valid on the real entry.
439 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
440 Fn->getAttributes().getFnAttrs());
441 EntryFn->setAttributes(NewAttrs);
442 setHLSLEntryAttributes(FD, EntryFn);
443
444 // Set the called function as internal linkage.
445 Fn->setLinkage(GlobalValue::InternalLinkage);
446
447 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
448 IRBuilder<> B(BB);
449 llvm::SmallVector<Value *> Args;
450
451 SmallVector<OperandBundleDef, 1> OB;
452 if (CGM.shouldEmitConvergenceTokens()) {
453 assert(EntryFn->isConvergent());
454 llvm::Value *I =
455 B.CreateIntrinsic(llvm::Intrinsic::experimental_convergence_entry, {});
456 llvm::Value *bundleArgs[] = {I};
457 OB.emplace_back("convergencectrl", bundleArgs);
458 }
459
460 // FIXME: support struct parameters where semantics are on members.
461 // See: https://github.com/llvm/llvm-project/issues/57874
462 unsigned SRetOffset = 0;
463 for (const auto &Param : Fn->args()) {
464 if (Param.hasStructRetAttr()) {
465 // FIXME: support output.
466 // See: https://github.com/llvm/llvm-project/issues/57874
467 SRetOffset = 1;
468 Args.emplace_back(PoisonValue::get(Param.getType()));
469 continue;
470 }
471 const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
472 Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
473 }
474
475 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args, OB);
476 CI->setCallingConv(Fn->getCallingConv());
477 // FIXME: Handle codegen for return type semantics.
478 // See: https://github.com/llvm/llvm-project/issues/57875
479 B.CreateRetVoid();
480
481 // Add and identify root signature to function, if applicable
482 for (const Attr *Attr : FD->getAttrs()) {
483 if (const auto *RSAttr = dyn_cast<RootSignatureAttr>(Attr)) {
484 auto *RSDecl = RSAttr->getSignatureDecl();
485 addRootSignature(RSDecl->getVersion(), RSDecl->getRootElements(), EntryFn,
486 M);
487 }
488 }
489 }
490
gatherFunctions(SmallVectorImpl<Function * > & Fns,llvm::Module & M,bool CtorOrDtor)491 static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
492 bool CtorOrDtor) {
493 const auto *GV =
494 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
495 if (!GV)
496 return;
497 const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
498 if (!CA)
499 return;
500 // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
501 // HLSL neither supports priorities or COMDat values, so we will check those
502 // in an assert but not handle them.
503
504 for (const auto &Ctor : CA->operands()) {
505 if (isa<ConstantAggregateZero>(Ctor))
506 continue;
507 ConstantStruct *CS = cast<ConstantStruct>(Ctor);
508
509 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
510 "HLSL doesn't support setting priority for global ctors.");
511 assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
512 "HLSL doesn't support COMDat for global ctors.");
513 Fns.push_back(cast<Function>(CS->getOperand(1)));
514 }
515 }
516
generateGlobalCtorDtorCalls()517 void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
518 llvm::Module &M = CGM.getModule();
519 SmallVector<Function *> CtorFns;
520 SmallVector<Function *> DtorFns;
521 gatherFunctions(CtorFns, M, true);
522 gatherFunctions(DtorFns, M, false);
523
524 // Insert a call to the global constructor at the beginning of the entry block
525 // to externally exported functions. This is a bit of a hack, but HLSL allows
526 // global constructors, but doesn't support driver initialization of globals.
527 for (auto &F : M.functions()) {
528 if (!F.hasFnAttribute("hlsl.shader"))
529 continue;
530 auto *Token = getConvergenceToken(F.getEntryBlock());
531 Instruction *IP = &*F.getEntryBlock().begin();
532 SmallVector<OperandBundleDef, 1> OB;
533 if (Token) {
534 llvm::Value *bundleArgs[] = {Token};
535 OB.emplace_back("convergencectrl", bundleArgs);
536 IP = Token->getNextNode();
537 }
538 IRBuilder<> B(IP);
539 for (auto *Fn : CtorFns) {
540 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);
541 CI->setCallingConv(Fn->getCallingConv());
542 }
543
544 // Insert global dtors before the terminator of the last instruction
545 B.SetInsertPoint(F.back().getTerminator());
546 for (auto *Fn : DtorFns) {
547 auto CI = B.CreateCall(FunctionCallee(Fn), {}, OB);
548 CI->setCallingConv(Fn->getCallingConv());
549 }
550 }
551
552 // No need to keep global ctors/dtors for non-lib profile after call to
553 // ctors/dtors added for entry.
554 Triple T(M.getTargetTriple());
555 if (T.getEnvironment() != Triple::EnvironmentType::Library) {
556 if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
557 GV->eraseFromParent();
558 if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
559 GV->eraseFromParent();
560 }
561 }
562
initializeBuffer(CodeGenModule & CGM,llvm::GlobalVariable * GV,Intrinsic::ID IntrID,ArrayRef<llvm::Value * > Args)563 static void initializeBuffer(CodeGenModule &CGM, llvm::GlobalVariable *GV,
564 Intrinsic::ID IntrID,
565 ArrayRef<llvm::Value *> Args) {
566
567 LLVMContext &Ctx = CGM.getLLVMContext();
568 llvm::Function *InitResFunc = llvm::Function::Create(
569 llvm::FunctionType::get(CGM.VoidTy, false),
570 llvm::GlobalValue::InternalLinkage,
571 ("_init_buffer_" + GV->getName()).str(), CGM.getModule());
572 InitResFunc->addFnAttr(llvm::Attribute::AlwaysInline);
573
574 llvm::BasicBlock *EntryBB =
575 llvm::BasicBlock::Create(Ctx, "entry", InitResFunc);
576 CGBuilderTy Builder(CGM, Ctx);
577 const DataLayout &DL = CGM.getModule().getDataLayout();
578 Builder.SetInsertPoint(EntryBB);
579
580 // Make sure the global variable is buffer resource handle
581 llvm::Type *HandleTy = GV->getValueType();
582 assert(HandleTy->isTargetExtTy() && "unexpected type of the buffer global");
583
584 llvm::Value *CreateHandle = Builder.CreateIntrinsic(
585 /*ReturnType=*/HandleTy, IntrID, Args, nullptr,
586 Twine(GV->getName()).concat("_h"));
587
588 llvm::Value *HandleRef = Builder.CreateStructGEP(GV->getValueType(), GV, 0);
589 Builder.CreateAlignedStore(CreateHandle, HandleRef,
590 HandleRef->getPointerAlignment(DL));
591 Builder.CreateRetVoid();
592
593 CGM.AddCXXGlobalInit(InitResFunc);
594 }
595
initializeBufferFromBinding(const HLSLBufferDecl * BufDecl,llvm::GlobalVariable * GV,HLSLResourceBindingAttr * RBA)596 void CGHLSLRuntime::initializeBufferFromBinding(const HLSLBufferDecl *BufDecl,
597 llvm::GlobalVariable *GV,
598 HLSLResourceBindingAttr *RBA) {
599 assert(RBA && "expect a nonnull binding attribute");
600 llvm::Type *Int1Ty = llvm::Type::getInt1Ty(CGM.getLLVMContext());
601 auto *NonUniform = llvm::ConstantInt::get(Int1Ty, false);
602 auto *Index = llvm::ConstantInt::get(CGM.IntTy, 0);
603 auto *RangeSize = llvm::ConstantInt::get(CGM.IntTy, 1);
604 auto *Space = llvm::ConstantInt::get(CGM.IntTy, RBA->getSpaceNumber());
605 Value *Name = nullptr;
606
607 llvm::Intrinsic::ID IntrinsicID =
608 RBA->hasRegisterSlot()
609 ? CGM.getHLSLRuntime().getCreateHandleFromBindingIntrinsic()
610 : CGM.getHLSLRuntime().getCreateHandleFromImplicitBindingIntrinsic();
611
612 std::string Str(BufDecl->getName());
613 std::string GlobalName(Str + ".str");
614 Name = CGM.GetAddrOfConstantCString(Str, GlobalName.c_str()).getPointer();
615
616 // buffer with explicit binding
617 if (RBA->hasRegisterSlot()) {
618 auto *RegSlot = llvm::ConstantInt::get(CGM.IntTy, RBA->getSlotNumber());
619 SmallVector<Value *> Args{Space, RegSlot, RangeSize,
620 Index, NonUniform, Name};
621 initializeBuffer(CGM, GV, IntrinsicID, Args);
622 } else {
623 // buffer with implicit binding
624 auto *OrderID =
625 llvm::ConstantInt::get(CGM.IntTy, RBA->getImplicitBindingOrderID());
626 SmallVector<Value *> Args{OrderID, Space, RangeSize,
627 Index, NonUniform, Name};
628 initializeBuffer(CGM, GV, IntrinsicID, Args);
629 }
630 }
631
handleGlobalVarDefinition(const VarDecl * VD,llvm::GlobalVariable * GV)632 void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD,
633 llvm::GlobalVariable *GV) {
634 if (auto Attr = VD->getAttr<HLSLVkExtBuiltinInputAttr>())
635 addSPIRVBuiltinDecoration(GV, Attr->getBuiltIn());
636 }
637
getConvergenceToken(BasicBlock & BB)638 llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) {
639 if (!CGM.shouldEmitConvergenceTokens())
640 return nullptr;
641
642 auto E = BB.end();
643 for (auto I = BB.begin(); I != E; ++I) {
644 auto *II = dyn_cast<llvm::IntrinsicInst>(&*I);
645 if (II && llvm::isConvergenceControlIntrinsic(II->getIntrinsicID())) {
646 return II;
647 }
648 }
649 llvm_unreachable("Convergence token should have been emitted.");
650 return nullptr;
651 }
652
653 class OpaqueValueVisitor : public RecursiveASTVisitor<OpaqueValueVisitor> {
654 public:
655 llvm::SmallPtrSet<OpaqueValueExpr *, 8> OVEs;
OpaqueValueVisitor()656 OpaqueValueVisitor() {}
657
VisitOpaqueValueExpr(OpaqueValueExpr * E)658 bool VisitOpaqueValueExpr(OpaqueValueExpr *E) {
659 OVEs.insert(E);
660 return true;
661 }
662 };
663
emitInitListOpaqueValues(CodeGenFunction & CGF,InitListExpr * E)664 void CGHLSLRuntime::emitInitListOpaqueValues(CodeGenFunction &CGF,
665 InitListExpr *E) {
666
667 typedef CodeGenFunction::OpaqueValueMappingData OpaqueValueMappingData;
668 OpaqueValueVisitor Visitor;
669 Visitor.TraverseStmt(E);
670 for (auto *OVE : Visitor.OVEs) {
671 if (CGF.isOpaqueValueEmitted(OVE))
672 continue;
673 if (OpaqueValueMappingData::shouldBindAsLValue(OVE)) {
674 LValue LV = CGF.EmitLValue(OVE->getSourceExpr());
675 OpaqueValueMappingData::bind(CGF, OVE, LV);
676 } else {
677 RValue RV = CGF.EmitAnyExpr(OVE->getSourceExpr());
678 OpaqueValueMappingData::bind(CGF, OVE, RV);
679 }
680 }
681 }
682