//===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass custom lowers llvm.gather and llvm.scatter instructions to // RISC-V intrinsics. // //===----------------------------------------------------------------------===// #include "RISCV.h" #include "RISCVTargetMachine.h" #include "llvm/Analysis/InstSimplifyFolder.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/IntrinsicsRISCV.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/Local.h" #include using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "riscv-gather-scatter-lowering" namespace { class RISCVGatherScatterLowering : public FunctionPass { const RISCVSubtarget *ST = nullptr; const RISCVTargetLowering *TLI = nullptr; LoopInfo *LI = nullptr; const DataLayout *DL = nullptr; SmallVector MaybeDeadPHIs; // Cache of the BasePtr and Stride determined from this GEP. When a GEP is // used by multiple gathers/scatters, this allow us to reuse the scalar // instructions we created for the first gather/scatter for the others. DenseMap> StridedAddrs; public: static char ID; // Pass identification, replacement for typeid RISCVGatherScatterLowering() : FunctionPass(ID) {} bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); AU.addRequired(); } StringRef getPassName() const override { return "RISC-V gather/scatter lowering"; } private: bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, Value *AlignOp); std::pair determineBaseAndStride(Instruction *Ptr, IRBuilderBase &Builder); bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, PHINode *&BasePtr, BinaryOperator *&Inc, IRBuilderBase &Builder); }; } // end anonymous namespace char RISCVGatherScatterLowering::ID = 0; INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE, "RISC-V gather/scatter lowering pass", false, false) FunctionPass *llvm::createRISCVGatherScatterLoweringPass() { return new RISCVGatherScatterLowering(); } // TODO: Should we consider the mask when looking for a stride? static std::pair matchStridedConstant(Constant *StartC) { if (!isa(StartC->getType())) return std::make_pair(nullptr, nullptr); unsigned NumElts = cast(StartC->getType())->getNumElements(); // Check that the start value is a strided constant. auto *StartVal = dyn_cast_or_null(StartC->getAggregateElement((unsigned)0)); if (!StartVal) return std::make_pair(nullptr, nullptr); APInt StrideVal(StartVal->getValue().getBitWidth(), 0); ConstantInt *Prev = StartVal; for (unsigned i = 1; i != NumElts; ++i) { auto *C = dyn_cast_or_null(StartC->getAggregateElement(i)); if (!C) return std::make_pair(nullptr, nullptr); APInt LocalStride = C->getValue() - Prev->getValue(); if (i == 1) StrideVal = LocalStride; else if (StrideVal != LocalStride) return std::make_pair(nullptr, nullptr); Prev = C; } Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal); return std::make_pair(StartVal, Stride); } static std::pair matchStridedStart(Value *Start, IRBuilderBase &Builder) { // Base case, start is a strided constant. auto *StartC = dyn_cast(Start); if (StartC) return matchStridedConstant(StartC); // Base case, start is a stepvector if (match(Start, m_Intrinsic())) { auto *Ty = Start->getType()->getScalarType(); return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1)); } // Not a constant, maybe it's a strided constant with a splat added or // multipled. auto *BO = dyn_cast(Start); if (!BO || (BO->getOpcode() != Instruction::Add && BO->getOpcode() != Instruction::Or && BO->getOpcode() != Instruction::Shl && BO->getOpcode() != Instruction::Mul)) return std::make_pair(nullptr, nullptr); if (BO->getOpcode() == Instruction::Or && !cast(BO)->isDisjoint()) return std::make_pair(nullptr, nullptr); // Look for an operand that is splatted. unsigned OtherIndex = 0; Value *Splat = getSplatValue(BO->getOperand(1)); if (!Splat && Instruction::isCommutative(BO->getOpcode())) { Splat = getSplatValue(BO->getOperand(0)); OtherIndex = 1; } if (!Splat) return std::make_pair(nullptr, nullptr); Value *Stride; std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex), Builder); if (!Start) return std::make_pair(nullptr, nullptr); Builder.SetInsertPoint(BO); Builder.SetCurrentDebugLocation(DebugLoc()); // Add the splat value to the start or multiply the start and stride by the // splat. switch (BO->getOpcode()) { default: llvm_unreachable("Unexpected opcode"); case Instruction::Or: // TODO: We'd be better off creating disjoint or here, but we don't yet // have an IRBuilder API for that. [[fallthrough]]; case Instruction::Add: Start = Builder.CreateAdd(Start, Splat); break; case Instruction::Mul: Start = Builder.CreateMul(Start, Splat); Stride = Builder.CreateMul(Stride, Splat); break; case Instruction::Shl: Start = Builder.CreateShl(Start, Splat); Stride = Builder.CreateShl(Stride, Splat); break; } return std::make_pair(Start, Stride); } // Recursively, walk about the use-def chain until we find a Phi with a strided // start value. Build and update a scalar recurrence as we unwind the recursion. // We also update the Stride as we unwind. Our goal is to move all of the // arithmetic out of the loop. bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride, PHINode *&BasePtr, BinaryOperator *&Inc, IRBuilderBase &Builder) { // Our base case is a Phi. if (auto *Phi = dyn_cast(Index)) { // A phi node we want to perform this function on should be from the // loop header. if (Phi->getParent() != L->getHeader()) return false; Value *Step, *Start; if (!matchSimpleRecurrence(Phi, Inc, Start, Step) || Inc->getOpcode() != Instruction::Add) return false; assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1; assert(Phi->getIncomingValue(IncrementingBlock) == Inc && "Expected one operand of phi to be Inc"); // Only proceed if the step is loop invariant. if (!L->isLoopInvariant(Step)) return false; // Step should be a splat. Step = getSplatValue(Step); if (!Step) return false; std::tie(Start, Stride) = matchStridedStart(Start, Builder); if (!Start) return false; assert(Stride != nullptr); // Build scalar phi and increment. BasePtr = PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi); Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar", Inc); BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock)); BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock)); // Note that this Phi might be eligible for removal. MaybeDeadPHIs.push_back(Phi); return true; } // Otherwise look for binary operator. auto *BO = dyn_cast(Index); if (!BO) return false; switch (BO->getOpcode()) { default: return false; case Instruction::Or: // We need to be able to treat Or as Add. if (!cast(BO)->isDisjoint()) return false; break; case Instruction::Add: break; case Instruction::Shl: break; case Instruction::Mul: break; } // We should have one operand in the loop and one splat. Value *OtherOp; if (isa(BO->getOperand(0)) && L->contains(cast(BO->getOperand(0)))) { Index = cast(BO->getOperand(0)); OtherOp = BO->getOperand(1); } else if (isa(BO->getOperand(1)) && L->contains(cast(BO->getOperand(1))) && Instruction::isCommutative(BO->getOpcode())) { Index = cast(BO->getOperand(1)); OtherOp = BO->getOperand(0); } else { return false; } // Make sure other op is loop invariant. if (!L->isLoopInvariant(OtherOp)) return false; // Make sure we have a splat. Value *SplatOp = getSplatValue(OtherOp); if (!SplatOp) return false; // Recurse up the use-def chain. if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder)) return false; // Locate the Step and Start values from the recurrence. unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0; unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0; Value *Step = Inc->getOperand(StepIndex); Value *Start = BasePtr->getOperand(StartBlock); // We need to adjust the start value in the preheader. Builder.SetInsertPoint( BasePtr->getIncomingBlock(StartBlock)->getTerminator()); Builder.SetCurrentDebugLocation(DebugLoc()); switch (BO->getOpcode()) { default: llvm_unreachable("Unexpected opcode!"); case Instruction::Add: case Instruction::Or: { // An add only affects the start value. It's ok to do this for Or because // we already checked that there are no common set bits. Start = Builder.CreateAdd(Start, SplatOp, "start"); break; } case Instruction::Mul: { Start = Builder.CreateMul(Start, SplatOp, "start"); Step = Builder.CreateMul(Step, SplatOp, "step"); Stride = Builder.CreateMul(Stride, SplatOp, "stride"); break; } case Instruction::Shl: { Start = Builder.CreateShl(Start, SplatOp, "start"); Step = Builder.CreateShl(Step, SplatOp, "step"); Stride = Builder.CreateShl(Stride, SplatOp, "stride"); break; } } Inc->setOperand(StepIndex, Step); BasePtr->setIncomingValue(StartBlock, Start); return true; } std::pair RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr, IRBuilderBase &Builder) { // A gather/scatter of a splat is a zero strided load/store. if (auto *BasePtr = getSplatValue(Ptr)) { Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0)); } auto *GEP = dyn_cast(Ptr); if (!GEP) return std::make_pair(nullptr, nullptr); auto I = StridedAddrs.find(GEP); if (I != StridedAddrs.end()) return I->second; SmallVector Ops(GEP->operands()); // Base pointer needs to be a scalar. Value *ScalarBase = Ops[0]; if (ScalarBase->getType()->isVectorTy()) { ScalarBase = getSplatValue(ScalarBase); if (!ScalarBase) return std::make_pair(nullptr, nullptr); } std::optional VecOperand; unsigned TypeScale = 0; // Look for a vector operand and scale. gep_type_iterator GTI = gep_type_begin(GEP); for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { if (!Ops[i]->getType()->isVectorTy()) continue; if (VecOperand) return std::make_pair(nullptr, nullptr); VecOperand = i; TypeSize TS = GTI.getSequentialElementStride(*DL); if (TS.isScalable()) return std::make_pair(nullptr, nullptr); TypeScale = TS.getFixedValue(); } // We need to find a vector index to simplify. if (!VecOperand) return std::make_pair(nullptr, nullptr); // We can't extract the stride if the arithmetic is done at a different size // than the pointer type. Adding the stride later may not wrap correctly. // Technically we could handle wider indices, but I don't expect that in // practice. Handle one special case here - constants. This simplifies // writing test cases. Value *VecIndex = Ops[*VecOperand]; Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType()); if (VecIndex->getType() != VecIntPtrTy) { auto *VecIndexC = dyn_cast(VecIndex); if (!VecIndexC) return std::make_pair(nullptr, nullptr); if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits()) VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy); else VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy); } // Handle the non-recursive case. This is what we see if the vectorizer // decides to use a scalar IV + vid on demand instead of a vector IV. auto [Start, Stride] = matchStridedStart(VecIndex, Builder); if (Start) { assert(Stride); Builder.SetInsertPoint(GEP); // Replace the vector index with the scalar start and build a scalar GEP. Ops[*VecOperand] = Start; Type *SourceTy = GEP->getSourceElementType(); Value *BasePtr = Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front()); // Convert stride to pointer size if needed. Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); assert(Stride->getType() == IntPtrTy && "Unexpected type"); // Scale the stride by the size of the indexed type. if (TypeScale != 1) Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); auto P = std::make_pair(BasePtr, Stride); StridedAddrs[GEP] = P; return P; } // Make sure we're in a loop and that has a pre-header and a single latch. Loop *L = LI->getLoopFor(GEP->getParent()); if (!L || !L->getLoopPreheader() || !L->getLoopLatch()) return std::make_pair(nullptr, nullptr); BinaryOperator *Inc; PHINode *BasePhi; if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder)) return std::make_pair(nullptr, nullptr); assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi."); unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1; assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc && "Expected one operand of phi to be Inc"); Builder.SetInsertPoint(GEP); // Replace the vector index with the scalar phi and build a scalar GEP. Ops[*VecOperand] = BasePhi; Type *SourceTy = GEP->getSourceElementType(); Value *BasePtr = Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front()); // Final adjustments to stride should go in the start block. Builder.SetInsertPoint( BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator()); // Convert stride to pointer size if needed. Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType()); assert(Stride->getType() == IntPtrTy && "Unexpected type"); // Scale the stride by the size of the indexed type. if (TypeScale != 1) Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale)); auto P = std::make_pair(BasePtr, Stride); StridedAddrs[GEP] = P; return P; } bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr, Value *AlignOp) { // Make sure the operation will be supported by the backend. MaybeAlign MA = cast(AlignOp)->getMaybeAlignValue(); EVT DataTypeVT = TLI->getValueType(*DL, DataType); if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA)) return false; // FIXME: Let the backend type legalize by splitting/widening? if (!TLI->isTypeLegal(DataTypeVT)) return false; // Pointer should be an instruction. auto *PtrI = dyn_cast(Ptr); if (!PtrI) return false; LLVMContext &Ctx = PtrI->getContext(); IRBuilder Builder(Ctx, *DL); Builder.SetInsertPoint(PtrI); Value *BasePtr, *Stride; std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder); if (!BasePtr) return false; assert(Stride != nullptr); Builder.SetInsertPoint(II); CallInst *Call; if (II->getIntrinsicID() == Intrinsic::masked_gather) Call = Builder.CreateIntrinsic( Intrinsic::riscv_masked_strided_load, {DataType, BasePtr->getType(), Stride->getType()}, {II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)}); else Call = Builder.CreateIntrinsic( Intrinsic::riscv_masked_strided_store, {DataType, BasePtr->getType(), Stride->getType()}, {II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)}); Call->takeName(II); II->replaceAllUsesWith(Call); II->eraseFromParent(); if (PtrI->use_empty()) RecursivelyDeleteTriviallyDeadInstructions(PtrI); return true; } bool RISCVGatherScatterLowering::runOnFunction(Function &F) { if (skipFunction(F)) return false; auto &TPC = getAnalysis(); auto &TM = TPC.getTM(); ST = &TM.getSubtarget(F); if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors()) return false; TLI = ST->getTargetLowering(); DL = &F.getParent()->getDataLayout(); LI = &getAnalysis().getLoopInfo(); StridedAddrs.clear(); SmallVector Gathers; SmallVector Scatters; bool Changed = false; for (BasicBlock &BB : F) { for (Instruction &I : BB) { IntrinsicInst *II = dyn_cast(&I); if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { Gathers.push_back(II); } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { Scatters.push_back(II); } } } // Rewrite gather/scatter to form strided load/store if possible. for (auto *II : Gathers) Changed |= tryCreateStridedLoadStore( II, II->getType(), II->getArgOperand(0), II->getArgOperand(1)); for (auto *II : Scatters) Changed |= tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(), II->getArgOperand(1), II->getArgOperand(2)); // Remove any dead phis. while (!MaybeDeadPHIs.empty()) { if (auto *Phi = dyn_cast_or_null(MaybeDeadPHIs.pop_back_val())) RecursivelyDeleteDeadPHINode(Phi); } return Changed; }