1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass custom lowers llvm.gather and llvm.scatter instructions to 10 // RISCV intrinsics. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "RISCV.h" 15 #include "RISCVTargetMachine.h" 16 #include "llvm/Analysis/LoopInfo.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/Analysis/VectorUtils.h" 19 #include "llvm/CodeGen/TargetPassConfig.h" 20 #include "llvm/IR/GetElementPtrTypeIterator.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/IntrinsicInst.h" 23 #include "llvm/IR/IntrinsicsRISCV.h" 24 #include "llvm/Transforms/Utils/Local.h" 25 26 using namespace llvm; 27 28 #define DEBUG_TYPE "riscv-gather-scatter-lowering" 29 30 namespace { 31 32 class RISCVGatherScatterLowering : public FunctionPass { 33 const RISCVSubtarget *ST = nullptr; 34 const RISCVTargetLowering *TLI = nullptr; 35 LoopInfo *LI = nullptr; 36 const DataLayout *DL = nullptr; 37 38 SmallVector<WeakTrackingVH> MaybeDeadPHIs; 39 40 // Cache of the BasePtr and Stride determined from this GEP. When a GEP is 41 // used by multiple gathers/scatters, this allow us to reuse the scalar 42 // instructions we created for the first gather/scatter for the others. 43 DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs; 44 45 public: 46 static char ID; // Pass identification, replacement for typeid 47 48 RISCVGatherScatterLowering() : FunctionPass(ID) {} 49 50 bool runOnFunction(Function &F) override; 51 52 void getAnalysisUsage(AnalysisUsage &AU) const override { 53 AU.setPreservesCFG(); 54 AU.addRequired<TargetPassConfig>(); 55 AU.addRequired<LoopInfoWrapperPass>(); 56 } 57 58 StringRef getPassName() const override { 59 return "RISCV gather/scatter lowering"; 60 } 61 62 private: 63 bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp); 64 65 bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, 66 Value *AlignOp); 67 68 std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP, 69 IRBuilder<> &Builder); 70 71 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, 72 PHINode *&BasePtr, BinaryOperator *&Inc, 73 IRBuilder<> &Builder); 74 }; 75 76 } // end anonymous namespace 77 78 char RISCVGatherScatterLowering::ID = 0; 79 80 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, 81 "RISCV gather/scatter lowering pass", false, false) 82 83 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { 84 return new RISCVGatherScatterLowering(); 85 } 86 87 bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType, 88 Value *AlignOp) { 89 Type *ScalarType = DataType->getScalarType(); 90 if (!TLI->isLegalElementTypeForRVV(ScalarType)) 91 return false; 92 93 MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue(); 94 if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize()) 95 return false; 96 97 // FIXME: Let the backend type legalize by splitting/widening? 98 EVT DataVT = TLI->getValueType(*DL, DataType); 99 if (!TLI->isTypeLegal(DataVT)) 100 return false; 101 102 return true; 103 } 104 105 // TODO: Should we consider the mask when looking for a stride? 106 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { 107 unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements(); 108 109 // Check that the start value is a strided constant. 110 auto *StartVal = 111 dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0)); 112 if (!StartVal) 113 return std::make_pair(nullptr, nullptr); 114 APInt StrideVal(StartVal->getValue().getBitWidth(), 0); 115 ConstantInt *Prev = StartVal; 116 for (unsigned i = 1; i != NumElts; ++i) { 117 auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i)); 118 if (!C) 119 return std::make_pair(nullptr, nullptr); 120 121 APInt LocalStride = C->getValue() - Prev->getValue(); 122 if (i == 1) 123 StrideVal = LocalStride; 124 else if (StrideVal != LocalStride) 125 return std::make_pair(nullptr, nullptr); 126 127 Prev = C; 128 } 129 130 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); 131 132 return std::make_pair(StartVal, Stride); 133 } 134 135 static std::pair<Value *, Value *> matchStridedStart(Value *Start, 136 IRBuilder<> &Builder) { 137 // Base case, start is a strided constant. 138 auto *StartC = dyn_cast<Constant>(Start); 139 if (StartC) 140 return matchStridedConstant(StartC); 141 142 // Not a constant, maybe it's a strided constant with a splat added to it. 143 auto *BO = dyn_cast<BinaryOperator>(Start); 144 if (!BO || BO->getOpcode() != Instruction::Add) 145 return std::make_pair(nullptr, nullptr); 146 147 // Look for an operand that is splatted. 148 unsigned OtherIndex = 1; 149 Value *Splat = getSplatValue(BO->getOperand(0)); 150 if (!Splat) { 151 Splat = getSplatValue(BO->getOperand(1)); 152 OtherIndex = 0; 153 } 154 if (!Splat) 155 return std::make_pair(nullptr, nullptr); 156 157 Value *Stride; 158 std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex), 159 Builder); 160 if (!Start) 161 return std::make_pair(nullptr, nullptr); 162 163 // Add the splat value to the start. 164 Builder.SetInsertPoint(BO); 165 Builder.SetCurrentDebugLocation(DebugLoc()); 166 Start = Builder.CreateAdd(Start, Splat); 167 return std::make_pair(Start, Stride); 168 } 169 170 // Recursively, walk about the use-def chain until we find a Phi with a strided 171 // start value. Build and update a scalar recurrence as we unwind the recursion. 172 // We also update the Stride as we unwind. Our goal is to move all of the 173 // arithmetic out of the loop. 174 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, 175 Value *&Stride, 176 PHINode *&BasePtr, 177 BinaryOperator *&Inc, 178 IRBuilder<> &Builder) { 179 // Our base case is a Phi. 180 if (auto *Phi = dyn_cast<PHINode>(Index)) { 181 // A phi node we want to perform this function on should be from the 182 // loop header. 183 if (Phi->getParent() != L->getHeader()) 184 return false; 185 186 Value *Step, *Start; 187 if (!matchSimpleRecurrence(Phi, Inc, Start, Step) || 188 Inc->getOpcode() != Instruction::Add) 189 return false; 190 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 191 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; 192 assert(Phi->getIncomingValue(IncrementingBlock) == Inc && 193 "Expected one operand of phi to be Inc"); 194 195 // Only proceed if the step is loop invariant. 196 if (!L->isLoopInvariant(Step)) 197 return false; 198 199 // Step should be a splat. 200 Step = getSplatValue(Step); 201 if (!Step) 202 return false; 203 204 std::tie(Start, Stride) = matchStridedStart(Start, Builder); 205 if (!Start) 206 return false; 207 assert(Stride != nullptr); 208 209 // Build scalar phi and increment. 210 BasePtr = 211 PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi); 212 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", 213 Inc); 214 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); 215 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); 216 217 // Note that this Phi might be eligible for removal. 218 MaybeDeadPHIs.push_back(Phi); 219 return true; 220 } 221 222 // Otherwise look for binary operator. 223 auto *BO = dyn_cast<BinaryOperator>(Index); 224 if (!BO) 225 return false; 226 227 if (BO->getOpcode() != Instruction::Add && 228 BO->getOpcode() != Instruction::Or && 229 BO->getOpcode() != Instruction::Mul && 230 BO->getOpcode() != Instruction::Shl) 231 return false; 232 233 // Only support shift by constant. 234 if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1))) 235 return false; 236 237 // We need to be able to treat Or as Add. 238 if (BO->getOpcode() == Instruction::Or && 239 !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL)) 240 return false; 241 242 // We should have one operand in the loop and one splat. 243 Value *OtherOp; 244 if (isa<Instruction>(BO->getOperand(0)) && 245 L->contains(cast<Instruction>(BO->getOperand(0)))) { 246 Index = cast<Instruction>(BO->getOperand(0)); 247 OtherOp = BO->getOperand(1); 248 } else if (isa<Instruction>(BO->getOperand(1)) && 249 L->contains(cast<Instruction>(BO->getOperand(1)))) { 250 Index = cast<Instruction>(BO->getOperand(1)); 251 OtherOp = BO->getOperand(0); 252 } else { 253 return false; 254 } 255 256 // Make sure other op is loop invariant. 257 if (!L->isLoopInvariant(OtherOp)) 258 return false; 259 260 // Make sure we have a splat. 261 Value *SplatOp = getSplatValue(OtherOp); 262 if (!SplatOp) 263 return false; 264 265 // Recurse up the use-def chain. 266 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) 267 return false; 268 269 // Locate the Step and Start values from the recurrence. 270 unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; 271 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; 272 Value *Step = Inc->getOperand(StepIndex); 273 Value *Start = BasePtr->getOperand(StartBlock); 274 275 // We need to adjust the start value in the preheader. 276 Builder.SetInsertPoint( 277 BasePtr->getIncomingBlock(StartBlock)->getTerminator()); 278 Builder.SetCurrentDebugLocation(DebugLoc()); 279 280 switch (BO->getOpcode()) { 281 default: 282 llvm_unreachable("Unexpected opcode!"); 283 case Instruction::Add: 284 case Instruction::Or: { 285 // An add only affects the start value. It's ok to do this for Or because 286 // we already checked that there are no common set bits. 287 288 // If the start value is Zero, just take the SplatOp. 289 if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero()) 290 Start = SplatOp; 291 else 292 Start = Builder.CreateAdd(Start, SplatOp, "start"); 293 BasePtr->setIncomingValue(StartBlock, Start); 294 break; 295 } 296 case Instruction::Mul: { 297 // If the start is zero we don't need to multiply. 298 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 299 Start = Builder.CreateMul(Start, SplatOp, "start"); 300 301 Step = Builder.CreateMul(Step, SplatOp, "step"); 302 303 // If the Stride is 1 just take the SplatOpt. 304 if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne()) 305 Stride = SplatOp; 306 else 307 Stride = Builder.CreateMul(Stride, SplatOp, "stride"); 308 Inc->setOperand(StepIndex, Step); 309 BasePtr->setIncomingValue(StartBlock, Start); 310 break; 311 } 312 case Instruction::Shl: { 313 // If the start is zero we don't need to shift. 314 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 315 Start = Builder.CreateShl(Start, SplatOp, "start"); 316 Step = Builder.CreateShl(Step, SplatOp, "step"); 317 Stride = Builder.CreateShl(Stride, SplatOp, "stride"); 318 Inc->setOperand(StepIndex, Step); 319 BasePtr->setIncomingValue(StartBlock, Start); 320 break; 321 } 322 } 323 324 return true; 325 } 326 327 std::pair<Value *, Value *> 328 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP, 329 IRBuilder<> &Builder) { 330 331 auto I = StridedAddrs.find(GEP); 332 if (I != StridedAddrs.end()) 333 return I->second; 334 335 SmallVector<Value *, 2> Ops(GEP->operands()); 336 337 // Base pointer needs to be a scalar. 338 if (Ops[0]->getType()->isVectorTy()) 339 return std::make_pair(nullptr, nullptr); 340 341 // Make sure we're in a loop and that has a pre-header and a single latch. 342 Loop *L = LI->getLoopFor(GEP->getParent()); 343 if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) 344 return std::make_pair(nullptr, nullptr); 345 346 Optional<unsigned> VecOperand; 347 unsigned TypeScale = 0; 348 349 // Look for a vector operand and scale. 350 gep_type_iterator GTI = gep_type_begin(GEP); 351 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { 352 if (!Ops[i]->getType()->isVectorTy()) 353 continue; 354 355 if (VecOperand) 356 return std::make_pair(nullptr, nullptr); 357 358 VecOperand = i; 359 360 TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType()); 361 if (TS.isScalable()) 362 return std::make_pair(nullptr, nullptr); 363 364 TypeScale = TS.getFixedSize(); 365 } 366 367 // We need to find a vector index to simplify. 368 if (!VecOperand) 369 return std::make_pair(nullptr, nullptr); 370 371 // We can't extract the stride if the arithmetic is done at a different size 372 // than the pointer type. Adding the stride later may not wrap correctly. 373 // Technically we could handle wider indices, but I don't expect that in 374 // practice. 375 Value *VecIndex = Ops[*VecOperand]; 376 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); 377 if (VecIndex->getType() != VecIntPtrTy) 378 return std::make_pair(nullptr, nullptr); 379 380 Value *Stride; 381 BinaryOperator *Inc; 382 PHINode *BasePhi; 383 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) 384 return std::make_pair(nullptr, nullptr); 385 386 assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 387 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; 388 assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && 389 "Expected one operand of phi to be Inc"); 390 391 Builder.SetInsertPoint(GEP); 392 393 // Replace the vector index with the scalar phi and build a scalar GEP. 394 Ops[*VecOperand] = BasePhi; 395 Type *SourceTy = GEP->getSourceElementType(); 396 Value *BasePtr = 397 Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front()); 398 399 // Final adjustments to stride should go in the start block. 400 Builder.SetInsertPoint( 401 BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); 402 403 // Convert stride to pointer size if needed. 404 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 405 assert(Stride->getType() == IntPtrTy && "Unexpected type"); 406 407 // Scale the stride by the size of the indexed type. 408 if (TypeScale != 1) 409 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 410 411 auto P = std::make_pair(BasePtr, Stride); 412 StridedAddrs[GEP] = P; 413 return P; 414 } 415 416 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, 417 Type *DataType, 418 Value *Ptr, 419 Value *AlignOp) { 420 // Make sure the operation will be supported by the backend. 421 if (!isLegalTypeAndAlignment(DataType, AlignOp)) 422 return false; 423 424 // Pointer should be a GEP. 425 auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); 426 if (!GEP) 427 return false; 428 429 IRBuilder<> Builder(GEP); 430 431 Value *BasePtr, *Stride; 432 std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder); 433 if (!BasePtr) 434 return false; 435 assert(Stride != nullptr); 436 437 Builder.SetInsertPoint(II); 438 439 CallInst *Call; 440 if (II->getIntrinsicID() == Intrinsic::masked_gather) 441 Call = Builder.CreateIntrinsic( 442 Intrinsic::riscv_masked_strided_load, 443 {DataType, BasePtr->getType(), Stride->getType()}, 444 {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)}); 445 else 446 Call = Builder.CreateIntrinsic( 447 Intrinsic::riscv_masked_strided_store, 448 {DataType, BasePtr->getType(), Stride->getType()}, 449 {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)}); 450 451 Call->takeName(II); 452 II->replaceAllUsesWith(Call); 453 II->eraseFromParent(); 454 455 if (GEP->use_empty()) 456 RecursivelyDeleteTriviallyDeadInstructions(GEP); 457 458 return true; 459 } 460 461 bool RISCVGatherScatterLowering::runOnFunction(Function &F) { 462 if (skipFunction(F)) 463 return false; 464 465 auto &TPC = getAnalysis<TargetPassConfig>(); 466 auto &TM = TPC.getTM<RISCVTargetMachine>(); 467 ST = &TM.getSubtarget<RISCVSubtarget>(F); 468 if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) 469 return false; 470 471 TLI = ST->getTargetLowering(); 472 DL = &F.getParent()->getDataLayout(); 473 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 474 475 StridedAddrs.clear(); 476 477 SmallVector<IntrinsicInst *, 4> Gathers; 478 SmallVector<IntrinsicInst *, 4> Scatters; 479 480 bool Changed = false; 481 482 for (BasicBlock &BB : F) { 483 for (Instruction &I : BB) { 484 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 485 if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 486 isa<FixedVectorType>(II->getType())) { 487 Gathers.push_back(II); 488 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 489 isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 490 Scatters.push_back(II); 491 } 492 } 493 } 494 495 // Rewrite gather/scatter to form strided load/store if possible. 496 for (auto *II : Gathers) 497 Changed |= tryCreateStridedLoadStore( 498 II, II->getType(), II->getArgOperand(0), II->getArgOperand(1)); 499 for (auto *II : Scatters) 500 Changed |= 501 tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(), 502 II->getArgOperand(1), II->getArgOperand(2)); 503 504 // Remove any dead phis. 505 while (!MaybeDeadPHIs.empty()) { 506 if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val())) 507 RecursivelyDeleteDeadPHINode(Phi); 508 } 509 510 return Changed; 511 } 512