1 //===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the lowering and legalization of vector instructions to 10 // VVP_*layer SDNodes. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "VECustomDAG.h" 15 #include "VEISelLowering.h" 16 17 using namespace llvm; 18 19 #define DEBUG_TYPE "ve-lower" 20 21 SDValue VETargetLowering::splitMaskArithmetic(SDValue Op, 22 SelectionDAG &DAG) const { 23 VECustomDAG CDAG(DAG, Op); 24 SDValue AVL = 25 CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32); 26 SDValue A = Op->getOperand(0); 27 SDValue B = Op->getOperand(1); 28 SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL); 29 SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL); 30 SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL); 31 SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL); 32 unsigned Opc = Op.getOpcode(); 33 auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB}); 34 auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB}); 35 return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL); 36 } 37 38 SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const { 39 // Can we represent this as a VVP node. 40 const unsigned Opcode = Op->getOpcode(); 41 auto VVPOpcodeOpt = getVVPOpcode(Opcode); 42 if (!VVPOpcodeOpt) 43 return SDValue(); 44 unsigned VVPOpcode = *VVPOpcodeOpt; 45 const bool FromVP = ISD::isVPOpcode(Opcode); 46 47 // The representative and legalized vector type of this operation. 48 VECustomDAG CDAG(DAG, Op); 49 // Dispatch to complex lowering functions. 50 switch (VVPOpcode) { 51 case VEISD::VVP_LOAD: 52 case VEISD::VVP_STORE: 53 return lowerVVP_LOAD_STORE(Op, CDAG); 54 case VEISD::VVP_GATHER: 55 case VEISD::VVP_SCATTER: 56 return lowerVVP_GATHER_SCATTER(Op, CDAG); 57 } 58 59 EVT OpVecVT = *getIdiomaticVectorType(Op.getNode()); 60 EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT); 61 auto Packing = getTypePacking(LegalVecVT.getSimpleVT()); 62 63 SDValue AVL; 64 SDValue Mask; 65 66 if (FromVP) { 67 // All upstream VP SDNodes always have a mask and avl. 68 auto MaskIdx = ISD::getVPMaskIdx(Opcode); 69 auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode); 70 if (MaskIdx) 71 Mask = Op->getOperand(*MaskIdx); 72 if (AVLIdx) 73 AVL = Op->getOperand(*AVLIdx); 74 } 75 76 // Materialize default mask and avl. 77 if (!AVL) 78 AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32); 79 if (!Mask) 80 Mask = CDAG.getConstantMask(Packing, true); 81 82 assert(LegalVecVT.isSimple()); 83 if (isVVPUnaryOp(VVPOpcode)) 84 return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL}); 85 if (isVVPBinaryOp(VVPOpcode)) 86 return CDAG.getNode(VVPOpcode, LegalVecVT, 87 {Op->getOperand(0), Op->getOperand(1), Mask, AVL}); 88 if (isVVPReductionOp(VVPOpcode)) { 89 auto SrcHasStart = hasReductionStartParam(Op->getOpcode()); 90 SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue(); 91 SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0); 92 return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV, 93 VectorV, Mask, AVL, Op->getFlags()); 94 } 95 96 switch (VVPOpcode) { 97 default: 98 llvm_unreachable("lowerToVVP called for unexpected SDNode."); 99 case VEISD::VVP_FFMA: { 100 // VE has a swizzled operand order in FMA (compared to LLVM IR and 101 // SDNodes). 102 auto X = Op->getOperand(2); 103 auto Y = Op->getOperand(0); 104 auto Z = Op->getOperand(1); 105 return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL}); 106 } 107 case VEISD::VVP_SELECT: { 108 auto Mask = Op->getOperand(0); 109 auto OnTrue = Op->getOperand(1); 110 auto OnFalse = Op->getOperand(2); 111 return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL}); 112 } 113 case VEISD::VVP_SETCC: { 114 EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType()); 115 auto LHS = Op->getOperand(0); 116 auto RHS = Op->getOperand(1); 117 auto Pred = Op->getOperand(2); 118 return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL}); 119 } 120 } 121 } 122 123 SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op, 124 VECustomDAG &CDAG) const { 125 auto VVPOpc = *getVVPOpcode(Op->getOpcode()); 126 const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD); 127 128 // Shares. 129 SDValue BasePtr = getMemoryPtr(Op); 130 SDValue Mask = getNodeMask(Op); 131 SDValue Chain = getNodeChain(Op); 132 SDValue AVL = getNodeAVL(Op); 133 // Store specific. 134 SDValue Data = getStoredValue(Op); 135 // Load specific. 136 SDValue PassThru = getNodePassthru(Op); 137 138 SDValue StrideV = getLoadStoreStride(Op, CDAG); 139 140 auto DataVT = *getIdiomaticVectorType(Op.getNode()); 141 auto Packing = getTypePacking(DataVT); 142 143 // TODO: Infer lower AVL from mask. 144 if (!AVL) 145 AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32); 146 147 // Default to the all-true mask. 148 if (!Mask) 149 Mask = CDAG.getConstantMask(Packing, true); 150 151 if (IsLoad) { 152 MVT LegalDataVT = getLegalVectorType( 153 Packing, DataVT.getVectorElementType().getSimpleVT()); 154 155 auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other}, 156 {Chain, BasePtr, StrideV, Mask, AVL}); 157 158 if (!PassThru || PassThru->isUndef()) 159 return NewLoadV; 160 161 // Convert passthru to an explicit select node. 162 SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT, 163 {NewLoadV, PassThru, Mask, AVL}); 164 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1); 165 166 // Merge them back into one node. 167 return CDAG.getMergeValues({DataV, NewLoadChainV}); 168 } 169 170 // VVP_STORE 171 assert(VVPOpc == VEISD::VVP_STORE); 172 if (getTypeAction(*CDAG.getDAG()->getContext(), Data.getValueType()) != 173 TargetLowering::TypeLegal) 174 // Doesn't lower store instruction if an operand is not lowered yet. 175 // If it isn't, return SDValue(). In this way, LLVM will try to lower 176 // store instruction again after lowering all operands. 177 return SDValue(); 178 return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(), 179 {Chain, Data, BasePtr, StrideV, Mask, AVL}); 180 } 181 182 SDValue VETargetLowering::splitPackedLoadStore(SDValue Op, 183 VECustomDAG &CDAG) const { 184 auto VVPOC = *getVVPOpcode(Op.getOpcode()); 185 assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE)); 186 187 MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT(); 188 assert(getTypePacking(DataVT) == Packing::Dense && 189 "Can only split packed load/store"); 190 MVT SplitDataVT = splitVectorType(DataVT); 191 192 assert(!getNodePassthru(Op) && 193 "Should have been folded in lowering to VVP layer"); 194 195 // Analyze the operation 196 SDValue PackedMask = getNodeMask(Op); 197 SDValue PackedAVL = getAnnotatedNodeAVL(Op).first; 198 SDValue PackPtr = getMemoryPtr(Op); 199 SDValue PackData = getStoredValue(Op); 200 SDValue PackStride = getLoadStoreStride(Op, CDAG); 201 202 unsigned ChainResIdx = PackData ? 0 : 1; 203 204 SDValue PartOps[2]; 205 206 SDValue UpperPartAVL; // we will use this for packing things back together 207 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) { 208 // VP ops already have an explicit mask and AVL. When expanding from non-VP 209 // attach those additional inputs here. 210 auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part); 211 212 // Keep track of the (higher) lvl. 213 if (Part == PackElem::Hi) 214 UpperPartAVL = SplitTM.AVL; 215 216 // Attach non-predicating value operands 217 SmallVector<SDValue, 4> OpVec; 218 219 // Chain 220 OpVec.push_back(getNodeChain(Op)); 221 222 // Data 223 if (PackData) { 224 SDValue PartData = 225 CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL); 226 OpVec.push_back(PartData); 227 } 228 229 // Ptr & Stride 230 // Push (ptr + ElemBytes * <Part>, 2 * ElemBytes) 231 // Stride info 232 // EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode); 233 OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part)); 234 OpVec.push_back(CDAG.getSplitPtrStride(PackStride)); 235 236 // Add predicating args and generate part node 237 OpVec.push_back(SplitTM.Mask); 238 OpVec.push_back(SplitTM.AVL); 239 240 if (PackData) { 241 // Store 242 PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec); 243 } else { 244 // Load 245 PartOps[(int)Part] = 246 CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec); 247 } 248 } 249 250 // Merge the chains 251 SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx); 252 SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx); 253 SDValue FusedChains = 254 CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain}); 255 256 // Chain only [store] 257 if (PackData) 258 return FusedChains; 259 260 // Re-pack into full packed vector result 261 MVT PackedVT = 262 getLegalVectorType(Packing::Dense, DataVT.getVectorElementType()); 263 SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo], 264 PartOps[(int)PackElem::Hi], UpperPartAVL); 265 266 return CDAG.getMergeValues({PackedVals, FusedChains}); 267 } 268 269 SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op, 270 VECustomDAG &CDAG) const { 271 EVT DataVT = *getIdiomaticVectorType(Op.getNode()); 272 auto Packing = getTypePacking(DataVT); 273 MVT LegalDataVT = 274 getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT()); 275 276 SDValue AVL = getAnnotatedNodeAVL(Op).first; 277 SDValue Index = getGatherScatterIndex(Op); 278 SDValue BasePtr = getMemoryPtr(Op); 279 SDValue Mask = getNodeMask(Op); 280 SDValue Chain = getNodeChain(Op); 281 SDValue Scale = getGatherScatterScale(Op); 282 SDValue PassThru = getNodePassthru(Op); 283 SDValue StoredValue = getStoredValue(Op); 284 if (PassThru && PassThru->isUndef()) 285 PassThru = SDValue(); 286 287 bool IsScatter = (bool)StoredValue; 288 289 // TODO: Infer lower AVL from mask. 290 if (!AVL) 291 AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32); 292 293 // Default to the all-true mask. 294 if (!Mask) 295 Mask = CDAG.getConstantMask(Packing, true); 296 297 SDValue AddressVec = 298 CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL); 299 if (IsScatter) 300 return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other, 301 {Chain, StoredValue, AddressVec, Mask, AVL}); 302 303 // Gather. 304 SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other}, 305 {Chain, AddressVec, Mask, AVL}); 306 307 if (!PassThru) 308 return NewLoadV; 309 310 // TODO: Use vvp_select 311 SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT, 312 {NewLoadV, PassThru, Mask, AVL}); 313 SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1); 314 return CDAG.getMergeValues({DataV, NewLoadChainV}); 315 } 316 317 SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op, 318 VECustomDAG &CDAG) const { 319 LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";); 320 MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT(); 321 322 // TODO: Recognize packable load,store. 323 if (isPackedVectorType(DataVT)) 324 return splitPackedLoadStore(Op, CDAG); 325 326 return legalizePackedAVL(Op, CDAG); 327 } 328 329 SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op, 330 SelectionDAG &DAG) const { 331 LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";); 332 VECustomDAG CDAG(DAG, Op); 333 334 // Dispatch to specialized legalization functions. 335 switch (Op->getOpcode()) { 336 case VEISD::VVP_LOAD: 337 case VEISD::VVP_STORE: 338 return legalizeInternalLoadStoreOp(Op, CDAG); 339 } 340 341 EVT IdiomVT = Op.getValueType(); 342 if (isPackedVectorType(IdiomVT) && 343 !supportsPackedMode(Op.getOpcode(), IdiomVT)) 344 return splitVectorOp(Op, CDAG); 345 346 // TODO: Implement odd/even splitting. 347 return legalizePackedAVL(Op, CDAG); 348 } 349 350 SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const { 351 MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType()); 352 353 auto AVLPos = getAVLPos(Op->getOpcode()); 354 auto MaskPos = getMaskPos(Op->getOpcode()); 355 356 SDValue PackedMask = getNodeMask(Op); 357 auto AVLPair = getAnnotatedNodeAVL(Op); 358 SDValue PackedAVL = AVLPair.first; 359 assert(!AVLPair.second && "Expecting non pack-legalized oepration"); 360 361 // request the parts 362 SDValue PartOps[2]; 363 364 SDValue UpperPartAVL; // we will use this for packing things back together 365 for (PackElem Part : {PackElem::Hi, PackElem::Lo}) { 366 // VP ops already have an explicit mask and AVL. When expanding from non-VP 367 // attach those additional inputs here. 368 auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part); 369 370 if (Part == PackElem::Hi) 371 UpperPartAVL = SplitTM.AVL; 372 373 // Attach non-predicating value operands 374 SmallVector<SDValue, 4> OpVec; 375 for (unsigned i = 0; i < Op.getNumOperands(); ++i) { 376 if (AVLPos && ((int)i) == *AVLPos) 377 continue; 378 if (MaskPos && ((int)i) == *MaskPos) 379 continue; 380 381 // Value operand 382 auto PackedOperand = Op.getOperand(i); 383 auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType()); 384 SDValue PartV = 385 CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL); 386 OpVec.push_back(PartV); 387 } 388 389 // Add predicating args and generate part node. 390 OpVec.push_back(SplitTM.Mask); 391 OpVec.push_back(SplitTM.AVL); 392 // Emit legal VVP nodes. 393 PartOps[(int)Part] = 394 CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags()); 395 } 396 397 // Re-package vectors. 398 return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo], 399 PartOps[(int)PackElem::Hi], UpperPartAVL); 400 } 401 402 SDValue VETargetLowering::legalizePackedAVL(SDValue Op, 403 VECustomDAG &CDAG) const { 404 LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";); 405 // Only required for VEC and VVP ops. 406 if (!isVVPOrVEC(Op->getOpcode())) 407 return Op; 408 409 // Operation already has a legal AVL. 410 auto AVL = getNodeAVL(Op); 411 if (isLegalAVL(AVL)) 412 return Op; 413 414 // Half and round up EVL for 32bit element types. 415 SDValue LegalAVL = AVL; 416 MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT(); 417 if (isPackedVectorType(IdiomVT)) { 418 assert(maySafelyIgnoreMask(Op) && 419 "TODO Shift predication from EVL into Mask"); 420 421 if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) { 422 LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32); 423 } else { 424 auto ConstOne = CDAG.getConstant(1, MVT::i32); 425 auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne}); 426 LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne}); 427 } 428 } 429 430 SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL); 431 432 // Copy the operand list. 433 int NumOp = Op->getNumOperands(); 434 auto AVLPos = getAVLPos(Op->getOpcode()); 435 std::vector<SDValue> FixedOperands; 436 for (int i = 0; i < NumOp; ++i) { 437 if (AVLPos && (i == *AVLPos)) { 438 FixedOperands.push_back(AnnotatedLegalAVL); 439 continue; 440 } 441 FixedOperands.push_back(Op->getOperand(i)); 442 } 443 444 // Clone the operation with fixed operands. 445 auto Flags = Op->getFlags(); 446 SDValue NewN = 447 CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags); 448 return NewN; 449 } 450