xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/MachineUniformityAnalysis.cpp (revision be092bcde96bdcfde9013d60e442cca023bfbd1b)
1  //===- MachineUniformityAnalysis.cpp --------------------------------------===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  
9  #include "llvm/CodeGen/MachineUniformityAnalysis.h"
10  #include "llvm/ADT/GenericUniformityImpl.h"
11  #include "llvm/CodeGen/MachineCycleAnalysis.h"
12  #include "llvm/CodeGen/MachineDominators.h"
13  #include "llvm/CodeGen/MachineRegisterInfo.h"
14  #include "llvm/CodeGen/MachineSSAContext.h"
15  #include "llvm/CodeGen/TargetInstrInfo.h"
16  #include "llvm/InitializePasses.h"
17  
18  using namespace llvm;
19  
20  template <>
21  bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
22      const MachineInstr &I) const {
23    for (auto &op : I.operands()) {
24      if (!op.isReg() || !op.isDef())
25        continue;
26      if (isDivergent(op.getReg()))
27        return true;
28    }
29    return false;
30  }
31  
32  template <>
33  bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
34      const MachineInstr &Instr, bool AllDefsDivergent) {
35    bool insertedDivergent = false;
36    const auto &MRI = F.getRegInfo();
37    const auto &TRI = *MRI.getTargetRegisterInfo();
38    for (auto &op : Instr.operands()) {
39      if (!op.isReg() || !op.isDef())
40        continue;
41      if (!op.getReg().isVirtual())
42        continue;
43      assert(!op.getSubReg());
44      if (!AllDefsDivergent) {
45        auto *RC = MRI.getRegClassOrNull(op.getReg());
46        if (RC && !TRI.isDivergentRegClass(RC))
47          continue;
48      }
49      insertedDivergent |= markDivergent(op.getReg());
50    }
51    return insertedDivergent;
52  }
53  
54  template <>
55  void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
56    const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
57  
58    for (const MachineBasicBlock &block : F) {
59      for (const MachineInstr &instr : block) {
60        auto uniformity = InstrInfo.getInstructionUniformity(instr);
61        if (uniformity == InstructionUniformity::AlwaysUniform) {
62          addUniformOverride(instr);
63          continue;
64        }
65  
66        if (uniformity == InstructionUniformity::NeverUniform) {
67          markDefsDivergent(instr, /* AllDefsDivergent = */ false);
68        }
69      }
70    }
71  }
72  
73  template <>
74  void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
75      Register Reg) {
76    const auto &RegInfo = F.getRegInfo();
77    for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
78      if (isAlwaysUniform(UserInstr))
79        continue;
80      if (markDivergent(UserInstr))
81        Worklist.push_back(&UserInstr);
82    }
83  }
84  
85  template <>
86  void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
87      const MachineInstr &Instr) {
88    assert(!isAlwaysUniform(Instr));
89    if (Instr.isTerminator())
90      return;
91    for (const MachineOperand &op : Instr.operands()) {
92      if (op.isReg() && op.isDef() && op.getReg().isVirtual())
93        pushUsers(op.getReg());
94    }
95  }
96  
97  template <>
98  bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
99      const MachineInstr &I, const MachineCycle &DefCycle) const {
100    assert(!isAlwaysUniform(I));
101    for (auto &Op : I.operands()) {
102      if (!Op.isReg() || !Op.readsReg())
103        continue;
104      auto Reg = Op.getReg();
105      assert(Reg.isVirtual());
106      auto *Def = F.getRegInfo().getVRegDef(Reg);
107      if (DefCycle.contains(Def->getParent()))
108        return true;
109    }
110    return false;
111  }
112  
113  // This ensures explicit instantiation of
114  // GenericUniformityAnalysisImpl::ImplDeleter::operator()
115  template class llvm::GenericUniformityInfo<MachineSSAContext>;
116  template struct llvm::GenericUniformityAnalysisImplDeleter<
117      llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
118  
119  MachineUniformityInfo
120  llvm::computeMachineUniformityInfo(MachineFunction &F,
121                                     const MachineCycleInfo &cycleInfo,
122                                     const MachineDomTree &domTree) {
123    assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
124    return MachineUniformityInfo(F, domTree, cycleInfo);
125  }
126  
127  namespace {
128  
129  /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
130  class MachineUniformityAnalysisPass : public MachineFunctionPass {
131    MachineUniformityInfo UI;
132  
133  public:
134    static char ID;
135  
136    MachineUniformityAnalysisPass();
137  
138    MachineUniformityInfo &getUniformityInfo() { return UI; }
139    const MachineUniformityInfo &getUniformityInfo() const { return UI; }
140  
141    bool runOnMachineFunction(MachineFunction &F) override;
142    void getAnalysisUsage(AnalysisUsage &AU) const override;
143    void print(raw_ostream &OS, const Module *M = nullptr) const override;
144  
145    // TODO: verify analysis
146  };
147  
148  class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
149  public:
150    static char ID;
151  
152    MachineUniformityInfoPrinterPass();
153  
154    bool runOnMachineFunction(MachineFunction &F) override;
155    void getAnalysisUsage(AnalysisUsage &AU) const override;
156  };
157  
158  } // namespace
159  
160  char MachineUniformityAnalysisPass::ID = 0;
161  
162  MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
163      : MachineFunctionPass(ID) {
164    initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
165  }
166  
167  INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
168                        "Machine Uniformity Info Analysis", true, true)
169  INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
170  INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
171  INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
172                      "Machine Uniformity Info Analysis", true, true)
173  
174  void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
175    AU.setPreservesAll();
176    AU.addRequired<MachineCycleInfoWrapperPass>();
177    AU.addRequired<MachineDominatorTree>();
178    MachineFunctionPass::getAnalysisUsage(AU);
179  }
180  
181  bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
182    auto &DomTree = getAnalysis<MachineDominatorTree>().getBase();
183    auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
184    UI = computeMachineUniformityInfo(MF, CI, DomTree);
185    return false;
186  }
187  
188  void MachineUniformityAnalysisPass::print(raw_ostream &OS,
189                                            const Module *) const {
190    OS << "MachineUniformityInfo for function: " << UI.getFunction().getName()
191       << "\n";
192    UI.print(OS);
193  }
194  
195  char MachineUniformityInfoPrinterPass::ID = 0;
196  
197  MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
198      : MachineFunctionPass(ID) {
199    initializeMachineUniformityInfoPrinterPassPass(
200        *PassRegistry::getPassRegistry());
201  }
202  
203  INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
204                        "print-machine-uniformity",
205                        "Print Machine Uniformity Info Analysis", true, true)
206  INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
207  INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
208                      "print-machine-uniformity",
209                      "Print Machine Uniformity Info Analysis", true, true)
210  
211  void MachineUniformityInfoPrinterPass::getAnalysisUsage(
212      AnalysisUsage &AU) const {
213    AU.setPreservesAll();
214    AU.addRequired<MachineUniformityAnalysisPass>();
215    MachineFunctionPass::getAnalysisUsage(AU);
216  }
217  
218  bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
219      MachineFunction &F) {
220    auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
221    UI.print(errs());
222    return false;
223  }
224