1 //===- MachineUniformityAnalysis.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/CodeGen/MachineUniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/TargetTransformInfo.h"
12 #include "llvm/CodeGen/MachineCycleAnalysis.h"
13 #include "llvm/CodeGen/MachineDominators.h"
14 #include "llvm/CodeGen/MachineRegisterInfo.h"
15 #include "llvm/CodeGen/MachineSSAContext.h"
16 #include "llvm/CodeGen/TargetInstrInfo.h"
17 #include "llvm/InitializePasses.h"
18
19 using namespace llvm;
20
21 template <>
hasDivergentDefs(const MachineInstr & I) const22 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
23 const MachineInstr &I) const {
24 for (auto &op : I.all_defs()) {
25 if (isDivergent(op.getReg()))
26 return true;
27 }
28 return false;
29 }
30
31 template <>
markDefsDivergent(const MachineInstr & Instr)32 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
33 const MachineInstr &Instr) {
34 bool insertedDivergent = false;
35 const auto &MRI = F.getRegInfo();
36 const auto &RBI = *F.getSubtarget().getRegBankInfo();
37 const auto &TRI = *MRI.getTargetRegisterInfo();
38 for (auto &op : Instr.all_defs()) {
39 if (!op.getReg().isVirtual())
40 continue;
41 assert(!op.getSubReg());
42 if (TRI.isUniformReg(MRI, RBI, op.getReg()))
43 continue;
44 insertedDivergent |= markDivergent(op.getReg());
45 }
46 return insertedDivergent;
47 }
48
49 template <>
initialize()50 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
51 const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
52
53 for (const MachineBasicBlock &block : F) {
54 for (const MachineInstr &instr : block) {
55 auto uniformity = InstrInfo.getInstructionUniformity(instr);
56 if (uniformity == InstructionUniformity::AlwaysUniform) {
57 addUniformOverride(instr);
58 continue;
59 }
60
61 if (uniformity == InstructionUniformity::NeverUniform) {
62 markDivergent(instr);
63 }
64 }
65 }
66 }
67
68 template <>
pushUsers(Register Reg)69 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
70 Register Reg) {
71 assert(isDivergent(Reg));
72 const auto &RegInfo = F.getRegInfo();
73 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
74 markDivergent(UserInstr);
75 }
76 }
77
78 template <>
pushUsers(const MachineInstr & Instr)79 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
80 const MachineInstr &Instr) {
81 assert(!isAlwaysUniform(Instr));
82 if (Instr.isTerminator())
83 return;
84 for (const MachineOperand &op : Instr.all_defs()) {
85 auto Reg = op.getReg();
86 if (isDivergent(Reg))
87 pushUsers(Reg);
88 }
89 }
90
91 template <>
usesValueFromCycle(const MachineInstr & I,const MachineCycle & DefCycle) const92 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
93 const MachineInstr &I, const MachineCycle &DefCycle) const {
94 assert(!isAlwaysUniform(I));
95 for (auto &Op : I.operands()) {
96 if (!Op.isReg() || !Op.readsReg())
97 continue;
98 auto Reg = Op.getReg();
99
100 // FIXME: Physical registers need to be properly checked instead of always
101 // returning true
102 if (Reg.isPhysical())
103 return true;
104
105 auto *Def = F.getRegInfo().getVRegDef(Reg);
106 if (DefCycle.contains(Def->getParent()))
107 return true;
108 }
109 return false;
110 }
111
112 template <>
113 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
propagateTemporalDivergence(const MachineInstr & I,const MachineCycle & DefCycle)114 propagateTemporalDivergence(const MachineInstr &I,
115 const MachineCycle &DefCycle) {
116 const auto &RegInfo = F.getRegInfo();
117 for (auto &Op : I.all_defs()) {
118 if (!Op.getReg().isVirtual())
119 continue;
120 auto Reg = Op.getReg();
121 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
122 if (DefCycle.contains(UserInstr.getParent()))
123 continue;
124 markDivergent(UserInstr);
125
126 recordTemporalDivergence(Reg, &UserInstr, &DefCycle);
127 }
128 }
129 }
130
131 template <>
isDivergentUse(const MachineOperand & U) const132 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
133 const MachineOperand &U) const {
134 if (!U.isReg())
135 return false;
136
137 auto Reg = U.getReg();
138 if (isDivergent(Reg))
139 return true;
140
141 const auto &RegInfo = F.getRegInfo();
142 auto *Def = RegInfo.getOneDef(Reg);
143 if (!Def)
144 return true;
145
146 auto *DefInstr = Def->getParent();
147 auto *UseInstr = U.getParent();
148 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
149 }
150
151 // This ensures explicit instantiation of
152 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
153 template class llvm::GenericUniformityInfo<MachineSSAContext>;
154 template struct llvm::GenericUniformityAnalysisImplDeleter<
155 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
156
computeMachineUniformityInfo(MachineFunction & F,const MachineCycleInfo & cycleInfo,const MachineDominatorTree & domTree,bool HasBranchDivergence)157 MachineUniformityInfo llvm::computeMachineUniformityInfo(
158 MachineFunction &F, const MachineCycleInfo &cycleInfo,
159 const MachineDominatorTree &domTree, bool HasBranchDivergence) {
160 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
161 MachineUniformityInfo UI(domTree, cycleInfo);
162 if (HasBranchDivergence)
163 UI.compute();
164 return UI;
165 }
166
167 namespace {
168
169 class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
170 public:
171 static char ID;
172
173 MachineUniformityInfoPrinterPass();
174
175 bool runOnMachineFunction(MachineFunction &F) override;
176 void getAnalysisUsage(AnalysisUsage &AU) const override;
177 };
178
179 } // namespace
180
181 AnalysisKey MachineUniformityAnalysis::Key;
182
183 MachineUniformityAnalysis::Result
run(MachineFunction & MF,MachineFunctionAnalysisManager & MFAM)184 MachineUniformityAnalysis::run(MachineFunction &MF,
185 MachineFunctionAnalysisManager &MFAM) {
186 auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF);
187 auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF);
188 auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
189 .getManager();
190 auto &F = MF.getFunction();
191 auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
192 return computeMachineUniformityInfo(MF, CI, DomTree,
193 TTI.hasBranchDivergence(&F));
194 }
195
196 PreservedAnalyses
run(MachineFunction & MF,MachineFunctionAnalysisManager & MFAM)197 MachineUniformityPrinterPass::run(MachineFunction &MF,
198 MachineFunctionAnalysisManager &MFAM) {
199 auto &MUI = MFAM.getResult<MachineUniformityAnalysis>(MF);
200 OS << "MachineUniformityInfo for function: ";
201 MF.getFunction().printAsOperand(OS, /*PrintType=*/false);
202 OS << '\n';
203 MUI.print(OS);
204 return PreservedAnalyses::all();
205 }
206
207 char MachineUniformityAnalysisPass::ID = 0;
208
MachineUniformityAnalysisPass()209 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
210 : MachineFunctionPass(ID) {
211 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
212 }
213
214 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
215 "Machine Uniformity Info Analysis", false, true)
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)216 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
217 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
218 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
219 "Machine Uniformity Info Analysis", false, true)
220
221 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
222 AU.setPreservesAll();
223 AU.addRequiredTransitive<MachineCycleInfoWrapperPass>();
224 AU.addRequired<MachineDominatorTreeWrapperPass>();
225 MachineFunctionPass::getAnalysisUsage(AU);
226 }
227
runOnMachineFunction(MachineFunction & MF)228 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
229 auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
230 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
231 // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
232 // default NoTTI
233 UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
234 return false;
235 }
236
print(raw_ostream & OS,const Module *) const237 void MachineUniformityAnalysisPass::print(raw_ostream &OS,
238 const Module *) const {
239 OS << "MachineUniformityInfo for function: ";
240 UI.getFunction().getFunction().printAsOperand(OS, /*PrintType=*/false);
241 OS << '\n';
242 UI.print(OS);
243 }
244
245 char MachineUniformityInfoPrinterPass::ID = 0;
246
MachineUniformityInfoPrinterPass()247 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
248 : MachineFunctionPass(ID) {
249 initializeMachineUniformityInfoPrinterPassPass(
250 *PassRegistry::getPassRegistry());
251 }
252
253 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
254 "print-machine-uniformity",
255 "Print Machine Uniformity Info Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)256 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
257 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
258 "print-machine-uniformity",
259 "Print Machine Uniformity Info Analysis", true, true)
260
261 void MachineUniformityInfoPrinterPass::getAnalysisUsage(
262 AnalysisUsage &AU) const {
263 AU.setPreservesAll();
264 AU.addRequired<MachineUniformityAnalysisPass>();
265 MachineFunctionPass::getAnalysisUsage(AU);
266 }
267
runOnMachineFunction(MachineFunction & F)268 bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
269 MachineFunction &F) {
270 auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
271 UI.print(errs());
272 return false;
273 }
274