1 //===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains helper objects and APIs for working with DXIL
10 /// Shader Flags.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #include "DXILShaderFlags.h"
15 #include "DirectX.h"
16 #include "llvm/ADT/SCCIterator.h"
17 #include "llvm/ADT/SmallVector.h"
18 #include "llvm/Analysis/CallGraph.h"
19 #include "llvm/Analysis/DXILResource.h"
20 #include "llvm/IR/Attributes.h"
21 #include "llvm/IR/DiagnosticInfo.h"
22 #include "llvm/IR/Instruction.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicInst.h"
25 #include "llvm/IR/Intrinsics.h"
26 #include "llvm/IR/IntrinsicsDirectX.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Support/FormatVariadic.h"
30 #include "llvm/Support/raw_ostream.h"
31
32 using namespace llvm;
33 using namespace llvm::dxil;
34
hasUAVsAtEveryStage(const DXILResourceMap & DRM,const ModuleMetadataInfo & MMDI)35 static bool hasUAVsAtEveryStage(const DXILResourceMap &DRM,
36 const ModuleMetadataInfo &MMDI) {
37 if (DRM.uavs().empty())
38 return false;
39
40 switch (MMDI.ShaderProfile) {
41 default:
42 return false;
43 case Triple::EnvironmentType::Compute:
44 case Triple::EnvironmentType::Pixel:
45 return false;
46 case Triple::EnvironmentType::Vertex:
47 case Triple::EnvironmentType::Geometry:
48 case Triple::EnvironmentType::Hull:
49 case Triple::EnvironmentType::Domain:
50 return true;
51 case Triple::EnvironmentType::Library:
52 case Triple::EnvironmentType::RayGeneration:
53 case Triple::EnvironmentType::Intersection:
54 case Triple::EnvironmentType::AnyHit:
55 case Triple::EnvironmentType::ClosestHit:
56 case Triple::EnvironmentType::Miss:
57 case Triple::EnvironmentType::Callable:
58 case Triple::EnvironmentType::Mesh:
59 case Triple::EnvironmentType::Amplification:
60 return MMDI.ValidatorVersion < VersionTuple(1, 8);
61 }
62 }
63
checkWaveOps(Intrinsic::ID IID)64 static bool checkWaveOps(Intrinsic::ID IID) {
65 // Currently unsupported intrinsics
66 // case Intrinsic::dx_wave_getlanecount:
67 // case Intrinsic::dx_wave_allequal:
68 // case Intrinsic::dx_wave_ballot:
69 // case Intrinsic::dx_wave_readfirst:
70 // case Intrinsic::dx_wave_reduce.and:
71 // case Intrinsic::dx_wave_reduce.or:
72 // case Intrinsic::dx_wave_reduce.xor:
73 // case Intrinsic::dx_wave_prefixop:
74 // case Intrinsic::dx_quad.readat:
75 // case Intrinsic::dx_quad.readacrossx:
76 // case Intrinsic::dx_quad.readacrossy:
77 // case Intrinsic::dx_quad.readacrossdiagonal:
78 // case Intrinsic::dx_wave_prefixballot:
79 // case Intrinsic::dx_wave_match:
80 // case Intrinsic::dx_wavemulti.*:
81 // case Intrinsic::dx_wavemulti.ballot:
82 // case Intrinsic::dx_quad.vote:
83 switch (IID) {
84 default:
85 return false;
86 case Intrinsic::dx_wave_is_first_lane:
87 case Intrinsic::dx_wave_getlaneindex:
88 case Intrinsic::dx_wave_any:
89 case Intrinsic::dx_wave_all:
90 case Intrinsic::dx_wave_readlane:
91 case Intrinsic::dx_wave_active_countbits:
92 // Wave Active Op Variants
93 case Intrinsic::dx_wave_reduce_sum:
94 case Intrinsic::dx_wave_reduce_usum:
95 case Intrinsic::dx_wave_reduce_max:
96 case Intrinsic::dx_wave_reduce_umax:
97 return true;
98 }
99 }
100
101 /// Update the shader flags mask based on the given instruction.
102 /// \param CSF Shader flags mask to update.
103 /// \param I Instruction to check.
updateFunctionFlags(ComputedShaderFlags & CSF,const Instruction & I,DXILResourceTypeMap & DRTM,const ModuleMetadataInfo & MMDI)104 void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF,
105 const Instruction &I,
106 DXILResourceTypeMap &DRTM,
107 const ModuleMetadataInfo &MMDI) {
108 if (!CSF.Doubles)
109 CSF.Doubles = I.getType()->isDoubleTy();
110
111 if (!CSF.Doubles) {
112 for (const Value *Op : I.operands()) {
113 if (Op->getType()->isDoubleTy()) {
114 CSF.Doubles = true;
115 break;
116 }
117 }
118 }
119
120 if (CSF.Doubles) {
121 switch (I.getOpcode()) {
122 case Instruction::FDiv:
123 case Instruction::UIToFP:
124 case Instruction::SIToFP:
125 case Instruction::FPToUI:
126 case Instruction::FPToSI:
127 CSF.DX11_1_DoubleExtensions = true;
128 break;
129 }
130 }
131
132 if (!CSF.LowPrecisionPresent)
133 CSF.LowPrecisionPresent =
134 I.getType()->isIntegerTy(16) || I.getType()->isHalfTy();
135
136 if (!CSF.LowPrecisionPresent) {
137 for (const Value *Op : I.operands()) {
138 if (Op->getType()->isIntegerTy(16) || Op->getType()->isHalfTy()) {
139 CSF.LowPrecisionPresent = true;
140 break;
141 }
142 }
143 }
144
145 if (CSF.LowPrecisionPresent) {
146 if (CSF.NativeLowPrecisionMode)
147 CSF.NativeLowPrecision = true;
148 else
149 CSF.MinimumPrecision = true;
150 }
151
152 if (!CSF.Int64Ops)
153 CSF.Int64Ops = I.getType()->isIntegerTy(64);
154
155 if (!CSF.Int64Ops && !isa<LifetimeIntrinsic>(&I)) {
156 for (const Value *Op : I.operands()) {
157 if (Op->getType()->isIntegerTy(64)) {
158 CSF.Int64Ops = true;
159 break;
160 }
161 }
162 }
163
164 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
165 switch (II->getIntrinsicID()) {
166 default:
167 break;
168 case Intrinsic::dx_resource_handlefrombinding: {
169 dxil::ResourceTypeInfo &RTI = DRTM[cast<TargetExtType>(II->getType())];
170
171 // Set ResMayNotAlias if DXIL validator version >= 1.8 and the function
172 // uses UAVs
173 if (!CSF.ResMayNotAlias && CanSetResMayNotAlias &&
174 MMDI.ValidatorVersion >= VersionTuple(1, 8) && RTI.isUAV())
175 CSF.ResMayNotAlias = true;
176
177 switch (RTI.getResourceKind()) {
178 case dxil::ResourceKind::StructuredBuffer:
179 case dxil::ResourceKind::RawBuffer:
180 CSF.EnableRawAndStructuredBuffers = true;
181 break;
182 default:
183 break;
184 }
185 break;
186 }
187 case Intrinsic::dx_resource_load_typedbuffer: {
188 dxil::ResourceTypeInfo &RTI =
189 DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())];
190 if (RTI.isTyped())
191 CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1;
192 break;
193 }
194 }
195 }
196 // Handle call instructions
197 if (auto *CI = dyn_cast<CallInst>(&I)) {
198 const Function *CF = CI->getCalledFunction();
199 // Merge-in shader flags mask of the called function in the current module
200 if (FunctionFlags.contains(CF))
201 CSF.merge(FunctionFlags[CF]);
202
203 // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic
204 // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554
205
206 CSF.WaveOps |= checkWaveOps(CI->getIntrinsicID());
207 }
208 }
209
210 /// Set shader flags that apply to all functions within the module
211 ComputedShaderFlags
gatherGlobalModuleFlags(const Module & M,const DXILResourceMap & DRM,const ModuleMetadataInfo & MMDI)212 ModuleShaderFlags::gatherGlobalModuleFlags(const Module &M,
213 const DXILResourceMap &DRM,
214 const ModuleMetadataInfo &MMDI) {
215
216 ComputedShaderFlags CSF;
217
218 // Set DisableOptimizations flag based on the presence of OptimizeNone
219 // attribute of entry functions.
220 if (MMDI.EntryPropertyVec.size() > 0) {
221 CSF.DisableOptimizations = MMDI.EntryPropertyVec[0].Entry->hasFnAttribute(
222 llvm::Attribute::OptimizeNone);
223 // Ensure all entry functions have the same optimization attribute
224 for (const auto &EntryFunProps : MMDI.EntryPropertyVec)
225 if (CSF.DisableOptimizations !=
226 EntryFunProps.Entry->hasFnAttribute(llvm::Attribute::OptimizeNone))
227 EntryFunProps.Entry->getContext().diagnose(DiagnosticInfoUnsupported(
228 *(EntryFunProps.Entry), "Inconsistent optnone attribute "));
229 }
230
231 CSF.UAVsAtEveryStage = hasUAVsAtEveryStage(DRM, MMDI);
232
233 // Set the Max64UAVs flag if the number of UAVs is > 8
234 uint32_t NumUAVs = 0;
235 for (auto &UAV : DRM.uavs())
236 if (MMDI.ValidatorVersion < VersionTuple(1, 6))
237 NumUAVs++;
238 else // MMDI.ValidatorVersion >= VersionTuple(1, 6)
239 NumUAVs += UAV.getBinding().Size;
240 if (NumUAVs > 8)
241 CSF.Max64UAVs = true;
242
243 // Set the module flag that enables native low-precision execution mode.
244 // NativeLowPrecisionMode can only be set when the command line option
245 // -enable-16bit-types is provided. This is indicated by the dx.nativelowprec
246 // module flag being set
247 // This flag is needed even if the module does not use 16-bit types because a
248 // corresponding debug module may include 16-bit types, and tools that use the
249 // debug module may expect it to have the same flags as the original
250 if (auto *NativeLowPrec = mdconst::extract_or_null<ConstantInt>(
251 M.getModuleFlag("dx.nativelowprec")))
252 if (MMDI.ShaderModelVersion >= VersionTuple(6, 2))
253 CSF.NativeLowPrecisionMode = NativeLowPrec->getValue().getBoolValue();
254
255 // Set ResMayNotAlias to true if DXIL validator version < 1.8 and there
256 // are UAVs present globally.
257 if (CanSetResMayNotAlias && MMDI.ValidatorVersion < VersionTuple(1, 8))
258 CSF.ResMayNotAlias = !DRM.uavs().empty();
259
260 return CSF;
261 }
262
263 /// Construct ModuleShaderFlags for module Module M
initialize(Module & M,DXILResourceTypeMap & DRTM,const DXILResourceMap & DRM,const ModuleMetadataInfo & MMDI)264 void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM,
265 const DXILResourceMap &DRM,
266 const ModuleMetadataInfo &MMDI) {
267
268 CanSetResMayNotAlias = MMDI.DXILVersion >= VersionTuple(1, 7);
269 // The command line option -res-may-alias will set the dx.resmayalias module
270 // flag to 1, thereby disabling the ability to set the ResMayNotAlias flag
271 if (auto *ResMayAlias = mdconst::extract_or_null<ConstantInt>(
272 M.getModuleFlag("dx.resmayalias")))
273 if (ResMayAlias->getValue().getBoolValue())
274 CanSetResMayNotAlias = false;
275
276 ComputedShaderFlags GlobalSFMask = gatherGlobalModuleFlags(M, DRM, MMDI);
277
278 CallGraph CG(M);
279
280 // Compute Shader Flags Mask for all functions using post-order visit of SCC
281 // of the call graph.
282 for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd();
283 ++SCCI) {
284 const std::vector<CallGraphNode *> &CurSCC = *SCCI;
285
286 // Union of shader masks of all functions in CurSCC
287 ComputedShaderFlags SCCSF;
288 // List of functions in CurSCC that are neither external nor declarations
289 // and hence whose flags are collected
290 SmallVector<Function *> CurSCCFuncs;
291 for (CallGraphNode *CGN : CurSCC) {
292 Function *F = CGN->getFunction();
293 if (!F)
294 continue;
295
296 if (F->isDeclaration()) {
297 assert(!F->getName().starts_with("dx.op.") &&
298 "DXIL Shader Flag analysis should not be run post-lowering.");
299 continue;
300 }
301
302 ComputedShaderFlags CSF = GlobalSFMask;
303 for (const auto &BB : *F)
304 for (const auto &I : BB)
305 updateFunctionFlags(CSF, I, DRTM, MMDI);
306 // Update combined shader flags mask for all functions in this SCC
307 SCCSF.merge(CSF);
308
309 CurSCCFuncs.push_back(F);
310 }
311
312 // Update combined shader flags mask for all functions of the module
313 CombinedSFMask.merge(SCCSF);
314
315 // Shader flags mask of each of the functions in an SCC of the call graph is
316 // the union of all functions in the SCC. Update shader flags masks of
317 // functions in CurSCC accordingly. This is trivially true if SCC contains
318 // one function.
319 for (Function *F : CurSCCFuncs)
320 // Merge SCCSF with that of F
321 FunctionFlags[F].merge(SCCSF);
322 }
323 }
324
print(raw_ostream & OS) const325 void ComputedShaderFlags::print(raw_ostream &OS) const {
326 uint64_t FlagVal = (uint64_t) * this;
327 OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal);
328 if (FlagVal == 0)
329 return;
330 OS << "; Note: shader requires additional functionality:\n";
331 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str) \
332 if (FlagName) \
333 (OS << ";").indent(7) << Str << "\n";
334 #include "llvm/BinaryFormat/DXContainerConstants.def"
335 OS << "; Note: extra DXIL module flags:\n";
336 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \
337 if (FlagName) \
338 (OS << ";").indent(7) << Str << "\n";
339 #include "llvm/BinaryFormat/DXContainerConstants.def"
340 OS << ";\n";
341 }
342
343 /// Return the shader flags mask of the specified function Func.
344 const ComputedShaderFlags &
getFunctionFlags(const Function * Func) const345 ModuleShaderFlags::getFunctionFlags(const Function *Func) const {
346 auto Iter = FunctionFlags.find(Func);
347 assert((Iter != FunctionFlags.end() && Iter->first == Func) &&
348 "Get Shader Flags : No Shader Flags Mask exists for function");
349 return Iter->second;
350 }
351
352 //===----------------------------------------------------------------------===//
353 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
354
355 // Provide an explicit template instantiation for the static ID.
356 AnalysisKey ShaderFlagsAnalysis::Key;
357
run(Module & M,ModuleAnalysisManager & AM)358 ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M,
359 ModuleAnalysisManager &AM) {
360 DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M);
361 DXILResourceMap &DRM = AM.getResult<DXILResourceAnalysis>(M);
362 const ModuleMetadataInfo MMDI = AM.getResult<DXILMetadataAnalysis>(M);
363
364 ModuleShaderFlags MSFI;
365 MSFI.initialize(M, DRTM, DRM, MMDI);
366
367 return MSFI;
368 }
369
run(Module & M,ModuleAnalysisManager & AM)370 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M,
371 ModuleAnalysisManager &AM) {
372 const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M);
373 // Print description of combined shader flags for all module functions
374 OS << "; Combined Shader Flags for Module\n";
375 FlagsInfo.getCombinedFlags().print(OS);
376 // Print shader flags mask for each of the module functions
377 OS << "; Shader Flags for Module Functions\n";
378 for (const auto &F : M.getFunctionList()) {
379 if (F.isDeclaration())
380 continue;
381 const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F);
382 OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(),
383 (uint64_t)(SFMask));
384 }
385
386 return PreservedAnalyses::all();
387 }
388
389 //===----------------------------------------------------------------------===//
390 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass
391
runOnModule(Module & M)392 bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) {
393 DXILResourceTypeMap &DRTM =
394 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
395 DXILResourceMap &DRM =
396 getAnalysis<DXILResourceWrapperPass>().getResourceMap();
397 const ModuleMetadataInfo MMDI =
398 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
399
400 MSFI.initialize(M, DRTM, DRM, MMDI);
401 return false;
402 }
403
getAnalysisUsage(AnalysisUsage & AU) const404 void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const {
405 AU.setPreservesAll();
406 AU.addRequiredTransitive<DXILResourceTypeWrapperPass>();
407 AU.addRequiredTransitive<DXILResourceWrapperPass>();
408 AU.addRequired<DXILMetadataAnalysisWrapperPass>();
409 }
410
411 char ShaderFlagsAnalysisWrapper::ID = 0;
412
413 INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
414 "DXIL Shader Flag Analysis", true, true)
415 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass)
416 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass)
417 INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis",
418 "DXIL Shader Flag Analysis", true, true)
419