xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVExpandAtomicPseudoInsts.cpp (revision 069ac18495ad8fde2748bc94b0f80a50250bb01d)
1 //===-- RISCVExpandAtomicPseudoInsts.cpp - Expand atomic pseudo instrs. ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a pass that expands atomic pseudo instructions into
10 // target instructions. This pass should be run at the last possible moment,
11 // avoiding the possibility for other passes to break the requirements for
12 // forward progress in the LR/SC block.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "RISCV.h"
17 #include "RISCVInstrInfo.h"
18 #include "RISCVTargetMachine.h"
19 
20 #include "llvm/CodeGen/LivePhysRegs.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 
24 using namespace llvm;
25 
26 #define RISCV_EXPAND_ATOMIC_PSEUDO_NAME                                        \
27   "RISC-V atomic pseudo instruction expansion pass"
28 
29 namespace {
30 
31 class RISCVExpandAtomicPseudo : public MachineFunctionPass {
32 public:
33   const RISCVInstrInfo *TII;
34   static char ID;
35 
36   RISCVExpandAtomicPseudo() : MachineFunctionPass(ID) {
37     initializeRISCVExpandAtomicPseudoPass(*PassRegistry::getPassRegistry());
38   }
39 
40   bool runOnMachineFunction(MachineFunction &MF) override;
41 
42   StringRef getPassName() const override {
43     return RISCV_EXPAND_ATOMIC_PSEUDO_NAME;
44   }
45 
46 private:
47   bool expandMBB(MachineBasicBlock &MBB);
48   bool expandMI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
49                 MachineBasicBlock::iterator &NextMBBI);
50   bool expandAtomicBinOp(MachineBasicBlock &MBB,
51                          MachineBasicBlock::iterator MBBI, AtomicRMWInst::BinOp,
52                          bool IsMasked, int Width,
53                          MachineBasicBlock::iterator &NextMBBI);
54   bool expandAtomicMinMaxOp(MachineBasicBlock &MBB,
55                             MachineBasicBlock::iterator MBBI,
56                             AtomicRMWInst::BinOp, bool IsMasked, int Width,
57                             MachineBasicBlock::iterator &NextMBBI);
58   bool expandAtomicCmpXchg(MachineBasicBlock &MBB,
59                            MachineBasicBlock::iterator MBBI, bool IsMasked,
60                            int Width, MachineBasicBlock::iterator &NextMBBI);
61 #ifndef NDEBUG
62   unsigned getInstSizeInBytes(const MachineFunction &MF) const {
63     unsigned Size = 0;
64     for (auto &MBB : MF)
65       for (auto &MI : MBB)
66         Size += TII->getInstSizeInBytes(MI);
67     return Size;
68   }
69 #endif
70 };
71 
72 char RISCVExpandAtomicPseudo::ID = 0;
73 
74 bool RISCVExpandAtomicPseudo::runOnMachineFunction(MachineFunction &MF) {
75   TII = MF.getSubtarget<RISCVSubtarget>().getInstrInfo();
76 
77 #ifndef NDEBUG
78   const unsigned OldSize = getInstSizeInBytes(MF);
79 #endif
80 
81   bool Modified = false;
82   for (auto &MBB : MF)
83     Modified |= expandMBB(MBB);
84 
85 #ifndef NDEBUG
86   const unsigned NewSize = getInstSizeInBytes(MF);
87   assert(OldSize >= NewSize);
88 #endif
89   return Modified;
90 }
91 
92 bool RISCVExpandAtomicPseudo::expandMBB(MachineBasicBlock &MBB) {
93   bool Modified = false;
94 
95   MachineBasicBlock::iterator MBBI = MBB.begin(), E = MBB.end();
96   while (MBBI != E) {
97     MachineBasicBlock::iterator NMBBI = std::next(MBBI);
98     Modified |= expandMI(MBB, MBBI, NMBBI);
99     MBBI = NMBBI;
100   }
101 
102   return Modified;
103 }
104 
105 bool RISCVExpandAtomicPseudo::expandMI(MachineBasicBlock &MBB,
106                                        MachineBasicBlock::iterator MBBI,
107                                        MachineBasicBlock::iterator &NextMBBI) {
108   // RISCVInstrInfo::getInstSizeInBytes expects that the total size of the
109   // expanded instructions for each pseudo is correct in the Size field of the
110   // tablegen definition for the pseudo.
111   switch (MBBI->getOpcode()) {
112   case RISCV::PseudoAtomicLoadNand32:
113     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, false, 32,
114                              NextMBBI);
115   case RISCV::PseudoAtomicLoadNand64:
116     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, false, 64,
117                              NextMBBI);
118   case RISCV::PseudoMaskedAtomicSwap32:
119     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Xchg, true, 32,
120                              NextMBBI);
121   case RISCV::PseudoMaskedAtomicLoadAdd32:
122     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Add, true, 32, NextMBBI);
123   case RISCV::PseudoMaskedAtomicLoadSub32:
124     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Sub, true, 32, NextMBBI);
125   case RISCV::PseudoMaskedAtomicLoadNand32:
126     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, true, 32,
127                              NextMBBI);
128   case RISCV::PseudoMaskedAtomicLoadMax32:
129     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::Max, true, 32,
130                                 NextMBBI);
131   case RISCV::PseudoMaskedAtomicLoadMin32:
132     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::Min, true, 32,
133                                 NextMBBI);
134   case RISCV::PseudoMaskedAtomicLoadUMax32:
135     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMax, true, 32,
136                                 NextMBBI);
137   case RISCV::PseudoMaskedAtomicLoadUMin32:
138     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMin, true, 32,
139                                 NextMBBI);
140   case RISCV::PseudoCmpXchg32:
141     return expandAtomicCmpXchg(MBB, MBBI, false, 32, NextMBBI);
142   case RISCV::PseudoCmpXchg64:
143     return expandAtomicCmpXchg(MBB, MBBI, false, 64, NextMBBI);
144   case RISCV::PseudoMaskedCmpXchg32:
145     return expandAtomicCmpXchg(MBB, MBBI, true, 32, NextMBBI);
146   }
147 
148   return false;
149 }
150 
151 static unsigned getLRForRMW32(AtomicOrdering Ordering) {
152   switch (Ordering) {
153   default:
154     llvm_unreachable("Unexpected AtomicOrdering");
155   case AtomicOrdering::Monotonic:
156     return RISCV::LR_W;
157   case AtomicOrdering::Acquire:
158     return RISCV::LR_W_AQ;
159   case AtomicOrdering::Release:
160     return RISCV::LR_W;
161   case AtomicOrdering::AcquireRelease:
162     return RISCV::LR_W_AQ;
163   case AtomicOrdering::SequentiallyConsistent:
164     return RISCV::LR_W_AQ_RL;
165   }
166 }
167 
168 static unsigned getSCForRMW32(AtomicOrdering Ordering) {
169   switch (Ordering) {
170   default:
171     llvm_unreachable("Unexpected AtomicOrdering");
172   case AtomicOrdering::Monotonic:
173     return RISCV::SC_W;
174   case AtomicOrdering::Acquire:
175     return RISCV::SC_W;
176   case AtomicOrdering::Release:
177     return RISCV::SC_W_RL;
178   case AtomicOrdering::AcquireRelease:
179     return RISCV::SC_W_RL;
180   case AtomicOrdering::SequentiallyConsistent:
181     return RISCV::SC_W_RL;
182   }
183 }
184 
185 static unsigned getLRForRMW64(AtomicOrdering Ordering) {
186   switch (Ordering) {
187   default:
188     llvm_unreachable("Unexpected AtomicOrdering");
189   case AtomicOrdering::Monotonic:
190     return RISCV::LR_D;
191   case AtomicOrdering::Acquire:
192     return RISCV::LR_D_AQ;
193   case AtomicOrdering::Release:
194     return RISCV::LR_D;
195   case AtomicOrdering::AcquireRelease:
196     return RISCV::LR_D_AQ;
197   case AtomicOrdering::SequentiallyConsistent:
198     return RISCV::LR_D_AQ_RL;
199   }
200 }
201 
202 static unsigned getSCForRMW64(AtomicOrdering Ordering) {
203   switch (Ordering) {
204   default:
205     llvm_unreachable("Unexpected AtomicOrdering");
206   case AtomicOrdering::Monotonic:
207     return RISCV::SC_D;
208   case AtomicOrdering::Acquire:
209     return RISCV::SC_D;
210   case AtomicOrdering::Release:
211     return RISCV::SC_D_RL;
212   case AtomicOrdering::AcquireRelease:
213     return RISCV::SC_D_RL;
214   case AtomicOrdering::SequentiallyConsistent:
215     return RISCV::SC_D_RL;
216   }
217 }
218 
219 static unsigned getLRForRMW(AtomicOrdering Ordering, int Width) {
220   if (Width == 32)
221     return getLRForRMW32(Ordering);
222   if (Width == 64)
223     return getLRForRMW64(Ordering);
224   llvm_unreachable("Unexpected LR width\n");
225 }
226 
227 static unsigned getSCForRMW(AtomicOrdering Ordering, int Width) {
228   if (Width == 32)
229     return getSCForRMW32(Ordering);
230   if (Width == 64)
231     return getSCForRMW64(Ordering);
232   llvm_unreachable("Unexpected SC width\n");
233 }
234 
235 static void doAtomicBinOpExpansion(const RISCVInstrInfo *TII, MachineInstr &MI,
236                                    DebugLoc DL, MachineBasicBlock *ThisMBB,
237                                    MachineBasicBlock *LoopMBB,
238                                    MachineBasicBlock *DoneMBB,
239                                    AtomicRMWInst::BinOp BinOp, int Width) {
240   Register DestReg = MI.getOperand(0).getReg();
241   Register ScratchReg = MI.getOperand(1).getReg();
242   Register AddrReg = MI.getOperand(2).getReg();
243   Register IncrReg = MI.getOperand(3).getReg();
244   AtomicOrdering Ordering =
245       static_cast<AtomicOrdering>(MI.getOperand(4).getImm());
246 
247   // .loop:
248   //   lr.[w|d] dest, (addr)
249   //   binop scratch, dest, val
250   //   sc.[w|d] scratch, scratch, (addr)
251   //   bnez scratch, loop
252   BuildMI(LoopMBB, DL, TII->get(getLRForRMW(Ordering, Width)), DestReg)
253       .addReg(AddrReg);
254   switch (BinOp) {
255   default:
256     llvm_unreachable("Unexpected AtomicRMW BinOp");
257   case AtomicRMWInst::Nand:
258     BuildMI(LoopMBB, DL, TII->get(RISCV::AND), ScratchReg)
259         .addReg(DestReg)
260         .addReg(IncrReg);
261     BuildMI(LoopMBB, DL, TII->get(RISCV::XORI), ScratchReg)
262         .addReg(ScratchReg)
263         .addImm(-1);
264     break;
265   }
266   BuildMI(LoopMBB, DL, TII->get(getSCForRMW(Ordering, Width)), ScratchReg)
267       .addReg(AddrReg)
268       .addReg(ScratchReg);
269   BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
270       .addReg(ScratchReg)
271       .addReg(RISCV::X0)
272       .addMBB(LoopMBB);
273 }
274 
275 static void insertMaskedMerge(const RISCVInstrInfo *TII, DebugLoc DL,
276                               MachineBasicBlock *MBB, Register DestReg,
277                               Register OldValReg, Register NewValReg,
278                               Register MaskReg, Register ScratchReg) {
279   assert(OldValReg != ScratchReg && "OldValReg and ScratchReg must be unique");
280   assert(OldValReg != MaskReg && "OldValReg and MaskReg must be unique");
281   assert(ScratchReg != MaskReg && "ScratchReg and MaskReg must be unique");
282 
283   // We select bits from newval and oldval using:
284   // https://graphics.stanford.edu/~seander/bithacks.html#MaskedMerge
285   // r = oldval ^ ((oldval ^ newval) & masktargetdata);
286   BuildMI(MBB, DL, TII->get(RISCV::XOR), ScratchReg)
287       .addReg(OldValReg)
288       .addReg(NewValReg);
289   BuildMI(MBB, DL, TII->get(RISCV::AND), ScratchReg)
290       .addReg(ScratchReg)
291       .addReg(MaskReg);
292   BuildMI(MBB, DL, TII->get(RISCV::XOR), DestReg)
293       .addReg(OldValReg)
294       .addReg(ScratchReg);
295 }
296 
297 static void doMaskedAtomicBinOpExpansion(
298     const RISCVInstrInfo *TII, MachineInstr &MI, DebugLoc DL,
299     MachineBasicBlock *ThisMBB, MachineBasicBlock *LoopMBB,
300     MachineBasicBlock *DoneMBB, AtomicRMWInst::BinOp BinOp, int Width) {
301   assert(Width == 32 && "Should never need to expand masked 64-bit operations");
302   Register DestReg = MI.getOperand(0).getReg();
303   Register ScratchReg = MI.getOperand(1).getReg();
304   Register AddrReg = MI.getOperand(2).getReg();
305   Register IncrReg = MI.getOperand(3).getReg();
306   Register MaskReg = MI.getOperand(4).getReg();
307   AtomicOrdering Ordering =
308       static_cast<AtomicOrdering>(MI.getOperand(5).getImm());
309 
310   // .loop:
311   //   lr.w destreg, (alignedaddr)
312   //   binop scratch, destreg, incr
313   //   xor scratch, destreg, scratch
314   //   and scratch, scratch, masktargetdata
315   //   xor scratch, destreg, scratch
316   //   sc.w scratch, scratch, (alignedaddr)
317   //   bnez scratch, loop
318   BuildMI(LoopMBB, DL, TII->get(getLRForRMW32(Ordering)), DestReg)
319       .addReg(AddrReg);
320   switch (BinOp) {
321   default:
322     llvm_unreachable("Unexpected AtomicRMW BinOp");
323   case AtomicRMWInst::Xchg:
324     BuildMI(LoopMBB, DL, TII->get(RISCV::ADDI), ScratchReg)
325         .addReg(IncrReg)
326         .addImm(0);
327     break;
328   case AtomicRMWInst::Add:
329     BuildMI(LoopMBB, DL, TII->get(RISCV::ADD), ScratchReg)
330         .addReg(DestReg)
331         .addReg(IncrReg);
332     break;
333   case AtomicRMWInst::Sub:
334     BuildMI(LoopMBB, DL, TII->get(RISCV::SUB), ScratchReg)
335         .addReg(DestReg)
336         .addReg(IncrReg);
337     break;
338   case AtomicRMWInst::Nand:
339     BuildMI(LoopMBB, DL, TII->get(RISCV::AND), ScratchReg)
340         .addReg(DestReg)
341         .addReg(IncrReg);
342     BuildMI(LoopMBB, DL, TII->get(RISCV::XORI), ScratchReg)
343         .addReg(ScratchReg)
344         .addImm(-1);
345     break;
346   }
347 
348   insertMaskedMerge(TII, DL, LoopMBB, ScratchReg, DestReg, ScratchReg, MaskReg,
349                     ScratchReg);
350 
351   BuildMI(LoopMBB, DL, TII->get(getSCForRMW32(Ordering)), ScratchReg)
352       .addReg(AddrReg)
353       .addReg(ScratchReg);
354   BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
355       .addReg(ScratchReg)
356       .addReg(RISCV::X0)
357       .addMBB(LoopMBB);
358 }
359 
360 bool RISCVExpandAtomicPseudo::expandAtomicBinOp(
361     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
362     AtomicRMWInst::BinOp BinOp, bool IsMasked, int Width,
363     MachineBasicBlock::iterator &NextMBBI) {
364   MachineInstr &MI = *MBBI;
365   DebugLoc DL = MI.getDebugLoc();
366 
367   MachineFunction *MF = MBB.getParent();
368   auto LoopMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
369   auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
370 
371   // Insert new MBBs.
372   MF->insert(++MBB.getIterator(), LoopMBB);
373   MF->insert(++LoopMBB->getIterator(), DoneMBB);
374 
375   // Set up successors and transfer remaining instructions to DoneMBB.
376   LoopMBB->addSuccessor(LoopMBB);
377   LoopMBB->addSuccessor(DoneMBB);
378   DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
379   DoneMBB->transferSuccessors(&MBB);
380   MBB.addSuccessor(LoopMBB);
381 
382   if (!IsMasked)
383     doAtomicBinOpExpansion(TII, MI, DL, &MBB, LoopMBB, DoneMBB, BinOp, Width);
384   else
385     doMaskedAtomicBinOpExpansion(TII, MI, DL, &MBB, LoopMBB, DoneMBB, BinOp,
386                                  Width);
387 
388   NextMBBI = MBB.end();
389   MI.eraseFromParent();
390 
391   LivePhysRegs LiveRegs;
392   computeAndAddLiveIns(LiveRegs, *LoopMBB);
393   computeAndAddLiveIns(LiveRegs, *DoneMBB);
394 
395   return true;
396 }
397 
398 static void insertSext(const RISCVInstrInfo *TII, DebugLoc DL,
399                        MachineBasicBlock *MBB, Register ValReg,
400                        Register ShamtReg) {
401   BuildMI(MBB, DL, TII->get(RISCV::SLL), ValReg)
402       .addReg(ValReg)
403       .addReg(ShamtReg);
404   BuildMI(MBB, DL, TII->get(RISCV::SRA), ValReg)
405       .addReg(ValReg)
406       .addReg(ShamtReg);
407 }
408 
409 bool RISCVExpandAtomicPseudo::expandAtomicMinMaxOp(
410     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
411     AtomicRMWInst::BinOp BinOp, bool IsMasked, int Width,
412     MachineBasicBlock::iterator &NextMBBI) {
413   assert(IsMasked == true &&
414          "Should only need to expand masked atomic max/min");
415   assert(Width == 32 && "Should never need to expand masked 64-bit operations");
416 
417   MachineInstr &MI = *MBBI;
418   DebugLoc DL = MI.getDebugLoc();
419   MachineFunction *MF = MBB.getParent();
420   auto LoopHeadMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
421   auto LoopIfBodyMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
422   auto LoopTailMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
423   auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
424 
425   // Insert new MBBs.
426   MF->insert(++MBB.getIterator(), LoopHeadMBB);
427   MF->insert(++LoopHeadMBB->getIterator(), LoopIfBodyMBB);
428   MF->insert(++LoopIfBodyMBB->getIterator(), LoopTailMBB);
429   MF->insert(++LoopTailMBB->getIterator(), DoneMBB);
430 
431   // Set up successors and transfer remaining instructions to DoneMBB.
432   LoopHeadMBB->addSuccessor(LoopIfBodyMBB);
433   LoopHeadMBB->addSuccessor(LoopTailMBB);
434   LoopIfBodyMBB->addSuccessor(LoopTailMBB);
435   LoopTailMBB->addSuccessor(LoopHeadMBB);
436   LoopTailMBB->addSuccessor(DoneMBB);
437   DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
438   DoneMBB->transferSuccessors(&MBB);
439   MBB.addSuccessor(LoopHeadMBB);
440 
441   Register DestReg = MI.getOperand(0).getReg();
442   Register Scratch1Reg = MI.getOperand(1).getReg();
443   Register Scratch2Reg = MI.getOperand(2).getReg();
444   Register AddrReg = MI.getOperand(3).getReg();
445   Register IncrReg = MI.getOperand(4).getReg();
446   Register MaskReg = MI.getOperand(5).getReg();
447   bool IsSigned = BinOp == AtomicRMWInst::Min || BinOp == AtomicRMWInst::Max;
448   AtomicOrdering Ordering =
449       static_cast<AtomicOrdering>(MI.getOperand(IsSigned ? 7 : 6).getImm());
450 
451   //
452   // .loophead:
453   //   lr.w destreg, (alignedaddr)
454   //   and scratch2, destreg, mask
455   //   mv scratch1, destreg
456   //   [sext scratch2 if signed min/max]
457   //   ifnochangeneeded scratch2, incr, .looptail
458   BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW32(Ordering)), DestReg)
459       .addReg(AddrReg);
460   BuildMI(LoopHeadMBB, DL, TII->get(RISCV::AND), Scratch2Reg)
461       .addReg(DestReg)
462       .addReg(MaskReg);
463   BuildMI(LoopHeadMBB, DL, TII->get(RISCV::ADDI), Scratch1Reg)
464       .addReg(DestReg)
465       .addImm(0);
466 
467   switch (BinOp) {
468   default:
469     llvm_unreachable("Unexpected AtomicRMW BinOp");
470   case AtomicRMWInst::Max: {
471     insertSext(TII, DL, LoopHeadMBB, Scratch2Reg, MI.getOperand(6).getReg());
472     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGE))
473         .addReg(Scratch2Reg)
474         .addReg(IncrReg)
475         .addMBB(LoopTailMBB);
476     break;
477   }
478   case AtomicRMWInst::Min: {
479     insertSext(TII, DL, LoopHeadMBB, Scratch2Reg, MI.getOperand(6).getReg());
480     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGE))
481         .addReg(IncrReg)
482         .addReg(Scratch2Reg)
483         .addMBB(LoopTailMBB);
484     break;
485   }
486   case AtomicRMWInst::UMax:
487     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGEU))
488         .addReg(Scratch2Reg)
489         .addReg(IncrReg)
490         .addMBB(LoopTailMBB);
491     break;
492   case AtomicRMWInst::UMin:
493     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGEU))
494         .addReg(IncrReg)
495         .addReg(Scratch2Reg)
496         .addMBB(LoopTailMBB);
497     break;
498   }
499 
500   // .loopifbody:
501   //   xor scratch1, destreg, incr
502   //   and scratch1, scratch1, mask
503   //   xor scratch1, destreg, scratch1
504   insertMaskedMerge(TII, DL, LoopIfBodyMBB, Scratch1Reg, DestReg, IncrReg,
505                     MaskReg, Scratch1Reg);
506 
507   // .looptail:
508   //   sc.w scratch1, scratch1, (addr)
509   //   bnez scratch1, loop
510   BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW32(Ordering)), Scratch1Reg)
511       .addReg(AddrReg)
512       .addReg(Scratch1Reg);
513   BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
514       .addReg(Scratch1Reg)
515       .addReg(RISCV::X0)
516       .addMBB(LoopHeadMBB);
517 
518   NextMBBI = MBB.end();
519   MI.eraseFromParent();
520 
521   LivePhysRegs LiveRegs;
522   computeAndAddLiveIns(LiveRegs, *LoopHeadMBB);
523   computeAndAddLiveIns(LiveRegs, *LoopIfBodyMBB);
524   computeAndAddLiveIns(LiveRegs, *LoopTailMBB);
525   computeAndAddLiveIns(LiveRegs, *DoneMBB);
526 
527   return true;
528 }
529 
530 // If a BNE on the cmpxchg comparison result immediately follows the cmpxchg
531 // operation, it can be folded into the cmpxchg expansion by
532 // modifying the branch within 'LoopHead' (which performs the same
533 // comparison). This is a valid transformation because after altering the
534 // LoopHead's BNE destination, the BNE following the cmpxchg becomes
535 // redundant and and be deleted. In the case of a masked cmpxchg, an
536 // appropriate AND and BNE must be matched.
537 //
538 // On success, returns true and deletes the matching BNE or AND+BNE, sets the
539 // LoopHeadBNETarget argument to the target that should be used within the
540 // loop head, and removes that block as a successor to MBB.
541 bool tryToFoldBNEOnCmpXchgResult(MachineBasicBlock &MBB,
542                                  MachineBasicBlock::iterator MBBI,
543                                  Register DestReg, Register CmpValReg,
544                                  Register MaskReg,
545                                  MachineBasicBlock *&LoopHeadBNETarget) {
546   SmallVector<MachineInstr *> ToErase;
547   auto E = MBB.end();
548   if (MBBI == E)
549     return false;
550   MBBI = skipDebugInstructionsForward(MBBI, E);
551 
552   // If we have a masked cmpxchg, match AND dst, DestReg, MaskReg.
553   if (MaskReg.isValid()) {
554     if (MBBI == E || MBBI->getOpcode() != RISCV::AND)
555       return false;
556     Register ANDOp1 = MBBI->getOperand(1).getReg();
557     Register ANDOp2 = MBBI->getOperand(2).getReg();
558     if (!(ANDOp1 == DestReg && ANDOp2 == MaskReg) &&
559         !(ANDOp1 == MaskReg && ANDOp2 == DestReg))
560       return false;
561     // We now expect the BNE to use the result of the AND as an operand.
562     DestReg = MBBI->getOperand(0).getReg();
563     ToErase.push_back(&*MBBI);
564     MBBI = skipDebugInstructionsForward(std::next(MBBI), E);
565   }
566 
567   // Match BNE DestReg, MaskReg.
568   if (MBBI == E || MBBI->getOpcode() != RISCV::BNE)
569     return false;
570   Register BNEOp0 = MBBI->getOperand(0).getReg();
571   Register BNEOp1 = MBBI->getOperand(1).getReg();
572   if (!(BNEOp0 == DestReg && BNEOp1 == CmpValReg) &&
573       !(BNEOp0 == CmpValReg && BNEOp1 == DestReg))
574     return false;
575 
576   // Make sure the branch is the only user of the AND.
577   if (MaskReg.isValid()) {
578     if (BNEOp0 == DestReg && !MBBI->getOperand(0).isKill())
579       return false;
580     if (BNEOp1 == DestReg && !MBBI->getOperand(1).isKill())
581       return false;
582   }
583 
584   ToErase.push_back(&*MBBI);
585   LoopHeadBNETarget = MBBI->getOperand(2).getMBB();
586   MBBI = skipDebugInstructionsForward(std::next(MBBI), E);
587   if (MBBI != E)
588     return false;
589 
590   MBB.removeSuccessor(LoopHeadBNETarget);
591   for (auto *MI : ToErase)
592     MI->eraseFromParent();
593   return true;
594 }
595 
596 bool RISCVExpandAtomicPseudo::expandAtomicCmpXchg(
597     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, bool IsMasked,
598     int Width, MachineBasicBlock::iterator &NextMBBI) {
599   MachineInstr &MI = *MBBI;
600   DebugLoc DL = MI.getDebugLoc();
601   MachineFunction *MF = MBB.getParent();
602   auto LoopHeadMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
603   auto LoopTailMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
604   auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
605 
606   Register DestReg = MI.getOperand(0).getReg();
607   Register ScratchReg = MI.getOperand(1).getReg();
608   Register AddrReg = MI.getOperand(2).getReg();
609   Register CmpValReg = MI.getOperand(3).getReg();
610   Register NewValReg = MI.getOperand(4).getReg();
611   Register MaskReg = IsMasked ? MI.getOperand(5).getReg() : Register();
612 
613   MachineBasicBlock *LoopHeadBNETarget = DoneMBB;
614   tryToFoldBNEOnCmpXchgResult(MBB, std::next(MBBI), DestReg, CmpValReg, MaskReg,
615                               LoopHeadBNETarget);
616 
617   // Insert new MBBs.
618   MF->insert(++MBB.getIterator(), LoopHeadMBB);
619   MF->insert(++LoopHeadMBB->getIterator(), LoopTailMBB);
620   MF->insert(++LoopTailMBB->getIterator(), DoneMBB);
621 
622   // Set up successors and transfer remaining instructions to DoneMBB.
623   LoopHeadMBB->addSuccessor(LoopTailMBB);
624   LoopHeadMBB->addSuccessor(LoopHeadBNETarget);
625   LoopTailMBB->addSuccessor(DoneMBB);
626   LoopTailMBB->addSuccessor(LoopHeadMBB);
627   DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
628   DoneMBB->transferSuccessors(&MBB);
629   MBB.addSuccessor(LoopHeadMBB);
630 
631   AtomicOrdering Ordering =
632       static_cast<AtomicOrdering>(MI.getOperand(IsMasked ? 6 : 5).getImm());
633 
634   if (!IsMasked) {
635     // .loophead:
636     //   lr.[w|d] dest, (addr)
637     //   bne dest, cmpval, done
638     BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(Ordering, Width)), DestReg)
639         .addReg(AddrReg);
640     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BNE))
641         .addReg(DestReg)
642         .addReg(CmpValReg)
643         .addMBB(LoopHeadBNETarget);
644     // .looptail:
645     //   sc.[w|d] scratch, newval, (addr)
646     //   bnez scratch, loophead
647     BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW(Ordering, Width)), ScratchReg)
648         .addReg(AddrReg)
649         .addReg(NewValReg);
650     BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
651         .addReg(ScratchReg)
652         .addReg(RISCV::X0)
653         .addMBB(LoopHeadMBB);
654   } else {
655     // .loophead:
656     //   lr.w dest, (addr)
657     //   and scratch, dest, mask
658     //   bne scratch, cmpval, done
659     Register MaskReg = MI.getOperand(5).getReg();
660     BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(Ordering, Width)), DestReg)
661         .addReg(AddrReg);
662     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::AND), ScratchReg)
663         .addReg(DestReg)
664         .addReg(MaskReg);
665     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BNE))
666         .addReg(ScratchReg)
667         .addReg(CmpValReg)
668         .addMBB(LoopHeadBNETarget);
669 
670     // .looptail:
671     //   xor scratch, dest, newval
672     //   and scratch, scratch, mask
673     //   xor scratch, dest, scratch
674     //   sc.w scratch, scratch, (adrr)
675     //   bnez scratch, loophead
676     insertMaskedMerge(TII, DL, LoopTailMBB, ScratchReg, DestReg, NewValReg,
677                       MaskReg, ScratchReg);
678     BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW(Ordering, Width)), ScratchReg)
679         .addReg(AddrReg)
680         .addReg(ScratchReg);
681     BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
682         .addReg(ScratchReg)
683         .addReg(RISCV::X0)
684         .addMBB(LoopHeadMBB);
685   }
686 
687   NextMBBI = MBB.end();
688   MI.eraseFromParent();
689 
690   LivePhysRegs LiveRegs;
691   computeAndAddLiveIns(LiveRegs, *LoopHeadMBB);
692   computeAndAddLiveIns(LiveRegs, *LoopTailMBB);
693   computeAndAddLiveIns(LiveRegs, *DoneMBB);
694 
695   return true;
696 }
697 
698 } // end of anonymous namespace
699 
700 INITIALIZE_PASS(RISCVExpandAtomicPseudo, "riscv-expand-atomic-pseudo",
701                 RISCV_EXPAND_ATOMIC_PSEUDO_NAME, false, false)
702 
703 namespace llvm {
704 
705 FunctionPass *createRISCVExpandAtomicPseudoPass() {
706   return new RISCVExpandAtomicPseudo();
707 }
708 
709 } // end of namespace llvm
710