1 //===----- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements IR expansion for vector predication intrinsics, allowing 10 // targets to enable vector predication until just before codegen. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/CodeGen/ExpandVectorPredication.h" 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/Analysis/TargetTransformInfo.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/CodeGen/Passes.h" 19 #include "llvm/IR/Constants.h" 20 #include "llvm/IR/Function.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/InstIterator.h" 23 #include "llvm/IR/Instructions.h" 24 #include "llvm/IR/IntrinsicInst.h" 25 #include "llvm/IR/Intrinsics.h" 26 #include "llvm/IR/Module.h" 27 #include "llvm/InitializePasses.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Compiler.h" 31 #include "llvm/Support/Debug.h" 32 #include "llvm/Support/MathExtras.h" 33 34 using namespace llvm; 35 36 using VPLegalization = TargetTransformInfo::VPLegalization; 37 using VPTransform = TargetTransformInfo::VPLegalization::VPTransform; 38 39 // Keep this in sync with TargetTransformInfo::VPLegalization. 40 #define VPINTERNAL_VPLEGAL_CASES \ 41 VPINTERNAL_CASE(Legal) \ 42 VPINTERNAL_CASE(Discard) \ 43 VPINTERNAL_CASE(Convert) 44 45 #define VPINTERNAL_CASE(X) "|" #X 46 47 // Override options. 48 static cl::opt<std::string> EVLTransformOverride( 49 "expandvp-override-evl-transform", cl::init(""), cl::Hidden, 50 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES 51 ". If non-empty, ignore " 52 "TargetTransformInfo and " 53 "always use this transformation for the %evl parameter (Used in " 54 "testing).")); 55 56 static cl::opt<std::string> MaskTransformOverride( 57 "expandvp-override-mask-transform", cl::init(""), cl::Hidden, 58 cl::desc("Options: <empty>" VPINTERNAL_VPLEGAL_CASES 59 ". If non-empty, Ignore " 60 "TargetTransformInfo and " 61 "always use this transformation for the %mask parameter (Used in " 62 "testing).")); 63 64 #undef VPINTERNAL_CASE 65 #define VPINTERNAL_CASE(X) .Case(#X, VPLegalization::X) 66 67 static VPTransform parseOverrideOption(const std::string &TextOpt) { 68 return StringSwitch<VPTransform>(TextOpt) VPINTERNAL_VPLEGAL_CASES; 69 } 70 71 #undef VPINTERNAL_VPLEGAL_CASES 72 73 // Whether any override options are set. 74 static bool anyExpandVPOverridesSet() { 75 return !EVLTransformOverride.empty() || !MaskTransformOverride.empty(); 76 } 77 78 #define DEBUG_TYPE "expandvp" 79 80 STATISTIC(NumFoldedVL, "Number of folded vector length params"); 81 STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations"); 82 83 ///// Helpers { 84 85 /// \returns Whether the vector mask \p MaskVal has all lane bits set. 86 static bool isAllTrueMask(Value *MaskVal) { 87 auto *ConstVec = dyn_cast<ConstantVector>(MaskVal); 88 return ConstVec && ConstVec->isAllOnesValue(); 89 } 90 91 /// \returns A non-excepting divisor constant for this type. 92 static Constant *getSafeDivisor(Type *DivTy) { 93 assert(DivTy->isIntOrIntVectorTy() && "Unsupported divisor type"); 94 return ConstantInt::get(DivTy, 1u, false); 95 } 96 97 /// Transfer operation properties from \p OldVPI to \p NewVal. 98 static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) { 99 auto *NewInst = dyn_cast<Instruction>(&NewVal); 100 if (!NewInst || !isa<FPMathOperator>(NewVal)) 101 return; 102 103 auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI); 104 if (!OldFMOp) 105 return; 106 107 NewInst->setFastMathFlags(OldFMOp->getFastMathFlags()); 108 } 109 110 /// Transfer all properties from \p OldOp to \p NewOp and replace all uses. 111 /// OldVP gets erased. 112 static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) { 113 transferDecorations(NewOp, OldOp); 114 OldOp.replaceAllUsesWith(&NewOp); 115 OldOp.eraseFromParent(); 116 } 117 118 //// } Helpers 119 120 namespace { 121 122 // Expansion pass state at function scope. 123 struct CachingVPExpander { 124 Function &F; 125 const TargetTransformInfo &TTI; 126 127 /// \returns A (fixed length) vector with ascending integer indices 128 /// (<0, 1, ..., NumElems-1>). 129 /// \p Builder 130 /// Used for instruction creation. 131 /// \p LaneTy 132 /// Integer element type of the result vector. 133 /// \p NumElems 134 /// Number of vector elements. 135 Value *createStepVector(IRBuilder<> &Builder, Type *LaneTy, 136 unsigned NumElems); 137 138 /// \returns A bitmask that is true where the lane position is less-than \p 139 /// EVLParam 140 /// 141 /// \p Builder 142 /// Used for instruction creation. 143 /// \p VLParam 144 /// The explicit vector length parameter to test against the lane 145 /// positions. 146 /// \p ElemCount 147 /// Static (potentially scalable) number of vector elements. 148 Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam, 149 ElementCount ElemCount); 150 151 Value *foldEVLIntoMask(VPIntrinsic &VPI); 152 153 /// "Remove" the %evl parameter of \p PI by setting it to the static vector 154 /// length of the operation. 155 void discardEVLParameter(VPIntrinsic &PI); 156 157 /// \brief Lower this VP binary operator to a unpredicated binary operator. 158 Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder, 159 VPIntrinsic &PI); 160 161 /// \brief Query TTI and expand the vector predication in \p P accordingly. 162 Value *expandPredication(VPIntrinsic &PI); 163 164 /// \brief Determine how and whether the VPIntrinsic \p VPI shall be 165 /// expanded. This overrides TTI with the cl::opts listed at the top of this 166 /// file. 167 VPLegalization getVPLegalizationStrategy(const VPIntrinsic &VPI) const; 168 bool UsingTTIOverrides; 169 170 public: 171 CachingVPExpander(Function &F, const TargetTransformInfo &TTI) 172 : F(F), TTI(TTI), UsingTTIOverrides(anyExpandVPOverridesSet()) {} 173 174 bool expandVectorPredication(); 175 }; 176 177 //// CachingVPExpander { 178 179 Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder, Type *LaneTy, 180 unsigned NumElems) { 181 // TODO add caching 182 SmallVector<Constant *, 16> ConstElems; 183 184 for (unsigned Idx = 0; Idx < NumElems; ++Idx) 185 ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false)); 186 187 return ConstantVector::get(ConstElems); 188 } 189 190 Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder, 191 Value *EVLParam, 192 ElementCount ElemCount) { 193 // TODO add caching 194 // Scalable vector %evl conversion. 195 if (ElemCount.isScalable()) { 196 auto *M = Builder.GetInsertBlock()->getModule(); 197 Type *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount); 198 Function *ActiveMaskFunc = Intrinsic::getDeclaration( 199 M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()}); 200 // `get_active_lane_mask` performs an implicit less-than comparison. 201 Value *ConstZero = Builder.getInt32(0); 202 return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam}); 203 } 204 205 // Fixed vector %evl conversion. 206 Type *LaneTy = EVLParam->getType(); 207 unsigned NumElems = ElemCount.getFixedValue(); 208 Value *VLSplat = Builder.CreateVectorSplat(NumElems, EVLParam); 209 Value *IdxVec = createStepVector(Builder, LaneTy, NumElems); 210 return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat); 211 } 212 213 Value * 214 CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder, 215 VPIntrinsic &VPI) { 216 assert((isSafeToSpeculativelyExecute(&VPI) || 217 VPI.canIgnoreVectorLengthParam()) && 218 "Implicitly dropping %evl in non-speculatable operator!"); 219 220 auto OC = static_cast<Instruction::BinaryOps>(*VPI.getFunctionalOpcode()); 221 assert(Instruction::isBinaryOp(OC)); 222 223 Value *Op0 = VPI.getOperand(0); 224 Value *Op1 = VPI.getOperand(1); 225 Value *Mask = VPI.getMaskParam(); 226 227 // Blend in safe operands. 228 if (Mask && !isAllTrueMask(Mask)) { 229 switch (OC) { 230 default: 231 // Can safely ignore the predicate. 232 break; 233 234 // Division operators need a safe divisor on masked-off lanes (1). 235 case Instruction::UDiv: 236 case Instruction::SDiv: 237 case Instruction::URem: 238 case Instruction::SRem: 239 // 2nd operand must not be zero. 240 Value *SafeDivisor = getSafeDivisor(VPI.getType()); 241 Op1 = Builder.CreateSelect(Mask, Op1, SafeDivisor); 242 } 243 } 244 245 Value *NewBinOp = Builder.CreateBinOp(OC, Op0, Op1, VPI.getName()); 246 247 replaceOperation(*NewBinOp, VPI); 248 return NewBinOp; 249 } 250 251 void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) { 252 LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n"); 253 254 if (VPI.canIgnoreVectorLengthParam()) 255 return; 256 257 Value *EVLParam = VPI.getVectorLengthParam(); 258 if (!EVLParam) 259 return; 260 261 ElementCount StaticElemCount = VPI.getStaticVectorLength(); 262 Value *MaxEVL = nullptr; 263 Type *Int32Ty = Type::getInt32Ty(VPI.getContext()); 264 if (StaticElemCount.isScalable()) { 265 // TODO add caching 266 auto *M = VPI.getModule(); 267 Function *VScaleFunc = 268 Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty); 269 IRBuilder<> Builder(VPI.getParent(), VPI.getIterator()); 270 Value *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue()); 271 Value *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale"); 272 MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size", 273 /*NUW*/ true, /*NSW*/ false); 274 } else { 275 MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false); 276 } 277 VPI.setVectorLengthParam(MaxEVL); 278 } 279 280 Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) { 281 LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n'); 282 283 IRBuilder<> Builder(&VPI); 284 285 // Ineffective %evl parameter and so nothing to do here. 286 if (VPI.canIgnoreVectorLengthParam()) 287 return &VPI; 288 289 // Only VP intrinsics can have an %evl parameter. 290 Value *OldMaskParam = VPI.getMaskParam(); 291 Value *OldEVLParam = VPI.getVectorLengthParam(); 292 assert(OldMaskParam && "no mask param to fold the vl param into"); 293 assert(OldEVLParam && "no EVL param to fold away"); 294 295 LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n'); 296 LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n'); 297 298 // Convert the %evl predication into vector mask predication. 299 ElementCount ElemCount = VPI.getStaticVectorLength(); 300 Value *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount); 301 Value *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam); 302 VPI.setMaskParam(NewMaskParam); 303 304 // Drop the %evl parameter. 305 discardEVLParameter(VPI); 306 assert(VPI.canIgnoreVectorLengthParam() && 307 "transformation did not render the evl param ineffective!"); 308 309 // Reassess the modified instruction. 310 return &VPI; 311 } 312 313 Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) { 314 LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n'); 315 316 IRBuilder<> Builder(&VPI); 317 318 // Try lowering to a LLVM instruction first. 319 auto OC = VPI.getFunctionalOpcode(); 320 321 if (OC && Instruction::isBinaryOp(*OC)) 322 return expandPredicationInBinaryOperator(Builder, VPI); 323 324 return &VPI; 325 } 326 327 //// } CachingVPExpander 328 329 struct TransformJob { 330 VPIntrinsic *PI; 331 TargetTransformInfo::VPLegalization Strategy; 332 TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat) 333 : PI(PI), Strategy(InitStrat) {} 334 335 bool isDone() const { return Strategy.shouldDoNothing(); } 336 }; 337 338 void sanitizeStrategy(Instruction &I, VPLegalization &LegalizeStrat) { 339 // Speculatable instructions do not strictly need predication. 340 if (isSafeToSpeculativelyExecute(&I)) { 341 // Converting a speculatable VP intrinsic means dropping %mask and %evl. 342 // No need to expand %evl into the %mask only to ignore that code. 343 if (LegalizeStrat.OpStrategy == VPLegalization::Convert) 344 LegalizeStrat.EVLParamStrategy = VPLegalization::Discard; 345 return; 346 } 347 348 // We have to preserve the predicating effect of %evl for this 349 // non-speculatable VP intrinsic. 350 // 1) Never discard %evl. 351 // 2) If this VP intrinsic will be expanded to non-VP code, make sure that 352 // %evl gets folded into %mask. 353 if ((LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) || 354 (LegalizeStrat.OpStrategy == VPLegalization::Convert)) { 355 LegalizeStrat.EVLParamStrategy = VPLegalization::Convert; 356 } 357 } 358 359 VPLegalization 360 CachingVPExpander::getVPLegalizationStrategy(const VPIntrinsic &VPI) const { 361 auto VPStrat = TTI.getVPLegalizationStrategy(VPI); 362 if (LLVM_LIKELY(!UsingTTIOverrides)) { 363 // No overrides - we are in production. 364 return VPStrat; 365 } 366 367 // Overrides set - we are in testing, the following does not need to be 368 // efficient. 369 VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride); 370 VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride); 371 return VPStrat; 372 } 373 374 /// \brief Expand llvm.vp.* intrinsics as requested by \p TTI. 375 bool CachingVPExpander::expandVectorPredication() { 376 SmallVector<TransformJob, 16> Worklist; 377 378 // Collect all VPIntrinsics that need expansion and determine their expansion 379 // strategy. 380 for (auto &I : instructions(F)) { 381 auto *VPI = dyn_cast<VPIntrinsic>(&I); 382 if (!VPI) 383 continue; 384 auto VPStrat = getVPLegalizationStrategy(*VPI); 385 sanitizeStrategy(I, VPStrat); 386 if (!VPStrat.shouldDoNothing()) 387 Worklist.emplace_back(VPI, VPStrat); 388 } 389 if (Worklist.empty()) 390 return false; 391 392 // Transform all VPIntrinsics on the worklist. 393 LLVM_DEBUG(dbgs() << "\n:::: Transforming " << Worklist.size() 394 << " instructions ::::\n"); 395 for (TransformJob Job : Worklist) { 396 // Transform the EVL parameter. 397 switch (Job.Strategy.EVLParamStrategy) { 398 case VPLegalization::Legal: 399 break; 400 case VPLegalization::Discard: 401 discardEVLParameter(*Job.PI); 402 break; 403 case VPLegalization::Convert: 404 if (foldEVLIntoMask(*Job.PI)) 405 ++NumFoldedVL; 406 break; 407 } 408 Job.Strategy.EVLParamStrategy = VPLegalization::Legal; 409 410 // Replace with a non-predicated operation. 411 switch (Job.Strategy.OpStrategy) { 412 case VPLegalization::Legal: 413 break; 414 case VPLegalization::Discard: 415 llvm_unreachable("Invalid strategy for operators."); 416 case VPLegalization::Convert: 417 expandPredication(*Job.PI); 418 ++NumLoweredVPOps; 419 break; 420 } 421 Job.Strategy.OpStrategy = VPLegalization::Legal; 422 423 assert(Job.isDone() && "incomplete transformation"); 424 } 425 426 return true; 427 } 428 class ExpandVectorPredication : public FunctionPass { 429 public: 430 static char ID; 431 ExpandVectorPredication() : FunctionPass(ID) { 432 initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry()); 433 } 434 435 bool runOnFunction(Function &F) override { 436 const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 437 CachingVPExpander VPExpander(F, *TTI); 438 return VPExpander.expandVectorPredication(); 439 } 440 441 void getAnalysisUsage(AnalysisUsage &AU) const override { 442 AU.addRequired<TargetTransformInfoWrapperPass>(); 443 AU.setPreservesCFG(); 444 } 445 }; 446 } // namespace 447 448 char ExpandVectorPredication::ID; 449 INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expandvp", 450 "Expand vector predication intrinsics", false, false) 451 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 452 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 453 INITIALIZE_PASS_END(ExpandVectorPredication, "expandvp", 454 "Expand vector predication intrinsics", false, false) 455 456 FunctionPass *llvm::createExpandVectorPredicationPass() { 457 return new ExpandVectorPredication(); 458 } 459 460 PreservedAnalyses 461 ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) { 462 const auto &TTI = AM.getResult<TargetIRAnalysis>(F); 463 CachingVPExpander VPExpander(F, TTI); 464 if (!VPExpander.expandVectorPredication()) 465 return PreservedAnalyses::all(); 466 PreservedAnalyses PA; 467 PA.preserveSet<CFGAnalyses>(); 468 return PA; 469 } 470