xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- AArch64SelectionDAGInfo.cpp - AArch64 SelectionDAG Info -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AArch64SelectionDAGInfo class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "AArch64SelectionDAGInfo.h"
14 #include "AArch64MachineFunctionInfo.h"
15 
16 #define GET_SDNODE_DESC
17 #include "AArch64GenSDNodeInfo.inc"
18 #undef GET_SDNODE_DESC
19 
20 using namespace llvm;
21 
22 #define DEBUG_TYPE "aarch64-selectiondag-info"
23 
24 static cl::opt<bool>
25     LowerToSMERoutines("aarch64-lower-to-sme-routines", cl::Hidden,
26                        cl::desc("Enable AArch64 SME memory operations "
27                                 "to lower to librt functions"),
28                        cl::init(true));
29 
AArch64SelectionDAGInfo()30 AArch64SelectionDAGInfo::AArch64SelectionDAGInfo()
31     : SelectionDAGGenTargetInfo(AArch64GenSDNodeInfo) {}
32 
verifyTargetNode(const SelectionDAG & DAG,const SDNode * N) const33 void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
34                                                const SDNode *N) const {
35 #ifndef NDEBUG
36   switch (N->getOpcode()) {
37   default:
38     return SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N);
39   case AArch64ISD::SADDWT:
40   case AArch64ISD::SADDWB:
41   case AArch64ISD::UADDWT:
42   case AArch64ISD::UADDWB: {
43     assert(N->getNumValues() == 1 && "Expected one result!");
44     assert(N->getNumOperands() == 2 && "Expected two operands!");
45     EVT VT = N->getValueType(0);
46     EVT Op0VT = N->getOperand(0).getValueType();
47     EVT Op1VT = N->getOperand(1).getValueType();
48     assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
49            VT.isInteger() && Op0VT.isInteger() && Op1VT.isInteger() &&
50            "Expected integer vectors!");
51     assert(VT == Op0VT &&
52            "Expected result and first input to have the same type!");
53     assert(Op0VT.getSizeInBits() == Op1VT.getSizeInBits() &&
54            "Expected vectors of equal size!");
55     assert(Op0VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount() &&
56            "Expected result vector and first input vector to have half the "
57            "lanes of the second input vector!");
58     break;
59   }
60   case AArch64ISD::SUNPKLO:
61   case AArch64ISD::SUNPKHI:
62   case AArch64ISD::UUNPKLO:
63   case AArch64ISD::UUNPKHI: {
64     assert(N->getNumValues() == 1 && "Expected one result!");
65     assert(N->getNumOperands() == 1 && "Expected one operand!");
66     EVT VT = N->getValueType(0);
67     EVT OpVT = N->getOperand(0).getValueType();
68     assert(OpVT.isVector() && VT.isVector() && OpVT.isInteger() &&
69            VT.isInteger() && "Expected integer vectors!");
70     assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
71            "Expected vectors of equal size!");
72     assert(OpVT.getVectorElementCount() == VT.getVectorElementCount() * 2 &&
73            "Expected result vector with half the lanes of its input!");
74     break;
75   }
76   case AArch64ISD::TRN1:
77   case AArch64ISD::TRN2:
78   case AArch64ISD::UZP1:
79   case AArch64ISD::UZP2:
80   case AArch64ISD::ZIP1:
81   case AArch64ISD::ZIP2: {
82     assert(N->getNumValues() == 1 && "Expected one result!");
83     assert(N->getNumOperands() == 2 && "Expected two operands!");
84     EVT VT = N->getValueType(0);
85     EVT Op0VT = N->getOperand(0).getValueType();
86     EVT Op1VT = N->getOperand(1).getValueType();
87     assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
88            "Expected vectors!");
89     assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
90     break;
91   }
92   case AArch64ISD::RSHRNB_I: {
93     assert(N->getNumValues() == 1 && "Expected one result!");
94     assert(N->getNumOperands() == 2 && "Expected two operands!");
95     EVT VT = N->getValueType(0);
96     EVT Op0VT = N->getOperand(0).getValueType();
97     EVT Op1VT = N->getOperand(1).getValueType();
98     assert(VT.isVector() && VT.isInteger() &&
99            "Expected integer vector result type!");
100     assert(Op0VT.isVector() && Op0VT.isInteger() &&
101            "Expected first operand to be an integer vector!");
102     assert(VT.getSizeInBits() == Op0VT.getSizeInBits() &&
103            "Expected vectors of equal size!");
104     assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 &&
105            "Expected input vector with half the lanes of its result!");
106     assert(Op1VT == MVT::i32 && isa<ConstantSDNode>(N->getOperand(1)) &&
107            "Expected second operand to be a constant i32!");
108     break;
109   }
110   }
111 #endif
112 }
113 
EmitMOPS(unsigned Opcode,SelectionDAG & DAG,const SDLoc & DL,SDValue Chain,SDValue Dst,SDValue SrcOrValue,SDValue Size,Align Alignment,bool isVolatile,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo) const114 SDValue AArch64SelectionDAGInfo::EmitMOPS(unsigned Opcode, SelectionDAG &DAG,
115                                           const SDLoc &DL, SDValue Chain,
116                                           SDValue Dst, SDValue SrcOrValue,
117                                           SDValue Size, Align Alignment,
118                                           bool isVolatile,
119                                           MachinePointerInfo DstPtrInfo,
120                                           MachinePointerInfo SrcPtrInfo) const {
121 
122   // Get the constant size of the copy/set.
123   uint64_t ConstSize = 0;
124   if (auto *C = dyn_cast<ConstantSDNode>(Size))
125     ConstSize = C->getZExtValue();
126 
127   const bool IsSet = Opcode == AArch64::MOPSMemorySetPseudo ||
128                      Opcode == AArch64::MOPSMemorySetTaggingPseudo;
129 
130   MachineFunction &MF = DAG.getMachineFunction();
131 
132   auto Vol =
133       isVolatile ? MachineMemOperand::MOVolatile : MachineMemOperand::MONone;
134   auto DstFlags = MachineMemOperand::MOStore | Vol;
135   auto *DstOp =
136       MF.getMachineMemOperand(DstPtrInfo, DstFlags, ConstSize, Alignment);
137 
138   if (IsSet) {
139     // Extend value to i64, if required.
140     if (SrcOrValue.getValueType() != MVT::i64)
141       SrcOrValue = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, SrcOrValue);
142     SDValue Ops[] = {Dst, Size, SrcOrValue, Chain};
143     const EVT ResultTys[] = {MVT::i64, MVT::i64, MVT::Other};
144     MachineSDNode *Node = DAG.getMachineNode(Opcode, DL, ResultTys, Ops);
145     DAG.setNodeMemRefs(Node, {DstOp});
146     return SDValue(Node, 2);
147   } else {
148     SDValue Ops[] = {Dst, SrcOrValue, Size, Chain};
149     const EVT ResultTys[] = {MVT::i64, MVT::i64, MVT::i64, MVT::Other};
150     MachineSDNode *Node = DAG.getMachineNode(Opcode, DL, ResultTys, Ops);
151 
152     auto SrcFlags = MachineMemOperand::MOLoad | Vol;
153     auto *SrcOp =
154         MF.getMachineMemOperand(SrcPtrInfo, SrcFlags, ConstSize, Alignment);
155     DAG.setNodeMemRefs(Node, {DstOp, SrcOp});
156     return SDValue(Node, 3);
157   }
158 }
159 
EmitStreamingCompatibleMemLibCall(SelectionDAG & DAG,const SDLoc & DL,SDValue Chain,SDValue Dst,SDValue Src,SDValue Size,RTLIB::Libcall LC) const160 SDValue AArch64SelectionDAGInfo::EmitStreamingCompatibleMemLibCall(
161     SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
162     SDValue Size, RTLIB::Libcall LC) const {
163   const AArch64Subtarget &STI =
164       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
165   const AArch64TargetLowering *TLI = STI.getTargetLowering();
166   TargetLowering::ArgListEntry DstEntry;
167   DstEntry.Ty = PointerType::getUnqual(*DAG.getContext());
168   DstEntry.Node = Dst;
169   TargetLowering::ArgListTy Args;
170   Args.push_back(DstEntry);
171 
172   RTLIB::Libcall NewLC;
173   switch (LC) {
174   case RTLIB::MEMCPY: {
175     NewLC = RTLIB::SC_MEMCPY;
176     TargetLowering::ArgListEntry Entry;
177     Entry.Ty = PointerType::getUnqual(*DAG.getContext());
178     Entry.Node = Src;
179     Args.push_back(Entry);
180     break;
181   }
182   case RTLIB::MEMMOVE: {
183     NewLC = RTLIB::SC_MEMMOVE;
184     TargetLowering::ArgListEntry Entry;
185     Entry.Ty = PointerType::getUnqual(*DAG.getContext());
186     Entry.Node = Src;
187     Args.push_back(Entry);
188     break;
189   }
190   case RTLIB::MEMSET: {
191     NewLC = RTLIB::SC_MEMSET;
192     TargetLowering::ArgListEntry Entry;
193     Entry.Ty = Type::getInt32Ty(*DAG.getContext());
194     Src = DAG.getZExtOrTrunc(Src, DL, MVT::i32);
195     Entry.Node = Src;
196     Args.push_back(Entry);
197     break;
198   }
199   default:
200     return SDValue();
201   }
202 
203   EVT PointerVT = TLI->getPointerTy(DAG.getDataLayout());
204   SDValue Symbol = DAG.getExternalSymbol(TLI->getLibcallName(NewLC), PointerVT);
205   TargetLowering::ArgListEntry SizeEntry;
206   SizeEntry.Node = Size;
207   SizeEntry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
208   Args.push_back(SizeEntry);
209 
210   TargetLowering::CallLoweringInfo CLI(DAG);
211   PointerType *RetTy = PointerType::getUnqual(*DAG.getContext());
212   CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
213       TLI->getLibcallCallingConv(NewLC), RetTy, Symbol, std::move(Args));
214   return TLI->LowerCallTo(CLI).second;
215 }
216 
EmitTargetCodeForMemcpy(SelectionDAG & DAG,const SDLoc & DL,SDValue Chain,SDValue Dst,SDValue Src,SDValue Size,Align Alignment,bool isVolatile,bool AlwaysInline,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo) const217 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
218     SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
219     SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
220     MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
221   const AArch64Subtarget &STI =
222       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
223 
224   if (STI.hasMOPS())
225     return EmitMOPS(AArch64::MOPSMemoryCopyPseudo, DAG, DL, Chain, Dst, Src,
226                     Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
227 
228   auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
229   SMEAttrs Attrs = AFI->getSMEFnAttrs();
230   if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
231     return EmitStreamingCompatibleMemLibCall(DAG, DL, Chain, Dst, Src, Size,
232                                              RTLIB::MEMCPY);
233   return SDValue();
234 }
235 
EmitTargetCodeForMemset(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Dst,SDValue Src,SDValue Size,Align Alignment,bool isVolatile,bool AlwaysInline,MachinePointerInfo DstPtrInfo) const236 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
237     SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Dst, SDValue Src,
238     SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
239     MachinePointerInfo DstPtrInfo) const {
240   const AArch64Subtarget &STI =
241       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
242 
243   if (STI.hasMOPS())
244     return EmitMOPS(AArch64::MOPSMemorySetPseudo, DAG, dl, Chain, Dst, Src,
245                     Size, Alignment, isVolatile, DstPtrInfo,
246                     MachinePointerInfo{});
247 
248   auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
249   SMEAttrs Attrs = AFI->getSMEFnAttrs();
250   if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
251     return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
252                                              RTLIB::MEMSET);
253   return SDValue();
254 }
255 
EmitTargetCodeForMemmove(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Dst,SDValue Src,SDValue Size,Align Alignment,bool isVolatile,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo) const256 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
257     SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Dst, SDValue Src,
258     SDValue Size, Align Alignment, bool isVolatile,
259     MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
260   const AArch64Subtarget &STI =
261       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
262 
263   if (STI.hasMOPS())
264     return EmitMOPS(AArch64::MOPSMemoryMovePseudo, DAG, dl, Chain, Dst, Src,
265                     Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
266 
267   auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
268   SMEAttrs Attrs = AFI->getSMEFnAttrs();
269   if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
270     return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
271                                              RTLIB::MEMMOVE);
272   return SDValue();
273 }
274 
275 static const int kSetTagLoopThreshold = 176;
276 
EmitUnrolledSetTag(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Ptr,uint64_t ObjSize,const MachineMemOperand * BaseMemOperand,bool ZeroData)277 static SDValue EmitUnrolledSetTag(SelectionDAG &DAG, const SDLoc &dl,
278                                   SDValue Chain, SDValue Ptr, uint64_t ObjSize,
279                                   const MachineMemOperand *BaseMemOperand,
280                                   bool ZeroData) {
281   MachineFunction &MF = DAG.getMachineFunction();
282   unsigned ObjSizeScaled = ObjSize / 16;
283 
284   SDValue TagSrc = Ptr;
285   if (Ptr.getOpcode() == ISD::FrameIndex) {
286     int FI = cast<FrameIndexSDNode>(Ptr)->getIndex();
287     Ptr = DAG.getTargetFrameIndex(FI, MVT::i64);
288     // A frame index operand may end up as [SP + offset] => it is fine to use SP
289     // register as the tag source.
290     TagSrc = DAG.getRegister(AArch64::SP, MVT::i64);
291   }
292 
293   const unsigned OpCode1 = ZeroData ? AArch64ISD::STZG : AArch64ISD::STG;
294   const unsigned OpCode2 = ZeroData ? AArch64ISD::STZ2G : AArch64ISD::ST2G;
295 
296   SmallVector<SDValue, 8> OutChains;
297   unsigned OffsetScaled = 0;
298   while (OffsetScaled < ObjSizeScaled) {
299     if (ObjSizeScaled - OffsetScaled >= 2) {
300       SDValue AddrNode = DAG.getMemBasePlusOffset(
301           Ptr, TypeSize::getFixed(OffsetScaled * 16), dl);
302       SDValue St = DAG.getMemIntrinsicNode(
303           OpCode2, dl, DAG.getVTList(MVT::Other),
304           {Chain, TagSrc, AddrNode},
305           MVT::v4i64,
306           MF.getMachineMemOperand(BaseMemOperand, OffsetScaled * 16, 16 * 2));
307       OffsetScaled += 2;
308       OutChains.push_back(St);
309       continue;
310     }
311 
312     if (ObjSizeScaled - OffsetScaled > 0) {
313       SDValue AddrNode = DAG.getMemBasePlusOffset(
314           Ptr, TypeSize::getFixed(OffsetScaled * 16), dl);
315       SDValue St = DAG.getMemIntrinsicNode(
316           OpCode1, dl, DAG.getVTList(MVT::Other),
317           {Chain, TagSrc, AddrNode},
318           MVT::v2i64,
319           MF.getMachineMemOperand(BaseMemOperand, OffsetScaled * 16, 16));
320       OffsetScaled += 1;
321       OutChains.push_back(St);
322     }
323   }
324 
325   SDValue Res = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
326   return Res;
327 }
328 
EmitTargetCodeForSetTag(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Addr,SDValue Size,MachinePointerInfo DstPtrInfo,bool ZeroData) const329 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForSetTag(
330     SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Addr,
331     SDValue Size, MachinePointerInfo DstPtrInfo, bool ZeroData) const {
332   uint64_t ObjSize = Size->getAsZExtVal();
333   assert(ObjSize % 16 == 0);
334 
335   MachineFunction &MF = DAG.getMachineFunction();
336   MachineMemOperand *BaseMemOperand = MF.getMachineMemOperand(
337       DstPtrInfo, MachineMemOperand::MOStore, ObjSize, Align(16));
338 
339   bool UseSetTagRangeLoop =
340       kSetTagLoopThreshold >= 0 && (int)ObjSize >= kSetTagLoopThreshold;
341   if (!UseSetTagRangeLoop)
342     return EmitUnrolledSetTag(DAG, dl, Chain, Addr, ObjSize, BaseMemOperand,
343                               ZeroData);
344 
345   const EVT ResTys[] = {MVT::i64, MVT::i64, MVT::Other};
346 
347   unsigned Opcode;
348   if (Addr.getOpcode() == ISD::FrameIndex) {
349     int FI = cast<FrameIndexSDNode>(Addr)->getIndex();
350     Addr = DAG.getTargetFrameIndex(FI, MVT::i64);
351     Opcode = ZeroData ? AArch64::STZGloop : AArch64::STGloop;
352   } else {
353     Opcode = ZeroData ? AArch64::STZGloop_wback : AArch64::STGloop_wback;
354   }
355   SDValue Ops[] = {DAG.getTargetConstant(ObjSize, dl, MVT::i64), Addr, Chain};
356   SDNode *St = DAG.getMachineNode(Opcode, dl, ResTys, Ops);
357 
358   DAG.setNodeMemRefs(cast<MachineSDNode>(St), {BaseMemOperand});
359   return SDValue(St, 2);
360 }
361