1 //===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements IR expansion for vector predication intrinsics, allowing 10 // targets to enable vector predication until just before codegen. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/ExpandVectorPredication.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/Analysis/TargetTransformInfo.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/Analysis/VectorUtils.h" 19 #include "llvm/CodeGen/Passes.h" 20 #include "llvm/IR/Constants.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/InstIterator.h" 24 #include "llvm/IR/Instructions.h" 25 #include "llvm/IR/IntrinsicInst.h" 26 #include "llvm/IR/Intrinsics.h" 27 #include "llvm/InitializePasses.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Compiler.h" 31 #include "llvm/Support/Debug.h" 32 #include <optional> 33 34 using namespace llvm; 35 36 using VPLegalization = TargetTransformInfo::VPLegalization; 37 using VPTransform = TargetTransformInfo::VPLegalization::VPTransform; 38 39 // Keep this in sync with TargetTransformInfo::VPLegalization. 40 #define VPINTERNAL_VPLEGAL_CASES \ 41 VPINTERNAL_CASE(Legal) \ 42 VPINTERNAL_CASE(Discard) \ 43 VPINTERNAL_CASE(Convert) 44 45 #define VPINTERNAL_CASE(X) "|" #X 46 47 // Override options. 48 static cl::opt<std::string> EVLTransformOverride( 49 "expandvp-override-evl-transform", cl::init(""), cl::Hidden, 50 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES 51 ". If non-empty, ignore " 52 "TargetTransformInfo and " 53 "always use this transformation for the %evl parameter (Used in " 54 "testing).")); 55 56 static cl::opt<std::string> MaskTransformOverride( 57 "expandvp-override-mask-transform", cl::init(""), cl::Hidden, 58 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES 59 ". If non-empty, Ignore " 60 "TargetTransformInfo and " 61 "always use this transformation for the %mask parameter (Used in " 62 "testing).")); 63 64 #undef VPINTERNAL_CASE 65 #define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X) 66 67 static VPTransform parseOverrideOption(const std::string &TextOpt) { 68 return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES; 69 } 70 71 #undef VPINTERNAL_VPLEGAL_CASES 72 73 // Whether any override options are set. 74 static bool anyExpandVPOverridesSet() { 75 return !EVLTransformOverride.empty() || !MaskTransformOverride.empty(); 76 } 77 78 #define DEBUG_TYPE "expandvp" 79 80 STATISTIC(NumFoldedVL, "Number of folded vector length params"); 81 STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations"); 82 83 ///// Helpers { 84 85 /// \returns Whether the vector mask \p MaskVal has all lane bits set. 86 static bool isAllTrueMask(Value *MaskVal) { 87 if (Value *SplattedVal = getSplatValue(MaskVal)) 88 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal)) 89 return ConstValue->isAllOnesValue(); 90 91 return false; 92 } 93 94 /// \returns A non-excepting divisor constant for this type. 95 static Constant *getSafeDivisor(Type *DivTy) { 96 assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type"); 97 return ConstantInt::get(DivTy, 1u, false); 98 } 99 100 /// Transfer operation properties from \p OldVPI to \p NewVal. 101 static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) { 102 auto *NewInst = dyn_cast<Instruction>(&NewVal); 103 if (!NewInst || !isa<FPMathOperator>(NewVal)) 104 return; 105 106 auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI); 107 if (!OldFMOp) 108 return; 109 110 NewInst->setFastMathFlags(OldFMOp->getFastMathFlags()); 111 } 112 113 /// Transfer all properties from \p OldOp to \p NewOp and replace all uses. 114 /// OldVP gets erased. 115 static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) { 116 transferDecorations(NewOp, OldOp); 117 OldOp.replaceAllUsesWith(&NewOp); 118 OldOp.eraseFromParent(); 119 } 120 121 static bool maySpeculateLanes(VPIntrinsic &VPI) { 122 // The result of VP reductions depends on the mask and evl. 123 if (isa<VPReductionIntrinsic>(VPI)) 124 return false; 125 // Fallback to whether the intrinsic is speculatable. 126 std::optional<unsigned> OpcOpt = VPI.getFunctionalOpcode(); 127 unsigned FunctionalOpc = OpcOpt.value_or((unsigned)Instruction::Call); 128 return isSafeToSpeculativelyExecuteWithOpcode(FunctionalOpc, &VPI); 129 } 130 131 //// } Helpers 132 133 namespace { 134 135 // Expansion pass state at function scope. 136 struct CachingVPExpander { 137 Function &F; 138 const TargetTransformInfo &TTI; 139 140 /// \returns A (fixed length) vector with ascending integer indices 141 /// (<0, 1, ..., NumElems-1>). 142 /// \p Builder 143 /// Used for instruction creation. 144 /// \p LaneTy 145 /// Integer element type of the result vector. 146 /// \p NumElems 147 /// Number of vector elements. 148 Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy, 149 unsigned NumElems); 150 151 /// \returns A bitmask that is true where the lane position is less-than \p 152 /// EVLParam 153 /// 154 /// \p Builder 155 /// Used for instruction creation. 156 /// \p VLParam 157 /// The explicit vector length parameter to test against the lane 158 /// positions. 159 /// \p ElemCount 160 /// Static (potentially scalable) number of vector elements. 161 Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam, 162 ElementCount ElemCount); 163 164 Value *foldEVLIntoMask(VPIntrinsic &VPI); 165 166 /// "Remove" the %evl parameter of \p PI by setting it to the static vector 167 /// length of the operation. 168 void discardEVLParameter(VPIntrinsic &PI); 169 170 /// Lower this VP binary operator to a unpredicated binary operator. 171 Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder, 172 VPIntrinsic &PI); 173 174 /// Lower this VP reduction to a call to an unpredicated reduction intrinsic. 175 Value *expandPredicationInReduction(IRBuilder<> &Builder, 176 VPReductionIntrinsic &PI); 177 178 /// Lower this VP memory operation to a non-VP intrinsic. 179 Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, 180 VPIntrinsic &VPI); 181 182 /// Lower this VP comparison to a call to an unpredicated comparison. 183 Value *expandPredicationInComparison(IRBuilder<> &Builder, 184 VPCmpIntrinsic &PI); 185 186 /// Query TTI and expand the vector predication in \p P accordingly. 187 Value *expandPredication(VPIntrinsic &PI); 188 189 /// Determine how and whether the VPIntrinsic \p VPI shall be expanded. This 190 /// overrides TTI with the cl::opts listed at the top of this file. 191 VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const; 192 bool UsingTTIOverrides; 193 194 public: 195 CachingVPExpander(Function &F, const TargetTransformInfo &TTI) 196 : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {} 197 198 bool expandVectorPredication(); 199 }; 200 201 //// CachingVPExpander { 202 203 Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy, 204 unsigned NumElems) { 205 // TODO add caching 206 SmallVector<Constant *, 16> ConstElems; 207 208 for (unsigned Idx = 0; Idx < NumElems; ++Idx) 209 ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false)); 210 211 return ConstantVector::get(ConstElems); 212 } 213 214 Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder, 215 Value *EVLParam, 216 ElementCount ElemCount) { 217 // TODO add caching 218 // Scalable vector %evl conversion. 219 if (ElemCount.isScalable()) { 220 auto *M = Builder.GetInsertBlock()->getModule(); 221 Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount); 222 Function *ActiveMaskFunc = Intrinsic::getDeclaration( 223 M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()}); 224 // `get_active_lane_mask` performs an implicit less-than comparison. 225 Value *ConstZero = Builder.getInt32(0); 226 return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam}); 227 } 228 229 // Fixed vector %evl conversion. 230 Type *LaneTy = EVLParam->getType(); 231 unsigned NumElems = ElemCount.getFixedValue(); 232 Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam); 233 Value *IdxVec = createStepVector(Builder, LaneTy, NumElems); 234 return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat); 235 } 236 237 Value * 238 CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder, 239 VPIntrinsic &VPI) { 240 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 241 "Implicitly dropping %evl in non-speculatable operator!"); 242 243 auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode()); 244 assert(Instruction::isBinaryOp(OC)); 245 246 Value *Op0 = VPI.getOperand(0); 247 Value *Op1 = VPI.getOperand(1); 248 Value *Mask = VPI.getMaskParam(); 249 250 // Blend in safe operands. 251 if (Mask && !isAllTrueMask(Mask)) { 252 switch (OC) { 253 default: 254 // Can safely ignore the predicate. 255 break; 256 257 // Division operators need a safe divisor on masked-off lanes (1). 258 case Instruction::UDiv: 259 case Instruction::SDiv: 260 case Instruction::URem: 261 case Instruction::SRem: 262 // 2nd operand must not be zero. 263 Value *SafeDivisor = getSafeDivisor(VPI.getType()); 264 Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor); 265 } 266 } 267 268 Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName()); 269 270 replaceOperation(*NewBinOp, VPI); 271 return NewBinOp; 272 } 273 274 static Value *getNeutralReductionElement(const VPReductionIntrinsic &VPI, 275 Type *EltTy) { 276 bool Negative = false; 277 unsigned EltBits = EltTy->getScalarSizeInBits(); 278 switch (VPI.getIntrinsicID()) { 279 default: 280 llvm_unreachable("Expecting a VP reduction intrinsic"); 281 case Intrinsic::vp_reduce_add: 282 case Intrinsic::vp_reduce_or: 283 case Intrinsic::vp_reduce_xor: 284 case Intrinsic::vp_reduce_umax: 285 return Constant::getNullValue(EltTy); 286 case Intrinsic::vp_reduce_mul: 287 return ConstantInt::get(EltTy, 1, /*IsSigned*/ false); 288 case Intrinsic::vp_reduce_and: 289 case Intrinsic::vp_reduce_umin: 290 return ConstantInt::getAllOnesValue(EltTy); 291 case Intrinsic::vp_reduce_smin: 292 return ConstantInt::get(EltTy->getContext(), 293 APInt::getSignedMaxValue(EltBits)); 294 case Intrinsic::vp_reduce_smax: 295 return ConstantInt::get(EltTy->getContext(), 296 APInt::getSignedMinValue(EltBits)); 297 case Intrinsic::vp_reduce_fmax: 298 Negative = true; 299 [[fallthrough]]; 300 case Intrinsic::vp_reduce_fmin: { 301 FastMathFlags Flags = VPI.getFastMathFlags(); 302 const fltSemantics &Semantics = EltTy->getFltSemantics(); 303 return !Flags.noNaNs() ? ConstantFP::getQNaN(EltTy, Negative) 304 : !Flags.noInfs() 305 ? ConstantFP::getInfinity(EltTy, Negative) 306 : ConstantFP::get(EltTy, 307 APFloat::getLargest(Semantics, Negative)); 308 } 309 case Intrinsic::vp_reduce_fadd: 310 return ConstantFP::getNegativeZero(EltTy); 311 case Intrinsic::vp_reduce_fmul: 312 return ConstantFP::get(EltTy, 1.0); 313 } 314 } 315 316 Value * 317 CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder, 318 VPReductionIntrinsic &VPI) { 319 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 320 "Implicitly dropping %evl in non-speculatable operator!"); 321 322 Value *Mask = VPI.getMaskParam(); 323 Value *RedOp = VPI.getOperand(VPI.getVectorParamPos()); 324 325 // Insert neutral element in masked-out positions 326 if (Mask && !isAllTrueMask(Mask)) { 327 auto *NeutralElt = getNeutralReductionElement(VPI, VPI.getType()); 328 auto *NeutralVector = Builder.CreateVectorSplat( 329 cast<VectorType>(RedOp->getType())->getElementCount(), NeutralElt); 330 RedOp = Builder.CreateSelect(Mask, RedOp, NeutralVector); 331 } 332 333 Value *Reduction; 334 Value *Start = VPI.getOperand(VPI.getStartParamPos()); 335 336 switch (VPI.getIntrinsicID()) { 337 default: 338 llvm_unreachable("Impossible reduction kind"); 339 case Intrinsic::vp_reduce_add: 340 Reduction = Builder.CreateAddReduce(RedOp); 341 Reduction = Builder.CreateAdd(Reduction, Start); 342 break; 343 case Intrinsic::vp_reduce_mul: 344 Reduction = Builder.CreateMulReduce(RedOp); 345 Reduction = Builder.CreateMul(Reduction, Start); 346 break; 347 case Intrinsic::vp_reduce_and: 348 Reduction = Builder.CreateAndReduce(RedOp); 349 Reduction = Builder.CreateAnd(Reduction, Start); 350 break; 351 case Intrinsic::vp_reduce_or: 352 Reduction = Builder.CreateOrReduce(RedOp); 353 Reduction = Builder.CreateOr(Reduction, Start); 354 break; 355 case Intrinsic::vp_reduce_xor: 356 Reduction = Builder.CreateXorReduce(RedOp); 357 Reduction = Builder.CreateXor(Reduction, Start); 358 break; 359 case Intrinsic::vp_reduce_smax: 360 Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ true); 361 Reduction = 362 Builder.CreateBinaryIntrinsic(Intrinsic::smax, Reduction, Start); 363 break; 364 case Intrinsic::vp_reduce_smin: 365 Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ true); 366 Reduction = 367 Builder.CreateBinaryIntrinsic(Intrinsic::smin, Reduction, Start); 368 break; 369 case Intrinsic::vp_reduce_umax: 370 Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ false); 371 Reduction = 372 Builder.CreateBinaryIntrinsic(Intrinsic::umax, Reduction, Start); 373 break; 374 case Intrinsic::vp_reduce_umin: 375 Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ false); 376 Reduction = 377 Builder.CreateBinaryIntrinsic(Intrinsic::umin, Reduction, Start); 378 break; 379 case Intrinsic::vp_reduce_fmax: 380 Reduction = Builder.CreateFPMaxReduce(RedOp); 381 transferDecorations(*Reduction, VPI); 382 Reduction = 383 Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, Reduction, Start); 384 break; 385 case Intrinsic::vp_reduce_fmin: 386 Reduction = Builder.CreateFPMinReduce(RedOp); 387 transferDecorations(*Reduction, VPI); 388 Reduction = 389 Builder.CreateBinaryIntrinsic(Intrinsic::minnum, Reduction, Start); 390 break; 391 case Intrinsic::vp_reduce_fadd: 392 Reduction = Builder.CreateFAddReduce(Start, RedOp); 393 break; 394 case Intrinsic::vp_reduce_fmul: 395 Reduction = Builder.CreateFMulReduce(Start, RedOp); 396 break; 397 } 398 399 replaceOperation(*Reduction, VPI); 400 return Reduction; 401 } 402 403 Value * 404 CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, 405 VPIntrinsic &VPI) { 406 assert(VPI.canIgnoreVectorLengthParam()); 407 408 const auto &DL = F.getParent()->getDataLayout(); 409 410 Value *MaskParam = VPI.getMaskParam(); 411 Value *PtrParam = VPI.getMemoryPointerParam(); 412 Value *DataParam = VPI.getMemoryDataParam(); 413 bool IsUnmasked = isAllTrueMask(MaskParam); 414 415 MaybeAlign AlignOpt = VPI.getPointerAlignment(); 416 417 Value *NewMemoryInst = nullptr; 418 switch (VPI.getIntrinsicID()) { 419 default: 420 llvm_unreachable("Not a VP memory intrinsic"); 421 case Intrinsic::vp_store: 422 if (IsUnmasked) { 423 StoreInst *NewStore = 424 Builder.CreateStore(DataParam, PtrParam, /*IsVolatile*/ false); 425 if (AlignOpt.has_value()) 426 NewStore->setAlignment(*AlignOpt); 427 NewMemoryInst = NewStore; 428 } else 429 NewMemoryInst = Builder.CreateMaskedStore( 430 DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam); 431 432 break; 433 case Intrinsic::vp_load: 434 if (IsUnmasked) { 435 LoadInst *NewLoad = 436 Builder.CreateLoad(VPI.getType(), PtrParam, /*IsVolatile*/ false); 437 if (AlignOpt.has_value()) 438 NewLoad->setAlignment(*AlignOpt); 439 NewMemoryInst = NewLoad; 440 } else 441 NewMemoryInst = Builder.CreateMaskedLoad( 442 VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam); 443 444 break; 445 case Intrinsic::vp_scatter: { 446 auto *ElementType = 447 cast<VectorType>(DataParam->getType())->getElementType(); 448 NewMemoryInst = Builder.CreateMaskedScatter( 449 DataParam, PtrParam, 450 AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam); 451 break; 452 } 453 case Intrinsic::vp_gather: { 454 auto *ElementType = cast<VectorType>(VPI.getType())->getElementType(); 455 NewMemoryInst = Builder.CreateMaskedGather( 456 VPI.getType(), PtrParam, 457 AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam, nullptr, 458 VPI.getName()); 459 break; 460 } 461 } 462 463 assert(NewMemoryInst); 464 replaceOperation(*NewMemoryInst, VPI); 465 return NewMemoryInst; 466 } 467 468 Value *CachingVPExpander::expandPredicationInComparison(IRBuilder<> &Builder, 469 VPCmpIntrinsic &VPI) { 470 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 471 "Implicitly dropping %evl in non-speculatable operator!"); 472 473 assert(*VPI.getFunctionalOpcode() == Instruction::ICmp || 474 *VPI.getFunctionalOpcode() == Instruction::FCmp); 475 476 Value *Op0 = VPI.getOperand(0); 477 Value *Op1 = VPI.getOperand(1); 478 auto Pred = VPI.getPredicate(); 479 480 auto *NewCmp = Builder.CreateCmp(Pred, Op0, Op1); 481 482 replaceOperation(*NewCmp, VPI); 483 return NewCmp; 484 } 485 486 void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) { 487 LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n"); 488 489 if (VPI.canIgnoreVectorLengthParam()) 490 return; 491 492 Value *EVLParam = VPI.getVectorLengthParam(); 493 if (!EVLParam) 494 return; 495 496 ElementCount StaticElemCount = VPI.getStaticVectorLength(); 497 Value *MaxEVL = nullptr; 498 Type *Int32Ty = Type::getInt32Ty(VPI.getContext()); 499 if (StaticElemCount.isScalable()) { 500 // TODO add caching 501 auto *M = VPI.getModule(); 502 Function *VScaleFunc = 503 Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty); 504 IRBuilder<> Builder(VPI.getParent(), VPI.getIterator()); 505 Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue()); 506 Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale"); 507 MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size", 508 /*NUW*/ true, /*NSW*/ false); 509 } else { 510 MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false); 511 } 512 VPI.setVectorLengthParam(MaxEVL); 513 } 514 515 Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) { 516 LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n'); 517 518 IRBuilder<> Builder(&VPI); 519 520 // Ineffective %evl parameter and so nothing to do here. 521 if (VPI.canIgnoreVectorLengthParam()) 522 return &VPI; 523 524 // Only VP intrinsics can have an %evl parameter. 525 Value *OldMaskParam = VPI.getMaskParam(); 526 Value *OldEVLParam = VPI.getVectorLengthParam(); 527 assert(OldMaskParam && "no mask param to fold the vl param into"); 528 assert(OldEVLParam && "no EVL param to fold away"); 529 530 LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n'); 531 LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); 532 533 // Convert the %evl predication into vector mask predication. 534 ElementCount ElemCount = VPI.getStaticVectorLength(); 535 Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount); 536 Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); 537 VPI.setMaskParam(NewMaskParam); 538 539 // Drop the %evl parameter. 540 discardEVLParameter(VPI); 541 assert(VPI.canIgnoreVectorLengthParam() && 542 "transformation did not render the evl param ineffective!"); 543 544 // Reassess the modified instruction. 545 return &VPI; 546 } 547 548 Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) { 549 LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n'); 550 551 IRBuilder<> Builder(&VPI); 552 553 // Try lowering to a LLVM instruction first. 554 auto OC = VPI.getFunctionalOpcode(); 555 556 if (OC && Instruction::isBinaryOp(*OC)) 557 return expandPredicationInBinaryOperator(Builder, VPI); 558 559 if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI)) 560 return expandPredicationInReduction(Builder, *VPRI); 561 562 if (auto *VPCmp = dyn_cast<VPCmpIntrinsic>(&VPI)) 563 return expandPredicationInComparison(Builder, *VPCmp); 564 565 switch (VPI.getIntrinsicID()) { 566 default: 567 break; 568 case Intrinsic::vp_load: 569 case Intrinsic::vp_store: 570 case Intrinsic::vp_gather: 571 case Intrinsic::vp_scatter: 572 return expandPredicationInMemoryIntrinsic(Builder, VPI); 573 } 574 575 return &VPI; 576 } 577 578 //// } CachingVPExpander 579 580 struct TransformJob { 581 VPIntrinsic *PI; 582 TargetTransformInfo::VPLegalization Strategy; 583 TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat) 584 : PI(PI), Strategy(InitStrat) {} 585 586 bool isDone() const { return Strategy.shouldDoNothing(); } 587 }; 588 589 void sanitizeStrategy(VPIntrinsic &VPI, VPLegalization &LegalizeStrat) { 590 // Operations with speculatable lanes do not strictly need predication. 591 if (maySpeculateLanes(VPI)) { 592 // Converting a speculatable VP intrinsic means dropping %mask and %evl. 593 // No need to expand %evl into the %mask only to ignore that code. 594 if (LegalizeStrat.OpStrategy == VPLegalization::Convert) 595 LegalizeStrat.EVLParamStrategy = VPLegalization::Discard; 596 return; 597 } 598 599 // We have to preserve the predicating effect of %evl for this 600 // non-speculatable VP intrinsic. 601 // 1) Never discard %evl. 602 // 2) If this VP intrinsic will be expanded to non-VP code, make sure that 603 // %evl gets folded into %mask. 604 if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) || 605 (LegalizeStrat.OpStrategy == VPLegalization::Convert)) { 606 LegalizeStrat.EVLParamStrategy = VPLegalization::Convert; 607 } 608 } 609 610 VPLegalization 611 CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const { 612 auto VPStrat = TTI.getVPLegalizationStrategy(VPI); 613 if (LLVM_LIKELY(!UsingTTIOverrides)) { 614 // No overrides - we are in production. 615 return VPStrat; 616 } 617 618 // Overrides set - we are in testing, the following does not need to be 619 // efficient. 620 VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride); 621 VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride); 622 return VPStrat; 623 } 624 625 /// Expand llvm.vp.* intrinsics as requested by \p TTI. 626 bool CachingVPExpander::expandVectorPredication() { 627 SmallVector<TransformJob, 16> Worklist; 628 629 // Collect all VPIntrinsics that need expansion and determine their expansion 630 // strategy. 631 for (auto &I : instructions(F)) { 632 auto *VPI = dyn_cast<VPIntrinsic>(&I); 633 if (!VPI) 634 continue; 635 auto VPStrat = getVPLegalizationStrategy(*VPI); 636 sanitizeStrategy(*VPI, VPStrat); 637 if (!VPStrat.shouldDoNothing()) 638 Worklist.emplace_back(VPI, VPStrat); 639 } 640 if (Worklist.empty()) 641 return false; 642 643 // Transform all VPIntrinsics on the worklist. 644 LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size() 645 << " instructions ::::\n"); 646 for (TransformJob Job : Worklist) { 647 // Transform the EVL parameter. 648 switch (Job.Strategy.EVLParamStrategy) { 649 case VPLegalization::Legal: 650 break; 651 case VPLegalization::Discard: 652 discardEVLParameter(*Job.PI); 653 break; 654 case VPLegalization::Convert: 655 if (foldEVLIntoMask(*Job.PI)) 656 ++NumFoldedVL; 657 break; 658 } 659 Job.Strategy.EVLParamStrategy = VPLegalization::Legal; 660 661 // Replace with a non-predicated operation. 662 switch (Job.Strategy.OpStrategy) { 663 case VPLegalization::Legal: 664 break; 665 case VPLegalization::Discard: 666 llvm_unreachable("Invalid strategy for operators."); 667 case VPLegalization::Convert: 668 expandPredication(*Job.PI); 669 ++NumLoweredVPOps; 670 break; 671 } 672 Job.Strategy.OpStrategy = VPLegalization::Legal; 673 674 assert(Job.isDone() && "incomplete transformation"); 675 } 676 677 return true; 678 } 679 class ExpandVectorPredication : public FunctionPass { 680 public: 681 static char ID; 682 ExpandVectorPredication() : FunctionPass(ID) { 683 initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry()); 684 } 685 686 bool runOnFunction(Function &F) override { 687 const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 688 CachingVPExpander VPExpander(F, *TTI); 689 return VPExpander.expandVectorPredication(); 690 } 691 692 void getAnalysisUsage(AnalysisUsage &AU) const override { 693 AU.addRequired<TargetTransformInfoWrapperPass>(); 694 AU.setPreservesCFG(); 695 } 696 }; 697 } // namespace 698 699 char ExpandVectorPredication::ID; 700 INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp", 701 "Expand vector predication intrinsics", false, false) 702 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 703 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 704 INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp", 705 "Expand vector predication intrinsics", false, false) 706 707 FunctionPass *llvm::createExpandVectorPredicationPass() { 708 return new ExpandVectorPredication(); 709 } 710 711 PreservedAnalyses 712 ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { 713 const auto &TTI = AM.getResult<TargetIRAnalysis>(F); 714 CachingVPExpander VPExpander(F, TTI); 715 if (!VPExpander.expandVectorPredication()) 716 return PreservedAnalyses::all(); 717 PreservedAnalyses PA; 718 PA.preserveSet<CFGAnalyses>(); 719 return PA; 720 } 721