xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVFoldMemOffset.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- RISCVFoldMemOffset.cpp - Fold ADDI into memory offsets ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // Look for ADDIs that can be removed by folding their immediate into later
10 // load/store addresses. There may be other arithmetic instructions between the
11 // addi and load/store that we need to reassociate through. If the final result
12 // of the arithmetic is only used by load/store addresses, we can fold the
13 // offset into the all the load/store as long as it doesn't create an offset
14 // that is too large.
15 //
16 //===---------------------------------------------------------------------===//
17 
18 #include "RISCV.h"
19 #include "RISCVSubtarget.h"
20 #include "llvm/CodeGen/MachineFunctionPass.h"
21 #include <queue>
22 
23 using namespace llvm;
24 
25 #define DEBUG_TYPE "riscv-fold-mem-offset"
26 #define RISCV_FOLD_MEM_OFFSET_NAME "RISC-V Fold Memory Offset"
27 
28 namespace {
29 
30 class RISCVFoldMemOffset : public MachineFunctionPass {
31 public:
32   static char ID;
33 
RISCVFoldMemOffset()34   RISCVFoldMemOffset() : MachineFunctionPass(ID) {}
35 
36   bool runOnMachineFunction(MachineFunction &MF) override;
37 
38   bool foldOffset(Register OrigReg, int64_t InitialOffset,
39                   const MachineRegisterInfo &MRI,
40                   DenseMap<MachineInstr *, int64_t> &FoldableInstrs);
41 
getAnalysisUsage(AnalysisUsage & AU) const42   void getAnalysisUsage(AnalysisUsage &AU) const override {
43     AU.setPreservesCFG();
44     MachineFunctionPass::getAnalysisUsage(AU);
45   }
46 
getPassName() const47   StringRef getPassName() const override { return RISCV_FOLD_MEM_OFFSET_NAME; }
48 };
49 
50 // Wrapper class around a std::optional to allow accumulation.
51 class FoldableOffset {
52   std::optional<int64_t> Offset;
53 
54 public:
hasValue() const55   bool hasValue() const { return Offset.has_value(); }
getValue() const56   int64_t getValue() const { return *Offset; }
57 
operator =(int64_t RHS)58   FoldableOffset &operator=(int64_t RHS) {
59     Offset = RHS;
60     return *this;
61   }
62 
operator +=(int64_t RHS)63   FoldableOffset &operator+=(int64_t RHS) {
64     if (!Offset)
65       Offset = 0;
66     Offset = (uint64_t)*Offset + (uint64_t)RHS;
67     return *this;
68   }
69 
operator *()70   int64_t operator*() { return *Offset; }
71 };
72 
73 } // end anonymous namespace
74 
75 char RISCVFoldMemOffset::ID = 0;
INITIALIZE_PASS(RISCVFoldMemOffset,DEBUG_TYPE,RISCV_FOLD_MEM_OFFSET_NAME,false,false)76 INITIALIZE_PASS(RISCVFoldMemOffset, DEBUG_TYPE, RISCV_FOLD_MEM_OFFSET_NAME,
77                 false, false)
78 
79 FunctionPass *llvm::createRISCVFoldMemOffsetPass() {
80   return new RISCVFoldMemOffset();
81 }
82 
83 // Walk forward from the ADDI looking for arithmetic instructions we can
84 // analyze or memory instructions that use it as part of their address
85 // calculation. For each arithmetic instruction we lookup how the offset
86 // contributes to the value in that register use that information to
87 // calculate the contribution to the output of this instruction.
88 // Only addition and left shift are supported.
89 // FIXME: Add multiplication by constant. The constant will be in a register.
foldOffset(Register OrigReg,int64_t InitialOffset,const MachineRegisterInfo & MRI,DenseMap<MachineInstr *,int64_t> & FoldableInstrs)90 bool RISCVFoldMemOffset::foldOffset(
91     Register OrigReg, int64_t InitialOffset, const MachineRegisterInfo &MRI,
92     DenseMap<MachineInstr *, int64_t> &FoldableInstrs) {
93   // Map to hold how much the offset contributes to the value of this register.
94   DenseMap<Register, int64_t> RegToOffsetMap;
95 
96   // Insert root offset into the map.
97   RegToOffsetMap[OrigReg] = InitialOffset;
98 
99   std::queue<Register> Worklist;
100   Worklist.push(OrigReg);
101 
102   while (!Worklist.empty()) {
103     Register Reg = Worklist.front();
104     Worklist.pop();
105 
106     if (!Reg.isVirtual())
107       return false;
108 
109     for (auto &User : MRI.use_nodbg_instructions(Reg)) {
110       FoldableOffset Offset;
111 
112       switch (User.getOpcode()) {
113       default:
114         return false;
115       case RISCV::ADD:
116         if (auto I = RegToOffsetMap.find(User.getOperand(1).getReg());
117             I != RegToOffsetMap.end())
118           Offset = I->second;
119         if (auto I = RegToOffsetMap.find(User.getOperand(2).getReg());
120             I != RegToOffsetMap.end())
121           Offset += I->second;
122         break;
123       case RISCV::SH1ADD:
124         if (auto I = RegToOffsetMap.find(User.getOperand(1).getReg());
125             I != RegToOffsetMap.end())
126           Offset = (uint64_t)I->second << 1;
127         if (auto I = RegToOffsetMap.find(User.getOperand(2).getReg());
128             I != RegToOffsetMap.end())
129           Offset += I->second;
130         break;
131       case RISCV::SH2ADD:
132         if (auto I = RegToOffsetMap.find(User.getOperand(1).getReg());
133             I != RegToOffsetMap.end())
134           Offset = (uint64_t)I->second << 2;
135         if (auto I = RegToOffsetMap.find(User.getOperand(2).getReg());
136             I != RegToOffsetMap.end())
137           Offset += I->second;
138         break;
139       case RISCV::SH3ADD:
140         if (auto I = RegToOffsetMap.find(User.getOperand(1).getReg());
141             I != RegToOffsetMap.end())
142           Offset = (uint64_t)I->second << 3;
143         if (auto I = RegToOffsetMap.find(User.getOperand(2).getReg());
144             I != RegToOffsetMap.end())
145           Offset += I->second;
146         break;
147       case RISCV::ADD_UW:
148       case RISCV::SH1ADD_UW:
149       case RISCV::SH2ADD_UW:
150       case RISCV::SH3ADD_UW:
151         // Don't fold through the zero extended input.
152         if (User.getOperand(1).getReg() == Reg)
153           return false;
154         if (auto I = RegToOffsetMap.find(User.getOperand(2).getReg());
155             I != RegToOffsetMap.end())
156           Offset = I->second;
157         break;
158       case RISCV::SLLI: {
159         unsigned ShAmt = User.getOperand(2).getImm();
160         if (auto I = RegToOffsetMap.find(User.getOperand(1).getReg());
161             I != RegToOffsetMap.end())
162           Offset = (uint64_t)I->second << ShAmt;
163         break;
164       }
165       case RISCV::LB:
166       case RISCV::LBU:
167       case RISCV::SB:
168       case RISCV::LH:
169       case RISCV::LH_INX:
170       case RISCV::LHU:
171       case RISCV::FLH:
172       case RISCV::SH:
173       case RISCV::SH_INX:
174       case RISCV::FSH:
175       case RISCV::LW:
176       case RISCV::LW_INX:
177       case RISCV::LWU:
178       case RISCV::FLW:
179       case RISCV::SW:
180       case RISCV::SW_INX:
181       case RISCV::FSW:
182       case RISCV::LD:
183       case RISCV::FLD:
184       case RISCV::SD:
185       case RISCV::FSD: {
186         // Can't fold into store value.
187         if (User.getOperand(0).getReg() == Reg)
188           return false;
189 
190         // Existing offset must be immediate.
191         if (!User.getOperand(2).isImm())
192           return false;
193 
194         // Require at least one operation between the ADDI and the load/store.
195         // We have other optimizations that should handle the simple case.
196         if (User.getOperand(1).getReg() == OrigReg)
197           return false;
198 
199         auto I = RegToOffsetMap.find(User.getOperand(1).getReg());
200         if (I == RegToOffsetMap.end())
201           return false;
202 
203         int64_t LocalOffset = User.getOperand(2).getImm();
204         assert(isInt<12>(LocalOffset));
205         int64_t CombinedOffset = (uint64_t)LocalOffset + (uint64_t)I->second;
206         if (!isInt<12>(CombinedOffset))
207           return false;
208 
209         FoldableInstrs[&User] = CombinedOffset;
210         continue;
211       }
212       }
213 
214       // If we reach here we should have an accumulated offset.
215       assert(Offset.hasValue() && "Expected an offset");
216 
217       // If the offset is new or changed, add the destination register to the
218       // work list.
219       int64_t OffsetVal = Offset.getValue();
220       auto P =
221           RegToOffsetMap.try_emplace(User.getOperand(0).getReg(), OffsetVal);
222       if (P.second) {
223         Worklist.push(User.getOperand(0).getReg());
224       } else if (P.first->second != OffsetVal) {
225         P.first->second = OffsetVal;
226         Worklist.push(User.getOperand(0).getReg());
227       }
228     }
229   }
230 
231   return true;
232 }
233 
runOnMachineFunction(MachineFunction & MF)234 bool RISCVFoldMemOffset::runOnMachineFunction(MachineFunction &MF) {
235   if (skipFunction(MF.getFunction()))
236     return false;
237 
238   // This optimization may increase size by preventing compression.
239   if (MF.getFunction().hasOptSize())
240     return false;
241 
242   MachineRegisterInfo &MRI = MF.getRegInfo();
243 
244   bool MadeChange = false;
245   for (MachineBasicBlock &MBB : MF) {
246     for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
247       // FIXME: We can support ADDIW from an LUI+ADDIW pair if the result is
248       // equivalent to LUI+ADDI.
249       if (MI.getOpcode() != RISCV::ADDI)
250         continue;
251 
252       // We only want to optimize register ADDIs.
253       if (!MI.getOperand(1).isReg() || !MI.getOperand(2).isImm())
254         continue;
255 
256       // Ignore 'li'.
257       if (MI.getOperand(1).getReg() == RISCV::X0)
258         continue;
259 
260       int64_t Offset = MI.getOperand(2).getImm();
261       assert(isInt<12>(Offset));
262 
263       DenseMap<MachineInstr *, int64_t> FoldableInstrs;
264 
265       if (!foldOffset(MI.getOperand(0).getReg(), Offset, MRI, FoldableInstrs))
266         continue;
267 
268       if (FoldableInstrs.empty())
269         continue;
270 
271       // We can fold this ADDI.
272       // Rewrite all the instructions.
273       for (auto [MemMI, NewOffset] : FoldableInstrs)
274         MemMI->getOperand(2).setImm(NewOffset);
275 
276       MRI.replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(1).getReg());
277       MRI.clearKillFlags(MI.getOperand(1).getReg());
278       MI.eraseFromParent();
279     }
280   }
281 
282   return MadeChange;
283 }
284