1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to 10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to 11 /// produce a better final result as we go. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "ARM.h" 16 #include "ARMBaseInstrInfo.h" 17 #include "ARMSubtarget.h" 18 #include "llvm/Analysis/LoopInfo.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/ValueTracking.h" 21 #include "llvm/CodeGen/TargetLowering.h" 22 #include "llvm/CodeGen/TargetPassConfig.h" 23 #include "llvm/CodeGen/TargetSubtargetInfo.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/IR/BasicBlock.h" 26 #include "llvm/IR/Constant.h" 27 #include "llvm/IR/Constants.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/InstrTypes.h" 31 #include "llvm/IR/Instruction.h" 32 #include "llvm/IR/Instructions.h" 33 #include "llvm/IR/IntrinsicInst.h" 34 #include "llvm/IR/Intrinsics.h" 35 #include "llvm/IR/IntrinsicsARM.h" 36 #include "llvm/IR/IRBuilder.h" 37 #include "llvm/IR/PatternMatch.h" 38 #include "llvm/IR/Type.h" 39 #include "llvm/IR/Value.h" 40 #include "llvm/Pass.h" 41 #include "llvm/Support/Casting.h" 42 #include "llvm/Transforms/Utils/Local.h" 43 #include <algorithm> 44 #include <cassert> 45 46 using namespace llvm; 47 48 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering" 49 50 cl::opt<bool> EnableMaskedGatherScatters( 51 "enable-arm-maskedgatscat", cl::Hidden, cl::init(true), 52 cl::desc("Enable the generation of masked gathers and scatters")); 53 54 namespace { 55 56 class MVEGatherScatterLowering : public FunctionPass { 57 public: 58 static char ID; // Pass identification, replacement for typeid 59 60 explicit MVEGatherScatterLowering() : FunctionPass(ID) { 61 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); 62 } 63 64 bool runOnFunction(Function &F) override; 65 66 StringRef getPassName() const override { 67 return "MVE gather/scatter lowering"; 68 } 69 70 void getAnalysisUsage(AnalysisUsage &AU) const override { 71 AU.setPreservesCFG(); 72 AU.addRequired<TargetPassConfig>(); 73 AU.addRequired<LoopInfoWrapperPass>(); 74 FunctionPass::getAnalysisUsage(AU); 75 } 76 77 private: 78 LoopInfo *LI = nullptr; 79 const DataLayout *DL; 80 81 // Check this is a valid gather with correct alignment 82 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, 83 Align Alignment); 84 // Check whether Ptr is hidden behind a bitcast and look through it 85 void lookThroughBitcast(Value *&Ptr); 86 // Decompose a ptr into Base and Offsets, potentially using a GEP to return a 87 // scalar base and vector offsets, or else fallback to using a base of 0 and 88 // offset of Ptr where possible. 89 Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale, 90 FixedVectorType *Ty, Type *MemoryTy, 91 IRBuilder<> &Builder); 92 // Check for a getelementptr and deduce base and offsets from it, on success 93 // returning the base directly and the offsets indirectly using the Offsets 94 // argument 95 Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty, 96 GetElementPtrInst *GEP, IRBuilder<> &Builder); 97 // Compute the scale of this gather/scatter instruction 98 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); 99 // If the value is a constant, or derived from constants via additions 100 // and multilications, return its numeric value 101 std::optional<int64_t> getIfConst(const Value *V); 102 // If Inst is an add instruction, check whether one summand is a 103 // constant. If so, scale this constant and return it together with 104 // the other summand. 105 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); 106 107 Instruction *lowerGather(IntrinsicInst *I); 108 // Create a gather from a base + vector of offsets 109 Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, 110 Instruction *&Root, 111 IRBuilder<> &Builder); 112 // Create a gather from a vector of pointers 113 Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, 114 IRBuilder<> &Builder, 115 int64_t Increment = 0); 116 // Create an incrementing gather from a vector of pointers 117 Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, 118 IRBuilder<> &Builder, 119 int64_t Increment = 0); 120 121 Instruction *lowerScatter(IntrinsicInst *I); 122 // Create a scatter to a base + vector of offsets 123 Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, 124 IRBuilder<> &Builder); 125 // Create a scatter to a vector of pointers 126 Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, 127 IRBuilder<> &Builder, 128 int64_t Increment = 0); 129 // Create an incrementing scatter from a vector of pointers 130 Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, 131 IRBuilder<> &Builder, 132 int64_t Increment = 0); 133 134 // QI gathers and scatters can increment their offsets on their own if 135 // the increment is a constant value (digit) 136 Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr, 137 IRBuilder<> &Builder); 138 // QI gathers/scatters can increment their offsets on their own if the 139 // increment is a constant value (digit) - this creates a writeback QI 140 // gather/scatter 141 Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, 142 Value *Ptr, unsigned TypeScale, 143 IRBuilder<> &Builder); 144 145 // Optimise the base and offsets of the given address 146 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI); 147 // Try to fold consecutive geps together into one 148 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, unsigned &Scale, 149 IRBuilder<> &Builder); 150 // Check whether these offsets could be moved out of the loop they're in 151 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); 152 // Pushes the given add out of the loop 153 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); 154 // Pushes the given mul or shl out of the loop 155 void pushOutMulShl(unsigned Opc, PHINode *&Phi, Value *IncrementPerRound, 156 Value *OffsSecondOperand, unsigned LoopIncrement, 157 IRBuilder<> &Builder); 158 }; 159 160 } // end anonymous namespace 161 162 char MVEGatherScatterLowering::ID = 0; 163 164 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, 165 "MVE gather/scattering lowering pass", false, false) 166 167 Pass *llvm::createMVEGatherScatterLoweringPass() { 168 return new MVEGatherScatterLowering(); 169 } 170 171 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, 172 unsigned ElemSize, 173 Align Alignment) { 174 if (((NumElements == 4 && 175 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || 176 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || 177 (NumElements == 16 && ElemSize == 8)) && 178 Alignment >= ElemSize / 8) 179 return true; 180 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " 181 << "valid alignment or vector type \n"); 182 return false; 183 } 184 185 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) { 186 // Offsets that are not of type <N x i32> are sign extended by the 187 // getelementptr instruction, and MVE gathers/scatters treat the offset as 188 // unsigned. Thus, if the element size is smaller than 32, we can only allow 189 // positive offsets - i.e., the offsets are not allowed to be variables we 190 // can't look into. 191 // Additionally, <N x i32> offsets have to either originate from a zext of a 192 // vector with element types smaller or equal the type of the gather we're 193 // looking at, or consist of constants that we can check are small enough 194 // to fit into the gather type. 195 // Thus we check that 0 < value < 2^TargetElemSize. 196 unsigned TargetElemSize = 128 / TargetElemCount; 197 unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType()) 198 ->getElementType() 199 ->getScalarSizeInBits(); 200 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) { 201 Constant *ConstOff = dyn_cast<Constant>(Offsets); 202 if (!ConstOff) 203 return false; 204 int64_t TargetElemMaxSize = (1ULL << TargetElemSize); 205 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) { 206 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem); 207 if (!OConst) 208 return false; 209 int SExtValue = OConst->getSExtValue(); 210 if (SExtValue >= TargetElemMaxSize || SExtValue < 0) 211 return false; 212 return true; 213 }; 214 if (isa<FixedVectorType>(ConstOff->getType())) { 215 for (unsigned i = 0; i < TargetElemCount; i++) { 216 if (!CheckValueSize(ConstOff->getAggregateElement(i))) 217 return false; 218 } 219 } else { 220 if (!CheckValueSize(ConstOff)) 221 return false; 222 } 223 } 224 return true; 225 } 226 227 Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets, 228 int &Scale, FixedVectorType *Ty, 229 Type *MemoryTy, 230 IRBuilder<> &Builder) { 231 if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) { 232 if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) { 233 Scale = 234 computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(), 235 MemoryTy->getScalarSizeInBits()); 236 return Scale == -1 ? nullptr : V; 237 } 238 } 239 240 // If we couldn't use the GEP (or it doesn't exist), attempt to use a 241 // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4 242 // elements. 243 FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType()); 244 if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32) 245 return nullptr; 246 Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0); 247 Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getPtrTy()); 248 Offsets = Builder.CreatePtrToInt( 249 Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4)); 250 Scale = 0; 251 return BasePtr; 252 } 253 254 Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets, 255 FixedVectorType *Ty, 256 GetElementPtrInst *GEP, 257 IRBuilder<> &Builder) { 258 if (!GEP) { 259 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer " 260 << "found\n"); 261 return nullptr; 262 } 263 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." 264 << " Looking at intrinsic for base + vector of offsets\n"); 265 Value *GEPPtr = GEP->getPointerOperand(); 266 Offsets = GEP->getOperand(1); 267 if (GEPPtr->getType()->isVectorTy() || 268 !isa<FixedVectorType>(Offsets->getType())) 269 return nullptr; 270 271 if (GEP->getNumOperands() != 2) { 272 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" 273 << " operands. Expanding.\n"); 274 return nullptr; 275 } 276 Offsets = GEP->getOperand(1); 277 unsigned OffsetsElemCount = 278 cast<FixedVectorType>(Offsets->getType())->getNumElements(); 279 // Paranoid check whether the number of parallel lanes is the same 280 assert(Ty->getNumElements() == OffsetsElemCount); 281 282 ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets); 283 if (ZextOffs) 284 Offsets = ZextOffs->getOperand(0); 285 FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType()); 286 287 // If the offsets are already being zext-ed to <N x i32>, that relieves us of 288 // having to make sure that they won't overflow. 289 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy()) 290 ->getElementType() 291 ->getScalarSizeInBits() != 32) 292 if (!checkOffsetSize(Offsets, OffsetsElemCount)) 293 return nullptr; 294 295 // The offset sizes have been checked; if any truncating or zext-ing is 296 // required to fix them, do that now 297 if (Ty != Offsets->getType()) { 298 if ((Ty->getElementType()->getScalarSizeInBits() < 299 OffsetType->getElementType()->getScalarSizeInBits())) { 300 Offsets = Builder.CreateTrunc(Offsets, Ty); 301 } else { 302 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty)); 303 } 304 } 305 // If none of the checks failed, return the gep's base pointer 306 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); 307 return GEPPtr; 308 } 309 310 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { 311 // Look through bitcast instruction if #elements is the same 312 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { 313 auto *BCTy = cast<FixedVectorType>(BitCast->getType()); 314 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); 315 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { 316 LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through " 317 << "bitcast\n"); 318 Ptr = BitCast->getOperand(0); 319 } 320 } 321 } 322 323 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, 324 unsigned MemoryElemSize) { 325 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, 326 // or a 8bit, 16bit or 32bit load/store scaled by 1 327 if (GEPElemSize == 32 && MemoryElemSize == 32) 328 return 2; 329 else if (GEPElemSize == 16 && MemoryElemSize == 16) 330 return 1; 331 else if (GEPElemSize == 8) 332 return 0; 333 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " 334 << "create intrinsic\n"); 335 return -1; 336 } 337 338 std::optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { 339 const Constant *C = dyn_cast<Constant>(V); 340 if (C && C->getSplatValue()) 341 return std::optional<int64_t>{C->getUniqueInteger().getSExtValue()}; 342 if (!isa<Instruction>(V)) 343 return std::optional<int64_t>{}; 344 345 const Instruction *I = cast<Instruction>(V); 346 if (I->getOpcode() == Instruction::Add || I->getOpcode() == Instruction::Or || 347 I->getOpcode() == Instruction::Mul || 348 I->getOpcode() == Instruction::Shl) { 349 std::optional<int64_t> Op0 = getIfConst(I->getOperand(0)); 350 std::optional<int64_t> Op1 = getIfConst(I->getOperand(1)); 351 if (!Op0 || !Op1) 352 return std::optional<int64_t>{}; 353 if (I->getOpcode() == Instruction::Add) 354 return std::optional<int64_t>{*Op0 + *Op1}; 355 if (I->getOpcode() == Instruction::Mul) 356 return std::optional<int64_t>{*Op0 * *Op1}; 357 if (I->getOpcode() == Instruction::Shl) 358 return std::optional<int64_t>{*Op0 << *Op1}; 359 if (I->getOpcode() == Instruction::Or) 360 return std::optional<int64_t>{*Op0 | *Op1}; 361 } 362 return std::optional<int64_t>{}; 363 } 364 365 // Return true if I is an Or instruction that is equivalent to an add, due to 366 // the operands having no common bits set. 367 static bool isAddLikeOr(Instruction *I, const DataLayout &DL) { 368 return I->getOpcode() == Instruction::Or && 369 haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL); 370 } 371 372 std::pair<Value *, int64_t> 373 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { 374 std::pair<Value *, int64_t> ReturnFalse = 375 std::pair<Value *, int64_t>(nullptr, 0); 376 // At this point, the instruction we're looking at must be an add or an 377 // add-like-or. 378 Instruction *Add = dyn_cast<Instruction>(Inst); 379 if (Add == nullptr || 380 (Add->getOpcode() != Instruction::Add && !isAddLikeOr(Add, *DL))) 381 return ReturnFalse; 382 383 Value *Summand; 384 std::optional<int64_t> Const; 385 // Find out which operand the value that is increased is 386 if ((Const = getIfConst(Add->getOperand(0)))) 387 Summand = Add->getOperand(1); 388 else if ((Const = getIfConst(Add->getOperand(1)))) 389 Summand = Add->getOperand(0); 390 else 391 return ReturnFalse; 392 393 // Check that the constant is small enough for an incrementing gather 394 int64_t Immediate = *Const << TypeScale; 395 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) 396 return ReturnFalse; 397 398 return std::pair<Value *, int64_t>(Summand, Immediate); 399 } 400 401 Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { 402 using namespace PatternMatch; 403 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n" 404 << *I << "\n"); 405 406 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) 407 // Attempt to turn the masked gather in I into a MVE intrinsic 408 // Potentially optimising the addressing modes as we do so. 409 auto *Ty = cast<FixedVectorType>(I->getType()); 410 Value *Ptr = I->getArgOperand(0); 411 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); 412 Value *Mask = I->getArgOperand(2); 413 Value *PassThru = I->getArgOperand(3); 414 415 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 416 Alignment)) 417 return nullptr; 418 lookThroughBitcast(Ptr); 419 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 420 421 IRBuilder<> Builder(I->getContext()); 422 Builder.SetInsertPoint(I); 423 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 424 425 Instruction *Root = I; 426 427 Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder); 428 if (!Load) 429 Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); 430 if (!Load) 431 Load = tryCreateMaskedGatherBase(I, Ptr, Builder); 432 if (!Load) 433 return nullptr; 434 435 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { 436 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " 437 << "creating select\n"); 438 Load = SelectInst::Create(Mask, Load, PassThru); 439 Builder.Insert(Load); 440 } 441 442 Root->replaceAllUsesWith(Load); 443 Root->eraseFromParent(); 444 if (Root != I) 445 // If this was an extending gather, we need to get rid of the sext/zext 446 // sext/zext as well as of the gather itself 447 I->eraseFromParent(); 448 449 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n" 450 << *Load << "\n"); 451 return Load; 452 } 453 454 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase( 455 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 456 using namespace PatternMatch; 457 auto *Ty = cast<FixedVectorType>(I->getType()); 458 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); 459 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 460 // Can't build an intrinsic for this 461 return nullptr; 462 Value *Mask = I->getArgOperand(2); 463 if (match(Mask, m_One())) 464 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, 465 {Ty, Ptr->getType()}, 466 {Ptr, Builder.getInt32(Increment)}); 467 else 468 return Builder.CreateIntrinsic( 469 Intrinsic::arm_mve_vldr_gather_base_predicated, 470 {Ty, Ptr->getType(), Mask->getType()}, 471 {Ptr, Builder.getInt32(Increment), Mask}); 472 } 473 474 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( 475 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 476 using namespace PatternMatch; 477 auto *Ty = cast<FixedVectorType>(I->getType()); 478 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with " 479 << "writeback\n"); 480 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 481 // Can't build an intrinsic for this 482 return nullptr; 483 Value *Mask = I->getArgOperand(2); 484 if (match(Mask, m_One())) 485 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, 486 {Ty, Ptr->getType()}, 487 {Ptr, Builder.getInt32(Increment)}); 488 else 489 return Builder.CreateIntrinsic( 490 Intrinsic::arm_mve_vldr_gather_base_wb_predicated, 491 {Ty, Ptr->getType(), Mask->getType()}, 492 {Ptr, Builder.getInt32(Increment), Mask}); 493 } 494 495 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( 496 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { 497 using namespace PatternMatch; 498 499 Type *MemoryTy = I->getType(); 500 Type *ResultTy = MemoryTy; 501 502 unsigned Unsigned = 1; 503 // The size of the gather was already checked in isLegalTypeAndAlignment; 504 // if it was not a full vector width an appropriate extend should follow. 505 auto *Extend = Root; 506 bool TruncResult = false; 507 if (MemoryTy->getPrimitiveSizeInBits() < 128) { 508 if (I->hasOneUse()) { 509 // If the gather has a single extend of the correct type, use an extending 510 // gather and replace the ext. In which case the correct root to replace 511 // is not the CallInst itself, but the instruction which extends it. 512 Instruction* User = cast<Instruction>(*I->users().begin()); 513 if (isa<SExtInst>(User) && 514 User->getType()->getPrimitiveSizeInBits() == 128) { 515 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: " 516 << *User << "\n"); 517 Extend = User; 518 ResultTy = User->getType(); 519 Unsigned = 0; 520 } else if (isa<ZExtInst>(User) && 521 User->getType()->getPrimitiveSizeInBits() == 128) { 522 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: " 523 << *ResultTy << "\n"); 524 Extend = User; 525 ResultTy = User->getType(); 526 } 527 } 528 529 // If an extend hasn't been found and the type is an integer, create an 530 // extending gather and truncate back to the original type. 531 if (ResultTy->getPrimitiveSizeInBits() < 128 && 532 ResultTy->isIntOrIntVectorTy()) { 533 ResultTy = ResultTy->getWithNewBitWidth( 534 128 / cast<FixedVectorType>(ResultTy)->getNumElements()); 535 TruncResult = true; 536 LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: " 537 << *ResultTy << "\n"); 538 } 539 540 // The final size of the gather must be a full vector width 541 if (ResultTy->getPrimitiveSizeInBits() != 128) { 542 LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided " 543 "from the correct type. Expanding\n"); 544 return nullptr; 545 } 546 } 547 548 Value *Offsets; 549 int Scale; 550 Value *BasePtr = decomposePtr( 551 Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder); 552 if (!BasePtr) 553 return nullptr; 554 555 Root = Extend; 556 Value *Mask = I->getArgOperand(2); 557 Instruction *Load = nullptr; 558 if (!match(Mask, m_One())) 559 Load = Builder.CreateIntrinsic( 560 Intrinsic::arm_mve_vldr_gather_offset_predicated, 561 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, 562 {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()), 563 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); 564 else 565 Load = Builder.CreateIntrinsic( 566 Intrinsic::arm_mve_vldr_gather_offset, 567 {ResultTy, BasePtr->getType(), Offsets->getType()}, 568 {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()), 569 Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); 570 571 if (TruncResult) { 572 Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy); 573 Builder.Insert(Load); 574 } 575 return Load; 576 } 577 578 Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { 579 using namespace PatternMatch; 580 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n" 581 << *I << "\n"); 582 583 // @llvm.masked.scatter.*(data, ptrs, alignment, mask) 584 // Attempt to turn the masked scatter in I into a MVE intrinsic 585 // Potentially optimising the addressing modes as we do so. 586 Value *Input = I->getArgOperand(0); 587 Value *Ptr = I->getArgOperand(1); 588 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); 589 auto *Ty = cast<FixedVectorType>(Input->getType()); 590 591 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 592 Alignment)) 593 return nullptr; 594 595 lookThroughBitcast(Ptr); 596 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 597 598 IRBuilder<> Builder(I->getContext()); 599 Builder.SetInsertPoint(I); 600 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 601 602 Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder); 603 if (!Store) 604 Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); 605 if (!Store) 606 Store = tryCreateMaskedScatterBase(I, Ptr, Builder); 607 if (!Store) 608 return nullptr; 609 610 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n" 611 << *Store << "\n"); 612 I->eraseFromParent(); 613 return Store; 614 } 615 616 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase( 617 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 618 using namespace PatternMatch; 619 Value *Input = I->getArgOperand(0); 620 auto *Ty = cast<FixedVectorType>(Input->getType()); 621 // Only QR variants allow truncating 622 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { 623 // Can't build an intrinsic for this 624 return nullptr; 625 } 626 Value *Mask = I->getArgOperand(3); 627 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) 628 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); 629 if (match(Mask, m_One())) 630 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, 631 {Ptr->getType(), Input->getType()}, 632 {Ptr, Builder.getInt32(Increment), Input}); 633 else 634 return Builder.CreateIntrinsic( 635 Intrinsic::arm_mve_vstr_scatter_base_predicated, 636 {Ptr->getType(), Input->getType(), Mask->getType()}, 637 {Ptr, Builder.getInt32(Increment), Input, Mask}); 638 } 639 640 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( 641 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 642 using namespace PatternMatch; 643 Value *Input = I->getArgOperand(0); 644 auto *Ty = cast<FixedVectorType>(Input->getType()); 645 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers " 646 << "with writeback\n"); 647 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 648 // Can't build an intrinsic for this 649 return nullptr; 650 Value *Mask = I->getArgOperand(3); 651 if (match(Mask, m_One())) 652 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, 653 {Ptr->getType(), Input->getType()}, 654 {Ptr, Builder.getInt32(Increment), Input}); 655 else 656 return Builder.CreateIntrinsic( 657 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, 658 {Ptr->getType(), Input->getType(), Mask->getType()}, 659 {Ptr, Builder.getInt32(Increment), Input, Mask}); 660 } 661 662 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( 663 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 664 using namespace PatternMatch; 665 Value *Input = I->getArgOperand(0); 666 Value *Mask = I->getArgOperand(3); 667 Type *InputTy = Input->getType(); 668 Type *MemoryTy = InputTy; 669 670 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" 671 << " to base + vector of offsets\n"); 672 // If the input has been truncated, try to integrate that trunc into the 673 // scatter instruction (we don't care about alignment here) 674 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { 675 Value *PreTrunc = Trunc->getOperand(0); 676 Type *PreTruncTy = PreTrunc->getType(); 677 if (PreTruncTy->getPrimitiveSizeInBits() == 128) { 678 Input = PreTrunc; 679 InputTy = PreTruncTy; 680 } 681 } 682 bool ExtendInput = false; 683 if (InputTy->getPrimitiveSizeInBits() < 128 && 684 InputTy->isIntOrIntVectorTy()) { 685 // If we can't find a trunc to incorporate into the instruction, create an 686 // implicit one with a zext, so that we can still create a scatter. We know 687 // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type 688 // smaller than 128 bits will divide evenly into a 128bit vector. 689 InputTy = InputTy->getWithNewBitWidth( 690 128 / cast<FixedVectorType>(InputTy)->getNumElements()); 691 ExtendInput = true; 692 LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n" 693 << *Input << "\n"); 694 } 695 if (InputTy->getPrimitiveSizeInBits() != 128) { 696 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for " 697 "non-standard input types. Expanding.\n"); 698 return nullptr; 699 } 700 701 Value *Offsets; 702 int Scale; 703 Value *BasePtr = decomposePtr( 704 Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder); 705 if (!BasePtr) 706 return nullptr; 707 708 if (ExtendInput) 709 Input = Builder.CreateZExt(Input, InputTy); 710 if (!match(Mask, m_One())) 711 return Builder.CreateIntrinsic( 712 Intrinsic::arm_mve_vstr_scatter_offset_predicated, 713 {BasePtr->getType(), Offsets->getType(), Input->getType(), 714 Mask->getType()}, 715 {BasePtr, Offsets, Input, 716 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 717 Builder.getInt32(Scale), Mask}); 718 else 719 return Builder.CreateIntrinsic( 720 Intrinsic::arm_mve_vstr_scatter_offset, 721 {BasePtr->getType(), Offsets->getType(), Input->getType()}, 722 {BasePtr, Offsets, Input, 723 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 724 Builder.getInt32(Scale)}); 725 } 726 727 Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat( 728 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 729 FixedVectorType *Ty; 730 if (I->getIntrinsicID() == Intrinsic::masked_gather) 731 Ty = cast<FixedVectorType>(I->getType()); 732 else 733 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); 734 735 // Incrementing gathers only exist for v4i32 736 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 737 return nullptr; 738 // Incrementing gathers are not beneficial outside of a loop 739 Loop *L = LI->getLoopFor(I->getParent()); 740 if (L == nullptr) 741 return nullptr; 742 743 // Decompose the GEP into Base and Offsets 744 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 745 Value *Offsets; 746 Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder); 747 if (!BasePtr) 748 return nullptr; 749 750 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 751 "wb gather/scatter\n"); 752 753 // The gep was in charge of making sure the offsets are scaled correctly 754 // - calculate that factor so it can be applied by hand 755 int TypeScale = 756 computeScale(DL->getTypeSizeInBits(GEP->getOperand(0)->getType()), 757 DL->getTypeSizeInBits(GEP->getType()) / 758 cast<FixedVectorType>(GEP->getType())->getNumElements()); 759 if (TypeScale == -1) 760 return nullptr; 761 762 if (GEP->hasOneUse()) { 763 // Only in this case do we want to build a wb gather, because the wb will 764 // change the phi which does affect other users of the gep (which will still 765 // be using the phi in the old way) 766 if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, 767 TypeScale, Builder)) 768 return Load; 769 } 770 771 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 772 "non-wb gather/scatter\n"); 773 774 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 775 if (Add.first == nullptr) 776 return nullptr; 777 Value *OffsetsIncoming = Add.first; 778 int64_t Immediate = Add.second; 779 780 // Make sure the offsets are scaled correctly 781 Instruction *ScaledOffsets = BinaryOperator::Create( 782 Instruction::Shl, OffsetsIncoming, 783 Builder.CreateVectorSplat(Ty->getNumElements(), 784 Builder.getInt32(TypeScale)), 785 "ScaledIndex", I->getIterator()); 786 // Add the base to the offsets 787 OffsetsIncoming = BinaryOperator::Create( 788 Instruction::Add, ScaledOffsets, 789 Builder.CreateVectorSplat( 790 Ty->getNumElements(), 791 Builder.CreatePtrToInt( 792 BasePtr, 793 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 794 "StartIndex", I->getIterator()); 795 796 if (I->getIntrinsicID() == Intrinsic::masked_gather) 797 return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate); 798 else 799 return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate); 800 } 801 802 Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( 803 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, 804 IRBuilder<> &Builder) { 805 // Check whether this gather's offset is incremented by a constant - if so, 806 // and the load is of the right type, we can merge this into a QI gather 807 Loop *L = LI->getLoopFor(I->getParent()); 808 // Offsets that are worth merging into this instruction will be incremented 809 // by a constant, thus we're looking for an add of a phi and a constant 810 PHINode *Phi = dyn_cast<PHINode>(Offsets); 811 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || 812 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) 813 // No phi means no IV to write back to; if there is a phi, we expect it 814 // to have exactly two incoming values; the only phis we are interested in 815 // will be loop IV's and have exactly two uses, one in their increment and 816 // one in the gather's gep 817 return nullptr; 818 819 unsigned IncrementIndex = 820 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; 821 // Look through the phi to the phi increment 822 Offsets = Phi->getIncomingValue(IncrementIndex); 823 824 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 825 if (Add.first == nullptr) 826 return nullptr; 827 Value *OffsetsIncoming = Add.first; 828 int64_t Immediate = Add.second; 829 if (OffsetsIncoming != Phi) 830 // Then the increment we are looking at is not an increment of the 831 // induction variable, and we don't want to do a writeback 832 return nullptr; 833 834 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); 835 unsigned NumElems = 836 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); 837 838 // Make sure the offsets are scaled correctly 839 Instruction *ScaledOffsets = BinaryOperator::Create( 840 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), 841 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), 842 "ScaledIndex", 843 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator()); 844 // Add the base to the offsets 845 OffsetsIncoming = BinaryOperator::Create( 846 Instruction::Add, ScaledOffsets, 847 Builder.CreateVectorSplat( 848 NumElems, 849 Builder.CreatePtrToInt( 850 BasePtr, 851 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 852 "StartIndex", 853 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator()); 854 // The gather is pre-incrementing 855 OffsetsIncoming = BinaryOperator::Create( 856 Instruction::Sub, OffsetsIncoming, 857 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), 858 "PreIncrementStartIndex", 859 Phi->getIncomingBlock(1 - IncrementIndex)->back().getIterator()); 860 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); 861 862 Builder.SetInsertPoint(I); 863 864 Instruction *EndResult; 865 Instruction *NewInduction; 866 if (I->getIntrinsicID() == Intrinsic::masked_gather) { 867 // Build the incrementing gather 868 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); 869 // One value to be handed to whoever uses the gather, one is the loop 870 // increment 871 EndResult = ExtractValueInst::Create(Load, 0, "Gather"); 872 NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement"); 873 Builder.Insert(EndResult); 874 Builder.Insert(NewInduction); 875 } else { 876 // Build the incrementing scatter 877 EndResult = NewInduction = 878 tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); 879 } 880 Instruction *AddInst = cast<Instruction>(Offsets); 881 AddInst->replaceAllUsesWith(NewInduction); 882 AddInst->eraseFromParent(); 883 Phi->setIncomingValue(IncrementIndex, NewInduction); 884 885 return EndResult; 886 } 887 888 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, 889 Value *OffsSecondOperand, 890 unsigned StartIndex) { 891 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); 892 BasicBlock::iterator InsertionPoint = 893 Phi->getIncomingBlock(StartIndex)->back().getIterator(); 894 // Initialize the phi with a vector that contains a sum of the constants 895 Instruction *NewIndex = BinaryOperator::Create( 896 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, 897 "PushedOutAdd", InsertionPoint); 898 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; 899 900 // Order such that start index comes first (this reduces mov's) 901 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); 902 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), 903 Phi->getIncomingBlock(IncrementIndex)); 904 Phi->removeIncomingValue(1); 905 Phi->removeIncomingValue((unsigned)0); 906 } 907 908 void MVEGatherScatterLowering::pushOutMulShl(unsigned Opcode, PHINode *&Phi, 909 Value *IncrementPerRound, 910 Value *OffsSecondOperand, 911 unsigned LoopIncrement, 912 IRBuilder<> &Builder) { 913 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); 914 915 // Create a new scalar add outside of the loop and transform it to a splat 916 // by which loop variable can be incremented 917 BasicBlock::iterator InsertionPoint = 918 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back().getIterator(); 919 920 // Create a new index 921 Value *StartIndex = 922 BinaryOperator::Create((Instruction::BinaryOps)Opcode, 923 Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), 924 OffsSecondOperand, "PushedOutMul", InsertionPoint); 925 926 Instruction *Product = 927 BinaryOperator::Create((Instruction::BinaryOps)Opcode, IncrementPerRound, 928 OffsSecondOperand, "Product", InsertionPoint); 929 930 BasicBlock::iterator NewIncrInsertPt = 931 Phi->getIncomingBlock(LoopIncrement)->back().getIterator(); 932 NewIncrInsertPt = std::prev(NewIncrInsertPt); 933 934 // Increment NewIndex by Product instead of the multiplication 935 Instruction *NewIncrement = BinaryOperator::Create( 936 Instruction::Add, Phi, Product, "IncrementPushedOutMul", NewIncrInsertPt); 937 938 Phi->addIncoming(StartIndex, 939 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); 940 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); 941 Phi->removeIncomingValue((unsigned)0); 942 Phi->removeIncomingValue((unsigned)0); 943 } 944 945 // Check whether all usages of this instruction are as offsets of 946 // gathers/scatters or simple arithmetics only used by gathers/scatters 947 static bool hasAllGatScatUsers(Instruction *I, const DataLayout &DL) { 948 if (I->hasNUses(0)) { 949 return false; 950 } 951 bool Gatscat = true; 952 for (User *U : I->users()) { 953 if (!isa<Instruction>(U)) 954 return false; 955 if (isa<GetElementPtrInst>(U) || 956 isGatherScatter(dyn_cast<IntrinsicInst>(U))) { 957 return Gatscat; 958 } else { 959 unsigned OpCode = cast<Instruction>(U)->getOpcode(); 960 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul || 961 OpCode == Instruction::Shl || 962 isAddLikeOr(cast<Instruction>(U), DL)) && 963 hasAllGatScatUsers(cast<Instruction>(U), DL)) { 964 continue; 965 } 966 return false; 967 } 968 } 969 return Gatscat; 970 } 971 972 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, 973 LoopInfo *LI) { 974 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize: " 975 << *Offsets << "\n"); 976 // Optimise the addresses of gathers/scatters by moving invariant 977 // calculations out of the loop 978 if (!isa<Instruction>(Offsets)) 979 return false; 980 Instruction *Offs = cast<Instruction>(Offsets); 981 if (Offs->getOpcode() != Instruction::Add && !isAddLikeOr(Offs, *DL) && 982 Offs->getOpcode() != Instruction::Mul && 983 Offs->getOpcode() != Instruction::Shl) 984 return false; 985 Loop *L = LI->getLoopFor(BB); 986 if (L == nullptr) 987 return false; 988 if (!Offs->hasOneUse()) { 989 if (!hasAllGatScatUsers(Offs, *DL)) 990 return false; 991 } 992 993 // Find out which, if any, operand of the instruction 994 // is a phi node 995 PHINode *Phi; 996 int OffsSecondOp; 997 if (isa<PHINode>(Offs->getOperand(0))) { 998 Phi = cast<PHINode>(Offs->getOperand(0)); 999 OffsSecondOp = 1; 1000 } else if (isa<PHINode>(Offs->getOperand(1))) { 1001 Phi = cast<PHINode>(Offs->getOperand(1)); 1002 OffsSecondOp = 0; 1003 } else { 1004 bool Changed = false; 1005 if (isa<Instruction>(Offs->getOperand(0)) && 1006 L->contains(cast<Instruction>(Offs->getOperand(0)))) 1007 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); 1008 if (isa<Instruction>(Offs->getOperand(1)) && 1009 L->contains(cast<Instruction>(Offs->getOperand(1)))) 1010 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); 1011 if (!Changed) 1012 return false; 1013 if (isa<PHINode>(Offs->getOperand(0))) { 1014 Phi = cast<PHINode>(Offs->getOperand(0)); 1015 OffsSecondOp = 1; 1016 } else if (isa<PHINode>(Offs->getOperand(1))) { 1017 Phi = cast<PHINode>(Offs->getOperand(1)); 1018 OffsSecondOp = 0; 1019 } else { 1020 return false; 1021 } 1022 } 1023 // A phi node we want to perform this function on should be from the 1024 // loop header. 1025 if (Phi->getParent() != L->getHeader()) 1026 return false; 1027 1028 // We're looking for a simple add recurrence. 1029 BinaryOperator *IncInstruction; 1030 Value *Start, *IncrementPerRound; 1031 if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) || 1032 IncInstruction->getOpcode() != Instruction::Add) 1033 return false; 1034 1035 int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1; 1036 1037 // Get the value that is added to/multiplied with the phi 1038 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); 1039 1040 if (IncrementPerRound->getType() != OffsSecondOperand->getType() || 1041 !L->isLoopInvariant(OffsSecondOperand)) 1042 // Something has gone wrong, abort 1043 return false; 1044 1045 // Only proceed if the increment per round is a constant or an instruction 1046 // which does not originate from within the loop 1047 if (!isa<Constant>(IncrementPerRound) && 1048 !(isa<Instruction>(IncrementPerRound) && 1049 !L->contains(cast<Instruction>(IncrementPerRound)))) 1050 return false; 1051 1052 // If the phi is not used by anything else, we can just adapt it when 1053 // replacing the instruction; if it is, we'll have to duplicate it 1054 PHINode *NewPhi; 1055 if (Phi->getNumUses() == 2) { 1056 // No other users -> reuse existing phi (One user is the instruction 1057 // we're looking at, the other is the phi increment) 1058 if (IncInstruction->getNumUses() != 1) { 1059 // If the incrementing instruction does have more users than 1060 // our phi, we need to copy it 1061 IncInstruction = BinaryOperator::Create( 1062 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, 1063 IncrementPerRound, "LoopIncrement", IncInstruction->getIterator()); 1064 Phi->setIncomingValue(IncrementingBlock, IncInstruction); 1065 } 1066 NewPhi = Phi; 1067 } else { 1068 // There are other users -> create a new phi 1069 NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi->getIterator()); 1070 // Copy the incoming values of the old phi 1071 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), 1072 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); 1073 IncInstruction = BinaryOperator::Create( 1074 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, 1075 IncrementPerRound, "LoopIncrement", IncInstruction->getIterator()); 1076 NewPhi->addIncoming(IncInstruction, 1077 Phi->getIncomingBlock(IncrementingBlock)); 1078 IncrementingBlock = 1; 1079 } 1080 1081 IRBuilder<> Builder(BB->getContext()); 1082 Builder.SetInsertPoint(Phi); 1083 Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); 1084 1085 switch (Offs->getOpcode()) { 1086 case Instruction::Add: 1087 case Instruction::Or: 1088 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); 1089 break; 1090 case Instruction::Mul: 1091 case Instruction::Shl: 1092 pushOutMulShl(Offs->getOpcode(), NewPhi, IncrementPerRound, 1093 OffsSecondOperand, IncrementingBlock, Builder); 1094 break; 1095 default: 1096 return false; 1097 } 1098 LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable " 1099 << "add/mul\n"); 1100 1101 // The instruction has now been "absorbed" into the phi value 1102 Offs->replaceAllUsesWith(NewPhi); 1103 if (Offs->hasNUses(0)) 1104 Offs->eraseFromParent(); 1105 // Clean up the old increment in case it's unused because we built a new 1106 // one 1107 if (IncInstruction->hasNUses(0)) 1108 IncInstruction->eraseFromParent(); 1109 1110 return true; 1111 } 1112 1113 static Value *CheckAndCreateOffsetAdd(Value *X, unsigned ScaleX, Value *Y, 1114 unsigned ScaleY, IRBuilder<> &Builder) { 1115 // Splat the non-vector value to a vector of the given type - if the value is 1116 // a constant (and its value isn't too big), we can even use this opportunity 1117 // to scale it to the size of the vector elements 1118 auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) { 1119 ConstantInt *Const; 1120 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) && 1121 VT->getElementType() != NonVectorVal->getType()) { 1122 unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits(); 1123 uint64_t N = Const->getZExtValue(); 1124 if (N < (unsigned)(1 << (TargetElemSize - 1))) { 1125 NonVectorVal = Builder.CreateVectorSplat( 1126 VT->getNumElements(), Builder.getIntN(TargetElemSize, N)); 1127 return; 1128 } 1129 } 1130 NonVectorVal = 1131 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal); 1132 }; 1133 1134 FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType()); 1135 FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType()); 1136 // If one of X, Y is not a vector, we have to splat it in order 1137 // to add the two of them. 1138 if (XElType && !YElType) { 1139 FixSummands(XElType, Y); 1140 YElType = cast<FixedVectorType>(Y->getType()); 1141 } else if (YElType && !XElType) { 1142 FixSummands(YElType, X); 1143 XElType = cast<FixedVectorType>(X->getType()); 1144 } 1145 assert(XElType && YElType && "Unknown vector types"); 1146 // Check that the summands are of compatible types 1147 if (XElType != YElType) { 1148 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n"); 1149 return nullptr; 1150 } 1151 1152 if (XElType->getElementType()->getScalarSizeInBits() != 32) { 1153 // Check that by adding the vectors we do not accidentally 1154 // create an overflow 1155 Constant *ConstX = dyn_cast<Constant>(X); 1156 Constant *ConstY = dyn_cast<Constant>(Y); 1157 if (!ConstX || !ConstY) 1158 return nullptr; 1159 unsigned TargetElemSize = 128 / XElType->getNumElements(); 1160 for (unsigned i = 0; i < XElType->getNumElements(); i++) { 1161 ConstantInt *ConstXEl = 1162 dyn_cast<ConstantInt>(ConstX->getAggregateElement(i)); 1163 ConstantInt *ConstYEl = 1164 dyn_cast<ConstantInt>(ConstY->getAggregateElement(i)); 1165 if (!ConstXEl || !ConstYEl || 1166 ConstXEl->getZExtValue() * ScaleX + 1167 ConstYEl->getZExtValue() * ScaleY >= 1168 (unsigned)(1 << (TargetElemSize - 1))) 1169 return nullptr; 1170 } 1171 } 1172 1173 Value *XScale = Builder.CreateVectorSplat( 1174 XElType->getNumElements(), 1175 Builder.getIntN(XElType->getScalarSizeInBits(), ScaleX)); 1176 Value *YScale = Builder.CreateVectorSplat( 1177 YElType->getNumElements(), 1178 Builder.getIntN(YElType->getScalarSizeInBits(), ScaleY)); 1179 Value *Add = Builder.CreateAdd(Builder.CreateMul(X, XScale), 1180 Builder.CreateMul(Y, YScale)); 1181 1182 if (checkOffsetSize(Add, XElType->getNumElements())) 1183 return Add; 1184 else 1185 return nullptr; 1186 } 1187 1188 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP, 1189 Value *&Offsets, unsigned &Scale, 1190 IRBuilder<> &Builder) { 1191 Value *GEPPtr = GEP->getPointerOperand(); 1192 Offsets = GEP->getOperand(1); 1193 Scale = DL->getTypeAllocSize(GEP->getSourceElementType()); 1194 // We only merge geps with constant offsets, because only for those 1195 // we can make sure that we do not cause an overflow 1196 if (GEP->getNumIndices() != 1 || !isa<Constant>(Offsets)) 1197 return nullptr; 1198 if (GetElementPtrInst *BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr)) { 1199 // Merge the two geps into one 1200 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Scale, Builder); 1201 if (!BaseBasePtr) 1202 return nullptr; 1203 Offsets = CheckAndCreateOffsetAdd( 1204 Offsets, Scale, GEP->getOperand(1), 1205 DL->getTypeAllocSize(GEP->getSourceElementType()), Builder); 1206 if (Offsets == nullptr) 1207 return nullptr; 1208 Scale = 1; // Scale is always an i8 at this point. 1209 return BaseBasePtr; 1210 } 1211 return GEPPtr; 1212 } 1213 1214 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB, 1215 LoopInfo *LI) { 1216 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address); 1217 if (!GEP) 1218 return false; 1219 bool Changed = false; 1220 if (GEP->hasOneUse() && isa<GetElementPtrInst>(GEP->getPointerOperand())) { 1221 IRBuilder<> Builder(GEP->getContext()); 1222 Builder.SetInsertPoint(GEP); 1223 Builder.SetCurrentDebugLocation(GEP->getDebugLoc()); 1224 Value *Offsets; 1225 unsigned Scale; 1226 Value *Base = foldGEP(GEP, Offsets, Scale, Builder); 1227 // We only want to merge the geps if there is a real chance that they can be 1228 // used by an MVE gather; thus the offset has to have the correct size 1229 // (always i32 if it is not of vector type) and the base has to be a 1230 // pointer. 1231 if (Offsets && Base && Base != GEP) { 1232 assert(Scale == 1 && "Expected to fold GEP to a scale of 1"); 1233 Type *BaseTy = Builder.getPtrTy(); 1234 if (auto *VecTy = dyn_cast<FixedVectorType>(Base->getType())) 1235 BaseTy = FixedVectorType::get(BaseTy, VecTy); 1236 GetElementPtrInst *NewAddress = GetElementPtrInst::Create( 1237 Builder.getInt8Ty(), Builder.CreateBitCast(Base, BaseTy), Offsets, 1238 "gep.merged", GEP->getIterator()); 1239 LLVM_DEBUG(dbgs() << "Folded GEP: " << *GEP 1240 << "\n new : " << *NewAddress << "\n"); 1241 GEP->replaceAllUsesWith( 1242 Builder.CreateBitCast(NewAddress, GEP->getType())); 1243 GEP = NewAddress; 1244 Changed = true; 1245 } 1246 } 1247 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI); 1248 return Changed; 1249 } 1250 1251 bool MVEGatherScatterLowering::runOnFunction(Function &F) { 1252 if (!EnableMaskedGatherScatters) 1253 return false; 1254 auto &TPC = getAnalysis<TargetPassConfig>(); 1255 auto &TM = TPC.getTM<TargetMachine>(); 1256 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 1257 if (!ST->hasMVEIntegerOps()) 1258 return false; 1259 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1260 DL = &F.getDataLayout(); 1261 SmallVector<IntrinsicInst *, 4> Gathers; 1262 SmallVector<IntrinsicInst *, 4> Scatters; 1263 1264 bool Changed = false; 1265 1266 for (BasicBlock &BB : F) { 1267 Changed |= SimplifyInstructionsInBlock(&BB); 1268 1269 for (Instruction &I : BB) { 1270 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 1271 if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 1272 isa<FixedVectorType>(II->getType())) { 1273 Gathers.push_back(II); 1274 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI); 1275 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 1276 isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 1277 Scatters.push_back(II); 1278 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI); 1279 } 1280 } 1281 } 1282 for (IntrinsicInst *I : Gathers) { 1283 Instruction *L = lowerGather(I); 1284 if (L == nullptr) 1285 continue; 1286 1287 // Get rid of any now dead instructions 1288 SimplifyInstructionsInBlock(L->getParent()); 1289 Changed = true; 1290 } 1291 1292 for (IntrinsicInst *I : Scatters) { 1293 Instruction *S = lowerScatter(I); 1294 if (S == nullptr) 1295 continue; 1296 1297 // Get rid of any now dead instructions 1298 SimplifyInstructionsInBlock(S->getParent()); 1299 Changed = true; 1300 } 1301 return Changed; 1302 } 1303