1 //===- DXILShaderFlags.cpp - DXIL Shader Flags helper objects -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains helper objects and APIs for working with DXIL 10 /// Shader Flags. 11 /// 12 //===----------------------------------------------------------------------===// 13 14 #include "DXILShaderFlags.h" 15 #include "DirectX.h" 16 #include "llvm/ADT/SCCIterator.h" 17 #include "llvm/ADT/SmallVector.h" 18 #include "llvm/Analysis/CallGraph.h" 19 #include "llvm/Analysis/DXILResource.h" 20 #include "llvm/IR/Attributes.h" 21 #include "llvm/IR/DiagnosticInfo.h" 22 #include "llvm/IR/Instruction.h" 23 #include "llvm/IR/Instructions.h" 24 #include "llvm/IR/IntrinsicInst.h" 25 #include "llvm/IR/Intrinsics.h" 26 #include "llvm/IR/IntrinsicsDirectX.h" 27 #include "llvm/IR/Module.h" 28 #include "llvm/InitializePasses.h" 29 #include "llvm/Support/FormatVariadic.h" 30 #include "llvm/Support/raw_ostream.h" 31 32 using namespace llvm; 33 using namespace llvm::dxil; 34 35 static bool hasUAVsAtEveryStage(const DXILResourceMap &DRM, 36 const ModuleMetadataInfo &MMDI) { 37 if (DRM.uavs().empty()) 38 return false; 39 40 switch (MMDI.ShaderProfile) { 41 default: 42 return false; 43 case Triple::EnvironmentType::Compute: 44 case Triple::EnvironmentType::Pixel: 45 return false; 46 case Triple::EnvironmentType::Vertex: 47 case Triple::EnvironmentType::Geometry: 48 case Triple::EnvironmentType::Hull: 49 case Triple::EnvironmentType::Domain: 50 return true; 51 case Triple::EnvironmentType::Library: 52 case Triple::EnvironmentType::RayGeneration: 53 case Triple::EnvironmentType::Intersection: 54 case Triple::EnvironmentType::AnyHit: 55 case Triple::EnvironmentType::ClosestHit: 56 case Triple::EnvironmentType::Miss: 57 case Triple::EnvironmentType::Callable: 58 case Triple::EnvironmentType::Mesh: 59 case Triple::EnvironmentType::Amplification: 60 return MMDI.ValidatorVersion < VersionTuple(1, 8); 61 } 62 } 63 64 static bool checkWaveOps(Intrinsic::ID IID) { 65 // Currently unsupported intrinsics 66 // case Intrinsic::dx_wave_getlanecount: 67 // case Intrinsic::dx_wave_allequal: 68 // case Intrinsic::dx_wave_ballot: 69 // case Intrinsic::dx_wave_readfirst: 70 // case Intrinsic::dx_wave_reduce.and: 71 // case Intrinsic::dx_wave_reduce.or: 72 // case Intrinsic::dx_wave_reduce.xor: 73 // case Intrinsic::dx_wave_prefixop: 74 // case Intrinsic::dx_quad.readat: 75 // case Intrinsic::dx_quad.readacrossx: 76 // case Intrinsic::dx_quad.readacrossy: 77 // case Intrinsic::dx_quad.readacrossdiagonal: 78 // case Intrinsic::dx_wave_prefixballot: 79 // case Intrinsic::dx_wave_match: 80 // case Intrinsic::dx_wavemulti.*: 81 // case Intrinsic::dx_wavemulti.ballot: 82 // case Intrinsic::dx_quad.vote: 83 switch (IID) { 84 default: 85 return false; 86 case Intrinsic::dx_wave_is_first_lane: 87 case Intrinsic::dx_wave_getlaneindex: 88 case Intrinsic::dx_wave_any: 89 case Intrinsic::dx_wave_all: 90 case Intrinsic::dx_wave_readlane: 91 case Intrinsic::dx_wave_active_countbits: 92 // Wave Active Op Variants 93 case Intrinsic::dx_wave_reduce_sum: 94 case Intrinsic::dx_wave_reduce_usum: 95 case Intrinsic::dx_wave_reduce_max: 96 case Intrinsic::dx_wave_reduce_umax: 97 return true; 98 } 99 } 100 101 /// Update the shader flags mask based on the given instruction. 102 /// \param CSF Shader flags mask to update. 103 /// \param I Instruction to check. 104 void ModuleShaderFlags::updateFunctionFlags(ComputedShaderFlags &CSF, 105 const Instruction &I, 106 DXILResourceTypeMap &DRTM, 107 const ModuleMetadataInfo &MMDI) { 108 if (!CSF.Doubles) 109 CSF.Doubles = I.getType()->isDoubleTy(); 110 111 if (!CSF.Doubles) { 112 for (const Value *Op : I.operands()) { 113 if (Op->getType()->isDoubleTy()) { 114 CSF.Doubles = true; 115 break; 116 } 117 } 118 } 119 120 if (CSF.Doubles) { 121 switch (I.getOpcode()) { 122 case Instruction::FDiv: 123 case Instruction::UIToFP: 124 case Instruction::SIToFP: 125 case Instruction::FPToUI: 126 case Instruction::FPToSI: 127 CSF.DX11_1_DoubleExtensions = true; 128 break; 129 } 130 } 131 132 if (!CSF.LowPrecisionPresent) 133 CSF.LowPrecisionPresent = 134 I.getType()->isIntegerTy(16) || I.getType()->isHalfTy(); 135 136 if (!CSF.LowPrecisionPresent) { 137 for (const Value *Op : I.operands()) { 138 if (Op->getType()->isIntegerTy(16) || Op->getType()->isHalfTy()) { 139 CSF.LowPrecisionPresent = true; 140 break; 141 } 142 } 143 } 144 145 if (CSF.LowPrecisionPresent) { 146 if (CSF.NativeLowPrecisionMode) 147 CSF.NativeLowPrecision = true; 148 else 149 CSF.MinimumPrecision = true; 150 } 151 152 if (!CSF.Int64Ops) 153 CSF.Int64Ops = I.getType()->isIntegerTy(64); 154 155 if (!CSF.Int64Ops && !isa<LifetimeIntrinsic>(&I)) { 156 for (const Value *Op : I.operands()) { 157 if (Op->getType()->isIntegerTy(64)) { 158 CSF.Int64Ops = true; 159 break; 160 } 161 } 162 } 163 164 if (auto *II = dyn_cast<IntrinsicInst>(&I)) { 165 switch (II->getIntrinsicID()) { 166 default: 167 break; 168 case Intrinsic::dx_resource_handlefrombinding: { 169 dxil::ResourceTypeInfo &RTI = DRTM[cast<TargetExtType>(II->getType())]; 170 171 // Set ResMayNotAlias if DXIL validator version >= 1.8 and the function 172 // uses UAVs 173 if (!CSF.ResMayNotAlias && CanSetResMayNotAlias && 174 MMDI.ValidatorVersion >= VersionTuple(1, 8) && RTI.isUAV()) 175 CSF.ResMayNotAlias = true; 176 177 switch (RTI.getResourceKind()) { 178 case dxil::ResourceKind::StructuredBuffer: 179 case dxil::ResourceKind::RawBuffer: 180 CSF.EnableRawAndStructuredBuffers = true; 181 break; 182 default: 183 break; 184 } 185 break; 186 } 187 case Intrinsic::dx_resource_load_typedbuffer: { 188 dxil::ResourceTypeInfo &RTI = 189 DRTM[cast<TargetExtType>(II->getArgOperand(0)->getType())]; 190 if (RTI.isTyped()) 191 CSF.TypedUAVLoadAdditionalFormats |= RTI.getTyped().ElementCount > 1; 192 break; 193 } 194 } 195 } 196 // Handle call instructions 197 if (auto *CI = dyn_cast<CallInst>(&I)) { 198 const Function *CF = CI->getCalledFunction(); 199 // Merge-in shader flags mask of the called function in the current module 200 if (FunctionFlags.contains(CF)) 201 CSF.merge(FunctionFlags[CF]); 202 203 // TODO: Set DX11_1_DoubleExtensions if I is a call to DXIL intrinsic 204 // DXIL::Opcode::Fma https://github.com/llvm/llvm-project/issues/114554 205 206 CSF.WaveOps |= checkWaveOps(CI->getIntrinsicID()); 207 } 208 } 209 210 /// Set shader flags that apply to all functions within the module 211 ComputedShaderFlags 212 ModuleShaderFlags::gatherGlobalModuleFlags(const Module &M, 213 const DXILResourceMap &DRM, 214 const ModuleMetadataInfo &MMDI) { 215 216 ComputedShaderFlags CSF; 217 218 // Set DisableOptimizations flag based on the presence of OptimizeNone 219 // attribute of entry functions. 220 if (MMDI.EntryPropertyVec.size() > 0) { 221 CSF.DisableOptimizations = MMDI.EntryPropertyVec[0].Entry->hasFnAttribute( 222 llvm::Attribute::OptimizeNone); 223 // Ensure all entry functions have the same optimization attribute 224 for (const auto &EntryFunProps : MMDI.EntryPropertyVec) 225 if (CSF.DisableOptimizations != 226 EntryFunProps.Entry->hasFnAttribute(llvm::Attribute::OptimizeNone)) 227 EntryFunProps.Entry->getContext().diagnose(DiagnosticInfoUnsupported( 228 *(EntryFunProps.Entry), "Inconsistent optnone attribute ")); 229 } 230 231 CSF.UAVsAtEveryStage = hasUAVsAtEveryStage(DRM, MMDI); 232 233 // Set the Max64UAVs flag if the number of UAVs is > 8 234 uint32_t NumUAVs = 0; 235 for (auto &UAV : DRM.uavs()) 236 if (MMDI.ValidatorVersion < VersionTuple(1, 6)) 237 NumUAVs++; 238 else // MMDI.ValidatorVersion >= VersionTuple(1, 6) 239 NumUAVs += UAV.getBinding().Size; 240 if (NumUAVs > 8) 241 CSF.Max64UAVs = true; 242 243 // Set the module flag that enables native low-precision execution mode. 244 // NativeLowPrecisionMode can only be set when the command line option 245 // -enable-16bit-types is provided. This is indicated by the dx.nativelowprec 246 // module flag being set 247 // This flag is needed even if the module does not use 16-bit types because a 248 // corresponding debug module may include 16-bit types, and tools that use the 249 // debug module may expect it to have the same flags as the original 250 if (auto *NativeLowPrec = mdconst::extract_or_null<ConstantInt>( 251 M.getModuleFlag("dx.nativelowprec"))) 252 if (MMDI.ShaderModelVersion >= VersionTuple(6, 2)) 253 CSF.NativeLowPrecisionMode = NativeLowPrec->getValue().getBoolValue(); 254 255 // Set ResMayNotAlias to true if DXIL validator version < 1.8 and there 256 // are UAVs present globally. 257 if (CanSetResMayNotAlias && MMDI.ValidatorVersion < VersionTuple(1, 8)) 258 CSF.ResMayNotAlias = !DRM.uavs().empty(); 259 260 return CSF; 261 } 262 263 /// Construct ModuleShaderFlags for module Module M 264 void ModuleShaderFlags::initialize(Module &M, DXILResourceTypeMap &DRTM, 265 const DXILResourceMap &DRM, 266 const ModuleMetadataInfo &MMDI) { 267 268 CanSetResMayNotAlias = MMDI.DXILVersion >= VersionTuple(1, 7); 269 // The command line option -res-may-alias will set the dx.resmayalias module 270 // flag to 1, thereby disabling the ability to set the ResMayNotAlias flag 271 if (auto *ResMayAlias = mdconst::extract_or_null<ConstantInt>( 272 M.getModuleFlag("dx.resmayalias"))) 273 if (ResMayAlias->getValue().getBoolValue()) 274 CanSetResMayNotAlias = false; 275 276 ComputedShaderFlags GlobalSFMask = gatherGlobalModuleFlags(M, DRM, MMDI); 277 278 CallGraph CG(M); 279 280 // Compute Shader Flags Mask for all functions using post-order visit of SCC 281 // of the call graph. 282 for (scc_iterator<CallGraph *> SCCI = scc_begin(&CG); !SCCI.isAtEnd(); 283 ++SCCI) { 284 const std::vector<CallGraphNode *> &CurSCC = *SCCI; 285 286 // Union of shader masks of all functions in CurSCC 287 ComputedShaderFlags SCCSF; 288 // List of functions in CurSCC that are neither external nor declarations 289 // and hence whose flags are collected 290 SmallVector<Function *> CurSCCFuncs; 291 for (CallGraphNode *CGN : CurSCC) { 292 Function *F = CGN->getFunction(); 293 if (!F) 294 continue; 295 296 if (F->isDeclaration()) { 297 assert(!F->getName().starts_with("dx.op.") && 298 "DXIL Shader Flag analysis should not be run post-lowering."); 299 continue; 300 } 301 302 ComputedShaderFlags CSF = GlobalSFMask; 303 for (const auto &BB : *F) 304 for (const auto &I : BB) 305 updateFunctionFlags(CSF, I, DRTM, MMDI); 306 // Update combined shader flags mask for all functions in this SCC 307 SCCSF.merge(CSF); 308 309 CurSCCFuncs.push_back(F); 310 } 311 312 // Update combined shader flags mask for all functions of the module 313 CombinedSFMask.merge(SCCSF); 314 315 // Shader flags mask of each of the functions in an SCC of the call graph is 316 // the union of all functions in the SCC. Update shader flags masks of 317 // functions in CurSCC accordingly. This is trivially true if SCC contains 318 // one function. 319 for (Function *F : CurSCCFuncs) 320 // Merge SCCSF with that of F 321 FunctionFlags[F].merge(SCCSF); 322 } 323 } 324 325 void ComputedShaderFlags::print(raw_ostream &OS) const { 326 uint64_t FlagVal = (uint64_t) * this; 327 OS << formatv("; Shader Flags Value: {0:x8}\n;\n", FlagVal); 328 if (FlagVal == 0) 329 return; 330 OS << "; Note: shader requires additional functionality:\n"; 331 #define SHADER_FEATURE_FLAG(FeatureBit, DxilModuleNum, FlagName, Str) \ 332 if (FlagName) \ 333 (OS << ";").indent(7) << Str << "\n"; 334 #include "llvm/BinaryFormat/DXContainerConstants.def" 335 OS << "; Note: extra DXIL module flags:\n"; 336 #define DXIL_MODULE_FLAG(DxilModuleBit, FlagName, Str) \ 337 if (FlagName) \ 338 (OS << ";").indent(7) << Str << "\n"; 339 #include "llvm/BinaryFormat/DXContainerConstants.def" 340 OS << ";\n"; 341 } 342 343 /// Return the shader flags mask of the specified function Func. 344 const ComputedShaderFlags & 345 ModuleShaderFlags::getFunctionFlags(const Function *Func) const { 346 auto Iter = FunctionFlags.find(Func); 347 assert((Iter != FunctionFlags.end() && Iter->first == Func) && 348 "Get Shader Flags : No Shader Flags Mask exists for function"); 349 return Iter->second; 350 } 351 352 //===----------------------------------------------------------------------===// 353 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass 354 355 // Provide an explicit template instantiation for the static ID. 356 AnalysisKey ShaderFlagsAnalysis::Key; 357 358 ModuleShaderFlags ShaderFlagsAnalysis::run(Module &M, 359 ModuleAnalysisManager &AM) { 360 DXILResourceTypeMap &DRTM = AM.getResult<DXILResourceTypeAnalysis>(M); 361 DXILResourceMap &DRM = AM.getResult<DXILResourceAnalysis>(M); 362 const ModuleMetadataInfo MMDI = AM.getResult<DXILMetadataAnalysis>(M); 363 364 ModuleShaderFlags MSFI; 365 MSFI.initialize(M, DRTM, DRM, MMDI); 366 367 return MSFI; 368 } 369 370 PreservedAnalyses ShaderFlagsAnalysisPrinter::run(Module &M, 371 ModuleAnalysisManager &AM) { 372 const ModuleShaderFlags &FlagsInfo = AM.getResult<ShaderFlagsAnalysis>(M); 373 // Print description of combined shader flags for all module functions 374 OS << "; Combined Shader Flags for Module\n"; 375 FlagsInfo.getCombinedFlags().print(OS); 376 // Print shader flags mask for each of the module functions 377 OS << "; Shader Flags for Module Functions\n"; 378 for (const auto &F : M.getFunctionList()) { 379 if (F.isDeclaration()) 380 continue; 381 const ComputedShaderFlags &SFMask = FlagsInfo.getFunctionFlags(&F); 382 OS << formatv("; Function {0} : {1:x8}\n;\n", F.getName(), 383 (uint64_t)(SFMask)); 384 } 385 386 return PreservedAnalyses::all(); 387 } 388 389 //===----------------------------------------------------------------------===// 390 // ShaderFlagsAnalysis and ShaderFlagsAnalysisPrinterPass 391 392 bool ShaderFlagsAnalysisWrapper::runOnModule(Module &M) { 393 DXILResourceTypeMap &DRTM = 394 getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap(); 395 DXILResourceMap &DRM = 396 getAnalysis<DXILResourceWrapperPass>().getResourceMap(); 397 const ModuleMetadataInfo MMDI = 398 getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata(); 399 400 MSFI.initialize(M, DRTM, DRM, MMDI); 401 return false; 402 } 403 404 void ShaderFlagsAnalysisWrapper::getAnalysisUsage(AnalysisUsage &AU) const { 405 AU.setPreservesAll(); 406 AU.addRequiredTransitive<DXILResourceTypeWrapperPass>(); 407 AU.addRequiredTransitive<DXILResourceWrapperPass>(); 408 AU.addRequired<DXILMetadataAnalysisWrapperPass>(); 409 } 410 411 char ShaderFlagsAnalysisWrapper::ID = 0; 412 413 INITIALIZE_PASS_BEGIN(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", 414 "DXIL Shader Flag Analysis", true, true) 415 INITIALIZE_PASS_DEPENDENCY(DXILResourceTypeWrapperPass) 416 INITIALIZE_PASS_DEPENDENCY(DXILMetadataAnalysisWrapperPass) 417 INITIALIZE_PASS_END(ShaderFlagsAnalysisWrapper, "dx-shader-flag-analysis", 418 "DXIL Shader Flag Analysis", true, true) 419