1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to 10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to 11 /// produce a better final result as we go. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "ARM.h" 16 #include "ARMBaseInstrInfo.h" 17 #include "ARMSubtarget.h" 18 #include "llvm/Analysis/LoopInfo.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/CodeGen/TargetLowering.h" 21 #include "llvm/CodeGen/TargetPassConfig.h" 22 #include "llvm/CodeGen/TargetSubtargetInfo.h" 23 #include "llvm/InitializePasses.h" 24 #include "llvm/IR/BasicBlock.h" 25 #include "llvm/IR/Constant.h" 26 #include "llvm/IR/Constants.h" 27 #include "llvm/IR/DerivedTypes.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/InstrTypes.h" 30 #include "llvm/IR/Instruction.h" 31 #include "llvm/IR/Instructions.h" 32 #include "llvm/IR/IntrinsicInst.h" 33 #include "llvm/IR/Intrinsics.h" 34 #include "llvm/IR/IntrinsicsARM.h" 35 #include "llvm/IR/IRBuilder.h" 36 #include "llvm/IR/PatternMatch.h" 37 #include "llvm/IR/Type.h" 38 #include "llvm/IR/Value.h" 39 #include "llvm/Pass.h" 40 #include "llvm/Support/Casting.h" 41 #include "llvm/Transforms/Utils/Local.h" 42 #include <algorithm> 43 #include <cassert> 44 45 using namespace llvm; 46 47 #define DEBUG_TYPE "mve-gather-scatter-lowering" 48 49 cl::opt<bool> EnableMaskedGatherScatters( 50 "enable-arm-maskedgatscat", cl::Hidden, cl::init(false), 51 cl::desc("Enable the generation of masked gathers and scatters")); 52 53 namespace { 54 55 class MVEGatherScatterLowering : public FunctionPass { 56 public: 57 static char ID; // Pass identification, replacement for typeid 58 59 explicit MVEGatherScatterLowering() : FunctionPass(ID) { 60 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); 61 } 62 63 bool runOnFunction(Function &F) override; 64 65 StringRef getPassName() const override { 66 return "MVE gather/scatter lowering"; 67 } 68 69 void getAnalysisUsage(AnalysisUsage &AU) const override { 70 AU.setPreservesCFG(); 71 AU.addRequired<TargetPassConfig>(); 72 AU.addRequired<LoopInfoWrapperPass>(); 73 FunctionPass::getAnalysisUsage(AU); 74 } 75 76 private: 77 LoopInfo *LI = nullptr; 78 79 // Check this is a valid gather with correct alignment 80 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, 81 Align Alignment); 82 // Check whether Ptr is hidden behind a bitcast and look through it 83 void lookThroughBitcast(Value *&Ptr); 84 // Check for a getelementptr and deduce base and offsets from it, on success 85 // returning the base directly and the offsets indirectly using the Offsets 86 // argument 87 Value *checkGEP(Value *&Offsets, Type *Ty, GetElementPtrInst *GEP, 88 IRBuilder<> &Builder); 89 // Compute the scale of this gather/scatter instruction 90 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); 91 // If the value is a constant, or derived from constants via additions 92 // and multilications, return its numeric value 93 Optional<int64_t> getIfConst(const Value *V); 94 // If Inst is an add instruction, check whether one summand is a 95 // constant. If so, scale this constant and return it together with 96 // the other summand. 97 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); 98 99 Value *lowerGather(IntrinsicInst *I); 100 // Create a gather from a base + vector of offsets 101 Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, 102 Instruction *&Root, IRBuilder<> &Builder); 103 // Create a gather from a vector of pointers 104 Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, 105 IRBuilder<> &Builder, int64_t Increment = 0); 106 // Create an incrementing gather from a vector of pointers 107 Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, 108 IRBuilder<> &Builder, 109 int64_t Increment = 0); 110 111 Value *lowerScatter(IntrinsicInst *I); 112 // Create a scatter to a base + vector of offsets 113 Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, 114 IRBuilder<> &Builder); 115 // Create a scatter to a vector of pointers 116 Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, 117 IRBuilder<> &Builder, 118 int64_t Increment = 0); 119 // Create an incrementing scatter from a vector of pointers 120 Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, 121 IRBuilder<> &Builder, 122 int64_t Increment = 0); 123 124 // QI gathers and scatters can increment their offsets on their own if 125 // the increment is a constant value (digit) 126 Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr, 127 Value *Ptr, GetElementPtrInst *GEP, 128 IRBuilder<> &Builder); 129 // QI gathers/scatters can increment their offsets on their own if the 130 // increment is a constant value (digit) - this creates a writeback QI 131 // gather/scatter 132 Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, 133 Value *Ptr, unsigned TypeScale, 134 IRBuilder<> &Builder); 135 // Check whether these offsets could be moved out of the loop they're in 136 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); 137 // Pushes the given add out of the loop 138 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); 139 // Pushes the given mul out of the loop 140 void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, 141 Value *OffsSecondOperand, unsigned LoopIncrement, 142 IRBuilder<> &Builder); 143 }; 144 145 } // end anonymous namespace 146 147 char MVEGatherScatterLowering::ID = 0; 148 149 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, 150 "MVE gather/scattering lowering pass", false, false) 151 152 Pass *llvm::createMVEGatherScatterLoweringPass() { 153 return new MVEGatherScatterLowering(); 154 } 155 156 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, 157 unsigned ElemSize, 158 Align Alignment) { 159 if (((NumElements == 4 && 160 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || 161 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || 162 (NumElements == 16 && ElemSize == 8)) && 163 Alignment >= ElemSize / 8) 164 return true; 165 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " 166 << "valid alignment or vector type \n"); 167 return false; 168 } 169 170 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, 171 GetElementPtrInst *GEP, 172 IRBuilder<> &Builder) { 173 if (!GEP) { 174 LLVM_DEBUG( 175 dbgs() << "masked gathers/scatters: no getelementpointer found\n"); 176 return nullptr; 177 } 178 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." 179 << " Looking at intrinsic for base + vector of offsets\n"); 180 Value *GEPPtr = GEP->getPointerOperand(); 181 if (GEPPtr->getType()->isVectorTy()) { 182 return nullptr; 183 } 184 if (GEP->getNumOperands() != 2) { 185 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" 186 << " operands. Expanding.\n"); 187 return nullptr; 188 } 189 Offsets = GEP->getOperand(1); 190 // Paranoid check whether the number of parallel lanes is the same 191 assert(cast<FixedVectorType>(Ty)->getNumElements() == 192 cast<FixedVectorType>(Offsets->getType())->getNumElements()); 193 // Only <N x i32> offsets can be integrated into an arm gather, any smaller 194 // type would have to be sign extended by the gep - and arm gathers can only 195 // zero extend. Additionally, the offsets do have to originate from a zext of 196 // a vector with element types smaller or equal the type of the gather we're 197 // looking at 198 if (Offsets->getType()->getScalarSizeInBits() != 32) 199 return nullptr; 200 if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets)) 201 Offsets = ZextOffs->getOperand(0); 202 else if (!(cast<FixedVectorType>(Offsets->getType())->getNumElements() == 4 && 203 Offsets->getType()->getScalarSizeInBits() == 32)) 204 return nullptr; 205 206 if (Ty != Offsets->getType()) { 207 if ((Ty->getScalarSizeInBits() < 208 Offsets->getType()->getScalarSizeInBits())) { 209 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no correct offset type." 210 << " Can't create intrinsic.\n"); 211 return nullptr; 212 } else { 213 Offsets = Builder.CreateZExt( 214 Offsets, VectorType::getInteger(cast<VectorType>(Ty))); 215 } 216 } 217 // If none of the checks failed, return the gep's base pointer 218 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); 219 return GEPPtr; 220 } 221 222 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { 223 // Look through bitcast instruction if #elements is the same 224 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { 225 auto *BCTy = cast<FixedVectorType>(BitCast->getType()); 226 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); 227 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { 228 LLVM_DEBUG( 229 dbgs() << "masked gathers/scatters: looking through bitcast\n"); 230 Ptr = BitCast->getOperand(0); 231 } 232 } 233 } 234 235 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, 236 unsigned MemoryElemSize) { 237 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, 238 // or a 8bit, 16bit or 32bit load/store scaled by 1 239 if (GEPElemSize == 32 && MemoryElemSize == 32) 240 return 2; 241 else if (GEPElemSize == 16 && MemoryElemSize == 16) 242 return 1; 243 else if (GEPElemSize == 8) 244 return 0; 245 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " 246 << "create intrinsic\n"); 247 return -1; 248 } 249 250 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { 251 const Constant *C = dyn_cast<Constant>(V); 252 if (C != nullptr) 253 return Optional<int64_t>{C->getUniqueInteger().getSExtValue()}; 254 if (!isa<Instruction>(V)) 255 return Optional<int64_t>{}; 256 257 const Instruction *I = cast<Instruction>(V); 258 if (I->getOpcode() == Instruction::Add || 259 I->getOpcode() == Instruction::Mul) { 260 Optional<int64_t> Op0 = getIfConst(I->getOperand(0)); 261 Optional<int64_t> Op1 = getIfConst(I->getOperand(1)); 262 if (!Op0 || !Op1) 263 return Optional<int64_t>{}; 264 if (I->getOpcode() == Instruction::Add) 265 return Optional<int64_t>{Op0.getValue() + Op1.getValue()}; 266 if (I->getOpcode() == Instruction::Mul) 267 return Optional<int64_t>{Op0.getValue() * Op1.getValue()}; 268 } 269 return Optional<int64_t>{}; 270 } 271 272 std::pair<Value *, int64_t> 273 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { 274 std::pair<Value *, int64_t> ReturnFalse = 275 std::pair<Value *, int64_t>(nullptr, 0); 276 // At this point, the instruction we're looking at must be an add or we 277 // bail out 278 Instruction *Add = dyn_cast<Instruction>(Inst); 279 if (Add == nullptr || Add->getOpcode() != Instruction::Add) 280 return ReturnFalse; 281 282 Value *Summand; 283 Optional<int64_t> Const; 284 // Find out which operand the value that is increased is 285 if ((Const = getIfConst(Add->getOperand(0)))) 286 Summand = Add->getOperand(1); 287 else if ((Const = getIfConst(Add->getOperand(1)))) 288 Summand = Add->getOperand(0); 289 else 290 return ReturnFalse; 291 292 // Check that the constant is small enough for an incrementing gather 293 int64_t Immediate = Const.getValue() << TypeScale; 294 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) 295 return ReturnFalse; 296 297 return std::pair<Value *, int64_t>(Summand, Immediate); 298 } 299 300 Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { 301 using namespace PatternMatch; 302 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"); 303 304 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) 305 // Attempt to turn the masked gather in I into a MVE intrinsic 306 // Potentially optimising the addressing modes as we do so. 307 auto *Ty = cast<FixedVectorType>(I->getType()); 308 Value *Ptr = I->getArgOperand(0); 309 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); 310 Value *Mask = I->getArgOperand(2); 311 Value *PassThru = I->getArgOperand(3); 312 313 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 314 Alignment)) 315 return nullptr; 316 lookThroughBitcast(Ptr); 317 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 318 319 IRBuilder<> Builder(I->getContext()); 320 Builder.SetInsertPoint(I); 321 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 322 323 Instruction *Root = I; 324 Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); 325 if (!Load) 326 Load = tryCreateMaskedGatherBase(I, Ptr, Builder); 327 if (!Load) 328 return nullptr; 329 330 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { 331 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " 332 << "creating select\n"); 333 Load = Builder.CreateSelect(Mask, Load, PassThru); 334 } 335 336 Root->replaceAllUsesWith(Load); 337 Root->eraseFromParent(); 338 if (Root != I) 339 // If this was an extending gather, we need to get rid of the sext/zext 340 // sext/zext as well as of the gather itself 341 I->eraseFromParent(); 342 343 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"); 344 return Load; 345 } 346 347 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I, 348 Value *Ptr, 349 IRBuilder<> &Builder, 350 int64_t Increment) { 351 using namespace PatternMatch; 352 auto *Ty = cast<FixedVectorType>(I->getType()); 353 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); 354 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 355 // Can't build an intrinsic for this 356 return nullptr; 357 Value *Mask = I->getArgOperand(2); 358 if (match(Mask, m_One())) 359 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, 360 {Ty, Ptr->getType()}, 361 {Ptr, Builder.getInt32(Increment)}); 362 else 363 return Builder.CreateIntrinsic( 364 Intrinsic::arm_mve_vldr_gather_base_predicated, 365 {Ty, Ptr->getType(), Mask->getType()}, 366 {Ptr, Builder.getInt32(Increment), Mask}); 367 } 368 369 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( 370 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 371 using namespace PatternMatch; 372 auto *Ty = cast<FixedVectorType>(I->getType()); 373 LLVM_DEBUG( 374 dbgs() 375 << "masked gathers: loading from vector of pointers with writeback\n"); 376 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 377 // Can't build an intrinsic for this 378 return nullptr; 379 Value *Mask = I->getArgOperand(2); 380 if (match(Mask, m_One())) 381 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, 382 {Ty, Ptr->getType()}, 383 {Ptr, Builder.getInt32(Increment)}); 384 else 385 return Builder.CreateIntrinsic( 386 Intrinsic::arm_mve_vldr_gather_base_wb_predicated, 387 {Ty, Ptr->getType(), Mask->getType()}, 388 {Ptr, Builder.getInt32(Increment), Mask}); 389 } 390 391 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( 392 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { 393 using namespace PatternMatch; 394 395 Type *OriginalTy = I->getType(); 396 Type *ResultTy = OriginalTy; 397 398 unsigned Unsigned = 1; 399 // The size of the gather was already checked in isLegalTypeAndAlignment; 400 // if it was not a full vector width an appropriate extend should follow. 401 auto *Extend = Root; 402 if (OriginalTy->getPrimitiveSizeInBits() < 128) { 403 // Only transform gathers with exactly one use 404 if (!I->hasOneUse()) 405 return nullptr; 406 407 // The correct root to replace is not the CallInst itself, but the 408 // instruction which extends it 409 Extend = cast<Instruction>(*I->users().begin()); 410 if (isa<SExtInst>(Extend)) { 411 Unsigned = 0; 412 } else if (!isa<ZExtInst>(Extend)) { 413 LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. " 414 << "Expanding\n"); 415 return nullptr; 416 } 417 LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n"); 418 ResultTy = Extend->getType(); 419 // The final size of the gather must be a full vector width 420 if (ResultTy->getPrimitiveSizeInBits() != 128) { 421 LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. " 422 << "Expanding\n"); 423 return nullptr; 424 } 425 } 426 427 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 428 Value *Offsets; 429 Value *BasePtr = checkGEP(Offsets, ResultTy, GEP, Builder); 430 if (!BasePtr) 431 return nullptr; 432 // Check whether the offset is a constant increment that could be merged into 433 // a QI gather 434 Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); 435 if (Load) 436 return Load; 437 438 int Scale = computeScale( 439 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), 440 OriginalTy->getScalarSizeInBits()); 441 if (Scale == -1) 442 return nullptr; 443 Root = Extend; 444 445 Value *Mask = I->getArgOperand(2); 446 if (!match(Mask, m_One())) 447 return Builder.CreateIntrinsic( 448 Intrinsic::arm_mve_vldr_gather_offset_predicated, 449 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, 450 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), 451 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); 452 else 453 return Builder.CreateIntrinsic( 454 Intrinsic::arm_mve_vldr_gather_offset, 455 {ResultTy, BasePtr->getType(), Offsets->getType()}, 456 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), 457 Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); 458 } 459 460 Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { 461 using namespace PatternMatch; 462 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"); 463 464 // @llvm.masked.scatter.*(data, ptrs, alignment, mask) 465 // Attempt to turn the masked scatter in I into a MVE intrinsic 466 // Potentially optimising the addressing modes as we do so. 467 Value *Input = I->getArgOperand(0); 468 Value *Ptr = I->getArgOperand(1); 469 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); 470 auto *Ty = cast<FixedVectorType>(Input->getType()); 471 472 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 473 Alignment)) 474 return nullptr; 475 476 lookThroughBitcast(Ptr); 477 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 478 479 IRBuilder<> Builder(I->getContext()); 480 Builder.SetInsertPoint(I); 481 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 482 483 Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); 484 if (!Store) 485 Store = tryCreateMaskedScatterBase(I, Ptr, Builder); 486 if (!Store) 487 return nullptr; 488 489 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"); 490 I->eraseFromParent(); 491 return Store; 492 } 493 494 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase( 495 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 496 using namespace PatternMatch; 497 Value *Input = I->getArgOperand(0); 498 auto *Ty = cast<FixedVectorType>(Input->getType()); 499 // Only QR variants allow truncating 500 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { 501 // Can't build an intrinsic for this 502 return nullptr; 503 } 504 Value *Mask = I->getArgOperand(3); 505 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) 506 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); 507 if (match(Mask, m_One())) 508 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, 509 {Ptr->getType(), Input->getType()}, 510 {Ptr, Builder.getInt32(Increment), Input}); 511 else 512 return Builder.CreateIntrinsic( 513 Intrinsic::arm_mve_vstr_scatter_base_predicated, 514 {Ptr->getType(), Input->getType(), Mask->getType()}, 515 {Ptr, Builder.getInt32(Increment), Input, Mask}); 516 } 517 518 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( 519 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 520 using namespace PatternMatch; 521 Value *Input = I->getArgOperand(0); 522 auto *Ty = cast<FixedVectorType>(Input->getType()); 523 LLVM_DEBUG( 524 dbgs() 525 << "masked scatters: storing to a vector of pointers with writeback\n"); 526 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 527 // Can't build an intrinsic for this 528 return nullptr; 529 Value *Mask = I->getArgOperand(3); 530 if (match(Mask, m_One())) 531 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, 532 {Ptr->getType(), Input->getType()}, 533 {Ptr, Builder.getInt32(Increment), Input}); 534 else 535 return Builder.CreateIntrinsic( 536 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, 537 {Ptr->getType(), Input->getType(), Mask->getType()}, 538 {Ptr, Builder.getInt32(Increment), Input, Mask}); 539 } 540 541 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( 542 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 543 using namespace PatternMatch; 544 Value *Input = I->getArgOperand(0); 545 Value *Mask = I->getArgOperand(3); 546 Type *InputTy = Input->getType(); 547 Type *MemoryTy = InputTy; 548 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" 549 << " to base + vector of offsets\n"); 550 // If the input has been truncated, try to integrate that trunc into the 551 // scatter instruction (we don't care about alignment here) 552 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { 553 Value *PreTrunc = Trunc->getOperand(0); 554 Type *PreTruncTy = PreTrunc->getType(); 555 if (PreTruncTy->getPrimitiveSizeInBits() == 128) { 556 Input = PreTrunc; 557 InputTy = PreTruncTy; 558 } 559 } 560 if (InputTy->getPrimitiveSizeInBits() != 128) { 561 LLVM_DEBUG( 562 dbgs() << "masked scatters: cannot create scatters for non-standard" 563 << " input types. Expanding.\n"); 564 return nullptr; 565 } 566 567 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 568 Value *Offsets; 569 Value *BasePtr = checkGEP(Offsets, InputTy, GEP, Builder); 570 if (!BasePtr) 571 return nullptr; 572 // Check whether the offset is a constant increment that could be merged into 573 // a QI gather 574 Value *Store = 575 tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); 576 if (Store) 577 return Store; 578 int Scale = computeScale( 579 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), 580 MemoryTy->getScalarSizeInBits()); 581 if (Scale == -1) 582 return nullptr; 583 584 if (!match(Mask, m_One())) 585 return Builder.CreateIntrinsic( 586 Intrinsic::arm_mve_vstr_scatter_offset_predicated, 587 {BasePtr->getType(), Offsets->getType(), Input->getType(), 588 Mask->getType()}, 589 {BasePtr, Offsets, Input, 590 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 591 Builder.getInt32(Scale), Mask}); 592 else 593 return Builder.CreateIntrinsic( 594 Intrinsic::arm_mve_vstr_scatter_offset, 595 {BasePtr->getType(), Offsets->getType(), Input->getType()}, 596 {BasePtr, Offsets, Input, 597 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 598 Builder.getInt32(Scale)}); 599 } 600 601 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat( 602 IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP, 603 IRBuilder<> &Builder) { 604 FixedVectorType *Ty; 605 if (I->getIntrinsicID() == Intrinsic::masked_gather) 606 Ty = cast<FixedVectorType>(I->getType()); 607 else 608 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); 609 // Incrementing gathers only exist for v4i32 610 if (Ty->getNumElements() != 4 || 611 Ty->getScalarSizeInBits() != 32) 612 return nullptr; 613 Loop *L = LI->getLoopFor(I->getParent()); 614 if (L == nullptr) 615 // Incrementing gathers are not beneficial outside of a loop 616 return nullptr; 617 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 618 "wb gather/scatter\n"); 619 620 // The gep was in charge of making sure the offsets are scaled correctly 621 // - calculate that factor so it can be applied by hand 622 DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout(); 623 int TypeScale = 624 computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()), 625 DT.getTypeSizeInBits(GEP->getType()) / 626 cast<FixedVectorType>(GEP->getType())->getNumElements()); 627 if (TypeScale == -1) 628 return nullptr; 629 630 if (GEP->hasOneUse()) { 631 // Only in this case do we want to build a wb gather, because the wb will 632 // change the phi which does affect other users of the gep (which will still 633 // be using the phi in the old way) 634 Value *Load = 635 tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder); 636 if (Load != nullptr) 637 return Load; 638 } 639 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 640 "non-wb gather/scatter\n"); 641 642 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 643 if (Add.first == nullptr) 644 return nullptr; 645 Value *OffsetsIncoming = Add.first; 646 int64_t Immediate = Add.second; 647 648 // Make sure the offsets are scaled correctly 649 Instruction *ScaledOffsets = BinaryOperator::Create( 650 Instruction::Shl, OffsetsIncoming, 651 Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), 652 "ScaledIndex", I); 653 // Add the base to the offsets 654 OffsetsIncoming = BinaryOperator::Create( 655 Instruction::Add, ScaledOffsets, 656 Builder.CreateVectorSplat( 657 Ty->getNumElements(), 658 Builder.CreatePtrToInt( 659 BasePtr, 660 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 661 "StartIndex", I); 662 663 if (I->getIntrinsicID() == Intrinsic::masked_gather) 664 return cast<IntrinsicInst>( 665 tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate)); 666 else 667 return cast<IntrinsicInst>( 668 tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate)); 669 } 670 671 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( 672 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, 673 IRBuilder<> &Builder) { 674 // Check whether this gather's offset is incremented by a constant - if so, 675 // and the load is of the right type, we can merge this into a QI gather 676 Loop *L = LI->getLoopFor(I->getParent()); 677 // Offsets that are worth merging into this instruction will be incremented 678 // by a constant, thus we're looking for an add of a phi and a constant 679 PHINode *Phi = dyn_cast<PHINode>(Offsets); 680 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || 681 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) 682 // No phi means no IV to write back to; if there is a phi, we expect it 683 // to have exactly two incoming values; the only phis we are interested in 684 // will be loop IV's and have exactly two uses, one in their increment and 685 // one in the gather's gep 686 return nullptr; 687 688 unsigned IncrementIndex = 689 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; 690 // Look through the phi to the phi increment 691 Offsets = Phi->getIncomingValue(IncrementIndex); 692 693 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 694 if (Add.first == nullptr) 695 return nullptr; 696 Value *OffsetsIncoming = Add.first; 697 int64_t Immediate = Add.second; 698 if (OffsetsIncoming != Phi) 699 // Then the increment we are looking at is not an increment of the 700 // induction variable, and we don't want to do a writeback 701 return nullptr; 702 703 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); 704 unsigned NumElems = 705 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); 706 707 // Make sure the offsets are scaled correctly 708 Instruction *ScaledOffsets = BinaryOperator::Create( 709 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), 710 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), 711 "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 712 // Add the base to the offsets 713 OffsetsIncoming = BinaryOperator::Create( 714 Instruction::Add, ScaledOffsets, 715 Builder.CreateVectorSplat( 716 NumElems, 717 Builder.CreatePtrToInt( 718 BasePtr, 719 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 720 "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 721 // The gather is pre-incrementing 722 OffsetsIncoming = BinaryOperator::Create( 723 Instruction::Sub, OffsetsIncoming, 724 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), 725 "PreIncrementStartIndex", 726 &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 727 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); 728 729 Builder.SetInsertPoint(I); 730 731 Value *EndResult; 732 Value *NewInduction; 733 if (I->getIntrinsicID() == Intrinsic::masked_gather) { 734 // Build the incrementing gather 735 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); 736 // One value to be handed to whoever uses the gather, one is the loop 737 // increment 738 EndResult = Builder.CreateExtractValue(Load, 0, "Gather"); 739 NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement"); 740 } else { 741 // Build the incrementing scatter 742 NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); 743 EndResult = NewInduction; 744 } 745 Instruction *AddInst = cast<Instruction>(Offsets); 746 AddInst->replaceAllUsesWith(NewInduction); 747 AddInst->eraseFromParent(); 748 Phi->setIncomingValue(IncrementIndex, NewInduction); 749 750 return EndResult; 751 } 752 753 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, 754 Value *OffsSecondOperand, 755 unsigned StartIndex) { 756 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); 757 Instruction *InsertionPoint = 758 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back()); 759 // Initialize the phi with a vector that contains a sum of the constants 760 Instruction *NewIndex = BinaryOperator::Create( 761 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, 762 "PushedOutAdd", InsertionPoint); 763 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; 764 765 // Order such that start index comes first (this reduces mov's) 766 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); 767 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), 768 Phi->getIncomingBlock(IncrementIndex)); 769 Phi->removeIncomingValue(IncrementIndex); 770 Phi->removeIncomingValue(StartIndex); 771 } 772 773 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, 774 Value *IncrementPerRound, 775 Value *OffsSecondOperand, 776 unsigned LoopIncrement, 777 IRBuilder<> &Builder) { 778 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); 779 780 // Create a new scalar add outside of the loop and transform it to a splat 781 // by which loop variable can be incremented 782 Instruction *InsertionPoint = &cast<Instruction>( 783 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); 784 785 // Create a new index 786 Value *StartIndex = BinaryOperator::Create( 787 Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), 788 OffsSecondOperand, "PushedOutMul", InsertionPoint); 789 790 Instruction *Product = 791 BinaryOperator::Create(Instruction::Mul, IncrementPerRound, 792 OffsSecondOperand, "Product", InsertionPoint); 793 // Increment NewIndex by Product instead of the multiplication 794 Instruction *NewIncrement = BinaryOperator::Create( 795 Instruction::Add, Phi, Product, "IncrementPushedOutMul", 796 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back()) 797 .getPrevNode()); 798 799 Phi->addIncoming(StartIndex, 800 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); 801 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); 802 Phi->removeIncomingValue((unsigned)0); 803 Phi->removeIncomingValue((unsigned)0); 804 return; 805 } 806 807 // Check whether all usages of this instruction are as offsets of 808 // gathers/scatters or simple arithmetics only used by gathers/scatters 809 static bool hasAllGatScatUsers(Instruction *I) { 810 if (I->hasNUses(0)) { 811 return false; 812 } 813 bool Gatscat = true; 814 for (User *U : I->users()) { 815 if (!isa<Instruction>(U)) 816 return false; 817 if (isa<GetElementPtrInst>(U) || 818 isGatherScatter(dyn_cast<IntrinsicInst>(U))) { 819 return Gatscat; 820 } else { 821 unsigned OpCode = cast<Instruction>(U)->getOpcode(); 822 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && 823 hasAllGatScatUsers(cast<Instruction>(U))) { 824 continue; 825 } 826 return false; 827 } 828 } 829 return Gatscat; 830 } 831 832 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, 833 LoopInfo *LI) { 834 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"); 835 // Optimise the addresses of gathers/scatters by moving invariant 836 // calculations out of the loop 837 if (!isa<Instruction>(Offsets)) 838 return false; 839 Instruction *Offs = cast<Instruction>(Offsets); 840 if (Offs->getOpcode() != Instruction::Add && 841 Offs->getOpcode() != Instruction::Mul) 842 return false; 843 Loop *L = LI->getLoopFor(BB); 844 if (L == nullptr) 845 return false; 846 if (!Offs->hasOneUse()) { 847 if (!hasAllGatScatUsers(Offs)) 848 return false; 849 } 850 851 // Find out which, if any, operand of the instruction 852 // is a phi node 853 PHINode *Phi; 854 int OffsSecondOp; 855 if (isa<PHINode>(Offs->getOperand(0))) { 856 Phi = cast<PHINode>(Offs->getOperand(0)); 857 OffsSecondOp = 1; 858 } else if (isa<PHINode>(Offs->getOperand(1))) { 859 Phi = cast<PHINode>(Offs->getOperand(1)); 860 OffsSecondOp = 0; 861 } else { 862 bool Changed = true; 863 if (isa<Instruction>(Offs->getOperand(0)) && 864 L->contains(cast<Instruction>(Offs->getOperand(0)))) 865 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); 866 if (isa<Instruction>(Offs->getOperand(1)) && 867 L->contains(cast<Instruction>(Offs->getOperand(1)))) 868 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); 869 if (!Changed) { 870 return false; 871 } else { 872 if (isa<PHINode>(Offs->getOperand(0))) { 873 Phi = cast<PHINode>(Offs->getOperand(0)); 874 OffsSecondOp = 1; 875 } else if (isa<PHINode>(Offs->getOperand(1))) { 876 Phi = cast<PHINode>(Offs->getOperand(1)); 877 OffsSecondOp = 0; 878 } else { 879 return false; 880 } 881 } 882 } 883 // A phi node we want to perform this function on should be from the 884 // loop header, and shouldn't have more than 2 incoming values 885 if (Phi->getParent() != L->getHeader() || 886 Phi->getNumIncomingValues() != 2) 887 return false; 888 889 // The phi must be an induction variable 890 Instruction *Op; 891 int IncrementingBlock = -1; 892 893 for (int i = 0; i < 2; i++) 894 if ((Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) != nullptr) 895 if (Op->getOpcode() == Instruction::Add && 896 (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi)) 897 IncrementingBlock = i; 898 if (IncrementingBlock == -1) 899 return false; 900 901 Instruction *IncInstruction = 902 cast<Instruction>(Phi->getIncomingValue(IncrementingBlock)); 903 904 // If the phi is not used by anything else, we can just adapt it when 905 // replacing the instruction; if it is, we'll have to duplicate it 906 PHINode *NewPhi; 907 Value *IncrementPerRound = IncInstruction->getOperand( 908 (IncInstruction->getOperand(0) == Phi) ? 1 : 0); 909 910 // Get the value that is added to/multiplied with the phi 911 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); 912 913 if (IncrementPerRound->getType() != OffsSecondOperand->getType()) 914 // Something has gone wrong, abort 915 return false; 916 917 // Only proceed if the increment per round is a constant or an instruction 918 // which does not originate from within the loop 919 if (!isa<Constant>(IncrementPerRound) && 920 !(isa<Instruction>(IncrementPerRound) && 921 !L->contains(cast<Instruction>(IncrementPerRound)))) 922 return false; 923 924 if (Phi->getNumUses() == 2) { 925 // No other users -> reuse existing phi (One user is the instruction 926 // we're looking at, the other is the phi increment) 927 if (IncInstruction->getNumUses() != 1) { 928 // If the incrementing instruction does have more users than 929 // our phi, we need to copy it 930 IncInstruction = BinaryOperator::Create( 931 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, 932 IncrementPerRound, "LoopIncrement", IncInstruction); 933 Phi->setIncomingValue(IncrementingBlock, IncInstruction); 934 } 935 NewPhi = Phi; 936 } else { 937 // There are other users -> create a new phi 938 NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi); 939 std::vector<Value *> Increases; 940 // Copy the incoming values of the old phi 941 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), 942 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); 943 IncInstruction = BinaryOperator::Create( 944 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, 945 IncrementPerRound, "LoopIncrement", IncInstruction); 946 NewPhi->addIncoming(IncInstruction, 947 Phi->getIncomingBlock(IncrementingBlock)); 948 IncrementingBlock = 1; 949 } 950 951 IRBuilder<> Builder(BB->getContext()); 952 Builder.SetInsertPoint(Phi); 953 Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); 954 955 switch (Offs->getOpcode()) { 956 case Instruction::Add: 957 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); 958 break; 959 case Instruction::Mul: 960 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, 961 Builder); 962 break; 963 default: 964 return false; 965 } 966 LLVM_DEBUG( 967 dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n"); 968 969 // The instruction has now been "absorbed" into the phi value 970 Offs->replaceAllUsesWith(NewPhi); 971 if (Offs->hasNUses(0)) 972 Offs->eraseFromParent(); 973 // Clean up the old increment in case it's unused because we built a new 974 // one 975 if (IncInstruction->hasNUses(0)) 976 IncInstruction->eraseFromParent(); 977 978 return true; 979 } 980 981 bool MVEGatherScatterLowering::runOnFunction(Function &F) { 982 if (!EnableMaskedGatherScatters) 983 return false; 984 auto &TPC = getAnalysis<TargetPassConfig>(); 985 auto &TM = TPC.getTM<TargetMachine>(); 986 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 987 if (!ST->hasMVEIntegerOps()) 988 return false; 989 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 990 SmallVector<IntrinsicInst *, 4> Gathers; 991 SmallVector<IntrinsicInst *, 4> Scatters; 992 993 bool Changed = false; 994 995 for (BasicBlock &BB : F) { 996 for (Instruction &I : BB) { 997 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 998 if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { 999 Gathers.push_back(II); 1000 if (isa<GetElementPtrInst>(II->getArgOperand(0))) 1001 Changed |= optimiseOffsets( 1002 cast<Instruction>(II->getArgOperand(0))->getOperand(1), 1003 II->getParent(), LI); 1004 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { 1005 Scatters.push_back(II); 1006 if (isa<GetElementPtrInst>(II->getArgOperand(1))) 1007 Changed |= optimiseOffsets( 1008 cast<Instruction>(II->getArgOperand(1))->getOperand(1), 1009 II->getParent(), LI); 1010 } 1011 } 1012 } 1013 1014 for (unsigned i = 0; i < Gathers.size(); i++) { 1015 IntrinsicInst *I = Gathers[i]; 1016 Value *L = lowerGather(I); 1017 if (L == nullptr) 1018 continue; 1019 1020 // Get rid of any now dead instructions 1021 SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent()); 1022 Changed = true; 1023 } 1024 1025 for (unsigned i = 0; i < Scatters.size(); i++) { 1026 IntrinsicInst *I = Scatters[i]; 1027 Value *S = lowerScatter(I); 1028 if (S == nullptr) 1029 continue; 1030 1031 // Get rid of any now dead instructions 1032 SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent()); 1033 Changed = true; 1034 } 1035 return Changed; 1036 } 1037