xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILPrepare.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- DXILPrepare.cpp - Prepare LLVM Module for DXIL encoding ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains pases and utilities to convert a modern LLVM
10 /// module into a module compatible with the LLVM 3.7-based DirectX Intermediate
11 /// Language (DXIL).
12 //===----------------------------------------------------------------------===//
13 
14 #include "DXILMetadata.h"
15 #include "DXILResourceAnalysis.h"
16 #include "DXILShaderFlags.h"
17 #include "DirectX.h"
18 #include "DirectXIRPasses/PointerTypeAnalysis.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/StringSet.h"
22 #include "llvm/CodeGen/Passes.h"
23 #include "llvm/IR/AttributeMask.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/Instruction.h"
26 #include "llvm/IR/Module.h"
27 #include "llvm/InitializePasses.h"
28 #include "llvm/Pass.h"
29 #include "llvm/Support/Compiler.h"
30 #include "llvm/Support/VersionTuple.h"
31 
32 #define DEBUG_TYPE "dxil-prepare"
33 
34 using namespace llvm;
35 using namespace llvm::dxil;
36 
37 namespace {
38 
isValidForDXIL(Attribute::AttrKind Attr)39 constexpr bool isValidForDXIL(Attribute::AttrKind Attr) {
40   return is_contained({Attribute::Alignment,
41                        Attribute::AlwaysInline,
42                        Attribute::Builtin,
43                        Attribute::ByVal,
44                        Attribute::InAlloca,
45                        Attribute::Cold,
46                        Attribute::Convergent,
47                        Attribute::InlineHint,
48                        Attribute::InReg,
49                        Attribute::JumpTable,
50                        Attribute::MinSize,
51                        Attribute::Naked,
52                        Attribute::Nest,
53                        Attribute::NoAlias,
54                        Attribute::NoBuiltin,
55                        Attribute::NoCapture,
56                        Attribute::NoDuplicate,
57                        Attribute::NoImplicitFloat,
58                        Attribute::NoInline,
59                        Attribute::NonLazyBind,
60                        Attribute::NonNull,
61                        Attribute::Dereferenceable,
62                        Attribute::DereferenceableOrNull,
63                        Attribute::Memory,
64                        Attribute::NoRedZone,
65                        Attribute::NoReturn,
66                        Attribute::NoUnwind,
67                        Attribute::OptimizeForSize,
68                        Attribute::OptimizeNone,
69                        Attribute::ReadNone,
70                        Attribute::ReadOnly,
71                        Attribute::Returned,
72                        Attribute::ReturnsTwice,
73                        Attribute::SExt,
74                        Attribute::StackAlignment,
75                        Attribute::StackProtect,
76                        Attribute::StackProtectReq,
77                        Attribute::StackProtectStrong,
78                        Attribute::SafeStack,
79                        Attribute::StructRet,
80                        Attribute::SanitizeAddress,
81                        Attribute::SanitizeThread,
82                        Attribute::SanitizeMemory,
83                        Attribute::UWTable,
84                        Attribute::ZExt},
85                       Attr);
86 }
87 
collectDeadStringAttrs(AttributeMask & DeadAttrs,AttributeSet && AS,const StringSet<> & LiveKeys,bool AllowExperimental)88 static void collectDeadStringAttrs(AttributeMask &DeadAttrs, AttributeSet &&AS,
89                                    const StringSet<> &LiveKeys,
90                                    bool AllowExperimental) {
91   for (auto &Attr : AS) {
92     if (!Attr.isStringAttribute())
93       continue;
94     StringRef Key = Attr.getKindAsString();
95     if (LiveKeys.contains(Key))
96       continue;
97     if (AllowExperimental && Key.starts_with("exp-"))
98       continue;
99     DeadAttrs.addAttribute(Key);
100   }
101 }
102 
removeStringFunctionAttributes(Function & F,bool AllowExperimental)103 static void removeStringFunctionAttributes(Function &F,
104                                            bool AllowExperimental) {
105   AttributeList Attrs = F.getAttributes();
106   const StringSet<> LiveKeys = {"waveops-include-helper-lanes",
107                                 "fp32-denorm-mode"};
108   // Collect DeadKeys in FnAttrs.
109   AttributeMask DeadAttrs;
110   collectDeadStringAttrs(DeadAttrs, Attrs.getFnAttrs(), LiveKeys,
111                          AllowExperimental);
112   collectDeadStringAttrs(DeadAttrs, Attrs.getRetAttrs(), LiveKeys,
113                          AllowExperimental);
114 
115   F.removeFnAttrs(DeadAttrs);
116   F.removeRetAttrs(DeadAttrs);
117 }
118 
cleanModuleFlags(Module & M)119 static void cleanModuleFlags(Module &M) {
120   NamedMDNode *MDFlags = M.getModuleFlagsMetadata();
121   if (!MDFlags)
122     return;
123 
124   SmallVector<llvm::Module::ModuleFlagEntry> FlagEntries;
125   M.getModuleFlagsMetadata(FlagEntries);
126   bool Updated = false;
127   for (auto &Flag : FlagEntries) {
128     // llvm 3.7 only supports behavior up to AppendUnique.
129     if (Flag.Behavior <= Module::ModFlagBehavior::AppendUnique)
130       continue;
131     Flag.Behavior = Module::ModFlagBehavior::Warning;
132     Updated = true;
133   }
134 
135   if (!Updated)
136     return;
137 
138   MDFlags->eraseFromParent();
139 
140   for (auto &Flag : FlagEntries)
141     M.addModuleFlag(Flag.Behavior, Flag.Key->getString(), Flag.Val);
142 }
143 
144 class DXILPrepareModule : public ModulePass {
145 
maybeGenerateBitcast(IRBuilder<> & Builder,PointerTypeMap & PointerTypes,Instruction & Inst,Value * Operand,Type * Ty)146   static Value *maybeGenerateBitcast(IRBuilder<> &Builder,
147                                      PointerTypeMap &PointerTypes,
148                                      Instruction &Inst, Value *Operand,
149                                      Type *Ty) {
150     // Omit bitcasts if the incoming value matches the instruction type.
151     auto It = PointerTypes.find(Operand);
152     if (It != PointerTypes.end())
153       if (cast<TypedPointerType>(It->second)->getElementType() == Ty)
154         return nullptr;
155     // Insert bitcasts where we are removing the instruction.
156     Builder.SetInsertPoint(&Inst);
157     // This code only gets hit in opaque-pointer mode, so the type of the
158     // pointer doesn't matter.
159     PointerType *PtrTy = cast<PointerType>(Operand->getType());
160     return Builder.Insert(
161         CastInst::Create(Instruction::BitCast, Operand,
162                          Builder.getPtrTy(PtrTy->getAddressSpace())));
163   }
164 
165 public:
runOnModule(Module & M)166   bool runOnModule(Module &M) override {
167     PointerTypeMap PointerTypes = PointerTypeAnalysis::run(M);
168     AttributeMask AttrMask;
169     for (Attribute::AttrKind I = Attribute::None; I != Attribute::EndAttrKinds;
170          I = Attribute::AttrKind(I + 1)) {
171       if (!isValidForDXIL(I))
172         AttrMask.addAttribute(I);
173     }
174 
175     dxil::ValidatorVersionMD ValVerMD(M);
176     VersionTuple ValVer = ValVerMD.getAsVersionTuple();
177     bool SkipValidation = ValVer.getMajor() == 0 && ValVer.getMinor() == 0;
178 
179     for (auto &F : M.functions()) {
180       F.removeFnAttrs(AttrMask);
181       F.removeRetAttrs(AttrMask);
182       // Only remove string attributes if we are not skipping validation.
183       // This will reserve the experimental attributes when validation version
184       // is 0.0 for experiment mode.
185       removeStringFunctionAttributes(F, SkipValidation);
186       for (size_t Idx = 0, End = F.arg_size(); Idx < End; ++Idx)
187         F.removeParamAttrs(Idx, AttrMask);
188 
189       for (auto &BB : F) {
190         IRBuilder<> Builder(&BB);
191         for (auto &I : make_early_inc_range(BB)) {
192           if (I.getOpcode() == Instruction::FNeg) {
193             Builder.SetInsertPoint(&I);
194             Value *In = I.getOperand(0);
195             Value *Zero = ConstantFP::get(In->getType(), -0.0);
196             I.replaceAllUsesWith(Builder.CreateFSub(Zero, In));
197             I.eraseFromParent();
198             continue;
199           }
200 
201           // Emtting NoOp bitcast instructions allows the ValueEnumerator to be
202           // unmodified as it reserves instruction IDs during contruction.
203           if (auto LI = dyn_cast<LoadInst>(&I)) {
204             if (Value *NoOpBitcast = maybeGenerateBitcast(
205                     Builder, PointerTypes, I, LI->getPointerOperand(),
206                     LI->getType())) {
207               LI->replaceAllUsesWith(
208                   Builder.CreateLoad(LI->getType(), NoOpBitcast));
209               LI->eraseFromParent();
210             }
211             continue;
212           }
213           if (auto SI = dyn_cast<StoreInst>(&I)) {
214             if (Value *NoOpBitcast = maybeGenerateBitcast(
215                     Builder, PointerTypes, I, SI->getPointerOperand(),
216                     SI->getValueOperand()->getType())) {
217 
218               SI->replaceAllUsesWith(
219                   Builder.CreateStore(SI->getValueOperand(), NoOpBitcast));
220               SI->eraseFromParent();
221             }
222             continue;
223           }
224           if (auto GEP = dyn_cast<GetElementPtrInst>(&I)) {
225             if (Value *NoOpBitcast = maybeGenerateBitcast(
226                     Builder, PointerTypes, I, GEP->getPointerOperand(),
227                     GEP->getSourceElementType()))
228               GEP->setOperand(0, NoOpBitcast);
229             continue;
230           }
231           if (auto *CB = dyn_cast<CallBase>(&I)) {
232             CB->removeFnAttrs(AttrMask);
233             CB->removeRetAttrs(AttrMask);
234             for (size_t Idx = 0, End = CB->arg_size(); Idx < End; ++Idx)
235               CB->removeParamAttrs(Idx, AttrMask);
236             continue;
237           }
238         }
239       }
240     }
241     // Remove flags not for DXIL.
242     cleanModuleFlags(M);
243     return true;
244   }
245 
DXILPrepareModule()246   DXILPrepareModule() : ModulePass(ID) {}
getAnalysisUsage(AnalysisUsage & AU) const247   void getAnalysisUsage(AnalysisUsage &AU) const override {
248     AU.addPreserved<ShaderFlagsAnalysisWrapper>();
249     AU.addPreserved<DXILResourceWrapper>();
250   }
251   static char ID; // Pass identification.
252 };
253 char DXILPrepareModule::ID = 0;
254 
255 } // end anonymous namespace
256 
257 INITIALIZE_PASS_BEGIN(DXILPrepareModule, DEBUG_TYPE, "DXIL Prepare Module",
258                       false, false)
259 INITIALIZE_PASS_END(DXILPrepareModule, DEBUG_TYPE, "DXIL Prepare Module", false,
260                     false)
261 
createDXILPrepareModulePass()262 ModulePass *llvm::createDXILPrepareModulePass() {
263   return new DXILPrepareModule();
264 }
265