xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CGHLSLRuntime.cpp (revision 5f757f3ff9144b609b3c433dfd370cc6bdc191ad)
1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This provides an abstract class for HLSL code generation.  Concrete
10 // subclasses of this implement code generation for specific HLSL
11 // runtime libraries.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "CGHLSLRuntime.h"
16 #include "CGDebugInfo.h"
17 #include "CodeGenModule.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/Basic/TargetOptions.h"
20 #include "llvm/IR/IntrinsicsDirectX.h"
21 #include "llvm/IR/Metadata.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/Support/FormatVariadic.h"
24 
25 using namespace clang;
26 using namespace CodeGen;
27 using namespace clang::hlsl;
28 using namespace llvm;
29 
30 namespace {
31 
32 void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
33   // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
34   // Assume ValVersionStr is legal here.
35   VersionTuple Version;
36   if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
37       Version.getSubminor() || !Version.getMinor()) {
38     return;
39   }
40 
41   uint64_t Major = Version.getMajor();
42   uint64_t Minor = *Version.getMinor();
43 
44   auto &Ctx = M.getContext();
45   IRBuilder<> B(M.getContext());
46   MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
47                                   ConstantAsMetadata::get(B.getInt32(Minor))});
48   StringRef DXILValKey = "dx.valver";
49   auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
50   DXILValMD->addOperand(Val);
51 }
52 void addDisableOptimizations(llvm::Module &M) {
53   StringRef Key = "dx.disable_optimizations";
54   M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
55 }
56 // cbuffer will be translated into global variable in special address space.
57 // If translate into C,
58 // cbuffer A {
59 //   float a;
60 //   float b;
61 // }
62 // float foo() { return a + b; }
63 //
64 // will be translated into
65 //
66 // struct A {
67 //   float a;
68 //   float b;
69 // } cbuffer_A __attribute__((address_space(4)));
70 // float foo() { return cbuffer_A.a + cbuffer_A.b; }
71 //
72 // layoutBuffer will create the struct A type.
73 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
74 // and cbuffer_A.b.
75 //
76 void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
77   if (Buf.Constants.empty())
78     return;
79 
80   std::vector<llvm::Type *> EltTys;
81   for (auto &Const : Buf.Constants) {
82     GlobalVariable *GV = Const.first;
83     Const.second = EltTys.size();
84     llvm::Type *Ty = GV->getValueType();
85     EltTys.emplace_back(Ty);
86   }
87   Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
88 }
89 
90 GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
91   // Create global variable for CB.
92   GlobalVariable *CBGV = new GlobalVariable(
93       Buf.LayoutStruct, /*isConstant*/ true,
94       GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
95       llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
96       GlobalValue::NotThreadLocal);
97 
98   IRBuilder<> B(CBGV->getContext());
99   Value *ZeroIdx = B.getInt32(0);
100   // Replace Const use with CB use.
101   for (auto &[GV, Offset] : Buf.Constants) {
102     Value *GEP =
103         B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
104 
105     assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
106            "constant type mismatch");
107 
108     // Replace.
109     GV->replaceAllUsesWith(GEP);
110     // Erase GV.
111     GV->removeDeadConstantUsers();
112     GV->eraseFromParent();
113   }
114   return CBGV;
115 }
116 
117 } // namespace
118 
119 void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
120   if (D->getStorageClass() == SC_Static) {
121     // For static inside cbuffer, take as global static.
122     // Don't add to cbuffer.
123     CGM.EmitGlobal(D);
124     return;
125   }
126 
127   auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
128   // Add debug info for constVal.
129   if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
130     if (CGM.getCodeGenOpts().getDebugInfo() >=
131         codegenoptions::DebugInfoKind::LimitedDebugInfo)
132       DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
133 
134   // FIXME: support packoffset.
135   // See https://github.com/llvm/llvm-project/issues/57914.
136   uint32_t Offset = 0;
137   bool HasUserOffset = false;
138 
139   unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
140   CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
141 }
142 
143 void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
144   for (Decl *it : DC->decls()) {
145     if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
146       addConstant(ConstDecl, CB);
147     } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
148       // Nothing to do for this declaration.
149     } else if (isa<FunctionDecl>(it)) {
150       // A function within an cbuffer is effectively a top-level function,
151       // as it only refers to globally scoped declarations.
152       CGM.EmitTopLevelDecl(it);
153     }
154   }
155 }
156 
157 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
158   Buffers.emplace_back(Buffer(D));
159   addBufferDecls(D, Buffers.back());
160 }
161 
162 void CGHLSLRuntime::finishCodeGen() {
163   auto &TargetOpts = CGM.getTarget().getTargetOpts();
164   llvm::Module &M = CGM.getModule();
165   Triple T(M.getTargetTriple());
166   if (T.getArch() == Triple::ArchType::dxil)
167     addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
168 
169   generateGlobalCtorDtorCalls();
170   if (CGM.getCodeGenOpts().OptimizationLevel == 0)
171     addDisableOptimizations(M);
172 
173   const DataLayout &DL = M.getDataLayout();
174 
175   for (auto &Buf : Buffers) {
176     layoutBuffer(Buf, DL);
177     GlobalVariable *GV = replaceBuffer(Buf);
178     M.insertGlobalVariable(GV);
179     llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
180                                        ? llvm::hlsl::ResourceClass::CBuffer
181                                        : llvm::hlsl::ResourceClass::SRV;
182     llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
183                                       ? llvm::hlsl::ResourceKind::CBuffer
184                                       : llvm::hlsl::ResourceKind::TBuffer;
185     std::string TyName =
186         Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty";
187     addBufferResourceAnnotation(GV, TyName, RC, RK, /*IsROV=*/false,
188                                 Buf.Binding);
189   }
190 }
191 
192 CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
193     : Name(D->getName()), IsCBuffer(D->isCBuffer()),
194       Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
195 
196 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
197                                                 llvm::StringRef TyName,
198                                                 llvm::hlsl::ResourceClass RC,
199                                                 llvm::hlsl::ResourceKind RK,
200                                                 bool IsROV,
201                                                 BufferResBinding &Binding) {
202   llvm::Module &M = CGM.getModule();
203 
204   NamedMDNode *ResourceMD = nullptr;
205   switch (RC) {
206   case llvm::hlsl::ResourceClass::UAV:
207     ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
208     break;
209   case llvm::hlsl::ResourceClass::SRV:
210     ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
211     break;
212   case llvm::hlsl::ResourceClass::CBuffer:
213     ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
214     break;
215   default:
216     assert(false && "Unsupported buffer type!");
217     return;
218   }
219 
220   assert(ResourceMD != nullptr &&
221          "ResourceMD must have been set by the switch above.");
222 
223   llvm::hlsl::FrontendResource Res(
224       GV, TyName, RK, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
225   ResourceMD->addOperand(Res.getMetadata());
226 }
227 
228 void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
229   const Type *Ty = D->getType()->getPointeeOrArrayElementType();
230   if (!Ty)
231     return;
232   const auto *RD = Ty->getAsCXXRecordDecl();
233   if (!RD)
234     return;
235   const auto *Attr = RD->getAttr<HLSLResourceAttr>();
236   if (!Attr)
237     return;
238 
239   llvm::hlsl::ResourceClass RC = Attr->getResourceClass();
240   llvm::hlsl::ResourceKind RK = Attr->getResourceKind();
241   bool IsROV = Attr->getIsROV();
242 
243   QualType QT(Ty, 0);
244   BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
245   addBufferResourceAnnotation(GV, QT.getAsString(), RC, RK, IsROV, Binding);
246 }
247 
248 CGHLSLRuntime::BufferResBinding::BufferResBinding(
249     HLSLResourceBindingAttr *Binding) {
250   if (Binding) {
251     llvm::APInt RegInt(64, 0);
252     Binding->getSlot().substr(1).getAsInteger(10, RegInt);
253     Reg = RegInt.getLimitedValue();
254     llvm::APInt SpaceInt(64, 0);
255     Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
256     Space = SpaceInt.getLimitedValue();
257   } else {
258     Space = 0;
259   }
260 }
261 
262 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
263     const FunctionDecl *FD, llvm::Function *Fn) {
264   const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
265   assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
266   const StringRef ShaderAttrKindStr = "hlsl.shader";
267   Fn->addFnAttr(ShaderAttrKindStr,
268                 ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
269   if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
270     const StringRef NumThreadsKindStr = "hlsl.numthreads";
271     std::string NumThreadsStr =
272         formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
273                 NumThreadsAttr->getZ());
274     Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
275   }
276 }
277 
278 static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
279   if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
280     Value *Result = PoisonValue::get(Ty);
281     for (unsigned I = 0; I < VT->getNumElements(); ++I) {
282       Value *Elt = B.CreateCall(F, {B.getInt32(I)});
283       Result = B.CreateInsertElement(Result, Elt, I);
284     }
285     return Result;
286   }
287   return B.CreateCall(F, {B.getInt32(0)});
288 }
289 
290 llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
291                                               const ParmVarDecl &D,
292                                               llvm::Type *Ty) {
293   assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
294   if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
295     llvm::Function *DxGroupIndex =
296         CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
297     return B.CreateCall(FunctionCallee(DxGroupIndex));
298   }
299   if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
300     llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id);
301     return buildVectorInput(B, DxThreadID, Ty);
302   }
303   assert(false && "Unhandled parameter attribute");
304   return nullptr;
305 }
306 
307 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
308                                       llvm::Function *Fn) {
309   llvm::Module &M = CGM.getModule();
310   llvm::LLVMContext &Ctx = M.getContext();
311   auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
312   Function *EntryFn =
313       Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
314 
315   // Copy function attributes over, we have no argument or return attributes
316   // that can be valid on the real entry.
317   AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
318                                               Fn->getAttributes().getFnAttrs());
319   EntryFn->setAttributes(NewAttrs);
320   setHLSLEntryAttributes(FD, EntryFn);
321 
322   // Set the called function as internal linkage.
323   Fn->setLinkage(GlobalValue::InternalLinkage);
324 
325   BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
326   IRBuilder<> B(BB);
327   llvm::SmallVector<Value *> Args;
328   // FIXME: support struct parameters where semantics are on members.
329   // See: https://github.com/llvm/llvm-project/issues/57874
330   unsigned SRetOffset = 0;
331   for (const auto &Param : Fn->args()) {
332     if (Param.hasStructRetAttr()) {
333       // FIXME: support output.
334       // See: https://github.com/llvm/llvm-project/issues/57874
335       SRetOffset = 1;
336       Args.emplace_back(PoisonValue::get(Param.getType()));
337       continue;
338     }
339     const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
340     Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
341   }
342 
343   CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
344   (void)CI;
345   // FIXME: Handle codegen for return type semantics.
346   // See: https://github.com/llvm/llvm-project/issues/57875
347   B.CreateRetVoid();
348 }
349 
350 static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
351                             bool CtorOrDtor) {
352   const auto *GV =
353       M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
354   if (!GV)
355     return;
356   const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
357   if (!CA)
358     return;
359   // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
360   // HLSL neither supports priorities or COMDat values, so we will check those
361   // in an assert but not handle them.
362 
363   llvm::SmallVector<Function *> CtorFns;
364   for (const auto &Ctor : CA->operands()) {
365     if (isa<ConstantAggregateZero>(Ctor))
366       continue;
367     ConstantStruct *CS = cast<ConstantStruct>(Ctor);
368 
369     assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
370            "HLSL doesn't support setting priority for global ctors.");
371     assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
372            "HLSL doesn't support COMDat for global ctors.");
373     Fns.push_back(cast<Function>(CS->getOperand(1)));
374   }
375 }
376 
377 void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
378   llvm::Module &M = CGM.getModule();
379   SmallVector<Function *> CtorFns;
380   SmallVector<Function *> DtorFns;
381   gatherFunctions(CtorFns, M, true);
382   gatherFunctions(DtorFns, M, false);
383 
384   // Insert a call to the global constructor at the beginning of the entry block
385   // to externally exported functions. This is a bit of a hack, but HLSL allows
386   // global constructors, but doesn't support driver initialization of globals.
387   for (auto &F : M.functions()) {
388     if (!F.hasFnAttribute("hlsl.shader"))
389       continue;
390     IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
391     for (auto *Fn : CtorFns)
392       B.CreateCall(FunctionCallee(Fn));
393 
394     // Insert global dtors before the terminator of the last instruction
395     B.SetInsertPoint(F.back().getTerminator());
396     for (auto *Fn : DtorFns)
397       B.CreateCall(FunctionCallee(Fn));
398   }
399 
400   // No need to keep global ctors/dtors for non-lib profile after call to
401   // ctors/dtors added for entry.
402   Triple T(M.getTargetTriple());
403   if (T.getEnvironment() != Triple::EnvironmentType::Library) {
404     if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
405       GV->eraseFromParent();
406     if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
407       GV->eraseFromParent();
408   }
409 }
410