//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// This pass custom lowers llvm.gather and llvm.scatter instructions to /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to /// produce a better final result as we go. // //===----------------------------------------------------------------------===// #include "ARM.h" #include "ARMBaseInstrInfo.h" #include "ARMSubtarget.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/InitializePasses.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include #include using namespace llvm; #define DEBUG_TYPE "mve-gather-scatter-lowering" cl::opt EnableMaskedGatherScatters( "enable-arm-maskedgatscat", cl::Hidden, cl::init(false), cl::desc("Enable the generation of masked gathers and scatters")); namespace { class MVEGatherScatterLowering : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid explicit MVEGatherScatterLowering() : FunctionPass(ID) { initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; StringRef getPassName() const override { return "MVE gather/scatter lowering"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); FunctionPass::getAnalysisUsage(AU); } private: // Check this is a valid gather with correct alignment bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, unsigned Alignment); // Check whether Ptr is hidden behind a bitcast and look through it void lookThroughBitcast(Value *&Ptr); // Check for a getelementptr and deduce base and offsets from it, on success // returning the base directly and the offsets indirectly using the Offsets // argument Value *checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, IRBuilder<> Builder); bool lowerGather(IntrinsicInst *I); // Create a gather from a base + vector of offsets Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder); // Create a gather from a vector of pointers Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder); }; } // end anonymous namespace char MVEGatherScatterLowering::ID = 0; INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, "MVE gather/scattering lowering pass", false, false) Pass *llvm::createMVEGatherScatterLoweringPass() { return new MVEGatherScatterLowering(); } bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, unsigned Alignment) { // Do only allow non-extending gathers for now if (((NumElements == 4 && ElemSize == 32) || (NumElements == 8 && ElemSize == 16) || (NumElements == 16 && ElemSize == 8)) && ElemSize / 8 <= Alignment) return true; LLVM_DEBUG(dbgs() << "masked gathers: instruction does not have valid " << "alignment or vector type \n"); return false; } Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, IRBuilder<> Builder) { GetElementPtrInst *GEP = dyn_cast(Ptr); if (!GEP) { LLVM_DEBUG(dbgs() << "masked gathers: no getelementpointer found\n"); return nullptr; } LLVM_DEBUG(dbgs() << "masked gathers: getelementpointer found. Loading" << " from base + vector of offsets\n"); Value *GEPPtr = GEP->getPointerOperand(); if (GEPPtr->getType()->isVectorTy()) { LLVM_DEBUG(dbgs() << "masked gathers: gather from a vector of pointers" << " hidden behind a getelementptr currently not" << " supported. Expanding.\n"); return nullptr; } if (GEP->getNumOperands() != 2) { LLVM_DEBUG(dbgs() << "masked gathers: getelementptr with too many" << " operands. Expanding.\n"); return nullptr; } Offsets = GEP->getOperand(1); // SExt offsets inside masked gathers are not permitted by the architecture; // we therefore can't fold them if (ZExtInst *ZextOffs = dyn_cast(Offsets)) Offsets = ZextOffs->getOperand(0); Type *OffsType = VectorType::getInteger(cast(Ty)); // If the offset we found does not have the type the intrinsic expects, // i.e., the same type as the gather itself, we need to convert it (only i // types) or fall back to expanding the gather if (OffsType != Offsets->getType()) { if (OffsType->getScalarSizeInBits() > Offsets->getType()->getScalarSizeInBits()) { LLVM_DEBUG(dbgs() << "masked gathers: extending offsets\n"); Offsets = Builder.CreateZExt(Offsets, OffsType, ""); } else { LLVM_DEBUG(dbgs() << "masked gathers: no correct offset type. Can't" << " create masked gather\n"); return nullptr; } } // If none of the checks failed, return the gep's base pointer return GEPPtr; } void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { // Look through bitcast instruction if #elements is the same if (auto *BitCast = dyn_cast(Ptr)) { Type *BCTy = BitCast->getType(); Type *BCSrcTy = BitCast->getOperand(0)->getType(); if (BCTy->getVectorNumElements() == BCSrcTy->getVectorNumElements()) { LLVM_DEBUG(dbgs() << "masked gathers: looking through bitcast\n"); Ptr = BitCast->getOperand(0); } } } bool MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { using namespace PatternMatch; LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"); // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) // Attempt to turn the masked gather in I into a MVE intrinsic // Potentially optimising the addressing modes as we do so. Type *Ty = I->getType(); Value *Ptr = I->getArgOperand(0); unsigned Alignment = cast(I->getArgOperand(1))->getZExtValue(); Value *Mask = I->getArgOperand(2); Value *PassThru = I->getArgOperand(3); if (!isLegalTypeAndAlignment(Ty->getVectorNumElements(), Ty->getScalarSizeInBits(), Alignment)) return false; lookThroughBitcast(Ptr); assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); IRBuilder<> Builder(I->getContext()); Builder.SetInsertPoint(I); Builder.SetCurrentDebugLocation(I->getDebugLoc()); Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Builder); if (!Load) Load = tryCreateMaskedGatherBase(I, Ptr, Builder); if (!Load) return false; if (!isa(PassThru) && !match(PassThru, m_Zero())) { LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " << "creating select\n"); Load = Builder.CreateSelect(Mask, Load, PassThru); } LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"); I->replaceAllUsesWith(Load); I->eraseFromParent(); return true; } Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase( IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) { using namespace PatternMatch; LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); Type *Ty = I->getType(); if (Ty->getVectorNumElements() != 4) // Can't build an intrinsic for this return nullptr; Value *Mask = I->getArgOperand(2); if (match(Mask, m_One())) return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, {Ty, Ptr->getType()}, {Ptr, Builder.getInt32(0)}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_base_predicated, {Ty, Ptr->getType(), Mask->getType()}, {Ptr, Builder.getInt32(0), Mask}); } Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) { using namespace PatternMatch; Type *Ty = I->getType(); Value *Offsets; Value *BasePtr = checkGEP(Offsets, Ty, Ptr, Builder); if (!BasePtr) return nullptr; unsigned Scale; int GEPElemSize = BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(); int ResultElemSize = Ty->getScalarSizeInBits(); // This can be a 32bit load scaled by 4, a 16bit load scaled by 2, or a // 8bit, 16bit or 32bit load scaled by 1 if (GEPElemSize == 32 && ResultElemSize == 32) { Scale = 2; } else if (GEPElemSize == 16 && ResultElemSize == 16) { Scale = 1; } else if (GEPElemSize == 8) { Scale = 0; } else { LLVM_DEBUG(dbgs() << "masked gathers: incorrect scale for load. Can't" << " create masked gather\n"); return nullptr; } Value *Mask = I->getArgOperand(2); if (!match(Mask, m_One())) return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_offset_predicated, {Ty, BasePtr->getType(), Offsets->getType(), Mask->getType()}, {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()), Builder.getInt32(Scale), Builder.getInt32(1), Mask}); else return Builder.CreateIntrinsic( Intrinsic::arm_mve_vldr_gather_offset, {Ty, BasePtr->getType(), Offsets->getType()}, {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()), Builder.getInt32(Scale), Builder.getInt32(1)}); } bool MVEGatherScatterLowering::runOnFunction(Function &F) { if (!EnableMaskedGatherScatters) return false; auto &TPC = getAnalysis(); auto &TM = TPC.getTM(); auto *ST = &TM.getSubtarget(F); if (!ST->hasMVEIntegerOps()) return false; SmallVector Gathers; for (BasicBlock &BB : F) { for (Instruction &I : BB) { IntrinsicInst *II = dyn_cast(&I); if (II && II->getIntrinsicID() == Intrinsic::masked_gather) Gathers.push_back(II); } } if (Gathers.empty()) return false; for (IntrinsicInst *I : Gathers) lowerGather(I); return true; }