1 //===- UniformityAnalysis.cpp ---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/UniformityAnalysis.h" 10 #include "llvm/ADT/GenericUniformityImpl.h" 11 #include "llvm/Analysis/CycleAnalysis.h" 12 #include "llvm/Analysis/TargetTransformInfo.h" 13 #include "llvm/IR/Dominators.h" 14 #include "llvm/IR/InstIterator.h" 15 #include "llvm/IR/Instructions.h" 16 #include "llvm/InitializePasses.h" 17 18 using namespace llvm; 19 20 template <> 21 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs( 22 const Instruction &I) const { 23 return isDivergent((const Value *)&I); 24 } 25 26 template <> 27 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent( 28 const Instruction &Instr) { 29 return markDivergent(cast<Value>(&Instr)); 30 } 31 32 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() { 33 for (auto &I : instructions(F)) { 34 if (TTI->isSourceOfDivergence(&I)) 35 markDivergent(I); 36 else if (TTI->isAlwaysUniform(&I)) 37 addUniformOverride(I); 38 } 39 for (auto &Arg : F.args()) { 40 if (TTI->isSourceOfDivergence(&Arg)) { 41 markDivergent(&Arg); 42 } 43 } 44 } 45 46 template <> 47 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( 48 const Value *V) { 49 for (const auto *User : V->users()) { 50 if (const auto *UserInstr = dyn_cast<const Instruction>(User)) { 51 markDivergent(*UserInstr); 52 } 53 } 54 } 55 56 template <> 57 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers( 58 const Instruction &Instr) { 59 assert(!isAlwaysUniform(Instr)); 60 if (Instr.isTerminator()) 61 return; 62 pushUsers(cast<Value>(&Instr)); 63 } 64 65 template <> 66 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle( 67 const Instruction &I, const Cycle &DefCycle) const { 68 assert(!isAlwaysUniform(I)); 69 for (const Use &U : I.operands()) { 70 if (auto *I = dyn_cast<Instruction>(&U)) { 71 if (DefCycle.contains(I->getParent())) 72 return true; 73 } 74 } 75 return false; 76 } 77 78 template <> 79 void llvm::GenericUniformityAnalysisImpl< 80 SSAContext>::propagateTemporalDivergence(const Instruction &I, 81 const Cycle &DefCycle) { 82 for (auto *User : I.users()) { 83 auto *UserInstr = cast<Instruction>(User); 84 if (DefCycle.contains(UserInstr->getParent())) 85 continue; 86 markDivergent(*UserInstr); 87 recordTemporalDivergence(&I, UserInstr, &DefCycle); 88 } 89 } 90 91 template <> 92 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse( 93 const Use &U) const { 94 const auto *V = U.get(); 95 if (isDivergent(V)) 96 return true; 97 if (const auto *DefInstr = dyn_cast<Instruction>(V)) { 98 const auto *UseInstr = cast<Instruction>(U.getUser()); 99 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); 100 } 101 return false; 102 } 103 104 // This ensures explicit instantiation of 105 // GenericUniformityAnalysisImpl::ImplDeleter::operator() 106 template class llvm::GenericUniformityInfo<SSAContext>; 107 template struct llvm::GenericUniformityAnalysisImplDeleter< 108 llvm::GenericUniformityAnalysisImpl<SSAContext>>; 109 110 //===----------------------------------------------------------------------===// 111 // UniformityInfoAnalysis and related pass implementations 112 //===----------------------------------------------------------------------===// 113 114 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F, 115 FunctionAnalysisManager &FAM) { 116 auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); 117 auto &TTI = FAM.getResult<TargetIRAnalysis>(F); 118 auto &CI = FAM.getResult<CycleAnalysis>(F); 119 UniformityInfo UI{DT, CI, &TTI}; 120 // Skip computation if we can assume everything is uniform. 121 if (TTI.hasBranchDivergence(&F)) 122 UI.compute(); 123 124 return UI; 125 } 126 127 AnalysisKey UniformityInfoAnalysis::Key; 128 129 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS) 130 : OS(OS) {} 131 132 PreservedAnalyses UniformityInfoPrinterPass::run(Function &F, 133 FunctionAnalysisManager &AM) { 134 OS << "UniformityInfo for function '" << F.getName() << "':\n"; 135 AM.getResult<UniformityInfoAnalysis>(F).print(OS); 136 137 return PreservedAnalyses::all(); 138 } 139 140 //===----------------------------------------------------------------------===// 141 // UniformityInfoWrapperPass Implementation 142 //===----------------------------------------------------------------------===// 143 144 char UniformityInfoWrapperPass::ID = 0; 145 146 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {} 147 148 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity", 149 "Uniformity Analysis", false, true) 150 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 151 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass) 152 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 153 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity", 154 "Uniformity Analysis", false, true) 155 156 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { 157 AU.setPreservesAll(); 158 AU.addRequired<DominatorTreeWrapperPass>(); 159 AU.addRequiredTransitive<CycleInfoWrapperPass>(); 160 AU.addRequired<TargetTransformInfoWrapperPass>(); 161 } 162 163 bool UniformityInfoWrapperPass::runOnFunction(Function &F) { 164 auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult(); 165 auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 166 auto &targetTransformInfo = 167 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 168 169 m_function = &F; 170 m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo}; 171 172 // Skip computation if we can assume everything is uniform. 173 if (targetTransformInfo.hasBranchDivergence(m_function)) 174 m_uniformityInfo.compute(); 175 176 return false; 177 } 178 179 void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const { 180 OS << "UniformityInfo for function '" << m_function->getName() << "':\n"; 181 } 182 183 void UniformityInfoWrapperPass::releaseMemory() { 184 m_uniformityInfo = UniformityInfo{}; 185 m_function = nullptr; 186 } 187