xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILPrepare.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DXILPrepare.cpp - Prepare LLVM Module for DXIL encoding ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains pases and utilities to convert a modern LLVM
10 /// module into a module compatible with the LLVM 3.7-based DirectX Intermediate
11 /// Language (DXIL).
12 //===----------------------------------------------------------------------===//
13 
14 #include "DXILRootSignature.h"
15 #include "DXILShaderFlags.h"
16 #include "DirectX.h"
17 #include "DirectXIRPasses/PointerTypeAnalysis.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/ADT/StringSet.h"
21 #include "llvm/Analysis/DXILMetadataAnalysis.h"
22 #include "llvm/Analysis/DXILResource.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/IR/AttributeMask.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/InitializePasses.h"
30 #include "llvm/Pass.h"
31 #include "llvm/Support/Compiler.h"
32 #include "llvm/Support/VersionTuple.h"
33 
34 #define DEBUG_TYPE "dxil-prepare"
35 
36 using namespace llvm;
37 using namespace llvm::dxil;
38 
39 namespace {
40 
isValidForDXIL(Attribute::AttrKind Attr)41 constexpr bool isValidForDXIL(Attribute::AttrKind Attr) {
42   return is_contained({Attribute::Alignment,
43                        Attribute::AlwaysInline,
44                        Attribute::Builtin,
45                        Attribute::ByVal,
46                        Attribute::InAlloca,
47                        Attribute::Cold,
48                        Attribute::Convergent,
49                        Attribute::InlineHint,
50                        Attribute::InReg,
51                        Attribute::JumpTable,
52                        Attribute::MinSize,
53                        Attribute::Naked,
54                        Attribute::Nest,
55                        Attribute::NoAlias,
56                        Attribute::NoBuiltin,
57                        Attribute::NoDuplicate,
58                        Attribute::NoImplicitFloat,
59                        Attribute::NoInline,
60                        Attribute::NonLazyBind,
61                        Attribute::NonNull,
62                        Attribute::Dereferenceable,
63                        Attribute::DereferenceableOrNull,
64                        Attribute::Memory,
65                        Attribute::NoRedZone,
66                        Attribute::NoReturn,
67                        Attribute::NoUnwind,
68                        Attribute::OptimizeForSize,
69                        Attribute::OptimizeNone,
70                        Attribute::ReadNone,
71                        Attribute::ReadOnly,
72                        Attribute::Returned,
73                        Attribute::ReturnsTwice,
74                        Attribute::SExt,
75                        Attribute::StackAlignment,
76                        Attribute::StackProtect,
77                        Attribute::StackProtectReq,
78                        Attribute::StackProtectStrong,
79                        Attribute::SafeStack,
80                        Attribute::StructRet,
81                        Attribute::SanitizeAddress,
82                        Attribute::SanitizeThread,
83                        Attribute::SanitizeMemory,
84                        Attribute::UWTable,
85                        Attribute::ZExt},
86                       Attr);
87 }
88 
collectDeadStringAttrs(AttributeMask & DeadAttrs,AttributeSet && AS,const StringSet<> & LiveKeys,bool AllowExperimental)89 static void collectDeadStringAttrs(AttributeMask &DeadAttrs, AttributeSet &&AS,
90                                    const StringSet<> &LiveKeys,
91                                    bool AllowExperimental) {
92   for (auto &Attr : AS) {
93     if (!Attr.isStringAttribute())
94       continue;
95     StringRef Key = Attr.getKindAsString();
96     if (LiveKeys.contains(Key))
97       continue;
98     if (AllowExperimental && Key.starts_with("exp-"))
99       continue;
100     DeadAttrs.addAttribute(Key);
101   }
102 }
103 
removeStringFunctionAttributes(Function & F,bool AllowExperimental)104 static void removeStringFunctionAttributes(Function &F,
105                                            bool AllowExperimental) {
106   AttributeList Attrs = F.getAttributes();
107   const StringSet<> LiveKeys = {"waveops-include-helper-lanes",
108                                 "fp32-denorm-mode"};
109   // Collect DeadKeys in FnAttrs.
110   AttributeMask DeadAttrs;
111   collectDeadStringAttrs(DeadAttrs, Attrs.getFnAttrs(), LiveKeys,
112                          AllowExperimental);
113   collectDeadStringAttrs(DeadAttrs, Attrs.getRetAttrs(), LiveKeys,
114                          AllowExperimental);
115 
116   F.removeFnAttrs(DeadAttrs);
117   F.removeRetAttrs(DeadAttrs);
118 }
119 
cleanModuleFlags(Module & M)120 static void cleanModuleFlags(Module &M) {
121   NamedMDNode *MDFlags = M.getModuleFlagsMetadata();
122   if (!MDFlags)
123     return;
124 
125   SmallVector<llvm::Module::ModuleFlagEntry> FlagEntries;
126   M.getModuleFlagsMetadata(FlagEntries);
127   bool Updated = false;
128   for (auto &Flag : FlagEntries) {
129     // llvm 3.7 only supports behavior up to AppendUnique.
130     if (Flag.Behavior <= Module::ModFlagBehavior::AppendUnique)
131       continue;
132     Flag.Behavior = Module::ModFlagBehavior::Warning;
133     Updated = true;
134   }
135 
136   if (!Updated)
137     return;
138 
139   MDFlags->eraseFromParent();
140 
141   for (auto &Flag : FlagEntries)
142     M.addModuleFlag(Flag.Behavior, Flag.Key->getString(), Flag.Val);
143 }
144 
145 class DXILPrepareModule : public ModulePass {
146 
maybeGenerateBitcast(IRBuilder<> & Builder,PointerTypeMap & PointerTypes,Instruction & Inst,Value * Operand,Type * Ty)147   static Value *maybeGenerateBitcast(IRBuilder<> &Builder,
148                                      PointerTypeMap &PointerTypes,
149                                      Instruction &Inst, Value *Operand,
150                                      Type *Ty) {
151     // Omit bitcasts if the incoming value matches the instruction type.
152     auto It = PointerTypes.find(Operand);
153     if (It != PointerTypes.end()) {
154       auto *OpTy = cast<TypedPointerType>(It->second)->getElementType();
155       if (OpTy == Ty)
156         return nullptr;
157     }
158 
159     Type *ValTy = Operand->getType();
160     // Also omit the bitcast for matching global array types
161     if (auto *GlobalVar = dyn_cast<GlobalVariable>(Operand))
162       ValTy = GlobalVar->getValueType();
163 
164     if (auto *AI = dyn_cast<AllocaInst>(Operand))
165       ValTy = AI->getAllocatedType();
166 
167     if (auto *ArrTy = dyn_cast<ArrayType>(ValTy)) {
168       Type *ElTy = ArrTy->getElementType();
169       if (ElTy == Ty)
170         return nullptr;
171     }
172 
173     // finally, drill down GEP instructions until we get the array
174     // that is being accessed, and compare element types
175     if (ConstantExpr *GEPInstr = dyn_cast<ConstantExpr>(Operand)) {
176       while (GEPInstr->getOpcode() == Instruction::GetElementPtr) {
177         Value *OpArg = GEPInstr->getOperand(0);
178         if (ConstantExpr *NewGEPInstr = dyn_cast<ConstantExpr>(OpArg)) {
179           GEPInstr = NewGEPInstr;
180           continue;
181         }
182 
183         if (auto *GlobalVar = dyn_cast<GlobalVariable>(OpArg))
184           ValTy = GlobalVar->getValueType();
185         if (auto *AI = dyn_cast<AllocaInst>(Operand))
186           ValTy = AI->getAllocatedType();
187         if (auto *ArrTy = dyn_cast<ArrayType>(ValTy)) {
188           Type *ElTy = ArrTy->getElementType();
189           if (ElTy == Ty)
190             return nullptr;
191         }
192         break;
193       }
194     }
195 
196     // Insert bitcasts where we are removing the instruction.
197     Builder.SetInsertPoint(&Inst);
198     // This code only gets hit in opaque-pointer mode, so the type of the
199     // pointer doesn't matter.
200     PointerType *PtrTy = cast<PointerType>(Operand->getType());
201     return Builder.Insert(
202         CastInst::Create(Instruction::BitCast, Operand,
203                          Builder.getPtrTy(PtrTy->getAddressSpace())));
204   }
205 
getCompatibleInstructionMDs(llvm::Module & M)206   static std::array<unsigned, 6> getCompatibleInstructionMDs(llvm::Module &M) {
207     return {M.getMDKindID("dx.nonuniform"),
208             M.getMDKindID("dx.controlflow.hints"),
209             M.getMDKindID("dx.precise"),
210             llvm::LLVMContext::MD_range,
211             llvm::LLVMContext::MD_alias_scope,
212             llvm::LLVMContext::MD_noalias};
213   }
214 
215 public:
runOnModule(Module & M)216   bool runOnModule(Module &M) override {
217     PointerTypeMap PointerTypes = PointerTypeAnalysis::run(M);
218     AttributeMask AttrMask;
219     for (Attribute::AttrKind I = Attribute::None; I != Attribute::EndAttrKinds;
220          I = Attribute::AttrKind(I + 1)) {
221       if (!isValidForDXIL(I))
222         AttrMask.addAttribute(I);
223     }
224 
225     const dxil::ModuleMetadataInfo MetadataInfo =
226         getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
227     VersionTuple ValVer = MetadataInfo.ValidatorVersion;
228     bool SkipValidation = ValVer.getMajor() == 0 && ValVer.getMinor() == 0;
229 
230     // construct allowlist of valid metadata node kinds
231     std::array<unsigned, 6> DXILCompatibleMDs = getCompatibleInstructionMDs(M);
232 
233     for (auto &F : M.functions()) {
234       F.removeFnAttrs(AttrMask);
235       F.removeRetAttrs(AttrMask);
236       // Only remove string attributes if we are not skipping validation.
237       // This will reserve the experimental attributes when validation version
238       // is 0.0 for experiment mode.
239       removeStringFunctionAttributes(F, SkipValidation);
240       for (size_t Idx = 0, End = F.arg_size(); Idx < End; ++Idx)
241         F.removeParamAttrs(Idx, AttrMask);
242 
243       // Lifetime intrinsics in LLVM 3.7 do not have the memory FnAttr
244       if (Intrinsic::ID IID = F.getIntrinsicID();
245           IID == Intrinsic::lifetime_start || IID == Intrinsic::lifetime_end)
246         F.removeFnAttr(Attribute::Memory);
247 
248       for (auto &BB : F) {
249         IRBuilder<> Builder(&BB);
250         for (auto &I : make_early_inc_range(BB)) {
251 
252           I.dropUnknownNonDebugMetadata(DXILCompatibleMDs);
253 
254           // Emtting NoOp bitcast instructions allows the ValueEnumerator to be
255           // unmodified as it reserves instruction IDs during contruction.
256           if (auto *LI = dyn_cast<LoadInst>(&I)) {
257             if (Value *NoOpBitcast = maybeGenerateBitcast(
258                     Builder, PointerTypes, I, LI->getPointerOperand(),
259                     LI->getType())) {
260               LI->replaceAllUsesWith(
261                   Builder.CreateLoad(LI->getType(), NoOpBitcast));
262               LI->eraseFromParent();
263             }
264             continue;
265           }
266           if (auto *SI = dyn_cast<StoreInst>(&I)) {
267             if (Value *NoOpBitcast = maybeGenerateBitcast(
268                     Builder, PointerTypes, I, SI->getPointerOperand(),
269                     SI->getValueOperand()->getType())) {
270 
271               SI->replaceAllUsesWith(
272                   Builder.CreateStore(SI->getValueOperand(), NoOpBitcast));
273               SI->eraseFromParent();
274             }
275             continue;
276           }
277           if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
278             if (Value *NoOpBitcast = maybeGenerateBitcast(
279                     Builder, PointerTypes, I, GEP->getPointerOperand(),
280                     GEP->getSourceElementType()))
281               GEP->setOperand(0, NoOpBitcast);
282             continue;
283           }
284           if (auto *CB = dyn_cast<CallBase>(&I)) {
285             CB->removeFnAttrs(AttrMask);
286             CB->removeRetAttrs(AttrMask);
287             for (size_t Idx = 0, End = CB->arg_size(); Idx < End; ++Idx)
288               CB->removeParamAttrs(Idx, AttrMask);
289             // LLVM 3.7 Lifetime intrinics require an i8* pointer operand, so we
290             // insert a bitcast here to ensure that is the case
291             if (isa<LifetimeIntrinsic>(CB)) {
292               Value *PtrOperand = CB->getArgOperand(1);
293               Builder.SetInsertPoint(CB);
294               PointerType *PtrTy = cast<PointerType>(PtrOperand->getType());
295               Value *NoOpBitcast = Builder.Insert(
296                   CastInst::Create(Instruction::BitCast, PtrOperand,
297                                    Builder.getPtrTy(PtrTy->getAddressSpace())));
298               CB->setArgOperand(1, NoOpBitcast);
299             }
300             continue;
301           }
302         }
303       }
304     }
305     // Remove flags not for DXIL.
306     cleanModuleFlags(M);
307 
308     // dx.rootsignatures will have been parsed from its metadata form as its
309     // binary form as part of the RootSignatureAnalysisWrapper, so safely
310     // remove it as it is not recognized in DXIL
311     if (NamedMDNode *RootSignature = M.getNamedMetadata("dx.rootsignatures"))
312       RootSignature->eraseFromParent();
313 
314     return true;
315   }
316 
DXILPrepareModule()317   DXILPrepareModule() : ModulePass(ID) {}
getAnalysisUsage(AnalysisUsage & AU) const318   void getAnalysisUsage(AnalysisUsage &AU) const override {
319     AU.addRequired<DXILMetadataAnalysisWrapperPass>();
320     AU.addRequired<RootSignatureAnalysisWrapper>();
321     AU.addPreserved<RootSignatureAnalysisWrapper>();
322     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
323     AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
324     AU.addPreserved<DXILResourceWrapperPass>();
325   }
326   static char ID; // Pass identification.
327 };
328 char DXILPrepareModule::ID = 0;
329 
330 } // end anonymous namespace
331 
332 INITIALIZE_PASS_BEGIN(DXILPrepareModule, DEBUG_TYPE, "DXIL Prepare Module",
333                       false, false)
INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)334 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
335 INITIALIZE_PASS_DEPENDENCY(RootSignatureAnalysisWrapper)
336 INITIALIZE_PASS_END(DXILPrepareModule, DEBUG_TYPE, "DXIL Prepare Module", false,
337                     false)
338 
339 ModulePass *llvm::createDXILPrepareModulePass() {
340   return new DXILPrepareModule();
341 }
342