1 //===- DXILPrepare.cpp - Prepare LLVM Module for DXIL encoding ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains pases and utilities to convert a modern LLVM
10 /// module into a module compatible with the LLVM 3.7-based DirectX Intermediate
11 /// Language (DXIL).
12 //===----------------------------------------------------------------------===//
13
14 #include "DXILRootSignature.h"
15 #include "DXILShaderFlags.h"
16 #include "DirectX.h"
17 #include "DirectXIRPasses/PointerTypeAnalysis.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/StringSet.h"
21 #include "llvm/Analysis/DXILMetadataAnalysis.h"
22 #include "llvm/Analysis/DXILResource.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/IR/AttributeMask.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/InitializePasses.h"
30 #include "llvm/Pass.h"
31 #include "llvm/Support/Compiler.h"
32 #include "llvm/Support/VersionTuple.h"
33
34 #define DEBUG_TYPE "dxil-prepare"
35
36 using namespace llvm;
37 using namespace llvm::dxil;
38
39 namespace {
40
isValidForDXIL(Attribute::AttrKind Attr)41 constexpr bool isValidForDXIL(Attribute::AttrKind Attr) {
42 return is_contained({Attribute::Alignment,
43 Attribute::AlwaysInline,
44 Attribute::Builtin,
45 Attribute::ByVal,
46 Attribute::InAlloca,
47 Attribute::Cold,
48 Attribute::Convergent,
49 Attribute::InlineHint,
50 Attribute::InReg,
51 Attribute::JumpTable,
52 Attribute::MinSize,
53 Attribute::Naked,
54 Attribute::Nest,
55 Attribute::NoAlias,
56 Attribute::NoBuiltin,
57 Attribute::NoDuplicate,
58 Attribute::NoImplicitFloat,
59 Attribute::NoInline,
60 Attribute::NonLazyBind,
61 Attribute::NonNull,
62 Attribute::Dereferenceable,
63 Attribute::DereferenceableOrNull,
64 Attribute::Memory,
65 Attribute::NoRedZone,
66 Attribute::NoReturn,
67 Attribute::NoUnwind,
68 Attribute::OptimizeForSize,
69 Attribute::OptimizeNone,
70 Attribute::ReadNone,
71 Attribute::ReadOnly,
72 Attribute::Returned,
73 Attribute::ReturnsTwice,
74 Attribute::SExt,
75 Attribute::StackAlignment,
76 Attribute::StackProtect,
77 Attribute::StackProtectReq,
78 Attribute::StackProtectStrong,
79 Attribute::SafeStack,
80 Attribute::StructRet,
81 Attribute::SanitizeAddress,
82 Attribute::SanitizeThread,
83 Attribute::SanitizeMemory,
84 Attribute::UWTable,
85 Attribute::ZExt},
86 Attr);
87 }
88
collectDeadStringAttrs(AttributeMask & DeadAttrs,AttributeSet && AS,const StringSet<> & LiveKeys,bool AllowExperimental)89 static void collectDeadStringAttrs(AttributeMask &DeadAttrs, AttributeSet &&AS,
90 const StringSet<> &LiveKeys,
91 bool AllowExperimental) {
92 for (auto &Attr : AS) {
93 if (!Attr.isStringAttribute())
94 continue;
95 StringRef Key = Attr.getKindAsString();
96 if (LiveKeys.contains(Key))
97 continue;
98 if (AllowExperimental && Key.starts_with("exp-"))
99 continue;
100 DeadAttrs.addAttribute(Key);
101 }
102 }
103
removeStringFunctionAttributes(Function & F,bool AllowExperimental)104 static void removeStringFunctionAttributes(Function &F,
105 bool AllowExperimental) {
106 AttributeList Attrs = F.getAttributes();
107 const StringSet<> LiveKeys = {"waveops-include-helper-lanes",
108 "fp32-denorm-mode"};
109 // Collect DeadKeys in FnAttrs.
110 AttributeMask DeadAttrs;
111 collectDeadStringAttrs(DeadAttrs, Attrs.getFnAttrs(), LiveKeys,
112 AllowExperimental);
113 collectDeadStringAttrs(DeadAttrs, Attrs.getRetAttrs(), LiveKeys,
114 AllowExperimental);
115
116 F.removeFnAttrs(DeadAttrs);
117 F.removeRetAttrs(DeadAttrs);
118 }
119
cleanModuleFlags(Module & M)120 static void cleanModuleFlags(Module &M) {
121 NamedMDNode *MDFlags = M.getModuleFlagsMetadata();
122 if (!MDFlags)
123 return;
124
125 SmallVector<llvm::Module::ModuleFlagEntry> FlagEntries;
126 M.getModuleFlagsMetadata(FlagEntries);
127 bool Updated = false;
128 for (auto &Flag : FlagEntries) {
129 // llvm 3.7 only supports behavior up to AppendUnique.
130 if (Flag.Behavior <= Module::ModFlagBehavior::AppendUnique)
131 continue;
132 Flag.Behavior = Module::ModFlagBehavior::Warning;
133 Updated = true;
134 }
135
136 if (!Updated)
137 return;
138
139 MDFlags->eraseFromParent();
140
141 for (auto &Flag : FlagEntries)
142 M.addModuleFlag(Flag.Behavior, Flag.Key->getString(), Flag.Val);
143 }
144
145 class DXILPrepareModule : public ModulePass {
146
maybeGenerateBitcast(IRBuilder<> & Builder,PointerTypeMap & PointerTypes,Instruction & Inst,Value * Operand,Type * Ty)147 static Value *maybeGenerateBitcast(IRBuilder<> &Builder,
148 PointerTypeMap &PointerTypes,
149 Instruction &Inst, Value *Operand,
150 Type *Ty) {
151 // Omit bitcasts if the incoming value matches the instruction type.
152 auto It = PointerTypes.find(Operand);
153 if (It != PointerTypes.end()) {
154 auto *OpTy = cast<TypedPointerType>(It->second)->getElementType();
155 if (OpTy == Ty)
156 return nullptr;
157 }
158
159 Type *ValTy = Operand->getType();
160 // Also omit the bitcast for matching global array types
161 if (auto *GlobalVar = dyn_cast<GlobalVariable>(Operand))
162 ValTy = GlobalVar->getValueType();
163
164 if (auto *AI = dyn_cast<AllocaInst>(Operand))
165 ValTy = AI->getAllocatedType();
166
167 if (auto *ArrTy = dyn_cast<ArrayType>(ValTy)) {
168 Type *ElTy = ArrTy->getElementType();
169 if (ElTy == Ty)
170 return nullptr;
171 }
172
173 // finally, drill down GEP instructions until we get the array
174 // that is being accessed, and compare element types
175 if (ConstantExpr *GEPInstr = dyn_cast<ConstantExpr>(Operand)) {
176 while (GEPInstr->getOpcode() == Instruction::GetElementPtr) {
177 Value *OpArg = GEPInstr->getOperand(0);
178 if (ConstantExpr *NewGEPInstr = dyn_cast<ConstantExpr>(OpArg)) {
179 GEPInstr = NewGEPInstr;
180 continue;
181 }
182
183 if (auto *GlobalVar = dyn_cast<GlobalVariable>(OpArg))
184 ValTy = GlobalVar->getValueType();
185 if (auto *AI = dyn_cast<AllocaInst>(Operand))
186 ValTy = AI->getAllocatedType();
187 if (auto *ArrTy = dyn_cast<ArrayType>(ValTy)) {
188 Type *ElTy = ArrTy->getElementType();
189 if (ElTy == Ty)
190 return nullptr;
191 }
192 break;
193 }
194 }
195
196 // Insert bitcasts where we are removing the instruction.
197 Builder.SetInsertPoint(&Inst);
198 // This code only gets hit in opaque-pointer mode, so the type of the
199 // pointer doesn't matter.
200 PointerType *PtrTy = cast<PointerType>(Operand->getType());
201 return Builder.Insert(
202 CastInst::Create(Instruction::BitCast, Operand,
203 Builder.getPtrTy(PtrTy->getAddressSpace())));
204 }
205
getCompatibleInstructionMDs(llvm::Module & M)206 static std::array<unsigned, 6> getCompatibleInstructionMDs(llvm::Module &M) {
207 return {M.getMDKindID("dx.nonuniform"),
208 M.getMDKindID("dx.controlflow.hints"),
209 M.getMDKindID("dx.precise"),
210 llvm::LLVMContext::MD_range,
211 llvm::LLVMContext::MD_alias_scope,
212 llvm::LLVMContext::MD_noalias};
213 }
214
215 public:
runOnModule(Module & M)216 bool runOnModule(Module &M) override {
217 PointerTypeMap PointerTypes = PointerTypeAnalysis::run(M);
218 AttributeMask AttrMask;
219 for (Attribute::AttrKind I = Attribute::None; I != Attribute::EndAttrKinds;
220 I = Attribute::AttrKind(I + 1)) {
221 if (!isValidForDXIL(I))
222 AttrMask.addAttribute(I);
223 }
224
225 const dxil::ModuleMetadataInfo MetadataInfo =
226 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
227 VersionTuple ValVer = MetadataInfo.ValidatorVersion;
228 bool SkipValidation = ValVer.getMajor() == 0 && ValVer.getMinor() == 0;
229
230 // construct allowlist of valid metadata node kinds
231 std::array<unsigned, 6> DXILCompatibleMDs = getCompatibleInstructionMDs(M);
232
233 for (auto &F : M.functions()) {
234 F.removeFnAttrs(AttrMask);
235 F.removeRetAttrs(AttrMask);
236 // Only remove string attributes if we are not skipping validation.
237 // This will reserve the experimental attributes when validation version
238 // is 0.0 for experiment mode.
239 removeStringFunctionAttributes(F, SkipValidation);
240 for (size_t Idx = 0, End = F.arg_size(); Idx < End; ++Idx)
241 F.removeParamAttrs(Idx, AttrMask);
242
243 // Lifetime intrinsics in LLVM 3.7 do not have the memory FnAttr
244 if (Intrinsic::ID IID = F.getIntrinsicID();
245 IID == Intrinsic::lifetime_start || IID == Intrinsic::lifetime_end)
246 F.removeFnAttr(Attribute::Memory);
247
248 for (auto &BB : F) {
249 IRBuilder<> Builder(&BB);
250 for (auto &I : make_early_inc_range(BB)) {
251
252 I.dropUnknownNonDebugMetadata(DXILCompatibleMDs);
253
254 // Emtting NoOp bitcast instructions allows the ValueEnumerator to be
255 // unmodified as it reserves instruction IDs during contruction.
256 if (auto *LI = dyn_cast<LoadInst>(&I)) {
257 if (Value *NoOpBitcast = maybeGenerateBitcast(
258 Builder, PointerTypes, I, LI->getPointerOperand(),
259 LI->getType())) {
260 LI->replaceAllUsesWith(
261 Builder.CreateLoad(LI->getType(), NoOpBitcast));
262 LI->eraseFromParent();
263 }
264 continue;
265 }
266 if (auto *SI = dyn_cast<StoreInst>(&I)) {
267 if (Value *NoOpBitcast = maybeGenerateBitcast(
268 Builder, PointerTypes, I, SI->getPointerOperand(),
269 SI->getValueOperand()->getType())) {
270
271 SI->replaceAllUsesWith(
272 Builder.CreateStore(SI->getValueOperand(), NoOpBitcast));
273 SI->eraseFromParent();
274 }
275 continue;
276 }
277 if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
278 if (Value *NoOpBitcast = maybeGenerateBitcast(
279 Builder, PointerTypes, I, GEP->getPointerOperand(),
280 GEP->getSourceElementType()))
281 GEP->setOperand(0, NoOpBitcast);
282 continue;
283 }
284 if (auto *CB = dyn_cast<CallBase>(&I)) {
285 CB->removeFnAttrs(AttrMask);
286 CB->removeRetAttrs(AttrMask);
287 for (size_t Idx = 0, End = CB->arg_size(); Idx < End; ++Idx)
288 CB->removeParamAttrs(Idx, AttrMask);
289 // LLVM 3.7 Lifetime intrinics require an i8* pointer operand, so we
290 // insert a bitcast here to ensure that is the case
291 if (isa<LifetimeIntrinsic>(CB)) {
292 Value *PtrOperand = CB->getArgOperand(1);
293 Builder.SetInsertPoint(CB);
294 PointerType *PtrTy = cast<PointerType>(PtrOperand->getType());
295 Value *NoOpBitcast = Builder.Insert(
296 CastInst::Create(Instruction::BitCast, PtrOperand,
297 Builder.getPtrTy(PtrTy->getAddressSpace())));
298 CB->setArgOperand(1, NoOpBitcast);
299 }
300 continue;
301 }
302 }
303 }
304 }
305 // Remove flags not for DXIL.
306 cleanModuleFlags(M);
307
308 // dx.rootsignatures will have been parsed from its metadata form as its
309 // binary form as part of the RootSignatureAnalysisWrapper, so safely
310 // remove it as it is not recognized in DXIL
311 if (NamedMDNode *RootSignature = M.getNamedMetadata("dx.rootsignatures"))
312 RootSignature->eraseFromParent();
313
314 return true;
315 }
316
DXILPrepareModule()317 DXILPrepareModule() : ModulePass(ID) {}
getAnalysisUsage(AnalysisUsage & AU) const318 void getAnalysisUsage(AnalysisUsage &AU) const override {
319 AU.addRequired<DXILMetadataAnalysisWrapperPass>();
320 AU.addRequired<RootSignatureAnalysisWrapper>();
321 AU.addPreserved<RootSignatureAnalysisWrapper>();
322 AU.addPreserved<ShaderFlagsAnalysisWrapper>();
323 AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
324 AU.addPreserved<DXILResourceWrapperPass>();
325 }
326 static char ID; // Pass identification.
327 };
328 char DXILPrepareModule::ID = 0;
329
330 } // end anonymous namespace
331
332 INITIALIZE_PASS_BEGIN(DXILPrepareModule, DEBUG_TYPE, "DXIL Prepare Module",
333 false, false)
INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)334 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
335 INITIALIZE_PASS_DEPENDENCY(RootSignatureAnalysisWrapper)
336 INITIALIZE_PASS_END(DXILPrepareModule, DEBUG_TYPE, "DXIL Prepare Module", false,
337 false)
338
339 ModulePass *llvm::createDXILPrepareModulePass() {
340 return new DXILPrepareModule();
341 }
342