1 //===- MachineUniformityAnalysis.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/CodeGen/MachineUniformityAnalysis.h" 10 #include "llvm/ADT/GenericUniformityImpl.h" 11 #include "llvm/CodeGen/MachineCycleAnalysis.h" 12 #include "llvm/CodeGen/MachineDominators.h" 13 #include "llvm/CodeGen/MachineRegisterInfo.h" 14 #include "llvm/CodeGen/MachineSSAContext.h" 15 #include "llvm/CodeGen/TargetInstrInfo.h" 16 #include "llvm/InitializePasses.h" 17 18 using namespace llvm; 19 20 template <> 21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs( 22 const MachineInstr &I) const { 23 for (auto &op : I.all_defs()) { 24 if (isDivergent(op.getReg())) 25 return true; 26 } 27 return false; 28 } 29 30 template <> 31 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent( 32 const MachineInstr &Instr) { 33 bool insertedDivergent = false; 34 const auto &MRI = F.getRegInfo(); 35 const auto &RBI = *F.getSubtarget().getRegBankInfo(); 36 const auto &TRI = *MRI.getTargetRegisterInfo(); 37 for (auto &op : Instr.all_defs()) { 38 if (!op.getReg().isVirtual()) 39 continue; 40 assert(!op.getSubReg()); 41 if (TRI.isUniformReg(MRI, RBI, op.getReg())) 42 continue; 43 insertedDivergent |= markDivergent(op.getReg()); 44 } 45 return insertedDivergent; 46 } 47 48 template <> 49 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() { 50 const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); 51 52 for (const MachineBasicBlock &block : F) { 53 for (const MachineInstr &instr : block) { 54 auto uniformity = InstrInfo.getInstructionUniformity(instr); 55 if (uniformity == InstructionUniformity::AlwaysUniform) { 56 addUniformOverride(instr); 57 continue; 58 } 59 60 if (uniformity == InstructionUniformity::NeverUniform) { 61 markDivergent(instr); 62 } 63 } 64 } 65 } 66 67 template <> 68 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 69 Register Reg) { 70 assert(isDivergent(Reg)); 71 const auto &RegInfo = F.getRegInfo(); 72 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 73 markDivergent(UserInstr); 74 } 75 } 76 77 template <> 78 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 79 const MachineInstr &Instr) { 80 assert(!isAlwaysUniform(Instr)); 81 if (Instr.isTerminator()) 82 return; 83 for (const MachineOperand &op : Instr.all_defs()) { 84 auto Reg = op.getReg(); 85 if (isDivergent(Reg)) 86 pushUsers(Reg); 87 } 88 } 89 90 template <> 91 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle( 92 const MachineInstr &I, const MachineCycle &DefCycle) const { 93 assert(!isAlwaysUniform(I)); 94 for (auto &Op : I.operands()) { 95 if (!Op.isReg() || !Op.readsReg()) 96 continue; 97 auto Reg = Op.getReg(); 98 99 // FIXME: Physical registers need to be properly checked instead of always 100 // returning true 101 if (Reg.isPhysical()) 102 return true; 103 104 auto *Def = F.getRegInfo().getVRegDef(Reg); 105 if (DefCycle.contains(Def->getParent())) 106 return true; 107 } 108 return false; 109 } 110 111 template <> 112 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>:: 113 propagateTemporalDivergence(const MachineInstr &I, 114 const MachineCycle &DefCycle) { 115 const auto &RegInfo = F.getRegInfo(); 116 for (auto &Op : I.all_defs()) { 117 if (!Op.getReg().isVirtual()) 118 continue; 119 auto Reg = Op.getReg(); 120 if (isDivergent(Reg)) 121 continue; 122 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 123 if (DefCycle.contains(UserInstr.getParent())) 124 continue; 125 markDivergent(UserInstr); 126 } 127 } 128 } 129 130 template <> 131 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse( 132 const MachineOperand &U) const { 133 if (!U.isReg()) 134 return false; 135 136 auto Reg = U.getReg(); 137 if (isDivergent(Reg)) 138 return true; 139 140 const auto &RegInfo = F.getRegInfo(); 141 auto *Def = RegInfo.getOneDef(Reg); 142 if (!Def) 143 return true; 144 145 auto *DefInstr = Def->getParent(); 146 auto *UseInstr = U.getParent(); 147 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); 148 } 149 150 // This ensures explicit instantiation of 151 // GenericUniformityAnalysisImpl::ImplDeleter::operator() 152 template class llvm::GenericUniformityInfo<MachineSSAContext>; 153 template struct llvm::GenericUniformityAnalysisImplDeleter< 154 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>; 155 156 MachineUniformityInfo llvm::computeMachineUniformityInfo( 157 MachineFunction &F, const MachineCycleInfo &cycleInfo, 158 const MachineDomTree &domTree, bool HasBranchDivergence) { 159 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); 160 MachineUniformityInfo UI(F, domTree, cycleInfo); 161 if (HasBranchDivergence) 162 UI.compute(); 163 return UI; 164 } 165 166 namespace { 167 168 /// Legacy analysis pass which computes a \ref MachineUniformityInfo. 169 class MachineUniformityAnalysisPass : public MachineFunctionPass { 170 MachineUniformityInfo UI; 171 172 public: 173 static char ID; 174 175 MachineUniformityAnalysisPass(); 176 177 MachineUniformityInfo &getUniformityInfo() { return UI; } 178 const MachineUniformityInfo &getUniformityInfo() const { return UI; } 179 180 bool runOnMachineFunction(MachineFunction &F) override; 181 void getAnalysisUsage(AnalysisUsage &AU) const override; 182 void print(raw_ostream &OS, const Module *M = nullptr) const override; 183 184 // TODO: verify analysis 185 }; 186 187 class MachineUniformityInfoPrinterPass : public MachineFunctionPass { 188 public: 189 static char ID; 190 191 MachineUniformityInfoPrinterPass(); 192 193 bool runOnMachineFunction(MachineFunction &F) override; 194 void getAnalysisUsage(AnalysisUsage &AU) const override; 195 }; 196 197 } // namespace 198 199 char MachineUniformityAnalysisPass::ID = 0; 200 201 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() 202 : MachineFunctionPass(ID) { 203 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry()); 204 } 205 206 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity", 207 "Machine Uniformity Info Analysis", true, true) 208 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) 209 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) 210 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity", 211 "Machine Uniformity Info Analysis", true, true) 212 213 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { 214 AU.setPreservesAll(); 215 AU.addRequired<MachineCycleInfoWrapperPass>(); 216 AU.addRequired<MachineDominatorTree>(); 217 MachineFunctionPass::getAnalysisUsage(AU); 218 } 219 220 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { 221 auto &DomTree = getAnalysis<MachineDominatorTree>().getBase(); 222 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo(); 223 // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a 224 // default NoTTI 225 UI = computeMachineUniformityInfo(MF, CI, DomTree, true); 226 return false; 227 } 228 229 void MachineUniformityAnalysisPass::print(raw_ostream &OS, 230 const Module *) const { 231 OS << "MachineUniformityInfo for function: " << UI.getFunction().getName() 232 << "\n"; 233 UI.print(OS); 234 } 235 236 char MachineUniformityInfoPrinterPass::ID = 0; 237 238 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() 239 : MachineFunctionPass(ID) { 240 initializeMachineUniformityInfoPrinterPassPass( 241 *PassRegistry::getPassRegistry()); 242 } 243 244 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, 245 "print-machine-uniformity", 246 "Print Machine Uniformity Info Analysis", true, true) 247 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) 248 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, 249 "print-machine-uniformity", 250 "Print Machine Uniformity Info Analysis", true, true) 251 252 void MachineUniformityInfoPrinterPass::getAnalysisUsage( 253 AnalysisUsage &AU) const { 254 AU.setPreservesAll(); 255 AU.addRequired<MachineUniformityAnalysisPass>(); 256 MachineFunctionPass::getAnalysisUsage(AU); 257 } 258 259 bool MachineUniformityInfoPrinterPass::runOnMachineFunction( 260 MachineFunction &F) { 261 auto &UI = getAnalysis<MachineUniformityAnalysisPass>(); 262 UI.print(errs()); 263 return false; 264 } 265