1 //===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the lowering and legalization of vector instructions to 10 // VVP_*layer SDNodes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "VECustomDAG.h" 15 #include "VEISelLowering.h" 16 17 using namespace llvm; 18 19 #define DEBUG_TYPE "ve-lower" 20 21 SDValue VETargetLowering::splitMaskArithmetic(SDValue Op, 22 SelectionDAG &DAG) const { 23 VECustomDAG CDAG(DAG, Op); 24 SDValue AVL = 25 CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32); 26 SDValue A = Op->getOperand(0); 27 SDValue B = Op->getOperand(1); 28 SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL); 29 SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL); 30 SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL); 31 SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL); 32 unsigned Opc = Op.getOpcode(); 33 auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB}); 34 auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB}); 35 return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL); 36 } 37 38 SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const { 39 // Can we represent this as a VVP node. 40 const unsigned Opcode = Op->getOpcode(); 41 auto VVPOpcodeOpt = getVVPOpcode(Opcode); 42 if (!VVPOpcodeOpt) 43 return SDValue(); 44 unsigned VVPOpcode = *VVPOpcodeOpt; 45 const bool FromVP = ISD::isVPOpcode(Opcode); 46 47 // The representative and legalized vector type of this operation. 48 VECustomDAG CDAG(DAG, Op); 49 // Dispatch to complex lowering functions. 50 switch (VVPOpcode) { 51 case VEISD::VVP_LOAD: 52 case VEISD::VVP_STORE: 53 return lowerVVP_LOAD_STORE(Op, CDAG); 54 case VEISD::VVP_GATHER: 55 case VEISD::VVP_SCATTER: 56 return lowerVVP_GATHER_SCATTER(Op, CDAG); 57 } 58 59 EVT OpVecVT = *getIdiomaticVectorType(Op.getNode()); 60 EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT); 61 auto Packing = getTypePacking(LegalVecVT.getSimpleVT()); 62 63 SDValue AVL; 64 SDValue Mask; 65 66 if (FromVP) { 67 // All upstream VP SDNodes always have a mask and avl. 68 auto MaskIdx = ISD::getVPMaskIdx(Opcode); 69 auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode); 70 if (MaskIdx) 71 Mask = Op->getOperand(*MaskIdx); 72 if (AVLIdx) 73 AVL = Op->getOperand(*AVLIdx); 74 } 75 76 // Materialize default mask and avl. 77 if (!AVL) 78 AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32); 79 if (!Mask) 80 Mask = CDAG.getConstantMask(Packing, true); 81 82 assert(LegalVecVT.isSimple()); 83 if (isVVPUnaryOp(VVPOpcode)) 84 return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL}); 85 if (isVVPBinaryOp(VVPOpcode)) 86 return CDAG.getNode(VVPOpcode, LegalVecVT, 87 {Op->getOperand(0), Op->getOperand(1), Mask, AVL}); 88 if (isVVPReductionOp(VVPOpcode)) { 89 auto SrcHasStart = hasReductionStartParam(Op->getOpcode()); 90 SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue(); 91 SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0); 92 return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV, 93 VectorV, Mask, AVL, Op->getFlags()); 94 } 95 96 switch (VVPOpcode) { 97 default: 98 llvm_unreachable("lowerToVVP called for unexpected SDNode."); 99 case VEISD::VVP_FFMA: { 100 // VE has a swizzled operand order in FMA (compared to LLVM IR and 101 // SDNodes). 102 auto X = Op->getOperand(2); 103 auto Y = Op->getOperand(0); 104 auto Z = Op->getOperand(1); 105 return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL}); 106 } 107 case VEISD::VVP_SELECT: { 108 auto Mask = Op->getOperand(0); 109 auto OnTrue = Op->getOperand(1); 110 auto OnFalse = Op->getOperand(2); 111 return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL}); 112 } 113 case VEISD::VVP_SETCC: { 114 EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType()); 115 auto LHS = Op->getOperand(0); 116 auto RHS = Op->getOperand(1); 117 auto Pred = Op->getOperand(2); 118 return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL}); 119 } 120 } 121 } 122 123 SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op, 124 VECustomDAG &CDAG) const { 125 auto VVPOpc = *getVVPOpcode(Op->getOpcode()); 126 const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD); 127 128 // Shares. 129 SDValue BasePtr = getMemoryPtr(Op); 130 SDValue Mask = getNodeMask(Op); 131 SDValue Chain = getNodeChain(Op); 132 SDValue AVL = getNodeAVL(Op); 133 // Store specific. 134 SDValue Data = getStoredValue(Op); 135 // Load specific. 136 SDValue PassThru = getNodePassthru(Op); 137 138 SDValue StrideV = getLoadStoreStride(Op, CDAG); 139 140 auto DataVT = *getIdiomaticVectorType(Op.getNode()); 141 auto Packing = getTypePacking(DataVT); 142 143 // TODO: Infer lower AVL from mask. 144 if (!AVL) 145 AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32); 146 147 // Default to the all-true mask. 148 if (!Mask) 149 Mask = CDAG.getConstantMask(Packing, true); 150 151 if (IsLoad) { 152 MVT LegalDataVT = getLegalVectorType( 153 Packing, DataVT.getVectorElementType().getSimpleVT()); 154 155 auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other}, 156 {Chain, BasePtr, StrideV, Mask, AVL}); 157 158 if (!PassThru || PassThru->isUndef()) 159 return NewLoadV; 160 161 // Convert passthru to an explicit select node. 162 SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT, 163 {NewLoadV, PassThru, Mask, AVL}); 164 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1); 165 166 // Merge them back into one node. 167 return CDAG.getMergeValues({DataV, NewLoadChainV}); 168 } 169 170 // VVP_STORE 171 assert(VVPOpc == VEISD::VVP_STORE); 172 return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(), 173 {Chain, Data, BasePtr, StrideV, Mask, AVL}); 174 } 175 176 SDValue VETargetLowering::splitPackedLoadStore(SDValue Op, 177 VECustomDAG &CDAG) const { 178 auto VVPOC = *getVVPOpcode(Op.getOpcode()); 179 assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE)); 180 181 MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT(); 182 assert(getTypePacking(DataVT) == Packing::Dense && 183 "Can only split packed load/store"); 184 MVT SplitDataVT = splitVectorType(DataVT); 185 186 assert(!getNodePassthru(Op) && 187 "Should have been folded in lowering to VVP layer"); 188 189 // Analyze the operation 190 SDValue PackedMask = getNodeMask(Op); 191 SDValue PackedAVL = getAnnotatedNodeAVL(Op).first; 192 SDValue PackPtr = getMemoryPtr(Op); 193 SDValue PackData = getStoredValue(Op); 194 SDValue PackStride = getLoadStoreStride(Op, CDAG); 195 196 unsigned ChainResIdx = PackData ? 0 : 1; 197 198 SDValue PartOps[2]; 199 200 SDValue UpperPartAVL; // we will use this for packing things back together 201 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) { 202 // VP ops already have an explicit mask and AVL. When expanding from non-VP 203 // attach those additional inputs here. 204 auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part); 205 206 // Keep track of the (higher) lvl. 207 if (Part == PackElem::Hi) 208 UpperPartAVL = SplitTM.AVL; 209 210 // Attach non-predicating value operands 211 SmallVector<SDValue, 4> OpVec; 212 213 // Chain 214 OpVec.push_back(getNodeChain(Op)); 215 216 // Data 217 if (PackData) { 218 SDValue PartData = 219 CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL); 220 OpVec.push_back(PartData); 221 } 222 223 // Ptr & Stride 224 // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes) 225 // Stride info 226 // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode); 227 OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part)); 228 OpVec.push_back(CDAG.getSplitPtrStride(PackStride)); 229 230 // Add predicating args and generate part node 231 OpVec.push_back(SplitTM.Mask); 232 OpVec.push_back(SplitTM.AVL); 233 234 if (PackData) { 235 // Store 236 PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec); 237 } else { 238 // Load 239 PartOps[(int)Part] = 240 CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec); 241 } 242 } 243 244 // Merge the chains 245 SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx); 246 SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx); 247 SDValue FusedChains = 248 CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain}); 249 250 // Chain only [store] 251 if (PackData) 252 return FusedChains; 253 254 // Re-pack into full packed vector result 255 MVT PackedVT = 256 getLegalVectorType(Packing::Dense, DataVT.getVectorElementType()); 257 SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo], 258 PartOps[(int)PackElem::Hi], UpperPartAVL); 259 260 return CDAG.getMergeValues({PackedVals, FusedChains}); 261 } 262 263 SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op, 264 VECustomDAG &CDAG) const { 265 EVT DataVT = *getIdiomaticVectorType(Op.getNode()); 266 auto Packing = getTypePacking(DataVT); 267 MVT LegalDataVT = 268 getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT()); 269 270 SDValue AVL = getAnnotatedNodeAVL(Op).first; 271 SDValue Index = getGatherScatterIndex(Op); 272 SDValue BasePtr = getMemoryPtr(Op); 273 SDValue Mask = getNodeMask(Op); 274 SDValue Chain = getNodeChain(Op); 275 SDValue Scale = getGatherScatterScale(Op); 276 SDValue PassThru = getNodePassthru(Op); 277 SDValue StoredValue = getStoredValue(Op); 278 if (PassThru && PassThru->isUndef()) 279 PassThru = SDValue(); 280 281 bool IsScatter = (bool)StoredValue; 282 283 // TODO: Infer lower AVL from mask. 284 if (!AVL) 285 AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32); 286 287 // Default to the all-true mask. 288 if (!Mask) 289 Mask = CDAG.getConstantMask(Packing, true); 290 291 SDValue AddressVec = 292 CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL); 293 if (IsScatter) 294 return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other, 295 {Chain, StoredValue, AddressVec, Mask, AVL}); 296 297 // Gather. 298 SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other}, 299 {Chain, AddressVec, Mask, AVL}); 300 301 if (!PassThru) 302 return NewLoadV; 303 304 // TODO: Use vvp_select 305 SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT, 306 {NewLoadV, PassThru, Mask, AVL}); 307 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1); 308 return CDAG.getMergeValues({DataV, NewLoadChainV}); 309 } 310 311 SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op, 312 VECustomDAG &CDAG) const { 313 LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";); 314 MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT(); 315 316 // TODO: Recognize packable load,store. 317 if (isPackedVectorType(DataVT)) 318 return splitPackedLoadStore(Op, CDAG); 319 320 return legalizePackedAVL(Op, CDAG); 321 } 322 323 SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op, 324 SelectionDAG &DAG) const { 325 LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";); 326 VECustomDAG CDAG(DAG, Op); 327 328 // Dispatch to specialized legalization functions. 329 switch (Op->getOpcode()) { 330 case VEISD::VVP_LOAD: 331 case VEISD::VVP_STORE: 332 return legalizeInternalLoadStoreOp(Op, CDAG); 333 } 334 335 EVT IdiomVT = Op.getValueType(); 336 if (isPackedVectorType(IdiomVT) && 337 !supportsPackedMode(Op.getOpcode(), IdiomVT)) 338 return splitVectorOp(Op, CDAG); 339 340 // TODO: Implement odd/even splitting. 341 return legalizePackedAVL(Op, CDAG); 342 } 343 344 SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const { 345 MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType()); 346 347 auto AVLPos = getAVLPos(Op->getOpcode()); 348 auto MaskPos = getMaskPos(Op->getOpcode()); 349 350 SDValue PackedMask = getNodeMask(Op); 351 auto AVLPair = getAnnotatedNodeAVL(Op); 352 SDValue PackedAVL = AVLPair.first; 353 assert(!AVLPair.second && "Expecting non pack-legalized oepration"); 354 355 // request the parts 356 SDValue PartOps[2]; 357 358 SDValue UpperPartAVL; // we will use this for packing things back together 359 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) { 360 // VP ops already have an explicit mask and AVL. When expanding from non-VP 361 // attach those additional inputs here. 362 auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part); 363 364 if (Part == PackElem::Hi) 365 UpperPartAVL = SplitTM.AVL; 366 367 // Attach non-predicating value operands 368 SmallVector<SDValue, 4> OpVec; 369 for (unsigned i = 0; i < Op.getNumOperands(); ++i) { 370 if (AVLPos && ((int)i) == *AVLPos) 371 continue; 372 if (MaskPos && ((int)i) == *MaskPos) 373 continue; 374 375 // Value operand 376 auto PackedOperand = Op.getOperand(i); 377 auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType()); 378 SDValue PartV = 379 CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL); 380 OpVec.push_back(PartV); 381 } 382 383 // Add predicating args and generate part node. 384 OpVec.push_back(SplitTM.Mask); 385 OpVec.push_back(SplitTM.AVL); 386 // Emit legal VVP nodes. 387 PartOps[(int)Part] = 388 CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags()); 389 } 390 391 // Re-package vectors. 392 return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo], 393 PartOps[(int)PackElem::Hi], UpperPartAVL); 394 } 395 396 SDValue VETargetLowering::legalizePackedAVL(SDValue Op, 397 VECustomDAG &CDAG) const { 398 LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";); 399 // Only required for VEC and VVP ops. 400 if (!isVVPOrVEC(Op->getOpcode())) 401 return Op; 402 403 // Operation already has a legal AVL. 404 auto AVL = getNodeAVL(Op); 405 if (isLegalAVL(AVL)) 406 return Op; 407 408 // Half and round up EVL for 32bit element types. 409 SDValue LegalAVL = AVL; 410 MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT(); 411 if (isPackedVectorType(IdiomVT)) { 412 assert(maySafelyIgnoreMask(Op) && 413 "TODO Shift predication from EVL into Mask"); 414 415 if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) { 416 LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32); 417 } else { 418 auto ConstOne = CDAG.getConstant(1, MVT::i32); 419 auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne}); 420 LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne}); 421 } 422 } 423 424 SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL); 425 426 // Copy the operand list. 427 int NumOp = Op->getNumOperands(); 428 auto AVLPos = getAVLPos(Op->getOpcode()); 429 std::vector<SDValue> FixedOperands; 430 for (int i = 0; i < NumOp; ++i) { 431 if (AVLPos && (i == *AVLPos)) { 432 FixedOperands.push_back(AnnotatedLegalAVL); 433 continue; 434 } 435 FixedOperands.push_back(Op->getOperand(i)); 436 } 437 438 // Clone the operation with fixed operands. 439 auto Flags = Op->getFlags(); 440 SDValue NewN = 441 CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags); 442 return NewN; 443 } 444