1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This provides an abstract class for HLSL code generation. Concrete 10 // subclasses of this implement code generation for specific HLSL 11 // runtime libraries. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "CGHLSLRuntime.h" 16 #include "CGDebugInfo.h" 17 #include "CodeGenModule.h" 18 #include "clang/AST/Decl.h" 19 #include "clang/Basic/TargetOptions.h" 20 #include "llvm/IR/Metadata.h" 21 #include "llvm/IR/Module.h" 22 #include "llvm/Support/FormatVariadic.h" 23 24 using namespace clang; 25 using namespace CodeGen; 26 using namespace clang::hlsl; 27 using namespace llvm; 28 29 namespace { 30 31 void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) { 32 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs. 33 // Assume ValVersionStr is legal here. 34 VersionTuple Version; 35 if (Version.tryParse(ValVersionStr) || Version.getBuild() || 36 Version.getSubminor() || !Version.getMinor()) { 37 return; 38 } 39 40 uint64_t Major = Version.getMajor(); 41 uint64_t Minor = *Version.getMinor(); 42 43 auto &Ctx = M.getContext(); 44 IRBuilder<> B(M.getContext()); 45 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)), 46 ConstantAsMetadata::get(B.getInt32(Minor))}); 47 StringRef DXILValKey = "dx.valver"; 48 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey); 49 DXILValMD->addOperand(Val); 50 } 51 void addDisableOptimizations(llvm::Module &M) { 52 StringRef Key = "dx.disable_optimizations"; 53 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1); 54 } 55 // cbuffer will be translated into global variable in special address space. 56 // If translate into C, 57 // cbuffer A { 58 // float a; 59 // float b; 60 // } 61 // float foo() { return a + b; } 62 // 63 // will be translated into 64 // 65 // struct A { 66 // float a; 67 // float b; 68 // } cbuffer_A __attribute__((address_space(4))); 69 // float foo() { return cbuffer_A.a + cbuffer_A.b; } 70 // 71 // layoutBuffer will create the struct A type. 72 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a 73 // and cbuffer_A.b. 74 // 75 void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) { 76 if (Buf.Constants.empty()) 77 return; 78 79 std::vector<llvm::Type *> EltTys; 80 for (auto &Const : Buf.Constants) { 81 GlobalVariable *GV = Const.first; 82 Const.second = EltTys.size(); 83 llvm::Type *Ty = GV->getValueType(); 84 EltTys.emplace_back(Ty); 85 } 86 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys); 87 } 88 89 GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) { 90 // Create global variable for CB. 91 GlobalVariable *CBGV = new GlobalVariable( 92 Buf.LayoutStruct, /*isConstant*/ true, 93 GlobalValue::LinkageTypes::ExternalLinkage, nullptr, 94 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."), 95 GlobalValue::NotThreadLocal); 96 97 IRBuilder<> B(CBGV->getContext()); 98 Value *ZeroIdx = B.getInt32(0); 99 // Replace Const use with CB use. 100 for (auto &[GV, Offset] : Buf.Constants) { 101 Value *GEP = 102 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)}); 103 104 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() && 105 "constant type mismatch"); 106 107 // Replace. 108 GV->replaceAllUsesWith(GEP); 109 // Erase GV. 110 GV->removeDeadConstantUsers(); 111 GV->eraseFromParent(); 112 } 113 return CBGV; 114 } 115 116 } // namespace 117 118 llvm::Triple::ArchType CGHLSLRuntime::getArch() { 119 return CGM.getTarget().getTriple().getArch(); 120 } 121 122 void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) { 123 if (D->getStorageClass() == SC_Static) { 124 // For static inside cbuffer, take as global static. 125 // Don't add to cbuffer. 126 CGM.EmitGlobal(D); 127 return; 128 } 129 130 auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D)); 131 // Add debug info for constVal. 132 if (CGDebugInfo *DI = CGM.getModuleDebugInfo()) 133 if (CGM.getCodeGenOpts().getDebugInfo() >= 134 codegenoptions::DebugInfoKind::LimitedDebugInfo) 135 DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D); 136 137 // FIXME: support packoffset. 138 // See https://github.com/llvm/llvm-project/issues/57914. 139 uint32_t Offset = 0; 140 bool HasUserOffset = false; 141 142 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX; 143 CB.Constants.emplace_back(std::make_pair(GV, LowerBound)); 144 } 145 146 void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) { 147 for (Decl *it : DC->decls()) { 148 if (auto *ConstDecl = dyn_cast<VarDecl>(it)) { 149 addConstant(ConstDecl, CB); 150 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) { 151 // Nothing to do for this declaration. 152 } else if (isa<FunctionDecl>(it)) { 153 // A function within an cbuffer is effectively a top-level function, 154 // as it only refers to globally scoped declarations. 155 CGM.EmitTopLevelDecl(it); 156 } 157 } 158 } 159 160 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) { 161 Buffers.emplace_back(Buffer(D)); 162 addBufferDecls(D, Buffers.back()); 163 } 164 165 void CGHLSLRuntime::finishCodeGen() { 166 auto &TargetOpts = CGM.getTarget().getTargetOpts(); 167 llvm::Module &M = CGM.getModule(); 168 Triple T(M.getTargetTriple()); 169 if (T.getArch() == Triple::ArchType::dxil) 170 addDxilValVersion(TargetOpts.DxilValidatorVersion, M); 171 172 generateGlobalCtorDtorCalls(); 173 if (CGM.getCodeGenOpts().OptimizationLevel == 0) 174 addDisableOptimizations(M); 175 176 const DataLayout &DL = M.getDataLayout(); 177 178 for (auto &Buf : Buffers) { 179 layoutBuffer(Buf, DL); 180 GlobalVariable *GV = replaceBuffer(Buf); 181 M.insertGlobalVariable(GV); 182 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer 183 ? llvm::hlsl::ResourceClass::CBuffer 184 : llvm::hlsl::ResourceClass::SRV; 185 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer 186 ? llvm::hlsl::ResourceKind::CBuffer 187 : llvm::hlsl::ResourceKind::TBuffer; 188 addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false, 189 llvm::hlsl::ElementType::Invalid, Buf.Binding); 190 } 191 } 192 193 CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D) 194 : Name(D->getName()), IsCBuffer(D->isCBuffer()), 195 Binding(D->getAttr<HLSLResourceBindingAttr>()) {} 196 197 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV, 198 llvm::hlsl::ResourceClass RC, 199 llvm::hlsl::ResourceKind RK, 200 bool IsROV, 201 llvm::hlsl::ElementType ET, 202 BufferResBinding &Binding) { 203 llvm::Module &M = CGM.getModule(); 204 205 NamedMDNode *ResourceMD = nullptr; 206 switch (RC) { 207 case llvm::hlsl::ResourceClass::UAV: 208 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs"); 209 break; 210 case llvm::hlsl::ResourceClass::SRV: 211 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs"); 212 break; 213 case llvm::hlsl::ResourceClass::CBuffer: 214 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs"); 215 break; 216 default: 217 assert(false && "Unsupported buffer type!"); 218 return; 219 } 220 assert(ResourceMD != nullptr && 221 "ResourceMD must have been set by the switch above."); 222 223 llvm::hlsl::FrontendResource Res( 224 GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space); 225 ResourceMD->addOperand(Res.getMetadata()); 226 } 227 228 static llvm::hlsl::ElementType 229 calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) { 230 using llvm::hlsl::ElementType; 231 232 // TODO: We may need to update this when we add things like ByteAddressBuffer 233 // that don't have a template parameter (or, indeed, an element type). 234 const auto *TST = ResourceTy->getAs<TemplateSpecializationType>(); 235 assert(TST && "Resource types must be template specializations"); 236 ArrayRef<TemplateArgument> Args = TST->template_arguments(); 237 assert(!Args.empty() && "Resource has no element type"); 238 239 // At this point we have a resource with an element type, so we can assume 240 // that it's valid or we would have diagnosed the error earlier. 241 QualType ElTy = Args[0].getAsType(); 242 243 // We should either have a basic type or a vector of a basic type. 244 if (const auto *VecTy = ElTy->getAs<clang::VectorType>()) 245 ElTy = VecTy->getElementType(); 246 247 if (ElTy->isSignedIntegerType()) { 248 switch (Context.getTypeSize(ElTy)) { 249 case 16: 250 return ElementType::I16; 251 case 32: 252 return ElementType::I32; 253 case 64: 254 return ElementType::I64; 255 } 256 } else if (ElTy->isUnsignedIntegerType()) { 257 switch (Context.getTypeSize(ElTy)) { 258 case 16: 259 return ElementType::U16; 260 case 32: 261 return ElementType::U32; 262 case 64: 263 return ElementType::U64; 264 } 265 } else if (ElTy->isSpecificBuiltinType(BuiltinType::Half)) 266 return ElementType::F16; 267 else if (ElTy->isSpecificBuiltinType(BuiltinType::Float)) 268 return ElementType::F32; 269 else if (ElTy->isSpecificBuiltinType(BuiltinType::Double)) 270 return ElementType::F64; 271 272 // TODO: We need to handle unorm/snorm float types here once we support them 273 llvm_unreachable("Invalid element type for resource"); 274 } 275 276 void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) { 277 const Type *Ty = D->getType()->getPointeeOrArrayElementType(); 278 if (!Ty) 279 return; 280 const auto *RD = Ty->getAsCXXRecordDecl(); 281 if (!RD) 282 return; 283 const auto *HLSLResAttr = RD->getAttr<HLSLResourceAttr>(); 284 const auto *HLSLResClassAttr = RD->getAttr<HLSLResourceClassAttr>(); 285 if (!HLSLResAttr || !HLSLResClassAttr) 286 return; 287 288 llvm::hlsl::ResourceClass RC = HLSLResClassAttr->getResourceClass(); 289 llvm::hlsl::ResourceKind RK = HLSLResAttr->getResourceKind(); 290 bool IsROV = HLSLResAttr->getIsROV(); 291 llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty); 292 293 BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>()); 294 addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding); 295 } 296 297 CGHLSLRuntime::BufferResBinding::BufferResBinding( 298 HLSLResourceBindingAttr *Binding) { 299 if (Binding) { 300 llvm::APInt RegInt(64, 0); 301 Binding->getSlot().substr(1).getAsInteger(10, RegInt); 302 Reg = RegInt.getLimitedValue(); 303 llvm::APInt SpaceInt(64, 0); 304 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt); 305 Space = SpaceInt.getLimitedValue(); 306 } else { 307 Space = 0; 308 } 309 } 310 311 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes( 312 const FunctionDecl *FD, llvm::Function *Fn) { 313 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); 314 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr"); 315 const StringRef ShaderAttrKindStr = "hlsl.shader"; 316 Fn->addFnAttr(ShaderAttrKindStr, 317 llvm::Triple::getEnvironmentTypeName(ShaderAttr->getType())); 318 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) { 319 const StringRef NumThreadsKindStr = "hlsl.numthreads"; 320 std::string NumThreadsStr = 321 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(), 322 NumThreadsAttr->getZ()); 323 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr); 324 } 325 } 326 327 static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) { 328 if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) { 329 Value *Result = PoisonValue::get(Ty); 330 for (unsigned I = 0; I < VT->getNumElements(); ++I) { 331 Value *Elt = B.CreateCall(F, {B.getInt32(I)}); 332 Result = B.CreateInsertElement(Result, Elt, I); 333 } 334 return Result; 335 } 336 return B.CreateCall(F, {B.getInt32(0)}); 337 } 338 339 llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, 340 const ParmVarDecl &D, 341 llvm::Type *Ty) { 342 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!"); 343 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) { 344 llvm::Function *DxGroupIndex = 345 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group); 346 return B.CreateCall(FunctionCallee(DxGroupIndex)); 347 } 348 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) { 349 llvm::Function *ThreadIDIntrinsic = 350 CGM.getIntrinsic(getThreadIdIntrinsic()); 351 return buildVectorInput(B, ThreadIDIntrinsic, Ty); 352 } 353 assert(false && "Unhandled parameter attribute"); 354 return nullptr; 355 } 356 357 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, 358 llvm::Function *Fn) { 359 llvm::Module &M = CGM.getModule(); 360 llvm::LLVMContext &Ctx = M.getContext(); 361 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); 362 Function *EntryFn = 363 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M); 364 365 // Copy function attributes over, we have no argument or return attributes 366 // that can be valid on the real entry. 367 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex, 368 Fn->getAttributes().getFnAttrs()); 369 EntryFn->setAttributes(NewAttrs); 370 setHLSLEntryAttributes(FD, EntryFn); 371 372 // Set the called function as internal linkage. 373 Fn->setLinkage(GlobalValue::InternalLinkage); 374 375 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn); 376 IRBuilder<> B(BB); 377 llvm::SmallVector<Value *> Args; 378 // FIXME: support struct parameters where semantics are on members. 379 // See: https://github.com/llvm/llvm-project/issues/57874 380 unsigned SRetOffset = 0; 381 for (const auto &Param : Fn->args()) { 382 if (Param.hasStructRetAttr()) { 383 // FIXME: support output. 384 // See: https://github.com/llvm/llvm-project/issues/57874 385 SRetOffset = 1; 386 Args.emplace_back(PoisonValue::get(Param.getType())); 387 continue; 388 } 389 const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset); 390 Args.push_back(emitInputSemantic(B, *PD, Param.getType())); 391 } 392 393 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args); 394 (void)CI; 395 // FIXME: Handle codegen for return type semantics. 396 // See: https://github.com/llvm/llvm-project/issues/57875 397 B.CreateRetVoid(); 398 } 399 400 static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M, 401 bool CtorOrDtor) { 402 const auto *GV = 403 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors"); 404 if (!GV) 405 return; 406 const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer()); 407 if (!CA) 408 return; 409 // The global_ctor array elements are a struct [Priority, Fn *, COMDat]. 410 // HLSL neither supports priorities or COMDat values, so we will check those 411 // in an assert but not handle them. 412 413 llvm::SmallVector<Function *> CtorFns; 414 for (const auto &Ctor : CA->operands()) { 415 if (isa<ConstantAggregateZero>(Ctor)) 416 continue; 417 ConstantStruct *CS = cast<ConstantStruct>(Ctor); 418 419 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 && 420 "HLSL doesn't support setting priority for global ctors."); 421 assert(isa<ConstantPointerNull>(CS->getOperand(2)) && 422 "HLSL doesn't support COMDat for global ctors."); 423 Fns.push_back(cast<Function>(CS->getOperand(1))); 424 } 425 } 426 427 void CGHLSLRuntime::generateGlobalCtorDtorCalls() { 428 llvm::Module &M = CGM.getModule(); 429 SmallVector<Function *> CtorFns; 430 SmallVector<Function *> DtorFns; 431 gatherFunctions(CtorFns, M, true); 432 gatherFunctions(DtorFns, M, false); 433 434 // Insert a call to the global constructor at the beginning of the entry block 435 // to externally exported functions. This is a bit of a hack, but HLSL allows 436 // global constructors, but doesn't support driver initialization of globals. 437 for (auto &F : M.functions()) { 438 if (!F.hasFnAttribute("hlsl.shader")) 439 continue; 440 IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin()); 441 for (auto *Fn : CtorFns) 442 B.CreateCall(FunctionCallee(Fn)); 443 444 // Insert global dtors before the terminator of the last instruction 445 B.SetInsertPoint(F.back().getTerminator()); 446 for (auto *Fn : DtorFns) 447 B.CreateCall(FunctionCallee(Fn)); 448 } 449 450 // No need to keep global ctors/dtors for non-lib profile after call to 451 // ctors/dtors added for entry. 452 Triple T(M.getTargetTriple()); 453 if (T.getEnvironment() != Triple::EnvironmentType::Library) { 454 if (auto *GV = M.getNamedGlobal("llvm.global_ctors")) 455 GV->eraseFromParent(); 456 if (auto *GV = M.getNamedGlobal("llvm.global_dtors")) 457 GV->eraseFromParent(); 458 } 459 } 460