xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CGHLSLRuntime.cpp (revision bdd1243df58e60e85101c09001d9812a789b6bc4)
181ad6265SDimitry Andric //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
281ad6265SDimitry Andric //
381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
681ad6265SDimitry Andric //
781ad6265SDimitry Andric //===----------------------------------------------------------------------===//
881ad6265SDimitry Andric //
981ad6265SDimitry Andric // This provides an abstract class for HLSL code generation.  Concrete
1081ad6265SDimitry Andric // subclasses of this implement code generation for specific HLSL
1181ad6265SDimitry Andric // runtime libraries.
1281ad6265SDimitry Andric //
1381ad6265SDimitry Andric //===----------------------------------------------------------------------===//
1481ad6265SDimitry Andric 
1581ad6265SDimitry Andric #include "CGHLSLRuntime.h"
16*bdd1243dSDimitry Andric #include "CGDebugInfo.h"
1781ad6265SDimitry Andric #include "CodeGenModule.h"
18*bdd1243dSDimitry Andric #include "clang/AST/Decl.h"
1981ad6265SDimitry Andric #include "clang/Basic/TargetOptions.h"
20*bdd1243dSDimitry Andric #include "llvm/IR/IntrinsicsDirectX.h"
2181ad6265SDimitry Andric #include "llvm/IR/Metadata.h"
2281ad6265SDimitry Andric #include "llvm/IR/Module.h"
23*bdd1243dSDimitry Andric #include "llvm/Support/FormatVariadic.h"
2481ad6265SDimitry Andric 
2581ad6265SDimitry Andric using namespace clang;
2681ad6265SDimitry Andric using namespace CodeGen;
27*bdd1243dSDimitry Andric using namespace clang::hlsl;
2881ad6265SDimitry Andric using namespace llvm;
2981ad6265SDimitry Andric 
3081ad6265SDimitry Andric namespace {
31*bdd1243dSDimitry Andric 
3281ad6265SDimitry Andric void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
3381ad6265SDimitry Andric   // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
3481ad6265SDimitry Andric   // Assume ValVersionStr is legal here.
3581ad6265SDimitry Andric   VersionTuple Version;
3681ad6265SDimitry Andric   if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
3781ad6265SDimitry Andric       Version.getSubminor() || !Version.getMinor()) {
3881ad6265SDimitry Andric     return;
3981ad6265SDimitry Andric   }
4081ad6265SDimitry Andric 
4181ad6265SDimitry Andric   uint64_t Major = Version.getMajor();
4281ad6265SDimitry Andric   uint64_t Minor = *Version.getMinor();
4381ad6265SDimitry Andric 
4481ad6265SDimitry Andric   auto &Ctx = M.getContext();
4581ad6265SDimitry Andric   IRBuilder<> B(M.getContext());
4681ad6265SDimitry Andric   MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
4781ad6265SDimitry Andric                                   ConstantAsMetadata::get(B.getInt32(Minor))});
48*bdd1243dSDimitry Andric   StringRef DXILValKey = "dx.valver";
49*bdd1243dSDimitry Andric   auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
50*bdd1243dSDimitry Andric   DXILValMD->addOperand(Val);
5181ad6265SDimitry Andric }
52*bdd1243dSDimitry Andric void addDisableOptimizations(llvm::Module &M) {
53*bdd1243dSDimitry Andric   StringRef Key = "dx.disable_optimizations";
54*bdd1243dSDimitry Andric   M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
55*bdd1243dSDimitry Andric }
56*bdd1243dSDimitry Andric // cbuffer will be translated into global variable in special address space.
57*bdd1243dSDimitry Andric // If translate into C,
58*bdd1243dSDimitry Andric // cbuffer A {
59*bdd1243dSDimitry Andric //   float a;
60*bdd1243dSDimitry Andric //   float b;
61*bdd1243dSDimitry Andric // }
62*bdd1243dSDimitry Andric // float foo() { return a + b; }
63*bdd1243dSDimitry Andric //
64*bdd1243dSDimitry Andric // will be translated into
65*bdd1243dSDimitry Andric //
66*bdd1243dSDimitry Andric // struct A {
67*bdd1243dSDimitry Andric //   float a;
68*bdd1243dSDimitry Andric //   float b;
69*bdd1243dSDimitry Andric // } cbuffer_A __attribute__((address_space(4)));
70*bdd1243dSDimitry Andric // float foo() { return cbuffer_A.a + cbuffer_A.b; }
71*bdd1243dSDimitry Andric //
72*bdd1243dSDimitry Andric // layoutBuffer will create the struct A type.
73*bdd1243dSDimitry Andric // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
74*bdd1243dSDimitry Andric // and cbuffer_A.b.
75*bdd1243dSDimitry Andric //
76*bdd1243dSDimitry Andric void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
77*bdd1243dSDimitry Andric   if (Buf.Constants.empty())
78*bdd1243dSDimitry Andric     return;
79*bdd1243dSDimitry Andric 
80*bdd1243dSDimitry Andric   std::vector<llvm::Type *> EltTys;
81*bdd1243dSDimitry Andric   for (auto &Const : Buf.Constants) {
82*bdd1243dSDimitry Andric     GlobalVariable *GV = Const.first;
83*bdd1243dSDimitry Andric     Const.second = EltTys.size();
84*bdd1243dSDimitry Andric     llvm::Type *Ty = GV->getValueType();
85*bdd1243dSDimitry Andric     EltTys.emplace_back(Ty);
86*bdd1243dSDimitry Andric   }
87*bdd1243dSDimitry Andric   Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
88*bdd1243dSDimitry Andric }
89*bdd1243dSDimitry Andric 
90*bdd1243dSDimitry Andric GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
91*bdd1243dSDimitry Andric   // Create global variable for CB.
92*bdd1243dSDimitry Andric   GlobalVariable *CBGV = new GlobalVariable(
93*bdd1243dSDimitry Andric       Buf.LayoutStruct, /*isConstant*/ true,
94*bdd1243dSDimitry Andric       GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
95*bdd1243dSDimitry Andric       llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
96*bdd1243dSDimitry Andric       GlobalValue::NotThreadLocal);
97*bdd1243dSDimitry Andric 
98*bdd1243dSDimitry Andric   IRBuilder<> B(CBGV->getContext());
99*bdd1243dSDimitry Andric   Value *ZeroIdx = B.getInt32(0);
100*bdd1243dSDimitry Andric   // Replace Const use with CB use.
101*bdd1243dSDimitry Andric   for (auto &[GV, Offset] : Buf.Constants) {
102*bdd1243dSDimitry Andric     Value *GEP =
103*bdd1243dSDimitry Andric         B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
104*bdd1243dSDimitry Andric 
105*bdd1243dSDimitry Andric     assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
106*bdd1243dSDimitry Andric            "constant type mismatch");
107*bdd1243dSDimitry Andric 
108*bdd1243dSDimitry Andric     // Replace.
109*bdd1243dSDimitry Andric     GV->replaceAllUsesWith(GEP);
110*bdd1243dSDimitry Andric     // Erase GV.
111*bdd1243dSDimitry Andric     GV->removeDeadConstantUsers();
112*bdd1243dSDimitry Andric     GV->eraseFromParent();
113*bdd1243dSDimitry Andric   }
114*bdd1243dSDimitry Andric   return CBGV;
115*bdd1243dSDimitry Andric }
116*bdd1243dSDimitry Andric 
11781ad6265SDimitry Andric } // namespace
11881ad6265SDimitry Andric 
119*bdd1243dSDimitry Andric void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
120*bdd1243dSDimitry Andric   if (D->getStorageClass() == SC_Static) {
121*bdd1243dSDimitry Andric     // For static inside cbuffer, take as global static.
122*bdd1243dSDimitry Andric     // Don't add to cbuffer.
123*bdd1243dSDimitry Andric     CGM.EmitGlobal(D);
124*bdd1243dSDimitry Andric     return;
125*bdd1243dSDimitry Andric   }
126*bdd1243dSDimitry Andric 
127*bdd1243dSDimitry Andric   auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
128*bdd1243dSDimitry Andric   // Add debug info for constVal.
129*bdd1243dSDimitry Andric   if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
130*bdd1243dSDimitry Andric     if (CGM.getCodeGenOpts().getDebugInfo() >=
131*bdd1243dSDimitry Andric         codegenoptions::DebugInfoKind::LimitedDebugInfo)
132*bdd1243dSDimitry Andric       DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
133*bdd1243dSDimitry Andric 
134*bdd1243dSDimitry Andric   // FIXME: support packoffset.
135*bdd1243dSDimitry Andric   // See https://github.com/llvm/llvm-project/issues/57914.
136*bdd1243dSDimitry Andric   uint32_t Offset = 0;
137*bdd1243dSDimitry Andric   bool HasUserOffset = false;
138*bdd1243dSDimitry Andric 
139*bdd1243dSDimitry Andric   unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
140*bdd1243dSDimitry Andric   CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
141*bdd1243dSDimitry Andric }
142*bdd1243dSDimitry Andric 
143*bdd1243dSDimitry Andric void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
144*bdd1243dSDimitry Andric   for (Decl *it : DC->decls()) {
145*bdd1243dSDimitry Andric     if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
146*bdd1243dSDimitry Andric       addConstant(ConstDecl, CB);
147*bdd1243dSDimitry Andric     } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
148*bdd1243dSDimitry Andric       // Nothing to do for this declaration.
149*bdd1243dSDimitry Andric     } else if (isa<FunctionDecl>(it)) {
150*bdd1243dSDimitry Andric       // A function within an cbuffer is effectively a top-level function,
151*bdd1243dSDimitry Andric       // as it only refers to globally scoped declarations.
152*bdd1243dSDimitry Andric       CGM.EmitTopLevelDecl(it);
153*bdd1243dSDimitry Andric     }
154*bdd1243dSDimitry Andric   }
155*bdd1243dSDimitry Andric }
156*bdd1243dSDimitry Andric 
157*bdd1243dSDimitry Andric void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
158*bdd1243dSDimitry Andric   Buffers.emplace_back(Buffer(D));
159*bdd1243dSDimitry Andric   addBufferDecls(D, Buffers.back());
160*bdd1243dSDimitry Andric }
161*bdd1243dSDimitry Andric 
16281ad6265SDimitry Andric void CGHLSLRuntime::finishCodeGen() {
16381ad6265SDimitry Andric   auto &TargetOpts = CGM.getTarget().getTargetOpts();
16481ad6265SDimitry Andric   llvm::Module &M = CGM.getModule();
165*bdd1243dSDimitry Andric   Triple T(M.getTargetTriple());
166*bdd1243dSDimitry Andric   if (T.getArch() == Triple::ArchType::dxil)
16781ad6265SDimitry Andric     addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
168*bdd1243dSDimitry Andric 
169*bdd1243dSDimitry Andric   generateGlobalCtorDtorCalls();
170*bdd1243dSDimitry Andric   if (CGM.getCodeGenOpts().OptimizationLevel == 0)
171*bdd1243dSDimitry Andric     addDisableOptimizations(M);
172*bdd1243dSDimitry Andric 
173*bdd1243dSDimitry Andric   const DataLayout &DL = M.getDataLayout();
174*bdd1243dSDimitry Andric 
175*bdd1243dSDimitry Andric   for (auto &Buf : Buffers) {
176*bdd1243dSDimitry Andric     layoutBuffer(Buf, DL);
177*bdd1243dSDimitry Andric     GlobalVariable *GV = replaceBuffer(Buf);
178*bdd1243dSDimitry Andric     M.getGlobalList().push_back(GV);
179*bdd1243dSDimitry Andric     llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
180*bdd1243dSDimitry Andric                                        ? llvm::hlsl::ResourceClass::CBuffer
181*bdd1243dSDimitry Andric                                        : llvm::hlsl::ResourceClass::SRV;
182*bdd1243dSDimitry Andric     llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
183*bdd1243dSDimitry Andric                                       ? llvm::hlsl::ResourceKind::CBuffer
184*bdd1243dSDimitry Andric                                       : llvm::hlsl::ResourceKind::TBuffer;
185*bdd1243dSDimitry Andric     std::string TyName =
186*bdd1243dSDimitry Andric         Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty";
187*bdd1243dSDimitry Andric     addBufferResourceAnnotation(GV, TyName, RC, RK, Buf.Binding);
188*bdd1243dSDimitry Andric   }
189*bdd1243dSDimitry Andric }
190*bdd1243dSDimitry Andric 
191*bdd1243dSDimitry Andric CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
192*bdd1243dSDimitry Andric     : Name(D->getName()), IsCBuffer(D->isCBuffer()),
193*bdd1243dSDimitry Andric       Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
194*bdd1243dSDimitry Andric 
195*bdd1243dSDimitry Andric void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
196*bdd1243dSDimitry Andric                                                 llvm::StringRef TyName,
197*bdd1243dSDimitry Andric                                                 llvm::hlsl::ResourceClass RC,
198*bdd1243dSDimitry Andric                                                 llvm::hlsl::ResourceKind RK,
199*bdd1243dSDimitry Andric                                                 BufferResBinding &Binding) {
200*bdd1243dSDimitry Andric   llvm::Module &M = CGM.getModule();
201*bdd1243dSDimitry Andric 
202*bdd1243dSDimitry Andric   NamedMDNode *ResourceMD = nullptr;
203*bdd1243dSDimitry Andric   switch (RC) {
204*bdd1243dSDimitry Andric   case llvm::hlsl::ResourceClass::UAV:
205*bdd1243dSDimitry Andric     ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
206*bdd1243dSDimitry Andric     break;
207*bdd1243dSDimitry Andric   case llvm::hlsl::ResourceClass::SRV:
208*bdd1243dSDimitry Andric     ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
209*bdd1243dSDimitry Andric     break;
210*bdd1243dSDimitry Andric   case llvm::hlsl::ResourceClass::CBuffer:
211*bdd1243dSDimitry Andric     ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
212*bdd1243dSDimitry Andric     break;
213*bdd1243dSDimitry Andric   default:
214*bdd1243dSDimitry Andric     assert(false && "Unsupported buffer type!");
215*bdd1243dSDimitry Andric     return;
216*bdd1243dSDimitry Andric   }
217*bdd1243dSDimitry Andric 
218*bdd1243dSDimitry Andric   assert(ResourceMD != nullptr &&
219*bdd1243dSDimitry Andric          "ResourceMD must have been set by the switch above.");
220*bdd1243dSDimitry Andric 
221*bdd1243dSDimitry Andric   llvm::hlsl::FrontendResource Res(
222*bdd1243dSDimitry Andric       GV, TyName, RK, Binding.Reg.value_or(UINT_MAX), Binding.Space);
223*bdd1243dSDimitry Andric   ResourceMD->addOperand(Res.getMetadata());
224*bdd1243dSDimitry Andric }
225*bdd1243dSDimitry Andric 
226*bdd1243dSDimitry Andric static llvm::hlsl::ResourceKind
227*bdd1243dSDimitry Andric castResourceShapeToResourceKind(HLSLResourceAttr::ResourceKind RK) {
228*bdd1243dSDimitry Andric   switch (RK) {
229*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::Texture1D:
230*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::Texture1D;
231*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::Texture2D:
232*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::Texture2D;
233*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::Texture2DMS:
234*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::Texture2DMS;
235*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::Texture3D:
236*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::Texture3D;
237*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::TextureCube:
238*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::TextureCube;
239*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::Texture1DArray:
240*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::Texture1DArray;
241*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::Texture2DArray:
242*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::Texture2DArray;
243*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::Texture2DMSArray:
244*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::Texture2DMSArray;
245*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::TextureCubeArray:
246*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::TextureCubeArray;
247*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::TypedBuffer:
248*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::TypedBuffer;
249*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::RawBuffer:
250*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::RawBuffer;
251*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::StructuredBuffer:
252*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::StructuredBuffer;
253*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::CBufferKind:
254*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::CBuffer;
255*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::SamplerKind:
256*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::Sampler;
257*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::TBuffer:
258*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::TBuffer;
259*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::RTAccelerationStructure:
260*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::RTAccelerationStructure;
261*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::FeedbackTexture2D:
262*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::FeedbackTexture2D;
263*bdd1243dSDimitry Andric   case HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray:
264*bdd1243dSDimitry Andric     return llvm::hlsl::ResourceKind::FeedbackTexture2DArray;
265*bdd1243dSDimitry Andric   }
266*bdd1243dSDimitry Andric   // Make sure to update HLSLResourceAttr::ResourceKind when add new Kind to
267*bdd1243dSDimitry Andric   // hlsl::ResourceKind. Assume FeedbackTexture2DArray is the last enum for
268*bdd1243dSDimitry Andric   // HLSLResourceAttr::ResourceKind.
269*bdd1243dSDimitry Andric   static_assert(
270*bdd1243dSDimitry Andric       static_cast<uint32_t>(
271*bdd1243dSDimitry Andric           HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray) ==
272*bdd1243dSDimitry Andric       (static_cast<uint32_t>(llvm::hlsl::ResourceKind::NumEntries) - 2));
273*bdd1243dSDimitry Andric   llvm_unreachable("all switch cases should be covered");
274*bdd1243dSDimitry Andric }
275*bdd1243dSDimitry Andric 
276*bdd1243dSDimitry Andric void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
277*bdd1243dSDimitry Andric   const Type *Ty = D->getType()->getPointeeOrArrayElementType();
278*bdd1243dSDimitry Andric   if (!Ty)
279*bdd1243dSDimitry Andric     return;
280*bdd1243dSDimitry Andric   const auto *RD = Ty->getAsCXXRecordDecl();
281*bdd1243dSDimitry Andric   if (!RD)
282*bdd1243dSDimitry Andric     return;
283*bdd1243dSDimitry Andric   const auto *Attr = RD->getAttr<HLSLResourceAttr>();
284*bdd1243dSDimitry Andric   if (!Attr)
285*bdd1243dSDimitry Andric     return;
286*bdd1243dSDimitry Andric 
287*bdd1243dSDimitry Andric   HLSLResourceAttr::ResourceClass RC = Attr->getResourceType();
288*bdd1243dSDimitry Andric   llvm::hlsl::ResourceKind RK =
289*bdd1243dSDimitry Andric       castResourceShapeToResourceKind(Attr->getResourceShape());
290*bdd1243dSDimitry Andric 
291*bdd1243dSDimitry Andric   QualType QT(Ty, 0);
292*bdd1243dSDimitry Andric   BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
293*bdd1243dSDimitry Andric   addBufferResourceAnnotation(GV, QT.getAsString(),
294*bdd1243dSDimitry Andric                               static_cast<llvm::hlsl::ResourceClass>(RC), RK,
295*bdd1243dSDimitry Andric                               Binding);
296*bdd1243dSDimitry Andric }
297*bdd1243dSDimitry Andric 
298*bdd1243dSDimitry Andric CGHLSLRuntime::BufferResBinding::BufferResBinding(
299*bdd1243dSDimitry Andric     HLSLResourceBindingAttr *Binding) {
300*bdd1243dSDimitry Andric   if (Binding) {
301*bdd1243dSDimitry Andric     llvm::APInt RegInt(64, 0);
302*bdd1243dSDimitry Andric     Binding->getSlot().substr(1).getAsInteger(10, RegInt);
303*bdd1243dSDimitry Andric     Reg = RegInt.getLimitedValue();
304*bdd1243dSDimitry Andric     llvm::APInt SpaceInt(64, 0);
305*bdd1243dSDimitry Andric     Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
306*bdd1243dSDimitry Andric     Space = SpaceInt.getLimitedValue();
307*bdd1243dSDimitry Andric   } else {
308*bdd1243dSDimitry Andric     Space = 0;
309*bdd1243dSDimitry Andric   }
310*bdd1243dSDimitry Andric }
311*bdd1243dSDimitry Andric 
312*bdd1243dSDimitry Andric void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
313*bdd1243dSDimitry Andric     const FunctionDecl *FD, llvm::Function *Fn) {
314*bdd1243dSDimitry Andric   const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
315*bdd1243dSDimitry Andric   assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
316*bdd1243dSDimitry Andric   const StringRef ShaderAttrKindStr = "hlsl.shader";
317*bdd1243dSDimitry Andric   Fn->addFnAttr(ShaderAttrKindStr,
318*bdd1243dSDimitry Andric                 ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
319*bdd1243dSDimitry Andric   if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
320*bdd1243dSDimitry Andric     const StringRef NumThreadsKindStr = "hlsl.numthreads";
321*bdd1243dSDimitry Andric     std::string NumThreadsStr =
322*bdd1243dSDimitry Andric         formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
323*bdd1243dSDimitry Andric                 NumThreadsAttr->getZ());
324*bdd1243dSDimitry Andric     Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
325*bdd1243dSDimitry Andric   }
326*bdd1243dSDimitry Andric }
327*bdd1243dSDimitry Andric 
328*bdd1243dSDimitry Andric static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
329*bdd1243dSDimitry Andric   if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
330*bdd1243dSDimitry Andric     Value *Result = PoisonValue::get(Ty);
331*bdd1243dSDimitry Andric     for (unsigned I = 0; I < VT->getNumElements(); ++I) {
332*bdd1243dSDimitry Andric       Value *Elt = B.CreateCall(F, {B.getInt32(I)});
333*bdd1243dSDimitry Andric       Result = B.CreateInsertElement(Result, Elt, I);
334*bdd1243dSDimitry Andric     }
335*bdd1243dSDimitry Andric     return Result;
336*bdd1243dSDimitry Andric   }
337*bdd1243dSDimitry Andric   return B.CreateCall(F, {B.getInt32(0)});
338*bdd1243dSDimitry Andric }
339*bdd1243dSDimitry Andric 
340*bdd1243dSDimitry Andric llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
341*bdd1243dSDimitry Andric                                               const ParmVarDecl &D,
342*bdd1243dSDimitry Andric                                               llvm::Type *Ty) {
343*bdd1243dSDimitry Andric   assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
344*bdd1243dSDimitry Andric   if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
345*bdd1243dSDimitry Andric     llvm::Function *DxGroupIndex =
346*bdd1243dSDimitry Andric         CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
347*bdd1243dSDimitry Andric     return B.CreateCall(FunctionCallee(DxGroupIndex));
348*bdd1243dSDimitry Andric   }
349*bdd1243dSDimitry Andric   if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
350*bdd1243dSDimitry Andric     llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id);
351*bdd1243dSDimitry Andric     return buildVectorInput(B, DxThreadID, Ty);
352*bdd1243dSDimitry Andric   }
353*bdd1243dSDimitry Andric   assert(false && "Unhandled parameter attribute");
354*bdd1243dSDimitry Andric   return nullptr;
355*bdd1243dSDimitry Andric }
356*bdd1243dSDimitry Andric 
357*bdd1243dSDimitry Andric void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
358*bdd1243dSDimitry Andric                                       llvm::Function *Fn) {
359*bdd1243dSDimitry Andric   llvm::Module &M = CGM.getModule();
360*bdd1243dSDimitry Andric   llvm::LLVMContext &Ctx = M.getContext();
361*bdd1243dSDimitry Andric   auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
362*bdd1243dSDimitry Andric   Function *EntryFn =
363*bdd1243dSDimitry Andric       Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
364*bdd1243dSDimitry Andric 
365*bdd1243dSDimitry Andric   // Copy function attributes over, we have no argument or return attributes
366*bdd1243dSDimitry Andric   // that can be valid on the real entry.
367*bdd1243dSDimitry Andric   AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
368*bdd1243dSDimitry Andric                                               Fn->getAttributes().getFnAttrs());
369*bdd1243dSDimitry Andric   EntryFn->setAttributes(NewAttrs);
370*bdd1243dSDimitry Andric   setHLSLEntryAttributes(FD, EntryFn);
371*bdd1243dSDimitry Andric 
372*bdd1243dSDimitry Andric   // Set the called function as internal linkage.
373*bdd1243dSDimitry Andric   Fn->setLinkage(GlobalValue::InternalLinkage);
374*bdd1243dSDimitry Andric 
375*bdd1243dSDimitry Andric   BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
376*bdd1243dSDimitry Andric   IRBuilder<> B(BB);
377*bdd1243dSDimitry Andric   llvm::SmallVector<Value *> Args;
378*bdd1243dSDimitry Andric   // FIXME: support struct parameters where semantics are on members.
379*bdd1243dSDimitry Andric   // See: https://github.com/llvm/llvm-project/issues/57874
380*bdd1243dSDimitry Andric   unsigned SRetOffset = 0;
381*bdd1243dSDimitry Andric   for (const auto &Param : Fn->args()) {
382*bdd1243dSDimitry Andric     if (Param.hasStructRetAttr()) {
383*bdd1243dSDimitry Andric       // FIXME: support output.
384*bdd1243dSDimitry Andric       // See: https://github.com/llvm/llvm-project/issues/57874
385*bdd1243dSDimitry Andric       SRetOffset = 1;
386*bdd1243dSDimitry Andric       Args.emplace_back(PoisonValue::get(Param.getType()));
387*bdd1243dSDimitry Andric       continue;
388*bdd1243dSDimitry Andric     }
389*bdd1243dSDimitry Andric     const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
390*bdd1243dSDimitry Andric     Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
391*bdd1243dSDimitry Andric   }
392*bdd1243dSDimitry Andric 
393*bdd1243dSDimitry Andric   CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
394*bdd1243dSDimitry Andric   (void)CI;
395*bdd1243dSDimitry Andric   // FIXME: Handle codegen for return type semantics.
396*bdd1243dSDimitry Andric   // See: https://github.com/llvm/llvm-project/issues/57875
397*bdd1243dSDimitry Andric   B.CreateRetVoid();
398*bdd1243dSDimitry Andric }
399*bdd1243dSDimitry Andric 
400*bdd1243dSDimitry Andric static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
401*bdd1243dSDimitry Andric                             bool CtorOrDtor) {
402*bdd1243dSDimitry Andric   const auto *GV =
403*bdd1243dSDimitry Andric       M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
404*bdd1243dSDimitry Andric   if (!GV)
405*bdd1243dSDimitry Andric     return;
406*bdd1243dSDimitry Andric   const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
407*bdd1243dSDimitry Andric   if (!CA)
408*bdd1243dSDimitry Andric     return;
409*bdd1243dSDimitry Andric   // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
410*bdd1243dSDimitry Andric   // HLSL neither supports priorities or COMDat values, so we will check those
411*bdd1243dSDimitry Andric   // in an assert but not handle them.
412*bdd1243dSDimitry Andric 
413*bdd1243dSDimitry Andric   llvm::SmallVector<Function *> CtorFns;
414*bdd1243dSDimitry Andric   for (const auto &Ctor : CA->operands()) {
415*bdd1243dSDimitry Andric     if (isa<ConstantAggregateZero>(Ctor))
416*bdd1243dSDimitry Andric       continue;
417*bdd1243dSDimitry Andric     ConstantStruct *CS = cast<ConstantStruct>(Ctor);
418*bdd1243dSDimitry Andric 
419*bdd1243dSDimitry Andric     assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
420*bdd1243dSDimitry Andric            "HLSL doesn't support setting priority for global ctors.");
421*bdd1243dSDimitry Andric     assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
422*bdd1243dSDimitry Andric            "HLSL doesn't support COMDat for global ctors.");
423*bdd1243dSDimitry Andric     Fns.push_back(cast<Function>(CS->getOperand(1)));
424*bdd1243dSDimitry Andric   }
425*bdd1243dSDimitry Andric }
426*bdd1243dSDimitry Andric 
427*bdd1243dSDimitry Andric void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
428*bdd1243dSDimitry Andric   llvm::Module &M = CGM.getModule();
429*bdd1243dSDimitry Andric   SmallVector<Function *> CtorFns;
430*bdd1243dSDimitry Andric   SmallVector<Function *> DtorFns;
431*bdd1243dSDimitry Andric   gatherFunctions(CtorFns, M, true);
432*bdd1243dSDimitry Andric   gatherFunctions(DtorFns, M, false);
433*bdd1243dSDimitry Andric 
434*bdd1243dSDimitry Andric   // Insert a call to the global constructor at the beginning of the entry block
435*bdd1243dSDimitry Andric   // to externally exported functions. This is a bit of a hack, but HLSL allows
436*bdd1243dSDimitry Andric   // global constructors, but doesn't support driver initialization of globals.
437*bdd1243dSDimitry Andric   for (auto &F : M.functions()) {
438*bdd1243dSDimitry Andric     if (!F.hasFnAttribute("hlsl.shader"))
439*bdd1243dSDimitry Andric       continue;
440*bdd1243dSDimitry Andric     IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
441*bdd1243dSDimitry Andric     for (auto *Fn : CtorFns)
442*bdd1243dSDimitry Andric       B.CreateCall(FunctionCallee(Fn));
443*bdd1243dSDimitry Andric 
444*bdd1243dSDimitry Andric     // Insert global dtors before the terminator of the last instruction
445*bdd1243dSDimitry Andric     B.SetInsertPoint(F.back().getTerminator());
446*bdd1243dSDimitry Andric     for (auto *Fn : DtorFns)
447*bdd1243dSDimitry Andric       B.CreateCall(FunctionCallee(Fn));
448*bdd1243dSDimitry Andric   }
449*bdd1243dSDimitry Andric 
450*bdd1243dSDimitry Andric   // No need to keep global ctors/dtors for non-lib profile after call to
451*bdd1243dSDimitry Andric   // ctors/dtors added for entry.
452*bdd1243dSDimitry Andric   Triple T(M.getTargetTriple());
453*bdd1243dSDimitry Andric   if (T.getEnvironment() != Triple::EnvironmentType::Library) {
454*bdd1243dSDimitry Andric     if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
455*bdd1243dSDimitry Andric       GV->eraseFromParent();
456*bdd1243dSDimitry Andric     if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
457*bdd1243dSDimitry Andric       GV->eraseFromParent();
458*bdd1243dSDimitry Andric   }
45981ad6265SDimitry Andric }
460