xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp (revision 59c8e88e72633afbc47a4ace0d2170d00d51f7dc)
1 //===- MachineUniformityAnalysis.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/CodeGen/MachineUniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/CodeGen/MachineCycleAnalysis.h"
12 #include "llvm/CodeGen/MachineDominators.h"
13 #include "llvm/CodeGen/MachineRegisterInfo.h"
14 #include "llvm/CodeGen/MachineSSAContext.h"
15 #include "llvm/CodeGen/TargetInstrInfo.h"
16 #include "llvm/InitializePasses.h"
17 
18 using namespace llvm;
19 
20 template <>
21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
22     const MachineInstr &I) const {
23   for (auto &op : I.all_defs()) {
24     if (isDivergent(op.getReg()))
25       return true;
26   }
27   return false;
28 }
29 
30 template <>
31 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
32     const MachineInstr &Instr) {
33   bool insertedDivergent = false;
34   const auto &MRI = F.getRegInfo();
35   const auto &RBI = *F.getSubtarget().getRegBankInfo();
36   const auto &TRI = *MRI.getTargetRegisterInfo();
37   for (auto &op : Instr.all_defs()) {
38     if (!op.getReg().isVirtual())
39       continue;
40     assert(!op.getSubReg());
41     if (TRI.isUniformReg(MRI, RBI, op.getReg()))
42       continue;
43     insertedDivergent |= markDivergent(op.getReg());
44   }
45   return insertedDivergent;
46 }
47 
48 template <>
49 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
50   const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
51 
52   for (const MachineBasicBlock &block : F) {
53     for (const MachineInstr &instr : block) {
54       auto uniformity = InstrInfo.getInstructionUniformity(instr);
55       if (uniformity == InstructionUniformity::AlwaysUniform) {
56         addUniformOverride(instr);
57         continue;
58       }
59 
60       if (uniformity == InstructionUniformity::NeverUniform) {
61         markDivergent(instr);
62       }
63     }
64   }
65 }
66 
67 template <>
68 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
69     Register Reg) {
70   assert(isDivergent(Reg));
71   const auto &RegInfo = F.getRegInfo();
72   for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
73     markDivergent(UserInstr);
74   }
75 }
76 
77 template <>
78 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
79     const MachineInstr &Instr) {
80   assert(!isAlwaysUniform(Instr));
81   if (Instr.isTerminator())
82     return;
83   for (const MachineOperand &op : Instr.all_defs()) {
84     auto Reg = op.getReg();
85     if (isDivergent(Reg))
86       pushUsers(Reg);
87   }
88 }
89 
90 template <>
91 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
92     const MachineInstr &I, const MachineCycle &DefCycle) const {
93   assert(!isAlwaysUniform(I));
94   for (auto &Op : I.operands()) {
95     if (!Op.isReg() || !Op.readsReg())
96       continue;
97     auto Reg = Op.getReg();
98 
99     // FIXME: Physical registers need to be properly checked instead of always
100     // returning true
101     if (Reg.isPhysical())
102       return true;
103 
104     auto *Def = F.getRegInfo().getVRegDef(Reg);
105     if (DefCycle.contains(Def->getParent()))
106       return true;
107   }
108   return false;
109 }
110 
111 template <>
112 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
113     propagateTemporalDivergence(const MachineInstr &I,
114                                 const MachineCycle &DefCycle) {
115   const auto &RegInfo = F.getRegInfo();
116   for (auto &Op : I.all_defs()) {
117     if (!Op.getReg().isVirtual())
118       continue;
119     auto Reg = Op.getReg();
120     if (isDivergent(Reg))
121       continue;
122     for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
123       if (DefCycle.contains(UserInstr.getParent()))
124         continue;
125       markDivergent(UserInstr);
126     }
127   }
128 }
129 
130 template <>
131 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
132     const MachineOperand &U) const {
133   if (!U.isReg())
134     return false;
135 
136   auto Reg = U.getReg();
137   if (isDivergent(Reg))
138     return true;
139 
140   const auto &RegInfo = F.getRegInfo();
141   auto *Def = RegInfo.getOneDef(Reg);
142   if (!Def)
143     return true;
144 
145   auto *DefInstr = Def->getParent();
146   auto *UseInstr = U.getParent();
147   return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
148 }
149 
150 // This ensures explicit instantiation of
151 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
152 template class llvm::GenericUniformityInfo<MachineSSAContext>;
153 template struct llvm::GenericUniformityAnalysisImplDeleter<
154     llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
155 
156 MachineUniformityInfo llvm::computeMachineUniformityInfo(
157     MachineFunction &F, const MachineCycleInfo &cycleInfo,
158     const MachineDomTree &domTree, bool HasBranchDivergence) {
159   assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
160   MachineUniformityInfo UI(F, domTree, cycleInfo);
161   if (HasBranchDivergence)
162     UI.compute();
163   return UI;
164 }
165 
166 namespace {
167 
168 /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
169 class MachineUniformityAnalysisPass : public MachineFunctionPass {
170   MachineUniformityInfo UI;
171 
172 public:
173   static char ID;
174 
175   MachineUniformityAnalysisPass();
176 
177   MachineUniformityInfo &getUniformityInfo() { return UI; }
178   const MachineUniformityInfo &getUniformityInfo() const { return UI; }
179 
180   bool runOnMachineFunction(MachineFunction &F) override;
181   void getAnalysisUsage(AnalysisUsage &AU) const override;
182   void print(raw_ostream &OS, const Module *M = nullptr) const override;
183 
184   // TODO: verify analysis
185 };
186 
187 class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
188 public:
189   static char ID;
190 
191   MachineUniformityInfoPrinterPass();
192 
193   bool runOnMachineFunction(MachineFunction &F) override;
194   void getAnalysisUsage(AnalysisUsage &AU) const override;
195 };
196 
197 } // namespace
198 
199 char MachineUniformityAnalysisPass::ID = 0;
200 
201 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
202     : MachineFunctionPass(ID) {
203   initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
204 }
205 
206 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
207                       "Machine Uniformity Info Analysis", true, true)
208 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
209 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
210 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
211                     "Machine Uniformity Info Analysis", true, true)
212 
213 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
214   AU.setPreservesAll();
215   AU.addRequired<MachineCycleInfoWrapperPass>();
216   AU.addRequired<MachineDominatorTree>();
217   MachineFunctionPass::getAnalysisUsage(AU);
218 }
219 
220 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
221   auto &DomTree = getAnalysis<MachineDominatorTree>().getBase();
222   auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
223   // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
224   // default NoTTI
225   UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
226   return false;
227 }
228 
229 void MachineUniformityAnalysisPass::print(raw_ostream &OS,
230                                           const Module *) const {
231   OS << "MachineUniformityInfo for function: " << UI.getFunction().getName()
232      << "\n";
233   UI.print(OS);
234 }
235 
236 char MachineUniformityInfoPrinterPass::ID = 0;
237 
238 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
239     : MachineFunctionPass(ID) {
240   initializeMachineUniformityInfoPrinterPassPass(
241       *PassRegistry::getPassRegistry());
242 }
243 
244 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
245                       "print-machine-uniformity",
246                       "Print Machine Uniformity Info Analysis", true, true)
247 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
248 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
249                     "print-machine-uniformity",
250                     "Print Machine Uniformity Info Analysis", true, true)
251 
252 void MachineUniformityInfoPrinterPass::getAnalysisUsage(
253     AnalysisUsage &AU) const {
254   AU.setPreservesAll();
255   AU.addRequired<MachineUniformityAnalysisPass>();
256   MachineFunctionPass::getAnalysisUsage(AU);
257 }
258 
259 bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
260     MachineFunction &F) {
261   auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
262   UI.print(errs());
263   return false;
264 }
265