1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This provides an abstract class for HLSL code generation. Concrete 10 // subclasses of this implement code generation for specific HLSL 11 // runtime libraries. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "CGHLSLRuntime.h" 16 #include "CGDebugInfo.h" 17 #include "CodeGenModule.h" 18 #include "clang/AST/Decl.h" 19 #include "clang/Basic/TargetOptions.h" 20 #include "llvm/IR/IntrinsicsDirectX.h" 21 #include "llvm/IR/Metadata.h" 22 #include "llvm/IR/Module.h" 23 #include "llvm/Support/FormatVariadic.h" 24 25 using namespace clang; 26 using namespace CodeGen; 27 using namespace clang::hlsl; 28 using namespace llvm; 29 30 namespace { 31 32 void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) { 33 // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs. 34 // Assume ValVersionStr is legal here. 35 VersionTuple Version; 36 if (Version.tryParse(ValVersionStr) || Version.getBuild() || 37 Version.getSubminor() || !Version.getMinor()) { 38 return; 39 } 40 41 uint64_t Major = Version.getMajor(); 42 uint64_t Minor = *Version.getMinor(); 43 44 auto &Ctx = M.getContext(); 45 IRBuilder<> B(M.getContext()); 46 MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)), 47 ConstantAsMetadata::get(B.getInt32(Minor))}); 48 StringRef DXILValKey = "dx.valver"; 49 auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey); 50 DXILValMD->addOperand(Val); 51 } 52 void addDisableOptimizations(llvm::Module &M) { 53 StringRef Key = "dx.disable_optimizations"; 54 M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1); 55 } 56 // cbuffer will be translated into global variable in special address space. 57 // If translate into C, 58 // cbuffer A { 59 // float a; 60 // float b; 61 // } 62 // float foo() { return a + b; } 63 // 64 // will be translated into 65 // 66 // struct A { 67 // float a; 68 // float b; 69 // } cbuffer_A __attribute__((address_space(4))); 70 // float foo() { return cbuffer_A.a + cbuffer_A.b; } 71 // 72 // layoutBuffer will create the struct A type. 73 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a 74 // and cbuffer_A.b. 75 // 76 void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) { 77 if (Buf.Constants.empty()) 78 return; 79 80 std::vector<llvm::Type *> EltTys; 81 for (auto &Const : Buf.Constants) { 82 GlobalVariable *GV = Const.first; 83 Const.second = EltTys.size(); 84 llvm::Type *Ty = GV->getValueType(); 85 EltTys.emplace_back(Ty); 86 } 87 Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys); 88 } 89 90 GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) { 91 // Create global variable for CB. 92 GlobalVariable *CBGV = new GlobalVariable( 93 Buf.LayoutStruct, /*isConstant*/ true, 94 GlobalValue::LinkageTypes::ExternalLinkage, nullptr, 95 llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."), 96 GlobalValue::NotThreadLocal); 97 98 IRBuilder<> B(CBGV->getContext()); 99 Value *ZeroIdx = B.getInt32(0); 100 // Replace Const use with CB use. 101 for (auto &[GV, Offset] : Buf.Constants) { 102 Value *GEP = 103 B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)}); 104 105 assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() && 106 "constant type mismatch"); 107 108 // Replace. 109 GV->replaceAllUsesWith(GEP); 110 // Erase GV. 111 GV->removeDeadConstantUsers(); 112 GV->eraseFromParent(); 113 } 114 return CBGV; 115 } 116 117 } // namespace 118 119 void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) { 120 if (D->getStorageClass() == SC_Static) { 121 // For static inside cbuffer, take as global static. 122 // Don't add to cbuffer. 123 CGM.EmitGlobal(D); 124 return; 125 } 126 127 auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D)); 128 // Add debug info for constVal. 129 if (CGDebugInfo *DI = CGM.getModuleDebugInfo()) 130 if (CGM.getCodeGenOpts().getDebugInfo() >= 131 codegenoptions::DebugInfoKind::LimitedDebugInfo) 132 DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D); 133 134 // FIXME: support packoffset. 135 // See https://github.com/llvm/llvm-project/issues/57914. 136 uint32_t Offset = 0; 137 bool HasUserOffset = false; 138 139 unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX; 140 CB.Constants.emplace_back(std::make_pair(GV, LowerBound)); 141 } 142 143 void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) { 144 for (Decl *it : DC->decls()) { 145 if (auto *ConstDecl = dyn_cast<VarDecl>(it)) { 146 addConstant(ConstDecl, CB); 147 } else if (isa<CXXRecordDecl, EmptyDecl>(it)) { 148 // Nothing to do for this declaration. 149 } else if (isa<FunctionDecl>(it)) { 150 // A function within an cbuffer is effectively a top-level function, 151 // as it only refers to globally scoped declarations. 152 CGM.EmitTopLevelDecl(it); 153 } 154 } 155 } 156 157 void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) { 158 Buffers.emplace_back(Buffer(D)); 159 addBufferDecls(D, Buffers.back()); 160 } 161 162 void CGHLSLRuntime::finishCodeGen() { 163 auto &TargetOpts = CGM.getTarget().getTargetOpts(); 164 llvm::Module &M = CGM.getModule(); 165 Triple T(M.getTargetTriple()); 166 if (T.getArch() == Triple::ArchType::dxil) 167 addDxilValVersion(TargetOpts.DxilValidatorVersion, M); 168 169 generateGlobalCtorDtorCalls(); 170 if (CGM.getCodeGenOpts().OptimizationLevel == 0) 171 addDisableOptimizations(M); 172 173 const DataLayout &DL = M.getDataLayout(); 174 175 for (auto &Buf : Buffers) { 176 layoutBuffer(Buf, DL); 177 GlobalVariable *GV = replaceBuffer(Buf); 178 M.insertGlobalVariable(GV); 179 llvm::hlsl::ResourceClass RC = Buf.IsCBuffer 180 ? llvm::hlsl::ResourceClass::CBuffer 181 : llvm::hlsl::ResourceClass::SRV; 182 llvm::hlsl::ResourceKind RK = Buf.IsCBuffer 183 ? llvm::hlsl::ResourceKind::CBuffer 184 : llvm::hlsl::ResourceKind::TBuffer; 185 addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false, 186 llvm::hlsl::ElementType::Invalid, Buf.Binding); 187 } 188 } 189 190 CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D) 191 : Name(D->getName()), IsCBuffer(D->isCBuffer()), 192 Binding(D->getAttr<HLSLResourceBindingAttr>()) {} 193 194 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV, 195 llvm::hlsl::ResourceClass RC, 196 llvm::hlsl::ResourceKind RK, 197 bool IsROV, 198 llvm::hlsl::ElementType ET, 199 BufferResBinding &Binding) { 200 llvm::Module &M = CGM.getModule(); 201 202 NamedMDNode *ResourceMD = nullptr; 203 switch (RC) { 204 case llvm::hlsl::ResourceClass::UAV: 205 ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs"); 206 break; 207 case llvm::hlsl::ResourceClass::SRV: 208 ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs"); 209 break; 210 case llvm::hlsl::ResourceClass::CBuffer: 211 ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs"); 212 break; 213 default: 214 assert(false && "Unsupported buffer type!"); 215 return; 216 } 217 assert(ResourceMD != nullptr && 218 "ResourceMD must have been set by the switch above."); 219 220 llvm::hlsl::FrontendResource Res( 221 GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space); 222 ResourceMD->addOperand(Res.getMetadata()); 223 } 224 225 static llvm::hlsl::ElementType 226 calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) { 227 using llvm::hlsl::ElementType; 228 229 // TODO: We may need to update this when we add things like ByteAddressBuffer 230 // that don't have a template parameter (or, indeed, an element type). 231 const auto *TST = ResourceTy->getAs<TemplateSpecializationType>(); 232 assert(TST && "Resource types must be template specializations"); 233 ArrayRef<TemplateArgument> Args = TST->template_arguments(); 234 assert(!Args.empty() && "Resource has no element type"); 235 236 // At this point we have a resource with an element type, so we can assume 237 // that it's valid or we would have diagnosed the error earlier. 238 QualType ElTy = Args[0].getAsType(); 239 240 // We should either have a basic type or a vector of a basic type. 241 if (const auto *VecTy = ElTy->getAs<clang::VectorType>()) 242 ElTy = VecTy->getElementType(); 243 244 if (ElTy->isSignedIntegerType()) { 245 switch (Context.getTypeSize(ElTy)) { 246 case 16: 247 return ElementType::I16; 248 case 32: 249 return ElementType::I32; 250 case 64: 251 return ElementType::I64; 252 } 253 } else if (ElTy->isUnsignedIntegerType()) { 254 switch (Context.getTypeSize(ElTy)) { 255 case 16: 256 return ElementType::U16; 257 case 32: 258 return ElementType::U32; 259 case 64: 260 return ElementType::U64; 261 } 262 } else if (ElTy->isSpecificBuiltinType(BuiltinType::Half)) 263 return ElementType::F16; 264 else if (ElTy->isSpecificBuiltinType(BuiltinType::Float)) 265 return ElementType::F32; 266 else if (ElTy->isSpecificBuiltinType(BuiltinType::Double)) 267 return ElementType::F64; 268 269 // TODO: We need to handle unorm/snorm float types here once we support them 270 llvm_unreachable("Invalid element type for resource"); 271 } 272 273 void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) { 274 const Type *Ty = D->getType()->getPointeeOrArrayElementType(); 275 if (!Ty) 276 return; 277 const auto *RD = Ty->getAsCXXRecordDecl(); 278 if (!RD) 279 return; 280 const auto *Attr = RD->getAttr<HLSLResourceAttr>(); 281 if (!Attr) 282 return; 283 284 llvm::hlsl::ResourceClass RC = Attr->getResourceClass(); 285 llvm::hlsl::ResourceKind RK = Attr->getResourceKind(); 286 bool IsROV = Attr->getIsROV(); 287 llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty); 288 289 BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>()); 290 addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding); 291 } 292 293 CGHLSLRuntime::BufferResBinding::BufferResBinding( 294 HLSLResourceBindingAttr *Binding) { 295 if (Binding) { 296 llvm::APInt RegInt(64, 0); 297 Binding->getSlot().substr(1).getAsInteger(10, RegInt); 298 Reg = RegInt.getLimitedValue(); 299 llvm::APInt SpaceInt(64, 0); 300 Binding->getSpace().substr(5).getAsInteger(10, SpaceInt); 301 Space = SpaceInt.getLimitedValue(); 302 } else { 303 Space = 0; 304 } 305 } 306 307 void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes( 308 const FunctionDecl *FD, llvm::Function *Fn) { 309 const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); 310 assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr"); 311 const StringRef ShaderAttrKindStr = "hlsl.shader"; 312 Fn->addFnAttr(ShaderAttrKindStr, 313 ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); 314 if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) { 315 const StringRef NumThreadsKindStr = "hlsl.numthreads"; 316 std::string NumThreadsStr = 317 formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(), 318 NumThreadsAttr->getZ()); 319 Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr); 320 } 321 } 322 323 static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) { 324 if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) { 325 Value *Result = PoisonValue::get(Ty); 326 for (unsigned I = 0; I < VT->getNumElements(); ++I) { 327 Value *Elt = B.CreateCall(F, {B.getInt32(I)}); 328 Result = B.CreateInsertElement(Result, Elt, I); 329 } 330 return Result; 331 } 332 return B.CreateCall(F, {B.getInt32(0)}); 333 } 334 335 llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, 336 const ParmVarDecl &D, 337 llvm::Type *Ty) { 338 assert(D.hasAttrs() && "Entry parameter missing annotation attribute!"); 339 if (D.hasAttr<HLSLSV_GroupIndexAttr>()) { 340 llvm::Function *DxGroupIndex = 341 CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group); 342 return B.CreateCall(FunctionCallee(DxGroupIndex)); 343 } 344 if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) { 345 llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id); 346 return buildVectorInput(B, DxThreadID, Ty); 347 } 348 assert(false && "Unhandled parameter attribute"); 349 return nullptr; 350 } 351 352 void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, 353 llvm::Function *Fn) { 354 llvm::Module &M = CGM.getModule(); 355 llvm::LLVMContext &Ctx = M.getContext(); 356 auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); 357 Function *EntryFn = 358 Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M); 359 360 // Copy function attributes over, we have no argument or return attributes 361 // that can be valid on the real entry. 362 AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex, 363 Fn->getAttributes().getFnAttrs()); 364 EntryFn->setAttributes(NewAttrs); 365 setHLSLEntryAttributes(FD, EntryFn); 366 367 // Set the called function as internal linkage. 368 Fn->setLinkage(GlobalValue::InternalLinkage); 369 370 BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn); 371 IRBuilder<> B(BB); 372 llvm::SmallVector<Value *> Args; 373 // FIXME: support struct parameters where semantics are on members. 374 // See: https://github.com/llvm/llvm-project/issues/57874 375 unsigned SRetOffset = 0; 376 for (const auto &Param : Fn->args()) { 377 if (Param.hasStructRetAttr()) { 378 // FIXME: support output. 379 // See: https://github.com/llvm/llvm-project/issues/57874 380 SRetOffset = 1; 381 Args.emplace_back(PoisonValue::get(Param.getType())); 382 continue; 383 } 384 const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset); 385 Args.push_back(emitInputSemantic(B, *PD, Param.getType())); 386 } 387 388 CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args); 389 (void)CI; 390 // FIXME: Handle codegen for return type semantics. 391 // See: https://github.com/llvm/llvm-project/issues/57875 392 B.CreateRetVoid(); 393 } 394 395 static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M, 396 bool CtorOrDtor) { 397 const auto *GV = 398 M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors"); 399 if (!GV) 400 return; 401 const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer()); 402 if (!CA) 403 return; 404 // The global_ctor array elements are a struct [Priority, Fn *, COMDat]. 405 // HLSL neither supports priorities or COMDat values, so we will check those 406 // in an assert but not handle them. 407 408 llvm::SmallVector<Function *> CtorFns; 409 for (const auto &Ctor : CA->operands()) { 410 if (isa<ConstantAggregateZero>(Ctor)) 411 continue; 412 ConstantStruct *CS = cast<ConstantStruct>(Ctor); 413 414 assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 && 415 "HLSL doesn't support setting priority for global ctors."); 416 assert(isa<ConstantPointerNull>(CS->getOperand(2)) && 417 "HLSL doesn't support COMDat for global ctors."); 418 Fns.push_back(cast<Function>(CS->getOperand(1))); 419 } 420 } 421 422 void CGHLSLRuntime::generateGlobalCtorDtorCalls() { 423 llvm::Module &M = CGM.getModule(); 424 SmallVector<Function *> CtorFns; 425 SmallVector<Function *> DtorFns; 426 gatherFunctions(CtorFns, M, true); 427 gatherFunctions(DtorFns, M, false); 428 429 // Insert a call to the global constructor at the beginning of the entry block 430 // to externally exported functions. This is a bit of a hack, but HLSL allows 431 // global constructors, but doesn't support driver initialization of globals. 432 for (auto &F : M.functions()) { 433 if (!F.hasFnAttribute("hlsl.shader")) 434 continue; 435 IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin()); 436 for (auto *Fn : CtorFns) 437 B.CreateCall(FunctionCallee(Fn)); 438 439 // Insert global dtors before the terminator of the last instruction 440 B.SetInsertPoint(F.back().getTerminator()); 441 for (auto *Fn : DtorFns) 442 B.CreateCall(FunctionCallee(Fn)); 443 } 444 445 // No need to keep global ctors/dtors for non-lib profile after call to 446 // ctors/dtors added for entry. 447 Triple T(M.getTargetTriple()); 448 if (T.getEnvironment() != Triple::EnvironmentType::Library) { 449 if (auto *GV = M.getNamedGlobal("llvm.global_ctors")) 450 GV->eraseFromParent(); 451 if (auto *GV = M.getNamedGlobal("llvm.global_dtors")) 452 GV->eraseFromParent(); 453 } 454 } 455