1 //===-- VECustomDAG.h - VE Custom DAG Nodes ------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the interfaces that VE uses to lower LLVM code into a 10 // selection DAG. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "VECustomDAG.h" 15 16 #ifndef DEBUG_TYPE 17 #define DEBUG_TYPE "vecustomdag" 18 #endif 19 20 namespace llvm { 21 22 bool isPackedVectorType(EVT SomeVT) { 23 if (!SomeVT.isVector()) 24 return false; 25 return SomeVT.getVectorNumElements() > StandardVectorWidth; 26 } 27 28 MVT splitVectorType(MVT VT) { 29 if (!VT.isVector()) 30 return VT; 31 return MVT::getVectorVT(VT.getVectorElementType(), StandardVectorWidth); 32 } 33 34 MVT getLegalVectorType(Packing P, MVT ElemVT) { 35 return MVT::getVectorVT(ElemVT, P == Packing::Normal ? StandardVectorWidth 36 : PackedVectorWidth); 37 } 38 39 Packing getTypePacking(EVT VT) { 40 assert(VT.isVector()); 41 return isPackedVectorType(VT) ? Packing::Dense : Packing::Normal; 42 } 43 44 bool isMaskType(EVT SomeVT) { 45 if (!SomeVT.isVector()) 46 return false; 47 return SomeVT.getVectorElementType() == MVT::i1; 48 } 49 50 bool isMaskArithmetic(SDValue Op) { 51 switch (Op.getOpcode()) { 52 default: 53 return false; 54 case ISD::AND: 55 case ISD::XOR: 56 case ISD::OR: 57 return isMaskType(Op.getValueType()); 58 } 59 } 60 61 /// \returns the VVP_* SDNode opcode corresponsing to \p OC. 62 std::optional<unsigned> getVVPOpcode(unsigned Opcode) { 63 switch (Opcode) { 64 case ISD::MLOAD: 65 return VEISD::VVP_LOAD; 66 case ISD::MSTORE: 67 return VEISD::VVP_STORE; 68 #define HANDLE_VP_TO_VVP(VPOPC, VVPNAME) \ 69 case ISD::VPOPC: \ 70 return VEISD::VVPNAME; 71 #define ADD_VVP_OP(VVPNAME, SDNAME) \ 72 case VEISD::VVPNAME: \ 73 case ISD::SDNAME: \ 74 return VEISD::VVPNAME; 75 #include "VVPNodes.def" 76 // TODO: Map those in VVPNodes.def too 77 case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: 78 return VEISD::VVP_LOAD; 79 case ISD::EXPERIMENTAL_VP_STRIDED_STORE: 80 return VEISD::VVP_STORE; 81 } 82 return std::nullopt; 83 } 84 85 bool maySafelyIgnoreMask(SDValue Op) { 86 auto VVPOpc = getVVPOpcode(Op->getOpcode()); 87 auto Opc = VVPOpc.value_or(Op->getOpcode()); 88 89 switch (Opc) { 90 case VEISD::VVP_SDIV: 91 case VEISD::VVP_UDIV: 92 case VEISD::VVP_FDIV: 93 case VEISD::VVP_SELECT: 94 return false; 95 96 default: 97 return true; 98 } 99 } 100 101 bool supportsPackedMode(unsigned Opcode, EVT IdiomVT) { 102 bool IsPackedOp = isPackedVectorType(IdiomVT); 103 bool IsMaskOp = isMaskType(IdiomVT); 104 switch (Opcode) { 105 default: 106 return false; 107 108 case VEISD::VEC_BROADCAST: 109 return true; 110 #define REGISTER_PACKED(VVP_NAME) case VEISD::VVP_NAME: 111 #include "VVPNodes.def" 112 return IsPackedOp && !IsMaskOp; 113 } 114 } 115 116 bool isPackingSupportOpcode(unsigned Opc) { 117 switch (Opc) { 118 case VEISD::VEC_PACK: 119 case VEISD::VEC_UNPACK_LO: 120 case VEISD::VEC_UNPACK_HI: 121 return true; 122 } 123 return false; 124 } 125 126 bool isVVPOrVEC(unsigned Opcode) { 127 switch (Opcode) { 128 case VEISD::VEC_BROADCAST: 129 #define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME: 130 #include "VVPNodes.def" 131 return true; 132 } 133 return false; 134 } 135 136 bool isVVPUnaryOp(unsigned VVPOpcode) { 137 switch (VVPOpcode) { 138 #define ADD_UNARY_VVP_OP(VVPNAME, ...) \ 139 case VEISD::VVPNAME: \ 140 return true; 141 #include "VVPNodes.def" 142 } 143 return false; 144 } 145 146 bool isVVPBinaryOp(unsigned VVPOpcode) { 147 switch (VVPOpcode) { 148 #define ADD_BINARY_VVP_OP(VVPNAME, ...) \ 149 case VEISD::VVPNAME: \ 150 return true; 151 #include "VVPNodes.def" 152 } 153 return false; 154 } 155 156 bool isVVPReductionOp(unsigned Opcode) { 157 switch (Opcode) { 158 #define ADD_REDUCE_VVP_OP(VVP_NAME, SDNAME) case VEISD::VVP_NAME: 159 #include "VVPNodes.def" 160 return true; 161 } 162 return false; 163 } 164 165 // Return the AVL operand position for this VVP or VEC Op. 166 std::optional<int> getAVLPos(unsigned Opc) { 167 // This is only available for VP SDNodes 168 auto PosOpt = ISD::getVPExplicitVectorLengthIdx(Opc); 169 if (PosOpt) 170 return *PosOpt; 171 172 // VVP Opcodes. 173 if (isVVPBinaryOp(Opc)) 174 return 3; 175 176 // VM Opcodes. 177 switch (Opc) { 178 case VEISD::VEC_BROADCAST: 179 return 1; 180 case VEISD::VVP_SELECT: 181 return 3; 182 case VEISD::VVP_LOAD: 183 return 4; 184 case VEISD::VVP_STORE: 185 return 5; 186 } 187 188 return std::nullopt; 189 } 190 191 std::optional<int> getMaskPos(unsigned Opc) { 192 // This is only available for VP SDNodes 193 auto PosOpt = ISD::getVPMaskIdx(Opc); 194 if (PosOpt) 195 return *PosOpt; 196 197 // VVP Opcodes. 198 if (isVVPBinaryOp(Opc)) 199 return 2; 200 201 // Other opcodes. 202 switch (Opc) { 203 case ISD::MSTORE: 204 return 4; 205 case ISD::MLOAD: 206 return 3; 207 case VEISD::VVP_SELECT: 208 return 2; 209 } 210 211 return std::nullopt; 212 } 213 214 bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; } 215 216 /// Node Properties { 217 218 SDValue getNodeChain(SDValue Op) { 219 if (MemSDNode *MemN = dyn_cast<MemSDNode>(Op.getNode())) 220 return MemN->getChain(); 221 222 switch (Op->getOpcode()) { 223 case VEISD::VVP_LOAD: 224 case VEISD::VVP_STORE: 225 return Op->getOperand(0); 226 } 227 return SDValue(); 228 } 229 230 SDValue getMemoryPtr(SDValue Op) { 231 if (auto *MemN = dyn_cast<MemSDNode>(Op.getNode())) 232 return MemN->getBasePtr(); 233 234 switch (Op->getOpcode()) { 235 case VEISD::VVP_LOAD: 236 return Op->getOperand(1); 237 case VEISD::VVP_STORE: 238 return Op->getOperand(2); 239 } 240 return SDValue(); 241 } 242 243 std::optional<EVT> getIdiomaticVectorType(SDNode *Op) { 244 unsigned OC = Op->getOpcode(); 245 246 // For memory ops -> the transfered data type 247 if (auto MemN = dyn_cast<MemSDNode>(Op)) 248 return MemN->getMemoryVT(); 249 250 switch (OC) { 251 // Standard ISD. 252 case ISD::SELECT: // not aliased with VVP_SELECT 253 case ISD::CONCAT_VECTORS: 254 case ISD::EXTRACT_SUBVECTOR: 255 case ISD::VECTOR_SHUFFLE: 256 case ISD::BUILD_VECTOR: 257 case ISD::SCALAR_TO_VECTOR: 258 return Op->getValueType(0); 259 } 260 261 // Translate to VVP where possible. 262 unsigned OriginalOC = OC; 263 if (auto VVPOpc = getVVPOpcode(OC)) 264 OC = *VVPOpc; 265 266 if (isVVPReductionOp(OC)) 267 return Op->getOperand(hasReductionStartParam(OriginalOC) ? 1 : 0) 268 .getValueType(); 269 270 switch (OC) { 271 default: 272 case VEISD::VVP_SETCC: 273 return Op->getOperand(0).getValueType(); 274 275 case VEISD::VVP_SELECT: 276 #define ADD_BINARY_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME: 277 #include "VVPNodes.def" 278 return Op->getValueType(0); 279 280 case VEISD::VVP_LOAD: 281 return Op->getValueType(0); 282 283 case VEISD::VVP_STORE: 284 return Op->getOperand(1)->getValueType(0); 285 286 // VEC 287 case VEISD::VEC_BROADCAST: 288 return Op->getValueType(0); 289 } 290 } 291 292 SDValue getLoadStoreStride(SDValue Op, VECustomDAG &CDAG) { 293 switch (Op->getOpcode()) { 294 case VEISD::VVP_STORE: 295 return Op->getOperand(3); 296 case VEISD::VVP_LOAD: 297 return Op->getOperand(2); 298 } 299 300 if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Op.getNode())) 301 return StoreN->getStride(); 302 if (auto *StoreN = dyn_cast<VPStridedLoadSDNode>(Op.getNode())) 303 return StoreN->getStride(); 304 305 if (isa<MemSDNode>(Op.getNode())) { 306 // Regular MLOAD/MSTORE/LOAD/STORE 307 // No stride argument -> use the contiguous element size as stride. 308 uint64_t ElemStride = getIdiomaticVectorType(Op.getNode()) 309 ->getVectorElementType() 310 .getStoreSize(); 311 return CDAG.getConstant(ElemStride, MVT::i64); 312 } 313 return SDValue(); 314 } 315 316 SDValue getGatherScatterIndex(SDValue Op) { 317 if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Op.getNode())) 318 return N->getIndex(); 319 if (auto *N = dyn_cast<VPGatherScatterSDNode>(Op.getNode())) 320 return N->getIndex(); 321 return SDValue(); 322 } 323 324 SDValue getGatherScatterScale(SDValue Op) { 325 if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Op.getNode())) 326 return N->getScale(); 327 if (auto *N = dyn_cast<VPGatherScatterSDNode>(Op.getNode())) 328 return N->getScale(); 329 return SDValue(); 330 } 331 332 SDValue getStoredValue(SDValue Op) { 333 switch (Op->getOpcode()) { 334 case ISD::EXPERIMENTAL_VP_STRIDED_STORE: 335 case VEISD::VVP_STORE: 336 return Op->getOperand(1); 337 } 338 if (auto *StoreN = dyn_cast<StoreSDNode>(Op.getNode())) 339 return StoreN->getValue(); 340 if (auto *StoreN = dyn_cast<MaskedStoreSDNode>(Op.getNode())) 341 return StoreN->getValue(); 342 if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Op.getNode())) 343 return StoreN->getValue(); 344 if (auto *StoreN = dyn_cast<VPStoreSDNode>(Op.getNode())) 345 return StoreN->getValue(); 346 if (auto *StoreN = dyn_cast<MaskedScatterSDNode>(Op.getNode())) 347 return StoreN->getValue(); 348 if (auto *StoreN = dyn_cast<VPScatterSDNode>(Op.getNode())) 349 return StoreN->getValue(); 350 return SDValue(); 351 } 352 353 SDValue getNodePassthru(SDValue Op) { 354 if (auto *N = dyn_cast<MaskedLoadSDNode>(Op.getNode())) 355 return N->getPassThru(); 356 if (auto *N = dyn_cast<MaskedGatherSDNode>(Op.getNode())) 357 return N->getPassThru(); 358 359 return SDValue(); 360 } 361 362 bool hasReductionStartParam(unsigned OPC) { 363 // TODO: Ordered reduction opcodes. 364 if (ISD::isVPReduction(OPC)) 365 return true; 366 return false; 367 } 368 369 unsigned getScalarReductionOpcode(unsigned VVPOC, bool IsMask) { 370 assert(!IsMask && "Mask reduction isel"); 371 372 switch (VVPOC) { 373 #define HANDLE_VVP_REDUCE_TO_SCALAR(VVP_RED_ISD, REDUCE_ISD) \ 374 case VEISD::VVP_RED_ISD: \ 375 return ISD::REDUCE_ISD; 376 #include "VVPNodes.def" 377 default: 378 break; 379 } 380 llvm_unreachable("Cannot not scalarize this reduction Opcode!"); 381 } 382 383 /// } Node Properties 384 385 SDValue getNodeAVL(SDValue Op) { 386 auto PosOpt = getAVLPos(Op->getOpcode()); 387 return PosOpt ? Op->getOperand(*PosOpt) : SDValue(); 388 } 389 390 SDValue getNodeMask(SDValue Op) { 391 auto PosOpt = getMaskPos(Op->getOpcode()); 392 return PosOpt ? Op->getOperand(*PosOpt) : SDValue(); 393 } 394 395 std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue Op) { 396 SDValue AVL = getNodeAVL(Op); 397 if (!AVL) 398 return {SDValue(), true}; 399 if (isLegalAVL(AVL)) 400 return {AVL->getOperand(0), true}; 401 return {AVL, false}; 402 } 403 404 SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget, 405 bool IsOpaque) const { 406 return DAG.getConstant(Val, DL, VT, IsTarget, IsOpaque); 407 } 408 409 SDValue VECustomDAG::getConstantMask(Packing Packing, bool AllTrue) const { 410 auto MaskVT = getLegalVectorType(Packing, MVT::i1); 411 412 // VEISelDAGtoDAG will replace this pattern with the constant-true VM. 413 auto TrueVal = DAG.getConstant(-1, DL, MVT::i32); 414 auto AVL = getConstant(MaskVT.getVectorNumElements(), MVT::i32); 415 auto Res = getNode(VEISD::VEC_BROADCAST, MaskVT, {TrueVal, AVL}); 416 if (AllTrue) 417 return Res; 418 419 return DAG.getNOT(DL, Res, Res.getValueType()); 420 } 421 422 SDValue VECustomDAG::getMaskBroadcast(EVT ResultVT, SDValue Scalar, 423 SDValue AVL) const { 424 // Constant mask splat. 425 if (auto BcConst = dyn_cast<ConstantSDNode>(Scalar)) 426 return getConstantMask(getTypePacking(ResultVT), 427 BcConst->getSExtValue() != 0); 428 429 // Expand the broadcast to a vector comparison. 430 auto ScalarBoolVT = Scalar.getSimpleValueType(); 431 assert(ScalarBoolVT == MVT::i32); 432 433 // Cast to i32 ty. 434 SDValue CmpElem = DAG.getSExtOrTrunc(Scalar, DL, MVT::i32); 435 unsigned ElemCount = ResultVT.getVectorNumElements(); 436 MVT CmpVecTy = MVT::getVectorVT(ScalarBoolVT, ElemCount); 437 438 // Broadcast to vector. 439 SDValue BCVec = 440 DAG.getNode(VEISD::VEC_BROADCAST, DL, CmpVecTy, {CmpElem, AVL}); 441 SDValue ZeroVec = 442 getBroadcast(CmpVecTy, {DAG.getConstant(0, DL, ScalarBoolVT)}, AVL); 443 444 MVT BoolVecTy = MVT::getVectorVT(MVT::i1, ElemCount); 445 446 // Broadcast(Data) != Broadcast(0) 447 // TODO: Use a VVP operation for this. 448 return DAG.getSetCC(DL, BoolVecTy, BCVec, ZeroVec, ISD::CondCode::SETNE); 449 } 450 451 SDValue VECustomDAG::getBroadcast(EVT ResultVT, SDValue Scalar, 452 SDValue AVL) const { 453 assert(ResultVT.isVector()); 454 auto ScaVT = Scalar.getValueType(); 455 456 if (isMaskType(ResultVT)) 457 return getMaskBroadcast(ResultVT, Scalar, AVL); 458 459 if (isPackedVectorType(ResultVT)) { 460 // v512x packed mode broadcast 461 // Replicate the scalar reg (f32 or i32) onto the opposing half of the full 462 // scalar register. If it's an I64 type, assume that this has already 463 // happened. 464 if (ScaVT == MVT::f32) { 465 Scalar = getNode(VEISD::REPL_F32, MVT::i64, Scalar); 466 } else if (ScaVT == MVT::i32) { 467 Scalar = getNode(VEISD::REPL_I32, MVT::i64, Scalar); 468 } 469 } 470 471 return getNode(VEISD::VEC_BROADCAST, ResultVT, {Scalar, AVL}); 472 } 473 474 SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const { 475 if (isLegalAVL(AVL)) 476 return AVL; 477 return getNode(VEISD::LEGALAVL, AVL.getValueType(), AVL); 478 } 479 480 SDValue VECustomDAG::getUnpack(EVT DestVT, SDValue Vec, PackElem Part, 481 SDValue AVL) const { 482 assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL"); 483 484 // TODO: Peek through VEC_PACK and VEC_BROADCAST(REPL_<sth> ..) operands. 485 unsigned OC = 486 (Part == PackElem::Lo) ? VEISD::VEC_UNPACK_LO : VEISD::VEC_UNPACK_HI; 487 return DAG.getNode(OC, DL, DestVT, Vec, AVL); 488 } 489 490 SDValue VECustomDAG::getPack(EVT DestVT, SDValue LoVec, SDValue HiVec, 491 SDValue AVL) const { 492 assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL"); 493 494 // TODO: Peek through VEC_UNPACK_LO|HI operands. 495 return DAG.getNode(VEISD::VEC_PACK, DL, DestVT, LoVec, HiVec, AVL); 496 } 497 498 VETargetMasks VECustomDAG::getTargetSplitMask(SDValue RawMask, SDValue RawAVL, 499 PackElem Part) const { 500 // Adjust AVL for this part 501 SDValue NewAVL; 502 SDValue OneV = getConstant(1, MVT::i32); 503 if (Part == PackElem::Hi) 504 NewAVL = getNode(ISD::ADD, MVT::i32, {RawAVL, OneV}); 505 else 506 NewAVL = RawAVL; 507 NewAVL = getNode(ISD::SRL, MVT::i32, {NewAVL, OneV}); 508 509 NewAVL = annotateLegalAVL(NewAVL); 510 511 // Legalize Mask (unpack or all-true) 512 SDValue NewMask; 513 if (!RawMask) 514 NewMask = getConstantMask(Packing::Normal, true); 515 else 516 NewMask = getUnpack(MVT::v256i1, RawMask, Part, NewAVL); 517 518 return VETargetMasks(NewMask, NewAVL); 519 } 520 521 SDValue VECustomDAG::getSplitPtrOffset(SDValue Ptr, SDValue ByteStride, 522 PackElem Part) const { 523 // High starts at base ptr but has more significant bits in the 64bit vector 524 // element. 525 if (Part == PackElem::Hi) 526 return Ptr; 527 return getNode(ISD::ADD, MVT::i64, {Ptr, ByteStride}); 528 } 529 530 SDValue VECustomDAG::getSplitPtrStride(SDValue PackStride) const { 531 if (auto ConstBytes = dyn_cast<ConstantSDNode>(PackStride)) 532 return getConstant(2 * ConstBytes->getSExtValue(), MVT::i64); 533 return getNode(ISD::SHL, MVT::i64, {PackStride, getConstant(1, MVT::i32)}); 534 } 535 536 SDValue VECustomDAG::getGatherScatterAddress(SDValue BasePtr, SDValue Scale, 537 SDValue Index, SDValue Mask, 538 SDValue AVL) const { 539 EVT IndexVT = Index.getValueType(); 540 541 // Apply scale. 542 SDValue ScaledIndex; 543 if (!Scale || isOneConstant(Scale)) 544 ScaledIndex = Index; 545 else { 546 SDValue ScaleBroadcast = getBroadcast(IndexVT, Scale, AVL); 547 ScaledIndex = 548 getNode(VEISD::VVP_MUL, IndexVT, {Index, ScaleBroadcast, Mask, AVL}); 549 } 550 551 // Add basePtr. 552 if (isNullConstant(BasePtr)) 553 return ScaledIndex; 554 555 // re-constitute pointer vector (basePtr + index * scale) 556 SDValue BaseBroadcast = getBroadcast(IndexVT, BasePtr, AVL); 557 auto ResPtr = 558 getNode(VEISD::VVP_ADD, IndexVT, {BaseBroadcast, ScaledIndex, Mask, AVL}); 559 return ResPtr; 560 } 561 562 SDValue VECustomDAG::getLegalReductionOpVVP(unsigned VVPOpcode, EVT ResVT, 563 SDValue StartV, SDValue VectorV, 564 SDValue Mask, SDValue AVL, 565 SDNodeFlags Flags) const { 566 567 // Optionally attach the start param with a scalar op (where it is 568 // unsupported). 569 bool scalarizeStartParam = StartV && !hasReductionStartParam(VVPOpcode); 570 bool IsMaskReduction = isMaskType(VectorV.getValueType()); 571 assert(!IsMaskReduction && "TODO Implement"); 572 auto AttachStartValue = [&](SDValue ReductionResV) { 573 if (!scalarizeStartParam) 574 return ReductionResV; 575 auto ScalarOC = getScalarReductionOpcode(VVPOpcode, IsMaskReduction); 576 return getNode(ScalarOC, ResVT, {StartV, ReductionResV}); 577 }; 578 579 // Fixup: Always Use sequential 'fmul' reduction. 580 if (!scalarizeStartParam && StartV) { 581 assert(hasReductionStartParam(VVPOpcode)); 582 return AttachStartValue( 583 getNode(VVPOpcode, ResVT, {StartV, VectorV, Mask, AVL}, Flags)); 584 } else 585 return AttachStartValue( 586 getNode(VVPOpcode, ResVT, {VectorV, Mask, AVL}, Flags)); 587 } 588 589 } // namespace llvm 590