Lines Matching +full:scatter +full:- +full:gather
1 //===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This pass custom lowers llvm.gather and llvm.scatter instructions to
10 // RISC-V intrinsics.
12 //===----------------------------------------------------------------------===//
32 #define DEBUG_TYPE "riscv-gather-scatter-lowering"
46 // instructions we created for the first gather/scatter for the others.
63 return "RISC-V gather/scatter lowering"; in getPassName()
83 "RISC-V gather/scatter lowering pass", false, false)
91 if (!isa<FixedVectorType>(StartC->getType())) in matchStridedConstant()
94 unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements(); in matchStridedConstant()
98 dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0)); in matchStridedConstant()
101 APInt StrideVal(StartVal->getValue().getBitWidth(), 0); in matchStridedConstant()
104 auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i)); in matchStridedConstant()
108 APInt LocalStride = C->getValue() - Prev->getValue(); in matchStridedConstant()
117 Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); in matchStridedConstant()
131 auto *Ty = Start->getType()->getScalarType(); in matchStridedStart()
138 if (!BO || (BO->getOpcode() != Instruction::Add && in matchStridedStart()
139 BO->getOpcode() != Instruction::Or && in matchStridedStart()
140 BO->getOpcode() != Instruction::Shl && in matchStridedStart()
141 BO->getOpcode() != Instruction::Mul)) in matchStridedStart()
144 if (BO->getOpcode() == Instruction::Or && in matchStridedStart()
145 !cast<PossiblyDisjointInst>(BO)->isDisjoint()) in matchStridedStart()
150 Value *Splat = getSplatValue(BO->getOperand(1)); in matchStridedStart()
151 if (!Splat && Instruction::isCommutative(BO->getOpcode())) { in matchStridedStart()
152 Splat = getSplatValue(BO->getOperand(0)); in matchStridedStart()
159 std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex), in matchStridedStart()
168 switch (BO->getOpcode()) { in matchStridedStart()
191 // Recursively, walk about the use-def chain until we find a Phi with a strided
204 if (Phi->getParent() != L->getHeader()) in matchStridedRecurrence()
209 Inc->getOpcode() != Instruction::Add) in matchStridedRecurrence()
211 assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); in matchStridedRecurrence()
212 unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; in matchStridedRecurrence()
213 assert(Phi->getIncomingValue(IncrementingBlock) == Inc && in matchStridedRecurrence()
217 if (!L->isLoopInvariant(Step)) in matchStridedRecurrence()
232 PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi->getIterator()); in matchStridedRecurrence()
233 Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", in matchStridedRecurrence()
234 Inc->getIterator()); in matchStridedRecurrence()
235 BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); in matchStridedRecurrence()
236 BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); in matchStridedRecurrence()
248 switch (BO->getOpcode()) { in matchStridedRecurrence()
253 if (!cast<PossiblyDisjointInst>(BO)->isDisjoint()) in matchStridedRecurrence()
266 if (isa<Instruction>(BO->getOperand(0)) && in matchStridedRecurrence()
267 L->contains(cast<Instruction>(BO->getOperand(0)))) { in matchStridedRecurrence()
268 Index = cast<Instruction>(BO->getOperand(0)); in matchStridedRecurrence()
269 OtherOp = BO->getOperand(1); in matchStridedRecurrence()
270 } else if (isa<Instruction>(BO->getOperand(1)) && in matchStridedRecurrence()
271 L->contains(cast<Instruction>(BO->getOperand(1))) && in matchStridedRecurrence()
272 Instruction::isCommutative(BO->getOpcode())) { in matchStridedRecurrence()
273 Index = cast<Instruction>(BO->getOperand(1)); in matchStridedRecurrence()
274 OtherOp = BO->getOperand(0); in matchStridedRecurrence()
280 if (!L->isLoopInvariant(OtherOp)) in matchStridedRecurrence()
288 // Recurse up the use-def chain. in matchStridedRecurrence()
293 unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; in matchStridedRecurrence()
294 unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; in matchStridedRecurrence()
295 Value *Step = Inc->getOperand(StepIndex); in matchStridedRecurrence()
296 Value *Start = BasePtr->getOperand(StartBlock); in matchStridedRecurrence()
300 BasePtr->getIncomingBlock(StartBlock)->getTerminator()); in matchStridedRecurrence()
303 switch (BO->getOpcode()) { in matchStridedRecurrence()
327 Inc->setOperand(StepIndex, Step); in matchStridedRecurrence()
328 BasePtr->setIncomingValue(StartBlock, Start); in matchStridedRecurrence()
336 // A gather/scatter of a splat is a zero strided load/store. in determineBaseAndStride()
338 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); in determineBaseAndStride()
348 return I->second; in determineBaseAndStride()
350 SmallVector<Value *, 2> Ops(GEP->operands()); in determineBaseAndStride()
353 Value *Base = GEP->getPointerOperand(); in determineBaseAndStride()
355 BaseInst && BaseInst->getType()->isVectorTy()) { in determineBaseAndStride()
357 auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); }; in determineBaseAndStride()
358 if (all_of(GEP->indices(), IsScalar)) { in determineBaseAndStride()
362 SmallVector<Value *> Indices(GEP->indices()); in determineBaseAndStride()
364 Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices, in determineBaseAndStride()
365 GEP->getName() + "offset", GEP->isInBounds()); in determineBaseAndStride()
373 if (ScalarBase->getType()->isVectorTy()) { in determineBaseAndStride()
384 for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { in determineBaseAndStride()
385 if (!Ops[i]->getType()->isVectorTy()) in determineBaseAndStride()
407 // practice. Handle one special case here - constants. This simplifies in determineBaseAndStride()
410 Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); in determineBaseAndStride()
411 if (VecIndex->getType() != VecIntPtrTy) { in determineBaseAndStride()
415 if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits()) in determineBaseAndStride()
421 // Handle the non-recursive case. This is what we see if the vectorizer in determineBaseAndStride()
430 Type *SourceTy = GEP->getSourceElementType(); in determineBaseAndStride()
435 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); in determineBaseAndStride()
436 assert(Stride->getType() == IntPtrTy && "Unexpected type"); in determineBaseAndStride()
447 // Make sure we're in a loop and that has a pre-header and a single latch. in determineBaseAndStride()
448 Loop *L = LI->getLoopFor(GEP->getParent()); in determineBaseAndStride()
449 if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) in determineBaseAndStride()
457 assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); in determineBaseAndStride()
458 unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; in determineBaseAndStride()
459 assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && in determineBaseAndStride()
466 Type *SourceTy = GEP->getSourceElementType(); in determineBaseAndStride()
472 BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); in determineBaseAndStride()
475 Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); in determineBaseAndStride()
476 assert(Stride->getType() == IntPtrTy && "Unexpected type"); in determineBaseAndStride()
492 MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue(); in tryCreateStridedLoadStore()
493 EVT DataTypeVT = TLI->getValueType(*DL, DataType); in tryCreateStridedLoadStore()
494 if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA)) in tryCreateStridedLoadStore()
498 if (!TLI->isTypeLegal(DataTypeVT)) in tryCreateStridedLoadStore()
506 LLVMContext &Ctx = PtrI->getContext(); in tryCreateStridedLoadStore()
519 if (II->getIntrinsicID() == Intrinsic::masked_gather) in tryCreateStridedLoadStore()
522 {DataType, BasePtr->getType(), Stride->getType()}, in tryCreateStridedLoadStore()
523 {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)}); in tryCreateStridedLoadStore()
527 {DataType, BasePtr->getType(), Stride->getType()}, in tryCreateStridedLoadStore()
528 {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)}); in tryCreateStridedLoadStore()
530 Call->takeName(II); in tryCreateStridedLoadStore()
531 II->replaceAllUsesWith(Call); in tryCreateStridedLoadStore()
532 II->eraseFromParent(); in tryCreateStridedLoadStore()
534 if (PtrI->use_empty()) in tryCreateStridedLoadStore()
547 if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) in runOnFunction()
550 TLI = ST->getTargetLowering(); in runOnFunction()
564 if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { in runOnFunction()
566 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { in runOnFunction()
572 // Rewrite gather/scatter to form strided load/store if possible. in runOnFunction()
575 II, II->getType(), II->getArgOperand(0), II->getArgOperand(1)); in runOnFunction()
578 tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(), in runOnFunction()
579 II->getArgOperand(1), II->getArgOperand(2)); in runOnFunction()