xref: /freebsd/contrib/llvm-project/llvm/lib/Target/VE/VVPISelLowering.cpp (revision b64c5a0ace59af62eff52bfe110a521dc73c937b)
1 //===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering and legalization of vector instructions to
10 // VVP_*layer SDNodes.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "VECustomDAG.h"
15 #include "VEISelLowering.h"
16 
17 using namespace llvm;
18 
19 #define DEBUG_TYPE "ve-lower"
20 
21 SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
22                                               SelectionDAG &DAG) const {
23   VECustomDAG CDAG(DAG, Op);
24   SDValue AVL =
25       CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32);
26   SDValue A = Op->getOperand(0);
27   SDValue B = Op->getOperand(1);
28   SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL);
29   SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL);
30   SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL);
31   SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL);
32   unsigned Opc = Op.getOpcode();
33   auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB});
34   auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB});
35   return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL);
36 }
37 
38 SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
39   // Can we represent this as a VVP node.
40   const unsigned Opcode = Op->getOpcode();
41   auto VVPOpcodeOpt = getVVPOpcode(Opcode);
42   if (!VVPOpcodeOpt)
43     return SDValue();
44   unsigned VVPOpcode = *VVPOpcodeOpt;
45   const bool FromVP = ISD::isVPOpcode(Opcode);
46 
47   // The representative and legalized vector type of this operation.
48   VECustomDAG CDAG(DAG, Op);
49   // Dispatch to complex lowering functions.
50   switch (VVPOpcode) {
51   case VEISD::VVP_LOAD:
52   case VEISD::VVP_STORE:
53     return lowerVVP_LOAD_STORE(Op, CDAG);
54   case VEISD::VVP_GATHER:
55   case VEISD::VVP_SCATTER:
56     return lowerVVP_GATHER_SCATTER(Op, CDAG);
57   }
58 
59   EVT OpVecVT = *getIdiomaticVectorType(Op.getNode());
60   EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT);
61   auto Packing = getTypePacking(LegalVecVT.getSimpleVT());
62 
63   SDValue AVL;
64   SDValue Mask;
65 
66   if (FromVP) {
67     // All upstream VP SDNodes always have a mask and avl.
68     auto MaskIdx = ISD::getVPMaskIdx(Opcode);
69     auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
70     if (MaskIdx)
71       Mask = Op->getOperand(*MaskIdx);
72     if (AVLIdx)
73       AVL = Op->getOperand(*AVLIdx);
74   }
75 
76   // Materialize default mask and avl.
77   if (!AVL)
78     AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32);
79   if (!Mask)
80     Mask = CDAG.getConstantMask(Packing, true);
81 
82   assert(LegalVecVT.isSimple());
83   if (isVVPUnaryOp(VVPOpcode))
84     return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL});
85   if (isVVPBinaryOp(VVPOpcode))
86     return CDAG.getNode(VVPOpcode, LegalVecVT,
87                         {Op->getOperand(0), Op->getOperand(1), Mask, AVL});
88   if (isVVPReductionOp(VVPOpcode)) {
89     auto SrcHasStart = hasReductionStartParam(Op->getOpcode());
90     SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue();
91     SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0);
92     return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV,
93                                        VectorV, Mask, AVL, Op->getFlags());
94   }
95 
96   switch (VVPOpcode) {
97   default:
98     llvm_unreachable("lowerToVVP called for unexpected SDNode.");
99   case VEISD::VVP_FFMA: {
100     // VE has a swizzled operand order in FMA (compared to LLVM IR and
101     // SDNodes).
102     auto X = Op->getOperand(2);
103     auto Y = Op->getOperand(0);
104     auto Z = Op->getOperand(1);
105     return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL});
106   }
107   case VEISD::VVP_SELECT: {
108     auto Mask = Op->getOperand(0);
109     auto OnTrue = Op->getOperand(1);
110     auto OnFalse = Op->getOperand(2);
111     return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL});
112   }
113   case VEISD::VVP_SETCC: {
114     EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
115     auto LHS = Op->getOperand(0);
116     auto RHS = Op->getOperand(1);
117     auto Pred = Op->getOperand(2);
118     return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL});
119   }
120   }
121 }
122 
123 SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
124                                               VECustomDAG &CDAG) const {
125   auto VVPOpc = *getVVPOpcode(Op->getOpcode());
126   const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);
127 
128   // Shares.
129   SDValue BasePtr = getMemoryPtr(Op);
130   SDValue Mask = getNodeMask(Op);
131   SDValue Chain = getNodeChain(Op);
132   SDValue AVL = getNodeAVL(Op);
133   // Store specific.
134   SDValue Data = getStoredValue(Op);
135   // Load specific.
136   SDValue PassThru = getNodePassthru(Op);
137 
138   SDValue StrideV = getLoadStoreStride(Op, CDAG);
139 
140   auto DataVT = *getIdiomaticVectorType(Op.getNode());
141   auto Packing = getTypePacking(DataVT);
142 
143   // TODO: Infer lower AVL from mask.
144   if (!AVL)
145     AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
146 
147   // Default to the all-true mask.
148   if (!Mask)
149     Mask = CDAG.getConstantMask(Packing, true);
150 
151   if (IsLoad) {
152     MVT LegalDataVT = getLegalVectorType(
153         Packing, DataVT.getVectorElementType().getSimpleVT());
154 
155     auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other},
156                                  {Chain, BasePtr, StrideV, Mask, AVL});
157 
158     if (!PassThru || PassThru->isUndef())
159       return NewLoadV;
160 
161     // Convert passthru to an explicit select node.
162     SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT,
163                                  {NewLoadV, PassThru, Mask, AVL});
164     SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
165 
166     // Merge them back into one node.
167     return CDAG.getMergeValues({DataV, NewLoadChainV});
168   }
169 
170   // VVP_STORE
171   assert(VVPOpc == VEISD::VVP_STORE);
172   if (getTypeAction(*CDAG.getDAG()->getContext(), Data.getValueType()) !=
173       TargetLowering::TypeLegal)
174     // Doesn't lower store instruction if an operand is not lowered yet.
175     // If it isn't, return SDValue().  In this way, LLVM will try to lower
176     // store instruction again after lowering all operands.
177     return SDValue();
178   return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(),
179                       {Chain, Data, BasePtr, StrideV, Mask, AVL});
180 }
181 
182 SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
183                                                VECustomDAG &CDAG) const {
184   auto VVPOC = *getVVPOpcode(Op.getOpcode());
185   assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));
186 
187   MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
188   assert(getTypePacking(DataVT) == Packing::Dense &&
189          "Can only split packed load/store");
190   MVT SplitDataVT = splitVectorType(DataVT);
191 
192   assert(!getNodePassthru(Op) &&
193          "Should have been folded in lowering to VVP layer");
194 
195   // Analyze the operation
196   SDValue PackedMask = getNodeMask(Op);
197   SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
198   SDValue PackPtr = getMemoryPtr(Op);
199   SDValue PackData = getStoredValue(Op);
200   SDValue PackStride = getLoadStoreStride(Op, CDAG);
201 
202   unsigned ChainResIdx = PackData ? 0 : 1;
203 
204   SDValue PartOps[2];
205 
206   SDValue UpperPartAVL; // we will use this for packing things back together
207   for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
208     // VP ops already have an explicit mask and AVL. When expanding from non-VP
209     // attach those additional inputs here.
210     auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
211 
212     // Keep track of the (higher) lvl.
213     if (Part == PackElem::Hi)
214       UpperPartAVL = SplitTM.AVL;
215 
216     // Attach non-predicating value operands
217     SmallVector<SDValue, 4> OpVec;
218 
219     // Chain
220     OpVec.push_back(getNodeChain(Op));
221 
222     // Data
223     if (PackData) {
224       SDValue PartData =
225           CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL);
226       OpVec.push_back(PartData);
227     }
228 
229     // Ptr & Stride
230     // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
231     // Stride info
232     // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
233     OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part));
234     OpVec.push_back(CDAG.getSplitPtrStride(PackStride));
235 
236     // Add predicating args and generate part node
237     OpVec.push_back(SplitTM.Mask);
238     OpVec.push_back(SplitTM.AVL);
239 
240     if (PackData) {
241       // Store
242       PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec);
243     } else {
244       // Load
245       PartOps[(int)Part] =
246           CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec);
247     }
248   }
249 
250   // Merge the chains
251   SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
252   SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
253   SDValue FusedChains =
254       CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain});
255 
256   // Chain only [store]
257   if (PackData)
258     return FusedChains;
259 
260   // Re-pack into full packed vector result
261   MVT PackedVT =
262       getLegalVectorType(Packing::Dense, DataVT.getVectorElementType());
263   SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo],
264                                     PartOps[(int)PackElem::Hi], UpperPartAVL);
265 
266   return CDAG.getMergeValues({PackedVals, FusedChains});
267 }
268 
269 SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
270                                                   VECustomDAG &CDAG) const {
271   EVT DataVT = *getIdiomaticVectorType(Op.getNode());
272   auto Packing = getTypePacking(DataVT);
273   MVT LegalDataVT =
274       getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT());
275 
276   SDValue AVL = getAnnotatedNodeAVL(Op).first;
277   SDValue Index = getGatherScatterIndex(Op);
278   SDValue BasePtr = getMemoryPtr(Op);
279   SDValue Mask = getNodeMask(Op);
280   SDValue Chain = getNodeChain(Op);
281   SDValue Scale = getGatherScatterScale(Op);
282   SDValue PassThru = getNodePassthru(Op);
283   SDValue StoredValue = getStoredValue(Op);
284   if (PassThru && PassThru->isUndef())
285     PassThru = SDValue();
286 
287   bool IsScatter = (bool)StoredValue;
288 
289   // TODO: Infer lower AVL from mask.
290   if (!AVL)
291     AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
292 
293   // Default to the all-true mask.
294   if (!Mask)
295     Mask = CDAG.getConstantMask(Packing, true);
296 
297   SDValue AddressVec =
298       CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
299   if (IsScatter)
300     return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other,
301                         {Chain, StoredValue, AddressVec, Mask, AVL});
302 
303   // Gather.
304   SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other},
305                                   {Chain, AddressVec, Mask, AVL});
306 
307   if (!PassThru)
308     return NewLoadV;
309 
310   // TODO: Use vvp_select
311   SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT,
312                                {NewLoadV, PassThru, Mask, AVL});
313   SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
314   return CDAG.getMergeValues({DataV, NewLoadChainV});
315 }
316 
317 SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
318                                                       VECustomDAG &CDAG) const {
319   LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
320   MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
321 
322   // TODO: Recognize packable load,store.
323   if (isPackedVectorType(DataVT))
324     return splitPackedLoadStore(Op, CDAG);
325 
326   return legalizePackedAVL(Op, CDAG);
327 }
328 
329 SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
330                                                    SelectionDAG &DAG) const {
331   LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
332   VECustomDAG CDAG(DAG, Op);
333 
334   // Dispatch to specialized legalization functions.
335   switch (Op->getOpcode()) {
336   case VEISD::VVP_LOAD:
337   case VEISD::VVP_STORE:
338     return legalizeInternalLoadStoreOp(Op, CDAG);
339   }
340 
341   EVT IdiomVT = Op.getValueType();
342   if (isPackedVectorType(IdiomVT) &&
343       !supportsPackedMode(Op.getOpcode(), IdiomVT))
344     return splitVectorOp(Op, CDAG);
345 
346   // TODO: Implement odd/even splitting.
347   return legalizePackedAVL(Op, CDAG);
348 }
349 
350 SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
351   MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType());
352 
353   auto AVLPos = getAVLPos(Op->getOpcode());
354   auto MaskPos = getMaskPos(Op->getOpcode());
355 
356   SDValue PackedMask = getNodeMask(Op);
357   auto AVLPair = getAnnotatedNodeAVL(Op);
358   SDValue PackedAVL = AVLPair.first;
359   assert(!AVLPair.second && "Expecting non pack-legalized oepration");
360 
361   // request the parts
362   SDValue PartOps[2];
363 
364   SDValue UpperPartAVL; // we will use this for packing things back together
365   for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
366     // VP ops already have an explicit mask and AVL. When expanding from non-VP
367     // attach those additional inputs here.
368     auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
369 
370     if (Part == PackElem::Hi)
371       UpperPartAVL = SplitTM.AVL;
372 
373     // Attach non-predicating value operands
374     SmallVector<SDValue, 4> OpVec;
375     for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
376       if (AVLPos && ((int)i) == *AVLPos)
377         continue;
378       if (MaskPos && ((int)i) == *MaskPos)
379         continue;
380 
381       // Value operand
382       auto PackedOperand = Op.getOperand(i);
383       auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType());
384       SDValue PartV =
385           CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL);
386       OpVec.push_back(PartV);
387     }
388 
389     // Add predicating args and generate part node.
390     OpVec.push_back(SplitTM.Mask);
391     OpVec.push_back(SplitTM.AVL);
392     // Emit legal VVP nodes.
393     PartOps[(int)Part] =
394         CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags());
395   }
396 
397   // Re-package vectors.
398   return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo],
399                       PartOps[(int)PackElem::Hi], UpperPartAVL);
400 }
401 
402 SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
403                                             VECustomDAG &CDAG) const {
404   LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
405   // Only required for VEC and VVP ops.
406   if (!isVVPOrVEC(Op->getOpcode()))
407     return Op;
408 
409   // Operation already has a legal AVL.
410   auto AVL = getNodeAVL(Op);
411   if (isLegalAVL(AVL))
412     return Op;
413 
414   // Half and round up EVL for 32bit element types.
415   SDValue LegalAVL = AVL;
416   MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
417   if (isPackedVectorType(IdiomVT)) {
418     assert(maySafelyIgnoreMask(Op) &&
419            "TODO Shift predication from EVL into Mask");
420 
421     if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {
422       LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);
423     } else {
424       auto ConstOne = CDAG.getConstant(1, MVT::i32);
425       auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});
426       LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});
427     }
428   }
429 
430   SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);
431 
432   // Copy the operand list.
433   int NumOp = Op->getNumOperands();
434   auto AVLPos = getAVLPos(Op->getOpcode());
435   std::vector<SDValue> FixedOperands;
436   for (int i = 0; i < NumOp; ++i) {
437     if (AVLPos && (i == *AVLPos)) {
438       FixedOperands.push_back(AnnotatedLegalAVL);
439       continue;
440     }
441     FixedOperands.push_back(Op->getOperand(i));
442   }
443 
444   // Clone the operation with fixed operands.
445   auto Flags = Op->getFlags();
446   SDValue NewN =
447       CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);
448   return NewN;
449 }
450