1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass custom lowers llvm.gather and llvm.scatter instructions to 10 // RISC-V intrinsics. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "RISCV.h" 15 #include "RISCVTargetMachine.h" 16 #include "llvm/Analysis/InstSimplifyFolder.h" 17 #include "llvm/Analysis/LoopInfo.h" 18 #include "llvm/Analysis/ValueTracking.h" 19 #include "llvm/Analysis/VectorUtils.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/IR/GetElementPtrTypeIterator.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/IR/IntrinsicInst.h" 24 #include "llvm/IR/IntrinsicsRISCV.h" 25 #include "llvm/IR/PatternMatch.h" 26 #include "llvm/Transforms/Utils/Local.h" 27 #include <optional> 28 29 using namespace llvm; 30 using namespace PatternMatch; 31 32 #define DEBUG_TYPE "riscv-gather-scatter-lowering" 33 34 namespace { 35 36 class RISCVGatherScatterLowering : public FunctionPass { 37 const RISCVSubtarget *ST = nullptr; 38 const RISCVTargetLowering *TLI = nullptr; 39 LoopInfo *LI = nullptr; 40 const DataLayout *DL = nullptr; 41 42 SmallVector<WeakTrackingVH> MaybeDeadPHIs; 43 44 // Cache of the BasePtr and Stride determined from this GEP. When a GEP is 45 // used by multiple gathers/scatters, this allow us to reuse the scalar 46 // instructions we created for the first gather/scatter for the others. 47 DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs; 48 49 public: 50 static char ID; // Pass identification, replacement for typeid 51 52 RISCVGatherScatterLowering() : FunctionPass(ID) {} 53 54 bool runOnFunction(Function &F) override; 55 56 void getAnalysisUsage(AnalysisUsage &AU) const override { 57 AU.setPreservesCFG(); 58 AU.addRequired<TargetPassConfig>(); 59 AU.addRequired<LoopInfoWrapperPass>(); 60 } 61 62 StringRef getPassName() const override { 63 return "RISC-V gather/scatter lowering"; 64 } 65 66 private: 67 bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, 68 Value *AlignOp); 69 70 std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr, 71 IRBuilderBase &Builder); 72 73 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, 74 PHINode *&BasePtr, BinaryOperator *&Inc, 75 IRBuilderBase &Builder); 76 }; 77 78 } // end anonymous namespace 79 80 char RISCVGatherScatterLowering::ID = 0; 81 82 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, 83 "RISC-V gather/scatter lowering pass", false, false) 84 85 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { 86 return new RISCVGatherScatterLowering(); 87 } 88 89 // TODO: Should we consider the mask when looking for a stride? 90 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { 91 if (!isa<FixedVectorType>(StartC->getType())) 92 return std::make_pair(nullptr, nullptr); 93 94 unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements(); 95 96 // Check that the start value is a strided constant. 97 auto *StartVal = 98 dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0)); 99 if (!StartVal) 100 return std::make_pair(nullptr, nullptr); 101 APInt StrideVal(StartVal->getValue().getBitWidth(), 0); 102 ConstantInt *Prev = StartVal; 103 for (unsigned i = 1; i != NumElts; ++i) { 104 auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i)); 105 if (!C) 106 return std::make_pair(nullptr, nullptr); 107 108 APInt LocalStride = C->getValue() - Prev->getValue(); 109 if (i == 1) 110 StrideVal = LocalStride; 111 else if (StrideVal != LocalStride) 112 return std::make_pair(nullptr, nullptr); 113 114 Prev = C; 115 } 116 117 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); 118 119 return std::make_pair(StartVal, Stride); 120 } 121 122 static std::pair<Value *, Value *> matchStridedStart(Value *Start, 123 IRBuilderBase &Builder) { 124 // Base case, start is a strided constant. 125 auto *StartC = dyn_cast<Constant>(Start); 126 if (StartC) 127 return matchStridedConstant(StartC); 128 129 // Base case, start is a stepvector 130 if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) { 131 auto *Ty = Start->getType()->getScalarType(); 132 return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1)); 133 } 134 135 // Not a constant, maybe it's a strided constant with a splat added or 136 // multipled. 137 auto *BO = dyn_cast<BinaryOperator>(Start); 138 if (!BO || (BO->getOpcode() != Instruction::Add && 139 BO->getOpcode() != Instruction::Or && 140 BO->getOpcode() != Instruction::Shl && 141 BO->getOpcode() != Instruction::Mul)) 142 return std::make_pair(nullptr, nullptr); 143 144 if (BO->getOpcode() == Instruction::Or && 145 !cast<PossiblyDisjointInst>(BO)->isDisjoint()) 146 return std::make_pair(nullptr, nullptr); 147 148 // Look for an operand that is splatted. 149 unsigned OtherIndex = 0; 150 Value *Splat = getSplatValue(BO->getOperand(1)); 151 if (!Splat && Instruction::isCommutative(BO->getOpcode())) { 152 Splat = getSplatValue(BO->getOperand(0)); 153 OtherIndex = 1; 154 } 155 if (!Splat) 156 return std::make_pair(nullptr, nullptr); 157 158 Value *Stride; 159 std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex), 160 Builder); 161 if (!Start) 162 return std::make_pair(nullptr, nullptr); 163 164 Builder.SetInsertPoint(BO); 165 Builder.SetCurrentDebugLocation(DebugLoc()); 166 // Add the splat value to the start or multiply the start and stride by the 167 // splat. 168 switch (BO->getOpcode()) { 169 default: 170 llvm_unreachable("Unexpected opcode"); 171 case Instruction::Or: 172 // TODO: We'd be better off creating disjoint or here, but we don't yet 173 // have an IRBuilder API for that. 174 [[fallthrough]]; 175 case Instruction::Add: 176 Start = Builder.CreateAdd(Start, Splat); 177 break; 178 case Instruction::Mul: 179 Start = Builder.CreateMul(Start, Splat); 180 Stride = Builder.CreateMul(Stride, Splat); 181 break; 182 case Instruction::Shl: 183 Start = Builder.CreateShl(Start, Splat); 184 Stride = Builder.CreateShl(Stride, Splat); 185 break; 186 } 187 188 return std::make_pair(Start, Stride); 189 } 190 191 // Recursively, walk about the use-def chain until we find a Phi with a strided 192 // start value. Build and update a scalar recurrence as we unwind the recursion. 193 // We also update the Stride as we unwind. Our goal is to move all of the 194 // arithmetic out of the loop. 195 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, 196 Value *&Stride, 197 PHINode *&BasePtr, 198 BinaryOperator *&Inc, 199 IRBuilderBase &Builder) { 200 // Our base case is a Phi. 201 if (auto *Phi = dyn_cast<PHINode>(Index)) { 202 // A phi node we want to perform this function on should be from the 203 // loop header. 204 if (Phi->getParent() != L->getHeader()) 205 return false; 206 207 Value *Step, *Start; 208 if (!matchSimpleRecurrence(Phi, Inc, Start, Step) || 209 Inc->getOpcode() != Instruction::Add) 210 return false; 211 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 212 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; 213 assert(Phi->getIncomingValue(IncrementingBlock) == Inc && 214 "Expected one operand of phi to be Inc"); 215 216 // Only proceed if the step is loop invariant. 217 if (!L->isLoopInvariant(Step)) 218 return false; 219 220 // Step should be a splat. 221 Step = getSplatValue(Step); 222 if (!Step) 223 return false; 224 225 std::tie(Start, Stride) = matchStridedStart(Start, Builder); 226 if (!Start) 227 return false; 228 assert(Stride != nullptr); 229 230 // Build scalar phi and increment. 231 BasePtr = 232 PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi->getIterator()); 233 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", 234 Inc->getIterator()); 235 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); 236 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); 237 238 // Note that this Phi might be eligible for removal. 239 MaybeDeadPHIs.push_back(Phi); 240 return true; 241 } 242 243 // Otherwise look for binary operator. 244 auto *BO = dyn_cast<BinaryOperator>(Index); 245 if (!BO) 246 return false; 247 248 switch (BO->getOpcode()) { 249 default: 250 return false; 251 case Instruction::Or: 252 // We need to be able to treat Or as Add. 253 if (!cast<PossiblyDisjointInst>(BO)->isDisjoint()) 254 return false; 255 break; 256 case Instruction::Add: 257 break; 258 case Instruction::Shl: 259 break; 260 case Instruction::Mul: 261 break; 262 } 263 264 // We should have one operand in the loop and one splat. 265 Value *OtherOp; 266 if (isa<Instruction>(BO->getOperand(0)) && 267 L->contains(cast<Instruction>(BO->getOperand(0)))) { 268 Index = cast<Instruction>(BO->getOperand(0)); 269 OtherOp = BO->getOperand(1); 270 } else if (isa<Instruction>(BO->getOperand(1)) && 271 L->contains(cast<Instruction>(BO->getOperand(1))) && 272 Instruction::isCommutative(BO->getOpcode())) { 273 Index = cast<Instruction>(BO->getOperand(1)); 274 OtherOp = BO->getOperand(0); 275 } else { 276 return false; 277 } 278 279 // Make sure other op is loop invariant. 280 if (!L->isLoopInvariant(OtherOp)) 281 return false; 282 283 // Make sure we have a splat. 284 Value *SplatOp = getSplatValue(OtherOp); 285 if (!SplatOp) 286 return false; 287 288 // Recurse up the use-def chain. 289 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) 290 return false; 291 292 // Locate the Step and Start values from the recurrence. 293 unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; 294 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; 295 Value *Step = Inc->getOperand(StepIndex); 296 Value *Start = BasePtr->getOperand(StartBlock); 297 298 // We need to adjust the start value in the preheader. 299 Builder.SetInsertPoint( 300 BasePtr->getIncomingBlock(StartBlock)->getTerminator()); 301 Builder.SetCurrentDebugLocation(DebugLoc()); 302 303 switch (BO->getOpcode()) { 304 default: 305 llvm_unreachable("Unexpected opcode!"); 306 case Instruction::Add: 307 case Instruction::Or: { 308 // An add only affects the start value. It's ok to do this for Or because 309 // we already checked that there are no common set bits. 310 Start = Builder.CreateAdd(Start, SplatOp, "start"); 311 break; 312 } 313 case Instruction::Mul: { 314 Start = Builder.CreateMul(Start, SplatOp, "start"); 315 Step = Builder.CreateMul(Step, SplatOp, "step"); 316 Stride = Builder.CreateMul(Stride, SplatOp, "stride"); 317 break; 318 } 319 case Instruction::Shl: { 320 Start = Builder.CreateShl(Start, SplatOp, "start"); 321 Step = Builder.CreateShl(Step, SplatOp, "step"); 322 Stride = Builder.CreateShl(Stride, SplatOp, "stride"); 323 break; 324 } 325 } 326 327 Inc->setOperand(StepIndex, Step); 328 BasePtr->setIncomingValue(StartBlock, Start); 329 return true; 330 } 331 332 std::pair<Value *, Value *> 333 RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr, 334 IRBuilderBase &Builder) { 335 336 // A gather/scatter of a splat is a zero strided load/store. 337 if (auto *BasePtr = getSplatValue(Ptr)) { 338 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 339 return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0)); 340 } 341 342 auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); 343 if (!GEP) 344 return std::make_pair(nullptr, nullptr); 345 346 auto I = StridedAddrs.find(GEP); 347 if (I != StridedAddrs.end()) 348 return I->second; 349 350 SmallVector<Value *, 2> Ops(GEP->operands()); 351 352 // If the base pointer is a vector, check if it's strided. 353 Value *Base = GEP->getPointerOperand(); 354 if (auto *BaseInst = dyn_cast<Instruction>(Base); 355 BaseInst && BaseInst->getType()->isVectorTy()) { 356 // If GEP's offset is scalar then we can add it to the base pointer's base. 357 auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); }; 358 if (all_of(GEP->indices(), IsScalar)) { 359 auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder); 360 if (BaseBase) { 361 Builder.SetInsertPoint(GEP); 362 SmallVector<Value *> Indices(GEP->indices()); 363 Value *OffsetBase = 364 Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices, 365 GEP->getName() + "offset", GEP->isInBounds()); 366 return {OffsetBase, Stride}; 367 } 368 } 369 } 370 371 // Base pointer needs to be a scalar. 372 Value *ScalarBase = Base; 373 if (ScalarBase->getType()->isVectorTy()) { 374 ScalarBase = getSplatValue(ScalarBase); 375 if (!ScalarBase) 376 return std::make_pair(nullptr, nullptr); 377 } 378 379 std::optional<unsigned> VecOperand; 380 unsigned TypeScale = 0; 381 382 // Look for a vector operand and scale. 383 gep_type_iterator GTI = gep_type_begin(GEP); 384 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { 385 if (!Ops[i]->getType()->isVectorTy()) 386 continue; 387 388 if (VecOperand) 389 return std::make_pair(nullptr, nullptr); 390 391 VecOperand = i; 392 393 TypeSize TS = GTI.getSequentialElementStride(*DL); 394 if (TS.isScalable()) 395 return std::make_pair(nullptr, nullptr); 396 397 TypeScale = TS.getFixedValue(); 398 } 399 400 // We need to find a vector index to simplify. 401 if (!VecOperand) 402 return std::make_pair(nullptr, nullptr); 403 404 // We can't extract the stride if the arithmetic is done at a different size 405 // than the pointer type. Adding the stride later may not wrap correctly. 406 // Technically we could handle wider indices, but I don't expect that in 407 // practice. Handle one special case here - constants. This simplifies 408 // writing test cases. 409 Value *VecIndex = Ops[*VecOperand]; 410 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); 411 if (VecIndex->getType() != VecIntPtrTy) { 412 auto *VecIndexC = dyn_cast<Constant>(VecIndex); 413 if (!VecIndexC) 414 return std::make_pair(nullptr, nullptr); 415 if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits()) 416 VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy); 417 else 418 VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy); 419 } 420 421 // Handle the non-recursive case. This is what we see if the vectorizer 422 // decides to use a scalar IV + vid on demand instead of a vector IV. 423 auto [Start, Stride] = matchStridedStart(VecIndex, Builder); 424 if (Start) { 425 assert(Stride); 426 Builder.SetInsertPoint(GEP); 427 428 // Replace the vector index with the scalar start and build a scalar GEP. 429 Ops[*VecOperand] = Start; 430 Type *SourceTy = GEP->getSourceElementType(); 431 Value *BasePtr = 432 Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front()); 433 434 // Convert stride to pointer size if needed. 435 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 436 assert(Stride->getType() == IntPtrTy && "Unexpected type"); 437 438 // Scale the stride by the size of the indexed type. 439 if (TypeScale != 1) 440 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 441 442 auto P = std::make_pair(BasePtr, Stride); 443 StridedAddrs[GEP] = P; 444 return P; 445 } 446 447 // Make sure we're in a loop and that has a pre-header and a single latch. 448 Loop *L = LI->getLoopFor(GEP->getParent()); 449 if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) 450 return std::make_pair(nullptr, nullptr); 451 452 BinaryOperator *Inc; 453 PHINode *BasePhi; 454 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) 455 return std::make_pair(nullptr, nullptr); 456 457 assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 458 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; 459 assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && 460 "Expected one operand of phi to be Inc"); 461 462 Builder.SetInsertPoint(GEP); 463 464 // Replace the vector index with the scalar phi and build a scalar GEP. 465 Ops[*VecOperand] = BasePhi; 466 Type *SourceTy = GEP->getSourceElementType(); 467 Value *BasePtr = 468 Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front()); 469 470 // Final adjustments to stride should go in the start block. 471 Builder.SetInsertPoint( 472 BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); 473 474 // Convert stride to pointer size if needed. 475 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 476 assert(Stride->getType() == IntPtrTy && "Unexpected type"); 477 478 // Scale the stride by the size of the indexed type. 479 if (TypeScale != 1) 480 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 481 482 auto P = std::make_pair(BasePtr, Stride); 483 StridedAddrs[GEP] = P; 484 return P; 485 } 486 487 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, 488 Type *DataType, 489 Value *Ptr, 490 Value *AlignOp) { 491 // Make sure the operation will be supported by the backend. 492 MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue(); 493 EVT DataTypeVT = TLI->getValueType(*DL, DataType); 494 if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA)) 495 return false; 496 497 // FIXME: Let the backend type legalize by splitting/widening? 498 if (!TLI->isTypeLegal(DataTypeVT)) 499 return false; 500 501 // Pointer should be an instruction. 502 auto *PtrI = dyn_cast<Instruction>(Ptr); 503 if (!PtrI) 504 return false; 505 506 LLVMContext &Ctx = PtrI->getContext(); 507 IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL); 508 Builder.SetInsertPoint(PtrI); 509 510 Value *BasePtr, *Stride; 511 std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder); 512 if (!BasePtr) 513 return false; 514 assert(Stride != nullptr); 515 516 Builder.SetInsertPoint(II); 517 518 CallInst *Call; 519 if (II->getIntrinsicID() == Intrinsic::masked_gather) 520 Call = Builder.CreateIntrinsic( 521 Intrinsic::riscv_masked_strided_load, 522 {DataType, BasePtr->getType(), Stride->getType()}, 523 {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)}); 524 else 525 Call = Builder.CreateIntrinsic( 526 Intrinsic::riscv_masked_strided_store, 527 {DataType, BasePtr->getType(), Stride->getType()}, 528 {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)}); 529 530 Call->takeName(II); 531 II->replaceAllUsesWith(Call); 532 II->eraseFromParent(); 533 534 if (PtrI->use_empty()) 535 RecursivelyDeleteTriviallyDeadInstructions(PtrI); 536 537 return true; 538 } 539 540 bool RISCVGatherScatterLowering::runOnFunction(Function &F) { 541 if (skipFunction(F)) 542 return false; 543 544 auto &TPC = getAnalysis<TargetPassConfig>(); 545 auto &TM = TPC.getTM<RISCVTargetMachine>(); 546 ST = &TM.getSubtarget<RISCVSubtarget>(F); 547 if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) 548 return false; 549 550 TLI = ST->getTargetLowering(); 551 DL = &F.getDataLayout(); 552 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 553 554 StridedAddrs.clear(); 555 556 SmallVector<IntrinsicInst *, 4> Gathers; 557 SmallVector<IntrinsicInst *, 4> Scatters; 558 559 bool Changed = false; 560 561 for (BasicBlock &BB : F) { 562 for (Instruction &I : BB) { 563 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 564 if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { 565 Gathers.push_back(II); 566 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { 567 Scatters.push_back(II); 568 } 569 } 570 } 571 572 // Rewrite gather/scatter to form strided load/store if possible. 573 for (auto *II : Gathers) 574 Changed |= tryCreateStridedLoadStore( 575 II, II->getType(), II->getArgOperand(0), II->getArgOperand(1)); 576 for (auto *II : Scatters) 577 Changed |= 578 tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(), 579 II->getArgOperand(1), II->getArgOperand(2)); 580 581 // Remove any dead phis. 582 while (!MaybeDeadPHIs.empty()) { 583 if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val())) 584 RecursivelyDeleteDeadPHINode(Phi); 585 } 586 587 return Changed; 588 } 589