xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/UniformityAnalysis.cpp (revision 5f757f3ff9144b609b3c433dfd370cc6bdc191ad)
106c3fb27SDimitry Andric //===- UniformityAnalysis.cpp ---------------------------------------------===//
2bdd1243dSDimitry Andric //
3bdd1243dSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4bdd1243dSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5bdd1243dSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6bdd1243dSDimitry Andric //
7bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
8bdd1243dSDimitry Andric 
9bdd1243dSDimitry Andric #include "llvm/Analysis/UniformityAnalysis.h"
10bdd1243dSDimitry Andric #include "llvm/ADT/GenericUniformityImpl.h"
11bdd1243dSDimitry Andric #include "llvm/Analysis/CycleAnalysis.h"
12bdd1243dSDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
13bdd1243dSDimitry Andric #include "llvm/IR/Constants.h"
14bdd1243dSDimitry Andric #include "llvm/IR/Dominators.h"
15bdd1243dSDimitry Andric #include "llvm/IR/InstIterator.h"
16bdd1243dSDimitry Andric #include "llvm/IR/Instructions.h"
17bdd1243dSDimitry Andric #include "llvm/InitializePasses.h"
18bdd1243dSDimitry Andric 
19bdd1243dSDimitry Andric using namespace llvm;
20bdd1243dSDimitry Andric 
21bdd1243dSDimitry Andric template <>
22bdd1243dSDimitry Andric bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
23bdd1243dSDimitry Andric     const Instruction &I) const {
24bdd1243dSDimitry Andric   return isDivergent((const Value *)&I);
25bdd1243dSDimitry Andric }
26bdd1243dSDimitry Andric 
27bdd1243dSDimitry Andric template <>
28bdd1243dSDimitry Andric bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
2906c3fb27SDimitry Andric     const Instruction &Instr) {
3006c3fb27SDimitry Andric   return markDivergent(cast<Value>(&Instr));
31bdd1243dSDimitry Andric }
32bdd1243dSDimitry Andric 
33bdd1243dSDimitry Andric template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
34bdd1243dSDimitry Andric   for (auto &I : instructions(F)) {
3506c3fb27SDimitry Andric     if (TTI->isSourceOfDivergence(&I))
36bdd1243dSDimitry Andric       markDivergent(I);
3706c3fb27SDimitry Andric     else if (TTI->isAlwaysUniform(&I))
38bdd1243dSDimitry Andric       addUniformOverride(I);
39bdd1243dSDimitry Andric   }
40bdd1243dSDimitry Andric   for (auto &Arg : F.args()) {
41bdd1243dSDimitry Andric     if (TTI->isSourceOfDivergence(&Arg)) {
42bdd1243dSDimitry Andric       markDivergent(&Arg);
43bdd1243dSDimitry Andric     }
44bdd1243dSDimitry Andric   }
45bdd1243dSDimitry Andric }
46bdd1243dSDimitry Andric 
47bdd1243dSDimitry Andric template <>
48bdd1243dSDimitry Andric void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
49bdd1243dSDimitry Andric     const Value *V) {
50bdd1243dSDimitry Andric   for (const auto *User : V->users()) {
5106c3fb27SDimitry Andric     if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
5206c3fb27SDimitry Andric       markDivergent(*UserInstr);
53bdd1243dSDimitry Andric     }
54bdd1243dSDimitry Andric   }
55bdd1243dSDimitry Andric }
56bdd1243dSDimitry Andric 
57bdd1243dSDimitry Andric template <>
58bdd1243dSDimitry Andric void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
59bdd1243dSDimitry Andric     const Instruction &Instr) {
60bdd1243dSDimitry Andric   assert(!isAlwaysUniform(Instr));
61bdd1243dSDimitry Andric   if (Instr.isTerminator())
62bdd1243dSDimitry Andric     return;
63bdd1243dSDimitry Andric   pushUsers(cast<Value>(&Instr));
64bdd1243dSDimitry Andric }
65bdd1243dSDimitry Andric 
66bdd1243dSDimitry Andric template <>
67bdd1243dSDimitry Andric bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
68bdd1243dSDimitry Andric     const Instruction &I, const Cycle &DefCycle) const {
6906c3fb27SDimitry Andric   assert(!isAlwaysUniform(I));
70bdd1243dSDimitry Andric   for (const Use &U : I.operands()) {
71bdd1243dSDimitry Andric     if (auto *I = dyn_cast<Instruction>(&U)) {
72bdd1243dSDimitry Andric       if (DefCycle.contains(I->getParent()))
73bdd1243dSDimitry Andric         return true;
74bdd1243dSDimitry Andric     }
75bdd1243dSDimitry Andric   }
76bdd1243dSDimitry Andric   return false;
77bdd1243dSDimitry Andric }
78bdd1243dSDimitry Andric 
7906c3fb27SDimitry Andric template <>
8006c3fb27SDimitry Andric void llvm::GenericUniformityAnalysisImpl<
8106c3fb27SDimitry Andric     SSAContext>::propagateTemporalDivergence(const Instruction &I,
8206c3fb27SDimitry Andric                                              const Cycle &DefCycle) {
8306c3fb27SDimitry Andric   if (isDivergent(I))
8406c3fb27SDimitry Andric     return;
8506c3fb27SDimitry Andric   for (auto *User : I.users()) {
8606c3fb27SDimitry Andric     auto *UserInstr = cast<Instruction>(User);
8706c3fb27SDimitry Andric     if (DefCycle.contains(UserInstr->getParent()))
8806c3fb27SDimitry Andric       continue;
8906c3fb27SDimitry Andric     markDivergent(*UserInstr);
9006c3fb27SDimitry Andric   }
9106c3fb27SDimitry Andric }
9206c3fb27SDimitry Andric 
9306c3fb27SDimitry Andric template <>
9406c3fb27SDimitry Andric bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
9506c3fb27SDimitry Andric     const Use &U) const {
9606c3fb27SDimitry Andric   const auto *V = U.get();
9706c3fb27SDimitry Andric   if (isDivergent(V))
9806c3fb27SDimitry Andric     return true;
9906c3fb27SDimitry Andric   if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
10006c3fb27SDimitry Andric     const auto *UseInstr = cast<Instruction>(U.getUser());
10106c3fb27SDimitry Andric     return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
10206c3fb27SDimitry Andric   }
10306c3fb27SDimitry Andric   return false;
10406c3fb27SDimitry Andric }
10506c3fb27SDimitry Andric 
106bdd1243dSDimitry Andric // This ensures explicit instantiation of
107bdd1243dSDimitry Andric // GenericUniformityAnalysisImpl::ImplDeleter::operator()
108bdd1243dSDimitry Andric template class llvm::GenericUniformityInfo<SSAContext>;
109bdd1243dSDimitry Andric template struct llvm::GenericUniformityAnalysisImplDeleter<
110bdd1243dSDimitry Andric     llvm::GenericUniformityAnalysisImpl<SSAContext>>;
111bdd1243dSDimitry Andric 
112bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
113bdd1243dSDimitry Andric //  UniformityInfoAnalysis and related pass implementations
114bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
115bdd1243dSDimitry Andric 
116bdd1243dSDimitry Andric llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
117bdd1243dSDimitry Andric                                                  FunctionAnalysisManager &FAM) {
118bdd1243dSDimitry Andric   auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
119bdd1243dSDimitry Andric   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
120bdd1243dSDimitry Andric   auto &CI = FAM.getResult<CycleAnalysis>(F);
121*5f757f3fSDimitry Andric   UniformityInfo UI{DT, CI, &TTI};
12206c3fb27SDimitry Andric   // Skip computation if we can assume everything is uniform.
12306c3fb27SDimitry Andric   if (TTI.hasBranchDivergence(&F))
12406c3fb27SDimitry Andric     UI.compute();
12506c3fb27SDimitry Andric 
12606c3fb27SDimitry Andric   return UI;
127bdd1243dSDimitry Andric }
128bdd1243dSDimitry Andric 
129bdd1243dSDimitry Andric AnalysisKey UniformityInfoAnalysis::Key;
130bdd1243dSDimitry Andric 
131bdd1243dSDimitry Andric UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
132bdd1243dSDimitry Andric     : OS(OS) {}
133bdd1243dSDimitry Andric 
134bdd1243dSDimitry Andric PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
135bdd1243dSDimitry Andric                                                  FunctionAnalysisManager &AM) {
136bdd1243dSDimitry Andric   OS << "UniformityInfo for function '" << F.getName() << "':\n";
137bdd1243dSDimitry Andric   AM.getResult<UniformityInfoAnalysis>(F).print(OS);
138bdd1243dSDimitry Andric 
139bdd1243dSDimitry Andric   return PreservedAnalyses::all();
140bdd1243dSDimitry Andric }
141bdd1243dSDimitry Andric 
142bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
143bdd1243dSDimitry Andric //  UniformityInfoWrapperPass Implementation
144bdd1243dSDimitry Andric //===----------------------------------------------------------------------===//
145bdd1243dSDimitry Andric 
146bdd1243dSDimitry Andric char UniformityInfoWrapperPass::ID = 0;
147bdd1243dSDimitry Andric 
148bdd1243dSDimitry Andric UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
149bdd1243dSDimitry Andric   initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
150bdd1243dSDimitry Andric }
151bdd1243dSDimitry Andric 
15206c3fb27SDimitry Andric INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
15306c3fb27SDimitry Andric                       "Uniformity Analysis", true, true)
154bdd1243dSDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
15506c3fb27SDimitry Andric INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
156bdd1243dSDimitry Andric INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
15706c3fb27SDimitry Andric INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
15806c3fb27SDimitry Andric                     "Uniformity Analysis", true, true)
159bdd1243dSDimitry Andric 
160bdd1243dSDimitry Andric void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
161bdd1243dSDimitry Andric   AU.setPreservesAll();
162bdd1243dSDimitry Andric   AU.addRequired<DominatorTreeWrapperPass>();
16306c3fb27SDimitry Andric   AU.addRequiredTransitive<CycleInfoWrapperPass>();
164bdd1243dSDimitry Andric   AU.addRequired<TargetTransformInfoWrapperPass>();
165bdd1243dSDimitry Andric }
166bdd1243dSDimitry Andric 
167bdd1243dSDimitry Andric bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
168bdd1243dSDimitry Andric   auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
169bdd1243dSDimitry Andric   auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
170bdd1243dSDimitry Andric   auto &targetTransformInfo =
171bdd1243dSDimitry Andric       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
172bdd1243dSDimitry Andric 
173bdd1243dSDimitry Andric   m_function = &F;
174*5f757f3fSDimitry Andric   m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
17506c3fb27SDimitry Andric 
17606c3fb27SDimitry Andric   // Skip computation if we can assume everything is uniform.
17706c3fb27SDimitry Andric   if (targetTransformInfo.hasBranchDivergence(m_function))
17806c3fb27SDimitry Andric     m_uniformityInfo.compute();
17906c3fb27SDimitry Andric 
180bdd1243dSDimitry Andric   return false;
181bdd1243dSDimitry Andric }
182bdd1243dSDimitry Andric 
183bdd1243dSDimitry Andric void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
184bdd1243dSDimitry Andric   OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
185bdd1243dSDimitry Andric }
186bdd1243dSDimitry Andric 
187bdd1243dSDimitry Andric void UniformityInfoWrapperPass::releaseMemory() {
188bdd1243dSDimitry Andric   m_uniformityInfo = UniformityInfo{};
189bdd1243dSDimitry Andric   m_function = nullptr;
190bdd1243dSDimitry Andric }
191