xref: /freebsd/contrib/llvm-project/llvm/lib/Target/VE/VECustomDAG.cpp (revision d439598dd0d341b0c0b77151ba904e09c42f8421)
1 //===-- VECustomDAG.h - VE Custom DAG Nodes ------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that VE uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "VECustomDAG.h"
15 
16 #ifndef DEBUG_TYPE
17 #define DEBUG_TYPE "vecustomdag"
18 #endif
19 
20 namespace llvm {
21 
22 bool isPackedVectorType(EVT SomeVT) {
23   if (!SomeVT.isVector())
24     return false;
25   return SomeVT.getVectorNumElements() > StandardVectorWidth;
26 }
27 
28 MVT splitVectorType(MVT VT) {
29   if (!VT.isVector())
30     return VT;
31   return MVT::getVectorVT(VT.getVectorElementType(), StandardVectorWidth);
32 }
33 
34 MVT getLegalVectorType(Packing P, MVT ElemVT) {
35   return MVT::getVectorVT(ElemVT, P == Packing::Normal ? StandardVectorWidth
36                                                        : PackedVectorWidth);
37 }
38 
39 Packing getTypePacking(EVT VT) {
40   assert(VT.isVector());
41   return isPackedVectorType(VT) ? Packing::Dense : Packing::Normal;
42 }
43 
44 bool isMaskType(EVT SomeVT) {
45   if (!SomeVT.isVector())
46     return false;
47   return SomeVT.getVectorElementType() == MVT::i1;
48 }
49 
50 bool isMaskArithmetic(SDValue Op) {
51   switch (Op.getOpcode()) {
52   default:
53     return false;
54   case ISD::AND:
55   case ISD::XOR:
56   case ISD::OR:
57     return isMaskType(Op.getValueType());
58   }
59 }
60 
61 /// \returns the VVP_* SDNode opcode corresponsing to \p OC.
62 std::optional<unsigned> getVVPOpcode(unsigned Opcode) {
63   switch (Opcode) {
64   case ISD::MLOAD:
65     return VEISD::VVP_LOAD;
66   case ISD::MSTORE:
67     return VEISD::VVP_STORE;
68 #define HANDLE_VP_TO_VVP(VPOPC, VVPNAME)                                       \
69   case ISD::VPOPC:                                                             \
70     return VEISD::VVPNAME;
71 #define ADD_VVP_OP(VVPNAME, SDNAME)                                            \
72   case VEISD::VVPNAME:                                                         \
73   case ISD::SDNAME:                                                            \
74     return VEISD::VVPNAME;
75 #include "VVPNodes.def"
76   // TODO: Map those in VVPNodes.def too
77   case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
78     return VEISD::VVP_LOAD;
79   case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
80     return VEISD::VVP_STORE;
81   }
82   return std::nullopt;
83 }
84 
85 bool maySafelyIgnoreMask(SDValue Op) {
86   auto VVPOpc = getVVPOpcode(Op->getOpcode());
87   auto Opc = VVPOpc.value_or(Op->getOpcode());
88 
89   switch (Opc) {
90   case VEISD::VVP_SDIV:
91   case VEISD::VVP_UDIV:
92   case VEISD::VVP_FDIV:
93   case VEISD::VVP_SELECT:
94     return false;
95 
96   default:
97     return true;
98   }
99 }
100 
101 bool supportsPackedMode(unsigned Opcode, EVT IdiomVT) {
102   bool IsPackedOp = isPackedVectorType(IdiomVT);
103   bool IsMaskOp = isMaskType(IdiomVT);
104   switch (Opcode) {
105   default:
106     return false;
107 
108   case VEISD::VEC_BROADCAST:
109     return true;
110 #define REGISTER_PACKED(VVP_NAME) case VEISD::VVP_NAME:
111 #include "VVPNodes.def"
112     return IsPackedOp && !IsMaskOp;
113   }
114 }
115 
116 bool isPackingSupportOpcode(unsigned Opc) {
117   switch (Opc) {
118   case VEISD::VEC_PACK:
119   case VEISD::VEC_UNPACK_LO:
120   case VEISD::VEC_UNPACK_HI:
121     return true;
122   }
123   return false;
124 }
125 
126 bool isVVPOrVEC(unsigned Opcode) {
127   switch (Opcode) {
128   case VEISD::VEC_BROADCAST:
129 #define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME:
130 #include "VVPNodes.def"
131     return true;
132   }
133   return false;
134 }
135 
136 bool isVVPUnaryOp(unsigned VVPOpcode) {
137   switch (VVPOpcode) {
138 #define ADD_UNARY_VVP_OP(VVPNAME, ...)                                         \
139   case VEISD::VVPNAME:                                                         \
140     return true;
141 #include "VVPNodes.def"
142   }
143   return false;
144 }
145 
146 bool isVVPBinaryOp(unsigned VVPOpcode) {
147   switch (VVPOpcode) {
148 #define ADD_BINARY_VVP_OP(VVPNAME, ...)                                        \
149   case VEISD::VVPNAME:                                                         \
150     return true;
151 #include "VVPNodes.def"
152   }
153   return false;
154 }
155 
156 bool isVVPReductionOp(unsigned Opcode) {
157   switch (Opcode) {
158 #define ADD_REDUCE_VVP_OP(VVP_NAME, SDNAME) case VEISD::VVP_NAME:
159 #include "VVPNodes.def"
160     return true;
161   }
162   return false;
163 }
164 
165 // Return the AVL operand position for this VVP or VEC Op.
166 std::optional<int> getAVLPos(unsigned Opc) {
167   // This is only available for VP SDNodes
168   auto PosOpt = ISD::getVPExplicitVectorLengthIdx(Opc);
169   if (PosOpt)
170     return *PosOpt;
171 
172   // VVP Opcodes.
173   if (isVVPBinaryOp(Opc))
174     return 3;
175 
176   // VM Opcodes.
177   switch (Opc) {
178   case VEISD::VEC_BROADCAST:
179     return 1;
180   case VEISD::VVP_SELECT:
181     return 3;
182   case VEISD::VVP_LOAD:
183     return 4;
184   case VEISD::VVP_STORE:
185     return 5;
186   }
187 
188   return std::nullopt;
189 }
190 
191 std::optional<int> getMaskPos(unsigned Opc) {
192   // This is only available for VP SDNodes
193   auto PosOpt = ISD::getVPMaskIdx(Opc);
194   if (PosOpt)
195     return *PosOpt;
196 
197   // VVP Opcodes.
198   if (isVVPBinaryOp(Opc))
199     return 2;
200 
201   // Other opcodes.
202   switch (Opc) {
203   case ISD::MSTORE:
204     return 4;
205   case ISD::MLOAD:
206     return 3;
207   case VEISD::VVP_SELECT:
208     return 2;
209   }
210 
211   return std::nullopt;
212 }
213 
214 bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; }
215 
216 /// Node Properties {
217 
218 SDValue getNodeChain(SDValue Op) {
219   if (MemSDNode *MemN = dyn_cast<MemSDNode>(Op.getNode()))
220     return MemN->getChain();
221 
222   switch (Op->getOpcode()) {
223   case VEISD::VVP_LOAD:
224   case VEISD::VVP_STORE:
225     return Op->getOperand(0);
226   }
227   return SDValue();
228 }
229 
230 SDValue getMemoryPtr(SDValue Op) {
231   if (auto *MemN = dyn_cast<MemSDNode>(Op.getNode()))
232     return MemN->getBasePtr();
233 
234   switch (Op->getOpcode()) {
235   case VEISD::VVP_LOAD:
236     return Op->getOperand(1);
237   case VEISD::VVP_STORE:
238     return Op->getOperand(2);
239   }
240   return SDValue();
241 }
242 
243 std::optional<EVT> getIdiomaticVectorType(SDNode *Op) {
244   unsigned OC = Op->getOpcode();
245 
246   // For memory ops -> the transfered data type
247   if (auto MemN = dyn_cast<MemSDNode>(Op))
248     return MemN->getMemoryVT();
249 
250   switch (OC) {
251   // Standard ISD.
252   case ISD::SELECT: // not aliased with VVP_SELECT
253   case ISD::CONCAT_VECTORS:
254   case ISD::EXTRACT_SUBVECTOR:
255   case ISD::VECTOR_SHUFFLE:
256   case ISD::BUILD_VECTOR:
257   case ISD::SCALAR_TO_VECTOR:
258     return Op->getValueType(0);
259   }
260 
261   // Translate to VVP where possible.
262   unsigned OriginalOC = OC;
263   if (auto VVPOpc = getVVPOpcode(OC))
264     OC = *VVPOpc;
265 
266   if (isVVPReductionOp(OC))
267     return Op->getOperand(hasReductionStartParam(OriginalOC) ? 1 : 0)
268         .getValueType();
269 
270   switch (OC) {
271   default:
272   case VEISD::VVP_SETCC:
273     return Op->getOperand(0).getValueType();
274 
275   case VEISD::VVP_SELECT:
276 #define ADD_BINARY_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME:
277 #include "VVPNodes.def"
278     return Op->getValueType(0);
279 
280   case VEISD::VVP_LOAD:
281     return Op->getValueType(0);
282 
283   case VEISD::VVP_STORE:
284     return Op->getOperand(1)->getValueType(0);
285 
286   // VEC
287   case VEISD::VEC_BROADCAST:
288     return Op->getValueType(0);
289   }
290 }
291 
292 SDValue getLoadStoreStride(SDValue Op, VECustomDAG &CDAG) {
293   switch (Op->getOpcode()) {
294   case VEISD::VVP_STORE:
295     return Op->getOperand(3);
296   case VEISD::VVP_LOAD:
297     return Op->getOperand(2);
298   }
299 
300   if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Op.getNode()))
301     return StoreN->getStride();
302   if (auto *StoreN = dyn_cast<VPStridedLoadSDNode>(Op.getNode()))
303     return StoreN->getStride();
304 
305   if (isa<MemSDNode>(Op.getNode())) {
306     // Regular MLOAD/MSTORE/LOAD/STORE
307     // No stride argument -> use the contiguous element size as stride.
308     uint64_t ElemStride = getIdiomaticVectorType(Op.getNode())
309                               ->getVectorElementType()
310                               .getStoreSize();
311     return CDAG.getConstant(ElemStride, MVT::i64);
312   }
313   return SDValue();
314 }
315 
316 SDValue getGatherScatterIndex(SDValue Op) {
317   if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Op.getNode()))
318     return N->getIndex();
319   if (auto *N = dyn_cast<VPGatherScatterSDNode>(Op.getNode()))
320     return N->getIndex();
321   return SDValue();
322 }
323 
324 SDValue getGatherScatterScale(SDValue Op) {
325   if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Op.getNode()))
326     return N->getScale();
327   if (auto *N = dyn_cast<VPGatherScatterSDNode>(Op.getNode()))
328     return N->getScale();
329   return SDValue();
330 }
331 
332 SDValue getStoredValue(SDValue Op) {
333   switch (Op->getOpcode()) {
334   case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
335   case VEISD::VVP_STORE:
336     return Op->getOperand(1);
337   }
338   if (auto *StoreN = dyn_cast<StoreSDNode>(Op.getNode()))
339     return StoreN->getValue();
340   if (auto *StoreN = dyn_cast<MaskedStoreSDNode>(Op.getNode()))
341     return StoreN->getValue();
342   if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Op.getNode()))
343     return StoreN->getValue();
344   if (auto *StoreN = dyn_cast<VPStoreSDNode>(Op.getNode()))
345     return StoreN->getValue();
346   if (auto *StoreN = dyn_cast<MaskedScatterSDNode>(Op.getNode()))
347     return StoreN->getValue();
348   if (auto *StoreN = dyn_cast<VPScatterSDNode>(Op.getNode()))
349     return StoreN->getValue();
350   return SDValue();
351 }
352 
353 SDValue getNodePassthru(SDValue Op) {
354   if (auto *N = dyn_cast<MaskedLoadSDNode>(Op.getNode()))
355     return N->getPassThru();
356   if (auto *N = dyn_cast<MaskedGatherSDNode>(Op.getNode()))
357     return N->getPassThru();
358 
359   return SDValue();
360 }
361 
362 bool hasReductionStartParam(unsigned OPC) {
363   // TODO: Ordered reduction opcodes.
364   if (ISD::isVPReduction(OPC))
365     return true;
366   return false;
367 }
368 
369 unsigned getScalarReductionOpcode(unsigned VVPOC, bool IsMask) {
370   assert(!IsMask && "Mask reduction isel");
371 
372   switch (VVPOC) {
373 #define HANDLE_VVP_REDUCE_TO_SCALAR(VVP_RED_ISD, REDUCE_ISD)                   \
374   case VEISD::VVP_RED_ISD:                                                     \
375     return ISD::REDUCE_ISD;
376 #include "VVPNodes.def"
377   default:
378     break;
379   }
380   llvm_unreachable("Cannot not scalarize this reduction Opcode!");
381 }
382 
383 /// } Node Properties
384 
385 SDValue getNodeAVL(SDValue Op) {
386   auto PosOpt = getAVLPos(Op->getOpcode());
387   return PosOpt ? Op->getOperand(*PosOpt) : SDValue();
388 }
389 
390 SDValue getNodeMask(SDValue Op) {
391   auto PosOpt = getMaskPos(Op->getOpcode());
392   return PosOpt ? Op->getOperand(*PosOpt) : SDValue();
393 }
394 
395 std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue Op) {
396   SDValue AVL = getNodeAVL(Op);
397   if (!AVL)
398     return {SDValue(), true};
399   if (isLegalAVL(AVL))
400     return {AVL->getOperand(0), true};
401   return {AVL, false};
402 }
403 
404 SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget,
405                                  bool IsOpaque) const {
406   return DAG.getConstant(Val, DL, VT, IsTarget, IsOpaque);
407 }
408 
409 SDValue VECustomDAG::getConstantMask(Packing Packing, bool AllTrue) const {
410   auto MaskVT = getLegalVectorType(Packing, MVT::i1);
411 
412   // VEISelDAGtoDAG will replace this pattern with the constant-true VM.
413   auto TrueVal = DAG.getConstant(-1, DL, MVT::i32);
414   auto AVL = getConstant(MaskVT.getVectorNumElements(), MVT::i32);
415   auto Res = getNode(VEISD::VEC_BROADCAST, MaskVT, {TrueVal, AVL});
416   if (AllTrue)
417     return Res;
418 
419   return DAG.getNOT(DL, Res, Res.getValueType());
420 }
421 
422 SDValue VECustomDAG::getMaskBroadcast(EVT ResultVT, SDValue Scalar,
423                                       SDValue AVL) const {
424   // Constant mask splat.
425   if (auto BcConst = dyn_cast<ConstantSDNode>(Scalar))
426     return getConstantMask(getTypePacking(ResultVT),
427                            BcConst->getSExtValue() != 0);
428 
429   // Expand the broadcast to a vector comparison.
430   auto ScalarBoolVT = Scalar.getSimpleValueType();
431   assert(ScalarBoolVT == MVT::i32);
432 
433   // Cast to i32 ty.
434   SDValue CmpElem = DAG.getSExtOrTrunc(Scalar, DL, MVT::i32);
435   unsigned ElemCount = ResultVT.getVectorNumElements();
436   MVT CmpVecTy = MVT::getVectorVT(ScalarBoolVT, ElemCount);
437 
438   // Broadcast to vector.
439   SDValue BCVec =
440       DAG.getNode(VEISD::VEC_BROADCAST, DL, CmpVecTy, {CmpElem, AVL});
441   SDValue ZeroVec =
442       getBroadcast(CmpVecTy, {DAG.getConstant(0, DL, ScalarBoolVT)}, AVL);
443 
444   MVT BoolVecTy = MVT::getVectorVT(MVT::i1, ElemCount);
445 
446   // Broadcast(Data) != Broadcast(0)
447   // TODO: Use a VVP operation for this.
448   return DAG.getSetCC(DL, BoolVecTy, BCVec, ZeroVec, ISD::CondCode::SETNE);
449 }
450 
451 SDValue VECustomDAG::getBroadcast(EVT ResultVT, SDValue Scalar,
452                                   SDValue AVL) const {
453   assert(ResultVT.isVector());
454   auto ScaVT = Scalar.getValueType();
455 
456   if (isMaskType(ResultVT))
457     return getMaskBroadcast(ResultVT, Scalar, AVL);
458 
459   if (isPackedVectorType(ResultVT)) {
460     // v512x packed mode broadcast
461     // Replicate the scalar reg (f32 or i32) onto the opposing half of the full
462     // scalar register. If it's an I64 type, assume that this has already
463     // happened.
464     if (ScaVT == MVT::f32) {
465       Scalar = getNode(VEISD::REPL_F32, MVT::i64, Scalar);
466     } else if (ScaVT == MVT::i32) {
467       Scalar = getNode(VEISD::REPL_I32, MVT::i64, Scalar);
468     }
469   }
470 
471   return getNode(VEISD::VEC_BROADCAST, ResultVT, {Scalar, AVL});
472 }
473 
474 SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const {
475   if (isLegalAVL(AVL))
476     return AVL;
477   return getNode(VEISD::LEGALAVL, AVL.getValueType(), AVL);
478 }
479 
480 SDValue VECustomDAG::getUnpack(EVT DestVT, SDValue Vec, PackElem Part,
481                                SDValue AVL) const {
482   assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");
483 
484   // TODO: Peek through VEC_PACK and VEC_BROADCAST(REPL_<sth> ..) operands.
485   unsigned OC =
486       (Part == PackElem::Lo) ? VEISD::VEC_UNPACK_LO : VEISD::VEC_UNPACK_HI;
487   return DAG.getNode(OC, DL, DestVT, Vec, AVL);
488 }
489 
490 SDValue VECustomDAG::getPack(EVT DestVT, SDValue LoVec, SDValue HiVec,
491                              SDValue AVL) const {
492   assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");
493 
494   // TODO: Peek through VEC_UNPACK_LO|HI operands.
495   return DAG.getNode(VEISD::VEC_PACK, DL, DestVT, LoVec, HiVec, AVL);
496 }
497 
498 VETargetMasks VECustomDAG::getTargetSplitMask(SDValue RawMask, SDValue RawAVL,
499                                               PackElem Part) const {
500   // Adjust AVL for this part
501   SDValue NewAVL;
502   SDValue OneV = getConstant(1, MVT::i32);
503   if (Part == PackElem::Hi)
504     NewAVL = getNode(ISD::ADD, MVT::i32, {RawAVL, OneV});
505   else
506     NewAVL = RawAVL;
507   NewAVL = getNode(ISD::SRL, MVT::i32, {NewAVL, OneV});
508 
509   NewAVL = annotateLegalAVL(NewAVL);
510 
511   // Legalize Mask (unpack or all-true)
512   SDValue NewMask;
513   if (!RawMask)
514     NewMask = getConstantMask(Packing::Normal, true);
515   else
516     NewMask = getUnpack(MVT::v256i1, RawMask, Part, NewAVL);
517 
518   return VETargetMasks(NewMask, NewAVL);
519 }
520 
521 SDValue VECustomDAG::getSplitPtrOffset(SDValue Ptr, SDValue ByteStride,
522                                        PackElem Part) const {
523   // High starts at base ptr but has more significant bits in the 64bit vector
524   // element.
525   if (Part == PackElem::Hi)
526     return Ptr;
527   return getNode(ISD::ADD, MVT::i64, {Ptr, ByteStride});
528 }
529 
530 SDValue VECustomDAG::getSplitPtrStride(SDValue PackStride) const {
531   if (auto ConstBytes = dyn_cast<ConstantSDNode>(PackStride))
532     return getConstant(2 * ConstBytes->getSExtValue(), MVT::i64);
533   return getNode(ISD::SHL, MVT::i64, {PackStride, getConstant(1, MVT::i32)});
534 }
535 
536 SDValue VECustomDAG::getGatherScatterAddress(SDValue BasePtr, SDValue Scale,
537                                              SDValue Index, SDValue Mask,
538                                              SDValue AVL) const {
539   EVT IndexVT = Index.getValueType();
540 
541   // Apply scale.
542   SDValue ScaledIndex;
543   if (!Scale || isOneConstant(Scale))
544     ScaledIndex = Index;
545   else {
546     SDValue ScaleBroadcast = getBroadcast(IndexVT, Scale, AVL);
547     ScaledIndex =
548         getNode(VEISD::VVP_MUL, IndexVT, {Index, ScaleBroadcast, Mask, AVL});
549   }
550 
551   // Add basePtr.
552   if (isNullConstant(BasePtr))
553     return ScaledIndex;
554 
555   // re-constitute pointer vector (basePtr + index * scale)
556   SDValue BaseBroadcast = getBroadcast(IndexVT, BasePtr, AVL);
557   auto ResPtr =
558       getNode(VEISD::VVP_ADD, IndexVT, {BaseBroadcast, ScaledIndex, Mask, AVL});
559   return ResPtr;
560 }
561 
562 SDValue VECustomDAG::getLegalReductionOpVVP(unsigned VVPOpcode, EVT ResVT,
563                                             SDValue StartV, SDValue VectorV,
564                                             SDValue Mask, SDValue AVL,
565                                             SDNodeFlags Flags) const {
566 
567   // Optionally attach the start param with a scalar op (where it is
568   // unsupported).
569   bool scalarizeStartParam = StartV && !hasReductionStartParam(VVPOpcode);
570   bool IsMaskReduction = isMaskType(VectorV.getValueType());
571   assert(!IsMaskReduction && "TODO Implement");
572   auto AttachStartValue = [&](SDValue ReductionResV) {
573     if (!scalarizeStartParam)
574       return ReductionResV;
575     auto ScalarOC = getScalarReductionOpcode(VVPOpcode, IsMaskReduction);
576     return getNode(ScalarOC, ResVT, {StartV, ReductionResV});
577   };
578 
579   // Fixup: Always Use sequential 'fmul' reduction.
580   if (!scalarizeStartParam && StartV) {
581     assert(hasReductionStartParam(VVPOpcode));
582     return AttachStartValue(
583         getNode(VVPOpcode, ResVT, {StartV, VectorV, Mask, AVL}, Flags));
584   } else
585     return AttachStartValue(
586         getNode(VVPOpcode, ResVT, {VectorV, Mask, AVL}, Flags));
587 }
588 
589 } // namespace llvm
590