1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This provides an abstract class for HLSL code generation. Concrete 10 // subclasses of this implement code generation for specific HLSL 11 // runtime libraries. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "CGHLSLRuntime.h" 16 #include "CGDebugInfo.h" 17 #include "CodeGenModule.h" 18 #include "clang/AST/Decl.h" 19 #include "clang/Basic/TargetOptions.h" 20 #include "llvm/IR/IntrinsicsDirectX.h" 21 #include "llvm/IR/Metadata.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/Support/FormatVariadic.h" 24 25 using namespace clang; 26 using namespace CodeGen; 27 using namespace clang::hlsl; 28 using namespace llvm; 29 30 namespace { 31 32 void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) { 33 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs. 34 // Assume ValVersionStr is legal here. 35 VersionTuple Version; 36 if (Version.tryParse(ValVersionStr) || Version.getBuild() || 37 Version.getSubminor() || !Version.getMinor()) { 38 return; 39 } 40 41 uint64_t Major = Version.getMajor(); 42 uint64_t Minor = *Version.getMinor(); 43 44 auto &Ctx = M.getContext(); 45 IRBuilder<> B(M.getContext()); 46 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)), 47 ConstantAsMetadata::get(B.getInt32(Minor))}); 48 StringRef DXILValKey = "dx.valver"; 49 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey); 50 DXILValMD->addOperand(Val); 51 } 52 void addDisableOptimizations(llvm::Module &M) { 53 StringRef Key = "dx.disable_optimizations"; 54 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1); 55 } 56 // cbuffer will be translated into global variable in special address space. 57 // If translate into C, 58 // cbuffer A { 59 // float a; 60 // float b; 61 // } 62 // float foo() { return a + b; } 63 // 64 // will be translated into 65 // 66 // struct A { 67 // float a; 68 // float b; 69 // } cbuffer_A __attribute__((address_space(4))); 70 // float foo() { return cbuffer_A.a + cbuffer_A.b; } 71 // 72 // layoutBuffer will create the struct A type. 73 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a 74 // and cbuffer_A.b. 75 // 76 void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) { 77 if (Buf.Constants.empty()) 78 return; 79 80 std::vector<llvm::Type *> EltTys; 81 for (auto &Const : Buf.Constants) { 82 GlobalVariable *GV = Const.first; 83 Const.second = EltTys.size(); 84 llvm::Type *Ty = GV->getValueType(); 85 EltTys.emplace_back(Ty); 86 } 87 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys); 88 } 89 90 GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) { 91 // Create global variable for CB. 92 GlobalVariable *CBGV = new GlobalVariable( 93 Buf.LayoutStruct, /*isConstant*/ true, 94 GlobalValue::LinkageTypes::ExternalLinkage, nullptr, 95 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."), 96 GlobalValue::NotThreadLocal); 97 98 IRBuilder<> B(CBGV->getContext()); 99 Value *ZeroIdx = B.getInt32(0); 100 // Replace Const use with CB use. 101 for (auto &[GV, Offset] : Buf.Constants) { 102 Value *GEP = 103 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)}); 104 105 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() && 106 "constant type mismatch"); 107 108 // Replace. 109 GV->replaceAllUsesWith(GEP); 110 // Erase GV. 111 GV->removeDeadConstantUsers(); 112 GV->eraseFromParent(); 113 } 114 return CBGV; 115 } 116 117 } // namespace 118 119 void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) { 120 if (D->getStorageClass() == SC_Static) { 121 // For static inside cbuffer, take as global static. 122 // Don't add to cbuffer. 123 CGM.EmitGlobal(D); 124 return; 125 } 126 127 auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D)); 128 // Add debug info for constVal. 129 if (CGDebugInfo *DI = CGM.getModuleDebugInfo()) 130 if (CGM.getCodeGenOpts().getDebugInfo() >= 131 codegenoptions::DebugInfoKind::LimitedDebugInfo) 132 DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D); 133 134 // FIXME: support packoffset. 135 // See https://github.com/llvm/llvm-project/issues/57914. 136 uint32_t Offset = 0; 137 bool HasUserOffset = false; 138 139 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX; 140 CB.Constants.emplace_back(std::make_pair(GV, LowerBound)); 141 } 142 143 void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) { 144 for (Decl *it : DC->decls()) { 145 if (auto *ConstDecl = dyn_cast<VarDecl>(it)) { 146 addConstant(ConstDecl, CB); 147 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) { 148 // Nothing to do for this declaration. 149 } else if (isa<FunctionDecl>(it)) { 150 // A function within an cbuffer is effectively a top-level function, 151 // as it only refers to globally scoped declarations. 152 CGM.EmitTopLevelDecl(it); 153 } 154 } 155 } 156 157 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) { 158 Buffers.emplace_back(Buffer(D)); 159 addBufferDecls(D, Buffers.back()); 160 } 161 162 void CGHLSLRuntime::finishCodeGen() { 163 auto &TargetOpts = CGM.getTarget().getTargetOpts(); 164 llvm::Module &M = CGM.getModule(); 165 Triple T(M.getTargetTriple()); 166 if (T.getArch() == Triple::ArchType::dxil) 167 addDxilValVersion(TargetOpts.DxilValidatorVersion, M); 168 169 generateGlobalCtorDtorCalls(); 170 if (CGM.getCodeGenOpts().OptimizationLevel == 0) 171 addDisableOptimizations(M); 172 173 const DataLayout &DL = M.getDataLayout(); 174 175 for (auto &Buf : Buffers) { 176 layoutBuffer(Buf, DL); 177 GlobalVariable *GV = replaceBuffer(Buf); 178 M.insertGlobalVariable(GV); 179 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer 180 ? llvm::hlsl::ResourceClass::CBuffer 181 : llvm::hlsl::ResourceClass::SRV; 182 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer 183 ? llvm::hlsl::ResourceKind::CBuffer 184 : llvm::hlsl::ResourceKind::TBuffer; 185 std::string TyName = 186 Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty"; 187 addBufferResourceAnnotation(GV, TyName, RC, RK, /*IsROV=*/false, 188 Buf.Binding); 189 } 190 } 191 192 CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D) 193 : Name(D->getName()), IsCBuffer(D->isCBuffer()), 194 Binding(D->getAttr<HLSLResourceBindingAttr>()) {} 195 196 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV, 197 llvm::StringRef TyName, 198 llvm::hlsl::ResourceClass RC, 199 llvm::hlsl::ResourceKind RK, 200 bool IsROV, 201 BufferResBinding &Binding) { 202 llvm::Module &M = CGM.getModule(); 203 204 NamedMDNode *ResourceMD = nullptr; 205 switch (RC) { 206 case llvm::hlsl::ResourceClass::UAV: 207 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs"); 208 break; 209 case llvm::hlsl::ResourceClass::SRV: 210 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs"); 211 break; 212 case llvm::hlsl::ResourceClass::CBuffer: 213 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs"); 214 break; 215 default: 216 assert(false && "Unsupported buffer type!"); 217 return; 218 } 219 220 assert(ResourceMD != nullptr && 221 "ResourceMD must have been set by the switch above."); 222 223 llvm::hlsl::FrontendResource Res( 224 GV, TyName, RK, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space); 225 ResourceMD->addOperand(Res.getMetadata()); 226 } 227 228 void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) { 229 const Type *Ty = D->getType()->getPointeeOrArrayElementType(); 230 if (!Ty) 231 return; 232 const auto *RD = Ty->getAsCXXRecordDecl(); 233 if (!RD) 234 return; 235 const auto *Attr = RD->getAttr<HLSLResourceAttr>(); 236 if (!Attr) 237 return; 238 239 llvm::hlsl::ResourceClass RC = Attr->getResourceClass(); 240 llvm::hlsl::ResourceKind RK = Attr->getResourceKind(); 241 bool IsROV = Attr->getIsROV(); 242 243 QualType QT(Ty, 0); 244 BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>()); 245 addBufferResourceAnnotation(GV, QT.getAsString(), RC, RK, IsROV, Binding); 246 } 247 248 CGHLSLRuntime::BufferResBinding::BufferResBinding( 249 HLSLResourceBindingAttr *Binding) { 250 if (Binding) { 251 llvm::APInt RegInt(64, 0); 252 Binding->getSlot().substr(1).getAsInteger(10, RegInt); 253 Reg = RegInt.getLimitedValue(); 254 llvm::APInt SpaceInt(64, 0); 255 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt); 256 Space = SpaceInt.getLimitedValue(); 257 } else { 258 Space = 0; 259 } 260 } 261 262 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes( 263 const FunctionDecl *FD, llvm::Function *Fn) { 264 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); 265 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr"); 266 const StringRef ShaderAttrKindStr = "hlsl.shader"; 267 Fn->addFnAttr(ShaderAttrKindStr, 268 ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); 269 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) { 270 const StringRef NumThreadsKindStr = "hlsl.numthreads"; 271 std::string NumThreadsStr = 272 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(), 273 NumThreadsAttr->getZ()); 274 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr); 275 } 276 } 277 278 static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) { 279 if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) { 280 Value *Result = PoisonValue::get(Ty); 281 for (unsigned I = 0; I < VT->getNumElements(); ++I) { 282 Value *Elt = B.CreateCall(F, {B.getInt32(I)}); 283 Result = B.CreateInsertElement(Result, Elt, I); 284 } 285 return Result; 286 } 287 return B.CreateCall(F, {B.getInt32(0)}); 288 } 289 290 llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, 291 const ParmVarDecl &D, 292 llvm::Type *Ty) { 293 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!"); 294 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) { 295 llvm::Function *DxGroupIndex = 296 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group); 297 return B.CreateCall(FunctionCallee(DxGroupIndex)); 298 } 299 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) { 300 llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id); 301 return buildVectorInput(B, DxThreadID, Ty); 302 } 303 assert(false && "Unhandled parameter attribute"); 304 return nullptr; 305 } 306 307 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, 308 llvm::Function *Fn) { 309 llvm::Module &M = CGM.getModule(); 310 llvm::LLVMContext &Ctx = M.getContext(); 311 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); 312 Function *EntryFn = 313 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M); 314 315 // Copy function attributes over, we have no argument or return attributes 316 // that can be valid on the real entry. 317 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex, 318 Fn->getAttributes().getFnAttrs()); 319 EntryFn->setAttributes(NewAttrs); 320 setHLSLEntryAttributes(FD, EntryFn); 321 322 // Set the called function as internal linkage. 323 Fn->setLinkage(GlobalValue::InternalLinkage); 324 325 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn); 326 IRBuilder<> B(BB); 327 llvm::SmallVector<Value *> Args; 328 // FIXME: support struct parameters where semantics are on members. 329 // See: https://github.com/llvm/llvm-project/issues/57874 330 unsigned SRetOffset = 0; 331 for (const auto &Param : Fn->args()) { 332 if (Param.hasStructRetAttr()) { 333 // FIXME: support output. 334 // See: https://github.com/llvm/llvm-project/issues/57874 335 SRetOffset = 1; 336 Args.emplace_back(PoisonValue::get(Param.getType())); 337 continue; 338 } 339 const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset); 340 Args.push_back(emitInputSemantic(B, *PD, Param.getType())); 341 } 342 343 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args); 344 (void)CI; 345 // FIXME: Handle codegen for return type semantics. 346 // See: https://github.com/llvm/llvm-project/issues/57875 347 B.CreateRetVoid(); 348 } 349 350 static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M, 351 bool CtorOrDtor) { 352 const auto *GV = 353 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors"); 354 if (!GV) 355 return; 356 const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer()); 357 if (!CA) 358 return; 359 // The global_ctor array elements are a struct [Priority, Fn *, COMDat]. 360 // HLSL neither supports priorities or COMDat values, so we will check those 361 // in an assert but not handle them. 362 363 llvm::SmallVector<Function *> CtorFns; 364 for (const auto &Ctor : CA->operands()) { 365 if (isa<ConstantAggregateZero>(Ctor)) 366 continue; 367 ConstantStruct *CS = cast<ConstantStruct>(Ctor); 368 369 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 && 370 "HLSL doesn't support setting priority for global ctors."); 371 assert(isa<ConstantPointerNull>(CS->getOperand(2)) && 372 "HLSL doesn't support COMDat for global ctors."); 373 Fns.push_back(cast<Function>(CS->getOperand(1))); 374 } 375 } 376 377 void CGHLSLRuntime::generateGlobalCtorDtorCalls() { 378 llvm::Module &M = CGM.getModule(); 379 SmallVector<Function *> CtorFns; 380 SmallVector<Function *> DtorFns; 381 gatherFunctions(CtorFns, M, true); 382 gatherFunctions(DtorFns, M, false); 383 384 // Insert a call to the global constructor at the beginning of the entry block 385 // to externally exported functions. This is a bit of a hack, but HLSL allows 386 // global constructors, but doesn't support driver initialization of globals. 387 for (auto &F : M.functions()) { 388 if (!F.hasFnAttribute("hlsl.shader")) 389 continue; 390 IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin()); 391 for (auto *Fn : CtorFns) 392 B.CreateCall(FunctionCallee(Fn)); 393 394 // Insert global dtors before the terminator of the last instruction 395 B.SetInsertPoint(F.back().getTerminator()); 396 for (auto *Fn : DtorFns) 397 B.CreateCall(FunctionCallee(Fn)); 398 } 399 400 // No need to keep global ctors/dtors for non-lib profile after call to 401 // ctors/dtors added for entry. 402 Triple T(M.getTargetTriple()); 403 if (T.getEnvironment() != Triple::EnvironmentType::Library) { 404 if (auto *GV = M.getNamedGlobal("llvm.global_ctors")) 405 GV->eraseFromParent(); 406 if (auto *GV = M.getNamedGlobal("llvm.global_dtors")) 407 GV->eraseFromParent(); 408 } 409 } 410