1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass custom lowers llvm.gather and llvm.scatter instructions to 10 // RISCV intrinsics. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "RISCV.h" 15 #include "RISCVTargetMachine.h" 16 #include "llvm/Analysis/LoopInfo.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/Analysis/VectorUtils.h" 19 #include "llvm/CodeGen/TargetPassConfig.h" 20 #include "llvm/IR/GetElementPtrTypeIterator.h" 21 #include "llvm/IR/IRBuilder.h" 22 #include "llvm/IR/IntrinsicInst.h" 23 #include "llvm/IR/IntrinsicsRISCV.h" 24 #include "llvm/Transforms/Utils/Local.h" 25 26 using namespace llvm; 27 28 #define DEBUG_TYPE "riscv-gather-scatter-lowering" 29 30 namespace { 31 32 class RISCVGatherScatterLowering : public FunctionPass { 33 const RISCVSubtarget *ST = nullptr; 34 const RISCVTargetLowering *TLI = nullptr; 35 LoopInfo *LI = nullptr; 36 const DataLayout *DL = nullptr; 37 38 SmallVector<WeakTrackingVH> MaybeDeadPHIs; 39 40 public: 41 static char ID; // Pass identification, replacement for typeid 42 43 RISCVGatherScatterLowering() : FunctionPass(ID) {} 44 45 bool runOnFunction(Function &F) override; 46 47 void getAnalysisUsage(AnalysisUsage &AU) const override { 48 AU.setPreservesCFG(); 49 AU.addRequired<TargetPassConfig>(); 50 AU.addRequired<LoopInfoWrapperPass>(); 51 } 52 53 StringRef getPassName() const override { 54 return "RISCV gather/scatter lowering"; 55 } 56 57 private: 58 bool isLegalTypeAndAlignment(Type *DataType, Value *AlignOp); 59 60 bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, 61 Value *AlignOp); 62 63 std::pair<Value *, Value *> determineBaseAndStride(GetElementPtrInst *GEP, 64 IRBuilder<> &Builder); 65 66 bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, 67 PHINode *&BasePtr, BinaryOperator *&Inc, 68 IRBuilder<> &Builder); 69 }; 70 71 } // end anonymous namespace 72 73 char RISCVGatherScatterLowering::ID = 0; 74 75 INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, 76 "RISCV gather/scatter lowering pass", false, false) 77 78 FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { 79 return new RISCVGatherScatterLowering(); 80 } 81 82 bool RISCVGatherScatterLowering::isLegalTypeAndAlignment(Type *DataType, 83 Value *AlignOp) { 84 Type *ScalarType = DataType->getScalarType(); 85 if (!TLI->isLegalElementTypeForRVV(ScalarType)) 86 return false; 87 88 MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue(); 89 if (MA && MA->value() < DL->getTypeStoreSize(ScalarType).getFixedSize()) 90 return false; 91 92 // FIXME: Let the backend type legalize by splitting/widening? 93 EVT DataVT = TLI->getValueType(*DL, DataType); 94 if (!TLI->isTypeLegal(DataVT)) 95 return false; 96 97 return true; 98 } 99 100 // TODO: Should we consider the mask when looking for a stride? 101 static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) { 102 unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements(); 103 104 // Check that the start value is a strided constant. 105 auto *StartVal = 106 dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0)); 107 if (!StartVal) 108 return std::make_pair(nullptr, nullptr); 109 APInt StrideVal(StartVal->getValue().getBitWidth(), 0); 110 ConstantInt *Prev = StartVal; 111 for (unsigned i = 1; i != NumElts; ++i) { 112 auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i)); 113 if (!C) 114 return std::make_pair(nullptr, nullptr); 115 116 APInt LocalStride = C->getValue() - Prev->getValue(); 117 if (i == 1) 118 StrideVal = LocalStride; 119 else if (StrideVal != LocalStride) 120 return std::make_pair(nullptr, nullptr); 121 122 Prev = C; 123 } 124 125 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); 126 127 return std::make_pair(StartVal, Stride); 128 } 129 130 static std::pair<Value *, Value *> matchStridedStart(Value *Start, 131 IRBuilder<> &Builder) { 132 // Base case, start is a strided constant. 133 auto *StartC = dyn_cast<Constant>(Start); 134 if (StartC) 135 return matchStridedConstant(StartC); 136 137 // Not a constant, maybe it's a strided constant with a splat added to it. 138 auto *BO = dyn_cast<BinaryOperator>(Start); 139 if (!BO || BO->getOpcode() != Instruction::Add) 140 return std::make_pair(nullptr, nullptr); 141 142 // Look for an operand that is splatted. 143 unsigned OtherIndex = 1; 144 Value *Splat = getSplatValue(BO->getOperand(0)); 145 if (!Splat) { 146 Splat = getSplatValue(BO->getOperand(1)); 147 OtherIndex = 0; 148 } 149 if (!Splat) 150 return std::make_pair(nullptr, nullptr); 151 152 Value *Stride; 153 std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex), 154 Builder); 155 if (!Start) 156 return std::make_pair(nullptr, nullptr); 157 158 // Add the splat value to the start. 159 Builder.SetInsertPoint(BO); 160 Builder.SetCurrentDebugLocation(DebugLoc()); 161 Start = Builder.CreateAdd(Start, Splat); 162 return std::make_pair(Start, Stride); 163 } 164 165 // Recursively, walk about the use-def chain until we find a Phi with a strided 166 // start value. Build and update a scalar recurrence as we unwind the recursion. 167 // We also update the Stride as we unwind. Our goal is to move all of the 168 // arithmetic out of the loop. 169 bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, 170 Value *&Stride, 171 PHINode *&BasePtr, 172 BinaryOperator *&Inc, 173 IRBuilder<> &Builder) { 174 // Our base case is a Phi. 175 if (auto *Phi = dyn_cast<PHINode>(Index)) { 176 // A phi node we want to perform this function on should be from the 177 // loop header. 178 if (Phi->getParent() != L->getHeader()) 179 return false; 180 181 Value *Step, *Start; 182 if (!matchSimpleRecurrence(Phi, Inc, Start, Step) || 183 Inc->getOpcode() != Instruction::Add) 184 return false; 185 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 186 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; 187 assert(Phi->getIncomingValue(IncrementingBlock) == Inc && 188 "Expected one operand of phi to be Inc"); 189 190 // Only proceed if the step is loop invariant. 191 if (!L->isLoopInvariant(Step)) 192 return false; 193 194 // Step should be a splat. 195 Step = getSplatValue(Step); 196 if (!Step) 197 return false; 198 199 std::tie(Start, Stride) = matchStridedStart(Start, Builder); 200 if (!Start) 201 return false; 202 assert(Stride != nullptr); 203 204 // Build scalar phi and increment. 205 BasePtr = 206 PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi); 207 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", 208 Inc); 209 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); 210 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); 211 212 // Note that this Phi might be eligible for removal. 213 MaybeDeadPHIs.push_back(Phi); 214 return true; 215 } 216 217 // Otherwise look for binary operator. 218 auto *BO = dyn_cast<BinaryOperator>(Index); 219 if (!BO) 220 return false; 221 222 if (BO->getOpcode() != Instruction::Add && 223 BO->getOpcode() != Instruction::Or && 224 BO->getOpcode() != Instruction::Mul && 225 BO->getOpcode() != Instruction::Shl) 226 return false; 227 228 // Only support shift by constant. 229 if (BO->getOpcode() == Instruction::Shl && !isa<Constant>(BO->getOperand(1))) 230 return false; 231 232 // We need to be able to treat Or as Add. 233 if (BO->getOpcode() == Instruction::Or && 234 !haveNoCommonBitsSet(BO->getOperand(0), BO->getOperand(1), *DL)) 235 return false; 236 237 // We should have one operand in the loop and one splat. 238 Value *OtherOp; 239 if (isa<Instruction>(BO->getOperand(0)) && 240 L->contains(cast<Instruction>(BO->getOperand(0)))) { 241 Index = cast<Instruction>(BO->getOperand(0)); 242 OtherOp = BO->getOperand(1); 243 } else if (isa<Instruction>(BO->getOperand(1)) && 244 L->contains(cast<Instruction>(BO->getOperand(1)))) { 245 Index = cast<Instruction>(BO->getOperand(1)); 246 OtherOp = BO->getOperand(0); 247 } else { 248 return false; 249 } 250 251 // Make sure other op is loop invariant. 252 if (!L->isLoopInvariant(OtherOp)) 253 return false; 254 255 // Make sure we have a splat. 256 Value *SplatOp = getSplatValue(OtherOp); 257 if (!SplatOp) 258 return false; 259 260 // Recurse up the use-def chain. 261 if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) 262 return false; 263 264 // Locate the Step and Start values from the recurrence. 265 unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; 266 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; 267 Value *Step = Inc->getOperand(StepIndex); 268 Value *Start = BasePtr->getOperand(StartBlock); 269 270 // We need to adjust the start value in the preheader. 271 Builder.SetInsertPoint( 272 BasePtr->getIncomingBlock(StartBlock)->getTerminator()); 273 Builder.SetCurrentDebugLocation(DebugLoc()); 274 275 switch (BO->getOpcode()) { 276 default: 277 llvm_unreachable("Unexpected opcode!"); 278 case Instruction::Add: 279 case Instruction::Or: { 280 // An add only affects the start value. It's ok to do this for Or because 281 // we already checked that there are no common set bits. 282 283 // If the start value is Zero, just take the SplatOp. 284 if (isa<ConstantInt>(Start) && cast<ConstantInt>(Start)->isZero()) 285 Start = SplatOp; 286 else 287 Start = Builder.CreateAdd(Start, SplatOp, "start"); 288 BasePtr->setIncomingValue(StartBlock, Start); 289 break; 290 } 291 case Instruction::Mul: { 292 // If the start is zero we don't need to multiply. 293 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 294 Start = Builder.CreateMul(Start, SplatOp, "start"); 295 296 Step = Builder.CreateMul(Step, SplatOp, "step"); 297 298 // If the Stride is 1 just take the SplatOpt. 299 if (isa<ConstantInt>(Stride) && cast<ConstantInt>(Stride)->isOne()) 300 Stride = SplatOp; 301 else 302 Stride = Builder.CreateMul(Stride, SplatOp, "stride"); 303 Inc->setOperand(StepIndex, Step); 304 BasePtr->setIncomingValue(StartBlock, Start); 305 break; 306 } 307 case Instruction::Shl: { 308 // If the start is zero we don't need to shift. 309 if (!isa<ConstantInt>(Start) || !cast<ConstantInt>(Start)->isZero()) 310 Start = Builder.CreateShl(Start, SplatOp, "start"); 311 Step = Builder.CreateShl(Step, SplatOp, "step"); 312 Stride = Builder.CreateShl(Stride, SplatOp, "stride"); 313 Inc->setOperand(StepIndex, Step); 314 BasePtr->setIncomingValue(StartBlock, Start); 315 break; 316 } 317 } 318 319 return true; 320 } 321 322 std::pair<Value *, Value *> 323 RISCVGatherScatterLowering::determineBaseAndStride(GetElementPtrInst *GEP, 324 IRBuilder<> &Builder) { 325 326 SmallVector<Value *, 2> Ops(GEP->operands()); 327 328 // Base pointer needs to be a scalar. 329 if (Ops[0]->getType()->isVectorTy()) 330 return std::make_pair(nullptr, nullptr); 331 332 // Make sure we're in a loop and it is in loop simplify form. 333 Loop *L = LI->getLoopFor(GEP->getParent()); 334 if (!L || !L->isLoopSimplifyForm()) 335 return std::make_pair(nullptr, nullptr); 336 337 Optional<unsigned> VecOperand; 338 unsigned TypeScale = 0; 339 340 // Look for a vector operand and scale. 341 gep_type_iterator GTI = gep_type_begin(GEP); 342 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { 343 if (!Ops[i]->getType()->isVectorTy()) 344 continue; 345 346 if (VecOperand) 347 return std::make_pair(nullptr, nullptr); 348 349 VecOperand = i; 350 351 TypeSize TS = DL->getTypeAllocSize(GTI.getIndexedType()); 352 if (TS.isScalable()) 353 return std::make_pair(nullptr, nullptr); 354 355 TypeScale = TS.getFixedSize(); 356 } 357 358 // We need to find a vector index to simplify. 359 if (!VecOperand) 360 return std::make_pair(nullptr, nullptr); 361 362 // We can't extract the stride if the arithmetic is done at a different size 363 // than the pointer type. Adding the stride later may not wrap correctly. 364 // Technically we could handle wider indices, but I don't expect that in 365 // practice. 366 Value *VecIndex = Ops[*VecOperand]; 367 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); 368 if (VecIndex->getType() != VecIntPtrTy) 369 return std::make_pair(nullptr, nullptr); 370 371 Value *Stride; 372 BinaryOperator *Inc; 373 PHINode *BasePhi; 374 if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) 375 return std::make_pair(nullptr, nullptr); 376 377 assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); 378 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; 379 assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && 380 "Expected one operand of phi to be Inc"); 381 382 Builder.SetInsertPoint(GEP); 383 384 // Replace the vector index with the scalar phi and build a scalar GEP. 385 Ops[*VecOperand] = BasePhi; 386 Type *SourceTy = GEP->getSourceElementType(); 387 Value *BasePtr = 388 Builder.CreateGEP(SourceTy, Ops[0], makeArrayRef(Ops).drop_front()); 389 390 // Cast the GEP to an i8*. 391 LLVMContext &Ctx = GEP->getContext(); 392 Type *I8PtrTy = 393 Type::getInt8PtrTy(Ctx, GEP->getType()->getPointerAddressSpace()); 394 if (BasePtr->getType() != I8PtrTy) 395 BasePtr = Builder.CreatePointerCast(BasePtr, I8PtrTy); 396 397 // Final adjustments to stride should go in the start block. 398 Builder.SetInsertPoint( 399 BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); 400 401 // Convert stride to pointer size if needed. 402 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); 403 assert(Stride->getType() == IntPtrTy && "Unexpected type"); 404 405 // Scale the stride by the size of the indexed type. 406 if (TypeScale != 1) 407 Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); 408 409 return std::make_pair(BasePtr, Stride); 410 } 411 412 bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, 413 Type *DataType, 414 Value *Ptr, 415 Value *AlignOp) { 416 // Make sure the operation will be supported by the backend. 417 if (!isLegalTypeAndAlignment(DataType, AlignOp)) 418 return false; 419 420 // Pointer should be a GEP. 421 auto *GEP = dyn_cast<GetElementPtrInst>(Ptr); 422 if (!GEP) 423 return false; 424 425 IRBuilder<> Builder(GEP); 426 427 Value *BasePtr, *Stride; 428 std::tie(BasePtr, Stride) = determineBaseAndStride(GEP, Builder); 429 if (!BasePtr) 430 return false; 431 assert(Stride != nullptr); 432 433 Builder.SetInsertPoint(II); 434 435 CallInst *Call; 436 if (II->getIntrinsicID() == Intrinsic::masked_gather) 437 Call = Builder.CreateIntrinsic( 438 Intrinsic::riscv_masked_strided_load, 439 {DataType, BasePtr->getType(), Stride->getType()}, 440 {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)}); 441 else 442 Call = Builder.CreateIntrinsic( 443 Intrinsic::riscv_masked_strided_store, 444 {DataType, BasePtr->getType(), Stride->getType()}, 445 {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)}); 446 447 Call->takeName(II); 448 II->replaceAllUsesWith(Call); 449 II->eraseFromParent(); 450 451 if (GEP->use_empty()) 452 RecursivelyDeleteTriviallyDeadInstructions(GEP); 453 454 return true; 455 } 456 457 bool RISCVGatherScatterLowering::runOnFunction(Function &F) { 458 if (skipFunction(F)) 459 return false; 460 461 auto &TPC = getAnalysis<TargetPassConfig>(); 462 auto &TM = TPC.getTM<RISCVTargetMachine>(); 463 ST = &TM.getSubtarget<RISCVSubtarget>(F); 464 if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) 465 return false; 466 467 TLI = ST->getTargetLowering(); 468 DL = &F.getParent()->getDataLayout(); 469 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 470 471 SmallVector<IntrinsicInst *, 4> Gathers; 472 SmallVector<IntrinsicInst *, 4> Scatters; 473 474 bool Changed = false; 475 476 for (BasicBlock &BB : F) { 477 for (Instruction &I : BB) { 478 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 479 if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 480 isa<FixedVectorType>(II->getType())) { 481 Gathers.push_back(II); 482 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 483 isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 484 Scatters.push_back(II); 485 } 486 } 487 } 488 489 // Rewrite gather/scatter to form strided load/store if possible. 490 for (auto *II : Gathers) 491 Changed |= tryCreateStridedLoadStore( 492 II, II->getType(), II->getArgOperand(0), II->getArgOperand(1)); 493 for (auto *II : Scatters) 494 Changed |= 495 tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(), 496 II->getArgOperand(1), II->getArgOperand(2)); 497 498 // Remove any dead phis. 499 while (!MaybeDeadPHIs.empty()) { 500 if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val())) 501 RecursivelyDeleteDeadPHINode(Phi); 502 } 503 504 return Changed; 505 } 506