xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVExpandAtomicPseudoInsts.cpp (revision 9f23cbd6cae82fd77edfad7173432fa8dccd0a95)
1 //===-- RISCVExpandAtomicPseudoInsts.cpp - Expand atomic pseudo instrs. ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains a pass that expands atomic pseudo instructions into
10 // target instructions. This pass should be run at the last possible moment,
11 // avoiding the possibility for other passes to break the requirements for
12 // forward progress in the LR/SC block.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "RISCV.h"
17 #include "RISCVInstrInfo.h"
18 #include "RISCVTargetMachine.h"
19 
20 #include "llvm/CodeGen/LivePhysRegs.h"
21 #include "llvm/CodeGen/MachineFunctionPass.h"
22 #include "llvm/CodeGen/MachineInstrBuilder.h"
23 
24 using namespace llvm;
25 
26 #define RISCV_EXPAND_ATOMIC_PSEUDO_NAME                                        \
27   "RISCV atomic pseudo instruction expansion pass"
28 
29 namespace {
30 
31 class RISCVExpandAtomicPseudo : public MachineFunctionPass {
32 public:
33   const RISCVInstrInfo *TII;
34   static char ID;
35 
36   RISCVExpandAtomicPseudo() : MachineFunctionPass(ID) {
37     initializeRISCVExpandAtomicPseudoPass(*PassRegistry::getPassRegistry());
38   }
39 
40   bool runOnMachineFunction(MachineFunction &MF) override;
41 
42   StringRef getPassName() const override {
43     return RISCV_EXPAND_ATOMIC_PSEUDO_NAME;
44   }
45 
46 private:
47   bool expandMBB(MachineBasicBlock &MBB);
48   bool expandMI(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
49                 MachineBasicBlock::iterator &NextMBBI);
50   bool expandAtomicBinOp(MachineBasicBlock &MBB,
51                          MachineBasicBlock::iterator MBBI, AtomicRMWInst::BinOp,
52                          bool IsMasked, int Width,
53                          MachineBasicBlock::iterator &NextMBBI);
54   bool expandAtomicMinMaxOp(MachineBasicBlock &MBB,
55                             MachineBasicBlock::iterator MBBI,
56                             AtomicRMWInst::BinOp, bool IsMasked, int Width,
57                             MachineBasicBlock::iterator &NextMBBI);
58   bool expandAtomicCmpXchg(MachineBasicBlock &MBB,
59                            MachineBasicBlock::iterator MBBI, bool IsMasked,
60                            int Width, MachineBasicBlock::iterator &NextMBBI);
61 };
62 
63 char RISCVExpandAtomicPseudo::ID = 0;
64 
65 bool RISCVExpandAtomicPseudo::runOnMachineFunction(MachineFunction &MF) {
66   TII = static_cast<const RISCVInstrInfo *>(MF.getSubtarget().getInstrInfo());
67   bool Modified = false;
68   for (auto &MBB : MF)
69     Modified |= expandMBB(MBB);
70   return Modified;
71 }
72 
73 bool RISCVExpandAtomicPseudo::expandMBB(MachineBasicBlock &MBB) {
74   bool Modified = false;
75 
76   MachineBasicBlock::iterator MBBI = MBB.begin(), E = MBB.end();
77   while (MBBI != E) {
78     MachineBasicBlock::iterator NMBBI = std::next(MBBI);
79     Modified |= expandMI(MBB, MBBI, NMBBI);
80     MBBI = NMBBI;
81   }
82 
83   return Modified;
84 }
85 
86 bool RISCVExpandAtomicPseudo::expandMI(MachineBasicBlock &MBB,
87                                        MachineBasicBlock::iterator MBBI,
88                                        MachineBasicBlock::iterator &NextMBBI) {
89   // RISCVInstrInfo::getInstSizeInBytes expects that the total size of the
90   // expanded instructions for each pseudo is correct in the Size field of the
91   // tablegen definition for the pseudo.
92   switch (MBBI->getOpcode()) {
93   case RISCV::PseudoAtomicLoadNand32:
94     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, false, 32,
95                              NextMBBI);
96   case RISCV::PseudoAtomicLoadNand64:
97     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, false, 64,
98                              NextMBBI);
99   case RISCV::PseudoMaskedAtomicSwap32:
100     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Xchg, true, 32,
101                              NextMBBI);
102   case RISCV::PseudoMaskedAtomicLoadAdd32:
103     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Add, true, 32, NextMBBI);
104   case RISCV::PseudoMaskedAtomicLoadSub32:
105     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Sub, true, 32, NextMBBI);
106   case RISCV::PseudoMaskedAtomicLoadNand32:
107     return expandAtomicBinOp(MBB, MBBI, AtomicRMWInst::Nand, true, 32,
108                              NextMBBI);
109   case RISCV::PseudoMaskedAtomicLoadMax32:
110     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::Max, true, 32,
111                                 NextMBBI);
112   case RISCV::PseudoMaskedAtomicLoadMin32:
113     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::Min, true, 32,
114                                 NextMBBI);
115   case RISCV::PseudoMaskedAtomicLoadUMax32:
116     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMax, true, 32,
117                                 NextMBBI);
118   case RISCV::PseudoMaskedAtomicLoadUMin32:
119     return expandAtomicMinMaxOp(MBB, MBBI, AtomicRMWInst::UMin, true, 32,
120                                 NextMBBI);
121   case RISCV::PseudoCmpXchg32:
122     return expandAtomicCmpXchg(MBB, MBBI, false, 32, NextMBBI);
123   case RISCV::PseudoCmpXchg64:
124     return expandAtomicCmpXchg(MBB, MBBI, false, 64, NextMBBI);
125   case RISCV::PseudoMaskedCmpXchg32:
126     return expandAtomicCmpXchg(MBB, MBBI, true, 32, NextMBBI);
127   }
128 
129   return false;
130 }
131 
132 static unsigned getLRForRMW32(AtomicOrdering Ordering) {
133   switch (Ordering) {
134   default:
135     llvm_unreachable("Unexpected AtomicOrdering");
136   case AtomicOrdering::Monotonic:
137     return RISCV::LR_W;
138   case AtomicOrdering::Acquire:
139     return RISCV::LR_W_AQ;
140   case AtomicOrdering::Release:
141     return RISCV::LR_W;
142   case AtomicOrdering::AcquireRelease:
143     return RISCV::LR_W_AQ;
144   case AtomicOrdering::SequentiallyConsistent:
145     return RISCV::LR_W_AQ_RL;
146   }
147 }
148 
149 static unsigned getSCForRMW32(AtomicOrdering Ordering) {
150   switch (Ordering) {
151   default:
152     llvm_unreachable("Unexpected AtomicOrdering");
153   case AtomicOrdering::Monotonic:
154     return RISCV::SC_W;
155   case AtomicOrdering::Acquire:
156     return RISCV::SC_W;
157   case AtomicOrdering::Release:
158     return RISCV::SC_W_RL;
159   case AtomicOrdering::AcquireRelease:
160     return RISCV::SC_W_RL;
161   case AtomicOrdering::SequentiallyConsistent:
162     return RISCV::SC_W_AQ_RL;
163   }
164 }
165 
166 static unsigned getLRForRMW64(AtomicOrdering Ordering) {
167   switch (Ordering) {
168   default:
169     llvm_unreachable("Unexpected AtomicOrdering");
170   case AtomicOrdering::Monotonic:
171     return RISCV::LR_D;
172   case AtomicOrdering::Acquire:
173     return RISCV::LR_D_AQ;
174   case AtomicOrdering::Release:
175     return RISCV::LR_D;
176   case AtomicOrdering::AcquireRelease:
177     return RISCV::LR_D_AQ;
178   case AtomicOrdering::SequentiallyConsistent:
179     return RISCV::LR_D_AQ_RL;
180   }
181 }
182 
183 static unsigned getSCForRMW64(AtomicOrdering Ordering) {
184   switch (Ordering) {
185   default:
186     llvm_unreachable("Unexpected AtomicOrdering");
187   case AtomicOrdering::Monotonic:
188     return RISCV::SC_D;
189   case AtomicOrdering::Acquire:
190     return RISCV::SC_D;
191   case AtomicOrdering::Release:
192     return RISCV::SC_D_RL;
193   case AtomicOrdering::AcquireRelease:
194     return RISCV::SC_D_RL;
195   case AtomicOrdering::SequentiallyConsistent:
196     return RISCV::SC_D_AQ_RL;
197   }
198 }
199 
200 static unsigned getLRForRMW(AtomicOrdering Ordering, int Width) {
201   if (Width == 32)
202     return getLRForRMW32(Ordering);
203   if (Width == 64)
204     return getLRForRMW64(Ordering);
205   llvm_unreachable("Unexpected LR width\n");
206 }
207 
208 static unsigned getSCForRMW(AtomicOrdering Ordering, int Width) {
209   if (Width == 32)
210     return getSCForRMW32(Ordering);
211   if (Width == 64)
212     return getSCForRMW64(Ordering);
213   llvm_unreachable("Unexpected SC width\n");
214 }
215 
216 static void doAtomicBinOpExpansion(const RISCVInstrInfo *TII, MachineInstr &MI,
217                                    DebugLoc DL, MachineBasicBlock *ThisMBB,
218                                    MachineBasicBlock *LoopMBB,
219                                    MachineBasicBlock *DoneMBB,
220                                    AtomicRMWInst::BinOp BinOp, int Width) {
221   Register DestReg = MI.getOperand(0).getReg();
222   Register ScratchReg = MI.getOperand(1).getReg();
223   Register AddrReg = MI.getOperand(2).getReg();
224   Register IncrReg = MI.getOperand(3).getReg();
225   AtomicOrdering Ordering =
226       static_cast<AtomicOrdering>(MI.getOperand(4).getImm());
227 
228   // .loop:
229   //   lr.[w|d] dest, (addr)
230   //   binop scratch, dest, val
231   //   sc.[w|d] scratch, scratch, (addr)
232   //   bnez scratch, loop
233   BuildMI(LoopMBB, DL, TII->get(getLRForRMW(Ordering, Width)), DestReg)
234       .addReg(AddrReg);
235   switch (BinOp) {
236   default:
237     llvm_unreachable("Unexpected AtomicRMW BinOp");
238   case AtomicRMWInst::Nand:
239     BuildMI(LoopMBB, DL, TII->get(RISCV::AND), ScratchReg)
240         .addReg(DestReg)
241         .addReg(IncrReg);
242     BuildMI(LoopMBB, DL, TII->get(RISCV::XORI), ScratchReg)
243         .addReg(ScratchReg)
244         .addImm(-1);
245     break;
246   }
247   BuildMI(LoopMBB, DL, TII->get(getSCForRMW(Ordering, Width)), ScratchReg)
248       .addReg(AddrReg)
249       .addReg(ScratchReg);
250   BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
251       .addReg(ScratchReg)
252       .addReg(RISCV::X0)
253       .addMBB(LoopMBB);
254 }
255 
256 static void insertMaskedMerge(const RISCVInstrInfo *TII, DebugLoc DL,
257                               MachineBasicBlock *MBB, Register DestReg,
258                               Register OldValReg, Register NewValReg,
259                               Register MaskReg, Register ScratchReg) {
260   assert(OldValReg != ScratchReg && "OldValReg and ScratchReg must be unique");
261   assert(OldValReg != MaskReg && "OldValReg and MaskReg must be unique");
262   assert(ScratchReg != MaskReg && "ScratchReg and MaskReg must be unique");
263 
264   // We select bits from newval and oldval using:
265   // https://graphics.stanford.edu/~seander/bithacks.html#MaskedMerge
266   // r = oldval ^ ((oldval ^ newval) & masktargetdata);
267   BuildMI(MBB, DL, TII->get(RISCV::XOR), ScratchReg)
268       .addReg(OldValReg)
269       .addReg(NewValReg);
270   BuildMI(MBB, DL, TII->get(RISCV::AND), ScratchReg)
271       .addReg(ScratchReg)
272       .addReg(MaskReg);
273   BuildMI(MBB, DL, TII->get(RISCV::XOR), DestReg)
274       .addReg(OldValReg)
275       .addReg(ScratchReg);
276 }
277 
278 static void doMaskedAtomicBinOpExpansion(
279     const RISCVInstrInfo *TII, MachineInstr &MI, DebugLoc DL,
280     MachineBasicBlock *ThisMBB, MachineBasicBlock *LoopMBB,
281     MachineBasicBlock *DoneMBB, AtomicRMWInst::BinOp BinOp, int Width) {
282   assert(Width == 32 && "Should never need to expand masked 64-bit operations");
283   Register DestReg = MI.getOperand(0).getReg();
284   Register ScratchReg = MI.getOperand(1).getReg();
285   Register AddrReg = MI.getOperand(2).getReg();
286   Register IncrReg = MI.getOperand(3).getReg();
287   Register MaskReg = MI.getOperand(4).getReg();
288   AtomicOrdering Ordering =
289       static_cast<AtomicOrdering>(MI.getOperand(5).getImm());
290 
291   // .loop:
292   //   lr.w destreg, (alignedaddr)
293   //   binop scratch, destreg, incr
294   //   xor scratch, destreg, scratch
295   //   and scratch, scratch, masktargetdata
296   //   xor scratch, destreg, scratch
297   //   sc.w scratch, scratch, (alignedaddr)
298   //   bnez scratch, loop
299   BuildMI(LoopMBB, DL, TII->get(getLRForRMW32(Ordering)), DestReg)
300       .addReg(AddrReg);
301   switch (BinOp) {
302   default:
303     llvm_unreachable("Unexpected AtomicRMW BinOp");
304   case AtomicRMWInst::Xchg:
305     BuildMI(LoopMBB, DL, TII->get(RISCV::ADDI), ScratchReg)
306         .addReg(IncrReg)
307         .addImm(0);
308     break;
309   case AtomicRMWInst::Add:
310     BuildMI(LoopMBB, DL, TII->get(RISCV::ADD), ScratchReg)
311         .addReg(DestReg)
312         .addReg(IncrReg);
313     break;
314   case AtomicRMWInst::Sub:
315     BuildMI(LoopMBB, DL, TII->get(RISCV::SUB), ScratchReg)
316         .addReg(DestReg)
317         .addReg(IncrReg);
318     break;
319   case AtomicRMWInst::Nand:
320     BuildMI(LoopMBB, DL, TII->get(RISCV::AND), ScratchReg)
321         .addReg(DestReg)
322         .addReg(IncrReg);
323     BuildMI(LoopMBB, DL, TII->get(RISCV::XORI), ScratchReg)
324         .addReg(ScratchReg)
325         .addImm(-1);
326     break;
327   }
328 
329   insertMaskedMerge(TII, DL, LoopMBB, ScratchReg, DestReg, ScratchReg, MaskReg,
330                     ScratchReg);
331 
332   BuildMI(LoopMBB, DL, TII->get(getSCForRMW32(Ordering)), ScratchReg)
333       .addReg(AddrReg)
334       .addReg(ScratchReg);
335   BuildMI(LoopMBB, DL, TII->get(RISCV::BNE))
336       .addReg(ScratchReg)
337       .addReg(RISCV::X0)
338       .addMBB(LoopMBB);
339 }
340 
341 bool RISCVExpandAtomicPseudo::expandAtomicBinOp(
342     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
343     AtomicRMWInst::BinOp BinOp, bool IsMasked, int Width,
344     MachineBasicBlock::iterator &NextMBBI) {
345   MachineInstr &MI = *MBBI;
346   DebugLoc DL = MI.getDebugLoc();
347 
348   MachineFunction *MF = MBB.getParent();
349   auto LoopMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
350   auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
351 
352   // Insert new MBBs.
353   MF->insert(++MBB.getIterator(), LoopMBB);
354   MF->insert(++LoopMBB->getIterator(), DoneMBB);
355 
356   // Set up successors and transfer remaining instructions to DoneMBB.
357   LoopMBB->addSuccessor(LoopMBB);
358   LoopMBB->addSuccessor(DoneMBB);
359   DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
360   DoneMBB->transferSuccessors(&MBB);
361   MBB.addSuccessor(LoopMBB);
362 
363   if (!IsMasked)
364     doAtomicBinOpExpansion(TII, MI, DL, &MBB, LoopMBB, DoneMBB, BinOp, Width);
365   else
366     doMaskedAtomicBinOpExpansion(TII, MI, DL, &MBB, LoopMBB, DoneMBB, BinOp,
367                                  Width);
368 
369   NextMBBI = MBB.end();
370   MI.eraseFromParent();
371 
372   LivePhysRegs LiveRegs;
373   computeAndAddLiveIns(LiveRegs, *LoopMBB);
374   computeAndAddLiveIns(LiveRegs, *DoneMBB);
375 
376   return true;
377 }
378 
379 static void insertSext(const RISCVInstrInfo *TII, DebugLoc DL,
380                        MachineBasicBlock *MBB, Register ValReg,
381                        Register ShamtReg) {
382   BuildMI(MBB, DL, TII->get(RISCV::SLL), ValReg)
383       .addReg(ValReg)
384       .addReg(ShamtReg);
385   BuildMI(MBB, DL, TII->get(RISCV::SRA), ValReg)
386       .addReg(ValReg)
387       .addReg(ShamtReg);
388 }
389 
390 bool RISCVExpandAtomicPseudo::expandAtomicMinMaxOp(
391     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
392     AtomicRMWInst::BinOp BinOp, bool IsMasked, int Width,
393     MachineBasicBlock::iterator &NextMBBI) {
394   assert(IsMasked == true &&
395          "Should only need to expand masked atomic max/min");
396   assert(Width == 32 && "Should never need to expand masked 64-bit operations");
397 
398   MachineInstr &MI = *MBBI;
399   DebugLoc DL = MI.getDebugLoc();
400   MachineFunction *MF = MBB.getParent();
401   auto LoopHeadMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
402   auto LoopIfBodyMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
403   auto LoopTailMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
404   auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
405 
406   // Insert new MBBs.
407   MF->insert(++MBB.getIterator(), LoopHeadMBB);
408   MF->insert(++LoopHeadMBB->getIterator(), LoopIfBodyMBB);
409   MF->insert(++LoopIfBodyMBB->getIterator(), LoopTailMBB);
410   MF->insert(++LoopTailMBB->getIterator(), DoneMBB);
411 
412   // Set up successors and transfer remaining instructions to DoneMBB.
413   LoopHeadMBB->addSuccessor(LoopIfBodyMBB);
414   LoopHeadMBB->addSuccessor(LoopTailMBB);
415   LoopIfBodyMBB->addSuccessor(LoopTailMBB);
416   LoopTailMBB->addSuccessor(LoopHeadMBB);
417   LoopTailMBB->addSuccessor(DoneMBB);
418   DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
419   DoneMBB->transferSuccessors(&MBB);
420   MBB.addSuccessor(LoopHeadMBB);
421 
422   Register DestReg = MI.getOperand(0).getReg();
423   Register Scratch1Reg = MI.getOperand(1).getReg();
424   Register Scratch2Reg = MI.getOperand(2).getReg();
425   Register AddrReg = MI.getOperand(3).getReg();
426   Register IncrReg = MI.getOperand(4).getReg();
427   Register MaskReg = MI.getOperand(5).getReg();
428   bool IsSigned = BinOp == AtomicRMWInst::Min || BinOp == AtomicRMWInst::Max;
429   AtomicOrdering Ordering =
430       static_cast<AtomicOrdering>(MI.getOperand(IsSigned ? 7 : 6).getImm());
431 
432   //
433   // .loophead:
434   //   lr.w destreg, (alignedaddr)
435   //   and scratch2, destreg, mask
436   //   mv scratch1, destreg
437   //   [sext scratch2 if signed min/max]
438   //   ifnochangeneeded scratch2, incr, .looptail
439   BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW32(Ordering)), DestReg)
440       .addReg(AddrReg);
441   BuildMI(LoopHeadMBB, DL, TII->get(RISCV::AND), Scratch2Reg)
442       .addReg(DestReg)
443       .addReg(MaskReg);
444   BuildMI(LoopHeadMBB, DL, TII->get(RISCV::ADDI), Scratch1Reg)
445       .addReg(DestReg)
446       .addImm(0);
447 
448   switch (BinOp) {
449   default:
450     llvm_unreachable("Unexpected AtomicRMW BinOp");
451   case AtomicRMWInst::Max: {
452     insertSext(TII, DL, LoopHeadMBB, Scratch2Reg, MI.getOperand(6).getReg());
453     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGE))
454         .addReg(Scratch2Reg)
455         .addReg(IncrReg)
456         .addMBB(LoopTailMBB);
457     break;
458   }
459   case AtomicRMWInst::Min: {
460     insertSext(TII, DL, LoopHeadMBB, Scratch2Reg, MI.getOperand(6).getReg());
461     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGE))
462         .addReg(IncrReg)
463         .addReg(Scratch2Reg)
464         .addMBB(LoopTailMBB);
465     break;
466   }
467   case AtomicRMWInst::UMax:
468     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGEU))
469         .addReg(Scratch2Reg)
470         .addReg(IncrReg)
471         .addMBB(LoopTailMBB);
472     break;
473   case AtomicRMWInst::UMin:
474     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BGEU))
475         .addReg(IncrReg)
476         .addReg(Scratch2Reg)
477         .addMBB(LoopTailMBB);
478     break;
479   }
480 
481   // .loopifbody:
482   //   xor scratch1, destreg, incr
483   //   and scratch1, scratch1, mask
484   //   xor scratch1, destreg, scratch1
485   insertMaskedMerge(TII, DL, LoopIfBodyMBB, Scratch1Reg, DestReg, IncrReg,
486                     MaskReg, Scratch1Reg);
487 
488   // .looptail:
489   //   sc.w scratch1, scratch1, (addr)
490   //   bnez scratch1, loop
491   BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW32(Ordering)), Scratch1Reg)
492       .addReg(AddrReg)
493       .addReg(Scratch1Reg);
494   BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
495       .addReg(Scratch1Reg)
496       .addReg(RISCV::X0)
497       .addMBB(LoopHeadMBB);
498 
499   NextMBBI = MBB.end();
500   MI.eraseFromParent();
501 
502   LivePhysRegs LiveRegs;
503   computeAndAddLiveIns(LiveRegs, *LoopHeadMBB);
504   computeAndAddLiveIns(LiveRegs, *LoopIfBodyMBB);
505   computeAndAddLiveIns(LiveRegs, *LoopTailMBB);
506   computeAndAddLiveIns(LiveRegs, *DoneMBB);
507 
508   return true;
509 }
510 
511 // If a BNE on the cmpxchg comparison result immediately follows the cmpxchg
512 // operation, it can be folded into the cmpxchg expansion by
513 // modifying the branch within 'LoopHead' (which performs the same
514 // comparison). This is a valid transformation because after altering the
515 // LoopHead's BNE destination, the BNE following the cmpxchg becomes
516 // redundant and and be deleted. In the case of a masked cmpxchg, an
517 // appropriate AND and BNE must be matched.
518 //
519 // On success, returns true and deletes the matching BNE or AND+BNE, sets the
520 // LoopHeadBNETarget argument to the target that should be used within the
521 // loop head, and removes that block as a successor to MBB.
522 bool tryToFoldBNEOnCmpXchgResult(MachineBasicBlock &MBB,
523                                  MachineBasicBlock::iterator MBBI,
524                                  Register DestReg, Register CmpValReg,
525                                  Register MaskReg,
526                                  MachineBasicBlock *&LoopHeadBNETarget) {
527   SmallVector<MachineInstr *> ToErase;
528   auto E = MBB.end();
529   if (MBBI == E)
530     return false;
531   MBBI = skipDebugInstructionsForward(MBBI, E);
532 
533   // If we have a masked cmpxchg, match AND dst, DestReg, MaskReg.
534   if (MaskReg.isValid()) {
535     if (MBBI == E || MBBI->getOpcode() != RISCV::AND)
536       return false;
537     Register ANDOp1 = MBBI->getOperand(1).getReg();
538     Register ANDOp2 = MBBI->getOperand(2).getReg();
539     if (!(ANDOp1 == DestReg && ANDOp2 == MaskReg) &&
540         !(ANDOp1 == MaskReg && ANDOp2 == DestReg))
541       return false;
542     // We now expect the BNE to use the result of the AND as an operand.
543     DestReg = MBBI->getOperand(0).getReg();
544     ToErase.push_back(&*MBBI);
545     MBBI = skipDebugInstructionsForward(std::next(MBBI), E);
546   }
547 
548   // Match BNE DestReg, MaskReg.
549   if (MBBI == E || MBBI->getOpcode() != RISCV::BNE)
550     return false;
551   Register BNEOp0 = MBBI->getOperand(0).getReg();
552   Register BNEOp1 = MBBI->getOperand(1).getReg();
553   if (!(BNEOp0 == DestReg && BNEOp1 == CmpValReg) &&
554       !(BNEOp0 == CmpValReg && BNEOp1 == DestReg))
555     return false;
556   ToErase.push_back(&*MBBI);
557   LoopHeadBNETarget = MBBI->getOperand(2).getMBB();
558   MBBI = skipDebugInstructionsForward(std::next(MBBI), E);
559   if (MBBI != E)
560     return false;
561 
562   MBB.removeSuccessor(LoopHeadBNETarget);
563   for (auto *MI : ToErase)
564     MI->eraseFromParent();
565   return true;
566 }
567 
568 bool RISCVExpandAtomicPseudo::expandAtomicCmpXchg(
569     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI, bool IsMasked,
570     int Width, MachineBasicBlock::iterator &NextMBBI) {
571   MachineInstr &MI = *MBBI;
572   DebugLoc DL = MI.getDebugLoc();
573   MachineFunction *MF = MBB.getParent();
574   auto LoopHeadMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
575   auto LoopTailMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
576   auto DoneMBB = MF->CreateMachineBasicBlock(MBB.getBasicBlock());
577 
578   Register DestReg = MI.getOperand(0).getReg();
579   Register ScratchReg = MI.getOperand(1).getReg();
580   Register AddrReg = MI.getOperand(2).getReg();
581   Register CmpValReg = MI.getOperand(3).getReg();
582   Register NewValReg = MI.getOperand(4).getReg();
583   Register MaskReg = IsMasked ? MI.getOperand(5).getReg() : Register();
584 
585   MachineBasicBlock *LoopHeadBNETarget = DoneMBB;
586   tryToFoldBNEOnCmpXchgResult(MBB, std::next(MBBI), DestReg, CmpValReg, MaskReg,
587                               LoopHeadBNETarget);
588 
589   // Insert new MBBs.
590   MF->insert(++MBB.getIterator(), LoopHeadMBB);
591   MF->insert(++LoopHeadMBB->getIterator(), LoopTailMBB);
592   MF->insert(++LoopTailMBB->getIterator(), DoneMBB);
593 
594   // Set up successors and transfer remaining instructions to DoneMBB.
595   LoopHeadMBB->addSuccessor(LoopTailMBB);
596   LoopHeadMBB->addSuccessor(LoopHeadBNETarget);
597   LoopTailMBB->addSuccessor(DoneMBB);
598   LoopTailMBB->addSuccessor(LoopHeadMBB);
599   DoneMBB->splice(DoneMBB->end(), &MBB, MI, MBB.end());
600   DoneMBB->transferSuccessors(&MBB);
601   MBB.addSuccessor(LoopHeadMBB);
602 
603   AtomicOrdering Ordering =
604       static_cast<AtomicOrdering>(MI.getOperand(IsMasked ? 6 : 5).getImm());
605 
606   if (!IsMasked) {
607     // .loophead:
608     //   lr.[w|d] dest, (addr)
609     //   bne dest, cmpval, done
610     BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(Ordering, Width)), DestReg)
611         .addReg(AddrReg);
612     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BNE))
613         .addReg(DestReg)
614         .addReg(CmpValReg)
615         .addMBB(LoopHeadBNETarget);
616     // .looptail:
617     //   sc.[w|d] scratch, newval, (addr)
618     //   bnez scratch, loophead
619     BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW(Ordering, Width)), ScratchReg)
620         .addReg(AddrReg)
621         .addReg(NewValReg);
622     BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
623         .addReg(ScratchReg)
624         .addReg(RISCV::X0)
625         .addMBB(LoopHeadMBB);
626   } else {
627     // .loophead:
628     //   lr.w dest, (addr)
629     //   and scratch, dest, mask
630     //   bne scratch, cmpval, done
631     Register MaskReg = MI.getOperand(5).getReg();
632     BuildMI(LoopHeadMBB, DL, TII->get(getLRForRMW(Ordering, Width)), DestReg)
633         .addReg(AddrReg);
634     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::AND), ScratchReg)
635         .addReg(DestReg)
636         .addReg(MaskReg);
637     BuildMI(LoopHeadMBB, DL, TII->get(RISCV::BNE))
638         .addReg(ScratchReg)
639         .addReg(CmpValReg)
640         .addMBB(LoopHeadBNETarget);
641 
642     // .looptail:
643     //   xor scratch, dest, newval
644     //   and scratch, scratch, mask
645     //   xor scratch, dest, scratch
646     //   sc.w scratch, scratch, (adrr)
647     //   bnez scratch, loophead
648     insertMaskedMerge(TII, DL, LoopTailMBB, ScratchReg, DestReg, NewValReg,
649                       MaskReg, ScratchReg);
650     BuildMI(LoopTailMBB, DL, TII->get(getSCForRMW(Ordering, Width)), ScratchReg)
651         .addReg(AddrReg)
652         .addReg(ScratchReg);
653     BuildMI(LoopTailMBB, DL, TII->get(RISCV::BNE))
654         .addReg(ScratchReg)
655         .addReg(RISCV::X0)
656         .addMBB(LoopHeadMBB);
657   }
658 
659   NextMBBI = MBB.end();
660   MI.eraseFromParent();
661 
662   LivePhysRegs LiveRegs;
663   computeAndAddLiveIns(LiveRegs, *LoopHeadMBB);
664   computeAndAddLiveIns(LiveRegs, *LoopTailMBB);
665   computeAndAddLiveIns(LiveRegs, *DoneMBB);
666 
667   return true;
668 }
669 
670 } // end of anonymous namespace
671 
672 INITIALIZE_PASS(RISCVExpandAtomicPseudo, "riscv-expand-atomic-pseudo",
673                 RISCV_EXPAND_ATOMIC_PSEUDO_NAME, false, false)
674 
675 namespace llvm {
676 
677 FunctionPass *createRISCVExpandAtomicPseudoPass() {
678   return new RISCVExpandAtomicPseudo();
679 }
680 
681 } // end of namespace llvm
682