1 //===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements IR expansion for vector predication intrinsics, allowing 10 // targets to enable vector predication until just before codegen. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/ExpandVectorPredication.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/Analysis/TargetTransformInfo.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/Analysis/VectorUtils.h" 19 #include "llvm/CodeGen/Passes.h" 20 #include "llvm/IR/Constants.h" 21 #include "llvm/IR/Function.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/InstIterator.h" 24 #include "llvm/IR/Instructions.h" 25 #include "llvm/IR/IntrinsicInst.h" 26 #include "llvm/IR/Intrinsics.h" 27 #include "llvm/InitializePasses.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Compiler.h" 31 #include "llvm/Support/Debug.h" 32 #include <optional> 33 34 using namespace llvm; 35 36 using VPLegalization = TargetTransformInfo::VPLegalization; 37 using VPTransform = TargetTransformInfo::VPLegalization::VPTransform; 38 39 // Keep this in sync with TargetTransformInfo::VPLegalization. 40 #define VPINTERNAL_VPLEGAL_CASES \ 41 VPINTERNAL_CASE(Legal) \ 42 VPINTERNAL_CASE(Discard) \ 43 VPINTERNAL_CASE(Convert) 44 45 #define VPINTERNAL_CASE(X) "|" #X 46 47 // Override options. 48 static cl::opt<std::string> EVLTransformOverride( 49 "expandvp-override-evl-transform", cl::init(""), cl::Hidden, 50 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES 51 ". If non-empty, ignore " 52 "TargetTransformInfo and " 53 "always use this transformation for the %evl parameter (Used in " 54 "testing).")); 55 56 static cl::opt<std::string> MaskTransformOverride( 57 "expandvp-override-mask-transform", cl::init(""), cl::Hidden, 58 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES 59 ". If non-empty, Ignore " 60 "TargetTransformInfo and " 61 "always use this transformation for the %mask parameter (Used in " 62 "testing).")); 63 64 #undef VPINTERNAL_CASE 65 #define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X) 66 67 static VPTransform parseOverrideOption(const std::string &TextOpt) { 68 return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES; 69 } 70 71 #undef VPINTERNAL_VPLEGAL_CASES 72 73 // Whether any override options are set. 74 static bool anyExpandVPOverridesSet() { 75 return !EVLTransformOverride.empty() || !MaskTransformOverride.empty(); 76 } 77 78 #define DEBUG_TYPE "expandvp" 79 80 STATISTIC(NumFoldedVL, "Number of folded vector length params"); 81 STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations"); 82 83 ///// Helpers { 84 85 /// \returns Whether the vector mask \p MaskVal has all lane bits set. 86 static bool isAllTrueMask(Value *MaskVal) { 87 if (Value *SplattedVal = getSplatValue(MaskVal)) 88 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal)) 89 return ConstValue->isAllOnesValue(); 90 91 return false; 92 } 93 94 /// \returns A non-excepting divisor constant for this type. 95 static Constant *getSafeDivisor(Type *DivTy) { 96 assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type"); 97 return ConstantInt::get(DivTy, 1u, false); 98 } 99 100 /// Transfer operation properties from \p OldVPI to \p NewVal. 101 static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) { 102 auto *NewInst = dyn_cast<Instruction>(&NewVal); 103 if (!NewInst || !isa<FPMathOperator>(NewVal)) 104 return; 105 106 auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI); 107 if (!OldFMOp) 108 return; 109 110 NewInst->setFastMathFlags(OldFMOp->getFastMathFlags()); 111 } 112 113 /// Transfer all properties from \p OldOp to \p NewOp and replace all uses. 114 /// OldVP gets erased. 115 static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) { 116 transferDecorations(NewOp, OldOp); 117 OldOp.replaceAllUsesWith(&NewOp); 118 OldOp.eraseFromParent(); 119 } 120 121 static bool maySpeculateLanes(VPIntrinsic &VPI) { 122 // The result of VP reductions depends on the mask and evl. 123 if (isa<VPReductionIntrinsic>(VPI)) 124 return false; 125 // Fallback to whether the intrinsic is speculatable. 126 if (auto IntrID = VPI.getFunctionalIntrinsicID()) 127 return Intrinsic::getAttributes(VPI.getContext(), *IntrID) 128 .hasFnAttr(Attribute::AttrKind::Speculatable); 129 if (auto Opc = VPI.getFunctionalOpcode()) 130 return isSafeToSpeculativelyExecuteWithOpcode(*Opc, &VPI); 131 return false; 132 } 133 134 //// } Helpers 135 136 namespace { 137 138 // Expansion pass state at function scope. 139 struct CachingVPExpander { 140 Function &F; 141 const TargetTransformInfo &TTI; 142 143 /// \returns A (fixed length) vector with ascending integer indices 144 /// (<0, 1, ..., NumElems-1>). 145 /// \p Builder 146 /// Used for instruction creation. 147 /// \p LaneTy 148 /// Integer element type of the result vector. 149 /// \p NumElems 150 /// Number of vector elements. 151 Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy, 152 unsigned NumElems); 153 154 /// \returns A bitmask that is true where the lane position is less-than \p 155 /// EVLParam 156 /// 157 /// \p Builder 158 /// Used for instruction creation. 159 /// \p VLParam 160 /// The explicit vector length parameter to test against the lane 161 /// positions. 162 /// \p ElemCount 163 /// Static (potentially scalable) number of vector elements. 164 Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam, 165 ElementCount ElemCount); 166 167 Value *foldEVLIntoMask(VPIntrinsic &VPI); 168 169 /// "Remove" the %evl parameter of \p PI by setting it to the static vector 170 /// length of the operation. 171 void discardEVLParameter(VPIntrinsic &PI); 172 173 /// Lower this VP binary operator to a unpredicated binary operator. 174 Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder, 175 VPIntrinsic &PI); 176 177 /// Lower this VP int call to a unpredicated int call. 178 Value *expandPredicationToIntCall(IRBuilder<> &Builder, VPIntrinsic &PI, 179 unsigned UnpredicatedIntrinsicID); 180 181 /// Lower this VP fp call to a unpredicated fp call. 182 Value *expandPredicationToFPCall(IRBuilder<> &Builder, VPIntrinsic &PI, 183 unsigned UnpredicatedIntrinsicID); 184 185 /// Lower this VP reduction to a call to an unpredicated reduction intrinsic. 186 Value *expandPredicationInReduction(IRBuilder<> &Builder, 187 VPReductionIntrinsic &PI); 188 189 /// Lower this VP cast operation to a non-VP intrinsic. 190 Value *expandPredicationToCastIntrinsic(IRBuilder<> &Builder, 191 VPIntrinsic &VPI); 192 193 /// Lower this VP memory operation to a non-VP intrinsic. 194 Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, 195 VPIntrinsic &VPI); 196 197 /// Lower this VP comparison to a call to an unpredicated comparison. 198 Value *expandPredicationInComparison(IRBuilder<> &Builder, 199 VPCmpIntrinsic &PI); 200 201 /// Query TTI and expand the vector predication in \p P accordingly. 202 Value *expandPredication(VPIntrinsic &PI); 203 204 /// Determine how and whether the VPIntrinsic \p VPI shall be expanded. This 205 /// overrides TTI with the cl::opts listed at the top of this file. 206 VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const; 207 bool UsingTTIOverrides; 208 209 public: 210 CachingVPExpander(Function &F, const TargetTransformInfo &TTI) 211 : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {} 212 213 bool expandVectorPredication(); 214 }; 215 216 //// CachingVPExpander { 217 218 Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy, 219 unsigned NumElems) { 220 // TODO add caching 221 SmallVector<Constant *, 16> ConstElems; 222 223 for (unsigned Idx = 0; Idx < NumElems; ++Idx) 224 ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false)); 225 226 return ConstantVector::get(ConstElems); 227 } 228 229 Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder, 230 Value *EVLParam, 231 ElementCount ElemCount) { 232 // TODO add caching 233 // Scalable vector %evl conversion. 234 if (ElemCount.isScalable()) { 235 auto *M = Builder.GetInsertBlock()->getModule(); 236 Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount); 237 Function *ActiveMaskFunc = Intrinsic::getDeclaration( 238 M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()}); 239 // `get_active_lane_mask` performs an implicit less-than comparison. 240 Value *ConstZero = Builder.getInt32(0); 241 return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam}); 242 } 243 244 // Fixed vector %evl conversion. 245 Type *LaneTy = EVLParam->getType(); 246 unsigned NumElems = ElemCount.getFixedValue(); 247 Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam); 248 Value *IdxVec = createStepVector(Builder, LaneTy, NumElems); 249 return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat); 250 } 251 252 Value * 253 CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder, 254 VPIntrinsic &VPI) { 255 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 256 "Implicitly dropping %evl in non-speculatable operator!"); 257 258 auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode()); 259 assert(Instruction::isBinaryOp(OC)); 260 261 Value *Op0 = VPI.getOperand(0); 262 Value *Op1 = VPI.getOperand(1); 263 Value *Mask = VPI.getMaskParam(); 264 265 // Blend in safe operands. 266 if (Mask && !isAllTrueMask(Mask)) { 267 switch (OC) { 268 default: 269 // Can safely ignore the predicate. 270 break; 271 272 // Division operators need a safe divisor on masked-off lanes (1). 273 case Instruction::UDiv: 274 case Instruction::SDiv: 275 case Instruction::URem: 276 case Instruction::SRem: 277 // 2nd operand must not be zero. 278 Value *SafeDivisor = getSafeDivisor(VPI.getType()); 279 Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor); 280 } 281 } 282 283 Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName()); 284 285 replaceOperation(*NewBinOp, VPI); 286 return NewBinOp; 287 } 288 289 Value *CachingVPExpander::expandPredicationToIntCall( 290 IRBuilder<> &Builder, VPIntrinsic &VPI, unsigned UnpredicatedIntrinsicID) { 291 switch (UnpredicatedIntrinsicID) { 292 case Intrinsic::abs: 293 case Intrinsic::smax: 294 case Intrinsic::smin: 295 case Intrinsic::umax: 296 case Intrinsic::umin: { 297 Value *Op0 = VPI.getOperand(0); 298 Value *Op1 = VPI.getOperand(1); 299 Function *Fn = Intrinsic::getDeclaration( 300 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 301 Value *NewOp = Builder.CreateCall(Fn, {Op0, Op1}, VPI.getName()); 302 replaceOperation(*NewOp, VPI); 303 return NewOp; 304 } 305 case Intrinsic::bswap: 306 case Intrinsic::bitreverse: { 307 Value *Op = VPI.getOperand(0); 308 Function *Fn = Intrinsic::getDeclaration( 309 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 310 Value *NewOp = Builder.CreateCall(Fn, {Op}, VPI.getName()); 311 replaceOperation(*NewOp, VPI); 312 return NewOp; 313 } 314 } 315 return nullptr; 316 } 317 318 Value *CachingVPExpander::expandPredicationToFPCall( 319 IRBuilder<> &Builder, VPIntrinsic &VPI, unsigned UnpredicatedIntrinsicID) { 320 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 321 "Implicitly dropping %evl in non-speculatable operator!"); 322 323 switch (UnpredicatedIntrinsicID) { 324 case Intrinsic::fabs: 325 case Intrinsic::sqrt: { 326 Value *Op0 = VPI.getOperand(0); 327 Function *Fn = Intrinsic::getDeclaration( 328 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 329 Value *NewOp = Builder.CreateCall(Fn, {Op0}, VPI.getName()); 330 replaceOperation(*NewOp, VPI); 331 return NewOp; 332 } 333 case Intrinsic::maxnum: 334 case Intrinsic::minnum: { 335 Value *Op0 = VPI.getOperand(0); 336 Value *Op1 = VPI.getOperand(1); 337 Function *Fn = Intrinsic::getDeclaration( 338 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 339 Value *NewOp = Builder.CreateCall(Fn, {Op0, Op1}, VPI.getName()); 340 replaceOperation(*NewOp, VPI); 341 return NewOp; 342 } 343 case Intrinsic::fma: 344 case Intrinsic::fmuladd: 345 case Intrinsic::experimental_constrained_fma: 346 case Intrinsic::experimental_constrained_fmuladd: { 347 Value *Op0 = VPI.getOperand(0); 348 Value *Op1 = VPI.getOperand(1); 349 Value *Op2 = VPI.getOperand(2); 350 Function *Fn = Intrinsic::getDeclaration( 351 VPI.getModule(), UnpredicatedIntrinsicID, {VPI.getType()}); 352 Value *NewOp; 353 if (Intrinsic::isConstrainedFPIntrinsic(UnpredicatedIntrinsicID)) 354 NewOp = 355 Builder.CreateConstrainedFPCall(Fn, {Op0, Op1, Op2}, VPI.getName()); 356 else 357 NewOp = Builder.CreateCall(Fn, {Op0, Op1, Op2}, VPI.getName()); 358 replaceOperation(*NewOp, VPI); 359 return NewOp; 360 } 361 } 362 363 return nullptr; 364 } 365 366 static Value *getNeutralReductionElement(const VPReductionIntrinsic &VPI, 367 Type *EltTy) { 368 bool Negative = false; 369 unsigned EltBits = EltTy->getScalarSizeInBits(); 370 Intrinsic::ID VID = VPI.getIntrinsicID(); 371 switch (VID) { 372 default: 373 llvm_unreachable("Expecting a VP reduction intrinsic"); 374 case Intrinsic::vp_reduce_add: 375 case Intrinsic::vp_reduce_or: 376 case Intrinsic::vp_reduce_xor: 377 case Intrinsic::vp_reduce_umax: 378 return Constant::getNullValue(EltTy); 379 case Intrinsic::vp_reduce_mul: 380 return ConstantInt::get(EltTy, 1, /*IsSigned*/ false); 381 case Intrinsic::vp_reduce_and: 382 case Intrinsic::vp_reduce_umin: 383 return ConstantInt::getAllOnesValue(EltTy); 384 case Intrinsic::vp_reduce_smin: 385 return ConstantInt::get(EltTy->getContext(), 386 APInt::getSignedMaxValue(EltBits)); 387 case Intrinsic::vp_reduce_smax: 388 return ConstantInt::get(EltTy->getContext(), 389 APInt::getSignedMinValue(EltBits)); 390 case Intrinsic::vp_reduce_fmax: 391 case Intrinsic::vp_reduce_fmaximum: 392 Negative = true; 393 [[fallthrough]]; 394 case Intrinsic::vp_reduce_fmin: 395 case Intrinsic::vp_reduce_fminimum: { 396 bool PropagatesNaN = VID == Intrinsic::vp_reduce_fminimum || 397 VID == Intrinsic::vp_reduce_fmaximum; 398 FastMathFlags Flags = VPI.getFastMathFlags(); 399 const fltSemantics &Semantics = EltTy->getFltSemantics(); 400 return (!Flags.noNaNs() && !PropagatesNaN) 401 ? ConstantFP::getQNaN(EltTy, Negative) 402 : !Flags.noInfs() 403 ? ConstantFP::getInfinity(EltTy, Negative) 404 : ConstantFP::get(EltTy, 405 APFloat::getLargest(Semantics, Negative)); 406 } 407 case Intrinsic::vp_reduce_fadd: 408 return ConstantFP::getNegativeZero(EltTy); 409 case Intrinsic::vp_reduce_fmul: 410 return ConstantFP::get(EltTy, 1.0); 411 } 412 } 413 414 Value * 415 CachingVPExpander::expandPredicationInReduction(IRBuilder<> &Builder, 416 VPReductionIntrinsic &VPI) { 417 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 418 "Implicitly dropping %evl in non-speculatable operator!"); 419 420 Value *Mask = VPI.getMaskParam(); 421 Value *RedOp = VPI.getOperand(VPI.getVectorParamPos()); 422 423 // Insert neutral element in masked-out positions 424 if (Mask && !isAllTrueMask(Mask)) { 425 auto *NeutralElt = getNeutralReductionElement(VPI, VPI.getType()); 426 auto *NeutralVector = Builder.CreateVectorSplat( 427 cast<VectorType>(RedOp->getType())->getElementCount(), NeutralElt); 428 RedOp = Builder.CreateSelect(Mask, RedOp, NeutralVector); 429 } 430 431 Value *Reduction; 432 Value *Start = VPI.getOperand(VPI.getStartParamPos()); 433 434 switch (VPI.getIntrinsicID()) { 435 default: 436 llvm_unreachable("Impossible reduction kind"); 437 case Intrinsic::vp_reduce_add: 438 Reduction = Builder.CreateAddReduce(RedOp); 439 Reduction = Builder.CreateAdd(Reduction, Start); 440 break; 441 case Intrinsic::vp_reduce_mul: 442 Reduction = Builder.CreateMulReduce(RedOp); 443 Reduction = Builder.CreateMul(Reduction, Start); 444 break; 445 case Intrinsic::vp_reduce_and: 446 Reduction = Builder.CreateAndReduce(RedOp); 447 Reduction = Builder.CreateAnd(Reduction, Start); 448 break; 449 case Intrinsic::vp_reduce_or: 450 Reduction = Builder.CreateOrReduce(RedOp); 451 Reduction = Builder.CreateOr(Reduction, Start); 452 break; 453 case Intrinsic::vp_reduce_xor: 454 Reduction = Builder.CreateXorReduce(RedOp); 455 Reduction = Builder.CreateXor(Reduction, Start); 456 break; 457 case Intrinsic::vp_reduce_smax: 458 Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ true); 459 Reduction = 460 Builder.CreateBinaryIntrinsic(Intrinsic::smax, Reduction, Start); 461 break; 462 case Intrinsic::vp_reduce_smin: 463 Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ true); 464 Reduction = 465 Builder.CreateBinaryIntrinsic(Intrinsic::smin, Reduction, Start); 466 break; 467 case Intrinsic::vp_reduce_umax: 468 Reduction = Builder.CreateIntMaxReduce(RedOp, /*IsSigned*/ false); 469 Reduction = 470 Builder.CreateBinaryIntrinsic(Intrinsic::umax, Reduction, Start); 471 break; 472 case Intrinsic::vp_reduce_umin: 473 Reduction = Builder.CreateIntMinReduce(RedOp, /*IsSigned*/ false); 474 Reduction = 475 Builder.CreateBinaryIntrinsic(Intrinsic::umin, Reduction, Start); 476 break; 477 case Intrinsic::vp_reduce_fmax: 478 Reduction = Builder.CreateFPMaxReduce(RedOp); 479 transferDecorations(*Reduction, VPI); 480 Reduction = 481 Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, Reduction, Start); 482 break; 483 case Intrinsic::vp_reduce_fmin: 484 Reduction = Builder.CreateFPMinReduce(RedOp); 485 transferDecorations(*Reduction, VPI); 486 Reduction = 487 Builder.CreateBinaryIntrinsic(Intrinsic::minnum, Reduction, Start); 488 break; 489 case Intrinsic::vp_reduce_fmaximum: 490 Reduction = Builder.CreateFPMaximumReduce(RedOp); 491 transferDecorations(*Reduction, VPI); 492 Reduction = 493 Builder.CreateBinaryIntrinsic(Intrinsic::maximum, Reduction, Start); 494 break; 495 case Intrinsic::vp_reduce_fminimum: 496 Reduction = Builder.CreateFPMinimumReduce(RedOp); 497 transferDecorations(*Reduction, VPI); 498 Reduction = 499 Builder.CreateBinaryIntrinsic(Intrinsic::minimum, Reduction, Start); 500 break; 501 case Intrinsic::vp_reduce_fadd: 502 Reduction = Builder.CreateFAddReduce(Start, RedOp); 503 break; 504 case Intrinsic::vp_reduce_fmul: 505 Reduction = Builder.CreateFMulReduce(Start, RedOp); 506 break; 507 } 508 509 replaceOperation(*Reduction, VPI); 510 return Reduction; 511 } 512 513 Value *CachingVPExpander::expandPredicationToCastIntrinsic(IRBuilder<> &Builder, 514 VPIntrinsic &VPI) { 515 Value *CastOp = nullptr; 516 switch (VPI.getIntrinsicID()) { 517 default: 518 llvm_unreachable("Not a VP cast intrinsic"); 519 case Intrinsic::vp_sext: 520 CastOp = 521 Builder.CreateSExt(VPI.getOperand(0), VPI.getType(), VPI.getName()); 522 break; 523 case Intrinsic::vp_zext: 524 CastOp = 525 Builder.CreateZExt(VPI.getOperand(0), VPI.getType(), VPI.getName()); 526 break; 527 case Intrinsic::vp_trunc: 528 CastOp = 529 Builder.CreateTrunc(VPI.getOperand(0), VPI.getType(), VPI.getName()); 530 break; 531 case Intrinsic::vp_inttoptr: 532 CastOp = 533 Builder.CreateIntToPtr(VPI.getOperand(0), VPI.getType(), VPI.getName()); 534 break; 535 case Intrinsic::vp_ptrtoint: 536 CastOp = 537 Builder.CreatePtrToInt(VPI.getOperand(0), VPI.getType(), VPI.getName()); 538 break; 539 case Intrinsic::vp_fptosi: 540 CastOp = 541 Builder.CreateFPToSI(VPI.getOperand(0), VPI.getType(), VPI.getName()); 542 break; 543 544 case Intrinsic::vp_fptoui: 545 CastOp = 546 Builder.CreateFPToUI(VPI.getOperand(0), VPI.getType(), VPI.getName()); 547 break; 548 case Intrinsic::vp_sitofp: 549 CastOp = 550 Builder.CreateSIToFP(VPI.getOperand(0), VPI.getType(), VPI.getName()); 551 break; 552 case Intrinsic::vp_uitofp: 553 CastOp = 554 Builder.CreateUIToFP(VPI.getOperand(0), VPI.getType(), VPI.getName()); 555 break; 556 case Intrinsic::vp_fptrunc: 557 CastOp = 558 Builder.CreateFPTrunc(VPI.getOperand(0), VPI.getType(), VPI.getName()); 559 break; 560 case Intrinsic::vp_fpext: 561 CastOp = 562 Builder.CreateFPExt(VPI.getOperand(0), VPI.getType(), VPI.getName()); 563 break; 564 } 565 replaceOperation(*CastOp, VPI); 566 return CastOp; 567 } 568 569 Value * 570 CachingVPExpander::expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, 571 VPIntrinsic &VPI) { 572 assert(VPI.canIgnoreVectorLengthParam()); 573 574 const auto &DL = F.getDataLayout(); 575 576 Value *MaskParam = VPI.getMaskParam(); 577 Value *PtrParam = VPI.getMemoryPointerParam(); 578 Value *DataParam = VPI.getMemoryDataParam(); 579 bool IsUnmasked = isAllTrueMask(MaskParam); 580 581 MaybeAlign AlignOpt = VPI.getPointerAlignment(); 582 583 Value *NewMemoryInst = nullptr; 584 switch (VPI.getIntrinsicID()) { 585 default: 586 llvm_unreachable("Not a VP memory intrinsic"); 587 case Intrinsic::vp_store: 588 if (IsUnmasked) { 589 StoreInst *NewStore = 590 Builder.CreateStore(DataParam, PtrParam, /*IsVolatile*/ false); 591 if (AlignOpt.has_value()) 592 NewStore->setAlignment(*AlignOpt); 593 NewMemoryInst = NewStore; 594 } else 595 NewMemoryInst = Builder.CreateMaskedStore( 596 DataParam, PtrParam, AlignOpt.valueOrOne(), MaskParam); 597 598 break; 599 case Intrinsic::vp_load: 600 if (IsUnmasked) { 601 LoadInst *NewLoad = 602 Builder.CreateLoad(VPI.getType(), PtrParam, /*IsVolatile*/ false); 603 if (AlignOpt.has_value()) 604 NewLoad->setAlignment(*AlignOpt); 605 NewMemoryInst = NewLoad; 606 } else 607 NewMemoryInst = Builder.CreateMaskedLoad( 608 VPI.getType(), PtrParam, AlignOpt.valueOrOne(), MaskParam); 609 610 break; 611 case Intrinsic::vp_scatter: { 612 auto *ElementType = 613 cast<VectorType>(DataParam->getType())->getElementType(); 614 NewMemoryInst = Builder.CreateMaskedScatter( 615 DataParam, PtrParam, 616 AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam); 617 break; 618 } 619 case Intrinsic::vp_gather: { 620 auto *ElementType = cast<VectorType>(VPI.getType())->getElementType(); 621 NewMemoryInst = Builder.CreateMaskedGather( 622 VPI.getType(), PtrParam, 623 AlignOpt.value_or(DL.getPrefTypeAlign(ElementType)), MaskParam, nullptr, 624 VPI.getName()); 625 break; 626 } 627 } 628 629 assert(NewMemoryInst); 630 replaceOperation(*NewMemoryInst, VPI); 631 return NewMemoryInst; 632 } 633 634 Value *CachingVPExpander::expandPredicationInComparison(IRBuilder<> &Builder, 635 VPCmpIntrinsic &VPI) { 636 assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && 637 "Implicitly dropping %evl in non-speculatable operator!"); 638 639 assert(*VPI.getFunctionalOpcode() == Instruction::ICmp || 640 *VPI.getFunctionalOpcode() == Instruction::FCmp); 641 642 Value *Op0 = VPI.getOperand(0); 643 Value *Op1 = VPI.getOperand(1); 644 auto Pred = VPI.getPredicate(); 645 646 auto *NewCmp = Builder.CreateCmp(Pred, Op0, Op1); 647 648 replaceOperation(*NewCmp, VPI); 649 return NewCmp; 650 } 651 652 void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) { 653 LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n"); 654 655 if (VPI.canIgnoreVectorLengthParam()) 656 return; 657 658 Value *EVLParam = VPI.getVectorLengthParam(); 659 if (!EVLParam) 660 return; 661 662 ElementCount StaticElemCount = VPI.getStaticVectorLength(); 663 Value *MaxEVL = nullptr; 664 Type *Int32Ty = Type::getInt32Ty(VPI.getContext()); 665 if (StaticElemCount.isScalable()) { 666 // TODO add caching 667 auto *M = VPI.getModule(); 668 Function *VScaleFunc = 669 Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty); 670 IRBuilder<> Builder(VPI.getParent(), VPI.getIterator()); 671 Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue()); 672 Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale"); 673 MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size", 674 /*NUW*/ true, /*NSW*/ false); 675 } else { 676 MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false); 677 } 678 VPI.setVectorLengthParam(MaxEVL); 679 } 680 681 Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) { 682 LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n'); 683 684 IRBuilder<> Builder(&VPI); 685 686 // Ineffective %evl parameter and so nothing to do here. 687 if (VPI.canIgnoreVectorLengthParam()) 688 return &VPI; 689 690 // Only VP intrinsics can have an %evl parameter. 691 Value *OldMaskParam = VPI.getMaskParam(); 692 Value *OldEVLParam = VPI.getVectorLengthParam(); 693 assert(OldMaskParam && "no mask param to fold the vl param into"); 694 assert(OldEVLParam && "no EVL param to fold away"); 695 696 LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n'); 697 LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); 698 699 // Convert the %evl predication into vector mask predication. 700 ElementCount ElemCount = VPI.getStaticVectorLength(); 701 Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount); 702 Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); 703 VPI.setMaskParam(NewMaskParam); 704 705 // Drop the %evl parameter. 706 discardEVLParameter(VPI); 707 assert(VPI.canIgnoreVectorLengthParam() && 708 "transformation did not render the evl param ineffective!"); 709 710 // Reassess the modified instruction. 711 return &VPI; 712 } 713 714 Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) { 715 LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n'); 716 717 IRBuilder<> Builder(&VPI); 718 719 // Try lowering to a LLVM instruction first. 720 auto OC = VPI.getFunctionalOpcode(); 721 722 if (OC && Instruction::isBinaryOp(*OC)) 723 return expandPredicationInBinaryOperator(Builder, VPI); 724 725 if (auto *VPRI = dyn_cast<VPReductionIntrinsic>(&VPI)) 726 return expandPredicationInReduction(Builder, *VPRI); 727 728 if (auto *VPCmp = dyn_cast<VPCmpIntrinsic>(&VPI)) 729 return expandPredicationInComparison(Builder, *VPCmp); 730 731 if (VPCastIntrinsic::isVPCast(VPI.getIntrinsicID())) { 732 return expandPredicationToCastIntrinsic(Builder, VPI); 733 } 734 735 switch (VPI.getIntrinsicID()) { 736 default: 737 break; 738 case Intrinsic::vp_fneg: { 739 Value *NewNegOp = Builder.CreateFNeg(VPI.getOperand(0), VPI.getName()); 740 replaceOperation(*NewNegOp, VPI); 741 return NewNegOp; 742 } 743 case Intrinsic::vp_abs: 744 case Intrinsic::vp_smax: 745 case Intrinsic::vp_smin: 746 case Intrinsic::vp_umax: 747 case Intrinsic::vp_umin: 748 case Intrinsic::vp_bswap: 749 case Intrinsic::vp_bitreverse: 750 return expandPredicationToIntCall(Builder, VPI, 751 VPI.getFunctionalIntrinsicID().value()); 752 case Intrinsic::vp_fabs: 753 case Intrinsic::vp_sqrt: 754 case Intrinsic::vp_maxnum: 755 case Intrinsic::vp_minnum: 756 case Intrinsic::vp_maximum: 757 case Intrinsic::vp_minimum: 758 case Intrinsic::vp_fma: 759 case Intrinsic::vp_fmuladd: 760 return expandPredicationToFPCall(Builder, VPI, 761 VPI.getFunctionalIntrinsicID().value()); 762 case Intrinsic::vp_load: 763 case Intrinsic::vp_store: 764 case Intrinsic::vp_gather: 765 case Intrinsic::vp_scatter: 766 return expandPredicationInMemoryIntrinsic(Builder, VPI); 767 } 768 769 if (auto CID = VPI.getConstrainedIntrinsicID()) 770 if (Value *Call = expandPredicationToFPCall(Builder, VPI, *CID)) 771 return Call; 772 773 return &VPI; 774 } 775 776 //// } CachingVPExpander 777 778 struct TransformJob { 779 VPIntrinsic *PI; 780 TargetTransformInfo::VPLegalization Strategy; 781 TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat) 782 : PI(PI), Strategy(InitStrat) {} 783 784 bool isDone() const { return Strategy.shouldDoNothing(); } 785 }; 786 787 void sanitizeStrategy(VPIntrinsic &VPI, VPLegalization &LegalizeStrat) { 788 // Operations with speculatable lanes do not strictly need predication. 789 if (maySpeculateLanes(VPI)) { 790 // Converting a speculatable VP intrinsic means dropping %mask and %evl. 791 // No need to expand %evl into the %mask only to ignore that code. 792 if (LegalizeStrat.OpStrategy == VPLegalization::Convert) 793 LegalizeStrat.EVLParamStrategy = VPLegalization::Discard; 794 return; 795 } 796 797 // We have to preserve the predicating effect of %evl for this 798 // non-speculatable VP intrinsic. 799 // 1) Never discard %evl. 800 // 2) If this VP intrinsic will be expanded to non-VP code, make sure that 801 // %evl gets folded into %mask. 802 if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) || 803 (LegalizeStrat.OpStrategy == VPLegalization::Convert)) { 804 LegalizeStrat.EVLParamStrategy = VPLegalization::Convert; 805 } 806 } 807 808 VPLegalization 809 CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const { 810 auto VPStrat = TTI.getVPLegalizationStrategy(VPI); 811 if (LLVM_LIKELY(!UsingTTIOverrides)) { 812 // No overrides - we are in production. 813 return VPStrat; 814 } 815 816 // Overrides set - we are in testing, the following does not need to be 817 // efficient. 818 VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride); 819 VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride); 820 return VPStrat; 821 } 822 823 /// Expand llvm.vp.* intrinsics as requested by \p TTI. 824 bool CachingVPExpander::expandVectorPredication() { 825 SmallVector<TransformJob, 16> Worklist; 826 827 // Collect all VPIntrinsics that need expansion and determine their expansion 828 // strategy. 829 for (auto &I : instructions(F)) { 830 auto *VPI = dyn_cast<VPIntrinsic>(&I); 831 if (!VPI) 832 continue; 833 auto VPStrat = getVPLegalizationStrategy(*VPI); 834 sanitizeStrategy(*VPI, VPStrat); 835 if (!VPStrat.shouldDoNothing()) 836 Worklist.emplace_back(VPI, VPStrat); 837 } 838 if (Worklist.empty()) 839 return false; 840 841 // Transform all VPIntrinsics on the worklist. 842 LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size() 843 << " instructions ::::\n"); 844 for (TransformJob Job : Worklist) { 845 // Transform the EVL parameter. 846 switch (Job.Strategy.EVLParamStrategy) { 847 case VPLegalization::Legal: 848 break; 849 case VPLegalization::Discard: 850 discardEVLParameter(*Job.PI); 851 break; 852 case VPLegalization::Convert: 853 if (foldEVLIntoMask(*Job.PI)) 854 ++NumFoldedVL; 855 break; 856 } 857 Job.Strategy.EVLParamStrategy = VPLegalization::Legal; 858 859 // Replace with a non-predicated operation. 860 switch (Job.Strategy.OpStrategy) { 861 case VPLegalization::Legal: 862 break; 863 case VPLegalization::Discard: 864 llvm_unreachable("Invalid strategy for operators."); 865 case VPLegalization::Convert: 866 expandPredication(*Job.PI); 867 ++NumLoweredVPOps; 868 break; 869 } 870 Job.Strategy.OpStrategy = VPLegalization::Legal; 871 872 assert(Job.isDone() && "incomplete transformation"); 873 } 874 875 return true; 876 } 877 class ExpandVectorPredication : public FunctionPass { 878 public: 879 static char ID; 880 ExpandVectorPredication() : FunctionPass(ID) { 881 initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry()); 882 } 883 884 bool runOnFunction(Function &F) override { 885 const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 886 CachingVPExpander VPExpander(F, *TTI); 887 return VPExpander.expandVectorPredication(); 888 } 889 890 void getAnalysisUsage(AnalysisUsage &AU) const override { 891 AU.addRequired<TargetTransformInfoWrapperPass>(); 892 AU.setPreservesCFG(); 893 } 894 }; 895 } // namespace 896 897 char ExpandVectorPredication::ID; 898 INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp", 899 "Expand vector predication intrinsics", false, false) 900 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 901 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 902 INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp", 903 "Expand vector predication intrinsics", false, false) 904 905 FunctionPass *llvm::createExpandVectorPredicationPass() { 906 return new ExpandVectorPredication(); 907 } 908 909 PreservedAnalyses 910 ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { 911 const auto &TTI = AM.getResult<TargetIRAnalysis>(F); 912 CachingVPExpander VPExpander(F, TTI); 913 if (!VPExpander.expandVectorPredication()) 914 return PreservedAnalyses::all(); 915 PreservedAnalyses PA; 916 PA.preserveSet<CFGAnalyses>(); 917 return PA; 918 } 919