1 //===- MachineUniformityAnalysis.cpp --------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/CodeGen/MachineUniformityAnalysis.h" 10 #include "llvm/ADT/GenericUniformityImpl.h" 11 #include "llvm/CodeGen/MachineCycleAnalysis.h" 12 #include "llvm/CodeGen/MachineDominators.h" 13 #include "llvm/CodeGen/MachineRegisterInfo.h" 14 #include "llvm/CodeGen/MachineSSAContext.h" 15 #include "llvm/CodeGen/TargetInstrInfo.h" 16 #include "llvm/InitializePasses.h" 17 18 using namespace llvm; 19 20 template <> 21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs( 22 const MachineInstr &I) const { 23 for (auto &op : I.all_defs()) { 24 if (isDivergent(op.getReg())) 25 return true; 26 } 27 return false; 28 } 29 30 template <> 31 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent( 32 const MachineInstr &Instr) { 33 bool insertedDivergent = false; 34 const auto &MRI = F.getRegInfo(); 35 const auto &RBI = *F.getSubtarget().getRegBankInfo(); 36 const auto &TRI = *MRI.getTargetRegisterInfo(); 37 for (auto &op : Instr.all_defs()) { 38 if (!op.getReg().isVirtual()) 39 continue; 40 assert(!op.getSubReg()); 41 if (TRI.isUniformReg(MRI, RBI, op.getReg())) 42 continue; 43 insertedDivergent |= markDivergent(op.getReg()); 44 } 45 return insertedDivergent; 46 } 47 48 template <> 49 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() { 50 const auto &InstrInfo = *F.getSubtarget().getInstrInfo(); 51 52 for (const MachineBasicBlock &block : F) { 53 for (const MachineInstr &instr : block) { 54 auto uniformity = InstrInfo.getInstructionUniformity(instr); 55 if (uniformity == InstructionUniformity::AlwaysUniform) { 56 addUniformOverride(instr); 57 continue; 58 } 59 60 if (uniformity == InstructionUniformity::NeverUniform) { 61 markDivergent(instr); 62 } 63 } 64 } 65 } 66 67 template <> 68 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 69 Register Reg) { 70 assert(isDivergent(Reg)); 71 const auto &RegInfo = F.getRegInfo(); 72 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 73 markDivergent(UserInstr); 74 } 75 } 76 77 template <> 78 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers( 79 const MachineInstr &Instr) { 80 assert(!isAlwaysUniform(Instr)); 81 if (Instr.isTerminator()) 82 return; 83 for (const MachineOperand &op : Instr.all_defs()) { 84 auto Reg = op.getReg(); 85 if (isDivergent(Reg)) 86 pushUsers(Reg); 87 } 88 } 89 90 template <> 91 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle( 92 const MachineInstr &I, const MachineCycle &DefCycle) const { 93 assert(!isAlwaysUniform(I)); 94 for (auto &Op : I.operands()) { 95 if (!Op.isReg() || !Op.readsReg()) 96 continue; 97 auto Reg = Op.getReg(); 98 99 // FIXME: Physical registers need to be properly checked instead of always 100 // returning true 101 if (Reg.isPhysical()) 102 return true; 103 104 auto *Def = F.getRegInfo().getVRegDef(Reg); 105 if (DefCycle.contains(Def->getParent())) 106 return true; 107 } 108 return false; 109 } 110 111 template <> 112 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>:: 113 propagateTemporalDivergence(const MachineInstr &I, 114 const MachineCycle &DefCycle) { 115 const auto &RegInfo = F.getRegInfo(); 116 for (auto &Op : I.all_defs()) { 117 if (!Op.getReg().isVirtual()) 118 continue; 119 auto Reg = Op.getReg(); 120 if (isDivergent(Reg)) 121 continue; 122 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) { 123 if (DefCycle.contains(UserInstr.getParent())) 124 continue; 125 markDivergent(UserInstr); 126 } 127 } 128 } 129 130 template <> 131 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse( 132 const MachineOperand &U) const { 133 if (!U.isReg()) 134 return false; 135 136 auto Reg = U.getReg(); 137 if (isDivergent(Reg)) 138 return true; 139 140 const auto &RegInfo = F.getRegInfo(); 141 auto *Def = RegInfo.getOneDef(Reg); 142 if (!Def) 143 return true; 144 145 auto *DefInstr = Def->getParent(); 146 auto *UseInstr = U.getParent(); 147 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr); 148 } 149 150 // This ensures explicit instantiation of 151 // GenericUniformityAnalysisImpl::ImplDeleter::operator() 152 template class llvm::GenericUniformityInfo<MachineSSAContext>; 153 template struct llvm::GenericUniformityAnalysisImplDeleter< 154 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>; 155 156 MachineUniformityInfo llvm::computeMachineUniformityInfo( 157 MachineFunction &F, const MachineCycleInfo &cycleInfo, 158 const MachineDominatorTree &domTree, bool HasBranchDivergence) { 159 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!"); 160 MachineUniformityInfo UI(domTree, cycleInfo); 161 if (HasBranchDivergence) 162 UI.compute(); 163 return UI; 164 } 165 166 namespace { 167 168 class MachineUniformityInfoPrinterPass : public MachineFunctionPass { 169 public: 170 static char ID; 171 172 MachineUniformityInfoPrinterPass(); 173 174 bool runOnMachineFunction(MachineFunction &F) override; 175 void getAnalysisUsage(AnalysisUsage &AU) const override; 176 }; 177 178 } // namespace 179 180 char MachineUniformityAnalysisPass::ID = 0; 181 182 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass() 183 : MachineFunctionPass(ID) { 184 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry()); 185 } 186 187 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity", 188 "Machine Uniformity Info Analysis", true, true) 189 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass) 190 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) 191 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity", 192 "Machine Uniformity Info Analysis", true, true) 193 194 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const { 195 AU.setPreservesAll(); 196 AU.addRequired<MachineCycleInfoWrapperPass>(); 197 AU.addRequired<MachineDominatorTreeWrapperPass>(); 198 MachineFunctionPass::getAnalysisUsage(AU); 199 } 200 201 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) { 202 auto &DomTree = 203 getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree().getBase(); 204 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo(); 205 // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a 206 // default NoTTI 207 UI = computeMachineUniformityInfo(MF, CI, DomTree, true); 208 return false; 209 } 210 211 void MachineUniformityAnalysisPass::print(raw_ostream &OS, 212 const Module *) const { 213 OS << "MachineUniformityInfo for function: " << UI.getFunction().getName() 214 << "\n"; 215 UI.print(OS); 216 } 217 218 char MachineUniformityInfoPrinterPass::ID = 0; 219 220 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass() 221 : MachineFunctionPass(ID) { 222 initializeMachineUniformityInfoPrinterPassPass( 223 *PassRegistry::getPassRegistry()); 224 } 225 226 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass, 227 "print-machine-uniformity", 228 "Print Machine Uniformity Info Analysis", true, true) 229 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass) 230 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass, 231 "print-machine-uniformity", 232 "Print Machine Uniformity Info Analysis", true, true) 233 234 void MachineUniformityInfoPrinterPass::getAnalysisUsage( 235 AnalysisUsage &AU) const { 236 AU.setPreservesAll(); 237 AU.addRequired<MachineUniformityAnalysisPass>(); 238 MachineFunctionPass::getAnalysisUsage(AU); 239 } 240 241 bool MachineUniformityInfoPrinterPass::runOnMachineFunction( 242 MachineFunction &F) { 243 auto &UI = getAnalysis<MachineUniformityAnalysisPass>(); 244 UI.print(errs()); 245 return false; 246 } 247