181ad6265SDimitry Andric //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
281ad6265SDimitry Andric //
381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681ad6265SDimitry Andric //
781ad6265SDimitry Andric //===----------------------------------------------------------------------===//
881ad6265SDimitry Andric //
981ad6265SDimitry Andric // This provides an abstract class for HLSL code generation. Concrete
1081ad6265SDimitry Andric // subclasses of this implement code generation for specific HLSL
1181ad6265SDimitry Andric // runtime libraries.
1281ad6265SDimitry Andric //
1381ad6265SDimitry Andric //===----------------------------------------------------------------------===//
1481ad6265SDimitry Andric
1581ad6265SDimitry Andric #include "CGHLSLRuntime.h"
16bdd1243dSDimitry Andric #include "CGDebugInfo.h"
1781ad6265SDimitry Andric #include "CodeGenModule.h"
18bdd1243dSDimitry Andric #include "clang/AST/Decl.h"
1981ad6265SDimitry Andric #include "clang/Basic/TargetOptions.h"
2081ad6265SDimitry Andric #include "llvm/IR/Metadata.h"
2181ad6265SDimitry Andric #include "llvm/IR/Module.h"
22bdd1243dSDimitry Andric #include "llvm/Support/FormatVariadic.h"
2381ad6265SDimitry Andric
2481ad6265SDimitry Andric using namespace clang;
2581ad6265SDimitry Andric using namespace CodeGen;
26bdd1243dSDimitry Andric using namespace clang::hlsl;
2781ad6265SDimitry Andric using namespace llvm;
2881ad6265SDimitry Andric
2981ad6265SDimitry Andric namespace {
30bdd1243dSDimitry Andric
addDxilValVersion(StringRef ValVersionStr,llvm::Module & M)3181ad6265SDimitry Andric void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
3281ad6265SDimitry Andric // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
3381ad6265SDimitry Andric // Assume ValVersionStr is legal here.
3481ad6265SDimitry Andric VersionTuple Version;
3581ad6265SDimitry Andric if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
3681ad6265SDimitry Andric Version.getSubminor() || !Version.getMinor()) {
3781ad6265SDimitry Andric return;
3881ad6265SDimitry Andric }
3981ad6265SDimitry Andric
4081ad6265SDimitry Andric uint64_t Major = Version.getMajor();
4181ad6265SDimitry Andric uint64_t Minor = *Version.getMinor();
4281ad6265SDimitry Andric
4381ad6265SDimitry Andric auto &Ctx = M.getContext();
4481ad6265SDimitry Andric IRBuilder<> B(M.getContext());
4581ad6265SDimitry Andric MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
4681ad6265SDimitry Andric ConstantAsMetadata::get(B.getInt32(Minor))});
47bdd1243dSDimitry Andric StringRef DXILValKey = "dx.valver";
48bdd1243dSDimitry Andric auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
49bdd1243dSDimitry Andric DXILValMD->addOperand(Val);
5081ad6265SDimitry Andric }
addDisableOptimizations(llvm::Module & M)51bdd1243dSDimitry Andric void addDisableOptimizations(llvm::Module &M) {
52bdd1243dSDimitry Andric StringRef Key = "dx.disable_optimizations";
53bdd1243dSDimitry Andric M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
54bdd1243dSDimitry Andric }
55bdd1243dSDimitry Andric // cbuffer will be translated into global variable in special address space.
56bdd1243dSDimitry Andric // If translate into C,
57bdd1243dSDimitry Andric // cbuffer A {
58bdd1243dSDimitry Andric // float a;
59bdd1243dSDimitry Andric // float b;
60bdd1243dSDimitry Andric // }
61bdd1243dSDimitry Andric // float foo() { return a + b; }
62bdd1243dSDimitry Andric //
63bdd1243dSDimitry Andric // will be translated into
64bdd1243dSDimitry Andric //
65bdd1243dSDimitry Andric // struct A {
66bdd1243dSDimitry Andric // float a;
67bdd1243dSDimitry Andric // float b;
68bdd1243dSDimitry Andric // } cbuffer_A __attribute__((address_space(4)));
69bdd1243dSDimitry Andric // float foo() { return cbuffer_A.a + cbuffer_A.b; }
70bdd1243dSDimitry Andric //
71bdd1243dSDimitry Andric // layoutBuffer will create the struct A type.
72bdd1243dSDimitry Andric // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
73bdd1243dSDimitry Andric // and cbuffer_A.b.
74bdd1243dSDimitry Andric //
layoutBuffer(CGHLSLRuntime::Buffer & Buf,const DataLayout & DL)75bdd1243dSDimitry Andric void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
76bdd1243dSDimitry Andric if (Buf.Constants.empty())
77bdd1243dSDimitry Andric return;
78bdd1243dSDimitry Andric
79bdd1243dSDimitry Andric std::vector<llvm::Type *> EltTys;
80bdd1243dSDimitry Andric for (auto &Const : Buf.Constants) {
81bdd1243dSDimitry Andric GlobalVariable *GV = Const.first;
82bdd1243dSDimitry Andric Const.second = EltTys.size();
83bdd1243dSDimitry Andric llvm::Type *Ty = GV->getValueType();
84bdd1243dSDimitry Andric EltTys.emplace_back(Ty);
85bdd1243dSDimitry Andric }
86bdd1243dSDimitry Andric Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
87bdd1243dSDimitry Andric }
88bdd1243dSDimitry Andric
replaceBuffer(CGHLSLRuntime::Buffer & Buf)89bdd1243dSDimitry Andric GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
90bdd1243dSDimitry Andric // Create global variable for CB.
91bdd1243dSDimitry Andric GlobalVariable *CBGV = new GlobalVariable(
92bdd1243dSDimitry Andric Buf.LayoutStruct, /*isConstant*/ true,
93bdd1243dSDimitry Andric GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
94bdd1243dSDimitry Andric llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
95bdd1243dSDimitry Andric GlobalValue::NotThreadLocal);
96bdd1243dSDimitry Andric
97bdd1243dSDimitry Andric IRBuilder<> B(CBGV->getContext());
98bdd1243dSDimitry Andric Value *ZeroIdx = B.getInt32(0);
99bdd1243dSDimitry Andric // Replace Const use with CB use.
100bdd1243dSDimitry Andric for (auto &[GV, Offset] : Buf.Constants) {
101bdd1243dSDimitry Andric Value *GEP =
102bdd1243dSDimitry Andric B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
103bdd1243dSDimitry Andric
104bdd1243dSDimitry Andric assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
105bdd1243dSDimitry Andric "constant type mismatch");
106bdd1243dSDimitry Andric
107bdd1243dSDimitry Andric // Replace.
108bdd1243dSDimitry Andric GV->replaceAllUsesWith(GEP);
109bdd1243dSDimitry Andric // Erase GV.
110bdd1243dSDimitry Andric GV->removeDeadConstantUsers();
111bdd1243dSDimitry Andric GV->eraseFromParent();
112bdd1243dSDimitry Andric }
113bdd1243dSDimitry Andric return CBGV;
114bdd1243dSDimitry Andric }
115bdd1243dSDimitry Andric
11681ad6265SDimitry Andric } // namespace
11781ad6265SDimitry Andric
getArch()118*0fca6ea1SDimitry Andric llvm::Triple::ArchType CGHLSLRuntime::getArch() {
119*0fca6ea1SDimitry Andric return CGM.getTarget().getTriple().getArch();
120*0fca6ea1SDimitry Andric }
121*0fca6ea1SDimitry Andric
addConstant(VarDecl * D,Buffer & CB)122bdd1243dSDimitry Andric void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
123bdd1243dSDimitry Andric if (D->getStorageClass() == SC_Static) {
124bdd1243dSDimitry Andric // For static inside cbuffer, take as global static.
125bdd1243dSDimitry Andric // Don't add to cbuffer.
126bdd1243dSDimitry Andric CGM.EmitGlobal(D);
127bdd1243dSDimitry Andric return;
128bdd1243dSDimitry Andric }
129bdd1243dSDimitry Andric
130bdd1243dSDimitry Andric auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
131bdd1243dSDimitry Andric // Add debug info for constVal.
132bdd1243dSDimitry Andric if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
133bdd1243dSDimitry Andric if (CGM.getCodeGenOpts().getDebugInfo() >=
134bdd1243dSDimitry Andric codegenoptions::DebugInfoKind::LimitedDebugInfo)
135bdd1243dSDimitry Andric DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
136bdd1243dSDimitry Andric
137bdd1243dSDimitry Andric // FIXME: support packoffset.
138bdd1243dSDimitry Andric // See https://github.com/llvm/llvm-project/issues/57914.
139bdd1243dSDimitry Andric uint32_t Offset = 0;
140bdd1243dSDimitry Andric bool HasUserOffset = false;
141bdd1243dSDimitry Andric
142bdd1243dSDimitry Andric unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
143bdd1243dSDimitry Andric CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
144bdd1243dSDimitry Andric }
145bdd1243dSDimitry Andric
addBufferDecls(const DeclContext * DC,Buffer & CB)146bdd1243dSDimitry Andric void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
147bdd1243dSDimitry Andric for (Decl *it : DC->decls()) {
148bdd1243dSDimitry Andric if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
149bdd1243dSDimitry Andric addConstant(ConstDecl, CB);
150bdd1243dSDimitry Andric } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
151bdd1243dSDimitry Andric // Nothing to do for this declaration.
152bdd1243dSDimitry Andric } else if (isa<FunctionDecl>(it)) {
153bdd1243dSDimitry Andric // A function within an cbuffer is effectively a top-level function,
154bdd1243dSDimitry Andric // as it only refers to globally scoped declarations.
155bdd1243dSDimitry Andric CGM.EmitTopLevelDecl(it);
156bdd1243dSDimitry Andric }
157bdd1243dSDimitry Andric }
158bdd1243dSDimitry Andric }
159bdd1243dSDimitry Andric
addBuffer(const HLSLBufferDecl * D)160bdd1243dSDimitry Andric void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
161bdd1243dSDimitry Andric Buffers.emplace_back(Buffer(D));
162bdd1243dSDimitry Andric addBufferDecls(D, Buffers.back());
163bdd1243dSDimitry Andric }
164bdd1243dSDimitry Andric
finishCodeGen()16581ad6265SDimitry Andric void CGHLSLRuntime::finishCodeGen() {
16681ad6265SDimitry Andric auto &TargetOpts = CGM.getTarget().getTargetOpts();
16781ad6265SDimitry Andric llvm::Module &M = CGM.getModule();
168bdd1243dSDimitry Andric Triple T(M.getTargetTriple());
169bdd1243dSDimitry Andric if (T.getArch() == Triple::ArchType::dxil)
17081ad6265SDimitry Andric addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
171bdd1243dSDimitry Andric
172bdd1243dSDimitry Andric generateGlobalCtorDtorCalls();
173bdd1243dSDimitry Andric if (CGM.getCodeGenOpts().OptimizationLevel == 0)
174bdd1243dSDimitry Andric addDisableOptimizations(M);
175bdd1243dSDimitry Andric
176bdd1243dSDimitry Andric const DataLayout &DL = M.getDataLayout();
177bdd1243dSDimitry Andric
178bdd1243dSDimitry Andric for (auto &Buf : Buffers) {
179bdd1243dSDimitry Andric layoutBuffer(Buf, DL);
180bdd1243dSDimitry Andric GlobalVariable *GV = replaceBuffer(Buf);
18106c3fb27SDimitry Andric M.insertGlobalVariable(GV);
182bdd1243dSDimitry Andric llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
183bdd1243dSDimitry Andric ? llvm::hlsl::ResourceClass::CBuffer
184bdd1243dSDimitry Andric : llvm::hlsl::ResourceClass::SRV;
185bdd1243dSDimitry Andric llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
186bdd1243dSDimitry Andric ? llvm::hlsl::ResourceKind::CBuffer
187bdd1243dSDimitry Andric : llvm::hlsl::ResourceKind::TBuffer;
188cb14a3feSDimitry Andric addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,
189cb14a3feSDimitry Andric llvm::hlsl::ElementType::Invalid, Buf.Binding);
190bdd1243dSDimitry Andric }
191bdd1243dSDimitry Andric }
192bdd1243dSDimitry Andric
Buffer(const HLSLBufferDecl * D)193bdd1243dSDimitry Andric CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
194bdd1243dSDimitry Andric : Name(D->getName()), IsCBuffer(D->isCBuffer()),
195bdd1243dSDimitry Andric Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
196bdd1243dSDimitry Andric
addBufferResourceAnnotation(llvm::GlobalVariable * GV,llvm::hlsl::ResourceClass RC,llvm::hlsl::ResourceKind RK,bool IsROV,llvm::hlsl::ElementType ET,BufferResBinding & Binding)197bdd1243dSDimitry Andric void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
198bdd1243dSDimitry Andric llvm::hlsl::ResourceClass RC,
199bdd1243dSDimitry Andric llvm::hlsl::ResourceKind RK,
2005f757f3fSDimitry Andric bool IsROV,
201cb14a3feSDimitry Andric llvm::hlsl::ElementType ET,
202bdd1243dSDimitry Andric BufferResBinding &Binding) {
203bdd1243dSDimitry Andric llvm::Module &M = CGM.getModule();
204bdd1243dSDimitry Andric
205bdd1243dSDimitry Andric NamedMDNode *ResourceMD = nullptr;
206bdd1243dSDimitry Andric switch (RC) {
207bdd1243dSDimitry Andric case llvm::hlsl::ResourceClass::UAV:
208bdd1243dSDimitry Andric ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
209bdd1243dSDimitry Andric break;
210bdd1243dSDimitry Andric case llvm::hlsl::ResourceClass::SRV:
211bdd1243dSDimitry Andric ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
212bdd1243dSDimitry Andric break;
213bdd1243dSDimitry Andric case llvm::hlsl::ResourceClass::CBuffer:
214bdd1243dSDimitry Andric ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
215bdd1243dSDimitry Andric break;
216bdd1243dSDimitry Andric default:
217bdd1243dSDimitry Andric assert(false && "Unsupported buffer type!");
218bdd1243dSDimitry Andric return;
219bdd1243dSDimitry Andric }
220bdd1243dSDimitry Andric assert(ResourceMD != nullptr &&
221bdd1243dSDimitry Andric "ResourceMD must have been set by the switch above.");
222bdd1243dSDimitry Andric
223bdd1243dSDimitry Andric llvm::hlsl::FrontendResource Res(
224cb14a3feSDimitry Andric GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
225bdd1243dSDimitry Andric ResourceMD->addOperand(Res.getMetadata());
226bdd1243dSDimitry Andric }
227bdd1243dSDimitry Andric
228cb14a3feSDimitry Andric static llvm::hlsl::ElementType
calculateElementType(const ASTContext & Context,const clang::Type * ResourceTy)229cb14a3feSDimitry Andric calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) {
230cb14a3feSDimitry Andric using llvm::hlsl::ElementType;
231cb14a3feSDimitry Andric
232cb14a3feSDimitry Andric // TODO: We may need to update this when we add things like ByteAddressBuffer
233cb14a3feSDimitry Andric // that don't have a template parameter (or, indeed, an element type).
234cb14a3feSDimitry Andric const auto *TST = ResourceTy->getAs<TemplateSpecializationType>();
235cb14a3feSDimitry Andric assert(TST && "Resource types must be template specializations");
236cb14a3feSDimitry Andric ArrayRef<TemplateArgument> Args = TST->template_arguments();
237cb14a3feSDimitry Andric assert(!Args.empty() && "Resource has no element type");
238cb14a3feSDimitry Andric
239cb14a3feSDimitry Andric // At this point we have a resource with an element type, so we can assume
240cb14a3feSDimitry Andric // that it's valid or we would have diagnosed the error earlier.
241cb14a3feSDimitry Andric QualType ElTy = Args[0].getAsType();
242cb14a3feSDimitry Andric
243cb14a3feSDimitry Andric // We should either have a basic type or a vector of a basic type.
244cb14a3feSDimitry Andric if (const auto *VecTy = ElTy->getAs<clang::VectorType>())
245cb14a3feSDimitry Andric ElTy = VecTy->getElementType();
246cb14a3feSDimitry Andric
247cb14a3feSDimitry Andric if (ElTy->isSignedIntegerType()) {
248cb14a3feSDimitry Andric switch (Context.getTypeSize(ElTy)) {
249cb14a3feSDimitry Andric case 16:
250cb14a3feSDimitry Andric return ElementType::I16;
251cb14a3feSDimitry Andric case 32:
252cb14a3feSDimitry Andric return ElementType::I32;
253cb14a3feSDimitry Andric case 64:
254cb14a3feSDimitry Andric return ElementType::I64;
255cb14a3feSDimitry Andric }
256cb14a3feSDimitry Andric } else if (ElTy->isUnsignedIntegerType()) {
257cb14a3feSDimitry Andric switch (Context.getTypeSize(ElTy)) {
258cb14a3feSDimitry Andric case 16:
259cb14a3feSDimitry Andric return ElementType::U16;
260cb14a3feSDimitry Andric case 32:
261cb14a3feSDimitry Andric return ElementType::U32;
262cb14a3feSDimitry Andric case 64:
263cb14a3feSDimitry Andric return ElementType::U64;
264cb14a3feSDimitry Andric }
265cb14a3feSDimitry Andric } else if (ElTy->isSpecificBuiltinType(BuiltinType::Half))
266cb14a3feSDimitry Andric return ElementType::F16;
267cb14a3feSDimitry Andric else if (ElTy->isSpecificBuiltinType(BuiltinType::Float))
268cb14a3feSDimitry Andric return ElementType::F32;
269cb14a3feSDimitry Andric else if (ElTy->isSpecificBuiltinType(BuiltinType::Double))
270cb14a3feSDimitry Andric return ElementType::F64;
271cb14a3feSDimitry Andric
272cb14a3feSDimitry Andric // TODO: We need to handle unorm/snorm float types here once we support them
273cb14a3feSDimitry Andric llvm_unreachable("Invalid element type for resource");
274cb14a3feSDimitry Andric }
275cb14a3feSDimitry Andric
annotateHLSLResource(const VarDecl * D,GlobalVariable * GV)276bdd1243dSDimitry Andric void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
277bdd1243dSDimitry Andric const Type *Ty = D->getType()->getPointeeOrArrayElementType();
278bdd1243dSDimitry Andric if (!Ty)
279bdd1243dSDimitry Andric return;
280bdd1243dSDimitry Andric const auto *RD = Ty->getAsCXXRecordDecl();
281bdd1243dSDimitry Andric if (!RD)
282bdd1243dSDimitry Andric return;
283*0fca6ea1SDimitry Andric const auto *HLSLResAttr = RD->getAttr<HLSLResourceAttr>();
284*0fca6ea1SDimitry Andric const auto *HLSLResClassAttr = RD->getAttr<HLSLResourceClassAttr>();
285*0fca6ea1SDimitry Andric if (!HLSLResAttr || !HLSLResClassAttr)
286bdd1243dSDimitry Andric return;
287bdd1243dSDimitry Andric
288*0fca6ea1SDimitry Andric llvm::hlsl::ResourceClass RC = HLSLResClassAttr->getResourceClass();
289*0fca6ea1SDimitry Andric llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind();
290*0fca6ea1SDimitry Andric bool IsROV = HLSLResAttr->getIsROV();
291cb14a3feSDimitry Andric llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty);
292bdd1243dSDimitry Andric
293bdd1243dSDimitry Andric BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
294cb14a3feSDimitry Andric addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
295bdd1243dSDimitry Andric }
296bdd1243dSDimitry Andric
BufferResBinding(HLSLResourceBindingAttr * Binding)297bdd1243dSDimitry Andric CGHLSLRuntime::BufferResBinding::BufferResBinding(
298bdd1243dSDimitry Andric HLSLResourceBindingAttr *Binding) {
299bdd1243dSDimitry Andric if (Binding) {
300bdd1243dSDimitry Andric llvm::APInt RegInt(64, 0);
301bdd1243dSDimitry Andric Binding->getSlot().substr(1).getAsInteger(10, RegInt);
302bdd1243dSDimitry Andric Reg = RegInt.getLimitedValue();
303bdd1243dSDimitry Andric llvm::APInt SpaceInt(64, 0);
304bdd1243dSDimitry Andric Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
305bdd1243dSDimitry Andric Space = SpaceInt.getLimitedValue();
306bdd1243dSDimitry Andric } else {
307bdd1243dSDimitry Andric Space = 0;
308bdd1243dSDimitry Andric }
309bdd1243dSDimitry Andric }
310bdd1243dSDimitry Andric
setHLSLEntryAttributes(const FunctionDecl * FD,llvm::Function * Fn)311bdd1243dSDimitry Andric void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
312bdd1243dSDimitry Andric const FunctionDecl *FD, llvm::Function *Fn) {
313bdd1243dSDimitry Andric const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
314bdd1243dSDimitry Andric assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
315bdd1243dSDimitry Andric const StringRef ShaderAttrKindStr = "hlsl.shader";
316bdd1243dSDimitry Andric Fn->addFnAttr(ShaderAttrKindStr,
317*0fca6ea1SDimitry Andric llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
318bdd1243dSDimitry Andric if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
319bdd1243dSDimitry Andric const StringRef NumThreadsKindStr = "hlsl.numthreads";
320bdd1243dSDimitry Andric std::string NumThreadsStr =
321bdd1243dSDimitry Andric formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
322bdd1243dSDimitry Andric NumThreadsAttr->getZ());
323bdd1243dSDimitry Andric Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
324bdd1243dSDimitry Andric }
325bdd1243dSDimitry Andric }
326bdd1243dSDimitry Andric
buildVectorInput(IRBuilder<> & B,Function * F,llvm::Type * Ty)327bdd1243dSDimitry Andric static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
328bdd1243dSDimitry Andric if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
329bdd1243dSDimitry Andric Value *Result = PoisonValue::get(Ty);
330bdd1243dSDimitry Andric for (unsigned I = 0; I < VT->getNumElements(); ++I) {
331bdd1243dSDimitry Andric Value *Elt = B.CreateCall(F, {B.getInt32(I)});
332bdd1243dSDimitry Andric Result = B.CreateInsertElement(Result, Elt, I);
333bdd1243dSDimitry Andric }
334bdd1243dSDimitry Andric return Result;
335bdd1243dSDimitry Andric }
336bdd1243dSDimitry Andric return B.CreateCall(F, {B.getInt32(0)});
337bdd1243dSDimitry Andric }
338bdd1243dSDimitry Andric
emitInputSemantic(IRBuilder<> & B,const ParmVarDecl & D,llvm::Type * Ty)339bdd1243dSDimitry Andric llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
340bdd1243dSDimitry Andric const ParmVarDecl &D,
341bdd1243dSDimitry Andric llvm::Type *Ty) {
342bdd1243dSDimitry Andric assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
343bdd1243dSDimitry Andric if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
344bdd1243dSDimitry Andric llvm::Function *DxGroupIndex =
345bdd1243dSDimitry Andric CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
346bdd1243dSDimitry Andric return B.CreateCall(FunctionCallee(DxGroupIndex));
347bdd1243dSDimitry Andric }
348bdd1243dSDimitry Andric if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
349*0fca6ea1SDimitry Andric llvm::Function *ThreadIDIntrinsic =
350*0fca6ea1SDimitry Andric CGM.getIntrinsic(getThreadIdIntrinsic());
351*0fca6ea1SDimitry Andric return buildVectorInput(B, ThreadIDIntrinsic, Ty);
352bdd1243dSDimitry Andric }
353bdd1243dSDimitry Andric assert(false && "Unhandled parameter attribute");
354bdd1243dSDimitry Andric return nullptr;
355bdd1243dSDimitry Andric }
356bdd1243dSDimitry Andric
emitEntryFunction(const FunctionDecl * FD,llvm::Function * Fn)357bdd1243dSDimitry Andric void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
358bdd1243dSDimitry Andric llvm::Function *Fn) {
359bdd1243dSDimitry Andric llvm::Module &M = CGM.getModule();
360bdd1243dSDimitry Andric llvm::LLVMContext &Ctx = M.getContext();
361bdd1243dSDimitry Andric auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
362bdd1243dSDimitry Andric Function *EntryFn =
363bdd1243dSDimitry Andric Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
364bdd1243dSDimitry Andric
365bdd1243dSDimitry Andric // Copy function attributes over, we have no argument or return attributes
366bdd1243dSDimitry Andric // that can be valid on the real entry.
367bdd1243dSDimitry Andric AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
368bdd1243dSDimitry Andric Fn->getAttributes().getFnAttrs());
369bdd1243dSDimitry Andric EntryFn->setAttributes(NewAttrs);
370bdd1243dSDimitry Andric setHLSLEntryAttributes(FD, EntryFn);
371bdd1243dSDimitry Andric
372bdd1243dSDimitry Andric // Set the called function as internal linkage.
373bdd1243dSDimitry Andric Fn->setLinkage(GlobalValue::InternalLinkage);
374bdd1243dSDimitry Andric
375bdd1243dSDimitry Andric BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
376bdd1243dSDimitry Andric IRBuilder<> B(BB);
377bdd1243dSDimitry Andric llvm::SmallVector<Value *> Args;
378bdd1243dSDimitry Andric // FIXME: support struct parameters where semantics are on members.
379bdd1243dSDimitry Andric // See: https://github.com/llvm/llvm-project/issues/57874
380bdd1243dSDimitry Andric unsigned SRetOffset = 0;
381bdd1243dSDimitry Andric for (const auto &Param : Fn->args()) {
382bdd1243dSDimitry Andric if (Param.hasStructRetAttr()) {
383bdd1243dSDimitry Andric // FIXME: support output.
384bdd1243dSDimitry Andric // See: https://github.com/llvm/llvm-project/issues/57874
385bdd1243dSDimitry Andric SRetOffset = 1;
386bdd1243dSDimitry Andric Args.emplace_back(PoisonValue::get(Param.getType()));
387bdd1243dSDimitry Andric continue;
388bdd1243dSDimitry Andric }
389bdd1243dSDimitry Andric const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
390bdd1243dSDimitry Andric Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
391bdd1243dSDimitry Andric }
392bdd1243dSDimitry Andric
393bdd1243dSDimitry Andric CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
394bdd1243dSDimitry Andric (void)CI;
395bdd1243dSDimitry Andric // FIXME: Handle codegen for return type semantics.
396bdd1243dSDimitry Andric // See: https://github.com/llvm/llvm-project/issues/57875
397bdd1243dSDimitry Andric B.CreateRetVoid();
398bdd1243dSDimitry Andric }
399bdd1243dSDimitry Andric
gatherFunctions(SmallVectorImpl<Function * > & Fns,llvm::Module & M,bool CtorOrDtor)400bdd1243dSDimitry Andric static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
401bdd1243dSDimitry Andric bool CtorOrDtor) {
402bdd1243dSDimitry Andric const auto *GV =
403bdd1243dSDimitry Andric M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
404bdd1243dSDimitry Andric if (!GV)
405bdd1243dSDimitry Andric return;
406bdd1243dSDimitry Andric const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
407bdd1243dSDimitry Andric if (!CA)
408bdd1243dSDimitry Andric return;
409bdd1243dSDimitry Andric // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
410bdd1243dSDimitry Andric // HLSL neither supports priorities or COMDat values, so we will check those
411bdd1243dSDimitry Andric // in an assert but not handle them.
412bdd1243dSDimitry Andric
413bdd1243dSDimitry Andric llvm::SmallVector<Function *> CtorFns;
414bdd1243dSDimitry Andric for (const auto &Ctor : CA->operands()) {
415bdd1243dSDimitry Andric if (isa<ConstantAggregateZero>(Ctor))
416bdd1243dSDimitry Andric continue;
417bdd1243dSDimitry Andric ConstantStruct *CS = cast<ConstantStruct>(Ctor);
418bdd1243dSDimitry Andric
419bdd1243dSDimitry Andric assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
420bdd1243dSDimitry Andric "HLSL doesn't support setting priority for global ctors.");
421bdd1243dSDimitry Andric assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
422bdd1243dSDimitry Andric "HLSL doesn't support COMDat for global ctors.");
423bdd1243dSDimitry Andric Fns.push_back(cast<Function>(CS->getOperand(1)));
424bdd1243dSDimitry Andric }
425bdd1243dSDimitry Andric }
426bdd1243dSDimitry Andric
generateGlobalCtorDtorCalls()427bdd1243dSDimitry Andric void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
428bdd1243dSDimitry Andric llvm::Module &M = CGM.getModule();
429bdd1243dSDimitry Andric SmallVector<Function *> CtorFns;
430bdd1243dSDimitry Andric SmallVector<Function *> DtorFns;
431bdd1243dSDimitry Andric gatherFunctions(CtorFns, M, true);
432bdd1243dSDimitry Andric gatherFunctions(DtorFns, M, false);
433bdd1243dSDimitry Andric
434bdd1243dSDimitry Andric // Insert a call to the global constructor at the beginning of the entry block
435bdd1243dSDimitry Andric // to externally exported functions. This is a bit of a hack, but HLSL allows
436bdd1243dSDimitry Andric // global constructors, but doesn't support driver initialization of globals.
437bdd1243dSDimitry Andric for (auto &F : M.functions()) {
438bdd1243dSDimitry Andric if (!F.hasFnAttribute("hlsl.shader"))
439bdd1243dSDimitry Andric continue;
440bdd1243dSDimitry Andric IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
441bdd1243dSDimitry Andric for (auto *Fn : CtorFns)
442bdd1243dSDimitry Andric B.CreateCall(FunctionCallee(Fn));
443bdd1243dSDimitry Andric
444bdd1243dSDimitry Andric // Insert global dtors before the terminator of the last instruction
445bdd1243dSDimitry Andric B.SetInsertPoint(F.back().getTerminator());
446bdd1243dSDimitry Andric for (auto *Fn : DtorFns)
447bdd1243dSDimitry Andric B.CreateCall(FunctionCallee(Fn));
448bdd1243dSDimitry Andric }
449bdd1243dSDimitry Andric
450bdd1243dSDimitry Andric // No need to keep global ctors/dtors for non-lib profile after call to
451bdd1243dSDimitry Andric // ctors/dtors added for entry.
452bdd1243dSDimitry Andric Triple T(M.getTargetTriple());
453bdd1243dSDimitry Andric if (T.getEnvironment() != Triple::EnvironmentType::Library) {
454bdd1243dSDimitry Andric if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
455bdd1243dSDimitry Andric GV->eraseFromParent();
456bdd1243dSDimitry Andric if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
457bdd1243dSDimitry Andric GV->eraseFromParent();
458bdd1243dSDimitry Andric }
45981ad6265SDimitry Andric }
460