1 //===- DXContainerGlobals.cpp - DXContainer global generator pass ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // DXContainerGlobalsPass implementation. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "DXILRootSignature.h" 14 #include "DXILShaderFlags.h" 15 #include "DirectX.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/ADT/StringExtras.h" 18 #include "llvm/ADT/StringRef.h" 19 #include "llvm/Analysis/DXILMetadataAnalysis.h" 20 #include "llvm/Analysis/DXILResource.h" 21 #include "llvm/BinaryFormat/DXContainer.h" 22 #include "llvm/CodeGen/Passes.h" 23 #include "llvm/IR/Constants.h" 24 #include "llvm/IR/Module.h" 25 #include "llvm/InitializePasses.h" 26 #include "llvm/MC/DXContainerPSVInfo.h" 27 #include "llvm/Pass.h" 28 #include "llvm/Support/MD5.h" 29 #include "llvm/TargetParser/Triple.h" 30 #include "llvm/Transforms/Utils/ModuleUtils.h" 31 #include <optional> 32 33 using namespace llvm; 34 using namespace llvm::dxil; 35 using namespace llvm::mcdxbc; 36 37 namespace { 38 class DXContainerGlobals : public llvm::ModulePass { 39 40 GlobalVariable *buildContainerGlobal(Module &M, Constant *Content, 41 StringRef Name, StringRef SectionName); 42 GlobalVariable *getFeatureFlags(Module &M); 43 GlobalVariable *computeShaderHash(Module &M); 44 GlobalVariable *buildSignature(Module &M, Signature &Sig, StringRef Name, 45 StringRef SectionName); 46 void addSignature(Module &M, SmallVector<GlobalValue *> &Globals); 47 void addRootSignature(Module &M, SmallVector<GlobalValue *> &Globals); 48 void addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV); 49 void addPipelineStateValidationInfo(Module &M, 50 SmallVector<GlobalValue *> &Globals); 51 52 public: 53 static char ID; // Pass identification, replacement for typeid 54 DXContainerGlobals() : ModulePass(ID) {} 55 56 StringRef getPassName() const override { 57 return "DXContainer Global Emitter"; 58 } 59 60 bool runOnModule(Module &M) override; 61 62 void getAnalysisUsage(AnalysisUsage &AU) const override { 63 AU.setPreservesAll(); 64 AU.addRequired<ShaderFlagsAnalysisWrapper>(); 65 AU.addRequired<RootSignatureAnalysisWrapper>(); 66 AU.addRequired<DXILMetadataAnalysisWrapperPass>(); 67 AU.addRequired<DXILResourceTypeWrapperPass>(); 68 AU.addRequired<DXILResourceWrapperPass>(); 69 } 70 }; 71 72 } // namespace 73 74 bool DXContainerGlobals::runOnModule(Module &M) { 75 llvm::SmallVector<GlobalValue *> Globals; 76 Globals.push_back(getFeatureFlags(M)); 77 Globals.push_back(computeShaderHash(M)); 78 addSignature(M, Globals); 79 addRootSignature(M, Globals); 80 addPipelineStateValidationInfo(M, Globals); 81 appendToCompilerUsed(M, Globals); 82 return true; 83 } 84 85 GlobalVariable *DXContainerGlobals::getFeatureFlags(Module &M) { 86 uint64_t CombinedFeatureFlags = getAnalysis<ShaderFlagsAnalysisWrapper>() 87 .getShaderFlags() 88 .getCombinedFlags() 89 .getFeatureFlags(); 90 91 Constant *FeatureFlagsConstant = 92 ConstantInt::get(M.getContext(), APInt(64, CombinedFeatureFlags)); 93 return buildContainerGlobal(M, FeatureFlagsConstant, "dx.sfi0", "SFI0"); 94 } 95 96 GlobalVariable *DXContainerGlobals::computeShaderHash(Module &M) { 97 auto *DXILConstant = 98 cast<ConstantDataArray>(M.getNamedGlobal("dx.dxil")->getInitializer()); 99 MD5 Digest; 100 Digest.update(DXILConstant->getRawDataValues()); 101 MD5::MD5Result Result = Digest.final(); 102 103 dxbc::ShaderHash HashData = {0, {0}}; 104 // The Hash's IncludesSource flag gets set whenever the hashed shader includes 105 // debug information. 106 if (M.debug_compile_units_begin() != M.debug_compile_units_end()) 107 HashData.Flags = static_cast<uint32_t>(dxbc::HashFlags::IncludesSource); 108 109 memcpy(reinterpret_cast<void *>(&HashData.Digest), Result.data(), 16); 110 if (sys::IsBigEndianHost) 111 HashData.swapBytes(); 112 StringRef Data(reinterpret_cast<char *>(&HashData), sizeof(dxbc::ShaderHash)); 113 114 Constant *ModuleConstant = 115 ConstantDataArray::get(M.getContext(), arrayRefFromStringRef(Data)); 116 return buildContainerGlobal(M, ModuleConstant, "dx.hash", "HASH"); 117 } 118 119 GlobalVariable *DXContainerGlobals::buildContainerGlobal( 120 Module &M, Constant *Content, StringRef Name, StringRef SectionName) { 121 auto *GV = new llvm::GlobalVariable( 122 M, Content->getType(), true, GlobalValue::PrivateLinkage, Content, Name); 123 GV->setSection(SectionName); 124 GV->setAlignment(Align(4)); 125 return GV; 126 } 127 128 GlobalVariable *DXContainerGlobals::buildSignature(Module &M, Signature &Sig, 129 StringRef Name, 130 StringRef SectionName) { 131 SmallString<256> Data; 132 raw_svector_ostream OS(Data); 133 Sig.write(OS); 134 Constant *Constant = 135 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false); 136 return buildContainerGlobal(M, Constant, Name, SectionName); 137 } 138 139 void DXContainerGlobals::addSignature(Module &M, 140 SmallVector<GlobalValue *> &Globals) { 141 // FIXME: support graphics shader. 142 // see issue https://github.com/llvm/llvm-project/issues/90504. 143 144 Signature InputSig; 145 Globals.emplace_back(buildSignature(M, InputSig, "dx.isg1", "ISG1")); 146 147 Signature OutputSig; 148 Globals.emplace_back(buildSignature(M, OutputSig, "dx.osg1", "OSG1")); 149 } 150 151 void DXContainerGlobals::addRootSignature(Module &M, 152 SmallVector<GlobalValue *> &Globals) { 153 154 dxil::ModuleMetadataInfo &MMI = 155 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); 156 157 // Root Signature in Library don't compile to DXContainer. 158 if (MMI.ShaderProfile == llvm::Triple::Library) 159 return; 160 161 assert(MMI.EntryPropertyVec.size() == 1); 162 163 auto &RSA = getAnalysis<RootSignatureAnalysisWrapper>().getRSInfo(); 164 const Function *EntryFunction = MMI.EntryPropertyVec[0].Entry; 165 const std::optional<mcdxbc::RootSignatureDesc> &RS = 166 RSA.getDescForFunction(EntryFunction); 167 168 if (!RS) 169 return; 170 171 SmallString<256> Data; 172 raw_svector_ostream OS(Data); 173 174 RS->write(OS); 175 176 Constant *Constant = 177 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false); 178 Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.rts0", "RTS0")); 179 } 180 181 void DXContainerGlobals::addResourcesForPSV(Module &M, PSVRuntimeInfo &PSV) { 182 const DXILResourceMap &DRM = 183 getAnalysis<DXILResourceWrapperPass>().getResourceMap(); 184 DXILResourceTypeMap &DRTM = 185 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); 186 187 auto MakeBinding = 188 [](const dxil::ResourceInfo::ResourceBinding &Binding, 189 const dxbc::PSV::ResourceType Type, const dxil::ResourceKind Kind, 190 const dxbc::PSV::ResourceFlags Flags = dxbc::PSV::ResourceFlags()) { 191 dxbc::PSV::v2::ResourceBindInfo BindInfo; 192 BindInfo.Type = Type; 193 BindInfo.LowerBound = Binding.LowerBound; 194 BindInfo.UpperBound = Binding.LowerBound + Binding.Size - 1; 195 BindInfo.Space = Binding.Space; 196 BindInfo.Kind = static_cast<dxbc::PSV::ResourceKind>(Kind); 197 BindInfo.Flags = Flags; 198 return BindInfo; 199 }; 200 201 for (const dxil::ResourceInfo &RI : DRM.cbuffers()) { 202 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding(); 203 PSV.Resources.push_back(MakeBinding(Binding, dxbc::PSV::ResourceType::CBV, 204 dxil::ResourceKind::CBuffer)); 205 } 206 for (const dxil::ResourceInfo &RI : DRM.samplers()) { 207 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding(); 208 PSV.Resources.push_back(MakeBinding(Binding, 209 dxbc::PSV::ResourceType::Sampler, 210 dxil::ResourceKind::Sampler)); 211 } 212 for (const dxil::ResourceInfo &RI : DRM.srvs()) { 213 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding(); 214 215 dxil::ResourceTypeInfo &TypeInfo = DRTM[RI.getHandleTy()]; 216 dxbc::PSV::ResourceType ResType; 217 if (TypeInfo.isStruct()) 218 ResType = dxbc::PSV::ResourceType::SRVStructured; 219 else if (TypeInfo.isTyped()) 220 ResType = dxbc::PSV::ResourceType::SRVTyped; 221 else 222 ResType = dxbc::PSV::ResourceType::SRVRaw; 223 224 PSV.Resources.push_back( 225 MakeBinding(Binding, ResType, TypeInfo.getResourceKind())); 226 } 227 for (const dxil::ResourceInfo &RI : DRM.uavs()) { 228 const dxil::ResourceInfo::ResourceBinding &Binding = RI.getBinding(); 229 230 dxil::ResourceTypeInfo &TypeInfo = DRTM[RI.getHandleTy()]; 231 dxbc::PSV::ResourceType ResType; 232 if (RI.hasCounter()) 233 ResType = dxbc::PSV::ResourceType::UAVStructuredWithCounter; 234 else if (TypeInfo.isStruct()) 235 ResType = dxbc::PSV::ResourceType::UAVStructured; 236 else if (TypeInfo.isTyped()) 237 ResType = dxbc::PSV::ResourceType::UAVTyped; 238 else 239 ResType = dxbc::PSV::ResourceType::UAVRaw; 240 241 dxbc::PSV::ResourceFlags Flags; 242 // TODO: Add support for dxbc::PSV::ResourceFlag::UsedByAtomic64, tracking 243 // with https://github.com/llvm/llvm-project/issues/104392 244 Flags.Flags = 0u; 245 246 PSV.Resources.push_back( 247 MakeBinding(Binding, ResType, TypeInfo.getResourceKind(), Flags)); 248 } 249 } 250 251 void DXContainerGlobals::addPipelineStateValidationInfo( 252 Module &M, SmallVector<GlobalValue *> &Globals) { 253 SmallString<256> Data; 254 raw_svector_ostream OS(Data); 255 PSVRuntimeInfo PSV; 256 PSV.BaseData.MinimumWaveLaneCount = 0; 257 PSV.BaseData.MaximumWaveLaneCount = std::numeric_limits<uint32_t>::max(); 258 259 dxil::ModuleMetadataInfo &MMI = 260 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); 261 assert(MMI.EntryPropertyVec.size() == 1 || 262 MMI.ShaderProfile == Triple::Library); 263 PSV.BaseData.ShaderStage = 264 static_cast<uint8_t>(MMI.ShaderProfile - Triple::Pixel); 265 266 addResourcesForPSV(M, PSV); 267 268 // Hardcoded values here to unblock loading the shader into D3D. 269 // 270 // TODO: Lots more stuff to do here! 271 // 272 // See issue https://github.com/llvm/llvm-project/issues/96674. 273 switch (MMI.ShaderProfile) { 274 case Triple::Compute: 275 PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX; 276 PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY; 277 PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ; 278 break; 279 default: 280 break; 281 } 282 283 if (MMI.ShaderProfile != Triple::Library) 284 PSV.EntryName = MMI.EntryPropertyVec[0].Entry->getName(); 285 286 PSV.finalize(MMI.ShaderProfile); 287 PSV.write(OS); 288 Constant *Constant = 289 ConstantDataArray::getString(M.getContext(), Data, /*AddNull*/ false); 290 Globals.emplace_back(buildContainerGlobal(M, Constant, "dx.psv0", "PSV0")); 291 } 292 293 char DXContainerGlobals::ID = 0; 294 INITIALIZE_PASS_BEGIN(DXContainerGlobals, "dxil-globals", 295 "DXContainer Global Emitter", false, true) 296 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) 297 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) 298 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) 299 INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) 300 INITIALIZE_PASS_END(DXContainerGlobals, "dxil-globals", 301 "DXContainer Global Emitter", false, true) 302 303 ModulePass *llvm::createDXContainerGlobalsPass() { 304 return new DXContainerGlobals(); 305 } 306