1 //===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "DXILTranslateMetadata.h" 10 #include "DXILShaderFlags.h" 11 #include "DirectX.h" 12 #include "llvm/ADT/SmallVector.h" 13 #include "llvm/ADT/Twine.h" 14 #include "llvm/Analysis/DXILMetadataAnalysis.h" 15 #include "llvm/Analysis/DXILResource.h" 16 #include "llvm/IR/BasicBlock.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/IR/DiagnosticInfo.h" 19 #include "llvm/IR/DiagnosticPrinter.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/LLVMContext.h" 23 #include "llvm/IR/MDBuilder.h" 24 #include "llvm/IR/Metadata.h" 25 #include "llvm/IR/Module.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Pass.h" 28 #include "llvm/Support/ErrorHandling.h" 29 #include "llvm/Support/VersionTuple.h" 30 #include "llvm/TargetParser/Triple.h" 31 #include <cstdint> 32 33 using namespace llvm; 34 using namespace llvm::dxil; 35 36 namespace { 37 /// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic 38 /// for TranslateMetadata pass 39 class DiagnosticInfoTranslateMD : public DiagnosticInfo { 40 private: 41 const Twine &Msg; 42 const Module &Mod; 43 44 public: 45 /// \p M is the module for which the diagnostic is being emitted. \p Msg is 46 /// the message to show. Note that this class does not copy this message, so 47 /// this reference must be valid for the whole life time of the diagnostic. 48 DiagnosticInfoTranslateMD(const Module &M, 49 const Twine &Msg LLVM_LIFETIME_BOUND, 50 DiagnosticSeverity Severity = DS_Error) 51 : DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {} 52 53 void print(DiagnosticPrinter &DP) const override { 54 DP << Mod.getName() << ": " << Msg << '\n'; 55 } 56 }; 57 58 enum class EntryPropsTag { 59 ShaderFlags = 0, 60 GSState, 61 DSState, 62 HSState, 63 NumThreads, 64 AutoBindingSpace, 65 RayPayloadSize, 66 RayAttribSize, 67 ShaderKind, 68 MSState, 69 ASStateTag, 70 WaveSize, 71 EntryRootSig, 72 }; 73 74 } // namespace 75 76 static NamedMDNode *emitResourceMetadata(Module &M, DXILResourceMap &DRM, 77 DXILResourceTypeMap &DRTM) { 78 LLVMContext &Context = M.getContext(); 79 80 for (ResourceInfo &RI : DRM) 81 if (!RI.hasSymbol()) 82 RI.createSymbol(M, 83 DRTM[RI.getHandleTy()].createElementStruct(RI.getName())); 84 85 SmallVector<Metadata *> SRVs, UAVs, CBufs, Smps; 86 for (const ResourceInfo &RI : DRM.srvs()) 87 SRVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); 88 for (const ResourceInfo &RI : DRM.uavs()) 89 UAVs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); 90 for (const ResourceInfo &RI : DRM.cbuffers()) 91 CBufs.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); 92 for (const ResourceInfo &RI : DRM.samplers()) 93 Smps.push_back(RI.getAsMetadata(M, DRTM[RI.getHandleTy()])); 94 95 Metadata *SRVMD = SRVs.empty() ? nullptr : MDNode::get(Context, SRVs); 96 Metadata *UAVMD = UAVs.empty() ? nullptr : MDNode::get(Context, UAVs); 97 Metadata *CBufMD = CBufs.empty() ? nullptr : MDNode::get(Context, CBufs); 98 Metadata *SmpMD = Smps.empty() ? nullptr : MDNode::get(Context, Smps); 99 100 if (DRM.empty()) 101 return nullptr; 102 103 NamedMDNode *ResourceMD = M.getOrInsertNamedMetadata("dx.resources"); 104 ResourceMD->addOperand( 105 MDNode::get(M.getContext(), {SRVMD, UAVMD, CBufMD, SmpMD})); 106 107 return ResourceMD; 108 } 109 110 static StringRef getShortShaderStage(Triple::EnvironmentType Env) { 111 switch (Env) { 112 case Triple::Pixel: 113 return "ps"; 114 case Triple::Vertex: 115 return "vs"; 116 case Triple::Geometry: 117 return "gs"; 118 case Triple::Hull: 119 return "hs"; 120 case Triple::Domain: 121 return "ds"; 122 case Triple::Compute: 123 return "cs"; 124 case Triple::Library: 125 return "lib"; 126 case Triple::Mesh: 127 return "ms"; 128 case Triple::Amplification: 129 return "as"; 130 default: 131 break; 132 } 133 llvm_unreachable("Unsupported environment for DXIL generation."); 134 } 135 136 static uint32_t getShaderStage(Triple::EnvironmentType Env) { 137 return (uint32_t)Env - (uint32_t)llvm::Triple::Pixel; 138 } 139 140 static SmallVector<Metadata *> 141 getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) { 142 SmallVector<Metadata *> MDVals; 143 MDVals.emplace_back(ConstantAsMetadata::get( 144 ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag)))); 145 switch (Tag) { 146 case EntryPropsTag::ShaderFlags: 147 MDVals.emplace_back(ConstantAsMetadata::get( 148 ConstantInt::get(Type::getInt64Ty(Ctx), Value))); 149 break; 150 case EntryPropsTag::ShaderKind: 151 MDVals.emplace_back(ConstantAsMetadata::get( 152 ConstantInt::get(Type::getInt32Ty(Ctx), Value))); 153 break; 154 case EntryPropsTag::GSState: 155 case EntryPropsTag::DSState: 156 case EntryPropsTag::HSState: 157 case EntryPropsTag::NumThreads: 158 case EntryPropsTag::AutoBindingSpace: 159 case EntryPropsTag::RayPayloadSize: 160 case EntryPropsTag::RayAttribSize: 161 case EntryPropsTag::MSState: 162 case EntryPropsTag::ASStateTag: 163 case EntryPropsTag::WaveSize: 164 case EntryPropsTag::EntryRootSig: 165 llvm_unreachable("NYI: Unhandled entry property tag"); 166 } 167 return MDVals; 168 } 169 170 static MDTuple * 171 getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags, 172 const Triple::EnvironmentType ShaderProfile) { 173 SmallVector<Metadata *> MDVals; 174 LLVMContext &Ctx = EP.Entry->getContext(); 175 if (EntryShaderFlags != 0) 176 MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderFlags, 177 EntryShaderFlags, Ctx)); 178 179 if (EP.Entry != nullptr) { 180 // FIXME: support more props. 181 // See https://github.com/llvm/llvm-project/issues/57948. 182 // Add shader kind for lib entries. 183 if (ShaderProfile == Triple::EnvironmentType::Library && 184 EP.ShaderStage != Triple::EnvironmentType::Library) 185 MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind, 186 getShaderStage(EP.ShaderStage), Ctx)); 187 188 if (EP.ShaderStage == Triple::EnvironmentType::Compute) { 189 MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get( 190 Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads)))); 191 Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get( 192 Type::getInt32Ty(Ctx), EP.NumThreadsX)), 193 ConstantAsMetadata::get(ConstantInt::get( 194 Type::getInt32Ty(Ctx), EP.NumThreadsY)), 195 ConstantAsMetadata::get(ConstantInt::get( 196 Type::getInt32Ty(Ctx), EP.NumThreadsZ))}; 197 MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals)); 198 } 199 } 200 if (MDVals.empty()) 201 return nullptr; 202 return MDNode::get(Ctx, MDVals); 203 } 204 205 MDTuple *constructEntryMetadata(const Function *EntryFn, MDTuple *Signatures, 206 MDNode *Resources, MDTuple *Properties, 207 LLVMContext &Ctx) { 208 // Each entry point metadata record specifies: 209 // * reference to the entry point function global symbol 210 // * unmangled name 211 // * list of signatures 212 // * list of resources 213 // * list of tag-value pairs of shader capabilities and other properties 214 Metadata *MDVals[5]; 215 MDVals[0] = 216 EntryFn ? ValueAsMetadata::get(const_cast<Function *>(EntryFn)) : nullptr; 217 MDVals[1] = MDString::get(Ctx, EntryFn ? EntryFn->getName() : ""); 218 MDVals[2] = Signatures; 219 MDVals[3] = Resources; 220 MDVals[4] = Properties; 221 return MDNode::get(Ctx, MDVals); 222 } 223 224 static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures, 225 MDNode *MDResources, 226 const uint64_t EntryShaderFlags, 227 const Triple::EnvironmentType ShaderProfile) { 228 MDTuple *Properties = 229 getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile); 230 return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties, 231 EP.Entry->getContext()); 232 } 233 234 static void emitValidatorVersionMD(Module &M, const ModuleMetadataInfo &MMDI) { 235 if (MMDI.ValidatorVersion.empty()) 236 return; 237 238 LLVMContext &Ctx = M.getContext(); 239 IRBuilder<> IRB(Ctx); 240 Metadata *MDVals[2]; 241 MDVals[0] = 242 ConstantAsMetadata::get(IRB.getInt32(MMDI.ValidatorVersion.getMajor())); 243 MDVals[1] = ConstantAsMetadata::get( 244 IRB.getInt32(MMDI.ValidatorVersion.getMinor().value_or(0))); 245 NamedMDNode *ValVerNode = M.getOrInsertNamedMetadata("dx.valver"); 246 // Set validator version obtained from DXIL Metadata Analysis pass 247 ValVerNode->clearOperands(); 248 ValVerNode->addOperand(MDNode::get(Ctx, MDVals)); 249 } 250 251 static void emitShaderModelVersionMD(Module &M, 252 const ModuleMetadataInfo &MMDI) { 253 LLVMContext &Ctx = M.getContext(); 254 IRBuilder<> IRB(Ctx); 255 Metadata *SMVals[3]; 256 VersionTuple SM = MMDI.ShaderModelVersion; 257 SMVals[0] = MDString::get(Ctx, getShortShaderStage(MMDI.ShaderProfile)); 258 SMVals[1] = ConstantAsMetadata::get(IRB.getInt32(SM.getMajor())); 259 SMVals[2] = ConstantAsMetadata::get(IRB.getInt32(SM.getMinor().value_or(0))); 260 NamedMDNode *SMMDNode = M.getOrInsertNamedMetadata("dx.shaderModel"); 261 SMMDNode->addOperand(MDNode::get(Ctx, SMVals)); 262 } 263 264 static void emitDXILVersionTupleMD(Module &M, const ModuleMetadataInfo &MMDI) { 265 LLVMContext &Ctx = M.getContext(); 266 IRBuilder<> IRB(Ctx); 267 VersionTuple DXILVer = MMDI.DXILVersion; 268 Metadata *DXILVals[2]; 269 DXILVals[0] = ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMajor())); 270 DXILVals[1] = 271 ConstantAsMetadata::get(IRB.getInt32(DXILVer.getMinor().value_or(0))); 272 NamedMDNode *DXILVerMDNode = M.getOrInsertNamedMetadata("dx.version"); 273 DXILVerMDNode->addOperand(MDNode::get(Ctx, DXILVals)); 274 } 275 276 static MDTuple *emitTopLevelLibraryNode(Module &M, MDNode *RMD, 277 uint64_t ShaderFlags) { 278 LLVMContext &Ctx = M.getContext(); 279 MDTuple *Properties = nullptr; 280 if (ShaderFlags != 0) { 281 SmallVector<Metadata *> MDVals; 282 MDVals.append( 283 getTagValueAsMetadata(EntryPropsTag::ShaderFlags, ShaderFlags, Ctx)); 284 Properties = MDNode::get(Ctx, MDVals); 285 } 286 // Library has an entry metadata with resource table metadata and all other 287 // MDNodes as null. 288 return constructEntryMetadata(nullptr, nullptr, RMD, Properties, Ctx); 289 } 290 291 // TODO: We might need to refactor this to be more generic, 292 // in case we need more metadata to be replaced. 293 static void translateBranchMetadata(Module &M) { 294 for (Function &F : M) { 295 for (BasicBlock &BB : F) { 296 Instruction *BBTerminatorInst = BB.getTerminator(); 297 298 MDNode *HlslControlFlowMD = 299 BBTerminatorInst->getMetadata("hlsl.controlflow.hint"); 300 301 if (!HlslControlFlowMD) 302 continue; 303 304 assert(HlslControlFlowMD->getNumOperands() == 2 && 305 "invalid operands for hlsl.controlflow.hint"); 306 307 MDBuilder MDHelper(M.getContext()); 308 ConstantInt *Op1 = 309 mdconst::extract<ConstantInt>(HlslControlFlowMD->getOperand(1)); 310 311 SmallVector<llvm::Metadata *, 2> Vals( 312 ArrayRef<Metadata *>{MDHelper.createString("dx.controlflow.hints"), 313 MDHelper.createConstant(Op1)}); 314 315 MDNode *MDNode = llvm::MDNode::get(M.getContext(), Vals); 316 317 BBTerminatorInst->setMetadata("dx.controlflow.hints", MDNode); 318 BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr); 319 } 320 } 321 } 322 323 static void translateMetadata(Module &M, DXILResourceMap &DRM, 324 DXILResourceTypeMap &DRTM, 325 const ModuleShaderFlags &ShaderFlags, 326 const ModuleMetadataInfo &MMDI) { 327 LLVMContext &Ctx = M.getContext(); 328 IRBuilder<> IRB(Ctx); 329 SmallVector<MDNode *> EntryFnMDNodes; 330 331 emitValidatorVersionMD(M, MMDI); 332 emitShaderModelVersionMD(M, MMDI); 333 emitDXILVersionTupleMD(M, MMDI); 334 NamedMDNode *NamedResourceMD = emitResourceMetadata(M, DRM, DRTM); 335 auto *ResourceMD = 336 (NamedResourceMD != nullptr) ? NamedResourceMD->getOperand(0) : nullptr; 337 // FIXME: Add support to construct Signatures 338 // See https://github.com/llvm/llvm-project/issues/57928 339 MDTuple *Signatures = nullptr; 340 341 if (MMDI.ShaderProfile == Triple::EnvironmentType::Library) { 342 // Get the combined shader flag mask of all functions in the library to be 343 // used as shader flags mask value associated with top-level library entry 344 // metadata. 345 uint64_t CombinedMask = ShaderFlags.getCombinedFlags(); 346 EntryFnMDNodes.emplace_back( 347 emitTopLevelLibraryNode(M, ResourceMD, CombinedMask)); 348 } else if (MMDI.EntryPropertyVec.size() > 1) { 349 M.getContext().diagnose(DiagnosticInfoTranslateMD( 350 M, "Non-library shader: One and only one entry expected")); 351 } 352 353 for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) { 354 const ComputedShaderFlags &EntrySFMask = 355 ShaderFlags.getFunctionFlags(EntryProp.Entry); 356 357 // If ShaderProfile is Library, mask is already consolidated in the 358 // top-level library node. Hence it is not emitted. 359 uint64_t EntryShaderFlags = 0; 360 if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) { 361 EntryShaderFlags = EntrySFMask; 362 if (EntryProp.ShaderStage != MMDI.ShaderProfile) { 363 M.getContext().diagnose(DiagnosticInfoTranslateMD( 364 M, 365 "Shader stage '" + 366 Twine(getShortShaderStage(EntryProp.ShaderStage) + 367 "' for entry '" + Twine(EntryProp.Entry->getName()) + 368 "' different from specified target profile '" + 369 Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) + 370 "'")))); 371 } 372 } 373 EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD, 374 EntryShaderFlags, 375 MMDI.ShaderProfile)); 376 } 377 378 NamedMDNode *EntryPointsNamedMD = 379 M.getOrInsertNamedMetadata("dx.entryPoints"); 380 for (auto *Entry : EntryFnMDNodes) 381 EntryPointsNamedMD->addOperand(Entry); 382 } 383 384 PreservedAnalyses DXILTranslateMetadata::run(Module &M, 385 ModuleAnalysisManager &MAM) { 386 DXILResourceMap &DRM = MAM.getResult<DXILResourceAnalysis>(M); 387 DXILResourceTypeMap &DRTM = MAM.getResult<DXILResourceTypeAnalysis>(M); 388 const ModuleShaderFlags &ShaderFlags = MAM.getResult<ShaderFlagsAnalysis>(M); 389 const dxil::ModuleMetadataInfo MMDI = MAM.getResult<DXILMetadataAnalysis>(M); 390 391 translateMetadata(M, DRM, DRTM, ShaderFlags, MMDI); 392 translateBranchMetadata(M); 393 394 return PreservedAnalyses::all(); 395 } 396 397 namespace { 398 class DXILTranslateMetadataLegacy : public ModulePass { 399 public: 400 static char ID; // Pass identification, replacement for typeid 401 explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {} 402 403 StringRef getPassName() const override { return "DXIL Translate Metadata"; } 404 405 void getAnalysisUsage(AnalysisUsage &AU) const override { 406 AU.addRequired<DXILResourceTypeWrapperPass>(); 407 AU.addRequired<DXILResourceWrapperPass>(); 408 AU.addRequired<ShaderFlagsAnalysisWrapper>(); 409 AU.addRequired<DXILMetadataAnalysisWrapperPass>(); 410 AU.addPreserved<DXILResourceWrapperPass>(); 411 AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); 412 AU.addPreserved<ShaderFlagsAnalysisWrapper>(); 413 AU.addPreserved<DXILResourceBindingWrapperPass>(); 414 } 415 416 bool runOnModule(Module &M) override { 417 DXILResourceMap &DRM = 418 getAnalysis<DXILResourceWrapperPass>().getResourceMap(); 419 DXILResourceTypeMap &DRTM = 420 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); 421 const ModuleShaderFlags &ShaderFlags = 422 getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags(); 423 dxil::ModuleMetadataInfo MMDI = 424 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); 425 426 translateMetadata(M, DRM, DRTM, ShaderFlags, MMDI); 427 translateBranchMetadata(M); 428 return true; 429 } 430 }; 431 432 } // namespace 433 434 char DXILTranslateMetadataLegacy::ID = 0; 435 436 ModulePass *llvm::createDXILTranslateMetadataLegacyPass() { 437 return new DXILTranslateMetadataLegacy(); 438 } 439 440 INITIALIZE_PASS_BEGIN(DXILTranslateMetadataLegacy, "dxil-translate-metadata", 441 "DXIL Translate Metadata", false, false) 442 INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass) 443 INITIALIZE_PASS_DEPENDENCY(ShaderFlagsAnalysisWrapper) 444 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) 445 INITIALIZE_PASS_END(DXILTranslateMetadataLegacy, "dxil-translate-metadata", 446 "DXIL Translate Metadata", false, false) 447