1 //===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements IR expansion for vector predication intrinsics, allowing 10 // targets to enable vector predication until just before codegen. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/ExpandVectorPredication.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/Analysis/TargetTransformInfo.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/Analysis/VectorUtils.h" 19 #include "llvm/CodeGen/Passes.h" 20 #include "llvm/IR/Constants.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/InstIterator.h" 24 #include "llvm/IR/Instructions.h" 25 #include "llvm/IR/IntrinsicInst.h" 26 #include "llvm/IR/Intrinsics.h" 27 #include "llvm/InitializePasses.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Compiler.h" 31 #include "llvm/Support/Debug.h" 32 #include <optional> 33 34 using namespace llvm; 35 36 using VPLegalization = TargetTransformInfo::VPLegalization; 37 using VPTransform = TargetTransformInfo::VPLegalization::VPTransform; 38 39 // Keep this in sync with TargetTransformInfo::VPLegalization. 40 #define VPINTERNAL_VPLEGAL_CASES \ 41 VPINTERNAL_CASE(Legal) \ 42 VPINTERNAL_CASE(Discard) \ 43 VPINTERNAL_CASE(Convert) 44 45 #define VPINTERNAL_CASE(X) "|" #X 46 47 // Override options. 48 static cl::opt<std::string> EVLTransformOverride( 49 "expandvp-override-evl-transform", cl::init(""), cl::Hidden, 50 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES 51 ". If non-empty, ignore " 52 "TargetTransformInfo and " 53 "always use this transformation for the %evl parameter (Used in " 54 "testing).")); 55 56 static cl::opt<std::string> MaskTransformOverride( 57 "expandvp-override-mask-transform", cl::init(""), cl::Hidden, 58 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES 59 ". If non-empty, Ignore " 60 "TargetTransformInfo and " 61 "always use this transformation for the %mask parameter (Used in " 62 "testing).")); 63 64 #undef VPINTERNAL_CASE 65 #define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X) 66 67 static VPTransform parseOverrideOption(const std::string &TextOpt) { 68 return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES; 69 } 70 71 #undef VPINTERNAL_VPLEGAL_CASES 72 73 // Whether any override options are set. 74 static bool anyExpandVPOverridesSet() { 75 return !EVLTransformOverride.empty() || !MaskTransformOverride.empty(); 76 } 77 78 #define DEBUG_TYPE "expandvp" 79 80 STATISTIC(NumFoldedVL, "Number of folded vector length params"); 81 STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations"); 82 83 ///// Helpers { 84 85 /// \returns Whether the vector mask \p MaskVal has all lane bits set. 86 static bool isAllTrueMask(Value *MaskVal) { 87 if (Value *SplattedVal = getSplatValue(MaskVal)) 88 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal)) 89 return ConstValue->isAllOnesValue(); 90 91 return false; 92 } 93 94 /// \returns A non-excepting divisor constant for this type. 95 static Constant *getSafeDivisor(Type *DivTy) { 96 assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type"); 97 return ConstantInt::get(DivTy, 1u, false); 98 } 99 100 /// Transfer operation properties from \p OldVPI to \p NewVal. 101 static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) { 102 auto *NewInst = dyn_cast<Instruction>(&NewVal); 103 if (!NewInst || !isa<FPMathOperator>(NewVal)) 104 return; 105 106 auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI); 107 if (!OldFMOp) 108 return; 109 110 NewInst->setFastMathFlags(OldFMOp->getFastMathFlags()); 111 } 112 113 /// Transfer all properties from \p OldOp to \p NewOp and replace all uses. 114 /// OldVP gets erased. 115 static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) { 116 transferDecorations(NewOp, OldOp); 117 OldOp.replaceAllUsesWith(&NewOp); 118 OldOp.eraseFromParent(); 119 } 120 121 static bool maySpeculateLanes(VPIntrinsic &VPI) { 122 // The result of VP reductions depends on the mask and evl. 123 if (isa<VPReductionIntrinsic>(VPI)) 124 return false; 125 // Fallback to whether the intrinsic is speculatable. 126 std::optional<unsigned> OpcOpt = VPI.getFunctionalOpcode(); 127 unsigned FunctionalOpc = OpcOpt.value_or((unsigned)Instruction::Call); 128 return isSafeToSpeculativelyExecuteWithOpcode(FunctionalOpc, &VPI); 129 } 130 131 //// } Helpers 132 133 namespace { 134 135 // Expansion pass state at function scope. 136 struct CachingVPExpander { 137 Function &F; 138 const TargetTransformInfo &TTI; 139 140 /// \returns A (fixed length) vector with ascending integer indices 141 /// (<0, 1, ..., NumElems-1>). 142 /// \p Builder 143 /// Used for instruction creation. 144 /// \p LaneTy 145 /// Integer element type of the result vector. 146 /// \p NumElems 147 /// Number of vector elements. 148 Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy, 149 unsigned NumElems); 150 151 /// \returns A bitmask that is true where the lane position is less-than \p 152 /// EVLParam 153 /// 154 /// \p Builder 155 /// Used for instruction creation. 156 /// \p VLParam 157 /// The explicit vector length parameter to test against the lane 158 /// positions. 159 /// \p ElemCount 160 /// Static (potentially scalable) number of vector elements. 161 Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam, 162 ElementCount ElemCount); 163 164 Value *foldEVLIntoMask(VPIntrinsic &VPI); 165 166 /// "Remove" the %evl parameter of \p PI by setting it to the static vector 167 /// length of the operation. 168 void discardEVLParameter(VPIntrinsic &PI); 169 170 /// Lower this VP binary operator to a unpredicated binary operator. 171 Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder, 172 VPIntrinsic &PI); 173 174 /// Lower this VP fp call to a unpredicated fp call. 175 Value *expandPredicationToFPCall(IRBuilder<> &Builder, VPIntrinsic &PI, 176 unsigned UnpredicatedIntrinsicID); 177 178 /// Lower this VP reduction to a call to an unpredicated reduction intrinsic. 179 Value *expandPredicationInReduction(IRBuilder<> &Builder, 180 VPReductionIntrinsic &PI); 181 182 /// Lower this VP memory operation to a non-VP intrinsic. 183 Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, 184 VPIntrinsic &VPI); 185 186 /// Lower this VP comparison to a call to an unpredicated comparison. 187 Value *expandPredicationInComparison(IRBuilder<> &Builder, 188 VPCmpIntrinsic &PI); 189 190 /// Query TTI and expand the vector predication in \p P accordingly. 191 Value *expandPredication(VPIntrinsic &PI); 192 193 /// Determine how and whether the VPIntrinsic \p VPI shall be expanded. This 194 /// overrides TTI with the cl::opts listed at the top of this file. 195 VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const; 196 bool UsingTTIOverrides; 197 198 public: 199 CachingVPExpander(Function &F, const TargetTransformInfo &TTI) 200 : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {} 201 202 bool expandVectorPredication(); 203 }; 204 205 //// CachingVPExpander { 206 207 Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy, 208 unsigned NumElems) { 209 // TODO add caching 210 SmallVector<Constant *, 16> ConstElems; 211 212 for (unsigned Idx = 0; Idx < NumElems; ++Idx) 213 ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false)); 214 215 return ConstantVector::get(ConstElems); 216 } 217 218 Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder, 219 Value *EVLParam, 220 ElementCount ElemCount) { 221 // TODO add caching 222 // Scalable vector %evl conversion. 223 if (ElemCount.isScalable()) { 224 auto *M = Builder.GetInsertBlock()->getModule(); 225 Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount); 226 Function *ActiveMaskFunc = Intrinsic::getDeclaration( 227 M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()}); 228 // `get_active_lane_mask` performs an implicit less-than comparison. 229 Value *ConstZero = Builder.getInt32(0); 230 return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam}); 231 } 232 233 // Fixed vector %evl conversion. 234 Type *LaneTy = EVLParam->getType(); 235 unsigned NumElems = ElemCount.getFixedValue(); 236 Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam); 237 Value *IdxVec = createStepVector(Builder, LaneTy, NumElems); 238 return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat); 239 } 240 241 Value * 242 CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder, 243 VPIntrinsic &VPI) { 244 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 245 "Implicitly dropping %evl in non-speculatable operator!"); 246 247 auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode()); 248 assert(Instruction::isBinaryOp(OC)); 249 250 Value *Op0 = VPI.getOperand(0); 251 Value *Op1 = VPI.getOperand(1); 252 Value *Mask = VPI.getMaskParam(); 253 254 // Blend in safe operands. 255 if (Mask && !isAllTrueMask(Mask)) { 256 switch (OC) { 257 default: 258 // Can safely ignore the predicate. 259 break; 260 261 // Division operators need a safe divisor on masked-off lanes (1). 262 case Instruction::UDiv: 263 case Instruction::SDiv: 264 case Instruction::URem: 265 case Instruction::SRem: 266 // 2nd operand must not be zero. 267 Value *SafeDivisor = getSafeDivisor(VPI.getType()); 268 Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor); 269 } 270 } 271 272 Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName()); 273 274 replaceOperation(*NewBinOp, VPI); 275 return NewBinOp; 276 } 277 278 Value *CachingVPExpander::expandPredicationToFPCall( 279 IRBuilder<> &Builder, VPIntrinsic &VPI, unsigned UnpredicatedIntrinsicID) { 280 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 281 "Implicitly dropping %evl in non-speculatable operator!"); 282 283 switch (UnpredicatedIntrinsicID) { 284 case Intrinsic::fabs: 285 case Intrinsic::sqrt: { 286 Value *Op0 = VPI.getOperand(0); 287 Function *Fn = Intrinsic::getDeclaration( 288 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 289 Value *NewOp = Builder.CreateCall(Fn, {Op0}, VPI.getName()); 290 replaceOperation(*NewOp, VPI); 291 return NewOp; 292 } 293 case Intrinsic::experimental_constrained_fma: 294 case Intrinsic::experimental_constrained_fmuladd: { 295 Value *Op0 = VPI.getOperand(0); 296 Value *Op1 = VPI.getOperand(1); 297 Value *Op2 = VPI.getOperand(2); 298 Function *Fn = Intrinsic::getDeclaration( 299 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 300 Value *NewOp = 301 Builder.CreateConstrainedFPCall(Fn, {Op0, Op1, Op2}, VPI.getName()); 302 replaceOperation(*NewOp, VPI); 303 return NewOp; 304 } 305 } 306 307 return nullptr; 308 } 309 310 static Value *getNeutralReductionElement(const VPReductionIntrinsic &VPI, 311 Type *EltTy) { 312 bool Negative = false; 313 unsigned EltBits = EltTy->getScalarSizeInBits(); 314 switch (VPI.getIntrinsicID()) { 315 default: 316 llvm_unreachable("Expecting a VP reduction intrinsic"); 317 case Intrinsic::vp_reduce_add: 318 case Intrinsic::vp_reduce_or: 319 case Intrinsic::vp_reduce_xor: 320 case Intrinsic::vp_reduce_umax: 321 return Constant::getNullValue(EltTy); 322 case Intrinsic::vp_reduce_mul: 323 return ConstantInt::get(EltTy, 1, /*IsSigned*/ false); 324 case Intrinsic::vp_reduce_and: 325 case Intrinsic::vp_reduce_umin: 326 return ConstantInt::getAllOnesValue(EltTy); 327 case Intrinsic::vp_reduce_smin: 328 return ConstantInt::get(EltTy->getContext(), 329 APInt::getSignedMaxValue(EltBits)); 330 case Intrinsic::vp_reduce_smax: 331 return ConstantInt::get(EltTy->getContext(), 332 APInt::getSignedMinValue(EltBits)); 333 case Intrinsic::vp_reduce_fmax: 334 Negative = true; 335 [[fallthrough]]; 336 case Intrinsic::vp_reduce_fmin: { 337 FastMathFlags Flags = VPI.getFastMathFlags(); 338 const fltSemantics &Semantics = EltTy->getFltSemantics(); 339 return !Flags.noNaNs() ? ConstantFP::getQNaN(EltTy, Negative) 340 : !Flags.noInfs() 341 ? ConstantFP::getInfinity(EltTy, Negative) 342 : ConstantFP::get(EltTy, 343 APFloat::getLargest(Semantics, Negative)); 344 } 345 case Intrinsic::vp_reduce_fadd: 346 return ConstantFP::getNegativeZero(EltTy); 347 case Intrinsic::vp_reduce_fmul: 348 return ConstantFP::get(EltTy, 1.0); 349 } 350 } 351 352 Value * 353 CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder, 354 VPReductionIntrinsic &VPI) { 355 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 356 "Implicitly dropping %evl in non-speculatable operator!"); 357 358 Value *Mask = VPI.getMaskParam(); 359 Value *RedOp = VPI.getOperand(VPI.getVectorParamPos()); 360 361 // Insert neutral element in masked-out positions 362 if (Mask && !isAllTrueMask(Mask)) { 363 auto *NeutralElt = getNeutralReductionElement(VPI, VPI.getType()); 364 auto *NeutralVector = Builder.CreateVectorSplat( 365 cast<VectorType>(RedOp->getType())->getElementCount(), NeutralElt); 366 RedOp = Builder.CreateSelect(Mask, RedOp, NeutralVector); 367 } 368 369 Value *Reduction; 370 Value *Start = VPI.getOperand(VPI.getStartParamPos()); 371 372 switch (VPI.getIntrinsicID()) { 373 default: 374 llvm_unreachable("Impossible reduction kind"); 375 case Intrinsic::vp_reduce_add: 376 Reduction = Builder.CreateAddReduce(RedOp); 377 Reduction = Builder.CreateAdd(Reduction, Start); 378 break; 379 case Intrinsic::vp_reduce_mul: 380 Reduction = Builder.CreateMulReduce(RedOp); 381 Reduction = Builder.CreateMul(Reduction, Start); 382 break; 383 case Intrinsic::vp_reduce_and: 384 Reduction = Builder.CreateAndReduce(RedOp); 385 Reduction = Builder.CreateAnd(Reduction, Start); 386 break; 387 case Intrinsic::vp_reduce_or: 388 Reduction = Builder.CreateOrReduce(RedOp); 389 Reduction = Builder.CreateOr(Reduction, Start); 390 break; 391 case Intrinsic::vp_reduce_xor: 392 Reduction = Builder.CreateXorReduce(RedOp); 393 Reduction = Builder.CreateXor(Reduction, Start); 394 break; 395 case Intrinsic::vp_reduce_smax: 396 Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ true); 397 Reduction = 398 Builder.CreateBinaryIntrinsic(Intrinsic::smax, Reduction, Start); 399 break; 400 case Intrinsic::vp_reduce_smin: 401 Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ true); 402 Reduction = 403 Builder.CreateBinaryIntrinsic(Intrinsic::smin, Reduction, Start); 404 break; 405 case Intrinsic::vp_reduce_umax: 406 Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ false); 407 Reduction = 408 Builder.CreateBinaryIntrinsic(Intrinsic::umax, Reduction, Start); 409 break; 410 case Intrinsic::vp_reduce_umin: 411 Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ false); 412 Reduction = 413 Builder.CreateBinaryIntrinsic(Intrinsic::umin, Reduction, Start); 414 break; 415 case Intrinsic::vp_reduce_fmax: 416 Reduction = Builder.CreateFPMaxReduce(RedOp); 417 transferDecorations(*Reduction, VPI); 418 Reduction = 419 Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, Reduction, Start); 420 break; 421 case Intrinsic::vp_reduce_fmin: 422 Reduction = Builder.CreateFPMinReduce(RedOp); 423 transferDecorations(*Reduction, VPI); 424 Reduction = 425 Builder.CreateBinaryIntrinsic(Intrinsic::minnum, Reduction, Start); 426 break; 427 case Intrinsic::vp_reduce_fadd: 428 Reduction = Builder.CreateFAddReduce(Start, RedOp); 429 break; 430 case Intrinsic::vp_reduce_fmul: 431 Reduction = Builder.CreateFMulReduce(Start, RedOp); 432 break; 433 } 434 435 replaceOperation(*Reduction, VPI); 436 return Reduction; 437 } 438 439 Value * 440 CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, 441 VPIntrinsic &VPI) { 442 assert(VPI.canIgnoreVectorLengthParam()); 443 444 const auto &DL = F.getParent()->getDataLayout(); 445 446 Value *MaskParam = VPI.getMaskParam(); 447 Value *PtrParam = VPI.getMemoryPointerParam(); 448 Value *DataParam = VPI.getMemoryDataParam(); 449 bool IsUnmasked = isAllTrueMask(MaskParam); 450 451 MaybeAlign AlignOpt = VPI.getPointerAlignment(); 452 453 Value *NewMemoryInst = nullptr; 454 switch (VPI.getIntrinsicID()) { 455 default: 456 llvm_unreachable("Not a VP memory intrinsic"); 457 case Intrinsic::vp_store: 458 if (IsUnmasked) { 459 StoreInst *NewStore = 460 Builder.CreateStore(DataParam, PtrParam, /*IsVolatile*/ false); 461 if (AlignOpt.has_value()) 462 NewStore->setAlignment(*AlignOpt); 463 NewMemoryInst = NewStore; 464 } else 465 NewMemoryInst = Builder.CreateMaskedStore( 466 DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam); 467 468 break; 469 case Intrinsic::vp_load: 470 if (IsUnmasked) { 471 LoadInst *NewLoad = 472 Builder.CreateLoad(VPI.getType(), PtrParam, /*IsVolatile*/ false); 473 if (AlignOpt.has_value()) 474 NewLoad->setAlignment(*AlignOpt); 475 NewMemoryInst = NewLoad; 476 } else 477 NewMemoryInst = Builder.CreateMaskedLoad( 478 VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam); 479 480 break; 481 case Intrinsic::vp_scatter: { 482 auto *ElementType = 483 cast<VectorType>(DataParam->getType())->getElementType(); 484 NewMemoryInst = Builder.CreateMaskedScatter( 485 DataParam, PtrParam, 486 AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam); 487 break; 488 } 489 case Intrinsic::vp_gather: { 490 auto *ElementType = cast<VectorType>(VPI.getType())->getElementType(); 491 NewMemoryInst = Builder.CreateMaskedGather( 492 VPI.getType(), PtrParam, 493 AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam, nullptr, 494 VPI.getName()); 495 break; 496 } 497 } 498 499 assert(NewMemoryInst); 500 replaceOperation(*NewMemoryInst, VPI); 501 return NewMemoryInst; 502 } 503 504 Value *CachingVPExpander::expandPredicationInComparison(IRBuilder<> &Builder, 505 VPCmpIntrinsic &VPI) { 506 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 507 "Implicitly dropping %evl in non-speculatable operator!"); 508 509 assert(*VPI.getFunctionalOpcode() == Instruction::ICmp || 510 *VPI.getFunctionalOpcode() == Instruction::FCmp); 511 512 Value *Op0 = VPI.getOperand(0); 513 Value *Op1 = VPI.getOperand(1); 514 auto Pred = VPI.getPredicate(); 515 516 auto *NewCmp = Builder.CreateCmp(Pred, Op0, Op1); 517 518 replaceOperation(*NewCmp, VPI); 519 return NewCmp; 520 } 521 522 void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) { 523 LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n"); 524 525 if (VPI.canIgnoreVectorLengthParam()) 526 return; 527 528 Value *EVLParam = VPI.getVectorLengthParam(); 529 if (!EVLParam) 530 return; 531 532 ElementCount StaticElemCount = VPI.getStaticVectorLength(); 533 Value *MaxEVL = nullptr; 534 Type *Int32Ty = Type::getInt32Ty(VPI.getContext()); 535 if (StaticElemCount.isScalable()) { 536 // TODO add caching 537 auto *M = VPI.getModule(); 538 Function *VScaleFunc = 539 Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty); 540 IRBuilder<> Builder(VPI.getParent(), VPI.getIterator()); 541 Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue()); 542 Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale"); 543 MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size", 544 /*NUW*/ true, /*NSW*/ false); 545 } else { 546 MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false); 547 } 548 VPI.setVectorLengthParam(MaxEVL); 549 } 550 551 Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) { 552 LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n'); 553 554 IRBuilder<> Builder(&VPI); 555 556 // Ineffective %evl parameter and so nothing to do here. 557 if (VPI.canIgnoreVectorLengthParam()) 558 return &VPI; 559 560 // Only VP intrinsics can have an %evl parameter. 561 Value *OldMaskParam = VPI.getMaskParam(); 562 Value *OldEVLParam = VPI.getVectorLengthParam(); 563 assert(OldMaskParam && "no mask param to fold the vl param into"); 564 assert(OldEVLParam && "no EVL param to fold away"); 565 566 LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n'); 567 LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); 568 569 // Convert the %evl predication into vector mask predication. 570 ElementCount ElemCount = VPI.getStaticVectorLength(); 571 Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount); 572 Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); 573 VPI.setMaskParam(NewMaskParam); 574 575 // Drop the %evl parameter. 576 discardEVLParameter(VPI); 577 assert(VPI.canIgnoreVectorLengthParam() && 578 "transformation did not render the evl param ineffective!"); 579 580 // Reassess the modified instruction. 581 return &VPI; 582 } 583 584 Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) { 585 LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n'); 586 587 IRBuilder<> Builder(&VPI); 588 589 // Try lowering to a LLVM instruction first. 590 auto OC = VPI.getFunctionalOpcode(); 591 592 if (OC && Instruction::isBinaryOp(*OC)) 593 return expandPredicationInBinaryOperator(Builder, VPI); 594 595 if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI)) 596 return expandPredicationInReduction(Builder, *VPRI); 597 598 if (auto *VPCmp = dyn_cast<VPCmpIntrinsic>(&VPI)) 599 return expandPredicationInComparison(Builder, *VPCmp); 600 601 switch (VPI.getIntrinsicID()) { 602 default: 603 break; 604 case Intrinsic::vp_fneg: { 605 Value *NewNegOp = Builder.CreateFNeg(VPI.getOperand(0), VPI.getName()); 606 replaceOperation(*NewNegOp, VPI); 607 return NewNegOp; 608 } 609 case Intrinsic::vp_fabs: 610 return expandPredicationToFPCall(Builder, VPI, Intrinsic::fabs); 611 case Intrinsic::vp_sqrt: 612 return expandPredicationToFPCall(Builder, VPI, Intrinsic::sqrt); 613 case Intrinsic::vp_load: 614 case Intrinsic::vp_store: 615 case Intrinsic::vp_gather: 616 case Intrinsic::vp_scatter: 617 return expandPredicationInMemoryIntrinsic(Builder, VPI); 618 } 619 620 if (auto CID = VPI.getConstrainedIntrinsicID()) 621 if (Value *Call = expandPredicationToFPCall(Builder, VPI, *CID)) 622 return Call; 623 624 return &VPI; 625 } 626 627 //// } CachingVPExpander 628 629 struct TransformJob { 630 VPIntrinsic *PI; 631 TargetTransformInfo::VPLegalization Strategy; 632 TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat) 633 : PI(PI), Strategy(InitStrat) {} 634 635 bool isDone() const { return Strategy.shouldDoNothing(); } 636 }; 637 638 void sanitizeStrategy(VPIntrinsic &VPI, VPLegalization &LegalizeStrat) { 639 // Operations with speculatable lanes do not strictly need predication. 640 if (maySpeculateLanes(VPI)) { 641 // Converting a speculatable VP intrinsic means dropping %mask and %evl. 642 // No need to expand %evl into the %mask only to ignore that code. 643 if (LegalizeStrat.OpStrategy == VPLegalization::Convert) 644 LegalizeStrat.EVLParamStrategy = VPLegalization::Discard; 645 return; 646 } 647 648 // We have to preserve the predicating effect of %evl for this 649 // non-speculatable VP intrinsic. 650 // 1) Never discard %evl. 651 // 2) If this VP intrinsic will be expanded to non-VP code, make sure that 652 // %evl gets folded into %mask. 653 if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) || 654 (LegalizeStrat.OpStrategy == VPLegalization::Convert)) { 655 LegalizeStrat.EVLParamStrategy = VPLegalization::Convert; 656 } 657 } 658 659 VPLegalization 660 CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const { 661 auto VPStrat = TTI.getVPLegalizationStrategy(VPI); 662 if (LLVM_LIKELY(!UsingTTIOverrides)) { 663 // No overrides - we are in production. 664 return VPStrat; 665 } 666 667 // Overrides set - we are in testing, the following does not need to be 668 // efficient. 669 VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride); 670 VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride); 671 return VPStrat; 672 } 673 674 /// Expand llvm.vp.* intrinsics as requested by \p TTI. 675 bool CachingVPExpander::expandVectorPredication() { 676 SmallVector<TransformJob, 16> Worklist; 677 678 // Collect all VPIntrinsics that need expansion and determine their expansion 679 // strategy. 680 for (auto &I : instructions(F)) { 681 auto *VPI = dyn_cast<VPIntrinsic>(&I); 682 if (!VPI) 683 continue; 684 auto VPStrat = getVPLegalizationStrategy(*VPI); 685 sanitizeStrategy(*VPI, VPStrat); 686 if (!VPStrat.shouldDoNothing()) 687 Worklist.emplace_back(VPI, VPStrat); 688 } 689 if (Worklist.empty()) 690 return false; 691 692 // Transform all VPIntrinsics on the worklist. 693 LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size() 694 << " instructions ::::\n"); 695 for (TransformJob Job : Worklist) { 696 // Transform the EVL parameter. 697 switch (Job.Strategy.EVLParamStrategy) { 698 case VPLegalization::Legal: 699 break; 700 case VPLegalization::Discard: 701 discardEVLParameter(*Job.PI); 702 break; 703 case VPLegalization::Convert: 704 if (foldEVLIntoMask(*Job.PI)) 705 ++NumFoldedVL; 706 break; 707 } 708 Job.Strategy.EVLParamStrategy = VPLegalization::Legal; 709 710 // Replace with a non-predicated operation. 711 switch (Job.Strategy.OpStrategy) { 712 case VPLegalization::Legal: 713 break; 714 case VPLegalization::Discard: 715 llvm_unreachable("Invalid strategy for operators."); 716 case VPLegalization::Convert: 717 expandPredication(*Job.PI); 718 ++NumLoweredVPOps; 719 break; 720 } 721 Job.Strategy.OpStrategy = VPLegalization::Legal; 722 723 assert(Job.isDone() && "incomplete transformation"); 724 } 725 726 return true; 727 } 728 class ExpandVectorPredication : public FunctionPass { 729 public: 730 static char ID; 731 ExpandVectorPredication() : FunctionPass(ID) { 732 initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry()); 733 } 734 735 bool runOnFunction(Function &F) override { 736 const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 737 CachingVPExpander VPExpander(F, *TTI); 738 return VPExpander.expandVectorPredication(); 739 } 740 741 void getAnalysisUsage(AnalysisUsage &AU) const override { 742 AU.addRequired<TargetTransformInfoWrapperPass>(); 743 AU.setPreservesCFG(); 744 } 745 }; 746 } // namespace 747 748 char ExpandVectorPredication::ID; 749 INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp", 750 "Expand vector predication intrinsics", false, false) 751 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 752 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 753 INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp", 754 "Expand vector predication intrinsics", false, false) 755 756 FunctionPass *llvm::createExpandVectorPredicationPass() { 757 return new ExpandVectorPredication(); 758 } 759 760 PreservedAnalyses 761 ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { 762 const auto &TTI = AM.getResult<TargetIRAnalysis>(F); 763 CachingVPExpander VPExpander(F, TTI); 764 if (!VPExpander.expandVectorPredication()) 765 return PreservedAnalyses::all(); 766 PreservedAnalyses PA; 767 PA.preserveSet<CFGAnalyses>(); 768 return PA; 769 } 770