1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass custom lowers llvm.gather and llvm.scatter instructions to 10 // RISCV intrinsics. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "RISCV.h" 15 #include "RISCVTargetMachine.h" 16 #include "llvm/Analysis/LoopInfo.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/Analysis/VectorUtils.h" 19 #include "llvm/CodeGen/TargetPassConfig.h" 20 #include "llvm/IR/GetElementPtrTypeIterator.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/IntrinsicInst.h" 23 #include "llvm/IR/IntrinsicsRISCV.h" 24 #include "llvm/IR/PatternMatch.h" 25 #include "llvm/Transforms/Utils/Local.h" 26 #include <optional> 27 28 using namespace llvm; 29 using namespace PatternMatch; 30 31 #define DEBUG_TYPE "riscv-gather-scatter-lowering" 32 33 namespace { 34 35 class RISCVGatherScatterLowering : public FunctionPass { 36 const RISCVSubtarget *ST = nullptr; 37 const RISCVTargetLowering *TLI = nullptr; 38 LoopInfo *LI = nullptr; 39 const DataLayout *DL = nullptr; 40 41 SmallVector<WeakTrackingVH> MaybeDeadPHIs; 42 43 // Cache of the BasePtr and Stride determined from this GEP. When a GEP is 44 // used by multiple gathers/scatters, this allow us to reuse the scalar 45 // instructions we created for the first gather/scatter for the others. 46 DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs; 47 48 public: 49 static char ID; // Pass identification, replacement for typeid 50 51 RISCVGatherScatterLowering() : FunctionPass(ID) {} 52 53 bool runOnFunction(Function &F) override; 54 55 void getAnalysisUsage(AnalysisUsage &AU) const override { 56 AU.setPreservesCFG(); 57 AU.addRequired<TargetPassConfig>(); 58 AU.addRequired<LoopInfoWrapperPass>(); 59 } 60 61 StringRef getPassName() const override { 62 return "RISCV gather/scatter lowering"; 63 } 64 65 private: 66 bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp); 67 68 bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, 69 Value *AlignOp); 70 71 std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP, 72 IRBuilder<> &Builder); 73 74 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, 75 PHINode *&BasePtr, BinaryOperator *&Inc, 76 IRBuilder<> &Builder); 77 }; 78 79 } // end anonymous namespace 80 81 char RISCVGatherScatterLowering::ID = 0; 82 83 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, 84 "RISCV gather/scatter lowering pass", false, false) 85 86 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { 87 return new RISCVGatherScatterLowering(); 88 } 89 90 bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType, 91 Value *AlignOp) { 92 Type *ScalarType = DataType->getScalarType(); 93 if (!TLI->isLegalElementTypeForRVV(ScalarType)) 94 return false; 95 96 MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue(); 97 if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedValue()) 98 return false; 99 100 // FIXME: Let the backend type legalize by splitting/widening? 101 EVT DataVT = TLI->getValueType(*DL, DataType); 102 if (!TLI->isTypeLegal(DataVT)) 103 return false; 104 105 return true; 106 } 107 108 // TODO: Should we consider the mask when looking for a stride? 109 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { 110 unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements(); 111 112 // Check that the start value is a strided constant. 113 auto *StartVal = 114 dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0)); 115 if (!StartVal) 116 return std::make_pair(nullptr, nullptr); 117 APInt StrideVal(StartVal->getValue().getBitWidth(), 0); 118 ConstantInt *Prev = StartVal; 119 for (unsigned i = 1; i != NumElts; ++i) { 120 auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i)); 121 if (!C) 122 return std::make_pair(nullptr, nullptr); 123 124 APInt LocalStride = C->getValue() - Prev->getValue(); 125 if (i == 1) 126 StrideVal = LocalStride; 127 else if (StrideVal != LocalStride) 128 return std::make_pair(nullptr, nullptr); 129 130 Prev = C; 131 } 132 133 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); 134 135 return std::make_pair(StartVal, Stride); 136 } 137 138 static std::pair<Value *, Value *> matchStridedStart(Value *Start, 139 IRBuilder<> &Builder) { 140 // Base case, start is a strided constant. 141 auto *StartC = dyn_cast<Constant>(Start); 142 if (StartC) 143 return matchStridedConstant(StartC); 144 145 // Base case, start is a stepvector 146 if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) { 147 auto *Ty = Start->getType()->getScalarType(); 148 return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1)); 149 } 150 151 // Not a constant, maybe it's a strided constant with a splat added to it. 152 auto *BO = dyn_cast<BinaryOperator>(Start); 153 if (!BO || BO->getOpcode() != Instruction::Add) 154 return std::make_pair(nullptr, nullptr); 155 156 // Look for an operand that is splatted. 157 unsigned OtherIndex = 1; 158 Value *Splat = getSplatValue(BO->getOperand(0)); 159 if (!Splat) { 160 Splat = getSplatValue(BO->getOperand(1)); 161 OtherIndex = 0; 162 } 163 if (!Splat) 164 return std::make_pair(nullptr, nullptr); 165 166 Value *Stride; 167 std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex), 168 Builder); 169 if (!Start) 170 return std::make_pair(nullptr, nullptr); 171 172 // Add the splat value to the start. 173 Builder.SetInsertPoint(BO); 174 Builder.SetCurrentDebugLocation(DebugLoc()); 175 Start = Builder.CreateAdd(Start, Splat); 176 return std::make_pair(Start, Stride); 177 } 178 179 // Recursively, walk about the use-def chain until we find a Phi with a strided 180 // start value. Build and update a scalar recurrence as we unwind the recursion. 181 // We also update the Stride as we unwind. Our goal is to move all of the 182 // arithmetic out of the loop. 183 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, 184 Value *&Stride, 185 PHINode *&BasePtr, 186 BinaryOperator *&Inc, 187 IRBuilder<> &Builder) { 188 // Our base case is a Phi. 189 if (auto *Phi = dyn_cast<PHINode>(Index)) { 190 // A phi node we want to perform this function on should be from the 191 // loop header. 192 if (Phi->getParent() != L->getHeader()) 193 return false; 194 195 Value *Step, *Start; 196 if (!matchSimpleRecurrence(Phi, Inc, Start, Step) || 197 Inc->getOpcode() != Instruction::Add) 198 return false; 199 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 200 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; 201 assert(Phi->getIncomingValue(IncrementingBlock) == Inc && 202 "Expected one operand of phi to be Inc"); 203 204 // Only proceed if the step is loop invariant. 205 if (!L->isLoopInvariant(Step)) 206 return false; 207 208 // Step should be a splat. 209 Step = getSplatValue(Step); 210 if (!Step) 211 return false; 212 213 std::tie(Start, Stride) = matchStridedStart(Start, Builder); 214 if (!Start) 215 return false; 216 assert(Stride != nullptr); 217 218 // Build scalar phi and increment. 219 BasePtr = 220 PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi); 221 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", 222 Inc); 223 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); 224 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); 225 226 // Note that this Phi might be eligible for removal. 227 MaybeDeadPHIs.push_back(Phi); 228 return true; 229 } 230 231 // Otherwise look for binary operator. 232 auto *BO = dyn_cast<BinaryOperator>(Index); 233 if (!BO) 234 return false; 235 236 if (BO->getOpcode() != Instruction::Add && 237 BO->getOpcode() != Instruction::Or && 238 BO->getOpcode() != Instruction::Mul && 239 BO->getOpcode() != Instruction::Shl) 240 return false; 241 242 // Only support shift by constant. 243 if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1))) 244 return false; 245 246 // We need to be able to treat Or as Add. 247 if (BO->getOpcode() == Instruction::Or && 248 !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL)) 249 return false; 250 251 // We should have one operand in the loop and one splat. 252 Value *OtherOp; 253 if (isa<Instruction>(BO->getOperand(0)) && 254 L->contains(cast<Instruction>(BO->getOperand(0)))) { 255 Index = cast<Instruction>(BO->getOperand(0)); 256 OtherOp = BO->getOperand(1); 257 } else if (isa<Instruction>(BO->getOperand(1)) && 258 L->contains(cast<Instruction>(BO->getOperand(1)))) { 259 Index = cast<Instruction>(BO->getOperand(1)); 260 OtherOp = BO->getOperand(0); 261 } else { 262 return false; 263 } 264 265 // Make sure other op is loop invariant. 266 if (!L->isLoopInvariant(OtherOp)) 267 return false; 268 269 // Make sure we have a splat. 270 Value *SplatOp = getSplatValue(OtherOp); 271 if (!SplatOp) 272 return false; 273 274 // Recurse up the use-def chain. 275 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) 276 return false; 277 278 // Locate the Step and Start values from the recurrence. 279 unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; 280 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; 281 Value *Step = Inc->getOperand(StepIndex); 282 Value *Start = BasePtr->getOperand(StartBlock); 283 284 // We need to adjust the start value in the preheader. 285 Builder.SetInsertPoint( 286 BasePtr->getIncomingBlock(StartBlock)->getTerminator()); 287 Builder.SetCurrentDebugLocation(DebugLoc()); 288 289 switch (BO->getOpcode()) { 290 default: 291 llvm_unreachable("Unexpected opcode!"); 292 case Instruction::Add: 293 case Instruction::Or: { 294 // An add only affects the start value. It's ok to do this for Or because 295 // we already checked that there are no common set bits. 296 297 // If the start value is Zero, just take the SplatOp. 298 if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero()) 299 Start = SplatOp; 300 else 301 Start = Builder.CreateAdd(Start, SplatOp, "start"); 302 BasePtr->setIncomingValue(StartBlock, Start); 303 break; 304 } 305 case Instruction::Mul: { 306 // If the start is zero we don't need to multiply. 307 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 308 Start = Builder.CreateMul(Start, SplatOp, "start"); 309 310 Step = Builder.CreateMul(Step, SplatOp, "step"); 311 312 // If the Stride is 1 just take the SplatOpt. 313 if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne()) 314 Stride = SplatOp; 315 else 316 Stride = Builder.CreateMul(Stride, SplatOp, "stride"); 317 Inc->setOperand(StepIndex, Step); 318 BasePtr->setIncomingValue(StartBlock, Start); 319 break; 320 } 321 case Instruction::Shl: { 322 // If the start is zero we don't need to shift. 323 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 324 Start = Builder.CreateShl(Start, SplatOp, "start"); 325 Step = Builder.CreateShl(Step, SplatOp, "step"); 326 Stride = Builder.CreateShl(Stride, SplatOp, "stride"); 327 Inc->setOperand(StepIndex, Step); 328 BasePtr->setIncomingValue(StartBlock, Start); 329 break; 330 } 331 } 332 333 return true; 334 } 335 336 std::pair<Value *, Value *> 337 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP, 338 IRBuilder<> &Builder) { 339 340 auto I = StridedAddrs.find(GEP); 341 if (I != StridedAddrs.end()) 342 return I->second; 343 344 SmallVector<Value *, 2> Ops(GEP->operands()); 345 346 // Base pointer needs to be a scalar. 347 if (Ops[0]->getType()->isVectorTy()) 348 return std::make_pair(nullptr, nullptr); 349 350 std::optional<unsigned> VecOperand; 351 unsigned TypeScale = 0; 352 353 // Look for a vector operand and scale. 354 gep_type_iterator GTI = gep_type_begin(GEP); 355 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { 356 if (!Ops[i]->getType()->isVectorTy()) 357 continue; 358 359 if (VecOperand) 360 return std::make_pair(nullptr, nullptr); 361 362 VecOperand = i; 363 364 TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType()); 365 if (TS.isScalable()) 366 return std::make_pair(nullptr, nullptr); 367 368 TypeScale = TS.getFixedValue(); 369 } 370 371 // We need to find a vector index to simplify. 372 if (!VecOperand) 373 return std::make_pair(nullptr, nullptr); 374 375 // We can't extract the stride if the arithmetic is done at a different size 376 // than the pointer type. Adding the stride later may not wrap correctly. 377 // Technically we could handle wider indices, but I don't expect that in 378 // practice. 379 Value *VecIndex = Ops[*VecOperand]; 380 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); 381 if (VecIndex->getType() != VecIntPtrTy) 382 return std::make_pair(nullptr, nullptr); 383 384 // Handle the non-recursive case. This is what we see if the vectorizer 385 // decides to use a scalar IV + vid on demand instead of a vector IV. 386 auto [Start, Stride] = matchStridedStart(VecIndex, Builder); 387 if (Start) { 388 assert(Stride); 389 Builder.SetInsertPoint(GEP); 390 391 // Replace the vector index with the scalar start and build a scalar GEP. 392 Ops[*VecOperand] = Start; 393 Type *SourceTy = GEP->getSourceElementType(); 394 Value *BasePtr = 395 Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front()); 396 397 // Convert stride to pointer size if needed. 398 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 399 assert(Stride->getType() == IntPtrTy && "Unexpected type"); 400 401 // Scale the stride by the size of the indexed type. 402 if (TypeScale != 1) 403 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 404 405 auto P = std::make_pair(BasePtr, Stride); 406 StridedAddrs[GEP] = P; 407 return P; 408 } 409 410 // Make sure we're in a loop and that has a pre-header and a single latch. 411 Loop *L = LI->getLoopFor(GEP->getParent()); 412 if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) 413 return std::make_pair(nullptr, nullptr); 414 415 BinaryOperator *Inc; 416 PHINode *BasePhi; 417 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) 418 return std::make_pair(nullptr, nullptr); 419 420 assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 421 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; 422 assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && 423 "Expected one operand of phi to be Inc"); 424 425 Builder.SetInsertPoint(GEP); 426 427 // Replace the vector index with the scalar phi and build a scalar GEP. 428 Ops[*VecOperand] = BasePhi; 429 Type *SourceTy = GEP->getSourceElementType(); 430 Value *BasePtr = 431 Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front()); 432 433 // Final adjustments to stride should go in the start block. 434 Builder.SetInsertPoint( 435 BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); 436 437 // Convert stride to pointer size if needed. 438 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 439 assert(Stride->getType() == IntPtrTy && "Unexpected type"); 440 441 // Scale the stride by the size of the indexed type. 442 if (TypeScale != 1) 443 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 444 445 auto P = std::make_pair(BasePtr, Stride); 446 StridedAddrs[GEP] = P; 447 return P; 448 } 449 450 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, 451 Type *DataType, 452 Value *Ptr, 453 Value *AlignOp) { 454 // Make sure the operation will be supported by the backend. 455 if (!isLegalTypeAndAlignment(DataType, AlignOp)) 456 return false; 457 458 // Pointer should be a GEP. 459 auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); 460 if (!GEP) 461 return false; 462 463 IRBuilder<> Builder(GEP); 464 465 Value *BasePtr, *Stride; 466 std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder); 467 if (!BasePtr) 468 return false; 469 assert(Stride != nullptr); 470 471 Builder.SetInsertPoint(II); 472 473 CallInst *Call; 474 if (II->getIntrinsicID() == Intrinsic::masked_gather) 475 Call = Builder.CreateIntrinsic( 476 Intrinsic::riscv_masked_strided_load, 477 {DataType, BasePtr->getType(), Stride->getType()}, 478 {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)}); 479 else 480 Call = Builder.CreateIntrinsic( 481 Intrinsic::riscv_masked_strided_store, 482 {DataType, BasePtr->getType(), Stride->getType()}, 483 {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)}); 484 485 Call->takeName(II); 486 II->replaceAllUsesWith(Call); 487 II->eraseFromParent(); 488 489 if (GEP->use_empty()) 490 RecursivelyDeleteTriviallyDeadInstructions(GEP); 491 492 return true; 493 } 494 495 bool RISCVGatherScatterLowering::runOnFunction(Function &F) { 496 if (skipFunction(F)) 497 return false; 498 499 auto &TPC = getAnalysis<TargetPassConfig>(); 500 auto &TM = TPC.getTM<RISCVTargetMachine>(); 501 ST = &TM.getSubtarget<RISCVSubtarget>(F); 502 if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) 503 return false; 504 505 TLI = ST->getTargetLowering(); 506 DL = &F.getParent()->getDataLayout(); 507 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 508 509 StridedAddrs.clear(); 510 511 SmallVector<IntrinsicInst *, 4> Gathers; 512 SmallVector<IntrinsicInst *, 4> Scatters; 513 514 bool Changed = false; 515 516 for (BasicBlock &BB : F) { 517 for (Instruction &I : BB) { 518 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 519 if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { 520 Gathers.push_back(II); 521 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { 522 Scatters.push_back(II); 523 } 524 } 525 } 526 527 // Rewrite gather/scatter to form strided load/store if possible. 528 for (auto *II : Gathers) 529 Changed |= tryCreateStridedLoadStore( 530 II, II->getType(), II->getArgOperand(0), II->getArgOperand(1)); 531 for (auto *II : Scatters) 532 Changed |= 533 tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(), 534 II->getArgOperand(1), II->getArgOperand(2)); 535 536 // Remove any dead phis. 537 while (!MaybeDeadPHIs.empty()) { 538 if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val())) 539 RecursivelyDeleteDeadPHINode(Phi); 540 } 541 542 return Changed; 543 } 544