//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// This pass custom lowers llvm.gather and llvm.scatter instructions to /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to /// produce a better final result as we go. // //===----------------------------------------------------------------------===// #include "ARM.h" #include "ARMBaseInstrInfo.h" #include "ARMSubtarget.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/InitializePasses.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Transforms/Utils/Local.h" #include #include using namespace llvm; #define DEBUG_TYPE "arm-mve-gather-scatter-lowering" cl::opt EnableMaskedGatherScatters( "enable-arm-maskedgatscat", cl::Hidden, cl::init(true), cl::desc("Enable the generation of masked gathers and scatters")); namespace { class MVEGatherScatterLowering : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid explicit MVEGatherScatterLowering() : FunctionPass(ID) { initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; StringRef getPassName() const override { return "MVE gather/scatter lowering"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); AU.addRequired(); FunctionPass::getAnalysisUsage(AU); } private: LoopInfo *LI = nullptr; // Check this is a valid gather with correct alignment bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, Align Alignment); // Check whether Ptr is hidden behind a bitcast and look through it void lookThroughBitcast(Value *&Ptr); // Decompose a ptr into Base and Offsets, potentially using a GEP to return a // scalar base and vector offsets, or else fallback to using a base of 0 and // offset of Ptr where possible. Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale, FixedVectorType *Ty, Type *MemoryTy, IRBuilder<> &Builder); // Check for a getelementptr and deduce base and offsets from it, on success // returning the base directly and the offsets indirectly using the Offsets // argument Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty, GetElementPtrInst *GEP, IRBuilder<> &Builder); // Compute the scale of this gather/scatter instruction int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); // If the value is a constant, or derived from constants via additions // and multilications, return its numeric value Optional getIfConst(const Value *V); // If Inst is an add instruction, check whether one summand is a // constant. If so, scale this constant and return it together with // the other summand. std::pair getVarAndConst(Value *Inst, int TypeScale); Instruction *lowerGather(IntrinsicInst *I); // Create a gather from a base + vector of offsets Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder); // Create a gather from a vector of pointers Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment = 0); // Create an incrementing gather from a vector of pointers Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment = 0); Instruction *lowerScatter(IntrinsicInst *I); // Create a scatter to a base + vector of offsets Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, IRBuilder<> &Builder); // Create a scatter to a vector of pointers Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment = 0); // Create an incrementing scatter from a vector of pointers Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment = 0); // QI gathers and scatters can increment their offsets on their own if // the increment is a constant value (digit) Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder); // QI gathers/scatters can increment their offsets on their own if the // increment is a constant value (digit) - this creates a writeback QI // gather/scatter Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, Value *Ptr, unsigned TypeScale, IRBuilder<> &Builder); // Optimise the base and offsets of the given address bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI); // Try to fold consecutive geps together into one Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder); // Check whether these offsets could be moved out of the loop they're in bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); // Pushes the given add out of the loop void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); // Pushes the given mul out of the loop void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, Value *OffsSecondOperand, unsigned LoopIncrement, IRBuilder<> &Builder); }; } // end anonymous namespace char MVEGatherScatterLowering::ID = 0; INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, "MVE gather/scattering lowering pass", false, false) Pass *llvm::createMVEGatherScatterLoweringPass() { return new MVEGatherScatterLowering(); } bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, Align Alignment) { if (((NumElements == 4 && (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || (NumElements == 16 && ElemSize == 8)) && Alignment >= ElemSize / 8) return true; LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " << "valid alignment or vector type \n"); return false; } static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) { // Offsets that are not of type are sign extended by the // getelementptr instruction, and MVE gathers/scatters treat the offset as // unsigned. Thus, if the element size is smaller than 32, we can only allow // positive offsets - i.e., the offsets are not allowed to be variables we // can't look into. // Additionally, offsets have to either originate from a zext of a // vector with element types smaller or equal the type of the gather we're // looking at, or consist of constants that we can check are small enough // to fit into the gather type. // Thus we check that 0 < value < 2^TargetElemSize. unsigned TargetElemSize = 128 / TargetElemCount; unsigned OffsetElemSize = cast(Offsets->getType()) ->getElementType() ->getScalarSizeInBits(); if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) { Constant *ConstOff = dyn_cast(Offsets); if (!ConstOff) return false; int64_t TargetElemMaxSize = (1ULL << TargetElemSize); auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) { ConstantInt *OConst = dyn_cast(OffsetElem); if (!OConst) return false; int SExtValue = OConst->getSExtValue(); if (SExtValue >= TargetElemMaxSize || SExtValue < 0) return false; return true; }; if (isa(ConstOff->getType())) { for (unsigned i = 0; i < TargetElemCount; i++) { if (!CheckValueSize(ConstOff->getAggregateElement(i))) return false; } } else { if (!CheckValueSize(ConstOff)) return false; } } return true; } Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets, int &Scale, FixedVectorType *Ty, Type *MemoryTy, IRBuilder<> &Builder) { if (auto *GEP = dyn_cast(Ptr)) { if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) { Scale = computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(), MemoryTy->getScalarSizeInBits()); return Scale == -1 ? nullptr : V; } } // If we couldn't use the GEP (or it doesn't exist), attempt to use a // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4 // elements. FixedVectorType *PtrTy = cast(Ptr->getType()); if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32) return nullptr; Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0); Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getInt8PtrTy()); Offsets = Builder.CreatePtrToInt( Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4)); Scale = 0; return BasePtr; } Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets, FixedVectorType *Ty, GetElementPtrInst *GEP, IRBuilder<> &Builder) { if (!GEP) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer " << "found\n"); return nullptr; } LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." << " Looking at intrinsic for base + vector of offsets\n"); Value *GEPPtr = GEP->getPointerOperand(); Offsets = GEP->getOperand(1); if (GEPPtr->getType()->isVectorTy() || !isa(Offsets->getType())) return nullptr; if (GEP->getNumOperands() != 2) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" << " operands. Expanding.\n"); return nullptr; } Offsets = GEP->getOperand(1); unsigned OffsetsElemCount = cast(Offsets->getType())->getNumElements(); // Paranoid check whether the number of parallel lanes is the same assert(Ty->getNumElements() == OffsetsElemCount); ZExtInst *ZextOffs = dyn_cast(Offsets); if (ZextOffs) Offsets = ZextOffs->getOperand(0); FixedVectorType *OffsetType = cast(Offsets->getType()); // If the offsets are already being zext-ed to , that relieves us of // having to make sure that they won't overflow. if (!ZextOffs || cast(ZextOffs->getDestTy()) ->getElementType() ->getScalarSizeInBits() != 32) if (!checkOffsetSize(Offsets, OffsetsElemCount)) return nullptr; // The offset sizes have been checked; if any truncating or zext-ing is // required to fix them, do that now if (Ty != Offsets->getType()) { if ((Ty->getElementType()->getScalarSizeInBits() < OffsetType->getElementType()->getScalarSizeInBits())) { Offsets = Builder.CreateTrunc(Offsets, Ty); } else { Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty)); } } // If none of the checks failed, return the gep's base pointer LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); return GEPPtr; } void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { // Look through bitcast instruction if #elements is the same if (auto *BitCast = dyn_cast(Ptr)) { auto *BCTy = cast(BitCast->getType()); auto *BCSrcTy = cast(BitCast->getOperand(0)->getType()); if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through " << "bitcast\n"); Ptr = BitCast->getOperand(0); } } } int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, unsigned MemoryElemSize) { // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, // or a 8bit, 16bit or 32bit load/store scaled by 1 if (GEPElemSize == 32 && MemoryElemSize == 32) return 2; else if (GEPElemSize == 16 && MemoryElemSize == 16) return 1; else if (GEPElemSize == 8) return 0; LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " << "create intrinsic\n"); return -1; } Optional MVEGatherScatterLowering::getIfConst(const Value *V) { const Constant *C = dyn_cast(V); if (C != nullptr) return Optional{C->getUniqueInteger().getSExtValue()}; if (!isa(V)) return Optional{}; const Instruction *I = cast(V); if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Mul) { Optional Op0 = getIfConst(I->getOperand(0)); Optional Op1 = getIfConst(I->getOperand(1)); if (!Op0 || !Op1) return Optional{}; if (I->getOpcode() == Instruction::Add) return Optional{Op0.getValue() + Op1.getValue()}; if (I->getOpcode() == Instruction::Mul) return Optional{Op0.getValue() * Op1.getValue()}; } return Optional{}; } std::pair MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { std::pair ReturnFalse = std::pair(nullptr, 0); // At this point, the instruction we're looking at must be an add or we // bail out Instruction *Add = dyn_cast(Inst); if (Add == nullptr || Add->getOpcode() != Instruction::Add) return ReturnFalse; Value *Summand; Optional Const; // Find out which operand the value that is increased is if ((Const = getIfConst(Add->getOperand(0)))) Summand = Add->getOperand(1); else if ((Const = getIfConst(Add->getOperand(1)))) Summand = Add->getOperand(0); else return ReturnFalse; // Check that the constant is small enough for an incrementing gather int64_t Immediate = Const.getValue() << TypeScale; if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) return ReturnFalse; return std::pair(Summand, Immediate); } Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { using namespace PatternMatch; LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n" << *I << "\n"); // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) // Attempt to turn the masked gather in I into a MVE intrinsic // Potentially optimising the addressing modes as we do so. auto *Ty = cast(I->getType()); Value *Ptr = I->getArgOperand(0); Align Alignment = cast(I->getArgOperand(1))->getAlignValue(); Value *Mask = I->getArgOperand(2); Value *PassThru = I->getArgOperand(3); if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), Alignment)) return nullptr; lookThroughBitcast(Ptr); assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); IRBuilder<> Builder(I->getContext()); Builder.SetInsertPoint(I); Builder.SetCurrentDebugLocation(I->getDebugLoc()); Instruction *Root = I; Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder); if (!Load) Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); if (!Load) Load = tryCreateMaskedGatherBase(I, Ptr, Builder); if (!Load) return nullptr; if (!isa(PassThru) && !match(PassThru, m_Zero())) { LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " << "creating select\n"); Load = SelectInst::Create(Mask, Load, PassThru); Builder.Insert(Load); } Root->replaceAllUsesWith(Load); Root->eraseFromParent(); if (Root != I) // If this was an extending gather, we need to get rid of the sext/zext // sext/zext as well as of the gather itself I->eraseFromParent(); LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n" << *Load << "\n"); return Load; } Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { using namespace PatternMatch; auto *Ty = cast(I->getType()); LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) // Can't build an intrinsic for this return nullptr; Value *Mask = I->getArgOperand(2); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, {Ty, Ptr->getType()}, {Ptr, Builder.getInt32(Increment)}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_base_predicated, {Ty, Ptr->getType(), Mask->getType()}, {Ptr, Builder.getInt32(Increment), Mask}); } Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { using namespace PatternMatch; auto *Ty = cast(I->getType()); LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with " << "writeback\n"); if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) // Can't build an intrinsic for this return nullptr; Value *Mask = I->getArgOperand(2); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, {Ty, Ptr->getType()}, {Ptr, Builder.getInt32(Increment)}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_base_wb_predicated, {Ty, Ptr->getType(), Mask->getType()}, {Ptr, Builder.getInt32(Increment), Mask}); } Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { using namespace PatternMatch; Type *MemoryTy = I->getType(); Type *ResultTy = MemoryTy; unsigned Unsigned = 1; // The size of the gather was already checked in isLegalTypeAndAlignment; // if it was not a full vector width an appropriate extend should follow. auto *Extend = Root; bool TruncResult = false; if (MemoryTy->getPrimitiveSizeInBits() < 128) { if (I->hasOneUse()) { // If the gather has a single extend of the correct type, use an extending // gather and replace the ext. In which case the correct root to replace // is not the CallInst itself, but the instruction which extends it. Instruction* User = cast(*I->users().begin()); if (isa(User) && User->getType()->getPrimitiveSizeInBits() == 128) { LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: " << *User << "\n"); Extend = User; ResultTy = User->getType(); Unsigned = 0; } else if (isa(User) && User->getType()->getPrimitiveSizeInBits() == 128) { LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: " << *ResultTy << "\n"); Extend = User; ResultTy = User->getType(); } } // If an extend hasn't been found and the type is an integer, create an // extending gather and truncate back to the original type. if (ResultTy->getPrimitiveSizeInBits() < 128 && ResultTy->isIntOrIntVectorTy()) { ResultTy = ResultTy->getWithNewBitWidth( 128 / cast(ResultTy)->getNumElements()); TruncResult = true; LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: " << *ResultTy << "\n"); } // The final size of the gather must be a full vector width if (ResultTy->getPrimitiveSizeInBits() != 128) { LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided " "from the correct type. Expanding\n"); return nullptr; } } Value *Offsets; int Scale; Value *BasePtr = decomposePtr( Ptr, Offsets, Scale, cast(ResultTy), MemoryTy, Builder); if (!BasePtr) return nullptr; Root = Extend; Value *Mask = I->getArgOperand(2); Instruction *Load = nullptr; if (!match(Mask, m_One())) Load = Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_offset_predicated, {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()), Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); else Load = Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_offset, {ResultTy, BasePtr->getType(), Offsets->getType()}, {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()), Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); if (TruncResult) { Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy); Builder.Insert(Load); } return Load; } Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { using namespace PatternMatch; LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n" << *I << "\n"); // @llvm.masked.scatter.*(data, ptrs, alignment, mask) // Attempt to turn the masked scatter in I into a MVE intrinsic // Potentially optimising the addressing modes as we do so. Value *Input = I->getArgOperand(0); Value *Ptr = I->getArgOperand(1); Align Alignment = cast(I->getArgOperand(2))->getAlignValue(); auto *Ty = cast(Input->getType()); if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), Alignment)) return nullptr; lookThroughBitcast(Ptr); assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); IRBuilder<> Builder(I->getContext()); Builder.SetInsertPoint(I); Builder.SetCurrentDebugLocation(I->getDebugLoc()); Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder); if (!Store) Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); if (!Store) Store = tryCreateMaskedScatterBase(I, Ptr, Builder); if (!Store) return nullptr; LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n" << *Store << "\n"); I->eraseFromParent(); return Store; } Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { using namespace PatternMatch; Value *Input = I->getArgOperand(0); auto *Ty = cast(Input->getType()); // Only QR variants allow truncating if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { // Can't build an intrinsic for this return nullptr; } Value *Mask = I->getArgOperand(3); // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, {Ptr->getType(), Input->getType()}, {Ptr, Builder.getInt32(Increment), Input}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vstr_scatter_base_predicated, {Ptr->getType(), Input->getType(), Mask->getType()}, {Ptr, Builder.getInt32(Increment), Input, Mask}); } Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { using namespace PatternMatch; Value *Input = I->getArgOperand(0); auto *Ty = cast(Input->getType()); LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers " << "with writeback\n"); if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) // Can't build an intrinsic for this return nullptr; Value *Mask = I->getArgOperand(3); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, {Ptr->getType(), Input->getType()}, {Ptr, Builder.getInt32(Increment), Input}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, {Ptr->getType(), Input->getType(), Mask->getType()}, {Ptr, Builder.getInt32(Increment), Input, Mask}); } Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { using namespace PatternMatch; Value *Input = I->getArgOperand(0); Value *Mask = I->getArgOperand(3); Type *InputTy = Input->getType(); Type *MemoryTy = InputTy; LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" << " to base + vector of offsets\n"); // If the input has been truncated, try to integrate that trunc into the // scatter instruction (we don't care about alignment here) if (TruncInst *Trunc = dyn_cast(Input)) { Value *PreTrunc = Trunc->getOperand(0); Type *PreTruncTy = PreTrunc->getType(); if (PreTruncTy->getPrimitiveSizeInBits() == 128) { Input = PreTrunc; InputTy = PreTruncTy; } } bool ExtendInput = false; if (InputTy->getPrimitiveSizeInBits() < 128 && InputTy->isIntOrIntVectorTy()) { // If we can't find a trunc to incorporate into the instruction, create an // implicit one with a zext, so that we can still create a scatter. We know // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type // smaller than 128 bits will divide evenly into a 128bit vector. InputTy = InputTy->getWithNewBitWidth( 128 / cast(InputTy)->getNumElements()); ExtendInput = true; LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n" << *Input << "\n"); } if (InputTy->getPrimitiveSizeInBits() != 128) { LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for " "non-standard input types. Expanding.\n"); return nullptr; } Value *Offsets; int Scale; Value *BasePtr = decomposePtr( Ptr, Offsets, Scale, cast(InputTy), MemoryTy, Builder); if (!BasePtr) return nullptr; if (ExtendInput) Input = Builder.CreateZExt(Input, InputTy); if (!match(Mask, m_One())) return Builder.CreateIntrinsic( Intrinsic::arm_mve_vstr_scatter_offset_predicated, {BasePtr->getType(), Offsets->getType(), Input->getType(), Mask->getType()}, {BasePtr, Offsets, Input, Builder.getInt32(MemoryTy->getScalarSizeInBits()), Builder.getInt32(Scale), Mask}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vstr_scatter_offset, {BasePtr->getType(), Offsets->getType(), Input->getType()}, {BasePtr, Offsets, Input, Builder.getInt32(MemoryTy->getScalarSizeInBits()), Builder.getInt32(Scale)}); } Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat( IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { FixedVectorType *Ty; if (I->getIntrinsicID() == Intrinsic::masked_gather) Ty = cast(I->getType()); else Ty = cast(I->getArgOperand(0)->getType()); // Incrementing gathers only exist for v4i32 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) return nullptr; // Incrementing gathers are not beneficial outside of a loop Loop *L = LI->getLoopFor(I->getParent()); if (L == nullptr) return nullptr; // Decompose the GEP into Base and Offsets GetElementPtrInst *GEP = dyn_cast(Ptr); Value *Offsets; Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder); if (!BasePtr) return nullptr; LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " "wb gather/scatter\n"); // The gep was in charge of making sure the offsets are scaled correctly // - calculate that factor so it can be applied by hand DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout(); int TypeScale = computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()), DT.getTypeSizeInBits(GEP->getType()) / cast(GEP->getType())->getNumElements()); if (TypeScale == -1) return nullptr; if (GEP->hasOneUse()) { // Only in this case do we want to build a wb gather, because the wb will // change the phi which does affect other users of the gep (which will still // be using the phi in the old way) if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder)) return Load; } LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " "non-wb gather/scatter\n"); std::pair Add = getVarAndConst(Offsets, TypeScale); if (Add.first == nullptr) return nullptr; Value *OffsetsIncoming = Add.first; int64_t Immediate = Add.second; // Make sure the offsets are scaled correctly Instruction *ScaledOffsets = BinaryOperator::Create( Instruction::Shl, OffsetsIncoming, Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), "ScaledIndex", I); // Add the base to the offsets OffsetsIncoming = BinaryOperator::Create( Instruction::Add, ScaledOffsets, Builder.CreateVectorSplat( Ty->getNumElements(), Builder.CreatePtrToInt( BasePtr, cast(ScaledOffsets->getType())->getElementType())), "StartIndex", I); if (I->getIntrinsicID() == Intrinsic::masked_gather) return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate); else return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate); } Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, IRBuilder<> &Builder) { // Check whether this gather's offset is incremented by a constant - if so, // and the load is of the right type, we can merge this into a QI gather Loop *L = LI->getLoopFor(I->getParent()); // Offsets that are worth merging into this instruction will be incremented // by a constant, thus we're looking for an add of a phi and a constant PHINode *Phi = dyn_cast(Offsets); if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) // No phi means no IV to write back to; if there is a phi, we expect it // to have exactly two incoming values; the only phis we are interested in // will be loop IV's and have exactly two uses, one in their increment and // one in the gather's gep return nullptr; unsigned IncrementIndex = Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; // Look through the phi to the phi increment Offsets = Phi->getIncomingValue(IncrementIndex); std::pair Add = getVarAndConst(Offsets, TypeScale); if (Add.first == nullptr) return nullptr; Value *OffsetsIncoming = Add.first; int64_t Immediate = Add.second; if (OffsetsIncoming != Phi) // Then the increment we are looking at is not an increment of the // induction variable, and we don't want to do a writeback return nullptr; Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); unsigned NumElems = cast(OffsetsIncoming->getType())->getNumElements(); // Make sure the offsets are scaled correctly Instruction *ScaledOffsets = BinaryOperator::Create( Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); // Add the base to the offsets OffsetsIncoming = BinaryOperator::Create( Instruction::Add, ScaledOffsets, Builder.CreateVectorSplat( NumElems, Builder.CreatePtrToInt( BasePtr, cast(ScaledOffsets->getType())->getElementType())), "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); // The gather is pre-incrementing OffsetsIncoming = BinaryOperator::Create( Instruction::Sub, OffsetsIncoming, Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), "PreIncrementStartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); Builder.SetInsertPoint(I); Instruction *EndResult; Instruction *NewInduction; if (I->getIntrinsicID() == Intrinsic::masked_gather) { // Build the incrementing gather Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); // One value to be handed to whoever uses the gather, one is the loop // increment EndResult = ExtractValueInst::Create(Load, 0, "Gather"); NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement"); Builder.Insert(EndResult); Builder.Insert(NewInduction); } else { // Build the incrementing scatter EndResult = NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); } Instruction *AddInst = cast(Offsets); AddInst->replaceAllUsesWith(NewInduction); AddInst->eraseFromParent(); Phi->setIncomingValue(IncrementIndex, NewInduction); return EndResult; } void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); Instruction *InsertionPoint = &cast(Phi->getIncomingBlock(StartIndex)->back()); // Initialize the phi with a vector that contains a sum of the constants Instruction *NewIndex = BinaryOperator::Create( Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, "PushedOutAdd", InsertionPoint); unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; // Order such that start index comes first (this reduces mov's) Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), Phi->getIncomingBlock(IncrementIndex)); Phi->removeIncomingValue(IncrementIndex); Phi->removeIncomingValue(StartIndex); } void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, Value *IncrementPerRound, Value *OffsSecondOperand, unsigned LoopIncrement, IRBuilder<> &Builder) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); // Create a new scalar add outside of the loop and transform it to a splat // by which loop variable can be incremented Instruction *InsertionPoint = &cast( Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); // Create a new index Value *StartIndex = BinaryOperator::Create( Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), OffsSecondOperand, "PushedOutMul", InsertionPoint); Instruction *Product = BinaryOperator::Create(Instruction::Mul, IncrementPerRound, OffsSecondOperand, "Product", InsertionPoint); // Increment NewIndex by Product instead of the multiplication Instruction *NewIncrement = BinaryOperator::Create( Instruction::Add, Phi, Product, "IncrementPushedOutMul", cast(Phi->getIncomingBlock(LoopIncrement)->back()) .getPrevNode()); Phi->addIncoming(StartIndex, Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); Phi->removeIncomingValue((unsigned)0); Phi->removeIncomingValue((unsigned)0); } // Check whether all usages of this instruction are as offsets of // gathers/scatters or simple arithmetics only used by gathers/scatters static bool hasAllGatScatUsers(Instruction *I) { if (I->hasNUses(0)) { return false; } bool Gatscat = true; for (User *U : I->users()) { if (!isa(U)) return false; if (isa(U) || isGatherScatter(dyn_cast(U))) { return Gatscat; } else { unsigned OpCode = cast(U)->getOpcode(); if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && hasAllGatScatUsers(cast(U))) { continue; } return false; } } return Gatscat; } bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n" << *Offsets << "\n"); // Optimise the addresses of gathers/scatters by moving invariant // calculations out of the loop if (!isa(Offsets)) return false; Instruction *Offs = cast(Offsets); if (Offs->getOpcode() != Instruction::Add && Offs->getOpcode() != Instruction::Mul) return false; Loop *L = LI->getLoopFor(BB); if (L == nullptr) return false; if (!Offs->hasOneUse()) { if (!hasAllGatScatUsers(Offs)) return false; } // Find out which, if any, operand of the instruction // is a phi node PHINode *Phi; int OffsSecondOp; if (isa(Offs->getOperand(0))) { Phi = cast(Offs->getOperand(0)); OffsSecondOp = 1; } else if (isa(Offs->getOperand(1))) { Phi = cast(Offs->getOperand(1)); OffsSecondOp = 0; } else { bool Changed = false; if (isa(Offs->getOperand(0)) && L->contains(cast(Offs->getOperand(0)))) Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); if (isa(Offs->getOperand(1)) && L->contains(cast(Offs->getOperand(1)))) Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); if (!Changed) return false; if (isa(Offs->getOperand(0))) { Phi = cast(Offs->getOperand(0)); OffsSecondOp = 1; } else if (isa(Offs->getOperand(1))) { Phi = cast(Offs->getOperand(1)); OffsSecondOp = 0; } else { return false; } } // A phi node we want to perform this function on should be from the // loop header. if (Phi->getParent() != L->getHeader()) return false; // We're looking for a simple add recurrence. BinaryOperator *IncInstruction; Value *Start, *IncrementPerRound; if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) || IncInstruction->getOpcode() != Instruction::Add) return false; int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1; // Get the value that is added to/multiplied with the phi Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); if (IncrementPerRound->getType() != OffsSecondOperand->getType() || !L->isLoopInvariant(OffsSecondOperand)) // Something has gone wrong, abort return false; // Only proceed if the increment per round is a constant or an instruction // which does not originate from within the loop if (!isa(IncrementPerRound) && !(isa(IncrementPerRound) && !L->contains(cast(IncrementPerRound)))) return false; // If the phi is not used by anything else, we can just adapt it when // replacing the instruction; if it is, we'll have to duplicate it PHINode *NewPhi; if (Phi->getNumUses() == 2) { // No other users -> reuse existing phi (One user is the instruction // we're looking at, the other is the phi increment) if (IncInstruction->getNumUses() != 1) { // If the incrementing instruction does have more users than // our phi, we need to copy it IncInstruction = BinaryOperator::Create( Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, IncrementPerRound, "LoopIncrement", IncInstruction); Phi->setIncomingValue(IncrementingBlock, IncInstruction); } NewPhi = Phi; } else { // There are other users -> create a new phi NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi); // Copy the incoming values of the old phi NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); IncInstruction = BinaryOperator::Create( Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, IncrementPerRound, "LoopIncrement", IncInstruction); NewPhi->addIncoming(IncInstruction, Phi->getIncomingBlock(IncrementingBlock)); IncrementingBlock = 1; } IRBuilder<> Builder(BB->getContext()); Builder.SetInsertPoint(Phi); Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); switch (Offs->getOpcode()) { case Instruction::Add: pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); break; case Instruction::Mul: pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, Builder); break; default: return false; } LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable " << "add/mul\n"); // The instruction has now been "absorbed" into the phi value Offs->replaceAllUsesWith(NewPhi); if (Offs->hasNUses(0)) Offs->eraseFromParent(); // Clean up the old increment in case it's unused because we built a new // one if (IncInstruction->hasNUses(0)) IncInstruction->eraseFromParent(); return true; } static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP, IRBuilder<> &Builder) { // Splat the non-vector value to a vector of the given type - if the value is // a constant (and its value isn't too big), we can even use this opportunity // to scale it to the size of the vector elements auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) { ConstantInt *Const; if ((Const = dyn_cast(NonVectorVal)) && VT->getElementType() != NonVectorVal->getType()) { unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits(); uint64_t N = Const->getZExtValue(); if (N < (unsigned)(1 << (TargetElemSize - 1))) { NonVectorVal = Builder.CreateVectorSplat( VT->getNumElements(), Builder.getIntN(TargetElemSize, N)); return; } } NonVectorVal = Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal); }; FixedVectorType *XElType = dyn_cast(X->getType()); FixedVectorType *YElType = dyn_cast(Y->getType()); // If one of X, Y is not a vector, we have to splat it in order // to add the two of them. if (XElType && !YElType) { FixSummands(XElType, Y); YElType = cast(Y->getType()); } else if (YElType && !XElType) { FixSummands(YElType, X); XElType = cast(X->getType()); } assert(XElType && YElType && "Unknown vector types"); // Check that the summands are of compatible types if (XElType != YElType) { LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n"); return nullptr; } if (XElType->getElementType()->getScalarSizeInBits() != 32) { // Check that by adding the vectors we do not accidentally // create an overflow Constant *ConstX = dyn_cast(X); Constant *ConstY = dyn_cast(Y); if (!ConstX || !ConstY) return nullptr; unsigned TargetElemSize = 128 / XElType->getNumElements(); for (unsigned i = 0; i < XElType->getNumElements(); i++) { ConstantInt *ConstXEl = dyn_cast(ConstX->getAggregateElement(i)); ConstantInt *ConstYEl = dyn_cast(ConstY->getAggregateElement(i)); if (!ConstXEl || !ConstYEl || ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >= (unsigned)(1 << (TargetElemSize - 1))) return nullptr; } } Value *Add = Builder.CreateAdd(X, Y); FixedVectorType *GEPType = cast(GEP->getType()); if (checkOffsetSize(Add, GEPType->getNumElements())) return Add; else return nullptr; } Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder) { Value *GEPPtr = GEP->getPointerOperand(); Offsets = GEP->getOperand(1); // We only merge geps with constant offsets, because only for those // we can make sure that we do not cause an overflow if (!isa(Offsets)) return nullptr; GetElementPtrInst *BaseGEP; if ((BaseGEP = dyn_cast(GEPPtr))) { // Merge the two geps into one Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder); if (!BaseBasePtr) return nullptr; Offsets = CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder); if (Offsets == nullptr) return nullptr; return BaseBasePtr; } return GEPPtr; } bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI) { GetElementPtrInst *GEP = dyn_cast(Address); if (!GEP) return false; bool Changed = false; if (GEP->hasOneUse() && dyn_cast(GEP->getPointerOperand())) { IRBuilder<> Builder(GEP->getContext()); Builder.SetInsertPoint(GEP); Builder.SetCurrentDebugLocation(GEP->getDebugLoc()); Value *Offsets; Value *Base = foldGEP(GEP, Offsets, Builder); // We only want to merge the geps if there is a real chance that they can be // used by an MVE gather; thus the offset has to have the correct size // (always i32 if it is not of vector type) and the base has to be a // pointer. if (Offsets && Base && Base != GEP) { GetElementPtrInst *NewAddress = GetElementPtrInst::Create( GEP->getSourceElementType(), Base, Offsets, "gep.merged", GEP); GEP->replaceAllUsesWith(NewAddress); GEP = NewAddress; Changed = true; } } Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI); return Changed; } bool MVEGatherScatterLowering::runOnFunction(Function &F) { if (!EnableMaskedGatherScatters) return false; auto &TPC = getAnalysis(); auto &TM = TPC.getTM(); auto *ST = &TM.getSubtarget(F); if (!ST->hasMVEIntegerOps()) return false; LI = &getAnalysis().getLoopInfo(); SmallVector Gathers; SmallVector Scatters; bool Changed = false; for (BasicBlock &BB : F) { Changed |= SimplifyInstructionsInBlock(&BB); for (Instruction &I : BB) { IntrinsicInst *II = dyn_cast(&I); if (II && II->getIntrinsicID() == Intrinsic::masked_gather && isa(II->getType())) { Gathers.push_back(II); Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI); } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && isa(II->getArgOperand(0)->getType())) { Scatters.push_back(II); Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI); } } } for (unsigned i = 0; i < Gathers.size(); i++) { IntrinsicInst *I = Gathers[i]; Instruction *L = lowerGather(I); if (L == nullptr) continue; // Get rid of any now dead instructions SimplifyInstructionsInBlock(L->getParent()); Changed = true; } for (unsigned i = 0; i < Scatters.size(); i++) { IntrinsicInst *I = Scatters[i]; Instruction *S = lowerScatter(I); if (S == nullptr) continue; // Get rid of any now dead instructions SimplifyInstructionsInBlock(S->getParent()); Changed = true; } return Changed; }