xref: /freebsd/contrib/llvm-project/llvm/lib/Target/VE/VVPISelLowering.cpp (revision 43e29d03f416d7dda52112a29600a7c82ee1a91e)
1 //===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering and legalization of vector instructions to
10 // VVP_*layer SDNodes.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "VECustomDAG.h"
15 #include "VEISelLowering.h"
16 
17 using namespace llvm;
18 
19 #define DEBUG_TYPE "ve-lower"
20 
21 SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
22                                               SelectionDAG &DAG) const {
23   VECustomDAG CDAG(DAG, Op);
24   SDValue AVL =
25       CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32);
26   SDValue A = Op->getOperand(0);
27   SDValue B = Op->getOperand(1);
28   SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL);
29   SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL);
30   SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL);
31   SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL);
32   unsigned Opc = Op.getOpcode();
33   auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB});
34   auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB});
35   return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL);
36 }
37 
38 SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
39   // Can we represent this as a VVP node.
40   const unsigned Opcode = Op->getOpcode();
41   auto VVPOpcodeOpt = getVVPOpcode(Opcode);
42   if (!VVPOpcodeOpt)
43     return SDValue();
44   unsigned VVPOpcode = *VVPOpcodeOpt;
45   const bool FromVP = ISD::isVPOpcode(Opcode);
46 
47   // The representative and legalized vector type of this operation.
48   VECustomDAG CDAG(DAG, Op);
49   // Dispatch to complex lowering functions.
50   switch (VVPOpcode) {
51   case VEISD::VVP_LOAD:
52   case VEISD::VVP_STORE:
53     return lowerVVP_LOAD_STORE(Op, CDAG);
54   case VEISD::VVP_GATHER:
55   case VEISD::VVP_SCATTER:
56     return lowerVVP_GATHER_SCATTER(Op, CDAG);
57   }
58 
59   EVT OpVecVT = *getIdiomaticVectorType(Op.getNode());
60   EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT);
61   auto Packing = getTypePacking(LegalVecVT.getSimpleVT());
62 
63   SDValue AVL;
64   SDValue Mask;
65 
66   if (FromVP) {
67     // All upstream VP SDNodes always have a mask and avl.
68     auto MaskIdx = ISD::getVPMaskIdx(Opcode);
69     auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
70     if (MaskIdx)
71       Mask = Op->getOperand(*MaskIdx);
72     if (AVLIdx)
73       AVL = Op->getOperand(*AVLIdx);
74   }
75 
76   // Materialize default mask and avl.
77   if (!AVL)
78     AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32);
79   if (!Mask)
80     Mask = CDAG.getConstantMask(Packing, true);
81 
82   assert(LegalVecVT.isSimple());
83   if (isVVPUnaryOp(VVPOpcode))
84     return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL});
85   if (isVVPBinaryOp(VVPOpcode))
86     return CDAG.getNode(VVPOpcode, LegalVecVT,
87                         {Op->getOperand(0), Op->getOperand(1), Mask, AVL});
88   if (isVVPReductionOp(VVPOpcode)) {
89     auto SrcHasStart = hasReductionStartParam(Op->getOpcode());
90     SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue();
91     SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0);
92     return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV,
93                                        VectorV, Mask, AVL, Op->getFlags());
94   }
95 
96   switch (VVPOpcode) {
97   default:
98     llvm_unreachable("lowerToVVP called for unexpected SDNode.");
99   case VEISD::VVP_FFMA: {
100     // VE has a swizzled operand order in FMA (compared to LLVM IR and
101     // SDNodes).
102     auto X = Op->getOperand(2);
103     auto Y = Op->getOperand(0);
104     auto Z = Op->getOperand(1);
105     return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL});
106   }
107   case VEISD::VVP_SELECT: {
108     auto Mask = Op->getOperand(0);
109     auto OnTrue = Op->getOperand(1);
110     auto OnFalse = Op->getOperand(2);
111     return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL});
112   }
113   case VEISD::VVP_SETCC: {
114     EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
115     auto LHS = Op->getOperand(0);
116     auto RHS = Op->getOperand(1);
117     auto Pred = Op->getOperand(2);
118     return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL});
119   }
120   }
121 }
122 
123 SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
124                                               VECustomDAG &CDAG) const {
125   auto VVPOpc = *getVVPOpcode(Op->getOpcode());
126   const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);
127 
128   // Shares.
129   SDValue BasePtr = getMemoryPtr(Op);
130   SDValue Mask = getNodeMask(Op);
131   SDValue Chain = getNodeChain(Op);
132   SDValue AVL = getNodeAVL(Op);
133   // Store specific.
134   SDValue Data = getStoredValue(Op);
135   // Load specific.
136   SDValue PassThru = getNodePassthru(Op);
137 
138   SDValue StrideV = getLoadStoreStride(Op, CDAG);
139 
140   auto DataVT = *getIdiomaticVectorType(Op.getNode());
141   auto Packing = getTypePacking(DataVT);
142 
143   // TODO: Infer lower AVL from mask.
144   if (!AVL)
145     AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
146 
147   // Default to the all-true mask.
148   if (!Mask)
149     Mask = CDAG.getConstantMask(Packing, true);
150 
151   if (IsLoad) {
152     MVT LegalDataVT = getLegalVectorType(
153         Packing, DataVT.getVectorElementType().getSimpleVT());
154 
155     auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other},
156                                  {Chain, BasePtr, StrideV, Mask, AVL});
157 
158     if (!PassThru || PassThru->isUndef())
159       return NewLoadV;
160 
161     // Convert passthru to an explicit select node.
162     SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT,
163                                  {NewLoadV, PassThru, Mask, AVL});
164     SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
165 
166     // Merge them back into one node.
167     return CDAG.getMergeValues({DataV, NewLoadChainV});
168   }
169 
170   // VVP_STORE
171   assert(VVPOpc == VEISD::VVP_STORE);
172   return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(),
173                       {Chain, Data, BasePtr, StrideV, Mask, AVL});
174 }
175 
176 SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
177                                                VECustomDAG &CDAG) const {
178   auto VVPOC = *getVVPOpcode(Op.getOpcode());
179   assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));
180 
181   MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
182   assert(getTypePacking(DataVT) == Packing::Dense &&
183          "Can only split packed load/store");
184   MVT SplitDataVT = splitVectorType(DataVT);
185 
186   assert(!getNodePassthru(Op) &&
187          "Should have been folded in lowering to VVP layer");
188 
189   // Analyze the operation
190   SDValue PackedMask = getNodeMask(Op);
191   SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
192   SDValue PackPtr = getMemoryPtr(Op);
193   SDValue PackData = getStoredValue(Op);
194   SDValue PackStride = getLoadStoreStride(Op, CDAG);
195 
196   unsigned ChainResIdx = PackData ? 0 : 1;
197 
198   SDValue PartOps[2];
199 
200   SDValue UpperPartAVL; // we will use this for packing things back together
201   for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
202     // VP ops already have an explicit mask and AVL. When expanding from non-VP
203     // attach those additional inputs here.
204     auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
205 
206     // Keep track of the (higher) lvl.
207     if (Part == PackElem::Hi)
208       UpperPartAVL = SplitTM.AVL;
209 
210     // Attach non-predicating value operands
211     SmallVector<SDValue, 4> OpVec;
212 
213     // Chain
214     OpVec.push_back(getNodeChain(Op));
215 
216     // Data
217     if (PackData) {
218       SDValue PartData =
219           CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL);
220       OpVec.push_back(PartData);
221     }
222 
223     // Ptr & Stride
224     // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
225     // Stride info
226     // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
227     OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part));
228     OpVec.push_back(CDAG.getSplitPtrStride(PackStride));
229 
230     // Add predicating args and generate part node
231     OpVec.push_back(SplitTM.Mask);
232     OpVec.push_back(SplitTM.AVL);
233 
234     if (PackData) {
235       // Store
236       PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec);
237     } else {
238       // Load
239       PartOps[(int)Part] =
240           CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec);
241     }
242   }
243 
244   // Merge the chains
245   SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
246   SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
247   SDValue FusedChains =
248       CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain});
249 
250   // Chain only [store]
251   if (PackData)
252     return FusedChains;
253 
254   // Re-pack into full packed vector result
255   MVT PackedVT =
256       getLegalVectorType(Packing::Dense, DataVT.getVectorElementType());
257   SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo],
258                                     PartOps[(int)PackElem::Hi], UpperPartAVL);
259 
260   return CDAG.getMergeValues({PackedVals, FusedChains});
261 }
262 
263 SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
264                                                   VECustomDAG &CDAG) const {
265   EVT DataVT = *getIdiomaticVectorType(Op.getNode());
266   auto Packing = getTypePacking(DataVT);
267   MVT LegalDataVT =
268       getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT());
269 
270   SDValue AVL = getAnnotatedNodeAVL(Op).first;
271   SDValue Index = getGatherScatterIndex(Op);
272   SDValue BasePtr = getMemoryPtr(Op);
273   SDValue Mask = getNodeMask(Op);
274   SDValue Chain = getNodeChain(Op);
275   SDValue Scale = getGatherScatterScale(Op);
276   SDValue PassThru = getNodePassthru(Op);
277   SDValue StoredValue = getStoredValue(Op);
278   if (PassThru && PassThru->isUndef())
279     PassThru = SDValue();
280 
281   bool IsScatter = (bool)StoredValue;
282 
283   // TODO: Infer lower AVL from mask.
284   if (!AVL)
285     AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
286 
287   // Default to the all-true mask.
288   if (!Mask)
289     Mask = CDAG.getConstantMask(Packing, true);
290 
291   SDValue AddressVec =
292       CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
293   if (IsScatter)
294     return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other,
295                         {Chain, StoredValue, AddressVec, Mask, AVL});
296 
297   // Gather.
298   SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other},
299                                   {Chain, AddressVec, Mask, AVL});
300 
301   if (!PassThru)
302     return NewLoadV;
303 
304   // TODO: Use vvp_select
305   SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT,
306                                {NewLoadV, PassThru, Mask, AVL});
307   SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
308   return CDAG.getMergeValues({DataV, NewLoadChainV});
309 }
310 
311 SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
312                                                       VECustomDAG &CDAG) const {
313   LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
314   MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
315 
316   // TODO: Recognize packable load,store.
317   if (isPackedVectorType(DataVT))
318     return splitPackedLoadStore(Op, CDAG);
319 
320   return legalizePackedAVL(Op, CDAG);
321 }
322 
323 SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
324                                                    SelectionDAG &DAG) const {
325   LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
326   VECustomDAG CDAG(DAG, Op);
327 
328   // Dispatch to specialized legalization functions.
329   switch (Op->getOpcode()) {
330   case VEISD::VVP_LOAD:
331   case VEISD::VVP_STORE:
332     return legalizeInternalLoadStoreOp(Op, CDAG);
333   }
334 
335   EVT IdiomVT = Op.getValueType();
336   if (isPackedVectorType(IdiomVT) &&
337       !supportsPackedMode(Op.getOpcode(), IdiomVT))
338     return splitVectorOp(Op, CDAG);
339 
340   // TODO: Implement odd/even splitting.
341   return legalizePackedAVL(Op, CDAG);
342 }
343 
344 SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
345   MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType());
346 
347   auto AVLPos = getAVLPos(Op->getOpcode());
348   auto MaskPos = getMaskPos(Op->getOpcode());
349 
350   SDValue PackedMask = getNodeMask(Op);
351   auto AVLPair = getAnnotatedNodeAVL(Op);
352   SDValue PackedAVL = AVLPair.first;
353   assert(!AVLPair.second && "Expecting non pack-legalized oepration");
354 
355   // request the parts
356   SDValue PartOps[2];
357 
358   SDValue UpperPartAVL; // we will use this for packing things back together
359   for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
360     // VP ops already have an explicit mask and AVL. When expanding from non-VP
361     // attach those additional inputs here.
362     auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
363 
364     if (Part == PackElem::Hi)
365       UpperPartAVL = SplitTM.AVL;
366 
367     // Attach non-predicating value operands
368     SmallVector<SDValue, 4> OpVec;
369     for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
370       if (AVLPos && ((int)i) == *AVLPos)
371         continue;
372       if (MaskPos && ((int)i) == *MaskPos)
373         continue;
374 
375       // Value operand
376       auto PackedOperand = Op.getOperand(i);
377       auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType());
378       SDValue PartV =
379           CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL);
380       OpVec.push_back(PartV);
381     }
382 
383     // Add predicating args and generate part node.
384     OpVec.push_back(SplitTM.Mask);
385     OpVec.push_back(SplitTM.AVL);
386     // Emit legal VVP nodes.
387     PartOps[(int)Part] =
388         CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags());
389   }
390 
391   // Re-package vectors.
392   return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo],
393                       PartOps[(int)PackElem::Hi], UpperPartAVL);
394 }
395 
396 SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
397                                             VECustomDAG &CDAG) const {
398   LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
399   // Only required for VEC and VVP ops.
400   if (!isVVPOrVEC(Op->getOpcode()))
401     return Op;
402 
403   // Operation already has a legal AVL.
404   auto AVL = getNodeAVL(Op);
405   if (isLegalAVL(AVL))
406     return Op;
407 
408   // Half and round up EVL for 32bit element types.
409   SDValue LegalAVL = AVL;
410   MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
411   if (isPackedVectorType(IdiomVT)) {
412     assert(maySafelyIgnoreMask(Op) &&
413            "TODO Shift predication from EVL into Mask");
414 
415     if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {
416       LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);
417     } else {
418       auto ConstOne = CDAG.getConstant(1, MVT::i32);
419       auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});
420       LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});
421     }
422   }
423 
424   SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);
425 
426   // Copy the operand list.
427   int NumOp = Op->getNumOperands();
428   auto AVLPos = getAVLPos(Op->getOpcode());
429   std::vector<SDValue> FixedOperands;
430   for (int i = 0; i < NumOp; ++i) {
431     if (AVLPos && (i == *AVLPos)) {
432       FixedOperands.push_back(AnnotatedLegalAVL);
433       continue;
434     }
435     FixedOperands.push_back(Op->getOperand(i));
436   }
437 
438   // Clone the operation with fixed operands.
439   auto Flags = Op->getFlags();
440   SDValue NewN =
441       CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);
442   return NewN;
443 }
444