1 //===- MachineUniformityAnalysis.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/CodeGen/MachineUniformityAnalysis.h" 10 #include "llvm/ADT/GenericUniformityImpl.h" 11 #include "llvm/Analysis/TargetTransformInfo.h" 12 #include "llvm/CodeGen/MachineCycleAnalysis.h" 13 #include "llvm/CodeGen/MachineDominators.h" 14 #include "llvm/CodeGen/MachineRegisterInfo.h" 15 #include "llvm/CodeGen/MachineSSAContext.h" 16 #include "llvm/CodeGen/TargetInstrInfo.h" 17 #include "llvm/InitializePasses.h" 18 19 using namespace llvm; 20 21 template <> 22 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs( 23 const MachineInstr &I) const { 24 for (auto &op : I.all_defs()) { 25 if (isDivergent(op.getReg())) 26 return true; 27 } 28 return false; 29 } 30 31 template <> 32 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent( 33 const MachineInstr &Instr) { 34 bool insertedDivergent = false; 35 const auto &MRI = F.getRegInfo(); 36 const auto &RBI = *F.getSubtarget().getRegBankInfo(); 37 const auto &TRI = *MRI.getTargetRegisterInfo(); 38 for (auto &op : Instr.all_defs()) { 39 if (!op.getReg().isVirtual()) 40 continue; 41 assert(!op.getSubReg()); 42 if (TRI.isUniformReg(MRI, RBI, op.getReg())) 43 continue; 44 insertedDivergent |= markDivergent(op.getReg()); 45 } 46 return insertedDivergent; 47 } 48 49 template <> 50 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() { 51 const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); 52 53 for (const MachineBasicBlock &block : F) { 54 for (const MachineInstr &instr : block) { 55 auto uniformity = InstrInfo.getInstructionUniformity(instr); 56 if (uniformity == InstructionUniformity::AlwaysUniform) { 57 addUniformOverride(instr); 58 continue; 59 } 60 61 if (uniformity == InstructionUniformity::NeverUniform) { 62 markDivergent(instr); 63 } 64 } 65 } 66 } 67 68 template <> 69 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 70 Register Reg) { 71 assert(isDivergent(Reg)); 72 const auto &RegInfo = F.getRegInfo(); 73 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 74 markDivergent(UserInstr); 75 } 76 } 77 78 template <> 79 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 80 const MachineInstr &Instr) { 81 assert(!isAlwaysUniform(Instr)); 82 if (Instr.isTerminator()) 83 return; 84 for (const MachineOperand &op : Instr.all_defs()) { 85 auto Reg = op.getReg(); 86 if (isDivergent(Reg)) 87 pushUsers(Reg); 88 } 89 } 90 91 template <> 92 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle( 93 const MachineInstr &I, const MachineCycle &DefCycle) const { 94 assert(!isAlwaysUniform(I)); 95 for (auto &Op : I.operands()) { 96 if (!Op.isReg() || !Op.readsReg()) 97 continue; 98 auto Reg = Op.getReg(); 99 100 // FIXME: Physical registers need to be properly checked instead of always 101 // returning true 102 if (Reg.isPhysical()) 103 return true; 104 105 auto *Def = F.getRegInfo().getVRegDef(Reg); 106 if (DefCycle.contains(Def->getParent())) 107 return true; 108 } 109 return false; 110 } 111 112 template <> 113 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>:: 114 propagateTemporalDivergence(const MachineInstr &I, 115 const MachineCycle &DefCycle) { 116 const auto &RegInfo = F.getRegInfo(); 117 for (auto &Op : I.all_defs()) { 118 if (!Op.getReg().isVirtual()) 119 continue; 120 auto Reg = Op.getReg(); 121 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 122 if (DefCycle.contains(UserInstr.getParent())) 123 continue; 124 markDivergent(UserInstr); 125 126 recordTemporalDivergence(Reg, &UserInstr, &DefCycle); 127 } 128 } 129 } 130 131 template <> 132 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse( 133 const MachineOperand &U) const { 134 if (!U.isReg()) 135 return false; 136 137 auto Reg = U.getReg(); 138 if (isDivergent(Reg)) 139 return true; 140 141 const auto &RegInfo = F.getRegInfo(); 142 auto *Def = RegInfo.getOneDef(Reg); 143 if (!Def) 144 return true; 145 146 auto *DefInstr = Def->getParent(); 147 auto *UseInstr = U.getParent(); 148 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); 149 } 150 151 // This ensures explicit instantiation of 152 // GenericUniformityAnalysisImpl::ImplDeleter::operator() 153 template class llvm::GenericUniformityInfo<MachineSSAContext>; 154 template struct llvm::GenericUniformityAnalysisImplDeleter< 155 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>; 156 157 MachineUniformityInfo llvm::computeMachineUniformityInfo( 158 MachineFunction &F, const MachineCycleInfo &cycleInfo, 159 const MachineDominatorTree &domTree, bool HasBranchDivergence) { 160 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); 161 MachineUniformityInfo UI(domTree, cycleInfo); 162 if (HasBranchDivergence) 163 UI.compute(); 164 return UI; 165 } 166 167 namespace { 168 169 class MachineUniformityInfoPrinterPass : public MachineFunctionPass { 170 public: 171 static char ID; 172 173 MachineUniformityInfoPrinterPass(); 174 175 bool runOnMachineFunction(MachineFunction &F) override; 176 void getAnalysisUsage(AnalysisUsage &AU) const override; 177 }; 178 179 } // namespace 180 181 AnalysisKey MachineUniformityAnalysis::Key; 182 183 MachineUniformityAnalysis::Result 184 MachineUniformityAnalysis::run(MachineFunction &MF, 185 MachineFunctionAnalysisManager &MFAM) { 186 auto &DomTree = MFAM.getResult<MachineDominatorTreeAnalysis>(MF); 187 auto &CI = MFAM.getResult<MachineCycleAnalysis>(MF); 188 auto &FAM = MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF) 189 .getManager(); 190 auto &F = MF.getFunction(); 191 auto &TTI = FAM.getResult<TargetIRAnalysis>(F); 192 return computeMachineUniformityInfo(MF, CI, DomTree, 193 TTI.hasBranchDivergence(&F)); 194 } 195 196 PreservedAnalyses 197 MachineUniformityPrinterPass::run(MachineFunction &MF, 198 MachineFunctionAnalysisManager &MFAM) { 199 auto &MUI = MFAM.getResult<MachineUniformityAnalysis>(MF); 200 OS << "MachineUniformityInfo for function: "; 201 MF.getFunction().printAsOperand(OS, /*PrintType=*/false); 202 OS << '\n'; 203 MUI.print(OS); 204 return PreservedAnalyses::all(); 205 } 206 207 char MachineUniformityAnalysisPass::ID = 0; 208 209 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() 210 : MachineFunctionPass(ID) { 211 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry()); 212 } 213 214 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity", 215 "Machine Uniformity Info Analysis", false, true) 216 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) 217 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) 218 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity", 219 "Machine Uniformity Info Analysis", false, true) 220 221 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { 222 AU.setPreservesAll(); 223 AU.addRequiredTransitive<MachineCycleInfoWrapperPass>(); 224 AU.addRequired<MachineDominatorTreeWrapperPass>(); 225 MachineFunctionPass::getAnalysisUsage(AU); 226 } 227 228 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { 229 auto &DomTree = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); 230 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo(); 231 // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a 232 // default NoTTI 233 UI = computeMachineUniformityInfo(MF, CI, DomTree, true); 234 return false; 235 } 236 237 void MachineUniformityAnalysisPass::print(raw_ostream &OS, 238 const Module *) const { 239 OS << "MachineUniformityInfo for function: "; 240 UI.getFunction().getFunction().printAsOperand(OS, /*PrintType=*/false); 241 OS << '\n'; 242 UI.print(OS); 243 } 244 245 char MachineUniformityInfoPrinterPass::ID = 0; 246 247 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() 248 : MachineFunctionPass(ID) { 249 initializeMachineUniformityInfoPrinterPassPass( 250 *PassRegistry::getPassRegistry()); 251 } 252 253 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, 254 "print-machine-uniformity", 255 "Print Machine Uniformity Info Analysis", true, true) 256 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) 257 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, 258 "print-machine-uniformity", 259 "Print Machine Uniformity Info Analysis", true, true) 260 261 void MachineUniformityInfoPrinterPass::getAnalysisUsage( 262 AnalysisUsage &AU) const { 263 AU.setPreservesAll(); 264 AU.addRequired<MachineUniformityAnalysisPass>(); 265 MachineFunctionPass::getAnalysisUsage(AU); 266 } 267 268 bool MachineUniformityInfoPrinterPass::runOnMachineFunction( 269 MachineFunction &F) { 270 auto &UI = getAnalysis<MachineUniformityAnalysisPass>(); 271 UI.print(errs()); 272 return false; 273 } 274