1 //===- DXILPrepare.cpp - Prepare LLVM Module for DXIL encoding ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains pases and utilities to convert a modern LLVM 10 /// module into a module compatible with the LLVM 3.7-based DirectX Intermediate 11 /// Language (DXIL). 12 //===----------------------------------------------------------------------===// 13 14 #include "DXILRootSignature.h" 15 #include "DXILShaderFlags.h" 16 #include "DirectX.h" 17 #include "DirectXIRPasses/PointerTypeAnalysis.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/ADT/StringSet.h" 21 #include "llvm/Analysis/DXILMetadataAnalysis.h" 22 #include "llvm/Analysis/DXILResource.h" 23 #include "llvm/CodeGen/Passes.h" 24 #include "llvm/IR/AttributeMask.h" 25 #include "llvm/IR/IRBuilder.h" 26 #include "llvm/IR/Instruction.h" 27 #include "llvm/IR/IntrinsicInst.h" 28 #include "llvm/IR/Module.h" 29 #include "llvm/InitializePasses.h" 30 #include "llvm/Pass.h" 31 #include "llvm/Support/Compiler.h" 32 #include "llvm/Support/VersionTuple.h" 33 34 #define DEBUG_TYPE "dxil-prepare" 35 36 using namespace llvm; 37 using namespace llvm::dxil; 38 39 namespace { 40 41 constexpr bool isValidForDXIL(Attribute::AttrKind Attr) { 42 return is_contained({Attribute::Alignment, 43 Attribute::AlwaysInline, 44 Attribute::Builtin, 45 Attribute::ByVal, 46 Attribute::InAlloca, 47 Attribute::Cold, 48 Attribute::Convergent, 49 Attribute::InlineHint, 50 Attribute::InReg, 51 Attribute::JumpTable, 52 Attribute::MinSize, 53 Attribute::Naked, 54 Attribute::Nest, 55 Attribute::NoAlias, 56 Attribute::NoBuiltin, 57 Attribute::NoDuplicate, 58 Attribute::NoImplicitFloat, 59 Attribute::NoInline, 60 Attribute::NonLazyBind, 61 Attribute::NonNull, 62 Attribute::Dereferenceable, 63 Attribute::DereferenceableOrNull, 64 Attribute::Memory, 65 Attribute::NoRedZone, 66 Attribute::NoReturn, 67 Attribute::NoUnwind, 68 Attribute::OptimizeForSize, 69 Attribute::OptimizeNone, 70 Attribute::ReadNone, 71 Attribute::ReadOnly, 72 Attribute::Returned, 73 Attribute::ReturnsTwice, 74 Attribute::SExt, 75 Attribute::StackAlignment, 76 Attribute::StackProtect, 77 Attribute::StackProtectReq, 78 Attribute::StackProtectStrong, 79 Attribute::SafeStack, 80 Attribute::StructRet, 81 Attribute::SanitizeAddress, 82 Attribute::SanitizeThread, 83 Attribute::SanitizeMemory, 84 Attribute::UWTable, 85 Attribute::ZExt}, 86 Attr); 87 } 88 89 static void collectDeadStringAttrs(AttributeMask &DeadAttrs, AttributeSet &&AS, 90 const StringSet<> &LiveKeys, 91 bool AllowExperimental) { 92 for (auto &Attr : AS) { 93 if (!Attr.isStringAttribute()) 94 continue; 95 StringRef Key = Attr.getKindAsString(); 96 if (LiveKeys.contains(Key)) 97 continue; 98 if (AllowExperimental && Key.starts_with("exp-")) 99 continue; 100 DeadAttrs.addAttribute(Key); 101 } 102 } 103 104 static void removeStringFunctionAttributes(Function &F, 105 bool AllowExperimental) { 106 AttributeList Attrs = F.getAttributes(); 107 const StringSet<> LiveKeys = {"waveops-include-helper-lanes", 108 "fp32-denorm-mode"}; 109 // Collect DeadKeys in FnAttrs. 110 AttributeMask DeadAttrs; 111 collectDeadStringAttrs(DeadAttrs, Attrs.getFnAttrs(), LiveKeys, 112 AllowExperimental); 113 collectDeadStringAttrs(DeadAttrs, Attrs.getRetAttrs(), LiveKeys, 114 AllowExperimental); 115 116 F.removeFnAttrs(DeadAttrs); 117 F.removeRetAttrs(DeadAttrs); 118 } 119 120 static void cleanModuleFlags(Module &M) { 121 NamedMDNode *MDFlags = M.getModuleFlagsMetadata(); 122 if (!MDFlags) 123 return; 124 125 SmallVector<llvm::Module::ModuleFlagEntry> FlagEntries; 126 M.getModuleFlagsMetadata(FlagEntries); 127 bool Updated = false; 128 for (auto &Flag : FlagEntries) { 129 // llvm 3.7 only supports behavior up to AppendUnique. 130 if (Flag.Behavior <= Module::ModFlagBehavior::AppendUnique) 131 continue; 132 Flag.Behavior = Module::ModFlagBehavior::Warning; 133 Updated = true; 134 } 135 136 if (!Updated) 137 return; 138 139 MDFlags->eraseFromParent(); 140 141 for (auto &Flag : FlagEntries) 142 M.addModuleFlag(Flag.Behavior, Flag.Key->getString(), Flag.Val); 143 } 144 145 class DXILPrepareModule : public ModulePass { 146 147 static Value *maybeGenerateBitcast(IRBuilder<> &Builder, 148 PointerTypeMap &PointerTypes, 149 Instruction &Inst, Value *Operand, 150 Type *Ty) { 151 // Omit bitcasts if the incoming value matches the instruction type. 152 auto It = PointerTypes.find(Operand); 153 if (It != PointerTypes.end()) { 154 auto *OpTy = cast<TypedPointerType>(It->second)->getElementType(); 155 if (OpTy == Ty) 156 return nullptr; 157 } 158 159 Type *ValTy = Operand->getType(); 160 // Also omit the bitcast for matching global array types 161 if (auto *GlobalVar = dyn_cast<GlobalVariable>(Operand)) 162 ValTy = GlobalVar->getValueType(); 163 164 if (auto *AI = dyn_cast<AllocaInst>(Operand)) 165 ValTy = AI->getAllocatedType(); 166 167 if (auto *ArrTy = dyn_cast<ArrayType>(ValTy)) { 168 Type *ElTy = ArrTy->getElementType(); 169 if (ElTy == Ty) 170 return nullptr; 171 } 172 173 // finally, drill down GEP instructions until we get the array 174 // that is being accessed, and compare element types 175 if (ConstantExpr *GEPInstr = dyn_cast<ConstantExpr>(Operand)) { 176 while (GEPInstr->getOpcode() == Instruction::GetElementPtr) { 177 Value *OpArg = GEPInstr->getOperand(0); 178 if (ConstantExpr *NewGEPInstr = dyn_cast<ConstantExpr>(OpArg)) { 179 GEPInstr = NewGEPInstr; 180 continue; 181 } 182 183 if (auto *GlobalVar = dyn_cast<GlobalVariable>(OpArg)) 184 ValTy = GlobalVar->getValueType(); 185 if (auto *AI = dyn_cast<AllocaInst>(Operand)) 186 ValTy = AI->getAllocatedType(); 187 if (auto *ArrTy = dyn_cast<ArrayType>(ValTy)) { 188 Type *ElTy = ArrTy->getElementType(); 189 if (ElTy == Ty) 190 return nullptr; 191 } 192 break; 193 } 194 } 195 196 // Insert bitcasts where we are removing the instruction. 197 Builder.SetInsertPoint(&Inst); 198 // This code only gets hit in opaque-pointer mode, so the type of the 199 // pointer doesn't matter. 200 PointerType *PtrTy = cast<PointerType>(Operand->getType()); 201 return Builder.Insert( 202 CastInst::Create(Instruction::BitCast, Operand, 203 Builder.getPtrTy(PtrTy->getAddressSpace()))); 204 } 205 206 static std::array<unsigned, 6> getCompatibleInstructionMDs(llvm::Module &M) { 207 return {M.getMDKindID("dx.nonuniform"), 208 M.getMDKindID("dx.controlflow.hints"), 209 M.getMDKindID("dx.precise"), 210 llvm::LLVMContext::MD_range, 211 llvm::LLVMContext::MD_alias_scope, 212 llvm::LLVMContext::MD_noalias}; 213 } 214 215 public: 216 bool runOnModule(Module &M) override { 217 PointerTypeMap PointerTypes = PointerTypeAnalysis::run(M); 218 AttributeMask AttrMask; 219 for (Attribute::AttrKind I = Attribute::None; I != Attribute::EndAttrKinds; 220 I = Attribute::AttrKind(I + 1)) { 221 if (!isValidForDXIL(I)) 222 AttrMask.addAttribute(I); 223 } 224 225 const dxil::ModuleMetadataInfo MetadataInfo = 226 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); 227 VersionTuple ValVer = MetadataInfo.ValidatorVersion; 228 bool SkipValidation = ValVer.getMajor() == 0 && ValVer.getMinor() == 0; 229 230 // construct allowlist of valid metadata node kinds 231 std::array<unsigned, 6> DXILCompatibleMDs = getCompatibleInstructionMDs(M); 232 233 for (auto &F : M.functions()) { 234 F.removeFnAttrs(AttrMask); 235 F.removeRetAttrs(AttrMask); 236 // Only remove string attributes if we are not skipping validation. 237 // This will reserve the experimental attributes when validation version 238 // is 0.0 for experiment mode. 239 removeStringFunctionAttributes(F, SkipValidation); 240 for (size_t Idx = 0, End = F.arg_size(); Idx < End; ++Idx) 241 F.removeParamAttrs(Idx, AttrMask); 242 243 // Lifetime intrinsics in LLVM 3.7 do not have the memory FnAttr 244 if (Intrinsic::ID IID = F.getIntrinsicID(); 245 IID == Intrinsic::lifetime_start || IID == Intrinsic::lifetime_end) 246 F.removeFnAttr(Attribute::Memory); 247 248 for (auto &BB : F) { 249 IRBuilder<> Builder(&BB); 250 for (auto &I : make_early_inc_range(BB)) { 251 252 I.dropUnknownNonDebugMetadata(DXILCompatibleMDs); 253 254 // Emtting NoOp bitcast instructions allows the ValueEnumerator to be 255 // unmodified as it reserves instruction IDs during contruction. 256 if (auto *LI = dyn_cast<LoadInst>(&I)) { 257 if (Value *NoOpBitcast = maybeGenerateBitcast( 258 Builder, PointerTypes, I, LI->getPointerOperand(), 259 LI->getType())) { 260 LI->replaceAllUsesWith( 261 Builder.CreateLoad(LI->getType(), NoOpBitcast)); 262 LI->eraseFromParent(); 263 } 264 continue; 265 } 266 if (auto *SI = dyn_cast<StoreInst>(&I)) { 267 if (Value *NoOpBitcast = maybeGenerateBitcast( 268 Builder, PointerTypes, I, SI->getPointerOperand(), 269 SI->getValueOperand()->getType())) { 270 271 SI->replaceAllUsesWith( 272 Builder.CreateStore(SI->getValueOperand(), NoOpBitcast)); 273 SI->eraseFromParent(); 274 } 275 continue; 276 } 277 if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { 278 if (Value *NoOpBitcast = maybeGenerateBitcast( 279 Builder, PointerTypes, I, GEP->getPointerOperand(), 280 GEP->getSourceElementType())) 281 GEP->setOperand(0, NoOpBitcast); 282 continue; 283 } 284 if (auto *CB = dyn_cast<CallBase>(&I)) { 285 CB->removeFnAttrs(AttrMask); 286 CB->removeRetAttrs(AttrMask); 287 for (size_t Idx = 0, End = CB->arg_size(); Idx < End; ++Idx) 288 CB->removeParamAttrs(Idx, AttrMask); 289 // LLVM 3.7 Lifetime intrinics require an i8* pointer operand, so we 290 // insert a bitcast here to ensure that is the case 291 if (isa<LifetimeIntrinsic>(CB)) { 292 Value *PtrOperand = CB->getArgOperand(1); 293 Builder.SetInsertPoint(CB); 294 PointerType *PtrTy = cast<PointerType>(PtrOperand->getType()); 295 Value *NoOpBitcast = Builder.Insert( 296 CastInst::Create(Instruction::BitCast, PtrOperand, 297 Builder.getPtrTy(PtrTy->getAddressSpace()))); 298 CB->setArgOperand(1, NoOpBitcast); 299 } 300 continue; 301 } 302 } 303 } 304 } 305 // Remove flags not for DXIL. 306 cleanModuleFlags(M); 307 308 // dx.rootsignatures will have been parsed from its metadata form as its 309 // binary form as part of the RootSignatureAnalysisWrapper, so safely 310 // remove it as it is not recognized in DXIL 311 if (NamedMDNode *RootSignature = M.getNamedMetadata("dx.rootsignatures")) 312 RootSignature->eraseFromParent(); 313 314 return true; 315 } 316 317 DXILPrepareModule() : ModulePass(ID) {} 318 void getAnalysisUsage(AnalysisUsage &AU) const override { 319 AU.addRequired<DXILMetadataAnalysisWrapperPass>(); 320 AU.addRequired<RootSignatureAnalysisWrapper>(); 321 AU.addPreserved<RootSignatureAnalysisWrapper>(); 322 AU.addPreserved<ShaderFlagsAnalysisWrapper>(); 323 AU.addPreserved<DXILMetadataAnalysisWrapperPass>(); 324 AU.addPreserved<DXILResourceWrapperPass>(); 325 } 326 static char ID; // Pass identification. 327 }; 328 char DXILPrepareModule::ID = 0; 329 330 } // end anonymous namespace 331 332 INITIALIZE_PASS_BEGIN(DXILPrepareModule, DEBUG_TYPE, "DXIL Prepare Module", 333 false, false) 334 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) 335 INITIALIZE_PASS_DEPENDENCY(RootSignatureAnalysisWrapper) 336 INITIALIZE_PASS_END(DXILPrepareModule, DEBUG_TYPE, "DXIL Prepare Module", false, 337 false) 338 339 ModulePass *llvm::createDXILPrepareModulePass() { 340 return new DXILPrepareModule(); 341 } 342