xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp (revision 2c2ec6bbc9cc7762a250ffe903bda6c2e44d25ff)
1 //===- UniformityAnalysis.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/UniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/CycleAnalysis.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/IR/Dominators.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/InitializePasses.h"
17 
18 using namespace llvm;
19 
20 template <>
21 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
22     const Instruction &I) const {
23   return isDivergent((const Value *)&I);
24 }
25 
26 template <>
27 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
28     const Instruction &Instr) {
29   return markDivergent(cast<Value>(&Instr));
30 }
31 
32 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
33   for (auto &I : instructions(F)) {
34     if (TTI->isSourceOfDivergence(&I))
35       markDivergent(I);
36     else if (TTI->isAlwaysUniform(&I))
37       addUniformOverride(I);
38   }
39   for (auto &Arg : F.args()) {
40     if (TTI->isSourceOfDivergence(&Arg)) {
41       markDivergent(&Arg);
42     }
43   }
44 }
45 
46 template <>
47 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
48     const Value *V) {
49   for (const auto *User : V->users()) {
50     if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
51       markDivergent(*UserInstr);
52     }
53   }
54 }
55 
56 template <>
57 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
58     const Instruction &Instr) {
59   assert(!isAlwaysUniform(Instr));
60   if (Instr.isTerminator())
61     return;
62   pushUsers(cast<Value>(&Instr));
63 }
64 
65 template <>
66 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
67     const Instruction &I, const Cycle &DefCycle) const {
68   assert(!isAlwaysUniform(I));
69   for (const Use &U : I.operands()) {
70     if (auto *I = dyn_cast<Instruction>(&U)) {
71       if (DefCycle.contains(I->getParent()))
72         return true;
73     }
74   }
75   return false;
76 }
77 
78 template <>
79 void llvm::GenericUniformityAnalysisImpl<
80     SSAContext>::propagateTemporalDivergence(const Instruction &I,
81                                              const Cycle &DefCycle) {
82   for (auto *User : I.users()) {
83     auto *UserInstr = cast<Instruction>(User);
84     if (DefCycle.contains(UserInstr->getParent()))
85       continue;
86     markDivergent(*UserInstr);
87     recordTemporalDivergence(&I, UserInstr, &DefCycle);
88   }
89 }
90 
91 template <>
92 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
93     const Use &U) const {
94   const auto *V = U.get();
95   if (isDivergent(V))
96     return true;
97   if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
98     const auto *UseInstr = cast<Instruction>(U.getUser());
99     return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
100   }
101   return false;
102 }
103 
104 // This ensures explicit instantiation of
105 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
106 template class llvm::GenericUniformityInfo<SSAContext>;
107 template struct llvm::GenericUniformityAnalysisImplDeleter<
108     llvm::GenericUniformityAnalysisImpl<SSAContext>>;
109 
110 //===----------------------------------------------------------------------===//
111 //  UniformityInfoAnalysis and related pass implementations
112 //===----------------------------------------------------------------------===//
113 
114 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
115                                                  FunctionAnalysisManager &FAM) {
116   auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
117   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
118   auto &CI = FAM.getResult<CycleAnalysis>(F);
119   UniformityInfo UI{DT, CI, &TTI};
120   // Skip computation if we can assume everything is uniform.
121   if (TTI.hasBranchDivergence(&F))
122     UI.compute();
123 
124   return UI;
125 }
126 
127 AnalysisKey UniformityInfoAnalysis::Key;
128 
129 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
130     : OS(OS) {}
131 
132 PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
133                                                  FunctionAnalysisManager &AM) {
134   OS << "UniformityInfo for function '" << F.getName() << "':\n";
135   AM.getResult<UniformityInfoAnalysis>(F).print(OS);
136 
137   return PreservedAnalyses::all();
138 }
139 
140 //===----------------------------------------------------------------------===//
141 //  UniformityInfoWrapperPass Implementation
142 //===----------------------------------------------------------------------===//
143 
144 char UniformityInfoWrapperPass::ID = 0;
145 
146 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {}
147 
148 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
149                       "Uniformity Analysis", false, true)
150 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
151 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
152 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
153 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
154                     "Uniformity Analysis", false, true)
155 
156 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
157   AU.setPreservesAll();
158   AU.addRequired<DominatorTreeWrapperPass>();
159   AU.addRequiredTransitive<CycleInfoWrapperPass>();
160   AU.addRequired<TargetTransformInfoWrapperPass>();
161 }
162 
163 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
164   auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
165   auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
166   auto &targetTransformInfo =
167       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
168 
169   m_function = &F;
170   m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
171 
172   // Skip computation if we can assume everything is uniform.
173   if (targetTransformInfo.hasBranchDivergence(m_function))
174     m_uniformityInfo.compute();
175 
176   return false;
177 }
178 
179 void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
180   OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
181 }
182 
183 void UniformityInfoWrapperPass::releaseMemory() {
184   m_uniformityInfo = UniformityInfo{};
185   m_function = nullptr;
186 }
187