xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVExpandAtomicPseudoInsts.cpp (revision 95eb4b873b6a8b527c5bd78d7191975dfca38998)
1 //===-- RISCVExpandAtomicPseudoInsts.cpp - Expand atomic pseudo instrs. ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a pass that expands atomic pseudo instructions into
10 // target instructions. This pass should be run at the last possible moment,
11 // avoiding the possibility for other passes to break the requirements for
12 // forward progress in the LR/SC block.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "RISCV.h"
17 #include "RISCVInstrInfo.h"
18 #include "RISCVTargetMachine.h"
19 
20 #include "llvm/CodeGen/LivePhysRegs.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 
24 using namespace llvm;
25 
26 #define RISCV_EXPAND_ATOMIC_PSEUDO_NAME                                        \
27   "RISC-V atomic pseudo instruction expansion pass"
28 
29 namespace {
30 
31 class RISCVExpandAtomicPseudo : public MachineFunctionPass {
32 public:
33   const RISCVSubtarget *STI;
34   const RISCVInstrInfo *TII;
35   static char ID;
36 
37   RISCVExpandAtomicPseudo() : MachineFunctionPass(ID) {
38     initializeRISCVExpandAtomicPseudoPass(*PassRegistry::getPassRegistry());
39   }
40 
41   bool runOnMachineFunction(MachineFunction &MF) override;
42 
43   StringRef getPassName() const override {
44     return RISCV_EXPAND_ATOMIC_PSEUDO_NAME;
45   }
46 
47 private:
48   bool expandMBB(MachineBasicBlock &MBB);
49   bool expandMI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
50                 MachineBasicBlock::iterator &NextMBBI);
51   bool expandAtomicBinOp(MachineBasicBlock &MBB,
52                          MachineBasicBlock::iterator MBBI, AtomicRMWInst::BinOp,
53                          bool IsMasked, int Width,
54                          MachineBasicBlock::iterator &NextMBBI);
55   bool expandAtomicMinMaxOp(MachineBasicBlock &MBB,
56                             MachineBasicBlock::iterator MBBI,
57                             AtomicRMWInst::BinOp, bool IsMasked, int Width,
58                             MachineBasicBlock::iterator &NextMBBI);
59   bool expandAtomicCmpXchg(MachineBasicBlock &MBB,
60                            MachineBasicBlock::iterator MBBI, bool IsMasked,
61                            int Width, MachineBasicBlock::iterator &NextMBBI);
62 #ifndef NDEBUG
63   unsigned getInstSizeInBytes(const MachineFunction &MF) const {
64     unsigned Size = 0;
65     for (auto &MBB : MF)
66       for (auto &MI : MBB)
67         Size += TII->getInstSizeInBytes(MI);
68     return Size;
69   }
70 #endif
71 };
72 
73 char RISCVExpandAtomicPseudo::ID = 0;
74 
75 bool RISCVExpandAtomicPseudo::runOnMachineFunction(MachineFunction &MF) {
76   STI = &MF.getSubtarget<RISCVSubtarget>();
77   TII = STI->getInstrInfo();
78 
79 #ifndef NDEBUG
80   const unsigned OldSize = getInstSizeInBytes(MF);
81 #endif
82 
83   bool Modified = false;
84   for (auto &MBB : MF)
85     Modified |= expandMBB(MBB);
86 
87 #ifndef NDEBUG
88   const unsigned NewSize = getInstSizeInBytes(MF);
89   assert(OldSize >= NewSize);
90 #endif
91   return Modified;
92 }
93 
94 bool RISCVExpandAtomicPseudo::expandMBB(MachineBasicBlock &MBB) {
95   bool Modified = false;
96 
97   MachineBasicBlock::iterator MBBI = MBB.begin(), E = MBB.end();
98   while (MBBI != E) {
99     MachineBasicBlock::iterator NMBBI = std::next(MBBI);
100     Modified |= expandMI(MBB, MBBI, NMBBI);
101     MBBI = NMBBI;
102   }
103 
104   return Modified;
105 }
106 
107 bool RISCVExpandAtomicPseudo::expandMI(MachineBasicBlock &MBB,
108                                        MachineBasicBlock::iterator MBBI,
109                                        MachineBasicBlock::iterator &NextMBBI) {
110   // RISCVInstrInfo::getInstSizeInBytes expects that the total size of the
111   // expanded instructions for each pseudo is correct in the Size field of the
112   // tablegen definition for the pseudo.
113   switch (MBBI->getOpcode()) {
114   case RISCV::PseudoAtomicLoadNand32:
115     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, false, 32,
116                              NextMBBI);
117   case RISCV::PseudoAtomicLoadNand64:
118     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, false, 64,
119                              NextMBBI);
120   case RISCV::PseudoMaskedAtomicSwap32:
121     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Xchg, true, 32,
122                              NextMBBI);
123   case RISCV::PseudoMaskedAtomicLoadAdd32:
124     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Add, true, 32, NextMBBI);
125   case RISCV::PseudoMaskedAtomicLoadSub32:
126     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Sub, true, 32, NextMBBI);
127   case RISCV::PseudoMaskedAtomicLoadNand32:
128     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, true, 32,
129                              NextMBBI);
130   case RISCV::PseudoMaskedAtomicLoadMax32:
131     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::Max, true, 32,
132                                 NextMBBI);
133   case RISCV::PseudoMaskedAtomicLoadMin32:
134     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::Min, true, 32,
135                                 NextMBBI);
136   case RISCV::PseudoMaskedAtomicLoadUMax32:
137     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMax, true, 32,
138                                 NextMBBI);
139   case RISCV::PseudoMaskedAtomicLoadUMin32:
140     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMin, true, 32,
141                                 NextMBBI);
142   case RISCV::PseudoCmpXchg32:
143     return expandAtomicCmpXchg(MBB, MBBI, false, 32, NextMBBI);
144   case RISCV::PseudoCmpXchg64:
145     return expandAtomicCmpXchg(MBB, MBBI, false, 64, NextMBBI);
146   case RISCV::PseudoMaskedCmpXchg32:
147     return expandAtomicCmpXchg(MBB, MBBI, true, 32, NextMBBI);
148   }
149 
150   return false;
151 }
152 
153 static unsigned getLRForRMW32(AtomicOrdering Ordering,
154                               const RISCVSubtarget *Subtarget) {
155   switch (Ordering) {
156   default:
157     llvm_unreachable("Unexpected AtomicOrdering");
158   case AtomicOrdering::Monotonic:
159     return RISCV::LR_W;
160   case AtomicOrdering::Acquire:
161     if (Subtarget->hasStdExtZtso())
162       return RISCV::LR_W;
163     return RISCV::LR_W_AQ;
164   case AtomicOrdering::Release:
165     return RISCV::LR_W;
166   case AtomicOrdering::AcquireRelease:
167     if (Subtarget->hasStdExtZtso())
168       return RISCV::LR_W;
169     return RISCV::LR_W_AQ;
170   case AtomicOrdering::SequentiallyConsistent:
171     return RISCV::LR_W_AQ_RL;
172   }
173 }
174 
175 static unsigned getSCForRMW32(AtomicOrdering Ordering,
176                               const RISCVSubtarget *Subtarget) {
177   switch (Ordering) {
178   default:
179     llvm_unreachable("Unexpected AtomicOrdering");
180   case AtomicOrdering::Monotonic:
181     return RISCV::SC_W;
182   case AtomicOrdering::Acquire:
183     return RISCV::SC_W;
184   case AtomicOrdering::Release:
185     if (Subtarget->hasStdExtZtso())
186       return RISCV::SC_W;
187     return RISCV::SC_W_RL;
188   case AtomicOrdering::AcquireRelease:
189     if (Subtarget->hasStdExtZtso())
190       return RISCV::SC_W;
191     return RISCV::SC_W_RL;
192   case AtomicOrdering::SequentiallyConsistent:
193     return RISCV::SC_W_RL;
194   }
195 }
196 
197 static unsigned getLRForRMW64(AtomicOrdering Ordering,
198                               const RISCVSubtarget *Subtarget) {
199   switch (Ordering) {
200   default:
201     llvm_unreachable("Unexpected AtomicOrdering");
202   case AtomicOrdering::Monotonic:
203     return RISCV::LR_D;
204   case AtomicOrdering::Acquire:
205     if (Subtarget->hasStdExtZtso())
206       return RISCV::LR_D;
207     return RISCV::LR_D_AQ;
208   case AtomicOrdering::Release:
209     return RISCV::LR_D;
210   case AtomicOrdering::AcquireRelease:
211     if (Subtarget->hasStdExtZtso())
212       return RISCV::LR_D;
213     return RISCV::LR_D_AQ;
214   case AtomicOrdering::SequentiallyConsistent:
215     return RISCV::LR_D_AQ_RL;
216   }
217 }
218 
219 static unsigned getSCForRMW64(AtomicOrdering Ordering,
220                               const RISCVSubtarget *Subtarget) {
221   switch (Ordering) {
222   default:
223     llvm_unreachable("Unexpected AtomicOrdering");
224   case AtomicOrdering::Monotonic:
225     return RISCV::SC_D;
226   case AtomicOrdering::Acquire:
227     return RISCV::SC_D;
228   case AtomicOrdering::Release:
229     if (Subtarget->hasStdExtZtso())
230       return RISCV::SC_D;
231     return RISCV::SC_D_RL;
232   case AtomicOrdering::AcquireRelease:
233     if (Subtarget->hasStdExtZtso())
234       return RISCV::SC_D;
235     return RISCV::SC_D_RL;
236   case AtomicOrdering::SequentiallyConsistent:
237     return RISCV::SC_D_RL;
238   }
239 }
240 
241 static unsigned getLRForRMW(AtomicOrdering Ordering, int Width,
242                             const RISCVSubtarget *Subtarget) {
243   if (Width == 32)
244     return getLRForRMW32(Ordering, Subtarget);
245   if (Width == 64)
246     return getLRForRMW64(Ordering, Subtarget);
247   llvm_unreachable("Unexpected LR width\n");
248 }
249 
250 static unsigned getSCForRMW(AtomicOrdering Ordering, int Width,
251                             const RISCVSubtarget *Subtarget) {
252   if (Width == 32)
253     return getSCForRMW32(Ordering, Subtarget);
254   if (Width == 64)
255     return getSCForRMW64(Ordering, Subtarget);
256   llvm_unreachable("Unexpected SC width\n");
257 }
258 
259 static void doAtomicBinOpExpansion(const RISCVInstrInfo *TII, MachineInstr &MI,
260                                    DebugLoc DL, MachineBasicBlock *ThisMBB,
261                                    MachineBasicBlock *LoopMBB,
262                                    MachineBasicBlock *DoneMBB,
263                                    AtomicRMWInst::BinOp BinOp, int Width,
264                                    const RISCVSubtarget *STI) {
265   Register DestReg = MI.getOperand(0).getReg();
266   Register ScratchReg = MI.getOperand(1).getReg();
267   Register AddrReg = MI.getOperand(2).getReg();
268   Register IncrReg = MI.getOperand(3).getReg();
269   AtomicOrdering Ordering =
270       static_cast<AtomicOrdering>(MI.getOperand(4).getImm());
271 
272   // .loop:
273   //   lr.[w|d] dest, (addr)
274   //   binop scratch, dest, val
275   //   sc.[w|d] scratch, scratch, (addr)
276   //   bnez scratch, loop
277   BuildMI(LoopMBB, DL, TII->get(getLRForRMW(Ordering, Width, STI)), DestReg)
278       .addReg(AddrReg);
279   switch (BinOp) {
280   default:
281     llvm_unreachable("Unexpected AtomicRMW BinOp");
282   case AtomicRMWInst::Nand:
283     BuildMI(LoopMBB, DL, TII->get(RISCV::AND), ScratchReg)
284         .addReg(DestReg)
285         .addReg(IncrReg);
286     BuildMI(LoopMBB, DL, TII->get(RISCV::XORI), ScratchReg)
287         .addReg(ScratchReg)
288         .addImm(-1);
289     break;
290   }
291   BuildMI(LoopMBB, DL, TII->get(getSCForRMW(Ordering, Width, STI)), ScratchReg)
292       .addReg(AddrReg)
293       .addReg(ScratchReg);
294   BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
295       .addReg(ScratchReg)
296       .addReg(RISCV::X0)
297       .addMBB(LoopMBB);
298 }
299 
300 static void insertMaskedMerge(const RISCVInstrInfo *TII, DebugLoc DL,
301                               MachineBasicBlock *MBB, Register DestReg,
302                               Register OldValReg, Register NewValReg,
303                               Register MaskReg, Register ScratchReg) {
304   assert(OldValReg != ScratchReg && "OldValReg and ScratchReg must be unique");
305   assert(OldValReg != MaskReg && "OldValReg and MaskReg must be unique");
306   assert(ScratchReg != MaskReg && "ScratchReg and MaskReg must be unique");
307 
308   // We select bits from newval and oldval using:
309   // https://graphics.stanford.edu/~seander/bithacks.html#MaskedMerge
310   // r = oldval ^ ((oldval ^ newval) & masktargetdata);
311   BuildMI(MBB, DL, TII->get(RISCV::XOR), ScratchReg)
312       .addReg(OldValReg)
313       .addReg(NewValReg);
314   BuildMI(MBB, DL, TII->get(RISCV::AND), ScratchReg)
315       .addReg(ScratchReg)
316       .addReg(MaskReg);
317   BuildMI(MBB, DL, TII->get(RISCV::XOR), DestReg)
318       .addReg(OldValReg)
319       .addReg(ScratchReg);
320 }
321 
322 static void doMaskedAtomicBinOpExpansion(const RISCVInstrInfo *TII,
323                                          MachineInstr &MI, DebugLoc DL,
324                                          MachineBasicBlock *ThisMBB,
325                                          MachineBasicBlock *LoopMBB,
326                                          MachineBasicBlock *DoneMBB,
327                                          AtomicRMWInst::BinOp BinOp, int Width,
328                                          const RISCVSubtarget *STI) {
329   assert(Width == 32 && "Should never need to expand masked 64-bit operations");
330   Register DestReg = MI.getOperand(0).getReg();
331   Register ScratchReg = MI.getOperand(1).getReg();
332   Register AddrReg = MI.getOperand(2).getReg();
333   Register IncrReg = MI.getOperand(3).getReg();
334   Register MaskReg = MI.getOperand(4).getReg();
335   AtomicOrdering Ordering =
336       static_cast<AtomicOrdering>(MI.getOperand(5).getImm());
337 
338   // .loop:
339   //   lr.w destreg, (alignedaddr)
340   //   binop scratch, destreg, incr
341   //   xor scratch, destreg, scratch
342   //   and scratch, scratch, masktargetdata
343   //   xor scratch, destreg, scratch
344   //   sc.w scratch, scratch, (alignedaddr)
345   //   bnez scratch, loop
346   BuildMI(LoopMBB, DL, TII->get(getLRForRMW32(Ordering, STI)), DestReg)
347       .addReg(AddrReg);
348   switch (BinOp) {
349   default:
350     llvm_unreachable("Unexpected AtomicRMW BinOp");
351   case AtomicRMWInst::Xchg:
352     BuildMI(LoopMBB, DL, TII->get(RISCV::ADDI), ScratchReg)
353         .addReg(IncrReg)
354         .addImm(0);
355     break;
356   case AtomicRMWInst::Add:
357     BuildMI(LoopMBB, DL, TII->get(RISCV::ADD), ScratchReg)
358         .addReg(DestReg)
359         .addReg(IncrReg);
360     break;
361   case AtomicRMWInst::Sub:
362     BuildMI(LoopMBB, DL, TII->get(RISCV::SUB), ScratchReg)
363         .addReg(DestReg)
364         .addReg(IncrReg);
365     break;
366   case AtomicRMWInst::Nand:
367     BuildMI(LoopMBB, DL, TII->get(RISCV::AND), ScratchReg)
368         .addReg(DestReg)
369         .addReg(IncrReg);
370     BuildMI(LoopMBB, DL, TII->get(RISCV::XORI), ScratchReg)
371         .addReg(ScratchReg)
372         .addImm(-1);
373     break;
374   }
375 
376   insertMaskedMerge(TII, DL, LoopMBB, ScratchReg, DestReg, ScratchReg, MaskReg,
377                     ScratchReg);
378 
379   BuildMI(LoopMBB, DL, TII->get(getSCForRMW32(Ordering, STI)), ScratchReg)
380       .addReg(AddrReg)
381       .addReg(ScratchReg);
382   BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
383       .addReg(ScratchReg)
384       .addReg(RISCV::X0)
385       .addMBB(LoopMBB);
386 }
387 
388 bool RISCVExpandAtomicPseudo::expandAtomicBinOp(
389     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
390     AtomicRMWInst::BinOp BinOp, bool IsMasked, int Width,
391     MachineBasicBlock::iterator &NextMBBI) {
392   MachineInstr &MI = *MBBI;
393   DebugLoc DL = MI.getDebugLoc();
394 
395   MachineFunction *MF = MBB.getParent();
396   auto LoopMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
397   auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
398 
399   // Insert new MBBs.
400   MF->insert(++MBB.getIterator(), LoopMBB);
401   MF->insert(++LoopMBB->getIterator(), DoneMBB);
402 
403   // Set up successors and transfer remaining instructions to DoneMBB.
404   LoopMBB->addSuccessor(LoopMBB);
405   LoopMBB->addSuccessor(DoneMBB);
406   DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
407   DoneMBB->transferSuccessors(&MBB);
408   MBB.addSuccessor(LoopMBB);
409 
410   if (!IsMasked)
411     doAtomicBinOpExpansion(TII, MI, DL, &MBB, LoopMBB, DoneMBB, BinOp, Width,
412                            STI);
413   else
414     doMaskedAtomicBinOpExpansion(TII, MI, DL, &MBB, LoopMBB, DoneMBB, BinOp,
415                                  Width, STI);
416 
417   NextMBBI = MBB.end();
418   MI.eraseFromParent();
419 
420   LivePhysRegs LiveRegs;
421   computeAndAddLiveIns(LiveRegs, *LoopMBB);
422   computeAndAddLiveIns(LiveRegs, *DoneMBB);
423 
424   return true;
425 }
426 
427 static void insertSext(const RISCVInstrInfo *TII, DebugLoc DL,
428                        MachineBasicBlock *MBB, Register ValReg,
429                        Register ShamtReg) {
430   BuildMI(MBB, DL, TII->get(RISCV::SLL), ValReg)
431       .addReg(ValReg)
432       .addReg(ShamtReg);
433   BuildMI(MBB, DL, TII->get(RISCV::SRA), ValReg)
434       .addReg(ValReg)
435       .addReg(ShamtReg);
436 }
437 
438 bool RISCVExpandAtomicPseudo::expandAtomicMinMaxOp(
439     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
440     AtomicRMWInst::BinOp BinOp, bool IsMasked, int Width,
441     MachineBasicBlock::iterator &NextMBBI) {
442   assert(IsMasked == true &&
443          "Should only need to expand masked atomic max/min");
444   assert(Width == 32 && "Should never need to expand masked 64-bit operations");
445 
446   MachineInstr &MI = *MBBI;
447   DebugLoc DL = MI.getDebugLoc();
448   MachineFunction *MF = MBB.getParent();
449   auto LoopHeadMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
450   auto LoopIfBodyMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
451   auto LoopTailMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
452   auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
453 
454   // Insert new MBBs.
455   MF->insert(++MBB.getIterator(), LoopHeadMBB);
456   MF->insert(++LoopHeadMBB->getIterator(), LoopIfBodyMBB);
457   MF->insert(++LoopIfBodyMBB->getIterator(), LoopTailMBB);
458   MF->insert(++LoopTailMBB->getIterator(), DoneMBB);
459 
460   // Set up successors and transfer remaining instructions to DoneMBB.
461   LoopHeadMBB->addSuccessor(LoopIfBodyMBB);
462   LoopHeadMBB->addSuccessor(LoopTailMBB);
463   LoopIfBodyMBB->addSuccessor(LoopTailMBB);
464   LoopTailMBB->addSuccessor(LoopHeadMBB);
465   LoopTailMBB->addSuccessor(DoneMBB);
466   DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
467   DoneMBB->transferSuccessors(&MBB);
468   MBB.addSuccessor(LoopHeadMBB);
469 
470   Register DestReg = MI.getOperand(0).getReg();
471   Register Scratch1Reg = MI.getOperand(1).getReg();
472   Register Scratch2Reg = MI.getOperand(2).getReg();
473   Register AddrReg = MI.getOperand(3).getReg();
474   Register IncrReg = MI.getOperand(4).getReg();
475   Register MaskReg = MI.getOperand(5).getReg();
476   bool IsSigned = BinOp == AtomicRMWInst::Min || BinOp == AtomicRMWInst::Max;
477   AtomicOrdering Ordering =
478       static_cast<AtomicOrdering>(MI.getOperand(IsSigned ? 7 : 6).getImm());
479 
480   //
481   // .loophead:
482   //   lr.w destreg, (alignedaddr)
483   //   and scratch2, destreg, mask
484   //   mv scratch1, destreg
485   //   [sext scratch2 if signed min/max]
486   //   ifnochangeneeded scratch2, incr, .looptail
487   BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW32(Ordering, STI)), DestReg)
488       .addReg(AddrReg);
489   BuildMI(LoopHeadMBB, DL, TII->get(RISCV::AND), Scratch2Reg)
490       .addReg(DestReg)
491       .addReg(MaskReg);
492   BuildMI(LoopHeadMBB, DL, TII->get(RISCV::ADDI), Scratch1Reg)
493       .addReg(DestReg)
494       .addImm(0);
495 
496   switch (BinOp) {
497   default:
498     llvm_unreachable("Unexpected AtomicRMW BinOp");
499   case AtomicRMWInst::Max: {
500     insertSext(TII, DL, LoopHeadMBB, Scratch2Reg, MI.getOperand(6).getReg());
501     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGE))
502         .addReg(Scratch2Reg)
503         .addReg(IncrReg)
504         .addMBB(LoopTailMBB);
505     break;
506   }
507   case AtomicRMWInst::Min: {
508     insertSext(TII, DL, LoopHeadMBB, Scratch2Reg, MI.getOperand(6).getReg());
509     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGE))
510         .addReg(IncrReg)
511         .addReg(Scratch2Reg)
512         .addMBB(LoopTailMBB);
513     break;
514   }
515   case AtomicRMWInst::UMax:
516     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGEU))
517         .addReg(Scratch2Reg)
518         .addReg(IncrReg)
519         .addMBB(LoopTailMBB);
520     break;
521   case AtomicRMWInst::UMin:
522     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGEU))
523         .addReg(IncrReg)
524         .addReg(Scratch2Reg)
525         .addMBB(LoopTailMBB);
526     break;
527   }
528 
529   // .loopifbody:
530   //   xor scratch1, destreg, incr
531   //   and scratch1, scratch1, mask
532   //   xor scratch1, destreg, scratch1
533   insertMaskedMerge(TII, DL, LoopIfBodyMBB, Scratch1Reg, DestReg, IncrReg,
534                     MaskReg, Scratch1Reg);
535 
536   // .looptail:
537   //   sc.w scratch1, scratch1, (addr)
538   //   bnez scratch1, loop
539   BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW32(Ordering, STI)), Scratch1Reg)
540       .addReg(AddrReg)
541       .addReg(Scratch1Reg);
542   BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
543       .addReg(Scratch1Reg)
544       .addReg(RISCV::X0)
545       .addMBB(LoopHeadMBB);
546 
547   NextMBBI = MBB.end();
548   MI.eraseFromParent();
549 
550   LivePhysRegs LiveRegs;
551   computeAndAddLiveIns(LiveRegs, *LoopHeadMBB);
552   computeAndAddLiveIns(LiveRegs, *LoopIfBodyMBB);
553   computeAndAddLiveIns(LiveRegs, *LoopTailMBB);
554   computeAndAddLiveIns(LiveRegs, *DoneMBB);
555 
556   return true;
557 }
558 
559 // If a BNE on the cmpxchg comparison result immediately follows the cmpxchg
560 // operation, it can be folded into the cmpxchg expansion by
561 // modifying the branch within 'LoopHead' (which performs the same
562 // comparison). This is a valid transformation because after altering the
563 // LoopHead's BNE destination, the BNE following the cmpxchg becomes
564 // redundant and and be deleted. In the case of a masked cmpxchg, an
565 // appropriate AND and BNE must be matched.
566 //
567 // On success, returns true and deletes the matching BNE or AND+BNE, sets the
568 // LoopHeadBNETarget argument to the target that should be used within the
569 // loop head, and removes that block as a successor to MBB.
570 bool tryToFoldBNEOnCmpXchgResult(MachineBasicBlock &MBB,
571                                  MachineBasicBlock::iterator MBBI,
572                                  Register DestReg, Register CmpValReg,
573                                  Register MaskReg,
574                                  MachineBasicBlock *&LoopHeadBNETarget) {
575   SmallVector<MachineInstr *> ToErase;
576   auto E = MBB.end();
577   if (MBBI == E)
578     return false;
579   MBBI = skipDebugInstructionsForward(MBBI, E);
580 
581   // If we have a masked cmpxchg, match AND dst, DestReg, MaskReg.
582   if (MaskReg.isValid()) {
583     if (MBBI == E || MBBI->getOpcode() != RISCV::AND)
584       return false;
585     Register ANDOp1 = MBBI->getOperand(1).getReg();
586     Register ANDOp2 = MBBI->getOperand(2).getReg();
587     if (!(ANDOp1 == DestReg && ANDOp2 == MaskReg) &&
588         !(ANDOp1 == MaskReg && ANDOp2 == DestReg))
589       return false;
590     // We now expect the BNE to use the result of the AND as an operand.
591     DestReg = MBBI->getOperand(0).getReg();
592     ToErase.push_back(&*MBBI);
593     MBBI = skipDebugInstructionsForward(std::next(MBBI), E);
594   }
595 
596   // Match BNE DestReg, MaskReg.
597   if (MBBI == E || MBBI->getOpcode() != RISCV::BNE)
598     return false;
599   Register BNEOp0 = MBBI->getOperand(0).getReg();
600   Register BNEOp1 = MBBI->getOperand(1).getReg();
601   if (!(BNEOp0 == DestReg && BNEOp1 == CmpValReg) &&
602       !(BNEOp0 == CmpValReg && BNEOp1 == DestReg))
603     return false;
604 
605   // Make sure the branch is the only user of the AND.
606   if (MaskReg.isValid()) {
607     if (BNEOp0 == DestReg && !MBBI->getOperand(0).isKill())
608       return false;
609     if (BNEOp1 == DestReg && !MBBI->getOperand(1).isKill())
610       return false;
611   }
612 
613   ToErase.push_back(&*MBBI);
614   LoopHeadBNETarget = MBBI->getOperand(2).getMBB();
615   MBBI = skipDebugInstructionsForward(std::next(MBBI), E);
616   if (MBBI != E)
617     return false;
618 
619   MBB.removeSuccessor(LoopHeadBNETarget);
620   for (auto *MI : ToErase)
621     MI->eraseFromParent();
622   return true;
623 }
624 
625 bool RISCVExpandAtomicPseudo::expandAtomicCmpXchg(
626     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, bool IsMasked,
627     int Width, MachineBasicBlock::iterator &NextMBBI) {
628   MachineInstr &MI = *MBBI;
629   DebugLoc DL = MI.getDebugLoc();
630   MachineFunction *MF = MBB.getParent();
631   auto LoopHeadMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
632   auto LoopTailMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
633   auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
634 
635   Register DestReg = MI.getOperand(0).getReg();
636   Register ScratchReg = MI.getOperand(1).getReg();
637   Register AddrReg = MI.getOperand(2).getReg();
638   Register CmpValReg = MI.getOperand(3).getReg();
639   Register NewValReg = MI.getOperand(4).getReg();
640   Register MaskReg = IsMasked ? MI.getOperand(5).getReg() : Register();
641 
642   MachineBasicBlock *LoopHeadBNETarget = DoneMBB;
643   tryToFoldBNEOnCmpXchgResult(MBB, std::next(MBBI), DestReg, CmpValReg, MaskReg,
644                               LoopHeadBNETarget);
645 
646   // Insert new MBBs.
647   MF->insert(++MBB.getIterator(), LoopHeadMBB);
648   MF->insert(++LoopHeadMBB->getIterator(), LoopTailMBB);
649   MF->insert(++LoopTailMBB->getIterator(), DoneMBB);
650 
651   // Set up successors and transfer remaining instructions to DoneMBB.
652   LoopHeadMBB->addSuccessor(LoopTailMBB);
653   LoopHeadMBB->addSuccessor(LoopHeadBNETarget);
654   LoopTailMBB->addSuccessor(DoneMBB);
655   LoopTailMBB->addSuccessor(LoopHeadMBB);
656   DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
657   DoneMBB->transferSuccessors(&MBB);
658   MBB.addSuccessor(LoopHeadMBB);
659 
660   AtomicOrdering Ordering =
661       static_cast<AtomicOrdering>(MI.getOperand(IsMasked ? 6 : 5).getImm());
662 
663   if (!IsMasked) {
664     // .loophead:
665     //   lr.[w|d] dest, (addr)
666     //   bne dest, cmpval, done
667     BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(Ordering, Width, STI)),
668             DestReg)
669         .addReg(AddrReg);
670     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BNE))
671         .addReg(DestReg)
672         .addReg(CmpValReg)
673         .addMBB(LoopHeadBNETarget);
674     // .looptail:
675     //   sc.[w|d] scratch, newval, (addr)
676     //   bnez scratch, loophead
677     BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW(Ordering, Width, STI)),
678             ScratchReg)
679         .addReg(AddrReg)
680         .addReg(NewValReg);
681     BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
682         .addReg(ScratchReg)
683         .addReg(RISCV::X0)
684         .addMBB(LoopHeadMBB);
685   } else {
686     // .loophead:
687     //   lr.w dest, (addr)
688     //   and scratch, dest, mask
689     //   bne scratch, cmpval, done
690     Register MaskReg = MI.getOperand(5).getReg();
691     BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(Ordering, Width, STI)),
692             DestReg)
693         .addReg(AddrReg);
694     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::AND), ScratchReg)
695         .addReg(DestReg)
696         .addReg(MaskReg);
697     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BNE))
698         .addReg(ScratchReg)
699         .addReg(CmpValReg)
700         .addMBB(LoopHeadBNETarget);
701 
702     // .looptail:
703     //   xor scratch, dest, newval
704     //   and scratch, scratch, mask
705     //   xor scratch, dest, scratch
706     //   sc.w scratch, scratch, (adrr)
707     //   bnez scratch, loophead
708     insertMaskedMerge(TII, DL, LoopTailMBB, ScratchReg, DestReg, NewValReg,
709                       MaskReg, ScratchReg);
710     BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW(Ordering, Width, STI)),
711             ScratchReg)
712         .addReg(AddrReg)
713         .addReg(ScratchReg);
714     BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
715         .addReg(ScratchReg)
716         .addReg(RISCV::X0)
717         .addMBB(LoopHeadMBB);
718   }
719 
720   NextMBBI = MBB.end();
721   MI.eraseFromParent();
722 
723   LivePhysRegs LiveRegs;
724   computeAndAddLiveIns(LiveRegs, *LoopHeadMBB);
725   computeAndAddLiveIns(LiveRegs, *LoopTailMBB);
726   computeAndAddLiveIns(LiveRegs, *DoneMBB);
727 
728   return true;
729 }
730 
731 } // end of anonymous namespace
732 
733 INITIALIZE_PASS(RISCVExpandAtomicPseudo, "riscv-expand-atomic-pseudo",
734                 RISCV_EXPAND_ATOMIC_PSEUDO_NAME, false, false)
735 
736 namespace llvm {
737 
738 FunctionPass *createRISCVExpandAtomicPseudoPass() {
739   return new RISCVExpandAtomicPseudo();
740 }
741 
742 } // end of namespace llvm
743