181ad6265SDimitry Andric //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===// 281ad6265SDimitry Andric // 381ad6265SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 481ad6265SDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 581ad6265SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 681ad6265SDimitry Andric // 781ad6265SDimitry Andric //===----------------------------------------------------------------------===// 881ad6265SDimitry Andric // 981ad6265SDimitry Andric // This provides an abstract class for HLSL code generation. Concrete 1081ad6265SDimitry Andric // subclasses of this implement code generation for specific HLSL 1181ad6265SDimitry Andric // runtime libraries. 1281ad6265SDimitry Andric // 1381ad6265SDimitry Andric //===----------------------------------------------------------------------===// 1481ad6265SDimitry Andric 1581ad6265SDimitry Andric #include "CGHLSLRuntime.h" 16*bdd1243dSDimitry Andric #include "CGDebugInfo.h" 1781ad6265SDimitry Andric #include "CodeGenModule.h" 18*bdd1243dSDimitry Andric #include "clang/AST/Decl.h" 1981ad6265SDimitry Andric #include "clang/Basic/TargetOptions.h" 20*bdd1243dSDimitry Andric #include "llvm/IR/IntrinsicsDirectX.h" 2181ad6265SDimitry Andric #include "llvm/IR/Metadata.h" 2281ad6265SDimitry Andric #include "llvm/IR/Module.h" 23*bdd1243dSDimitry Andric #include "llvm/Support/FormatVariadic.h" 2481ad6265SDimitry Andric 2581ad6265SDimitry Andric using namespace clang; 2681ad6265SDimitry Andric using namespace CodeGen; 27*bdd1243dSDimitry Andric using namespace clang::hlsl; 2881ad6265SDimitry Andric using namespace llvm; 2981ad6265SDimitry Andric 3081ad6265SDimitry Andric namespace { 31*bdd1243dSDimitry Andric 3281ad6265SDimitry Andric void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) { 3381ad6265SDimitry Andric // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs. 3481ad6265SDimitry Andric // Assume ValVersionStr is legal here. 3581ad6265SDimitry Andric VersionTuple Version; 3681ad6265SDimitry Andric if (Version.tryParse(ValVersionStr) || Version.getBuild() || 3781ad6265SDimitry Andric Version.getSubminor() || !Version.getMinor()) { 3881ad6265SDimitry Andric return; 3981ad6265SDimitry Andric } 4081ad6265SDimitry Andric 4181ad6265SDimitry Andric uint64_t Major = Version.getMajor(); 4281ad6265SDimitry Andric uint64_t Minor = *Version.getMinor(); 4381ad6265SDimitry Andric 4481ad6265SDimitry Andric auto &Ctx = M.getContext(); 4581ad6265SDimitry Andric IRBuilder<> B(M.getContext()); 4681ad6265SDimitry Andric MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)), 4781ad6265SDimitry Andric ConstantAsMetadata::get(B.getInt32(Minor))}); 48*bdd1243dSDimitry Andric StringRef DXILValKey = "dx.valver"; 49*bdd1243dSDimitry Andric auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey); 50*bdd1243dSDimitry Andric DXILValMD->addOperand(Val); 5181ad6265SDimitry Andric } 52*bdd1243dSDimitry Andric void addDisableOptimizations(llvm::Module &M) { 53*bdd1243dSDimitry Andric StringRef Key = "dx.disable_optimizations"; 54*bdd1243dSDimitry Andric M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1); 55*bdd1243dSDimitry Andric } 56*bdd1243dSDimitry Andric // cbuffer will be translated into global variable in special address space. 57*bdd1243dSDimitry Andric // If translate into C, 58*bdd1243dSDimitry Andric // cbuffer A { 59*bdd1243dSDimitry Andric // float a; 60*bdd1243dSDimitry Andric // float b; 61*bdd1243dSDimitry Andric // } 62*bdd1243dSDimitry Andric // float foo() { return a + b; } 63*bdd1243dSDimitry Andric // 64*bdd1243dSDimitry Andric // will be translated into 65*bdd1243dSDimitry Andric // 66*bdd1243dSDimitry Andric // struct A { 67*bdd1243dSDimitry Andric // float a; 68*bdd1243dSDimitry Andric // float b; 69*bdd1243dSDimitry Andric // } cbuffer_A __attribute__((address_space(4))); 70*bdd1243dSDimitry Andric // float foo() { return cbuffer_A.a + cbuffer_A.b; } 71*bdd1243dSDimitry Andric // 72*bdd1243dSDimitry Andric // layoutBuffer will create the struct A type. 73*bdd1243dSDimitry Andric // replaceBuffer will replace use of global variable a and b with cbuffer_A.a 74*bdd1243dSDimitry Andric // and cbuffer_A.b. 75*bdd1243dSDimitry Andric // 76*bdd1243dSDimitry Andric void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) { 77*bdd1243dSDimitry Andric if (Buf.Constants.empty()) 78*bdd1243dSDimitry Andric return; 79*bdd1243dSDimitry Andric 80*bdd1243dSDimitry Andric std::vector<llvm::Type *> EltTys; 81*bdd1243dSDimitry Andric for (auto &Const : Buf.Constants) { 82*bdd1243dSDimitry Andric GlobalVariable *GV = Const.first; 83*bdd1243dSDimitry Andric Const.second = EltTys.size(); 84*bdd1243dSDimitry Andric llvm::Type *Ty = GV->getValueType(); 85*bdd1243dSDimitry Andric EltTys.emplace_back(Ty); 86*bdd1243dSDimitry Andric } 87*bdd1243dSDimitry Andric Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys); 88*bdd1243dSDimitry Andric } 89*bdd1243dSDimitry Andric 90*bdd1243dSDimitry Andric GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) { 91*bdd1243dSDimitry Andric // Create global variable for CB. 92*bdd1243dSDimitry Andric GlobalVariable *CBGV = new GlobalVariable( 93*bdd1243dSDimitry Andric Buf.LayoutStruct, /*isConstant*/ true, 94*bdd1243dSDimitry Andric GlobalValue::LinkageTypes::ExternalLinkage, nullptr, 95*bdd1243dSDimitry Andric llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."), 96*bdd1243dSDimitry Andric GlobalValue::NotThreadLocal); 97*bdd1243dSDimitry Andric 98*bdd1243dSDimitry Andric IRBuilder<> B(CBGV->getContext()); 99*bdd1243dSDimitry Andric Value *ZeroIdx = B.getInt32(0); 100*bdd1243dSDimitry Andric // Replace Const use with CB use. 101*bdd1243dSDimitry Andric for (auto &[GV, Offset] : Buf.Constants) { 102*bdd1243dSDimitry Andric Value *GEP = 103*bdd1243dSDimitry Andric B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)}); 104*bdd1243dSDimitry Andric 105*bdd1243dSDimitry Andric assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() && 106*bdd1243dSDimitry Andric "constant type mismatch"); 107*bdd1243dSDimitry Andric 108*bdd1243dSDimitry Andric // Replace. 109*bdd1243dSDimitry Andric GV->replaceAllUsesWith(GEP); 110*bdd1243dSDimitry Andric // Erase GV. 111*bdd1243dSDimitry Andric GV->removeDeadConstantUsers(); 112*bdd1243dSDimitry Andric GV->eraseFromParent(); 113*bdd1243dSDimitry Andric } 114*bdd1243dSDimitry Andric return CBGV; 115*bdd1243dSDimitry Andric } 116*bdd1243dSDimitry Andric 11781ad6265SDimitry Andric } // namespace 11881ad6265SDimitry Andric 119*bdd1243dSDimitry Andric void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) { 120*bdd1243dSDimitry Andric if (D->getStorageClass() == SC_Static) { 121*bdd1243dSDimitry Andric // For static inside cbuffer, take as global static. 122*bdd1243dSDimitry Andric // Don't add to cbuffer. 123*bdd1243dSDimitry Andric CGM.EmitGlobal(D); 124*bdd1243dSDimitry Andric return; 125*bdd1243dSDimitry Andric } 126*bdd1243dSDimitry Andric 127*bdd1243dSDimitry Andric auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D)); 128*bdd1243dSDimitry Andric // Add debug info for constVal. 129*bdd1243dSDimitry Andric if (CGDebugInfo *DI = CGM.getModuleDebugInfo()) 130*bdd1243dSDimitry Andric if (CGM.getCodeGenOpts().getDebugInfo() >= 131*bdd1243dSDimitry Andric codegenoptions::DebugInfoKind::LimitedDebugInfo) 132*bdd1243dSDimitry Andric DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D); 133*bdd1243dSDimitry Andric 134*bdd1243dSDimitry Andric // FIXME: support packoffset. 135*bdd1243dSDimitry Andric // See https://github.com/llvm/llvm-project/issues/57914. 136*bdd1243dSDimitry Andric uint32_t Offset = 0; 137*bdd1243dSDimitry Andric bool HasUserOffset = false; 138*bdd1243dSDimitry Andric 139*bdd1243dSDimitry Andric unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX; 140*bdd1243dSDimitry Andric CB.Constants.emplace_back(std::make_pair(GV, LowerBound)); 141*bdd1243dSDimitry Andric } 142*bdd1243dSDimitry Andric 143*bdd1243dSDimitry Andric void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) { 144*bdd1243dSDimitry Andric for (Decl *it : DC->decls()) { 145*bdd1243dSDimitry Andric if (auto *ConstDecl = dyn_cast<VarDecl>(it)) { 146*bdd1243dSDimitry Andric addConstant(ConstDecl, CB); 147*bdd1243dSDimitry Andric } else if (isa<CXXRecordDecl, EmptyDecl>(it)) { 148*bdd1243dSDimitry Andric // Nothing to do for this declaration. 149*bdd1243dSDimitry Andric } else if (isa<FunctionDecl>(it)) { 150*bdd1243dSDimitry Andric // A function within an cbuffer is effectively a top-level function, 151*bdd1243dSDimitry Andric // as it only refers to globally scoped declarations. 152*bdd1243dSDimitry Andric CGM.EmitTopLevelDecl(it); 153*bdd1243dSDimitry Andric } 154*bdd1243dSDimitry Andric } 155*bdd1243dSDimitry Andric } 156*bdd1243dSDimitry Andric 157*bdd1243dSDimitry Andric void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) { 158*bdd1243dSDimitry Andric Buffers.emplace_back(Buffer(D)); 159*bdd1243dSDimitry Andric addBufferDecls(D, Buffers.back()); 160*bdd1243dSDimitry Andric } 161*bdd1243dSDimitry Andric 16281ad6265SDimitry Andric void CGHLSLRuntime::finishCodeGen() { 16381ad6265SDimitry Andric auto &TargetOpts = CGM.getTarget().getTargetOpts(); 16481ad6265SDimitry Andric llvm::Module &M = CGM.getModule(); 165*bdd1243dSDimitry Andric Triple T(M.getTargetTriple()); 166*bdd1243dSDimitry Andric if (T.getArch() == Triple::ArchType::dxil) 16781ad6265SDimitry Andric addDxilValVersion(TargetOpts.DxilValidatorVersion, M); 168*bdd1243dSDimitry Andric 169*bdd1243dSDimitry Andric generateGlobalCtorDtorCalls(); 170*bdd1243dSDimitry Andric if (CGM.getCodeGenOpts().OptimizationLevel == 0) 171*bdd1243dSDimitry Andric addDisableOptimizations(M); 172*bdd1243dSDimitry Andric 173*bdd1243dSDimitry Andric const DataLayout &DL = M.getDataLayout(); 174*bdd1243dSDimitry Andric 175*bdd1243dSDimitry Andric for (auto &Buf : Buffers) { 176*bdd1243dSDimitry Andric layoutBuffer(Buf, DL); 177*bdd1243dSDimitry Andric GlobalVariable *GV = replaceBuffer(Buf); 178*bdd1243dSDimitry Andric M.getGlobalList().push_back(GV); 179*bdd1243dSDimitry Andric llvm::hlsl::ResourceClass RC = Buf.IsCBuffer 180*bdd1243dSDimitry Andric ? llvm::hlsl::ResourceClass::CBuffer 181*bdd1243dSDimitry Andric : llvm::hlsl::ResourceClass::SRV; 182*bdd1243dSDimitry Andric llvm::hlsl::ResourceKind RK = Buf.IsCBuffer 183*bdd1243dSDimitry Andric ? llvm::hlsl::ResourceKind::CBuffer 184*bdd1243dSDimitry Andric : llvm::hlsl::ResourceKind::TBuffer; 185*bdd1243dSDimitry Andric std::string TyName = 186*bdd1243dSDimitry Andric Buf.Name.str() + (Buf.IsCBuffer ? ".cb." : ".tb.") + "ty"; 187*bdd1243dSDimitry Andric addBufferResourceAnnotation(GV, TyName, RC, RK, Buf.Binding); 188*bdd1243dSDimitry Andric } 189*bdd1243dSDimitry Andric } 190*bdd1243dSDimitry Andric 191*bdd1243dSDimitry Andric CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D) 192*bdd1243dSDimitry Andric : Name(D->getName()), IsCBuffer(D->isCBuffer()), 193*bdd1243dSDimitry Andric Binding(D->getAttr<HLSLResourceBindingAttr>()) {} 194*bdd1243dSDimitry Andric 195*bdd1243dSDimitry Andric void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV, 196*bdd1243dSDimitry Andric llvm::StringRef TyName, 197*bdd1243dSDimitry Andric llvm::hlsl::ResourceClass RC, 198*bdd1243dSDimitry Andric llvm::hlsl::ResourceKind RK, 199*bdd1243dSDimitry Andric BufferResBinding &Binding) { 200*bdd1243dSDimitry Andric llvm::Module &M = CGM.getModule(); 201*bdd1243dSDimitry Andric 202*bdd1243dSDimitry Andric NamedMDNode *ResourceMD = nullptr; 203*bdd1243dSDimitry Andric switch (RC) { 204*bdd1243dSDimitry Andric case llvm::hlsl::ResourceClass::UAV: 205*bdd1243dSDimitry Andric ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs"); 206*bdd1243dSDimitry Andric break; 207*bdd1243dSDimitry Andric case llvm::hlsl::ResourceClass::SRV: 208*bdd1243dSDimitry Andric ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs"); 209*bdd1243dSDimitry Andric break; 210*bdd1243dSDimitry Andric case llvm::hlsl::ResourceClass::CBuffer: 211*bdd1243dSDimitry Andric ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs"); 212*bdd1243dSDimitry Andric break; 213*bdd1243dSDimitry Andric default: 214*bdd1243dSDimitry Andric assert(false && "Unsupported buffer type!"); 215*bdd1243dSDimitry Andric return; 216*bdd1243dSDimitry Andric } 217*bdd1243dSDimitry Andric 218*bdd1243dSDimitry Andric assert(ResourceMD != nullptr && 219*bdd1243dSDimitry Andric "ResourceMD must have been set by the switch above."); 220*bdd1243dSDimitry Andric 221*bdd1243dSDimitry Andric llvm::hlsl::FrontendResource Res( 222*bdd1243dSDimitry Andric GV, TyName, RK, Binding.Reg.value_or(UINT_MAX), Binding.Space); 223*bdd1243dSDimitry Andric ResourceMD->addOperand(Res.getMetadata()); 224*bdd1243dSDimitry Andric } 225*bdd1243dSDimitry Andric 226*bdd1243dSDimitry Andric static llvm::hlsl::ResourceKind 227*bdd1243dSDimitry Andric castResourceShapeToResourceKind(HLSLResourceAttr::ResourceKind RK) { 228*bdd1243dSDimitry Andric switch (RK) { 229*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::Texture1D: 230*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::Texture1D; 231*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::Texture2D: 232*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::Texture2D; 233*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::Texture2DMS: 234*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::Texture2DMS; 235*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::Texture3D: 236*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::Texture3D; 237*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::TextureCube: 238*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::TextureCube; 239*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::Texture1DArray: 240*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::Texture1DArray; 241*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::Texture2DArray: 242*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::Texture2DArray; 243*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::Texture2DMSArray: 244*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::Texture2DMSArray; 245*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::TextureCubeArray: 246*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::TextureCubeArray; 247*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::TypedBuffer: 248*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::TypedBuffer; 249*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::RawBuffer: 250*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::RawBuffer; 251*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::StructuredBuffer: 252*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::StructuredBuffer; 253*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::CBufferKind: 254*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::CBuffer; 255*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::SamplerKind: 256*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::Sampler; 257*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::TBuffer: 258*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::TBuffer; 259*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::RTAccelerationStructure: 260*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::RTAccelerationStructure; 261*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::FeedbackTexture2D: 262*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::FeedbackTexture2D; 263*bdd1243dSDimitry Andric case HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray: 264*bdd1243dSDimitry Andric return llvm::hlsl::ResourceKind::FeedbackTexture2DArray; 265*bdd1243dSDimitry Andric } 266*bdd1243dSDimitry Andric // Make sure to update HLSLResourceAttr::ResourceKind when add new Kind to 267*bdd1243dSDimitry Andric // hlsl::ResourceKind. Assume FeedbackTexture2DArray is the last enum for 268*bdd1243dSDimitry Andric // HLSLResourceAttr::ResourceKind. 269*bdd1243dSDimitry Andric static_assert( 270*bdd1243dSDimitry Andric static_cast<uint32_t>( 271*bdd1243dSDimitry Andric HLSLResourceAttr::ResourceKind::FeedbackTexture2DArray) == 272*bdd1243dSDimitry Andric (static_cast<uint32_t>(llvm::hlsl::ResourceKind::NumEntries) - 2)); 273*bdd1243dSDimitry Andric llvm_unreachable("all switch cases should be covered"); 274*bdd1243dSDimitry Andric } 275*bdd1243dSDimitry Andric 276*bdd1243dSDimitry Andric void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) { 277*bdd1243dSDimitry Andric const Type *Ty = D->getType()->getPointeeOrArrayElementType(); 278*bdd1243dSDimitry Andric if (!Ty) 279*bdd1243dSDimitry Andric return; 280*bdd1243dSDimitry Andric const auto *RD = Ty->getAsCXXRecordDecl(); 281*bdd1243dSDimitry Andric if (!RD) 282*bdd1243dSDimitry Andric return; 283*bdd1243dSDimitry Andric const auto *Attr = RD->getAttr<HLSLResourceAttr>(); 284*bdd1243dSDimitry Andric if (!Attr) 285*bdd1243dSDimitry Andric return; 286*bdd1243dSDimitry Andric 287*bdd1243dSDimitry Andric HLSLResourceAttr::ResourceClass RC = Attr->getResourceType(); 288*bdd1243dSDimitry Andric llvm::hlsl::ResourceKind RK = 289*bdd1243dSDimitry Andric castResourceShapeToResourceKind(Attr->getResourceShape()); 290*bdd1243dSDimitry Andric 291*bdd1243dSDimitry Andric QualType QT(Ty, 0); 292*bdd1243dSDimitry Andric BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>()); 293*bdd1243dSDimitry Andric addBufferResourceAnnotation(GV, QT.getAsString(), 294*bdd1243dSDimitry Andric static_cast<llvm::hlsl::ResourceClass>(RC), RK, 295*bdd1243dSDimitry Andric Binding); 296*bdd1243dSDimitry Andric } 297*bdd1243dSDimitry Andric 298*bdd1243dSDimitry Andric CGHLSLRuntime::BufferResBinding::BufferResBinding( 299*bdd1243dSDimitry Andric HLSLResourceBindingAttr *Binding) { 300*bdd1243dSDimitry Andric if (Binding) { 301*bdd1243dSDimitry Andric llvm::APInt RegInt(64, 0); 302*bdd1243dSDimitry Andric Binding->getSlot().substr(1).getAsInteger(10, RegInt); 303*bdd1243dSDimitry Andric Reg = RegInt.getLimitedValue(); 304*bdd1243dSDimitry Andric llvm::APInt SpaceInt(64, 0); 305*bdd1243dSDimitry Andric Binding->getSpace().substr(5).getAsInteger(10, SpaceInt); 306*bdd1243dSDimitry Andric Space = SpaceInt.getLimitedValue(); 307*bdd1243dSDimitry Andric } else { 308*bdd1243dSDimitry Andric Space = 0; 309*bdd1243dSDimitry Andric } 310*bdd1243dSDimitry Andric } 311*bdd1243dSDimitry Andric 312*bdd1243dSDimitry Andric void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes( 313*bdd1243dSDimitry Andric const FunctionDecl *FD, llvm::Function *Fn) { 314*bdd1243dSDimitry Andric const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); 315*bdd1243dSDimitry Andric assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr"); 316*bdd1243dSDimitry Andric const StringRef ShaderAttrKindStr = "hlsl.shader"; 317*bdd1243dSDimitry Andric Fn->addFnAttr(ShaderAttrKindStr, 318*bdd1243dSDimitry Andric ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType())); 319*bdd1243dSDimitry Andric if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) { 320*bdd1243dSDimitry Andric const StringRef NumThreadsKindStr = "hlsl.numthreads"; 321*bdd1243dSDimitry Andric std::string NumThreadsStr = 322*bdd1243dSDimitry Andric formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(), 323*bdd1243dSDimitry Andric NumThreadsAttr->getZ()); 324*bdd1243dSDimitry Andric Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr); 325*bdd1243dSDimitry Andric } 326*bdd1243dSDimitry Andric } 327*bdd1243dSDimitry Andric 328*bdd1243dSDimitry Andric static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) { 329*bdd1243dSDimitry Andric if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) { 330*bdd1243dSDimitry Andric Value *Result = PoisonValue::get(Ty); 331*bdd1243dSDimitry Andric for (unsigned I = 0; I < VT->getNumElements(); ++I) { 332*bdd1243dSDimitry Andric Value *Elt = B.CreateCall(F, {B.getInt32(I)}); 333*bdd1243dSDimitry Andric Result = B.CreateInsertElement(Result, Elt, I); 334*bdd1243dSDimitry Andric } 335*bdd1243dSDimitry Andric return Result; 336*bdd1243dSDimitry Andric } 337*bdd1243dSDimitry Andric return B.CreateCall(F, {B.getInt32(0)}); 338*bdd1243dSDimitry Andric } 339*bdd1243dSDimitry Andric 340*bdd1243dSDimitry Andric llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, 341*bdd1243dSDimitry Andric const ParmVarDecl &D, 342*bdd1243dSDimitry Andric llvm::Type *Ty) { 343*bdd1243dSDimitry Andric assert(D.hasAttrs() && "Entry parameter missing annotation attribute!"); 344*bdd1243dSDimitry Andric if (D.hasAttr<HLSLSV_GroupIndexAttr>()) { 345*bdd1243dSDimitry Andric llvm::Function *DxGroupIndex = 346*bdd1243dSDimitry Andric CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group); 347*bdd1243dSDimitry Andric return B.CreateCall(FunctionCallee(DxGroupIndex)); 348*bdd1243dSDimitry Andric } 349*bdd1243dSDimitry Andric if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) { 350*bdd1243dSDimitry Andric llvm::Function *DxThreadID = CGM.getIntrinsic(Intrinsic::dx_thread_id); 351*bdd1243dSDimitry Andric return buildVectorInput(B, DxThreadID, Ty); 352*bdd1243dSDimitry Andric } 353*bdd1243dSDimitry Andric assert(false && "Unhandled parameter attribute"); 354*bdd1243dSDimitry Andric return nullptr; 355*bdd1243dSDimitry Andric } 356*bdd1243dSDimitry Andric 357*bdd1243dSDimitry Andric void CGHLSLRuntime::emitEntryFunction(const FunctionDecl *FD, 358*bdd1243dSDimitry Andric llvm::Function *Fn) { 359*bdd1243dSDimitry Andric llvm::Module &M = CGM.getModule(); 360*bdd1243dSDimitry Andric llvm::LLVMContext &Ctx = M.getContext(); 361*bdd1243dSDimitry Andric auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false); 362*bdd1243dSDimitry Andric Function *EntryFn = 363*bdd1243dSDimitry Andric Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M); 364*bdd1243dSDimitry Andric 365*bdd1243dSDimitry Andric // Copy function attributes over, we have no argument or return attributes 366*bdd1243dSDimitry Andric // that can be valid on the real entry. 367*bdd1243dSDimitry Andric AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex, 368*bdd1243dSDimitry Andric Fn->getAttributes().getFnAttrs()); 369*bdd1243dSDimitry Andric EntryFn->setAttributes(NewAttrs); 370*bdd1243dSDimitry Andric setHLSLEntryAttributes(FD, EntryFn); 371*bdd1243dSDimitry Andric 372*bdd1243dSDimitry Andric // Set the called function as internal linkage. 373*bdd1243dSDimitry Andric Fn->setLinkage(GlobalValue::InternalLinkage); 374*bdd1243dSDimitry Andric 375*bdd1243dSDimitry Andric BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn); 376*bdd1243dSDimitry Andric IRBuilder<> B(BB); 377*bdd1243dSDimitry Andric llvm::SmallVector<Value *> Args; 378*bdd1243dSDimitry Andric // FIXME: support struct parameters where semantics are on members. 379*bdd1243dSDimitry Andric // See: https://github.com/llvm/llvm-project/issues/57874 380*bdd1243dSDimitry Andric unsigned SRetOffset = 0; 381*bdd1243dSDimitry Andric for (const auto &Param : Fn->args()) { 382*bdd1243dSDimitry Andric if (Param.hasStructRetAttr()) { 383*bdd1243dSDimitry Andric // FIXME: support output. 384*bdd1243dSDimitry Andric // See: https://github.com/llvm/llvm-project/issues/57874 385*bdd1243dSDimitry Andric SRetOffset = 1; 386*bdd1243dSDimitry Andric Args.emplace_back(PoisonValue::get(Param.getType())); 387*bdd1243dSDimitry Andric continue; 388*bdd1243dSDimitry Andric } 389*bdd1243dSDimitry Andric const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset); 390*bdd1243dSDimitry Andric Args.push_back(emitInputSemantic(B, *PD, Param.getType())); 391*bdd1243dSDimitry Andric } 392*bdd1243dSDimitry Andric 393*bdd1243dSDimitry Andric CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args); 394*bdd1243dSDimitry Andric (void)CI; 395*bdd1243dSDimitry Andric // FIXME: Handle codegen for return type semantics. 396*bdd1243dSDimitry Andric // See: https://github.com/llvm/llvm-project/issues/57875 397*bdd1243dSDimitry Andric B.CreateRetVoid(); 398*bdd1243dSDimitry Andric } 399*bdd1243dSDimitry Andric 400*bdd1243dSDimitry Andric static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M, 401*bdd1243dSDimitry Andric bool CtorOrDtor) { 402*bdd1243dSDimitry Andric const auto *GV = 403*bdd1243dSDimitry Andric M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors"); 404*bdd1243dSDimitry Andric if (!GV) 405*bdd1243dSDimitry Andric return; 406*bdd1243dSDimitry Andric const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer()); 407*bdd1243dSDimitry Andric if (!CA) 408*bdd1243dSDimitry Andric return; 409*bdd1243dSDimitry Andric // The global_ctor array elements are a struct [Priority, Fn *, COMDat]. 410*bdd1243dSDimitry Andric // HLSL neither supports priorities or COMDat values, so we will check those 411*bdd1243dSDimitry Andric // in an assert but not handle them. 412*bdd1243dSDimitry Andric 413*bdd1243dSDimitry Andric llvm::SmallVector<Function *> CtorFns; 414*bdd1243dSDimitry Andric for (const auto &Ctor : CA->operands()) { 415*bdd1243dSDimitry Andric if (isa<ConstantAggregateZero>(Ctor)) 416*bdd1243dSDimitry Andric continue; 417*bdd1243dSDimitry Andric ConstantStruct *CS = cast<ConstantStruct>(Ctor); 418*bdd1243dSDimitry Andric 419*bdd1243dSDimitry Andric assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 && 420*bdd1243dSDimitry Andric "HLSL doesn't support setting priority for global ctors."); 421*bdd1243dSDimitry Andric assert(isa<ConstantPointerNull>(CS->getOperand(2)) && 422*bdd1243dSDimitry Andric "HLSL doesn't support COMDat for global ctors."); 423*bdd1243dSDimitry Andric Fns.push_back(cast<Function>(CS->getOperand(1))); 424*bdd1243dSDimitry Andric } 425*bdd1243dSDimitry Andric } 426*bdd1243dSDimitry Andric 427*bdd1243dSDimitry Andric void CGHLSLRuntime::generateGlobalCtorDtorCalls() { 428*bdd1243dSDimitry Andric llvm::Module &M = CGM.getModule(); 429*bdd1243dSDimitry Andric SmallVector<Function *> CtorFns; 430*bdd1243dSDimitry Andric SmallVector<Function *> DtorFns; 431*bdd1243dSDimitry Andric gatherFunctions(CtorFns, M, true); 432*bdd1243dSDimitry Andric gatherFunctions(DtorFns, M, false); 433*bdd1243dSDimitry Andric 434*bdd1243dSDimitry Andric // Insert a call to the global constructor at the beginning of the entry block 435*bdd1243dSDimitry Andric // to externally exported functions. This is a bit of a hack, but HLSL allows 436*bdd1243dSDimitry Andric // global constructors, but doesn't support driver initialization of globals. 437*bdd1243dSDimitry Andric for (auto &F : M.functions()) { 438*bdd1243dSDimitry Andric if (!F.hasFnAttribute("hlsl.shader")) 439*bdd1243dSDimitry Andric continue; 440*bdd1243dSDimitry Andric IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin()); 441*bdd1243dSDimitry Andric for (auto *Fn : CtorFns) 442*bdd1243dSDimitry Andric B.CreateCall(FunctionCallee(Fn)); 443*bdd1243dSDimitry Andric 444*bdd1243dSDimitry Andric // Insert global dtors before the terminator of the last instruction 445*bdd1243dSDimitry Andric B.SetInsertPoint(F.back().getTerminator()); 446*bdd1243dSDimitry Andric for (auto *Fn : DtorFns) 447*bdd1243dSDimitry Andric B.CreateCall(FunctionCallee(Fn)); 448*bdd1243dSDimitry Andric } 449*bdd1243dSDimitry Andric 450*bdd1243dSDimitry Andric // No need to keep global ctors/dtors for non-lib profile after call to 451*bdd1243dSDimitry Andric // ctors/dtors added for entry. 452*bdd1243dSDimitry Andric Triple T(M.getTargetTriple()); 453*bdd1243dSDimitry Andric if (T.getEnvironment() != Triple::EnvironmentType::Library) { 454*bdd1243dSDimitry Andric if (auto *GV = M.getNamedGlobal("llvm.global_ctors")) 455*bdd1243dSDimitry Andric GV->eraseFromParent(); 456*bdd1243dSDimitry Andric if (auto *GV = M.getNamedGlobal("llvm.global_dtors")) 457*bdd1243dSDimitry Andric GV->eraseFromParent(); 458*bdd1243dSDimitry Andric } 45981ad6265SDimitry Andric } 460