1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to 10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to 11 /// produce a better final result as we go. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "ARM.h" 16 #include "ARMBaseInstrInfo.h" 17 #include "ARMSubtarget.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/CodeGen/TargetLowering.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/CodeGen/TargetSubtargetInfo.h" 22 #include "llvm/InitializePasses.h" 23 #include "llvm/IR/BasicBlock.h" 24 #include "llvm/IR/Constant.h" 25 #include "llvm/IR/Constants.h" 26 #include "llvm/IR/DerivedTypes.h" 27 #include "llvm/IR/Function.h" 28 #include "llvm/IR/InstrTypes.h" 29 #include "llvm/IR/Instruction.h" 30 #include "llvm/IR/Instructions.h" 31 #include "llvm/IR/IntrinsicInst.h" 32 #include "llvm/IR/Intrinsics.h" 33 #include "llvm/IR/IntrinsicsARM.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/PatternMatch.h" 36 #include "llvm/IR/Type.h" 37 #include "llvm/IR/Value.h" 38 #include "llvm/Pass.h" 39 #include "llvm/Support/Casting.h" 40 #include <algorithm> 41 #include <cassert> 42 43 using namespace llvm; 44 45 #define DEBUG_TYPE "mve-gather-scatter-lowering" 46 47 cl::opt<bool> EnableMaskedGatherScatters( 48 "enable-arm-maskedgatscat", cl::Hidden, cl::init(false), 49 cl::desc("Enable the generation of masked gathers and scatters")); 50 51 namespace { 52 53 class MVEGatherScatterLowering : public FunctionPass { 54 public: 55 static char ID; // Pass identification, replacement for typeid 56 57 explicit MVEGatherScatterLowering() : FunctionPass(ID) { 58 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); 59 } 60 61 bool runOnFunction(Function &F) override; 62 63 StringRef getPassName() const override { 64 return "MVE gather/scatter lowering"; 65 } 66 67 void getAnalysisUsage(AnalysisUsage &AU) const override { 68 AU.setPreservesCFG(); 69 AU.addRequired<TargetPassConfig>(); 70 FunctionPass::getAnalysisUsage(AU); 71 } 72 73 private: 74 // Check this is a valid gather with correct alignment 75 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, 76 unsigned Alignment); 77 // Check whether Ptr is hidden behind a bitcast and look through it 78 void lookThroughBitcast(Value *&Ptr); 79 // Check for a getelementptr and deduce base and offsets from it, on success 80 // returning the base directly and the offsets indirectly using the Offsets 81 // argument 82 Value *checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, IRBuilder<> Builder); 83 84 bool lowerGather(IntrinsicInst *I); 85 // Create a gather from a base + vector of offsets 86 Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, 87 IRBuilder<> Builder); 88 // Create a gather from a vector of pointers 89 Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, 90 IRBuilder<> Builder); 91 }; 92 93 } // end anonymous namespace 94 95 char MVEGatherScatterLowering::ID = 0; 96 97 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, 98 "MVE gather/scattering lowering pass", false, false) 99 100 Pass *llvm::createMVEGatherScatterLoweringPass() { 101 return new MVEGatherScatterLowering(); 102 } 103 104 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, 105 unsigned ElemSize, 106 unsigned Alignment) { 107 // Do only allow non-extending gathers for now 108 if (((NumElements == 4 && ElemSize == 32) || 109 (NumElements == 8 && ElemSize == 16) || 110 (NumElements == 16 && ElemSize == 8)) && 111 ElemSize / 8 <= Alignment) 112 return true; 113 LLVM_DEBUG(dbgs() << "masked gathers: instruction does not have valid " 114 << "alignment or vector type \n"); 115 return false; 116 } 117 118 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, 119 IRBuilder<> Builder) { 120 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 121 if (!GEP) { 122 LLVM_DEBUG(dbgs() << "masked gathers: no getelementpointer found\n"); 123 return nullptr; 124 } 125 LLVM_DEBUG(dbgs() << "masked gathers: getelementpointer found. Loading" 126 << " from base + vector of offsets\n"); 127 Value *GEPPtr = GEP->getPointerOperand(); 128 if (GEPPtr->getType()->isVectorTy()) { 129 LLVM_DEBUG(dbgs() << "masked gathers: gather from a vector of pointers" 130 << " hidden behind a getelementptr currently not" 131 << " supported. Expanding.\n"); 132 return nullptr; 133 } 134 if (GEP->getNumOperands() != 2) { 135 LLVM_DEBUG(dbgs() << "masked gathers: getelementptr with too many" 136 << " operands. Expanding.\n"); 137 return nullptr; 138 } 139 Offsets = GEP->getOperand(1); 140 // SExt offsets inside masked gathers are not permitted by the architecture; 141 // we therefore can't fold them 142 if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets)) 143 Offsets = ZextOffs->getOperand(0); 144 Type *OffsType = VectorType::getInteger(cast<VectorType>(Ty)); 145 // If the offset we found does not have the type the intrinsic expects, 146 // i.e., the same type as the gather itself, we need to convert it (only i 147 // types) or fall back to expanding the gather 148 if (OffsType != Offsets->getType()) { 149 if (OffsType->getScalarSizeInBits() > 150 Offsets->getType()->getScalarSizeInBits()) { 151 LLVM_DEBUG(dbgs() << "masked gathers: extending offsets\n"); 152 Offsets = Builder.CreateZExt(Offsets, OffsType, ""); 153 } else { 154 LLVM_DEBUG(dbgs() << "masked gathers: no correct offset type. Can't" 155 << " create masked gather\n"); 156 return nullptr; 157 } 158 } 159 // If none of the checks failed, return the gep's base pointer 160 return GEPPtr; 161 } 162 163 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { 164 // Look through bitcast instruction if #elements is the same 165 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { 166 Type *BCTy = BitCast->getType(); 167 Type *BCSrcTy = BitCast->getOperand(0)->getType(); 168 if (BCTy->getVectorNumElements() == BCSrcTy->getVectorNumElements()) { 169 LLVM_DEBUG(dbgs() << "masked gathers: looking through bitcast\n"); 170 Ptr = BitCast->getOperand(0); 171 } 172 } 173 } 174 175 bool MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { 176 using namespace PatternMatch; 177 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"); 178 179 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) 180 // Attempt to turn the masked gather in I into a MVE intrinsic 181 // Potentially optimising the addressing modes as we do so. 182 Type *Ty = I->getType(); 183 Value *Ptr = I->getArgOperand(0); 184 unsigned Alignment = cast<ConstantInt>(I->getArgOperand(1))->getZExtValue(); 185 Value *Mask = I->getArgOperand(2); 186 Value *PassThru = I->getArgOperand(3); 187 188 if (!isLegalTypeAndAlignment(Ty->getVectorNumElements(), 189 Ty->getScalarSizeInBits(), Alignment)) 190 return false; 191 lookThroughBitcast(Ptr); 192 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 193 194 IRBuilder<> Builder(I->getContext()); 195 Builder.SetInsertPoint(I); 196 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 197 Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Builder); 198 if (!Load) 199 Load = tryCreateMaskedGatherBase(I, Ptr, Builder); 200 if (!Load) 201 return false; 202 203 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { 204 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " 205 << "creating select\n"); 206 Load = Builder.CreateSelect(Mask, Load, PassThru); 207 } 208 209 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"); 210 I->replaceAllUsesWith(Load); 211 I->eraseFromParent(); 212 return true; 213 } 214 215 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase( 216 IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) { 217 using namespace PatternMatch; 218 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); 219 Type *Ty = I->getType(); 220 if (Ty->getVectorNumElements() != 4) 221 // Can't build an intrinsic for this 222 return nullptr; 223 Value *Mask = I->getArgOperand(2); 224 if (match(Mask, m_One())) 225 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, 226 {Ty, Ptr->getType()}, 227 {Ptr, Builder.getInt32(0)}); 228 else 229 return Builder.CreateIntrinsic( 230 Intrinsic::arm_mve_vldr_gather_base_predicated, 231 {Ty, Ptr->getType(), Mask->getType()}, 232 {Ptr, Builder.getInt32(0), Mask}); 233 } 234 235 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( 236 IntrinsicInst *I, Value *Ptr, IRBuilder<> Builder) { 237 using namespace PatternMatch; 238 Type *Ty = I->getType(); 239 Value *Offsets; 240 Value *BasePtr = checkGEP(Offsets, Ty, Ptr, Builder); 241 if (!BasePtr) 242 return nullptr; 243 244 unsigned Scale; 245 int GEPElemSize = 246 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(); 247 int ResultElemSize = Ty->getScalarSizeInBits(); 248 // This can be a 32bit load scaled by 4, a 16bit load scaled by 2, or a 249 // 8bit, 16bit or 32bit load scaled by 1 250 if (GEPElemSize == 32 && ResultElemSize == 32) { 251 Scale = 2; 252 } else if (GEPElemSize == 16 && ResultElemSize == 16) { 253 Scale = 1; 254 } else if (GEPElemSize == 8) { 255 Scale = 0; 256 } else { 257 LLVM_DEBUG(dbgs() << "masked gathers: incorrect scale for load. Can't" 258 << " create masked gather\n"); 259 return nullptr; 260 } 261 262 Value *Mask = I->getArgOperand(2); 263 if (!match(Mask, m_One())) 264 return Builder.CreateIntrinsic( 265 Intrinsic::arm_mve_vldr_gather_offset_predicated, 266 {Ty, BasePtr->getType(), Offsets->getType(), Mask->getType()}, 267 {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()), 268 Builder.getInt32(Scale), Builder.getInt32(1), Mask}); 269 else 270 return Builder.CreateIntrinsic( 271 Intrinsic::arm_mve_vldr_gather_offset, 272 {Ty, BasePtr->getType(), Offsets->getType()}, 273 {BasePtr, Offsets, Builder.getInt32(Ty->getScalarSizeInBits()), 274 Builder.getInt32(Scale), Builder.getInt32(1)}); 275 } 276 277 bool MVEGatherScatterLowering::runOnFunction(Function &F) { 278 if (!EnableMaskedGatherScatters) 279 return false; 280 auto &TPC = getAnalysis<TargetPassConfig>(); 281 auto &TM = TPC.getTM<TargetMachine>(); 282 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 283 if (!ST->hasMVEIntegerOps()) 284 return false; 285 SmallVector<IntrinsicInst *, 4> Gathers; 286 for (BasicBlock &BB : F) { 287 for (Instruction &I : BB) { 288 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 289 if (II && II->getIntrinsicID() == Intrinsic::masked_gather) 290 Gathers.push_back(II); 291 } 292 } 293 294 if (Gathers.empty()) 295 return false; 296 297 for (IntrinsicInst *I : Gathers) 298 lowerGather(I); 299 300 return true; 301 } 302