xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64SelectionDAGInfo.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===-- AArch64SelectionDAGInfo.cpp - AArch64 SelectionDAG Info -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AArch64SelectionDAGInfo class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "AArch64TargetMachine.h"
14 using namespace llvm;
15 
16 #define DEBUG_TYPE "aarch64-selectiondag-info"
17 
18 static cl::opt<bool>
19     LowerToSMERoutines("aarch64-lower-to-sme-routines", cl::Hidden,
20                        cl::desc("Enable AArch64 SME memory operations "
21                                 "to lower to librt functions"),
22                        cl::init(true));
23 
24 SDValue AArch64SelectionDAGInfo::EmitMOPS(AArch64ISD::NodeType SDOpcode,
25                                           SelectionDAG &DAG, const SDLoc &DL,
26                                           SDValue Chain, SDValue Dst,
27                                           SDValue SrcOrValue, SDValue Size,
28                                           Align Alignment, bool isVolatile,
29                                           MachinePointerInfo DstPtrInfo,
30                                           MachinePointerInfo SrcPtrInfo) const {
31 
32   // Get the constant size of the copy/set.
33   uint64_t ConstSize = 0;
34   if (auto *C = dyn_cast<ConstantSDNode>(Size))
35     ConstSize = C->getZExtValue();
36 
37   const bool IsSet = SDOpcode == AArch64ISD::MOPS_MEMSET ||
38                      SDOpcode == AArch64ISD::MOPS_MEMSET_TAGGING;
39 
40   const auto MachineOpcode = [&]() {
41     switch (SDOpcode) {
42     case AArch64ISD::MOPS_MEMSET:
43       return AArch64::MOPSMemorySetPseudo;
44     case AArch64ISD::MOPS_MEMSET_TAGGING:
45       return AArch64::MOPSMemorySetTaggingPseudo;
46     case AArch64ISD::MOPS_MEMCOPY:
47       return AArch64::MOPSMemoryCopyPseudo;
48     case AArch64ISD::MOPS_MEMMOVE:
49       return AArch64::MOPSMemoryMovePseudo;
50     default:
51       llvm_unreachable("Unhandled MOPS ISD Opcode");
52     }
53   }();
54 
55   MachineFunction &MF = DAG.getMachineFunction();
56 
57   auto Vol =
58       isVolatile ? MachineMemOperand::MOVolatile : MachineMemOperand::MONone;
59   auto DstFlags = MachineMemOperand::MOStore | Vol;
60   auto *DstOp =
61       MF.getMachineMemOperand(DstPtrInfo, DstFlags, ConstSize, Alignment);
62 
63   if (IsSet) {
64     // Extend value to i64, if required.
65     if (SrcOrValue.getValueType() != MVT::i64)
66       SrcOrValue = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, SrcOrValue);
67     SDValue Ops[] = {Dst, Size, SrcOrValue, Chain};
68     const EVT ResultTys[] = {MVT::i64, MVT::i64, MVT::Other};
69     MachineSDNode *Node = DAG.getMachineNode(MachineOpcode, DL, ResultTys, Ops);
70     DAG.setNodeMemRefs(Node, {DstOp});
71     return SDValue(Node, 2);
72   } else {
73     SDValue Ops[] = {Dst, SrcOrValue, Size, Chain};
74     const EVT ResultTys[] = {MVT::i64, MVT::i64, MVT::i64, MVT::Other};
75     MachineSDNode *Node = DAG.getMachineNode(MachineOpcode, DL, ResultTys, Ops);
76 
77     auto SrcFlags = MachineMemOperand::MOLoad | Vol;
78     auto *SrcOp =
79         MF.getMachineMemOperand(SrcPtrInfo, SrcFlags, ConstSize, Alignment);
80     DAG.setNodeMemRefs(Node, {DstOp, SrcOp});
81     return SDValue(Node, 3);
82   }
83 }
84 
85 SDValue AArch64SelectionDAGInfo::EmitStreamingCompatibleMemLibCall(
86     SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
87     SDValue Size, RTLIB::Libcall LC) const {
88   const AArch64Subtarget &STI =
89       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
90   const AArch64TargetLowering *TLI = STI.getTargetLowering();
91   SDValue Symbol;
92   TargetLowering::ArgListEntry DstEntry;
93   DstEntry.Ty = PointerType::getUnqual(*DAG.getContext());
94   DstEntry.Node = Dst;
95   TargetLowering::ArgListTy Args;
96   Args.push_back(DstEntry);
97   EVT PointerVT = TLI->getPointerTy(DAG.getDataLayout());
98 
99   switch (LC) {
100   case RTLIB::MEMCPY: {
101     TargetLowering::ArgListEntry Entry;
102     Entry.Ty = PointerType::getUnqual(*DAG.getContext());
103     Symbol = DAG.getExternalSymbol("__arm_sc_memcpy", PointerVT);
104     Entry.Node = Src;
105     Args.push_back(Entry);
106     break;
107   }
108   case RTLIB::MEMMOVE: {
109     TargetLowering::ArgListEntry Entry;
110     Entry.Ty = PointerType::getUnqual(*DAG.getContext());
111     Symbol = DAG.getExternalSymbol("__arm_sc_memmove", PointerVT);
112     Entry.Node = Src;
113     Args.push_back(Entry);
114     break;
115   }
116   case RTLIB::MEMSET: {
117     TargetLowering::ArgListEntry Entry;
118     Entry.Ty = Type::getInt32Ty(*DAG.getContext());
119     Symbol = DAG.getExternalSymbol("__arm_sc_memset", PointerVT);
120     Src = DAG.getZExtOrTrunc(Src, DL, MVT::i32);
121     Entry.Node = Src;
122     Args.push_back(Entry);
123     break;
124   }
125   default:
126     return SDValue();
127   }
128 
129   TargetLowering::ArgListEntry SizeEntry;
130   SizeEntry.Node = Size;
131   SizeEntry.Ty = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
132   Args.push_back(SizeEntry);
133   assert(Symbol->getOpcode() == ISD::ExternalSymbol &&
134          "Function name is not set");
135 
136   TargetLowering::CallLoweringInfo CLI(DAG);
137   PointerType *RetTy = PointerType::getUnqual(*DAG.getContext());
138   CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
139       TLI->getLibcallCallingConv(LC), RetTy, Symbol, std::move(Args));
140   return TLI->LowerCallTo(CLI).second;
141 }
142 
143 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemcpy(
144     SelectionDAG &DAG, const SDLoc &DL, SDValue Chain, SDValue Dst, SDValue Src,
145     SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
146     MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
147   const AArch64Subtarget &STI =
148       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
149 
150   if (STI.hasMOPS())
151     return EmitMOPS(AArch64ISD::MOPS_MEMCOPY, DAG, DL, Chain, Dst, Src, Size,
152                     Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
153 
154   SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
155   if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
156     return EmitStreamingCompatibleMemLibCall(DAG, DL, Chain, Dst, Src, Size,
157                                              RTLIB::MEMCPY);
158   return SDValue();
159 }
160 
161 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemset(
162     SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Dst, SDValue Src,
163     SDValue Size, Align Alignment, bool isVolatile, bool AlwaysInline,
164     MachinePointerInfo DstPtrInfo) const {
165   const AArch64Subtarget &STI =
166       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
167 
168   if (STI.hasMOPS())
169     return EmitMOPS(AArch64ISD::MOPS_MEMSET, DAG, dl, Chain, Dst, Src, Size,
170                     Alignment, isVolatile, DstPtrInfo, MachinePointerInfo{});
171 
172   SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
173   if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
174     return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
175                                              RTLIB::MEMSET);
176   return SDValue();
177 }
178 
179 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForMemmove(
180     SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Dst, SDValue Src,
181     SDValue Size, Align Alignment, bool isVolatile,
182     MachinePointerInfo DstPtrInfo, MachinePointerInfo SrcPtrInfo) const {
183   const AArch64Subtarget &STI =
184       DAG.getMachineFunction().getSubtarget<AArch64Subtarget>();
185 
186   if (STI.hasMOPS())
187     return EmitMOPS(AArch64ISD::MOPS_MEMMOVE, DAG, dl, Chain, Dst, Src, Size,
188                     Alignment, isVolatile, DstPtrInfo, SrcPtrInfo);
189 
190   SMEAttrs Attrs(DAG.getMachineFunction().getFunction());
191   if (LowerToSMERoutines && !Attrs.hasNonStreamingInterfaceAndBody())
192     return EmitStreamingCompatibleMemLibCall(DAG, dl, Chain, Dst, Src, Size,
193                                              RTLIB::MEMMOVE);
194   return SDValue();
195 }
196 
197 static const int kSetTagLoopThreshold = 176;
198 
199 static SDValue EmitUnrolledSetTag(SelectionDAG &DAG, const SDLoc &dl,
200                                   SDValue Chain, SDValue Ptr, uint64_t ObjSize,
201                                   const MachineMemOperand *BaseMemOperand,
202                                   bool ZeroData) {
203   MachineFunction &MF = DAG.getMachineFunction();
204   unsigned ObjSizeScaled = ObjSize / 16;
205 
206   SDValue TagSrc = Ptr;
207   if (Ptr.getOpcode() == ISD::FrameIndex) {
208     int FI = cast<FrameIndexSDNode>(Ptr)->getIndex();
209     Ptr = DAG.getTargetFrameIndex(FI, MVT::i64);
210     // A frame index operand may end up as [SP + offset] => it is fine to use SP
211     // register as the tag source.
212     TagSrc = DAG.getRegister(AArch64::SP, MVT::i64);
213   }
214 
215   const unsigned OpCode1 = ZeroData ? AArch64ISD::STZG : AArch64ISD::STG;
216   const unsigned OpCode2 = ZeroData ? AArch64ISD::STZ2G : AArch64ISD::ST2G;
217 
218   SmallVector<SDValue, 8> OutChains;
219   unsigned OffsetScaled = 0;
220   while (OffsetScaled < ObjSizeScaled) {
221     if (ObjSizeScaled - OffsetScaled >= 2) {
222       SDValue AddrNode = DAG.getMemBasePlusOffset(
223           Ptr, TypeSize::getFixed(OffsetScaled * 16), dl);
224       SDValue St = DAG.getMemIntrinsicNode(
225           OpCode2, dl, DAG.getVTList(MVT::Other),
226           {Chain, TagSrc, AddrNode},
227           MVT::v4i64,
228           MF.getMachineMemOperand(BaseMemOperand, OffsetScaled * 16, 16 * 2));
229       OffsetScaled += 2;
230       OutChains.push_back(St);
231       continue;
232     }
233 
234     if (ObjSizeScaled - OffsetScaled > 0) {
235       SDValue AddrNode = DAG.getMemBasePlusOffset(
236           Ptr, TypeSize::getFixed(OffsetScaled * 16), dl);
237       SDValue St = DAG.getMemIntrinsicNode(
238           OpCode1, dl, DAG.getVTList(MVT::Other),
239           {Chain, TagSrc, AddrNode},
240           MVT::v2i64,
241           MF.getMachineMemOperand(BaseMemOperand, OffsetScaled * 16, 16));
242       OffsetScaled += 1;
243       OutChains.push_back(St);
244     }
245   }
246 
247   SDValue Res = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OutChains);
248   return Res;
249 }
250 
251 SDValue AArch64SelectionDAGInfo::EmitTargetCodeForSetTag(
252     SelectionDAG &DAG, const SDLoc &dl, SDValue Chain, SDValue Addr,
253     SDValue Size, MachinePointerInfo DstPtrInfo, bool ZeroData) const {
254   uint64_t ObjSize = Size->getAsZExtVal();
255   assert(ObjSize % 16 == 0);
256 
257   MachineFunction &MF = DAG.getMachineFunction();
258   MachineMemOperand *BaseMemOperand = MF.getMachineMemOperand(
259       DstPtrInfo, MachineMemOperand::MOStore, ObjSize, Align(16));
260 
261   bool UseSetTagRangeLoop =
262       kSetTagLoopThreshold >= 0 && (int)ObjSize >= kSetTagLoopThreshold;
263   if (!UseSetTagRangeLoop)
264     return EmitUnrolledSetTag(DAG, dl, Chain, Addr, ObjSize, BaseMemOperand,
265                               ZeroData);
266 
267   const EVT ResTys[] = {MVT::i64, MVT::i64, MVT::Other};
268 
269   unsigned Opcode;
270   if (Addr.getOpcode() == ISD::FrameIndex) {
271     int FI = cast<FrameIndexSDNode>(Addr)->getIndex();
272     Addr = DAG.getTargetFrameIndex(FI, MVT::i64);
273     Opcode = ZeroData ? AArch64::STZGloop : AArch64::STGloop;
274   } else {
275     Opcode = ZeroData ? AArch64::STZGloop_wback : AArch64::STGloop_wback;
276   }
277   SDValue Ops[] = {DAG.getTargetConstant(ObjSize, dl, MVT::i64), Addr, Chain};
278   SDNode *St = DAG.getMachineNode(Opcode, dl, ResTys, Ops);
279 
280   DAG.setNodeMemRefs(cast<MachineSDNode>(St), {BaseMemOperand});
281   return SDValue(St, 2);
282 }
283