1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This provides an abstract class for HLSL code generation. Concrete 10 // subclasses of this implement code generation for specific HLSL 11 // runtime libraries. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "CGHLSLRuntime.h" 16 #include "CGDebugInfo.h" 17 #include "CodeGenModule.h" 18 #include "clang/AST/Decl.h" 19 #include "clang/Basic/TargetOptions.h" 20 #include "llvm/IR/IntrinsicsDirectX.h" 21 #include "llvm/IR/Metadata.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/Support/FormatVariadic.h" 24 25 using namespace clang; 26 using namespace CodeGen; 27 using namespace clang::hlsl; 28 using namespace llvm; 29 30 namespace { 31 32 void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) { 33 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs. 34 // Assume ValVersionStr is legal here. 35 VersionTuple Version; 36 if (Version.tryParse(ValVersionStr) || Version.getBuild() || 37 Version.getSubminor() || !Version.getMinor()) { 38 return; 39 } 40 41 uint64_t Major = Version.getMajor(); 42 uint64_t Minor = *Version.getMinor(); 43 44 auto &Ctx = M.getContext(); 45 IRBuilder<> B(M.getContext()); 46 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)), 47 ConstantAsMetadata::get(B.getInt32(Minor))}); 48 StringRef DXILValKey = "dx.valver"; 49 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey); 50 DXILValMD->addOperand(Val); 51 } 52 void addDisableOptimizations(llvm::Module &M) { 53 StringRef Key = "dx.disable_optimizations"; 54 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1); 55 } 56 // cbuffer will be translated into global variable in special address space. 57 // If translate into C, 58 // cbuffer A { 59 // float a; 60 // float b; 61 // } 62 // float foo() { return a + b; } 63 // 64 // will be translated into 65 // 66 // struct A { 67 // float a; 68 // float b; 69 // } cbuffer_A __attribute__((address_space(4))); 70 // float foo() { return cbuffer_A.a + cbuffer_A.b; } 71 // 72 // layoutBuffer will create the struct A type. 73 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a 74 // and cbuffer_A.b. 75 // 76 void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) { 77 if (Buf.Constants.empty()) 78 return; 79 80 std::vector<llvm::Type *> EltTys; 81 for (auto &Const : Buf.Constants) { 82 GlobalVariable *GV = Const.first; 83 Const.second = EltTys.size(); 84 llvm::Type *Ty = GV->getValueType(); 85 EltTys.emplace_back(Ty); 86 } 87 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys); 88 } 89 90 GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) { 91 // Create global variable for CB. 92 GlobalVariable *CBGV = new GlobalVariable( 93 Buf.LayoutStruct, /*isConstant*/ true, 94 GlobalValue::LinkageTypes::ExternalLinkage, nullptr, 95 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."), 96 GlobalValue::NotThreadLocal); 97 98 IRBuilder<> B(CBGV->getContext()); 99 Value *ZeroIdx = B.getInt32(0); 100 // Replace Const use with CB use. 101 for (auto &[GV, Offset] : Buf.Constants) { 102 Value *GEP = 103 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)}); 104 105 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() && 106 "constant type mismatch"); 107 108 // Replace. 109 GV->replaceAllUsesWith(GEP); 110 // Erase GV. 111 GV->removeDeadConstantUsers(); 112 GV->eraseFromParent(); 113 } 114 return CBGV; 115 } 116 117 } // namespace 118 119 void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) { 120 if (D->getStorageClass() == SC_Static) { 121 // For static inside cbuffer, take as global static. 122 // Don't add to cbuffer. 123 CGM.EmitGlobal(D); 124 return; 125 } 126 127 auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D)); 128 // Add debug info for constVal. 129 if (CGDebugInfo *DI = CGM.getModuleDebugInfo()) 130 if (CGM.getCodeGenOpts().getDebugInfo() >= 131 codegenoptions::DebugInfoKind::LimitedDebugInfo) 132 DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D); 133 134 // FIXME: support packoffset. 135 // See https://github.com/llvm/llvm-project/issues/57914. 136 uint32_t Offset = 0; 137 bool HasUserOffset = false; 138 139 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX; 140 CB.Constants.emplace_back(std::make_pair(GV, LowerBound)); 141 } 142 143 void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) { 144 for (Decl *it : DC->decls()) { 145 if (auto *ConstDecl = dyn_cast<VarDecl>(it)) { 146 addConstant(ConstDecl, CB); 147 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) { 148 // Nothing to do for this declaration. 149 } else if (isa<FunctionDecl>(it)) { 150 // A function within an cbuffer is effectively a top-level function, 151 // as it only refers to globally scoped declarations. 152 CGM.EmitTopLevelDecl(it); 153 } 154 } 155 } 156 157 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) { 158 Buffers.emplace_back(Buffer(D)); 159 addBufferDecls(D, Buffers.back()); 160 } 161 162 void CGHLSLRuntime::finishCodeGen() { 163 auto &TargetOpts = CGM.getTarget().getTargetOpts(); 164 llvm::Module &M = CGM.getModule(); 165 Triple T(M.getTargetTriple()); 166 if (T.getArch() == Triple::ArchType::dxil) 167 addDxilValVersion(TargetOpts.DxilValidatorVersion, M); 168 169 generateGlobalCtorDtorCalls(); 170 if (CGM.getCodeGenOpts().OptimizationLevel == 0) 171 addDisableOptimizations(M); 172 173 const DataLayout &DL = M.getDataLayout(); 174 175 for (auto &Buf : Buffers) { 176 layoutBuffer(Buf, DL); 177 GlobalVariable *GV = replaceBuffer(Buf); 178 M.getGlobalList().push_back(GV); 179 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer 180 ? llvm::hlsl::ResourceClass::CBuffer 181 : llvm::hlsl::ResourceClass::SRV; 182 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer 183 ? llvm::hlsl::ResourceKind::CBuffer 184 : llvm::hlsl::ResourceKind::TBuffer; 185 std::string TyName = 186 Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty"; 187 addBufferResourceAnnotation(GV, TyName, RC, RK, Buf.Binding); 188 } 189 } 190 191 CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D) 192 : Name(D->getName()), IsCBuffer(D->isCBuffer()), 193 Binding(D->getAttr<HLSLResourceBindingAttr>()) {} 194 195 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV, 196 llvm::StringRef TyName, 197 llvm::hlsl::ResourceClass RC, 198 llvm::hlsl::ResourceKind RK, 199 BufferResBinding &Binding) { 200 llvm::Module &M = CGM.getModule(); 201 202 NamedMDNode *ResourceMD = nullptr; 203 switch (RC) { 204 case llvm::hlsl::ResourceClass::UAV: 205 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs"); 206 break; 207 case llvm::hlsl::ResourceClass::SRV: 208 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs"); 209 break; 210 case llvm::hlsl::ResourceClass::CBuffer: 211 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs"); 212 break; 213 default: 214 assert(false && "Unsupported buffer type!"); 215 return; 216 } 217 218 assert(ResourceMD != nullptr && 219 "ResourceMD must have been set by the switch above."); 220 221 llvm::hlsl::FrontendResource Res( 222 GV, TyName, RK, Binding.Reg.value_or(UINT_MAX), Binding.Space); 223 ResourceMD->addOperand(Res.getMetadata()); 224 } 225 226 static llvm::hlsl::ResourceKind 227 castResourceShapeToResourceKind(HLSLResourceAttr::ResourceKind RK) { 228 switch (RK) { 229 case HLSLResourceAttr::ResourceKind::Texture1D: 230 return llvm::hlsl::ResourceKind::Texture1D; 231 case HLSLResourceAttr::ResourceKind::Texture2D: 232 return llvm::hlsl::ResourceKind::Texture2D; 233 case HLSLResourceAttr::ResourceKind::Texture2DMS: 234 return llvm::hlsl::ResourceKind::Texture2DMS; 235 case HLSLResourceAttr::ResourceKind::Texture3D: 236 return llvm::hlsl::ResourceKind::Texture3D; 237 case HLSLResourceAttr::ResourceKind::TextureCube: 238 return llvm::hlsl::ResourceKind::TextureCube; 239 case HLSLResourceAttr::ResourceKind::Texture1DArray: 240 return llvm::hlsl::ResourceKind::Texture1DArray; 241 case HLSLResourceAttr::ResourceKind::Texture2DArray: 242 return llvm::hlsl::ResourceKind::Texture2DArray; 243 case HLSLResourceAttr::ResourceKind::Texture2DMSArray: 244 return llvm::hlsl::ResourceKind::Texture2DMSArray; 245 case HLSLResourceAttr::ResourceKind::TextureCubeArray: 246 return llvm::hlsl::ResourceKind::TextureCubeArray; 247 case HLSLResourceAttr::ResourceKind::TypedBuffer: 248 return llvm::hlsl::ResourceKind::TypedBuffer; 249 case HLSLResourceAttr::ResourceKind::RawBuffer: 250 return llvm::hlsl::ResourceKind::RawBuffer; 251 case HLSLResourceAttr::ResourceKind::StructuredBuffer: 252 return llvm::hlsl::ResourceKind::StructuredBuffer; 253 case HLSLResourceAttr::ResourceKind::CBufferKind: 254 return llvm::hlsl::ResourceKind::CBuffer; 255 case HLSLResourceAttr::ResourceKind::SamplerKind: 256 return llvm::hlsl::ResourceKind::Sampler; 257 case HLSLResourceAttr::ResourceKind::TBuffer: 258 return llvm::hlsl::ResourceKind::TBuffer; 259 case HLSLResourceAttr::ResourceKind::RTAccelerationStructure: 260 return llvm::hlsl::ResourceKind::RTAccelerationStructure; 261 case HLSLResourceAttr::ResourceKind::FeedbackTexture2D: 262 return llvm::hlsl::ResourceKind::FeedbackTexture2D; 263 case HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray: 264 return llvm::hlsl::ResourceKind::FeedbackTexture2DArray; 265 } 266 // Make sure to update HLSLResourceAttr::ResourceKind when add new Kind to 267 // hlsl::ResourceKind. Assume FeedbackTexture2DArray is the last enum for 268 // HLSLResourceAttr::ResourceKind. 269 static_assert( 270 static_cast<uint32_t>( 271 HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray) == 272 (static_cast<uint32_t>(llvm::hlsl::ResourceKind::NumEntries) - 2)); 273 llvm_unreachable("all switch cases should be covered"); 274 } 275 276 void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) { 277 const Type *Ty = D->getType()->getPointeeOrArrayElementType(); 278 if (!Ty) 279 return; 280 const auto *RD = Ty->getAsCXXRecordDecl(); 281 if (!RD) 282 return; 283 const auto *Attr = RD->getAttr<HLSLResourceAttr>(); 284 if (!Attr) 285 return; 286 287 HLSLResourceAttr::ResourceClass RC = Attr->getResourceType(); 288 llvm::hlsl::ResourceKind RK = 289 castResourceShapeToResourceKind(Attr->getResourceShape()); 290 291 QualType QT(Ty, 0); 292 BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>()); 293 addBufferResourceAnnotation(GV, QT.getAsString(), 294 static_cast<llvm::hlsl::ResourceClass>(RC), RK, 295 Binding); 296 } 297 298 CGHLSLRuntime::BufferResBinding::BufferResBinding( 299 HLSLResourceBindingAttr *Binding) { 300 if (Binding) { 301 llvm::APInt RegInt(64, 0); 302 Binding->getSlot().substr(1).getAsInteger(10, RegInt); 303 Reg = RegInt.getLimitedValue(); 304 llvm::APInt SpaceInt(64, 0); 305 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt); 306 Space = SpaceInt.getLimitedValue(); 307 } else { 308 Space = 0; 309 } 310 } 311 312 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes( 313 const FunctionDecl *FD, llvm::Function *Fn) { 314 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); 315 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr"); 316 const StringRef ShaderAttrKindStr = "hlsl.shader"; 317 Fn->addFnAttr(ShaderAttrKindStr, 318 ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); 319 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) { 320 const StringRef NumThreadsKindStr = "hlsl.numthreads"; 321 std::string NumThreadsStr = 322 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(), 323 NumThreadsAttr->getZ()); 324 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr); 325 } 326 } 327 328 static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) { 329 if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) { 330 Value *Result = PoisonValue::get(Ty); 331 for (unsigned I = 0; I < VT->getNumElements(); ++I) { 332 Value *Elt = B.CreateCall(F, {B.getInt32(I)}); 333 Result = B.CreateInsertElement(Result, Elt, I); 334 } 335 return Result; 336 } 337 return B.CreateCall(F, {B.getInt32(0)}); 338 } 339 340 llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, 341 const ParmVarDecl &D, 342 llvm::Type *Ty) { 343 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!"); 344 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) { 345 llvm::Function *DxGroupIndex = 346 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group); 347 return B.CreateCall(FunctionCallee(DxGroupIndex)); 348 } 349 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) { 350 llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id); 351 return buildVectorInput(B, DxThreadID, Ty); 352 } 353 assert(false && "Unhandled parameter attribute"); 354 return nullptr; 355 } 356 357 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, 358 llvm::Function *Fn) { 359 llvm::Module &M = CGM.getModule(); 360 llvm::LLVMContext &Ctx = M.getContext(); 361 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); 362 Function *EntryFn = 363 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M); 364 365 // Copy function attributes over, we have no argument or return attributes 366 // that can be valid on the real entry. 367 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex, 368 Fn->getAttributes().getFnAttrs()); 369 EntryFn->setAttributes(NewAttrs); 370 setHLSLEntryAttributes(FD, EntryFn); 371 372 // Set the called function as internal linkage. 373 Fn->setLinkage(GlobalValue::InternalLinkage); 374 375 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn); 376 IRBuilder<> B(BB); 377 llvm::SmallVector<Value *> Args; 378 // FIXME: support struct parameters where semantics are on members. 379 // See: https://github.com/llvm/llvm-project/issues/57874 380 unsigned SRetOffset = 0; 381 for (const auto &Param : Fn->args()) { 382 if (Param.hasStructRetAttr()) { 383 // FIXME: support output. 384 // See: https://github.com/llvm/llvm-project/issues/57874 385 SRetOffset = 1; 386 Args.emplace_back(PoisonValue::get(Param.getType())); 387 continue; 388 } 389 const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset); 390 Args.push_back(emitInputSemantic(B, *PD, Param.getType())); 391 } 392 393 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args); 394 (void)CI; 395 // FIXME: Handle codegen for return type semantics. 396 // See: https://github.com/llvm/llvm-project/issues/57875 397 B.CreateRetVoid(); 398 } 399 400 static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M, 401 bool CtorOrDtor) { 402 const auto *GV = 403 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors"); 404 if (!GV) 405 return; 406 const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer()); 407 if (!CA) 408 return; 409 // The global_ctor array elements are a struct [Priority, Fn *, COMDat]. 410 // HLSL neither supports priorities or COMDat values, so we will check those 411 // in an assert but not handle them. 412 413 llvm::SmallVector<Function *> CtorFns; 414 for (const auto &Ctor : CA->operands()) { 415 if (isa<ConstantAggregateZero>(Ctor)) 416 continue; 417 ConstantStruct *CS = cast<ConstantStruct>(Ctor); 418 419 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 && 420 "HLSL doesn't support setting priority for global ctors."); 421 assert(isa<ConstantPointerNull>(CS->getOperand(2)) && 422 "HLSL doesn't support COMDat for global ctors."); 423 Fns.push_back(cast<Function>(CS->getOperand(1))); 424 } 425 } 426 427 void CGHLSLRuntime::generateGlobalCtorDtorCalls() { 428 llvm::Module &M = CGM.getModule(); 429 SmallVector<Function *> CtorFns; 430 SmallVector<Function *> DtorFns; 431 gatherFunctions(CtorFns, M, true); 432 gatherFunctions(DtorFns, M, false); 433 434 // Insert a call to the global constructor at the beginning of the entry block 435 // to externally exported functions. This is a bit of a hack, but HLSL allows 436 // global constructors, but doesn't support driver initialization of globals. 437 for (auto &F : M.functions()) { 438 if (!F.hasFnAttribute("hlsl.shader")) 439 continue; 440 IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin()); 441 for (auto *Fn : CtorFns) 442 B.CreateCall(FunctionCallee(Fn)); 443 444 // Insert global dtors before the terminator of the last instruction 445 B.SetInsertPoint(F.back().getTerminator()); 446 for (auto *Fn : DtorFns) 447 B.CreateCall(FunctionCallee(Fn)); 448 } 449 450 // No need to keep global ctors/dtors for non-lib profile after call to 451 // ctors/dtors added for entry. 452 Triple T(M.getTargetTriple()); 453 if (T.getEnvironment() != Triple::EnvironmentType::Library) { 454 if (auto *GV = M.getNamedGlobal("llvm.global_ctors")) 455 GV->eraseFromParent(); 456 if (auto *GV = M.getNamedGlobal("llvm.global_dtors")) 457 GV->eraseFromParent(); 458 } 459 } 460