1 //===- DXILTranslateMetadata.cpp - Pass to emit DXIL metadata ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 //===----------------------------------------------------------------------===// 10 11 #include "DirectX.h" 12 #include "llvm/ADT/StringSet.h" 13 #include "llvm/ADT/Triple.h" 14 #include "llvm/IR/Constants.h" 15 #include "llvm/IR/Metadata.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/Pass.h" 18 19 using namespace llvm; 20 21 static uint32_t ConstMDToUint32(const MDOperand &MDO) { 22 ConstantInt *pConst = mdconst::extract<ConstantInt>(MDO); 23 return (uint32_t)pConst->getZExtValue(); 24 } 25 26 static ConstantAsMetadata *Uint32ToConstMD(unsigned v, LLVMContext &Ctx) { 27 return ConstantAsMetadata::get( 28 Constant::getIntegerValue(IntegerType::get(Ctx, 32), APInt(32, v))); 29 } 30 31 constexpr StringLiteral ValVerKey = "dx.valver"; 32 constexpr unsigned DXILVersionNumFields = 2; 33 34 static void emitDXILValidatorVersion(Module &M, VersionTuple &ValidatorVer) { 35 NamedMDNode *DXILValidatorVersionMD = M.getNamedMetadata(ValVerKey); 36 37 // Allow re-writing the validator version, since this can be changed at 38 // later points. 39 if (DXILValidatorVersionMD) 40 M.eraseNamedMetadata(DXILValidatorVersionMD); 41 42 DXILValidatorVersionMD = M.getOrInsertNamedMetadata(ValVerKey); 43 44 auto &Ctx = M.getContext(); 45 Metadata *MDVals[DXILVersionNumFields]; 46 MDVals[0] = Uint32ToConstMD(ValidatorVer.getMajor(), Ctx); 47 MDVals[1] = Uint32ToConstMD(ValidatorVer.getMinor().value_or(0), Ctx); 48 49 DXILValidatorVersionMD->addOperand(MDNode::get(Ctx, MDVals)); 50 } 51 52 static VersionTuple loadDXILValidatorVersion(MDNode *ValVerMD) { 53 if (ValVerMD->getNumOperands() != DXILVersionNumFields) 54 return VersionTuple(); 55 56 unsigned Major = ConstMDToUint32(ValVerMD->getOperand(0)); 57 unsigned Minor = ConstMDToUint32(ValVerMD->getOperand(1)); 58 return VersionTuple(Major, Minor); 59 } 60 61 static void cleanModuleFlags(Module &M) { 62 constexpr StringLiteral DeadKeys[] = {ValVerKey}; 63 // Collect DeadKeys in ModuleFlags. 64 StringSet<> DeadKeySet; 65 for (auto &Key : DeadKeys) { 66 if (M.getModuleFlag(Key)) 67 DeadKeySet.insert(Key); 68 } 69 if (DeadKeySet.empty()) 70 return; 71 72 SmallVector<Module::ModuleFlagEntry, 8> ModuleFlags; 73 M.getModuleFlagsMetadata(ModuleFlags); 74 NamedMDNode *MDFlags = M.getModuleFlagsMetadata(); 75 MDFlags->eraseFromParent(); 76 // Add ModuleFlag which not dead. 77 for (auto &Flag : ModuleFlags) { 78 StringRef Key = Flag.Key->getString(); 79 if (DeadKeySet.contains(Key)) 80 continue; 81 M.addModuleFlag(Flag.Behavior, Key, Flag.Val); 82 } 83 } 84 85 static void cleanModule(Module &M) { cleanModuleFlags(M); } 86 87 namespace { 88 class DXILTranslateMetadata : public ModulePass { 89 public: 90 static char ID; // Pass identification, replacement for typeid 91 explicit DXILTranslateMetadata() : ModulePass(ID), ValidatorVer(1, 0) {} 92 93 StringRef getPassName() const override { return "DXIL Metadata Emit"; } 94 95 bool runOnModule(Module &M) override; 96 97 private: 98 VersionTuple ValidatorVer; 99 }; 100 101 } // namespace 102 103 bool DXILTranslateMetadata::runOnModule(Module &M) { 104 if (MDNode *ValVerMD = cast_or_null<MDNode>(M.getModuleFlag(ValVerKey))) { 105 auto ValVer = loadDXILValidatorVersion(ValVerMD); 106 if (!ValVer.empty()) 107 ValidatorVer = ValVer; 108 } 109 emitDXILValidatorVersion(M, ValidatorVer); 110 cleanModule(M); 111 return false; 112 } 113 114 char DXILTranslateMetadata::ID = 0; 115 116 ModulePass *llvm::createDXILTranslateMetadataPass() { 117 return new DXILTranslateMetadata(); 118 } 119 120 INITIALIZE_PASS(DXILTranslateMetadata, "dxil-metadata-emit", 121 "DXIL Metadata Emit", false, false) 122