1 //===-- AArch64SelectionDAGInfo.cpp - AArch64 SelectionDAG Info -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AArch64SelectionDAGInfo class.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "AArch64SelectionDAGInfo.h"
14 #include "AArch64MachineFunctionInfo.h"
15
16 #define GET_SDNODE_DESC
17 #include "AArch64GenSDNodeInfo.inc"
18 #undef GET_SDNODE_DESC
19
20 using namespace llvm;
21
22 #define DEBUG_TYPE "aarch64-selectiondag-info"
23
24 static cl::opt<bool>
25 LowerToSMERoutines("aarch64-lower-to-sme-routines", cl::Hidden,
26 cl::desc("Enable AArch64 SME memory operations "
27 "to lower to librt functions"),
28 cl::init(true));
29
AArch64SelectionDAGInfo()30 AArch64SelectionDAGInfo::AArch64SelectionDAGInfo()
31 : SelectionDAGGenTargetInfo(AArch64GenSDNodeInfo) {}
32
verifyTargetNode(const SelectionDAG & DAG,const SDNode * N) const33 void AArch64SelectionDAGInfo::verifyTargetNode(const SelectionDAG &DAG,
34 const SDNode *N) const {
35 #ifndef NDEBUG
36 switch (N->getOpcode()) {
37 default:
38 return SelectionDAGGenTargetInfo::verifyTargetNode(DAG, N);
39 case AArch64ISD::SADDWT:
40 case AArch64ISD::SADDWB:
41 case AArch64ISD::UADDWT:
42 case AArch64ISD::UADDWB: {
43 assert(N->getNumValues() == 1 && "Expected one result!");
44 assert(N->getNumOperands() == 2 && "Expected two operands!");
45 EVT VT = N->getValueType(0);
46 EVT Op0VT = N->getOperand(0).getValueType();
47 EVT Op1VT = N->getOperand(1).getValueType();
48 assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
49 VT.isInteger() && Op0VT.isInteger() && Op1VT.isInteger() &&
50 "Expected integer vectors!");
51 assert(VT == Op0VT &&
52 "Expected result and first input to have the same type!");
53 assert(Op0VT.getSizeInBits() == Op1VT.getSizeInBits() &&
54 "Expected vectors of equal size!");
55 assert(Op0VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount() &&
56 "Expected result vector and first input vector to have half the "
57 "lanes of the second input vector!");
58 break;
59 }
60 case AArch64ISD::SUNPKLO:
61 case AArch64ISD::SUNPKHI:
62 case AArch64ISD::UUNPKLO:
63 case AArch64ISD::UUNPKHI: {
64 assert(N->getNumValues() == 1 && "Expected one result!");
65 assert(N->getNumOperands() == 1 && "Expected one operand!");
66 EVT VT = N->getValueType(0);
67 EVT OpVT = N->getOperand(0).getValueType();
68 assert(OpVT.isVector() && VT.isVector() && OpVT.isInteger() &&
69 VT.isInteger() && "Expected integer vectors!");
70 assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
71 "Expected vectors of equal size!");
72 assert(OpVT.getVectorElementCount() == VT.getVectorElementCount() * 2 &&
73 "Expected result vector with half the lanes of its input!");
74 break;
75 }
76 case AArch64ISD::TRN1:
77 case AArch64ISD::TRN2:
78 case AArch64ISD::UZP1:
79 case AArch64ISD::UZP2:
80 case AArch64ISD::ZIP1:
81 case AArch64ISD::ZIP2: {
82 assert(N->getNumValues() == 1 && "Expected one result!");
83 assert(N->getNumOperands() == 2 && "Expected two operands!");
84 EVT VT = N->getValueType(0);
85 EVT Op0VT = N->getOperand(0).getValueType();
86 EVT Op1VT = N->getOperand(1).getValueType();
87 assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
88 "Expected vectors!");
89 assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
90 break;
91 }
92 case AArch64ISD::RSHRNB_I: {
93 assert(N->getNumValues() == 1 && "Expected one result!");
94 assert(N->getNumOperands() == 2 && "Expected two operands!");
95 EVT VT = N->getValueType(0);
96 EVT Op0VT = N->getOperand(0).getValueType();
97 EVT Op1VT = N->getOperand(1).getValueType();
98 assert(VT.isVector() && VT.isInteger() &&
99 "Expected integer vector result type!");
100 assert(Op0VT.isVector() && Op0VT.isInteger() &&
101 "Expected first operand to be an integer vector!");
102 assert(VT.getSizeInBits() == Op0VT.getSizeInBits() &&
103 "Expected vectors of equal size!");
104 assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 &&
105 "Expected input vector with half the lanes of its result!");
106 assert(Op1VT == MVT::i32 && isa<ConstantSDNode>(N->getOperand(1)) &&
107 "Expected second operand to be a constant i32!");
108 break;
109 }
110 }
111 #endif
112 }
113
EmitMOPS(unsigned Opcode,SelectionDAG & DAG,const SDLoc & DL,SDValue Chain,SDValue Dst,SDValue SrcOrValue,SDValue Size,Align Alignment,bool isVolatile,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo) const114 SDValue AArch64SelectionDAGInfo::EmitMOPS(unsigned Opcode, SelectionDAG &DAG,
115 const SDLoc &DL, SDValue Chain,
116 SDValue Dst, SDValue SrcOrValue,
117 SDValue Size, Align Alignment,
118 bool isVolatile,
119 MachinePointerInfo DstPtrInfo,
120 MachinePointerInfo SrcPtrInfo) const {
121
122 // Get the constant size of the copy/set.
123 uint64_t ConstSize = 0;
124 if (auto *C = dyn_cast<ConstantSDNode>(Size))
125 ConstSize = C->getZExtValue();
126
127 const bool IsSet = Opcode == AArch64::MOPSMemorySetPseudo ||
128 Opcode == AArch64::MOPSMemorySetTaggingPseudo;
129
130 MachineFunction &MF = DAG.getMachineFunction();
131
132 auto Vol =
133 isVolatile ? MachineMemOperand::MOVolatile : MachineMemOperand::MONone;
134 auto DstFlags = MachineMemOperand::MOStore | Vol;
135 auto *DstOp =
136 MF.getMachineMemOperand(DstPtrInfo, DstFlags, ConstSize, Alignment);
137
138 if (IsSet) {
139 // Extend value to i64, if required.
140 if (SrcOrValue.getValueType() != MVT::i64)
141 SrcOrValue = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, SrcOrValue);
142 SDValue Ops[] = {Dst, Size, SrcOrValue, Chain};
143 const EVT ResultTys[] = {MVT::i64, MVT::i64, MVT::Other};
144 MachineSDNode *Node = DAG.getMachineNode(Opcode, DL, ResultTys, Ops);
145 DAG.setNodeMemRefs(Node, {DstOp});
146 return SDValue(Node, 2);
147 } else {
148 SDValue Ops[] = {Dst, SrcOrValue, Size, Chain};
149 const EVT ResultTys[] = {MVT::i64, MVT::i64, MVT::i64, MVT::Other};
150 MachineSDNode *Node = DAG.getMachineNode(Opcode, DL, ResultTys, Ops);
151
152 auto SrcFlags = MachineMemOperand::MOLoad | Vol;
153 auto *SrcOp =
154 MF.getMachineMemOperand(SrcPtrInfo, SrcFlags, ConstSize, Alignment);
155 DAG.setNodeMemRefs(Node, {DstOp, SrcOp});
156 return SDValue(Node, 3);
157 }
158 }
159
EmitStreamingCompatibleMemLibCall(SelectionDAG & DAG,const SDLoc & DL,SDValue Chain,SDValue Dst,SDValue Src,SDValue Size,RTLIB::Libcall LC) const160 SDValue AArch64SelectionDAGInfo::EmitStreamingCompatibleMemLibCall(
161 SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
162 SDValue Size, RTLIB::Libcall LC) const {
163 const AArch64Subtarget &STI =
164 DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
165 const AArch64TargetLowering *TLI = STI.getTargetLowering();
166 TargetLowering::ArgListEntry DstEntry;
167 DstEntry.Ty = PointerType::getUnqual(*DAG.getContext());
168 DstEntry.Node = Dst;
169 TargetLowering::ArgListTy Args;
170 Args.push_back(DstEntry);
171
172 RTLIB::Libcall NewLC;
173 switch (LC) {
174 case RTLIB::MEMCPY: {
175 NewLC = RTLIB::SC_MEMCPY;
176 TargetLowering::ArgListEntry Entry;
177 Entry.Ty = PointerType::getUnqual(*DAG.getContext());
178 Entry.Node = Src;
179 Args.push_back(Entry);
180 break;
181 }
182 case RTLIB::MEMMOVE: {
183 NewLC = RTLIB::SC_MEMMOVE;
184 TargetLowering::ArgListEntry Entry;
185 Entry.Ty = PointerType::getUnqual(*DAG.getContext());
186 Entry.Node = Src;
187 Args.push_back(Entry);
188 break;
189 }
190 case RTLIB::MEMSET: {
191 NewLC = RTLIB::SC_MEMSET;
192 TargetLowering::ArgListEntry Entry;
193 Entry.Ty = Type::getInt32Ty(*DAG.getContext());
194 Src = DAG.getZExtOrTrunc(Src, DL, MVT::i32);
195 Entry.Node = Src;
196 Args.push_back(Entry);
197 break;
198 }
199 default:
200 return SDValue();
201 }
202
203 EVT PointerVT = TLI->getPointerTy(DAG.getDataLayout());
204 SDValue Symbol = DAG.getExternalSymbol(TLI->getLibcallName(NewLC), PointerVT);
205 TargetLowering::ArgListEntry SizeEntry;
206 SizeEntry.Node = Size;
207 SizeEntry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
208 Args.push_back(SizeEntry);
209
210 TargetLowering::CallLoweringInfo CLI(DAG);
211 PointerType *RetTy = PointerType::getUnqual(*DAG.getContext());
212 CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
213 TLI->getLibcallCallingConv(NewLC), RetTy, Symbol, std::move(Args));
214 return TLI->LowerCallTo(CLI).second;
215 }
216
EmitTargetCodeForMemcpy(SelectionDAG & DAG,const SDLoc & DL,SDValue Chain,SDValue Dst,SDValue Src,SDValue Size,Align Alignment,bool isVolatile,bool AlwaysInline,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo) const217 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
218 SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
219 SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
220 MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
221 const AArch64Subtarget &STI =
222 DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
223
224 if (STI.hasMOPS())
225 return EmitMOPS(AArch64::MOPSMemoryCopyPseudo, DAG, DL, Chain, Dst, Src,
226 Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
227
228 auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
229 SMEAttrs Attrs = AFI->getSMEFnAttrs();
230 if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
231 return EmitStreamingCompatibleMemLibCall(DAG, DL, Chain, Dst, Src, Size,
232 RTLIB::MEMCPY);
233 return SDValue();
234 }
235
EmitTargetCodeForMemset(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Dst,SDValue Src,SDValue Size,Align Alignment,bool isVolatile,bool AlwaysInline,MachinePointerInfo DstPtrInfo) const236 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
237 SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Dst, SDValue Src,
238 SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
239 MachinePointerInfo DstPtrInfo) const {
240 const AArch64Subtarget &STI =
241 DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
242
243 if (STI.hasMOPS())
244 return EmitMOPS(AArch64::MOPSMemorySetPseudo, DAG, dl, Chain, Dst, Src,
245 Size, Alignment, isVolatile, DstPtrInfo,
246 MachinePointerInfo{});
247
248 auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
249 SMEAttrs Attrs = AFI->getSMEFnAttrs();
250 if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
251 return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
252 RTLIB::MEMSET);
253 return SDValue();
254 }
255
EmitTargetCodeForMemmove(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Dst,SDValue Src,SDValue Size,Align Alignment,bool isVolatile,MachinePointerInfo DstPtrInfo,MachinePointerInfo SrcPtrInfo) const256 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
257 SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Dst, SDValue Src,
258 SDValue Size, Align Alignment, bool isVolatile,
259 MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
260 const AArch64Subtarget &STI =
261 DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
262
263 if (STI.hasMOPS())
264 return EmitMOPS(AArch64::MOPSMemoryMovePseudo, DAG, dl, Chain, Dst, Src,
265 Size, Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
266
267 auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
268 SMEAttrs Attrs = AFI->getSMEFnAttrs();
269 if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
270 return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
271 RTLIB::MEMMOVE);
272 return SDValue();
273 }
274
275 static const int kSetTagLoopThreshold = 176;
276
EmitUnrolledSetTag(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Ptr,uint64_t ObjSize,const MachineMemOperand * BaseMemOperand,bool ZeroData)277 static SDValue EmitUnrolledSetTag(SelectionDAG &DAG, const SDLoc &dl,
278 SDValue Chain, SDValue Ptr, uint64_t ObjSize,
279 const MachineMemOperand *BaseMemOperand,
280 bool ZeroData) {
281 MachineFunction &MF = DAG.getMachineFunction();
282 unsigned ObjSizeScaled = ObjSize / 16;
283
284 SDValue TagSrc = Ptr;
285 if (Ptr.getOpcode() == ISD::FrameIndex) {
286 int FI = cast<FrameIndexSDNode>(Ptr)->getIndex();
287 Ptr = DAG.getTargetFrameIndex(FI, MVT::i64);
288 // A frame index operand may end up as [SP + offset] => it is fine to use SP
289 // register as the tag source.
290 TagSrc = DAG.getRegister(AArch64::SP, MVT::i64);
291 }
292
293 const unsigned OpCode1 = ZeroData ? AArch64ISD::STZG : AArch64ISD::STG;
294 const unsigned OpCode2 = ZeroData ? AArch64ISD::STZ2G : AArch64ISD::ST2G;
295
296 SmallVector<SDValue, 8> OutChains;
297 unsigned OffsetScaled = 0;
298 while (OffsetScaled < ObjSizeScaled) {
299 if (ObjSizeScaled - OffsetScaled >= 2) {
300 SDValue AddrNode = DAG.getMemBasePlusOffset(
301 Ptr, TypeSize::getFixed(OffsetScaled * 16), dl);
302 SDValue St = DAG.getMemIntrinsicNode(
303 OpCode2, dl, DAG.getVTList(MVT::Other),
304 {Chain, TagSrc, AddrNode},
305 MVT::v4i64,
306 MF.getMachineMemOperand(BaseMemOperand, OffsetScaled * 16, 16 * 2));
307 OffsetScaled += 2;
308 OutChains.push_back(St);
309 continue;
310 }
311
312 if (ObjSizeScaled - OffsetScaled > 0) {
313 SDValue AddrNode = DAG.getMemBasePlusOffset(
314 Ptr, TypeSize::getFixed(OffsetScaled * 16), dl);
315 SDValue St = DAG.getMemIntrinsicNode(
316 OpCode1, dl, DAG.getVTList(MVT::Other),
317 {Chain, TagSrc, AddrNode},
318 MVT::v2i64,
319 MF.getMachineMemOperand(BaseMemOperand, OffsetScaled * 16, 16));
320 OffsetScaled += 1;
321 OutChains.push_back(St);
322 }
323 }
324
325 SDValue Res = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
326 return Res;
327 }
328
EmitTargetCodeForSetTag(SelectionDAG & DAG,const SDLoc & dl,SDValue Chain,SDValue Addr,SDValue Size,MachinePointerInfo DstPtrInfo,bool ZeroData) const329 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForSetTag(
330 SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Addr,
331 SDValue Size, MachinePointerInfo DstPtrInfo, bool ZeroData) const {
332 uint64_t ObjSize = Size->getAsZExtVal();
333 assert(ObjSize % 16 == 0);
334
335 MachineFunction &MF = DAG.getMachineFunction();
336 MachineMemOperand *BaseMemOperand = MF.getMachineMemOperand(
337 DstPtrInfo, MachineMemOperand::MOStore, ObjSize, Align(16));
338
339 bool UseSetTagRangeLoop =
340 kSetTagLoopThreshold >= 0 && (int)ObjSize >= kSetTagLoopThreshold;
341 if (!UseSetTagRangeLoop)
342 return EmitUnrolledSetTag(DAG, dl, Chain, Addr, ObjSize, BaseMemOperand,
343 ZeroData);
344
345 const EVT ResTys[] = {MVT::i64, MVT::i64, MVT::Other};
346
347 unsigned Opcode;
348 if (Addr.getOpcode() == ISD::FrameIndex) {
349 int FI = cast<FrameIndexSDNode>(Addr)->getIndex();
350 Addr = DAG.getTargetFrameIndex(FI, MVT::i64);
351 Opcode = ZeroData ? AArch64::STZGloop : AArch64::STGloop;
352 } else {
353 Opcode = ZeroData ? AArch64::STZGloop_wback : AArch64::STGloop_wback;
354 }
355 SDValue Ops[] = {DAG.getTargetConstant(ObjSize, dl, MVT::i64), Addr, Chain};
356 SDNode *St = DAG.getMachineNode(Opcode, dl, ResTys, Ops);
357
358 DAG.setNodeMemRefs(cast<MachineSDNode>(St), {BaseMemOperand});
359 return SDValue(St, 2);
360 }
361