xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILShaderFlags.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects and APIs for working with DXIL
10 ///       Shader Flags.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "DXILShaderFlags.h"
15 #include "DirectX.h"
16 #include "llvm/ADT/SCCIterator.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/CallGraph.h"
19 #include "llvm/Analysis/DXILResource.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/DiagnosticInfo.h"
22 #include "llvm/IR/Instruction.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/IntrinsicsDirectX.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/raw_ostream.h"
31 
32 using namespace llvm;
33 using namespace llvm::dxil;
34 
hasUAVsAtEveryStage(const DXILResourceMap & DRM,const ModuleMetadataInfo & MMDI)35 static bool hasUAVsAtEveryStage(const DXILResourceMap &DRM,
36                                 const ModuleMetadataInfo &MMDI) {
37   if (DRM.uavs().empty())
38     return false;
39 
40   switch (MMDI.ShaderProfile) {
41   default:
42     return false;
43   case Triple::EnvironmentType::Compute:
44   case Triple::EnvironmentType::Pixel:
45     return false;
46   case Triple::EnvironmentType::Vertex:
47   case Triple::EnvironmentType::Geometry:
48   case Triple::EnvironmentType::Hull:
49   case Triple::EnvironmentType::Domain:
50     return true;
51   case Triple::EnvironmentType::Library:
52   case Triple::EnvironmentType::RayGeneration:
53   case Triple::EnvironmentType::Intersection:
54   case Triple::EnvironmentType::AnyHit:
55   case Triple::EnvironmentType::ClosestHit:
56   case Triple::EnvironmentType::Miss:
57   case Triple::EnvironmentType::Callable:
58   case Triple::EnvironmentType::Mesh:
59   case Triple::EnvironmentType::Amplification:
60     return MMDI.ValidatorVersion < VersionTuple(1, 8);
61   }
62 }
63 
checkWaveOps(Intrinsic::ID IID)64 static bool checkWaveOps(Intrinsic::ID IID) {
65   // Currently unsupported intrinsics
66   // case Intrinsic::dx_wave_getlanecount:
67   // case Intrinsic::dx_wave_allequal:
68   // case Intrinsic::dx_wave_ballot:
69   // case Intrinsic::dx_wave_readfirst:
70   // case Intrinsic::dx_wave_reduce.and:
71   // case Intrinsic::dx_wave_reduce.or:
72   // case Intrinsic::dx_wave_reduce.xor:
73   // case Intrinsic::dx_wave_prefixop:
74   // case Intrinsic::dx_quad.readat:
75   // case Intrinsic::dx_quad.readacrossx:
76   // case Intrinsic::dx_quad.readacrossy:
77   // case Intrinsic::dx_quad.readacrossdiagonal:
78   // case Intrinsic::dx_wave_prefixballot:
79   // case Intrinsic::dx_wave_match:
80   // case Intrinsic::dx_wavemulti.*:
81   // case Intrinsic::dx_wavemulti.ballot:
82   // case Intrinsic::dx_quad.vote:
83   switch (IID) {
84   default:
85     return false;
86   case Intrinsic::dx_wave_is_first_lane:
87   case Intrinsic::dx_wave_getlaneindex:
88   case Intrinsic::dx_wave_any:
89   case Intrinsic::dx_wave_all:
90   case Intrinsic::dx_wave_readlane:
91   case Intrinsic::dx_wave_active_countbits:
92   // Wave Active Op Variants
93   case Intrinsic::dx_wave_reduce_sum:
94   case Intrinsic::dx_wave_reduce_usum:
95   case Intrinsic::dx_wave_reduce_max:
96   case Intrinsic::dx_wave_reduce_umax:
97     return true;
98   }
99 }
100 
101 /// Update the shader flags mask based on the given instruction.
102 /// \param CSF Shader flags mask to update.
103 /// \param I Instruction to check.
updateFunctionFlags(ComputedShaderFlags & CSF,const Instruction & I,DXILResourceTypeMap & DRTM,const ModuleMetadataInfo & MMDI)104 void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
105                                             const Instruction &I,
106                                             DXILResourceTypeMap &DRTM,
107                                             const ModuleMetadataInfo &MMDI) {
108   if (!CSF.Doubles)
109     CSF.Doubles = I.getType()->isDoubleTy();
110 
111   if (!CSF.Doubles) {
112     for (const Value *Op : I.operands()) {
113       if (Op->getType()->isDoubleTy()) {
114         CSF.Doubles = true;
115         break;
116       }
117     }
118   }
119 
120   if (CSF.Doubles) {
121     switch (I.getOpcode()) {
122     case Instruction::FDiv:
123     case Instruction::UIToFP:
124     case Instruction::SIToFP:
125     case Instruction::FPToUI:
126     case Instruction::FPToSI:
127       CSF.DX11_1_DoubleExtensions = true;
128       break;
129     }
130   }
131 
132   if (!CSF.LowPrecisionPresent)
133     CSF.LowPrecisionPresent =
134         I.getType()->isIntegerTy(16) || I.getType()->isHalfTy();
135 
136   if (!CSF.LowPrecisionPresent) {
137     for (const Value *Op : I.operands()) {
138       if (Op->getType()->isIntegerTy(16) || Op->getType()->isHalfTy()) {
139         CSF.LowPrecisionPresent = true;
140         break;
141       }
142     }
143   }
144 
145   if (CSF.LowPrecisionPresent) {
146     if (CSF.NativeLowPrecisionMode)
147       CSF.NativeLowPrecision = true;
148     else
149       CSF.MinimumPrecision = true;
150   }
151 
152   if (!CSF.Int64Ops)
153     CSF.Int64Ops = I.getType()->isIntegerTy(64);
154 
155   if (!CSF.Int64Ops && !isa<LifetimeIntrinsic>(&I)) {
156     for (const Value *Op : I.operands()) {
157       if (Op->getType()->isIntegerTy(64)) {
158         CSF.Int64Ops = true;
159         break;
160       }
161     }
162   }
163 
164   if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
165     switch (II->getIntrinsicID()) {
166     default:
167       break;
168     case Intrinsic::dx_resource_handlefrombinding: {
169       dxil::ResourceTypeInfo &RTI = DRTM[cast<TargetExtType>(II->getType())];
170 
171       // Set ResMayNotAlias if DXIL validator version >= 1.8 and the function
172       // uses UAVs
173       if (!CSF.ResMayNotAlias && CanSetResMayNotAlias &&
174           MMDI.ValidatorVersion >= VersionTuple(1, 8) && RTI.isUAV())
175         CSF.ResMayNotAlias = true;
176 
177       switch (RTI.getResourceKind()) {
178       case dxil::ResourceKind::StructuredBuffer:
179       case dxil::ResourceKind::RawBuffer:
180         CSF.EnableRawAndStructuredBuffers = true;
181         break;
182       default:
183         break;
184       }
185       break;
186     }
187     case Intrinsic::dx_resource_load_typedbuffer: {
188       dxil::ResourceTypeInfo &RTI =
189           DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())];
190       if (RTI.isTyped())
191         CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1;
192       break;
193     }
194     }
195   }
196   // Handle call instructions
197   if (auto *CI = dyn_cast<CallInst>(&I)) {
198     const Function *CF = CI->getCalledFunction();
199     // Merge-in shader flags mask of the called function in the current module
200     if (FunctionFlags.contains(CF))
201       CSF.merge(FunctionFlags[CF]);
202 
203     // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
204     // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
205 
206     CSF.WaveOps |= checkWaveOps(CI->getIntrinsicID());
207   }
208 }
209 
210 /// Set shader flags that apply to all functions within the module
211 ComputedShaderFlags
gatherGlobalModuleFlags(const Module & M,const DXILResourceMap & DRM,const ModuleMetadataInfo & MMDI)212 ModuleShaderFlags::gatherGlobalModuleFlags(const Module &M,
213                                            const DXILResourceMap &DRM,
214                                            const ModuleMetadataInfo &MMDI) {
215 
216   ComputedShaderFlags CSF;
217 
218   // Set DisableOptimizations flag based on the presence of OptimizeNone
219   // attribute of entry functions.
220   if (MMDI.EntryPropertyVec.size() > 0) {
221     CSF.DisableOptimizations = MMDI.EntryPropertyVec[0].Entry->hasFnAttribute(
222         llvm::Attribute::OptimizeNone);
223     // Ensure all entry functions have the same optimization attribute
224     for (const auto &EntryFunProps : MMDI.EntryPropertyVec)
225       if (CSF.DisableOptimizations !=
226           EntryFunProps.Entry->hasFnAttribute(llvm::Attribute::OptimizeNone))
227         EntryFunProps.Entry->getContext().diagnose(DiagnosticInfoUnsupported(
228             *(EntryFunProps.Entry), "Inconsistent optnone attribute "));
229   }
230 
231   CSF.UAVsAtEveryStage = hasUAVsAtEveryStage(DRM, MMDI);
232 
233   // Set the Max64UAVs flag if the number of UAVs is > 8
234   uint32_t NumUAVs = 0;
235   for (auto &UAV : DRM.uavs())
236     if (MMDI.ValidatorVersion < VersionTuple(1, 6))
237       NumUAVs++;
238     else // MMDI.ValidatorVersion >= VersionTuple(1, 6)
239       NumUAVs += UAV.getBinding().Size;
240   if (NumUAVs > 8)
241     CSF.Max64UAVs = true;
242 
243   // Set the module flag that enables native low-precision execution mode.
244   // NativeLowPrecisionMode can only be set when the command line option
245   // -enable-16bit-types is provided. This is indicated by the dx.nativelowprec
246   // module flag being set
247   // This flag is needed even if the module does not use 16-bit types because a
248   // corresponding debug module may include 16-bit types, and tools that use the
249   // debug module may expect it to have the same flags as the original
250   if (auto *NativeLowPrec = mdconst::extract_or_null<ConstantInt>(
251           M.getModuleFlag("dx.nativelowprec")))
252     if (MMDI.ShaderModelVersion >= VersionTuple(6, 2))
253       CSF.NativeLowPrecisionMode = NativeLowPrec->getValue().getBoolValue();
254 
255   // Set ResMayNotAlias to true if DXIL validator version < 1.8 and there
256   // are UAVs present globally.
257   if (CanSetResMayNotAlias && MMDI.ValidatorVersion < VersionTuple(1, 8))
258     CSF.ResMayNotAlias = !DRM.uavs().empty();
259 
260   return CSF;
261 }
262 
263 /// Construct ModuleShaderFlags for module Module M
initialize(Module & M,DXILResourceTypeMap & DRTM,const DXILResourceMap & DRM,const ModuleMetadataInfo & MMDI)264 void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM,
265                                    const DXILResourceMap &DRM,
266                                    const ModuleMetadataInfo &MMDI) {
267 
268   CanSetResMayNotAlias = MMDI.DXILVersion >= VersionTuple(1, 7);
269   // The command line option -res-may-alias will set the dx.resmayalias module
270   // flag to 1, thereby disabling the ability to set the ResMayNotAlias flag
271   if (auto *ResMayAlias = mdconst::extract_or_null<ConstantInt>(
272           M.getModuleFlag("dx.resmayalias")))
273     if (ResMayAlias->getValue().getBoolValue())
274       CanSetResMayNotAlias = false;
275 
276   ComputedShaderFlags GlobalSFMask = gatherGlobalModuleFlags(M, DRM, MMDI);
277 
278   CallGraph CG(M);
279 
280   // Compute Shader Flags Mask for all functions using post-order visit of SCC
281   // of the call graph.
282   for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
283        ++SCCI) {
284     const std::vector<CallGraphNode *> &CurSCC = *SCCI;
285 
286     // Union of shader masks of all functions in CurSCC
287     ComputedShaderFlags SCCSF;
288     // List of functions in CurSCC that are neither external nor declarations
289     // and hence whose flags are collected
290     SmallVector<Function *> CurSCCFuncs;
291     for (CallGraphNode *CGN : CurSCC) {
292       Function *F = CGN->getFunction();
293       if (!F)
294         continue;
295 
296       if (F->isDeclaration()) {
297         assert(!F->getName().starts_with("dx.op.") &&
298                "DXIL Shader Flag analysis should not be run post-lowering.");
299         continue;
300       }
301 
302       ComputedShaderFlags CSF = GlobalSFMask;
303       for (const auto &BB : *F)
304         for (const auto &I : BB)
305           updateFunctionFlags(CSF, I, DRTM, MMDI);
306       // Update combined shader flags mask for all functions in this SCC
307       SCCSF.merge(CSF);
308 
309       CurSCCFuncs.push_back(F);
310     }
311 
312     // Update combined shader flags mask for all functions of the module
313     CombinedSFMask.merge(SCCSF);
314 
315     // Shader flags mask of each of the functions in an SCC of the call graph is
316     // the union of all functions in the SCC. Update shader flags masks of
317     // functions in CurSCC accordingly. This is trivially true if SCC contains
318     // one function.
319     for (Function *F : CurSCCFuncs)
320       // Merge SCCSF with that of F
321       FunctionFlags[F].merge(SCCSF);
322   }
323 }
324 
print(raw_ostream & OS) const325 void ComputedShaderFlags::print(raw_ostream &OS) const {
326   uint64_t FlagVal = (uint64_t) * this;
327   OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
328   if (FlagVal == 0)
329     return;
330   OS << "; Note: shader requires additional functionality:\n";
331 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str)          \
332   if (FlagName)                                                                \
333     (OS << ";").indent(7) << Str << "\n";
334 #include "llvm/BinaryFormat/DXContainerConstants.def"
335   OS << "; Note: extra DXIL module flags:\n";
336 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str)                         \
337   if (FlagName)                                                                \
338     (OS << ";").indent(7) << Str << "\n";
339 #include "llvm/BinaryFormat/DXContainerConstants.def"
340   OS << ";\n";
341 }
342 
343 /// Return the shader flags mask of the specified function Func.
344 const ComputedShaderFlags &
getFunctionFlags(const Function * Func) const345 ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
346   auto Iter = FunctionFlags.find(Func);
347   assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
348          "Get Shader Flags : No Shader Flags Mask exists for function");
349   return Iter->second;
350 }
351 
352 //===----------------------------------------------------------------------===//
353 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
354 
355 // Provide an explicit template instantiation for the static ID.
356 AnalysisKey ShaderFlagsAnalysis::Key;
357 
run(Module & M,ModuleAnalysisManager & AM)358 ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
359                                            ModuleAnalysisManager &AM) {
360   DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M);
361   DXILResourceMap &DRM = AM.getResult<DXILResourceAnalysis>(M);
362   const ModuleMetadataInfo MMDI = AM.getResult<DXILMetadataAnalysis>(M);
363 
364   ModuleShaderFlags MSFI;
365   MSFI.initialize(M, DRTM, DRM, MMDI);
366 
367   return MSFI;
368 }
369 
run(Module & M,ModuleAnalysisManager & AM)370 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
371                                                   ModuleAnalysisManager &AM) {
372   const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
373   // Print description of combined shader flags for all module functions
374   OS << "; Combined Shader Flags for Module\n";
375   FlagsInfo.getCombinedFlags().print(OS);
376   // Print shader flags mask for each of the module functions
377   OS << "; Shader Flags for Module Functions\n";
378   for (const auto &F : M.getFunctionList()) {
379     if (F.isDeclaration())
380       continue;
381     const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
382     OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
383                   (uint64_t)(SFMask));
384   }
385 
386   return PreservedAnalyses::all();
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
391 
runOnModule(Module & M)392 bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
393   DXILResourceTypeMap &DRTM =
394       getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
395   DXILResourceMap &DRM =
396       getAnalysis<DXILResourceWrapperPass>().getResourceMap();
397   const ModuleMetadataInfo MMDI =
398       getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
399 
400   MSFI.initialize(M, DRTM, DRM, MMDI);
401   return false;
402 }
403 
getAnalysisUsage(AnalysisUsage & AU) const404 void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
405   AU.setPreservesAll();
406   AU.addRequiredTransitive<DXILResourceTypeWrapperPass>();
407   AU.addRequiredTransitive<DXILResourceWrapperPass>();
408   AU.addRequired<DXILMetadataAnalysisWrapperPass>();
409 }
410 
411 char ShaderFlagsAnalysisWrapper::ID = 0;
412 
413 INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
414                       "DXIL Shader Flag Analysis", true, true)
415 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
416 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
417 INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
418                     "DXIL Shader Flag Analysis", true, true)
419