xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/CGHLSLRuntime.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This provides an abstract class for HLSL code generation.  Concrete
10 // subclasses of this implement code generation for specific HLSL
11 // runtime libraries.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "CGHLSLRuntime.h"
16 #include "CGDebugInfo.h"
17 #include "CodeGenModule.h"
18 #include "clang/AST/Decl.h"
19 #include "clang/Basic/TargetOptions.h"
20 #include "llvm/IR/Metadata.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 using namespace clang;
25 using namespace CodeGen;
26 using namespace clang::hlsl;
27 using namespace llvm;
28 
29 namespace {
30 
31 void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
32   // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
33   // Assume ValVersionStr is legal here.
34   VersionTuple Version;
35   if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
36       Version.getSubminor() || !Version.getMinor()) {
37     return;
38   }
39 
40   uint64_t Major = Version.getMajor();
41   uint64_t Minor = *Version.getMinor();
42 
43   auto &Ctx = M.getContext();
44   IRBuilder<> B(M.getContext());
45   MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
46                                   ConstantAsMetadata::get(B.getInt32(Minor))});
47   StringRef DXILValKey = "dx.valver";
48   auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
49   DXILValMD->addOperand(Val);
50 }
51 void addDisableOptimizations(llvm::Module &M) {
52   StringRef Key = "dx.disable_optimizations";
53   M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
54 }
55 // cbuffer will be translated into global variable in special address space.
56 // If translate into C,
57 // cbuffer A {
58 //   float a;
59 //   float b;
60 // }
61 // float foo() { return a + b; }
62 //
63 // will be translated into
64 //
65 // struct A {
66 //   float a;
67 //   float b;
68 // } cbuffer_A __attribute__((address_space(4)));
69 // float foo() { return cbuffer_A.a + cbuffer_A.b; }
70 //
71 // layoutBuffer will create the struct A type.
72 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
73 // and cbuffer_A.b.
74 //
75 void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
76   if (Buf.Constants.empty())
77     return;
78 
79   std::vector<llvm::Type *> EltTys;
80   for (auto &Const : Buf.Constants) {
81     GlobalVariable *GV = Const.first;
82     Const.second = EltTys.size();
83     llvm::Type *Ty = GV->getValueType();
84     EltTys.emplace_back(Ty);
85   }
86   Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
87 }
88 
89 GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
90   // Create global variable for CB.
91   GlobalVariable *CBGV = new GlobalVariable(
92       Buf.LayoutStruct, /*isConstant*/ true,
93       GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
94       llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
95       GlobalValue::NotThreadLocal);
96 
97   IRBuilder<> B(CBGV->getContext());
98   Value *ZeroIdx = B.getInt32(0);
99   // Replace Const use with CB use.
100   for (auto &[GV, Offset] : Buf.Constants) {
101     Value *GEP =
102         B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
103 
104     assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
105            "constant type mismatch");
106 
107     // Replace.
108     GV->replaceAllUsesWith(GEP);
109     // Erase GV.
110     GV->removeDeadConstantUsers();
111     GV->eraseFromParent();
112   }
113   return CBGV;
114 }
115 
116 } // namespace
117 
118 llvm::Triple::ArchType CGHLSLRuntime::getArch() {
119   return CGM.getTarget().getTriple().getArch();
120 }
121 
122 void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
123   if (D->getStorageClass() == SC_Static) {
124     // For static inside cbuffer, take as global static.
125     // Don't add to cbuffer.
126     CGM.EmitGlobal(D);
127     return;
128   }
129 
130   auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
131   // Add debug info for constVal.
132   if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
133     if (CGM.getCodeGenOpts().getDebugInfo() >=
134         codegenoptions::DebugInfoKind::LimitedDebugInfo)
135       DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
136 
137   // FIXME: support packoffset.
138   // See https://github.com/llvm/llvm-project/issues/57914.
139   uint32_t Offset = 0;
140   bool HasUserOffset = false;
141 
142   unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
143   CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
144 }
145 
146 void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
147   for (Decl *it : DC->decls()) {
148     if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
149       addConstant(ConstDecl, CB);
150     } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
151       // Nothing to do for this declaration.
152     } else if (isa<FunctionDecl>(it)) {
153       // A function within an cbuffer is effectively a top-level function,
154       // as it only refers to globally scoped declarations.
155       CGM.EmitTopLevelDecl(it);
156     }
157   }
158 }
159 
160 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
161   Buffers.emplace_back(Buffer(D));
162   addBufferDecls(D, Buffers.back());
163 }
164 
165 void CGHLSLRuntime::finishCodeGen() {
166   auto &TargetOpts = CGM.getTarget().getTargetOpts();
167   llvm::Module &M = CGM.getModule();
168   Triple T(M.getTargetTriple());
169   if (T.getArch() == Triple::ArchType::dxil)
170     addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
171 
172   generateGlobalCtorDtorCalls();
173   if (CGM.getCodeGenOpts().OptimizationLevel == 0)
174     addDisableOptimizations(M);
175 
176   const DataLayout &DL = M.getDataLayout();
177 
178   for (auto &Buf : Buffers) {
179     layoutBuffer(Buf, DL);
180     GlobalVariable *GV = replaceBuffer(Buf);
181     M.insertGlobalVariable(GV);
182     llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
183                                        ? llvm::hlsl::ResourceClass::CBuffer
184                                        : llvm::hlsl::ResourceClass::SRV;
185     llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
186                                       ? llvm::hlsl::ResourceKind::CBuffer
187                                       : llvm::hlsl::ResourceKind::TBuffer;
188     addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,
189                                 llvm::hlsl::ElementType::Invalid, Buf.Binding);
190   }
191 }
192 
193 CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
194     : Name(D->getName()), IsCBuffer(D->isCBuffer()),
195       Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
196 
197 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
198                                                 llvm::hlsl::ResourceClass RC,
199                                                 llvm::hlsl::ResourceKind RK,
200                                                 bool IsROV,
201                                                 llvm::hlsl::ElementType ET,
202                                                 BufferResBinding &Binding) {
203   llvm::Module &M = CGM.getModule();
204 
205   NamedMDNode *ResourceMD = nullptr;
206   switch (RC) {
207   case llvm::hlsl::ResourceClass::UAV:
208     ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
209     break;
210   case llvm::hlsl::ResourceClass::SRV:
211     ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
212     break;
213   case llvm::hlsl::ResourceClass::CBuffer:
214     ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
215     break;
216   default:
217     assert(false && "Unsupported buffer type!");
218     return;
219   }
220   assert(ResourceMD != nullptr &&
221          "ResourceMD must have been set by the switch above.");
222 
223   llvm::hlsl::FrontendResource Res(
224       GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
225   ResourceMD->addOperand(Res.getMetadata());
226 }
227 
228 static llvm::hlsl::ElementType
229 calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) {
230   using llvm::hlsl::ElementType;
231 
232   // TODO: We may need to update this when we add things like ByteAddressBuffer
233   // that don't have a template parameter (or, indeed, an element type).
234   const auto *TST = ResourceTy->getAs<TemplateSpecializationType>();
235   assert(TST && "Resource types must be template specializations");
236   ArrayRef<TemplateArgument> Args = TST->template_arguments();
237   assert(!Args.empty() && "Resource has no element type");
238 
239   // At this point we have a resource with an element type, so we can assume
240   // that it's valid or we would have diagnosed the error earlier.
241   QualType ElTy = Args[0].getAsType();
242 
243   // We should either have a basic type or a vector of a basic type.
244   if (const auto *VecTy = ElTy->getAs<clang::VectorType>())
245     ElTy = VecTy->getElementType();
246 
247   if (ElTy->isSignedIntegerType()) {
248     switch (Context.getTypeSize(ElTy)) {
249     case 16:
250       return ElementType::I16;
251     case 32:
252       return ElementType::I32;
253     case 64:
254       return ElementType::I64;
255     }
256   } else if (ElTy->isUnsignedIntegerType()) {
257     switch (Context.getTypeSize(ElTy)) {
258     case 16:
259       return ElementType::U16;
260     case 32:
261       return ElementType::U32;
262     case 64:
263       return ElementType::U64;
264     }
265   } else if (ElTy->isSpecificBuiltinType(BuiltinType::Half))
266     return ElementType::F16;
267   else if (ElTy->isSpecificBuiltinType(BuiltinType::Float))
268     return ElementType::F32;
269   else if (ElTy->isSpecificBuiltinType(BuiltinType::Double))
270     return ElementType::F64;
271 
272   // TODO: We need to handle unorm/snorm float types here once we support them
273   llvm_unreachable("Invalid element type for resource");
274 }
275 
276 void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
277   const Type *Ty = D->getType()->getPointeeOrArrayElementType();
278   if (!Ty)
279     return;
280   const auto *RD = Ty->getAsCXXRecordDecl();
281   if (!RD)
282     return;
283   const auto *HLSLResAttr = RD->getAttr<HLSLResourceAttr>();
284   const auto *HLSLResClassAttr = RD->getAttr<HLSLResourceClassAttr>();
285   if (!HLSLResAttr || !HLSLResClassAttr)
286     return;
287 
288   llvm::hlsl::ResourceClass RC = HLSLResClassAttr->getResourceClass();
289   llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind();
290   bool IsROV = HLSLResAttr->getIsROV();
291   llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty);
292 
293   BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
294   addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
295 }
296 
297 CGHLSLRuntime::BufferResBinding::BufferResBinding(
298     HLSLResourceBindingAttr *Binding) {
299   if (Binding) {
300     llvm::APInt RegInt(64, 0);
301     Binding->getSlot().substr(1).getAsInteger(10, RegInt);
302     Reg = RegInt.getLimitedValue();
303     llvm::APInt SpaceInt(64, 0);
304     Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
305     Space = SpaceInt.getLimitedValue();
306   } else {
307     Space = 0;
308   }
309 }
310 
311 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
312     const FunctionDecl *FD, llvm::Function *Fn) {
313   const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
314   assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
315   const StringRef ShaderAttrKindStr = "hlsl.shader";
316   Fn->addFnAttr(ShaderAttrKindStr,
317                 llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType()));
318   if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
319     const StringRef NumThreadsKindStr = "hlsl.numthreads";
320     std::string NumThreadsStr =
321         formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
322                 NumThreadsAttr->getZ());
323     Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
324   }
325 }
326 
327 static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
328   if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
329     Value *Result = PoisonValue::get(Ty);
330     for (unsigned I = 0; I < VT->getNumElements(); ++I) {
331       Value *Elt = B.CreateCall(F, {B.getInt32(I)});
332       Result = B.CreateInsertElement(Result, Elt, I);
333     }
334     return Result;
335   }
336   return B.CreateCall(F, {B.getInt32(0)});
337 }
338 
339 llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
340                                               const ParmVarDecl &D,
341                                               llvm::Type *Ty) {
342   assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
343   if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
344     llvm::Function *DxGroupIndex =
345         CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
346     return B.CreateCall(FunctionCallee(DxGroupIndex));
347   }
348   if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
349     llvm::Function *ThreadIDIntrinsic =
350         CGM.getIntrinsic(getThreadIdIntrinsic());
351     return buildVectorInput(B, ThreadIDIntrinsic, Ty);
352   }
353   assert(false && "Unhandled parameter attribute");
354   return nullptr;
355 }
356 
357 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD,
358                                       llvm::Function *Fn) {
359   llvm::Module &M = CGM.getModule();
360   llvm::LLVMContext &Ctx = M.getContext();
361   auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
362   Function *EntryFn =
363       Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
364 
365   // Copy function attributes over, we have no argument or return attributes
366   // that can be valid on the real entry.
367   AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
368                                               Fn->getAttributes().getFnAttrs());
369   EntryFn->setAttributes(NewAttrs);
370   setHLSLEntryAttributes(FD, EntryFn);
371 
372   // Set the called function as internal linkage.
373   Fn->setLinkage(GlobalValue::InternalLinkage);
374 
375   BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
376   IRBuilder<> B(BB);
377   llvm::SmallVector<Value *> Args;
378   // FIXME: support struct parameters where semantics are on members.
379   // See: https://github.com/llvm/llvm-project/issues/57874
380   unsigned SRetOffset = 0;
381   for (const auto &Param : Fn->args()) {
382     if (Param.hasStructRetAttr()) {
383       // FIXME: support output.
384       // See: https://github.com/llvm/llvm-project/issues/57874
385       SRetOffset = 1;
386       Args.emplace_back(PoisonValue::get(Param.getType()));
387       continue;
388     }
389     const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
390     Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
391   }
392 
393   CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
394   (void)CI;
395   // FIXME: Handle codegen for return type semantics.
396   // See: https://github.com/llvm/llvm-project/issues/57875
397   B.CreateRetVoid();
398 }
399 
400 static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
401                             bool CtorOrDtor) {
402   const auto *GV =
403       M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
404   if (!GV)
405     return;
406   const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
407   if (!CA)
408     return;
409   // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
410   // HLSL neither supports priorities or COMDat values, so we will check those
411   // in an assert but not handle them.
412 
413   llvm::SmallVector<Function *> CtorFns;
414   for (const auto &Ctor : CA->operands()) {
415     if (isa<ConstantAggregateZero>(Ctor))
416       continue;
417     ConstantStruct *CS = cast<ConstantStruct>(Ctor);
418 
419     assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
420            "HLSL doesn't support setting priority for global ctors.");
421     assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
422            "HLSL doesn't support COMDat for global ctors.");
423     Fns.push_back(cast<Function>(CS->getOperand(1)));
424   }
425 }
426 
427 void CGHLSLRuntime::generateGlobalCtorDtorCalls() {
428   llvm::Module &M = CGM.getModule();
429   SmallVector<Function *> CtorFns;
430   SmallVector<Function *> DtorFns;
431   gatherFunctions(CtorFns, M, true);
432   gatherFunctions(DtorFns, M, false);
433 
434   // Insert a call to the global constructor at the beginning of the entry block
435   // to externally exported functions. This is a bit of a hack, but HLSL allows
436   // global constructors, but doesn't support driver initialization of globals.
437   for (auto &F : M.functions()) {
438     if (!F.hasFnAttribute("hlsl.shader"))
439       continue;
440     IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
441     for (auto *Fn : CtorFns)
442       B.CreateCall(FunctionCallee(Fn));
443 
444     // Insert global dtors before the terminator of the last instruction
445     B.SetInsertPoint(F.back().getTerminator());
446     for (auto *Fn : DtorFns)
447       B.CreateCall(FunctionCallee(Fn));
448   }
449 
450   // No need to keep global ctors/dtors for non-lib profile after call to
451   // ctors/dtors added for entry.
452   Triple T(M.getTargetTriple());
453   if (T.getEnvironment() != Triple::EnvironmentType::Library) {
454     if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
455       GV->eraseFromParent();
456     if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
457       GV->eraseFromParent();
458   }
459 }
460