1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to 10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to 11 /// produce a better final result as we go. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "ARM.h" 16 #include "ARMBaseInstrInfo.h" 17 #include "ARMSubtarget.h" 18 #include "llvm/Analysis/LoopInfo.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/ValueTracking.h" 21 #include "llvm/CodeGen/TargetLowering.h" 22 #include "llvm/CodeGen/TargetPassConfig.h" 23 #include "llvm/CodeGen/TargetSubtargetInfo.h" 24 #include "llvm/InitializePasses.h" 25 #include "llvm/IR/BasicBlock.h" 26 #include "llvm/IR/Constant.h" 27 #include "llvm/IR/Constants.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/InstrTypes.h" 31 #include "llvm/IR/Instruction.h" 32 #include "llvm/IR/Instructions.h" 33 #include "llvm/IR/IntrinsicInst.h" 34 #include "llvm/IR/Intrinsics.h" 35 #include "llvm/IR/IntrinsicsARM.h" 36 #include "llvm/IR/IRBuilder.h" 37 #include "llvm/IR/PatternMatch.h" 38 #include "llvm/IR/Type.h" 39 #include "llvm/IR/Value.h" 40 #include "llvm/Pass.h" 41 #include "llvm/Support/Casting.h" 42 #include "llvm/Transforms/Utils/Local.h" 43 #include <algorithm> 44 #include <cassert> 45 46 using namespace llvm; 47 48 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering" 49 50 cl::opt<bool> EnableMaskedGatherScatters( 51 "enable-arm-maskedgatscat", cl::Hidden, cl::init(true), 52 cl::desc("Enable the generation of masked gathers and scatters")); 53 54 namespace { 55 56 class MVEGatherScatterLowering : public FunctionPass { 57 public: 58 static char ID; // Pass identification, replacement for typeid 59 60 explicit MVEGatherScatterLowering() : FunctionPass(ID) { 61 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); 62 } 63 64 bool runOnFunction(Function &F) override; 65 66 StringRef getPassName() const override { 67 return "MVE gather/scatter lowering"; 68 } 69 70 void getAnalysisUsage(AnalysisUsage &AU) const override { 71 AU.setPreservesCFG(); 72 AU.addRequired<TargetPassConfig>(); 73 AU.addRequired<LoopInfoWrapperPass>(); 74 FunctionPass::getAnalysisUsage(AU); 75 } 76 77 private: 78 LoopInfo *LI = nullptr; 79 80 // Check this is a valid gather with correct alignment 81 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, 82 Align Alignment); 83 // Check whether Ptr is hidden behind a bitcast and look through it 84 void lookThroughBitcast(Value *&Ptr); 85 // Decompose a ptr into Base and Offsets, potentially using a GEP to return a 86 // scalar base and vector offsets, or else fallback to using a base of 0 and 87 // offset of Ptr where possible. 88 Value *decomposePtr(Value *Ptr, Value *&Offsets, int &Scale, 89 FixedVectorType *Ty, Type *MemoryTy, 90 IRBuilder<> &Builder); 91 // Check for a getelementptr and deduce base and offsets from it, on success 92 // returning the base directly and the offsets indirectly using the Offsets 93 // argument 94 Value *decomposeGEP(Value *&Offsets, FixedVectorType *Ty, 95 GetElementPtrInst *GEP, IRBuilder<> &Builder); 96 // Compute the scale of this gather/scatter instruction 97 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); 98 // If the value is a constant, or derived from constants via additions 99 // and multilications, return its numeric value 100 Optional<int64_t> getIfConst(const Value *V); 101 // If Inst is an add instruction, check whether one summand is a 102 // constant. If so, scale this constant and return it together with 103 // the other summand. 104 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); 105 106 Instruction *lowerGather(IntrinsicInst *I); 107 // Create a gather from a base + vector of offsets 108 Instruction *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, 109 Instruction *&Root, 110 IRBuilder<> &Builder); 111 // Create a gather from a vector of pointers 112 Instruction *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, 113 IRBuilder<> &Builder, 114 int64_t Increment = 0); 115 // Create an incrementing gather from a vector of pointers 116 Instruction *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, 117 IRBuilder<> &Builder, 118 int64_t Increment = 0); 119 120 Instruction *lowerScatter(IntrinsicInst *I); 121 // Create a scatter to a base + vector of offsets 122 Instruction *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, 123 IRBuilder<> &Builder); 124 // Create a scatter to a vector of pointers 125 Instruction *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, 126 IRBuilder<> &Builder, 127 int64_t Increment = 0); 128 // Create an incrementing scatter from a vector of pointers 129 Instruction *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, 130 IRBuilder<> &Builder, 131 int64_t Increment = 0); 132 133 // QI gathers and scatters can increment their offsets on their own if 134 // the increment is a constant value (digit) 135 Instruction *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *Ptr, 136 IRBuilder<> &Builder); 137 // QI gathers/scatters can increment their offsets on their own if the 138 // increment is a constant value (digit) - this creates a writeback QI 139 // gather/scatter 140 Instruction *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, 141 Value *Ptr, unsigned TypeScale, 142 IRBuilder<> &Builder); 143 144 // Optimise the base and offsets of the given address 145 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI); 146 // Try to fold consecutive geps together into one 147 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder); 148 // Check whether these offsets could be moved out of the loop they're in 149 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); 150 // Pushes the given add out of the loop 151 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); 152 // Pushes the given mul out of the loop 153 void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, 154 Value *OffsSecondOperand, unsigned LoopIncrement, 155 IRBuilder<> &Builder); 156 }; 157 158 } // end anonymous namespace 159 160 char MVEGatherScatterLowering::ID = 0; 161 162 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, 163 "MVE gather/scattering lowering pass", false, false) 164 165 Pass *llvm::createMVEGatherScatterLoweringPass() { 166 return new MVEGatherScatterLowering(); 167 } 168 169 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, 170 unsigned ElemSize, 171 Align Alignment) { 172 if (((NumElements == 4 && 173 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || 174 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || 175 (NumElements == 16 && ElemSize == 8)) && 176 Alignment >= ElemSize / 8) 177 return true; 178 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " 179 << "valid alignment or vector type \n"); 180 return false; 181 } 182 183 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) { 184 // Offsets that are not of type <N x i32> are sign extended by the 185 // getelementptr instruction, and MVE gathers/scatters treat the offset as 186 // unsigned. Thus, if the element size is smaller than 32, we can only allow 187 // positive offsets - i.e., the offsets are not allowed to be variables we 188 // can't look into. 189 // Additionally, <N x i32> offsets have to either originate from a zext of a 190 // vector with element types smaller or equal the type of the gather we're 191 // looking at, or consist of constants that we can check are small enough 192 // to fit into the gather type. 193 // Thus we check that 0 < value < 2^TargetElemSize. 194 unsigned TargetElemSize = 128 / TargetElemCount; 195 unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType()) 196 ->getElementType() 197 ->getScalarSizeInBits(); 198 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) { 199 Constant *ConstOff = dyn_cast<Constant>(Offsets); 200 if (!ConstOff) 201 return false; 202 int64_t TargetElemMaxSize = (1ULL << TargetElemSize); 203 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) { 204 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem); 205 if (!OConst) 206 return false; 207 int SExtValue = OConst->getSExtValue(); 208 if (SExtValue >= TargetElemMaxSize || SExtValue < 0) 209 return false; 210 return true; 211 }; 212 if (isa<FixedVectorType>(ConstOff->getType())) { 213 for (unsigned i = 0; i < TargetElemCount; i++) { 214 if (!CheckValueSize(ConstOff->getAggregateElement(i))) 215 return false; 216 } 217 } else { 218 if (!CheckValueSize(ConstOff)) 219 return false; 220 } 221 } 222 return true; 223 } 224 225 Value *MVEGatherScatterLowering::decomposePtr(Value *Ptr, Value *&Offsets, 226 int &Scale, FixedVectorType *Ty, 227 Type *MemoryTy, 228 IRBuilder<> &Builder) { 229 if (auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) { 230 if (Value *V = decomposeGEP(Offsets, Ty, GEP, Builder)) { 231 Scale = 232 computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(), 233 MemoryTy->getScalarSizeInBits()); 234 return Scale == -1 ? nullptr : V; 235 } 236 } 237 238 // If we couldn't use the GEP (or it doesn't exist), attempt to use a 239 // BasePtr of 0 with Ptr as the Offsets, so long as there are only 4 240 // elements. 241 FixedVectorType *PtrTy = cast<FixedVectorType>(Ptr->getType()); 242 if (PtrTy->getNumElements() != 4 || MemoryTy->getScalarSizeInBits() == 32) 243 return nullptr; 244 Value *Zero = ConstantInt::get(Builder.getInt32Ty(), 0); 245 Value *BasePtr = Builder.CreateIntToPtr(Zero, Builder.getInt8PtrTy()); 246 Offsets = Builder.CreatePtrToInt( 247 Ptr, FixedVectorType::get(Builder.getInt32Ty(), 4)); 248 Scale = 0; 249 return BasePtr; 250 } 251 252 Value *MVEGatherScatterLowering::decomposeGEP(Value *&Offsets, 253 FixedVectorType *Ty, 254 GetElementPtrInst *GEP, 255 IRBuilder<> &Builder) { 256 if (!GEP) { 257 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer " 258 << "found\n"); 259 return nullptr; 260 } 261 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." 262 << " Looking at intrinsic for base + vector of offsets\n"); 263 Value *GEPPtr = GEP->getPointerOperand(); 264 Offsets = GEP->getOperand(1); 265 if (GEPPtr->getType()->isVectorTy() || 266 !isa<FixedVectorType>(Offsets->getType())) 267 return nullptr; 268 269 if (GEP->getNumOperands() != 2) { 270 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" 271 << " operands. Expanding.\n"); 272 return nullptr; 273 } 274 Offsets = GEP->getOperand(1); 275 unsigned OffsetsElemCount = 276 cast<FixedVectorType>(Offsets->getType())->getNumElements(); 277 // Paranoid check whether the number of parallel lanes is the same 278 assert(Ty->getNumElements() == OffsetsElemCount); 279 280 ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets); 281 if (ZextOffs) 282 Offsets = ZextOffs->getOperand(0); 283 FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType()); 284 285 // If the offsets are already being zext-ed to <N x i32>, that relieves us of 286 // having to make sure that they won't overflow. 287 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy()) 288 ->getElementType() 289 ->getScalarSizeInBits() != 32) 290 if (!checkOffsetSize(Offsets, OffsetsElemCount)) 291 return nullptr; 292 293 // The offset sizes have been checked; if any truncating or zext-ing is 294 // required to fix them, do that now 295 if (Ty != Offsets->getType()) { 296 if ((Ty->getElementType()->getScalarSizeInBits() < 297 OffsetType->getElementType()->getScalarSizeInBits())) { 298 Offsets = Builder.CreateTrunc(Offsets, Ty); 299 } else { 300 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty)); 301 } 302 } 303 // If none of the checks failed, return the gep's base pointer 304 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); 305 return GEPPtr; 306 } 307 308 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { 309 // Look through bitcast instruction if #elements is the same 310 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { 311 auto *BCTy = cast<FixedVectorType>(BitCast->getType()); 312 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); 313 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { 314 LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through " 315 << "bitcast\n"); 316 Ptr = BitCast->getOperand(0); 317 } 318 } 319 } 320 321 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, 322 unsigned MemoryElemSize) { 323 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, 324 // or a 8bit, 16bit or 32bit load/store scaled by 1 325 if (GEPElemSize == 32 && MemoryElemSize == 32) 326 return 2; 327 else if (GEPElemSize == 16 && MemoryElemSize == 16) 328 return 1; 329 else if (GEPElemSize == 8) 330 return 0; 331 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " 332 << "create intrinsic\n"); 333 return -1; 334 } 335 336 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { 337 const Constant *C = dyn_cast<Constant>(V); 338 if (C != nullptr) 339 return Optional<int64_t>{C->getUniqueInteger().getSExtValue()}; 340 if (!isa<Instruction>(V)) 341 return Optional<int64_t>{}; 342 343 const Instruction *I = cast<Instruction>(V); 344 if (I->getOpcode() == Instruction::Add || 345 I->getOpcode() == Instruction::Mul) { 346 Optional<int64_t> Op0 = getIfConst(I->getOperand(0)); 347 Optional<int64_t> Op1 = getIfConst(I->getOperand(1)); 348 if (!Op0 || !Op1) 349 return Optional<int64_t>{}; 350 if (I->getOpcode() == Instruction::Add) 351 return Optional<int64_t>{Op0.getValue() + Op1.getValue()}; 352 if (I->getOpcode() == Instruction::Mul) 353 return Optional<int64_t>{Op0.getValue() * Op1.getValue()}; 354 } 355 return Optional<int64_t>{}; 356 } 357 358 std::pair<Value *, int64_t> 359 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { 360 std::pair<Value *, int64_t> ReturnFalse = 361 std::pair<Value *, int64_t>(nullptr, 0); 362 // At this point, the instruction we're looking at must be an add or we 363 // bail out 364 Instruction *Add = dyn_cast<Instruction>(Inst); 365 if (Add == nullptr || Add->getOpcode() != Instruction::Add) 366 return ReturnFalse; 367 368 Value *Summand; 369 Optional<int64_t> Const; 370 // Find out which operand the value that is increased is 371 if ((Const = getIfConst(Add->getOperand(0)))) 372 Summand = Add->getOperand(1); 373 else if ((Const = getIfConst(Add->getOperand(1)))) 374 Summand = Add->getOperand(0); 375 else 376 return ReturnFalse; 377 378 // Check that the constant is small enough for an incrementing gather 379 int64_t Immediate = Const.getValue() << TypeScale; 380 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) 381 return ReturnFalse; 382 383 return std::pair<Value *, int64_t>(Summand, Immediate); 384 } 385 386 Instruction *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { 387 using namespace PatternMatch; 388 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n" 389 << *I << "\n"); 390 391 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) 392 // Attempt to turn the masked gather in I into a MVE intrinsic 393 // Potentially optimising the addressing modes as we do so. 394 auto *Ty = cast<FixedVectorType>(I->getType()); 395 Value *Ptr = I->getArgOperand(0); 396 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); 397 Value *Mask = I->getArgOperand(2); 398 Value *PassThru = I->getArgOperand(3); 399 400 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 401 Alignment)) 402 return nullptr; 403 lookThroughBitcast(Ptr); 404 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 405 406 IRBuilder<> Builder(I->getContext()); 407 Builder.SetInsertPoint(I); 408 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 409 410 Instruction *Root = I; 411 412 Instruction *Load = tryCreateIncrementingGatScat(I, Ptr, Builder); 413 if (!Load) 414 Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); 415 if (!Load) 416 Load = tryCreateMaskedGatherBase(I, Ptr, Builder); 417 if (!Load) 418 return nullptr; 419 420 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { 421 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " 422 << "creating select\n"); 423 Load = SelectInst::Create(Mask, Load, PassThru); 424 Builder.Insert(Load); 425 } 426 427 Root->replaceAllUsesWith(Load); 428 Root->eraseFromParent(); 429 if (Root != I) 430 // If this was an extending gather, we need to get rid of the sext/zext 431 // sext/zext as well as of the gather itself 432 I->eraseFromParent(); 433 434 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n" 435 << *Load << "\n"); 436 return Load; 437 } 438 439 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBase( 440 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 441 using namespace PatternMatch; 442 auto *Ty = cast<FixedVectorType>(I->getType()); 443 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); 444 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 445 // Can't build an intrinsic for this 446 return nullptr; 447 Value *Mask = I->getArgOperand(2); 448 if (match(Mask, m_One())) 449 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, 450 {Ty, Ptr->getType()}, 451 {Ptr, Builder.getInt32(Increment)}); 452 else 453 return Builder.CreateIntrinsic( 454 Intrinsic::arm_mve_vldr_gather_base_predicated, 455 {Ty, Ptr->getType(), Mask->getType()}, 456 {Ptr, Builder.getInt32(Increment), Mask}); 457 } 458 459 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( 460 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 461 using namespace PatternMatch; 462 auto *Ty = cast<FixedVectorType>(I->getType()); 463 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with " 464 << "writeback\n"); 465 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 466 // Can't build an intrinsic for this 467 return nullptr; 468 Value *Mask = I->getArgOperand(2); 469 if (match(Mask, m_One())) 470 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, 471 {Ty, Ptr->getType()}, 472 {Ptr, Builder.getInt32(Increment)}); 473 else 474 return Builder.CreateIntrinsic( 475 Intrinsic::arm_mve_vldr_gather_base_wb_predicated, 476 {Ty, Ptr->getType(), Mask->getType()}, 477 {Ptr, Builder.getInt32(Increment), Mask}); 478 } 479 480 Instruction *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( 481 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { 482 using namespace PatternMatch; 483 484 Type *MemoryTy = I->getType(); 485 Type *ResultTy = MemoryTy; 486 487 unsigned Unsigned = 1; 488 // The size of the gather was already checked in isLegalTypeAndAlignment; 489 // if it was not a full vector width an appropriate extend should follow. 490 auto *Extend = Root; 491 bool TruncResult = false; 492 if (MemoryTy->getPrimitiveSizeInBits() < 128) { 493 if (I->hasOneUse()) { 494 // If the gather has a single extend of the correct type, use an extending 495 // gather and replace the ext. In which case the correct root to replace 496 // is not the CallInst itself, but the instruction which extends it. 497 Instruction* User = cast<Instruction>(*I->users().begin()); 498 if (isa<SExtInst>(User) && 499 User->getType()->getPrimitiveSizeInBits() == 128) { 500 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: " 501 << *User << "\n"); 502 Extend = User; 503 ResultTy = User->getType(); 504 Unsigned = 0; 505 } else if (isa<ZExtInst>(User) && 506 User->getType()->getPrimitiveSizeInBits() == 128) { 507 LLVM_DEBUG(dbgs() << "masked gathers: Incorporating extend: " 508 << *ResultTy << "\n"); 509 Extend = User; 510 ResultTy = User->getType(); 511 } 512 } 513 514 // If an extend hasn't been found and the type is an integer, create an 515 // extending gather and truncate back to the original type. 516 if (ResultTy->getPrimitiveSizeInBits() < 128 && 517 ResultTy->isIntOrIntVectorTy()) { 518 ResultTy = ResultTy->getWithNewBitWidth( 519 128 / cast<FixedVectorType>(ResultTy)->getNumElements()); 520 TruncResult = true; 521 LLVM_DEBUG(dbgs() << "masked gathers: Small input type, truncing to: " 522 << *ResultTy << "\n"); 523 } 524 525 // The final size of the gather must be a full vector width 526 if (ResultTy->getPrimitiveSizeInBits() != 128) { 527 LLVM_DEBUG(dbgs() << "masked gathers: Extend needed but not provided " 528 "from the correct type. Expanding\n"); 529 return nullptr; 530 } 531 } 532 533 Value *Offsets; 534 int Scale; 535 Value *BasePtr = decomposePtr( 536 Ptr, Offsets, Scale, cast<FixedVectorType>(ResultTy), MemoryTy, Builder); 537 if (!BasePtr) 538 return nullptr; 539 540 Root = Extend; 541 Value *Mask = I->getArgOperand(2); 542 Instruction *Load = nullptr; 543 if (!match(Mask, m_One())) 544 Load = Builder.CreateIntrinsic( 545 Intrinsic::arm_mve_vldr_gather_offset_predicated, 546 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, 547 {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()), 548 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); 549 else 550 Load = Builder.CreateIntrinsic( 551 Intrinsic::arm_mve_vldr_gather_offset, 552 {ResultTy, BasePtr->getType(), Offsets->getType()}, 553 {BasePtr, Offsets, Builder.getInt32(MemoryTy->getScalarSizeInBits()), 554 Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); 555 556 if (TruncResult) { 557 Load = TruncInst::Create(Instruction::Trunc, Load, MemoryTy); 558 Builder.Insert(Load); 559 } 560 return Load; 561 } 562 563 Instruction *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { 564 using namespace PatternMatch; 565 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n" 566 << *I << "\n"); 567 568 // @llvm.masked.scatter.*(data, ptrs, alignment, mask) 569 // Attempt to turn the masked scatter in I into a MVE intrinsic 570 // Potentially optimising the addressing modes as we do so. 571 Value *Input = I->getArgOperand(0); 572 Value *Ptr = I->getArgOperand(1); 573 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); 574 auto *Ty = cast<FixedVectorType>(Input->getType()); 575 576 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 577 Alignment)) 578 return nullptr; 579 580 lookThroughBitcast(Ptr); 581 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 582 583 IRBuilder<> Builder(I->getContext()); 584 Builder.SetInsertPoint(I); 585 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 586 587 Instruction *Store = tryCreateIncrementingGatScat(I, Ptr, Builder); 588 if (!Store) 589 Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); 590 if (!Store) 591 Store = tryCreateMaskedScatterBase(I, Ptr, Builder); 592 if (!Store) 593 return nullptr; 594 595 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n" 596 << *Store << "\n"); 597 I->eraseFromParent(); 598 return Store; 599 } 600 601 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBase( 602 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 603 using namespace PatternMatch; 604 Value *Input = I->getArgOperand(0); 605 auto *Ty = cast<FixedVectorType>(Input->getType()); 606 // Only QR variants allow truncating 607 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { 608 // Can't build an intrinsic for this 609 return nullptr; 610 } 611 Value *Mask = I->getArgOperand(3); 612 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) 613 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); 614 if (match(Mask, m_One())) 615 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, 616 {Ptr->getType(), Input->getType()}, 617 {Ptr, Builder.getInt32(Increment), Input}); 618 else 619 return Builder.CreateIntrinsic( 620 Intrinsic::arm_mve_vstr_scatter_base_predicated, 621 {Ptr->getType(), Input->getType(), Mask->getType()}, 622 {Ptr, Builder.getInt32(Increment), Input, Mask}); 623 } 624 625 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( 626 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 627 using namespace PatternMatch; 628 Value *Input = I->getArgOperand(0); 629 auto *Ty = cast<FixedVectorType>(Input->getType()); 630 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers " 631 << "with writeback\n"); 632 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 633 // Can't build an intrinsic for this 634 return nullptr; 635 Value *Mask = I->getArgOperand(3); 636 if (match(Mask, m_One())) 637 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, 638 {Ptr->getType(), Input->getType()}, 639 {Ptr, Builder.getInt32(Increment), Input}); 640 else 641 return Builder.CreateIntrinsic( 642 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, 643 {Ptr->getType(), Input->getType(), Mask->getType()}, 644 {Ptr, Builder.getInt32(Increment), Input, Mask}); 645 } 646 647 Instruction *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( 648 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 649 using namespace PatternMatch; 650 Value *Input = I->getArgOperand(0); 651 Value *Mask = I->getArgOperand(3); 652 Type *InputTy = Input->getType(); 653 Type *MemoryTy = InputTy; 654 655 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" 656 << " to base + vector of offsets\n"); 657 // If the input has been truncated, try to integrate that trunc into the 658 // scatter instruction (we don't care about alignment here) 659 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { 660 Value *PreTrunc = Trunc->getOperand(0); 661 Type *PreTruncTy = PreTrunc->getType(); 662 if (PreTruncTy->getPrimitiveSizeInBits() == 128) { 663 Input = PreTrunc; 664 InputTy = PreTruncTy; 665 } 666 } 667 bool ExtendInput = false; 668 if (InputTy->getPrimitiveSizeInBits() < 128 && 669 InputTy->isIntOrIntVectorTy()) { 670 // If we can't find a trunc to incorporate into the instruction, create an 671 // implicit one with a zext, so that we can still create a scatter. We know 672 // that the input type is 4x/8x/16x and of type i8/i16/i32, so any type 673 // smaller than 128 bits will divide evenly into a 128bit vector. 674 InputTy = InputTy->getWithNewBitWidth( 675 128 / cast<FixedVectorType>(InputTy)->getNumElements()); 676 ExtendInput = true; 677 LLVM_DEBUG(dbgs() << "masked scatters: Small input type, will extend:\n" 678 << *Input << "\n"); 679 } 680 if (InputTy->getPrimitiveSizeInBits() != 128) { 681 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for " 682 "non-standard input types. Expanding.\n"); 683 return nullptr; 684 } 685 686 Value *Offsets; 687 int Scale; 688 Value *BasePtr = decomposePtr( 689 Ptr, Offsets, Scale, cast<FixedVectorType>(InputTy), MemoryTy, Builder); 690 if (!BasePtr) 691 return nullptr; 692 693 if (ExtendInput) 694 Input = Builder.CreateZExt(Input, InputTy); 695 if (!match(Mask, m_One())) 696 return Builder.CreateIntrinsic( 697 Intrinsic::arm_mve_vstr_scatter_offset_predicated, 698 {BasePtr->getType(), Offsets->getType(), Input->getType(), 699 Mask->getType()}, 700 {BasePtr, Offsets, Input, 701 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 702 Builder.getInt32(Scale), Mask}); 703 else 704 return Builder.CreateIntrinsic( 705 Intrinsic::arm_mve_vstr_scatter_offset, 706 {BasePtr->getType(), Offsets->getType(), Input->getType()}, 707 {BasePtr, Offsets, Input, 708 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 709 Builder.getInt32(Scale)}); 710 } 711 712 Instruction *MVEGatherScatterLowering::tryCreateIncrementingGatScat( 713 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 714 FixedVectorType *Ty; 715 if (I->getIntrinsicID() == Intrinsic::masked_gather) 716 Ty = cast<FixedVectorType>(I->getType()); 717 else 718 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); 719 720 // Incrementing gathers only exist for v4i32 721 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 722 return nullptr; 723 // Incrementing gathers are not beneficial outside of a loop 724 Loop *L = LI->getLoopFor(I->getParent()); 725 if (L == nullptr) 726 return nullptr; 727 728 // Decompose the GEP into Base and Offsets 729 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 730 Value *Offsets; 731 Value *BasePtr = decomposeGEP(Offsets, Ty, GEP, Builder); 732 if (!BasePtr) 733 return nullptr; 734 735 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 736 "wb gather/scatter\n"); 737 738 // The gep was in charge of making sure the offsets are scaled correctly 739 // - calculate that factor so it can be applied by hand 740 DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout(); 741 int TypeScale = 742 computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()), 743 DT.getTypeSizeInBits(GEP->getType()) / 744 cast<FixedVectorType>(GEP->getType())->getNumElements()); 745 if (TypeScale == -1) 746 return nullptr; 747 748 if (GEP->hasOneUse()) { 749 // Only in this case do we want to build a wb gather, because the wb will 750 // change the phi which does affect other users of the gep (which will still 751 // be using the phi in the old way) 752 if (auto *Load = tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, 753 TypeScale, Builder)) 754 return Load; 755 } 756 757 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 758 "non-wb gather/scatter\n"); 759 760 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 761 if (Add.first == nullptr) 762 return nullptr; 763 Value *OffsetsIncoming = Add.first; 764 int64_t Immediate = Add.second; 765 766 // Make sure the offsets are scaled correctly 767 Instruction *ScaledOffsets = BinaryOperator::Create( 768 Instruction::Shl, OffsetsIncoming, 769 Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), 770 "ScaledIndex", I); 771 // Add the base to the offsets 772 OffsetsIncoming = BinaryOperator::Create( 773 Instruction::Add, ScaledOffsets, 774 Builder.CreateVectorSplat( 775 Ty->getNumElements(), 776 Builder.CreatePtrToInt( 777 BasePtr, 778 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 779 "StartIndex", I); 780 781 if (I->getIntrinsicID() == Intrinsic::masked_gather) 782 return tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate); 783 else 784 return tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate); 785 } 786 787 Instruction *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( 788 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, 789 IRBuilder<> &Builder) { 790 // Check whether this gather's offset is incremented by a constant - if so, 791 // and the load is of the right type, we can merge this into a QI gather 792 Loop *L = LI->getLoopFor(I->getParent()); 793 // Offsets that are worth merging into this instruction will be incremented 794 // by a constant, thus we're looking for an add of a phi and a constant 795 PHINode *Phi = dyn_cast<PHINode>(Offsets); 796 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || 797 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) 798 // No phi means no IV to write back to; if there is a phi, we expect it 799 // to have exactly two incoming values; the only phis we are interested in 800 // will be loop IV's and have exactly two uses, one in their increment and 801 // one in the gather's gep 802 return nullptr; 803 804 unsigned IncrementIndex = 805 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; 806 // Look through the phi to the phi increment 807 Offsets = Phi->getIncomingValue(IncrementIndex); 808 809 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 810 if (Add.first == nullptr) 811 return nullptr; 812 Value *OffsetsIncoming = Add.first; 813 int64_t Immediate = Add.second; 814 if (OffsetsIncoming != Phi) 815 // Then the increment we are looking at is not an increment of the 816 // induction variable, and we don't want to do a writeback 817 return nullptr; 818 819 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); 820 unsigned NumElems = 821 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); 822 823 // Make sure the offsets are scaled correctly 824 Instruction *ScaledOffsets = BinaryOperator::Create( 825 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), 826 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), 827 "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 828 // Add the base to the offsets 829 OffsetsIncoming = BinaryOperator::Create( 830 Instruction::Add, ScaledOffsets, 831 Builder.CreateVectorSplat( 832 NumElems, 833 Builder.CreatePtrToInt( 834 BasePtr, 835 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 836 "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 837 // The gather is pre-incrementing 838 OffsetsIncoming = BinaryOperator::Create( 839 Instruction::Sub, OffsetsIncoming, 840 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), 841 "PreIncrementStartIndex", 842 &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 843 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); 844 845 Builder.SetInsertPoint(I); 846 847 Instruction *EndResult; 848 Instruction *NewInduction; 849 if (I->getIntrinsicID() == Intrinsic::masked_gather) { 850 // Build the incrementing gather 851 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); 852 // One value to be handed to whoever uses the gather, one is the loop 853 // increment 854 EndResult = ExtractValueInst::Create(Load, 0, "Gather"); 855 NewInduction = ExtractValueInst::Create(Load, 1, "GatherIncrement"); 856 Builder.Insert(EndResult); 857 Builder.Insert(NewInduction); 858 } else { 859 // Build the incrementing scatter 860 EndResult = NewInduction = 861 tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); 862 } 863 Instruction *AddInst = cast<Instruction>(Offsets); 864 AddInst->replaceAllUsesWith(NewInduction); 865 AddInst->eraseFromParent(); 866 Phi->setIncomingValue(IncrementIndex, NewInduction); 867 868 return EndResult; 869 } 870 871 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, 872 Value *OffsSecondOperand, 873 unsigned StartIndex) { 874 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); 875 Instruction *InsertionPoint = 876 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back()); 877 // Initialize the phi with a vector that contains a sum of the constants 878 Instruction *NewIndex = BinaryOperator::Create( 879 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, 880 "PushedOutAdd", InsertionPoint); 881 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; 882 883 // Order such that start index comes first (this reduces mov's) 884 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); 885 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), 886 Phi->getIncomingBlock(IncrementIndex)); 887 Phi->removeIncomingValue(IncrementIndex); 888 Phi->removeIncomingValue(StartIndex); 889 } 890 891 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, 892 Value *IncrementPerRound, 893 Value *OffsSecondOperand, 894 unsigned LoopIncrement, 895 IRBuilder<> &Builder) { 896 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); 897 898 // Create a new scalar add outside of the loop and transform it to a splat 899 // by which loop variable can be incremented 900 Instruction *InsertionPoint = &cast<Instruction>( 901 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); 902 903 // Create a new index 904 Value *StartIndex = BinaryOperator::Create( 905 Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), 906 OffsSecondOperand, "PushedOutMul", InsertionPoint); 907 908 Instruction *Product = 909 BinaryOperator::Create(Instruction::Mul, IncrementPerRound, 910 OffsSecondOperand, "Product", InsertionPoint); 911 // Increment NewIndex by Product instead of the multiplication 912 Instruction *NewIncrement = BinaryOperator::Create( 913 Instruction::Add, Phi, Product, "IncrementPushedOutMul", 914 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back()) 915 .getPrevNode()); 916 917 Phi->addIncoming(StartIndex, 918 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); 919 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); 920 Phi->removeIncomingValue((unsigned)0); 921 Phi->removeIncomingValue((unsigned)0); 922 } 923 924 // Check whether all usages of this instruction are as offsets of 925 // gathers/scatters or simple arithmetics only used by gathers/scatters 926 static bool hasAllGatScatUsers(Instruction *I) { 927 if (I->hasNUses(0)) { 928 return false; 929 } 930 bool Gatscat = true; 931 for (User *U : I->users()) { 932 if (!isa<Instruction>(U)) 933 return false; 934 if (isa<GetElementPtrInst>(U) || 935 isGatherScatter(dyn_cast<IntrinsicInst>(U))) { 936 return Gatscat; 937 } else { 938 unsigned OpCode = cast<Instruction>(U)->getOpcode(); 939 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && 940 hasAllGatScatUsers(cast<Instruction>(U))) { 941 continue; 942 } 943 return false; 944 } 945 } 946 return Gatscat; 947 } 948 949 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, 950 LoopInfo *LI) { 951 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n" 952 << *Offsets << "\n"); 953 // Optimise the addresses of gathers/scatters by moving invariant 954 // calculations out of the loop 955 if (!isa<Instruction>(Offsets)) 956 return false; 957 Instruction *Offs = cast<Instruction>(Offsets); 958 if (Offs->getOpcode() != Instruction::Add && 959 Offs->getOpcode() != Instruction::Mul) 960 return false; 961 Loop *L = LI->getLoopFor(BB); 962 if (L == nullptr) 963 return false; 964 if (!Offs->hasOneUse()) { 965 if (!hasAllGatScatUsers(Offs)) 966 return false; 967 } 968 969 // Find out which, if any, operand of the instruction 970 // is a phi node 971 PHINode *Phi; 972 int OffsSecondOp; 973 if (isa<PHINode>(Offs->getOperand(0))) { 974 Phi = cast<PHINode>(Offs->getOperand(0)); 975 OffsSecondOp = 1; 976 } else if (isa<PHINode>(Offs->getOperand(1))) { 977 Phi = cast<PHINode>(Offs->getOperand(1)); 978 OffsSecondOp = 0; 979 } else { 980 bool Changed = false; 981 if (isa<Instruction>(Offs->getOperand(0)) && 982 L->contains(cast<Instruction>(Offs->getOperand(0)))) 983 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); 984 if (isa<Instruction>(Offs->getOperand(1)) && 985 L->contains(cast<Instruction>(Offs->getOperand(1)))) 986 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); 987 if (!Changed) 988 return false; 989 if (isa<PHINode>(Offs->getOperand(0))) { 990 Phi = cast<PHINode>(Offs->getOperand(0)); 991 OffsSecondOp = 1; 992 } else if (isa<PHINode>(Offs->getOperand(1))) { 993 Phi = cast<PHINode>(Offs->getOperand(1)); 994 OffsSecondOp = 0; 995 } else { 996 return false; 997 } 998 } 999 // A phi node we want to perform this function on should be from the 1000 // loop header. 1001 if (Phi->getParent() != L->getHeader()) 1002 return false; 1003 1004 // We're looking for a simple add recurrence. 1005 BinaryOperator *IncInstruction; 1006 Value *Start, *IncrementPerRound; 1007 if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) || 1008 IncInstruction->getOpcode() != Instruction::Add) 1009 return false; 1010 1011 int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1; 1012 1013 // Get the value that is added to/multiplied with the phi 1014 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); 1015 1016 if (IncrementPerRound->getType() != OffsSecondOperand->getType() || 1017 !L->isLoopInvariant(OffsSecondOperand)) 1018 // Something has gone wrong, abort 1019 return false; 1020 1021 // Only proceed if the increment per round is a constant or an instruction 1022 // which does not originate from within the loop 1023 if (!isa<Constant>(IncrementPerRound) && 1024 !(isa<Instruction>(IncrementPerRound) && 1025 !L->contains(cast<Instruction>(IncrementPerRound)))) 1026 return false; 1027 1028 // If the phi is not used by anything else, we can just adapt it when 1029 // replacing the instruction; if it is, we'll have to duplicate it 1030 PHINode *NewPhi; 1031 if (Phi->getNumUses() == 2) { 1032 // No other users -> reuse existing phi (One user is the instruction 1033 // we're looking at, the other is the phi increment) 1034 if (IncInstruction->getNumUses() != 1) { 1035 // If the incrementing instruction does have more users than 1036 // our phi, we need to copy it 1037 IncInstruction = BinaryOperator::Create( 1038 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, 1039 IncrementPerRound, "LoopIncrement", IncInstruction); 1040 Phi->setIncomingValue(IncrementingBlock, IncInstruction); 1041 } 1042 NewPhi = Phi; 1043 } else { 1044 // There are other users -> create a new phi 1045 NewPhi = PHINode::Create(Phi->getType(), 2, "NewPhi", Phi); 1046 // Copy the incoming values of the old phi 1047 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), 1048 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); 1049 IncInstruction = BinaryOperator::Create( 1050 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, 1051 IncrementPerRound, "LoopIncrement", IncInstruction); 1052 NewPhi->addIncoming(IncInstruction, 1053 Phi->getIncomingBlock(IncrementingBlock)); 1054 IncrementingBlock = 1; 1055 } 1056 1057 IRBuilder<> Builder(BB->getContext()); 1058 Builder.SetInsertPoint(Phi); 1059 Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); 1060 1061 switch (Offs->getOpcode()) { 1062 case Instruction::Add: 1063 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); 1064 break; 1065 case Instruction::Mul: 1066 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, 1067 Builder); 1068 break; 1069 default: 1070 return false; 1071 } 1072 LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable " 1073 << "add/mul\n"); 1074 1075 // The instruction has now been "absorbed" into the phi value 1076 Offs->replaceAllUsesWith(NewPhi); 1077 if (Offs->hasNUses(0)) 1078 Offs->eraseFromParent(); 1079 // Clean up the old increment in case it's unused because we built a new 1080 // one 1081 if (IncInstruction->hasNUses(0)) 1082 IncInstruction->eraseFromParent(); 1083 1084 return true; 1085 } 1086 1087 static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP, 1088 IRBuilder<> &Builder) { 1089 // Splat the non-vector value to a vector of the given type - if the value is 1090 // a constant (and its value isn't too big), we can even use this opportunity 1091 // to scale it to the size of the vector elements 1092 auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) { 1093 ConstantInt *Const; 1094 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) && 1095 VT->getElementType() != NonVectorVal->getType()) { 1096 unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits(); 1097 uint64_t N = Const->getZExtValue(); 1098 if (N < (unsigned)(1 << (TargetElemSize - 1))) { 1099 NonVectorVal = Builder.CreateVectorSplat( 1100 VT->getNumElements(), Builder.getIntN(TargetElemSize, N)); 1101 return; 1102 } 1103 } 1104 NonVectorVal = 1105 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal); 1106 }; 1107 1108 FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType()); 1109 FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType()); 1110 // If one of X, Y is not a vector, we have to splat it in order 1111 // to add the two of them. 1112 if (XElType && !YElType) { 1113 FixSummands(XElType, Y); 1114 YElType = cast<FixedVectorType>(Y->getType()); 1115 } else if (YElType && !XElType) { 1116 FixSummands(YElType, X); 1117 XElType = cast<FixedVectorType>(X->getType()); 1118 } 1119 assert(XElType && YElType && "Unknown vector types"); 1120 // Check that the summands are of compatible types 1121 if (XElType != YElType) { 1122 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n"); 1123 return nullptr; 1124 } 1125 1126 if (XElType->getElementType()->getScalarSizeInBits() != 32) { 1127 // Check that by adding the vectors we do not accidentally 1128 // create an overflow 1129 Constant *ConstX = dyn_cast<Constant>(X); 1130 Constant *ConstY = dyn_cast<Constant>(Y); 1131 if (!ConstX || !ConstY) 1132 return nullptr; 1133 unsigned TargetElemSize = 128 / XElType->getNumElements(); 1134 for (unsigned i = 0; i < XElType->getNumElements(); i++) { 1135 ConstantInt *ConstXEl = 1136 dyn_cast<ConstantInt>(ConstX->getAggregateElement(i)); 1137 ConstantInt *ConstYEl = 1138 dyn_cast<ConstantInt>(ConstY->getAggregateElement(i)); 1139 if (!ConstXEl || !ConstYEl || 1140 ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >= 1141 (unsigned)(1 << (TargetElemSize - 1))) 1142 return nullptr; 1143 } 1144 } 1145 1146 Value *Add = Builder.CreateAdd(X, Y); 1147 1148 FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType()); 1149 if (checkOffsetSize(Add, GEPType->getNumElements())) 1150 return Add; 1151 else 1152 return nullptr; 1153 } 1154 1155 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP, 1156 Value *&Offsets, 1157 IRBuilder<> &Builder) { 1158 Value *GEPPtr = GEP->getPointerOperand(); 1159 Offsets = GEP->getOperand(1); 1160 // We only merge geps with constant offsets, because only for those 1161 // we can make sure that we do not cause an overflow 1162 if (!isa<Constant>(Offsets)) 1163 return nullptr; 1164 GetElementPtrInst *BaseGEP; 1165 if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) { 1166 // Merge the two geps into one 1167 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder); 1168 if (!BaseBasePtr) 1169 return nullptr; 1170 Offsets = 1171 CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder); 1172 if (Offsets == nullptr) 1173 return nullptr; 1174 return BaseBasePtr; 1175 } 1176 return GEPPtr; 1177 } 1178 1179 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB, 1180 LoopInfo *LI) { 1181 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address); 1182 if (!GEP) 1183 return false; 1184 bool Changed = false; 1185 if (GEP->hasOneUse() && 1186 dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) { 1187 IRBuilder<> Builder(GEP->getContext()); 1188 Builder.SetInsertPoint(GEP); 1189 Builder.SetCurrentDebugLocation(GEP->getDebugLoc()); 1190 Value *Offsets; 1191 Value *Base = foldGEP(GEP, Offsets, Builder); 1192 // We only want to merge the geps if there is a real chance that they can be 1193 // used by an MVE gather; thus the offset has to have the correct size 1194 // (always i32 if it is not of vector type) and the base has to be a 1195 // pointer. 1196 if (Offsets && Base && Base != GEP) { 1197 GetElementPtrInst *NewAddress = GetElementPtrInst::Create( 1198 GEP->getSourceElementType(), Base, Offsets, "gep.merged", GEP); 1199 GEP->replaceAllUsesWith(NewAddress); 1200 GEP = NewAddress; 1201 Changed = true; 1202 } 1203 } 1204 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI); 1205 return Changed; 1206 } 1207 1208 bool MVEGatherScatterLowering::runOnFunction(Function &F) { 1209 if (!EnableMaskedGatherScatters) 1210 return false; 1211 auto &TPC = getAnalysis<TargetPassConfig>(); 1212 auto &TM = TPC.getTM<TargetMachine>(); 1213 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 1214 if (!ST->hasMVEIntegerOps()) 1215 return false; 1216 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1217 SmallVector<IntrinsicInst *, 4> Gathers; 1218 SmallVector<IntrinsicInst *, 4> Scatters; 1219 1220 bool Changed = false; 1221 1222 for (BasicBlock &BB : F) { 1223 Changed |= SimplifyInstructionsInBlock(&BB); 1224 1225 for (Instruction &I : BB) { 1226 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 1227 if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 1228 isa<FixedVectorType>(II->getType())) { 1229 Gathers.push_back(II); 1230 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI); 1231 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 1232 isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 1233 Scatters.push_back(II); 1234 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI); 1235 } 1236 } 1237 } 1238 for (unsigned i = 0; i < Gathers.size(); i++) { 1239 IntrinsicInst *I = Gathers[i]; 1240 Instruction *L = lowerGather(I); 1241 if (L == nullptr) 1242 continue; 1243 1244 // Get rid of any now dead instructions 1245 SimplifyInstructionsInBlock(L->getParent()); 1246 Changed = true; 1247 } 1248 1249 for (unsigned i = 0; i < Scatters.size(); i++) { 1250 IntrinsicInst *I = Scatters[i]; 1251 Instruction *S = lowerScatter(I); 1252 if (S == nullptr) 1253 continue; 1254 1255 // Get rid of any now dead instructions 1256 SimplifyInstructionsInBlock(S->getParent()); 1257 Changed = true; 1258 } 1259 return Changed; 1260 } 1261