xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- MachineUniformityAnalysis.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/CodeGen/MachineUniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/TargetTransformInfo.h"
12 #include "llvm/CodeGen/MachineCycleAnalysis.h"
13 #include "llvm/CodeGen/MachineDominators.h"
14 #include "llvm/CodeGen/MachineRegisterInfo.h"
15 #include "llvm/CodeGen/MachineSSAContext.h"
16 #include "llvm/CodeGen/TargetInstrInfo.h"
17 #include "llvm/InitializePasses.h"
18 
19 using namespace llvm;
20 
21 template <>
22 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
23     const MachineInstr &I) const {
24   for (auto &op : I.all_defs()) {
25     if (isDivergent(op.getReg()))
26       return true;
27   }
28   return false;
29 }
30 
31 template <>
32 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
33     const MachineInstr &Instr) {
34   bool insertedDivergent = false;
35   const auto &MRI = F.getRegInfo();
36   const auto &RBI = *F.getSubtarget().getRegBankInfo();
37   const auto &TRI = *MRI.getTargetRegisterInfo();
38   for (auto &op : Instr.all_defs()) {
39     if (!op.getReg().isVirtual())
40       continue;
41     assert(!op.getSubReg());
42     if (TRI.isUniformReg(MRI, RBI, op.getReg()))
43       continue;
44     insertedDivergent |= markDivergent(op.getReg());
45   }
46   return insertedDivergent;
47 }
48 
49 template <>
50 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
51   const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
52 
53   for (const MachineBasicBlock &block : F) {
54     for (const MachineInstr &instr : block) {
55       auto uniformity = InstrInfo.getInstructionUniformity(instr);
56       if (uniformity == InstructionUniformity::AlwaysUniform) {
57         addUniformOverride(instr);
58         continue;
59       }
60 
61       if (uniformity == InstructionUniformity::NeverUniform) {
62         markDivergent(instr);
63       }
64     }
65   }
66 }
67 
68 template <>
69 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
70     Register Reg) {
71   assert(isDivergent(Reg));
72   const auto &RegInfo = F.getRegInfo();
73   for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
74     markDivergent(UserInstr);
75   }
76 }
77 
78 template <>
79 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
80     const MachineInstr &Instr) {
81   assert(!isAlwaysUniform(Instr));
82   if (Instr.isTerminator())
83     return;
84   for (const MachineOperand &op : Instr.all_defs()) {
85     auto Reg = op.getReg();
86     if (isDivergent(Reg))
87       pushUsers(Reg);
88   }
89 }
90 
91 template <>
92 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
93     const MachineInstr &I, const MachineCycle &DefCycle) const {
94   assert(!isAlwaysUniform(I));
95   for (auto &Op : I.operands()) {
96     if (!Op.isReg() || !Op.readsReg())
97       continue;
98     auto Reg = Op.getReg();
99 
100     // FIXME: Physical registers need to be properly checked instead of always
101     // returning true
102     if (Reg.isPhysical())
103       return true;
104 
105     auto *Def = F.getRegInfo().getVRegDef(Reg);
106     if (DefCycle.contains(Def->getParent()))
107       return true;
108   }
109   return false;
110 }
111 
112 template <>
113 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
114     propagateTemporalDivergence(const MachineInstr &I,
115                                 const MachineCycle &DefCycle) {
116   const auto &RegInfo = F.getRegInfo();
117   for (auto &Op : I.all_defs()) {
118     if (!Op.getReg().isVirtual())
119       continue;
120     auto Reg = Op.getReg();
121     for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
122       if (DefCycle.contains(UserInstr.getParent()))
123         continue;
124       markDivergent(UserInstr);
125 
126       recordTemporalDivergence(Reg, &UserInstr, &DefCycle);
127     }
128   }
129 }
130 
131 template <>
132 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
133     const MachineOperand &U) const {
134   if (!U.isReg())
135     return false;
136 
137   auto Reg = U.getReg();
138   if (isDivergent(Reg))
139     return true;
140 
141   const auto &RegInfo = F.getRegInfo();
142   auto *Def = RegInfo.getOneDef(Reg);
143   if (!Def)
144     return true;
145 
146   auto *DefInstr = Def->getParent();
147   auto *UseInstr = U.getParent();
148   return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
149 }
150 
151 // This ensures explicit instantiation of
152 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
153 template class llvm::GenericUniformityInfo<MachineSSAContext>;
154 template struct llvm::GenericUniformityAnalysisImplDeleter<
155     llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
156 
157 MachineUniformityInfo llvm::computeMachineUniformityInfo(
158     MachineFunction &F, const MachineCycleInfo &cycleInfo,
159     const MachineDominatorTree &domTree, bool HasBranchDivergence) {
160   assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
161   MachineUniformityInfo UI(domTree, cycleInfo);
162   if (HasBranchDivergence)
163     UI.compute();
164   return UI;
165 }
166 
167 namespace {
168 
169 class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
170 public:
171   static char ID;
172 
173   MachineUniformityInfoPrinterPass();
174 
175   bool runOnMachineFunction(MachineFunction &F) override;
176   void getAnalysisUsage(AnalysisUsage &AU) const override;
177 };
178 
179 } // namespace
180 
181 AnalysisKey MachineUniformityAnalysis::Key;
182 
183 MachineUniformityAnalysis::Result
184 MachineUniformityAnalysis::run(MachineFunction &MF,
185                                MachineFunctionAnalysisManager &MFAM) {
186   auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
187   auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
188   auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
189                   .getManager();
190   auto &F = MF.getFunction();
191   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
192   return computeMachineUniformityInfo(MF, CI, DomTree,
193                                       TTI.hasBranchDivergence(&F));
194 }
195 
196 PreservedAnalyses
197 MachineUniformityPrinterPass::run(MachineFunction &MF,
198                                   MachineFunctionAnalysisManager &MFAM) {
199   auto &MUI = MFAM.getResult<MachineUniformityAnalysis>(MF);
200   OS << "MachineUniformityInfo for function: ";
201   MF.getFunction().printAsOperand(OS, /*PrintType=*/false);
202   OS << '\n';
203   MUI.print(OS);
204   return PreservedAnalyses::all();
205 }
206 
207 char MachineUniformityAnalysisPass::ID = 0;
208 
209 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
210     : MachineFunctionPass(ID) {
211   initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
212 }
213 
214 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
215                       "Machine Uniformity Info Analysis", false, true)
216 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
217 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
218 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
219                     "Machine Uniformity Info Analysis", false, true)
220 
221 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
222   AU.setPreservesAll();
223   AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
224   AU.addRequired<MachineDominatorTreeWrapperPass>();
225   MachineFunctionPass::getAnalysisUsage(AU);
226 }
227 
228 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
229   auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
230   auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
231   // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
232   // default NoTTI
233   UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
234   return false;
235 }
236 
237 void MachineUniformityAnalysisPass::print(raw_ostream &OS,
238                                           const Module *) const {
239   OS << "MachineUniformityInfo for function: ";
240   UI.getFunction().getFunction().printAsOperand(OS, /*PrintType=*/false);
241   OS << '\n';
242   UI.print(OS);
243 }
244 
245 char MachineUniformityInfoPrinterPass::ID = 0;
246 
247 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
248     : MachineFunctionPass(ID) {
249   initializeMachineUniformityInfoPrinterPassPass(
250       *PassRegistry::getPassRegistry());
251 }
252 
253 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
254                       "print-machine-uniformity",
255                       "Print Machine Uniformity Info Analysis", true, true)
256 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
257 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
258                     "print-machine-uniformity",
259                     "Print Machine Uniformity Info Analysis", true, true)
260 
261 void MachineUniformityInfoPrinterPass::getAnalysisUsage(
262     AnalysisUsage &AU) const {
263   AU.setPreservesAll();
264   AU.addRequired<MachineUniformityAnalysisPass>();
265   MachineFunctionPass::getAnalysisUsage(AU);
266 }
267 
268 bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
269     MachineFunction &F) {
270   auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
271   UI.print(errs());
272   return false;
273 }
274