1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to 10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to 11 /// produce a better final result as we go. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "ARM.h" 16 #include "ARMBaseInstrInfo.h" 17 #include "ARMSubtarget.h" 18 #include "llvm/Analysis/LoopInfo.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/CodeGen/TargetLowering.h" 21 #include "llvm/CodeGen/TargetPassConfig.h" 22 #include "llvm/CodeGen/TargetSubtargetInfo.h" 23 #include "llvm/InitializePasses.h" 24 #include "llvm/IR/BasicBlock.h" 25 #include "llvm/IR/Constant.h" 26 #include "llvm/IR/Constants.h" 27 #include "llvm/IR/DerivedTypes.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/InstrTypes.h" 30 #include "llvm/IR/Instruction.h" 31 #include "llvm/IR/Instructions.h" 32 #include "llvm/IR/IntrinsicInst.h" 33 #include "llvm/IR/Intrinsics.h" 34 #include "llvm/IR/IntrinsicsARM.h" 35 #include "llvm/IR/IRBuilder.h" 36 #include "llvm/IR/PatternMatch.h" 37 #include "llvm/IR/Type.h" 38 #include "llvm/IR/Value.h" 39 #include "llvm/Pass.h" 40 #include "llvm/Support/Casting.h" 41 #include "llvm/Transforms/Utils/Local.h" 42 #include <algorithm> 43 #include <cassert> 44 45 using namespace llvm; 46 47 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering" 48 49 cl::opt<bool> EnableMaskedGatherScatters( 50 "enable-arm-maskedgatscat", cl::Hidden, cl::init(true), 51 cl::desc("Enable the generation of masked gathers and scatters")); 52 53 namespace { 54 55 class MVEGatherScatterLowering : public FunctionPass { 56 public: 57 static char ID; // Pass identification, replacement for typeid 58 59 explicit MVEGatherScatterLowering() : FunctionPass(ID) { 60 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); 61 } 62 63 bool runOnFunction(Function &F) override; 64 65 StringRef getPassName() const override { 66 return "MVE gather/scatter lowering"; 67 } 68 69 void getAnalysisUsage(AnalysisUsage &AU) const override { 70 AU.setPreservesCFG(); 71 AU.addRequired<TargetPassConfig>(); 72 AU.addRequired<LoopInfoWrapperPass>(); 73 FunctionPass::getAnalysisUsage(AU); 74 } 75 76 private: 77 LoopInfo *LI = nullptr; 78 79 // Check this is a valid gather with correct alignment 80 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, 81 Align Alignment); 82 // Check whether Ptr is hidden behind a bitcast and look through it 83 void lookThroughBitcast(Value *&Ptr); 84 // Check for a getelementptr and deduce base and offsets from it, on success 85 // returning the base directly and the offsets indirectly using the Offsets 86 // argument 87 Value *checkGEP(Value *&Offsets, FixedVectorType *Ty, GetElementPtrInst *GEP, 88 IRBuilder<> &Builder); 89 // Compute the scale of this gather/scatter instruction 90 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); 91 // If the value is a constant, or derived from constants via additions 92 // and multilications, return its numeric value 93 Optional<int64_t> getIfConst(const Value *V); 94 // If Inst is an add instruction, check whether one summand is a 95 // constant. If so, scale this constant and return it together with 96 // the other summand. 97 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); 98 99 Value *lowerGather(IntrinsicInst *I); 100 // Create a gather from a base + vector of offsets 101 Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, 102 Instruction *&Root, IRBuilder<> &Builder); 103 // Create a gather from a vector of pointers 104 Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, 105 IRBuilder<> &Builder, int64_t Increment = 0); 106 // Create an incrementing gather from a vector of pointers 107 Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, 108 IRBuilder<> &Builder, 109 int64_t Increment = 0); 110 111 Value *lowerScatter(IntrinsicInst *I); 112 // Create a scatter to a base + vector of offsets 113 Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, 114 IRBuilder<> &Builder); 115 // Create a scatter to a vector of pointers 116 Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, 117 IRBuilder<> &Builder, 118 int64_t Increment = 0); 119 // Create an incrementing scatter from a vector of pointers 120 Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, 121 IRBuilder<> &Builder, 122 int64_t Increment = 0); 123 124 // QI gathers and scatters can increment their offsets on their own if 125 // the increment is a constant value (digit) 126 Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr, 127 Value *Ptr, GetElementPtrInst *GEP, 128 IRBuilder<> &Builder); 129 // QI gathers/scatters can increment their offsets on their own if the 130 // increment is a constant value (digit) - this creates a writeback QI 131 // gather/scatter 132 Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, 133 Value *Ptr, unsigned TypeScale, 134 IRBuilder<> &Builder); 135 136 // Optimise the base and offsets of the given address 137 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI); 138 // Try to fold consecutive geps together into one 139 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder); 140 // Check whether these offsets could be moved out of the loop they're in 141 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); 142 // Pushes the given add out of the loop 143 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); 144 // Pushes the given mul out of the loop 145 void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, 146 Value *OffsSecondOperand, unsigned LoopIncrement, 147 IRBuilder<> &Builder); 148 }; 149 150 } // end anonymous namespace 151 152 char MVEGatherScatterLowering::ID = 0; 153 154 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, 155 "MVE gather/scattering lowering pass", false, false) 156 157 Pass *llvm::createMVEGatherScatterLoweringPass() { 158 return new MVEGatherScatterLowering(); 159 } 160 161 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, 162 unsigned ElemSize, 163 Align Alignment) { 164 if (((NumElements == 4 && 165 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || 166 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || 167 (NumElements == 16 && ElemSize == 8)) && 168 Alignment >= ElemSize / 8) 169 return true; 170 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " 171 << "valid alignment or vector type \n"); 172 return false; 173 } 174 175 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) { 176 // Offsets that are not of type <N x i32> are sign extended by the 177 // getelementptr instruction, and MVE gathers/scatters treat the offset as 178 // unsigned. Thus, if the element size is smaller than 32, we can only allow 179 // positive offsets - i.e., the offsets are not allowed to be variables we 180 // can't look into. 181 // Additionally, <N x i32> offsets have to either originate from a zext of a 182 // vector with element types smaller or equal the type of the gather we're 183 // looking at, or consist of constants that we can check are small enough 184 // to fit into the gather type. 185 // Thus we check that 0 < value < 2^TargetElemSize. 186 unsigned TargetElemSize = 128 / TargetElemCount; 187 unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType()) 188 ->getElementType() 189 ->getScalarSizeInBits(); 190 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) { 191 Constant *ConstOff = dyn_cast<Constant>(Offsets); 192 if (!ConstOff) 193 return false; 194 int64_t TargetElemMaxSize = (1ULL << TargetElemSize); 195 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) { 196 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem); 197 if (!OConst) 198 return false; 199 int SExtValue = OConst->getSExtValue(); 200 if (SExtValue >= TargetElemMaxSize || SExtValue < 0) 201 return false; 202 return true; 203 }; 204 if (isa<FixedVectorType>(ConstOff->getType())) { 205 for (unsigned i = 0; i < TargetElemCount; i++) { 206 if (!CheckValueSize(ConstOff->getAggregateElement(i))) 207 return false; 208 } 209 } else { 210 if (!CheckValueSize(ConstOff)) 211 return false; 212 } 213 } 214 return true; 215 } 216 217 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty, 218 GetElementPtrInst *GEP, 219 IRBuilder<> &Builder) { 220 if (!GEP) { 221 LLVM_DEBUG( 222 dbgs() << "masked gathers/scatters: no getelementpointer found\n"); 223 return nullptr; 224 } 225 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." 226 << " Looking at intrinsic for base + vector of offsets\n"); 227 Value *GEPPtr = GEP->getPointerOperand(); 228 Offsets = GEP->getOperand(1); 229 if (GEPPtr->getType()->isVectorTy() || 230 !isa<FixedVectorType>(Offsets->getType())) 231 return nullptr; 232 233 if (GEP->getNumOperands() != 2) { 234 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" 235 << " operands. Expanding.\n"); 236 return nullptr; 237 } 238 Offsets = GEP->getOperand(1); 239 unsigned OffsetsElemCount = 240 cast<FixedVectorType>(Offsets->getType())->getNumElements(); 241 // Paranoid check whether the number of parallel lanes is the same 242 assert(Ty->getNumElements() == OffsetsElemCount); 243 244 ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets); 245 if (ZextOffs) 246 Offsets = ZextOffs->getOperand(0); 247 FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType()); 248 249 // If the offsets are already being zext-ed to <N x i32>, that relieves us of 250 // having to make sure that they won't overflow. 251 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy()) 252 ->getElementType() 253 ->getScalarSizeInBits() != 32) 254 if (!checkOffsetSize(Offsets, OffsetsElemCount)) 255 return nullptr; 256 257 // The offset sizes have been checked; if any truncating or zext-ing is 258 // required to fix them, do that now 259 if (Ty != Offsets->getType()) { 260 if ((Ty->getElementType()->getScalarSizeInBits() < 261 OffsetType->getElementType()->getScalarSizeInBits())) { 262 Offsets = Builder.CreateTrunc(Offsets, Ty); 263 } else { 264 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty)); 265 } 266 } 267 // If none of the checks failed, return the gep's base pointer 268 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); 269 return GEPPtr; 270 } 271 272 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { 273 // Look through bitcast instruction if #elements is the same 274 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { 275 auto *BCTy = cast<FixedVectorType>(BitCast->getType()); 276 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); 277 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { 278 LLVM_DEBUG( 279 dbgs() << "masked gathers/scatters: looking through bitcast\n"); 280 Ptr = BitCast->getOperand(0); 281 } 282 } 283 } 284 285 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, 286 unsigned MemoryElemSize) { 287 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, 288 // or a 8bit, 16bit or 32bit load/store scaled by 1 289 if (GEPElemSize == 32 && MemoryElemSize == 32) 290 return 2; 291 else if (GEPElemSize == 16 && MemoryElemSize == 16) 292 return 1; 293 else if (GEPElemSize == 8) 294 return 0; 295 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " 296 << "create intrinsic\n"); 297 return -1; 298 } 299 300 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { 301 const Constant *C = dyn_cast<Constant>(V); 302 if (C != nullptr) 303 return Optional<int64_t>{C->getUniqueInteger().getSExtValue()}; 304 if (!isa<Instruction>(V)) 305 return Optional<int64_t>{}; 306 307 const Instruction *I = cast<Instruction>(V); 308 if (I->getOpcode() == Instruction::Add || 309 I->getOpcode() == Instruction::Mul) { 310 Optional<int64_t> Op0 = getIfConst(I->getOperand(0)); 311 Optional<int64_t> Op1 = getIfConst(I->getOperand(1)); 312 if (!Op0 || !Op1) 313 return Optional<int64_t>{}; 314 if (I->getOpcode() == Instruction::Add) 315 return Optional<int64_t>{Op0.getValue() + Op1.getValue()}; 316 if (I->getOpcode() == Instruction::Mul) 317 return Optional<int64_t>{Op0.getValue() * Op1.getValue()}; 318 } 319 return Optional<int64_t>{}; 320 } 321 322 std::pair<Value *, int64_t> 323 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { 324 std::pair<Value *, int64_t> ReturnFalse = 325 std::pair<Value *, int64_t>(nullptr, 0); 326 // At this point, the instruction we're looking at must be an add or we 327 // bail out 328 Instruction *Add = dyn_cast<Instruction>(Inst); 329 if (Add == nullptr || Add->getOpcode() != Instruction::Add) 330 return ReturnFalse; 331 332 Value *Summand; 333 Optional<int64_t> Const; 334 // Find out which operand the value that is increased is 335 if ((Const = getIfConst(Add->getOperand(0)))) 336 Summand = Add->getOperand(1); 337 else if ((Const = getIfConst(Add->getOperand(1)))) 338 Summand = Add->getOperand(0); 339 else 340 return ReturnFalse; 341 342 // Check that the constant is small enough for an incrementing gather 343 int64_t Immediate = Const.getValue() << TypeScale; 344 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) 345 return ReturnFalse; 346 347 return std::pair<Value *, int64_t>(Summand, Immediate); 348 } 349 350 Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { 351 using namespace PatternMatch; 352 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n"); 353 354 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) 355 // Attempt to turn the masked gather in I into a MVE intrinsic 356 // Potentially optimising the addressing modes as we do so. 357 auto *Ty = cast<FixedVectorType>(I->getType()); 358 Value *Ptr = I->getArgOperand(0); 359 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); 360 Value *Mask = I->getArgOperand(2); 361 Value *PassThru = I->getArgOperand(3); 362 363 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 364 Alignment)) 365 return nullptr; 366 lookThroughBitcast(Ptr); 367 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 368 369 IRBuilder<> Builder(I->getContext()); 370 Builder.SetInsertPoint(I); 371 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 372 373 Instruction *Root = I; 374 Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); 375 if (!Load) 376 Load = tryCreateMaskedGatherBase(I, Ptr, Builder); 377 if (!Load) 378 return nullptr; 379 380 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { 381 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " 382 << "creating select\n"); 383 Load = Builder.CreateSelect(Mask, Load, PassThru); 384 } 385 386 Root->replaceAllUsesWith(Load); 387 Root->eraseFromParent(); 388 if (Root != I) 389 // If this was an extending gather, we need to get rid of the sext/zext 390 // sext/zext as well as of the gather itself 391 I->eraseFromParent(); 392 393 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n"); 394 return Load; 395 } 396 397 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I, 398 Value *Ptr, 399 IRBuilder<> &Builder, 400 int64_t Increment) { 401 using namespace PatternMatch; 402 auto *Ty = cast<FixedVectorType>(I->getType()); 403 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); 404 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 405 // Can't build an intrinsic for this 406 return nullptr; 407 Value *Mask = I->getArgOperand(2); 408 if (match(Mask, m_One())) 409 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, 410 {Ty, Ptr->getType()}, 411 {Ptr, Builder.getInt32(Increment)}); 412 else 413 return Builder.CreateIntrinsic( 414 Intrinsic::arm_mve_vldr_gather_base_predicated, 415 {Ty, Ptr->getType(), Mask->getType()}, 416 {Ptr, Builder.getInt32(Increment), Mask}); 417 } 418 419 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( 420 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 421 using namespace PatternMatch; 422 auto *Ty = cast<FixedVectorType>(I->getType()); 423 LLVM_DEBUG( 424 dbgs() 425 << "masked gathers: loading from vector of pointers with writeback\n"); 426 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 427 // Can't build an intrinsic for this 428 return nullptr; 429 Value *Mask = I->getArgOperand(2); 430 if (match(Mask, m_One())) 431 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, 432 {Ty, Ptr->getType()}, 433 {Ptr, Builder.getInt32(Increment)}); 434 else 435 return Builder.CreateIntrinsic( 436 Intrinsic::arm_mve_vldr_gather_base_wb_predicated, 437 {Ty, Ptr->getType(), Mask->getType()}, 438 {Ptr, Builder.getInt32(Increment), Mask}); 439 } 440 441 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( 442 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { 443 using namespace PatternMatch; 444 445 Type *OriginalTy = I->getType(); 446 Type *ResultTy = OriginalTy; 447 448 unsigned Unsigned = 1; 449 // The size of the gather was already checked in isLegalTypeAndAlignment; 450 // if it was not a full vector width an appropriate extend should follow. 451 auto *Extend = Root; 452 if (OriginalTy->getPrimitiveSizeInBits() < 128) { 453 // Only transform gathers with exactly one use 454 if (!I->hasOneUse()) 455 return nullptr; 456 457 // The correct root to replace is not the CallInst itself, but the 458 // instruction which extends it 459 Extend = cast<Instruction>(*I->users().begin()); 460 if (isa<SExtInst>(Extend)) { 461 Unsigned = 0; 462 } else if (!isa<ZExtInst>(Extend)) { 463 LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. " 464 << "Expanding\n"); 465 return nullptr; 466 } 467 LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n"); 468 ResultTy = Extend->getType(); 469 // The final size of the gather must be a full vector width 470 if (ResultTy->getPrimitiveSizeInBits() != 128) { 471 LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. " 472 << "Expanding\n"); 473 return nullptr; 474 } 475 } 476 477 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 478 Value *Offsets; 479 Value *BasePtr = 480 checkGEP(Offsets, cast<FixedVectorType>(ResultTy), GEP, Builder); 481 if (!BasePtr) 482 return nullptr; 483 // Check whether the offset is a constant increment that could be merged into 484 // a QI gather 485 Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); 486 if (Load) 487 return Load; 488 489 int Scale = computeScale( 490 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), 491 OriginalTy->getScalarSizeInBits()); 492 if (Scale == -1) 493 return nullptr; 494 Root = Extend; 495 496 Value *Mask = I->getArgOperand(2); 497 if (!match(Mask, m_One())) 498 return Builder.CreateIntrinsic( 499 Intrinsic::arm_mve_vldr_gather_offset_predicated, 500 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, 501 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), 502 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); 503 else 504 return Builder.CreateIntrinsic( 505 Intrinsic::arm_mve_vldr_gather_offset, 506 {ResultTy, BasePtr->getType(), Offsets->getType()}, 507 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), 508 Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); 509 } 510 511 Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { 512 using namespace PatternMatch; 513 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n"); 514 515 // @llvm.masked.scatter.*(data, ptrs, alignment, mask) 516 // Attempt to turn the masked scatter in I into a MVE intrinsic 517 // Potentially optimising the addressing modes as we do so. 518 Value *Input = I->getArgOperand(0); 519 Value *Ptr = I->getArgOperand(1); 520 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); 521 auto *Ty = cast<FixedVectorType>(Input->getType()); 522 523 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 524 Alignment)) 525 return nullptr; 526 527 lookThroughBitcast(Ptr); 528 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 529 530 IRBuilder<> Builder(I->getContext()); 531 Builder.SetInsertPoint(I); 532 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 533 534 Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); 535 if (!Store) 536 Store = tryCreateMaskedScatterBase(I, Ptr, Builder); 537 if (!Store) 538 return nullptr; 539 540 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n"); 541 I->eraseFromParent(); 542 return Store; 543 } 544 545 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase( 546 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 547 using namespace PatternMatch; 548 Value *Input = I->getArgOperand(0); 549 auto *Ty = cast<FixedVectorType>(Input->getType()); 550 // Only QR variants allow truncating 551 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { 552 // Can't build an intrinsic for this 553 return nullptr; 554 } 555 Value *Mask = I->getArgOperand(3); 556 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) 557 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); 558 if (match(Mask, m_One())) 559 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, 560 {Ptr->getType(), Input->getType()}, 561 {Ptr, Builder.getInt32(Increment), Input}); 562 else 563 return Builder.CreateIntrinsic( 564 Intrinsic::arm_mve_vstr_scatter_base_predicated, 565 {Ptr->getType(), Input->getType(), Mask->getType()}, 566 {Ptr, Builder.getInt32(Increment), Input, Mask}); 567 } 568 569 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( 570 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 571 using namespace PatternMatch; 572 Value *Input = I->getArgOperand(0); 573 auto *Ty = cast<FixedVectorType>(Input->getType()); 574 LLVM_DEBUG( 575 dbgs() 576 << "masked scatters: storing to a vector of pointers with writeback\n"); 577 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 578 // Can't build an intrinsic for this 579 return nullptr; 580 Value *Mask = I->getArgOperand(3); 581 if (match(Mask, m_One())) 582 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, 583 {Ptr->getType(), Input->getType()}, 584 {Ptr, Builder.getInt32(Increment), Input}); 585 else 586 return Builder.CreateIntrinsic( 587 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, 588 {Ptr->getType(), Input->getType(), Mask->getType()}, 589 {Ptr, Builder.getInt32(Increment), Input, Mask}); 590 } 591 592 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( 593 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 594 using namespace PatternMatch; 595 Value *Input = I->getArgOperand(0); 596 Value *Mask = I->getArgOperand(3); 597 Type *InputTy = Input->getType(); 598 Type *MemoryTy = InputTy; 599 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" 600 << " to base + vector of offsets\n"); 601 // If the input has been truncated, try to integrate that trunc into the 602 // scatter instruction (we don't care about alignment here) 603 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { 604 Value *PreTrunc = Trunc->getOperand(0); 605 Type *PreTruncTy = PreTrunc->getType(); 606 if (PreTruncTy->getPrimitiveSizeInBits() == 128) { 607 Input = PreTrunc; 608 InputTy = PreTruncTy; 609 } 610 } 611 if (InputTy->getPrimitiveSizeInBits() != 128) { 612 LLVM_DEBUG( 613 dbgs() << "masked scatters: cannot create scatters for non-standard" 614 << " input types. Expanding.\n"); 615 return nullptr; 616 } 617 618 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 619 Value *Offsets; 620 Value *BasePtr = 621 checkGEP(Offsets, cast<FixedVectorType>(InputTy), GEP, Builder); 622 if (!BasePtr) 623 return nullptr; 624 // Check whether the offset is a constant increment that could be merged into 625 // a QI gather 626 Value *Store = 627 tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); 628 if (Store) 629 return Store; 630 int Scale = computeScale( 631 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(), 632 MemoryTy->getScalarSizeInBits()); 633 if (Scale == -1) 634 return nullptr; 635 636 if (!match(Mask, m_One())) 637 return Builder.CreateIntrinsic( 638 Intrinsic::arm_mve_vstr_scatter_offset_predicated, 639 {BasePtr->getType(), Offsets->getType(), Input->getType(), 640 Mask->getType()}, 641 {BasePtr, Offsets, Input, 642 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 643 Builder.getInt32(Scale), Mask}); 644 else 645 return Builder.CreateIntrinsic( 646 Intrinsic::arm_mve_vstr_scatter_offset, 647 {BasePtr->getType(), Offsets->getType(), Input->getType()}, 648 {BasePtr, Offsets, Input, 649 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 650 Builder.getInt32(Scale)}); 651 } 652 653 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat( 654 IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP, 655 IRBuilder<> &Builder) { 656 FixedVectorType *Ty; 657 if (I->getIntrinsicID() == Intrinsic::masked_gather) 658 Ty = cast<FixedVectorType>(I->getType()); 659 else 660 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); 661 // Incrementing gathers only exist for v4i32 662 if (Ty->getNumElements() != 4 || 663 Ty->getScalarSizeInBits() != 32) 664 return nullptr; 665 Loop *L = LI->getLoopFor(I->getParent()); 666 if (L == nullptr) 667 // Incrementing gathers are not beneficial outside of a loop 668 return nullptr; 669 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 670 "wb gather/scatter\n"); 671 672 // The gep was in charge of making sure the offsets are scaled correctly 673 // - calculate that factor so it can be applied by hand 674 DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout(); 675 int TypeScale = 676 computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()), 677 DT.getTypeSizeInBits(GEP->getType()) / 678 cast<FixedVectorType>(GEP->getType())->getNumElements()); 679 if (TypeScale == -1) 680 return nullptr; 681 682 if (GEP->hasOneUse()) { 683 // Only in this case do we want to build a wb gather, because the wb will 684 // change the phi which does affect other users of the gep (which will still 685 // be using the phi in the old way) 686 Value *Load = 687 tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder); 688 if (Load != nullptr) 689 return Load; 690 } 691 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 692 "non-wb gather/scatter\n"); 693 694 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 695 if (Add.first == nullptr) 696 return nullptr; 697 Value *OffsetsIncoming = Add.first; 698 int64_t Immediate = Add.second; 699 700 // Make sure the offsets are scaled correctly 701 Instruction *ScaledOffsets = BinaryOperator::Create( 702 Instruction::Shl, OffsetsIncoming, 703 Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), 704 "ScaledIndex", I); 705 // Add the base to the offsets 706 OffsetsIncoming = BinaryOperator::Create( 707 Instruction::Add, ScaledOffsets, 708 Builder.CreateVectorSplat( 709 Ty->getNumElements(), 710 Builder.CreatePtrToInt( 711 BasePtr, 712 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 713 "StartIndex", I); 714 715 if (I->getIntrinsicID() == Intrinsic::masked_gather) 716 return cast<IntrinsicInst>( 717 tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate)); 718 else 719 return cast<IntrinsicInst>( 720 tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate)); 721 } 722 723 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( 724 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, 725 IRBuilder<> &Builder) { 726 // Check whether this gather's offset is incremented by a constant - if so, 727 // and the load is of the right type, we can merge this into a QI gather 728 Loop *L = LI->getLoopFor(I->getParent()); 729 // Offsets that are worth merging into this instruction will be incremented 730 // by a constant, thus we're looking for an add of a phi and a constant 731 PHINode *Phi = dyn_cast<PHINode>(Offsets); 732 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || 733 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) 734 // No phi means no IV to write back to; if there is a phi, we expect it 735 // to have exactly two incoming values; the only phis we are interested in 736 // will be loop IV's and have exactly two uses, one in their increment and 737 // one in the gather's gep 738 return nullptr; 739 740 unsigned IncrementIndex = 741 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; 742 // Look through the phi to the phi increment 743 Offsets = Phi->getIncomingValue(IncrementIndex); 744 745 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 746 if (Add.first == nullptr) 747 return nullptr; 748 Value *OffsetsIncoming = Add.first; 749 int64_t Immediate = Add.second; 750 if (OffsetsIncoming != Phi) 751 // Then the increment we are looking at is not an increment of the 752 // induction variable, and we don't want to do a writeback 753 return nullptr; 754 755 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); 756 unsigned NumElems = 757 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); 758 759 // Make sure the offsets are scaled correctly 760 Instruction *ScaledOffsets = BinaryOperator::Create( 761 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), 762 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), 763 "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 764 // Add the base to the offsets 765 OffsetsIncoming = BinaryOperator::Create( 766 Instruction::Add, ScaledOffsets, 767 Builder.CreateVectorSplat( 768 NumElems, 769 Builder.CreatePtrToInt( 770 BasePtr, 771 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 772 "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 773 // The gather is pre-incrementing 774 OffsetsIncoming = BinaryOperator::Create( 775 Instruction::Sub, OffsetsIncoming, 776 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), 777 "PreIncrementStartIndex", 778 &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 779 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); 780 781 Builder.SetInsertPoint(I); 782 783 Value *EndResult; 784 Value *NewInduction; 785 if (I->getIntrinsicID() == Intrinsic::masked_gather) { 786 // Build the incrementing gather 787 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); 788 // One value to be handed to whoever uses the gather, one is the loop 789 // increment 790 EndResult = Builder.CreateExtractValue(Load, 0, "Gather"); 791 NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement"); 792 } else { 793 // Build the incrementing scatter 794 NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); 795 EndResult = NewInduction; 796 } 797 Instruction *AddInst = cast<Instruction>(Offsets); 798 AddInst->replaceAllUsesWith(NewInduction); 799 AddInst->eraseFromParent(); 800 Phi->setIncomingValue(IncrementIndex, NewInduction); 801 802 return EndResult; 803 } 804 805 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, 806 Value *OffsSecondOperand, 807 unsigned StartIndex) { 808 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); 809 Instruction *InsertionPoint = 810 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back()); 811 // Initialize the phi with a vector that contains a sum of the constants 812 Instruction *NewIndex = BinaryOperator::Create( 813 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, 814 "PushedOutAdd", InsertionPoint); 815 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; 816 817 // Order such that start index comes first (this reduces mov's) 818 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); 819 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), 820 Phi->getIncomingBlock(IncrementIndex)); 821 Phi->removeIncomingValue(IncrementIndex); 822 Phi->removeIncomingValue(StartIndex); 823 } 824 825 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, 826 Value *IncrementPerRound, 827 Value *OffsSecondOperand, 828 unsigned LoopIncrement, 829 IRBuilder<> &Builder) { 830 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); 831 832 // Create a new scalar add outside of the loop and transform it to a splat 833 // by which loop variable can be incremented 834 Instruction *InsertionPoint = &cast<Instruction>( 835 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); 836 837 // Create a new index 838 Value *StartIndex = BinaryOperator::Create( 839 Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), 840 OffsSecondOperand, "PushedOutMul", InsertionPoint); 841 842 Instruction *Product = 843 BinaryOperator::Create(Instruction::Mul, IncrementPerRound, 844 OffsSecondOperand, "Product", InsertionPoint); 845 // Increment NewIndex by Product instead of the multiplication 846 Instruction *NewIncrement = BinaryOperator::Create( 847 Instruction::Add, Phi, Product, "IncrementPushedOutMul", 848 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back()) 849 .getPrevNode()); 850 851 Phi->addIncoming(StartIndex, 852 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); 853 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); 854 Phi->removeIncomingValue((unsigned)0); 855 Phi->removeIncomingValue((unsigned)0); 856 } 857 858 // Check whether all usages of this instruction are as offsets of 859 // gathers/scatters or simple arithmetics only used by gathers/scatters 860 static bool hasAllGatScatUsers(Instruction *I) { 861 if (I->hasNUses(0)) { 862 return false; 863 } 864 bool Gatscat = true; 865 for (User *U : I->users()) { 866 if (!isa<Instruction>(U)) 867 return false; 868 if (isa<GetElementPtrInst>(U) || 869 isGatherScatter(dyn_cast<IntrinsicInst>(U))) { 870 return Gatscat; 871 } else { 872 unsigned OpCode = cast<Instruction>(U)->getOpcode(); 873 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && 874 hasAllGatScatUsers(cast<Instruction>(U))) { 875 continue; 876 } 877 return false; 878 } 879 } 880 return Gatscat; 881 } 882 883 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, 884 LoopInfo *LI) { 885 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n"); 886 // Optimise the addresses of gathers/scatters by moving invariant 887 // calculations out of the loop 888 if (!isa<Instruction>(Offsets)) 889 return false; 890 Instruction *Offs = cast<Instruction>(Offsets); 891 if (Offs->getOpcode() != Instruction::Add && 892 Offs->getOpcode() != Instruction::Mul) 893 return false; 894 Loop *L = LI->getLoopFor(BB); 895 if (L == nullptr) 896 return false; 897 if (!Offs->hasOneUse()) { 898 if (!hasAllGatScatUsers(Offs)) 899 return false; 900 } 901 902 // Find out which, if any, operand of the instruction 903 // is a phi node 904 PHINode *Phi; 905 int OffsSecondOp; 906 if (isa<PHINode>(Offs->getOperand(0))) { 907 Phi = cast<PHINode>(Offs->getOperand(0)); 908 OffsSecondOp = 1; 909 } else if (isa<PHINode>(Offs->getOperand(1))) { 910 Phi = cast<PHINode>(Offs->getOperand(1)); 911 OffsSecondOp = 0; 912 } else { 913 bool Changed = true; 914 if (isa<Instruction>(Offs->getOperand(0)) && 915 L->contains(cast<Instruction>(Offs->getOperand(0)))) 916 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); 917 if (isa<Instruction>(Offs->getOperand(1)) && 918 L->contains(cast<Instruction>(Offs->getOperand(1)))) 919 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); 920 if (!Changed) { 921 return false; 922 } else { 923 if (isa<PHINode>(Offs->getOperand(0))) { 924 Phi = cast<PHINode>(Offs->getOperand(0)); 925 OffsSecondOp = 1; 926 } else if (isa<PHINode>(Offs->getOperand(1))) { 927 Phi = cast<PHINode>(Offs->getOperand(1)); 928 OffsSecondOp = 0; 929 } else { 930 return false; 931 } 932 } 933 } 934 // A phi node we want to perform this function on should be from the 935 // loop header, and shouldn't have more than 2 incoming values 936 if (Phi->getParent() != L->getHeader() || 937 Phi->getNumIncomingValues() != 2) 938 return false; 939 940 // The phi must be an induction variable 941 int IncrementingBlock = -1; 942 943 for (int i = 0; i < 2; i++) 944 if (auto *Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) 945 if (Op->getOpcode() == Instruction::Add && 946 (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi)) 947 IncrementingBlock = i; 948 if (IncrementingBlock == -1) 949 return false; 950 951 Instruction *IncInstruction = 952 cast<Instruction>(Phi->getIncomingValue(IncrementingBlock)); 953 954 // If the phi is not used by anything else, we can just adapt it when 955 // replacing the instruction; if it is, we'll have to duplicate it 956 PHINode *NewPhi; 957 Value *IncrementPerRound = IncInstruction->getOperand( 958 (IncInstruction->getOperand(0) == Phi) ? 1 : 0); 959 960 // Get the value that is added to/multiplied with the phi 961 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); 962 963 if (IncrementPerRound->getType() != OffsSecondOperand->getType() || 964 !L->isLoopInvariant(OffsSecondOperand)) 965 // Something has gone wrong, abort 966 return false; 967 968 // Only proceed if the increment per round is a constant or an instruction 969 // which does not originate from within the loop 970 if (!isa<Constant>(IncrementPerRound) && 971 !(isa<Instruction>(IncrementPerRound) && 972 !L->contains(cast<Instruction>(IncrementPerRound)))) 973 return false; 974 975 if (Phi->getNumUses() == 2) { 976 // No other users -> reuse existing phi (One user is the instruction 977 // we're looking at, the other is the phi increment) 978 if (IncInstruction->getNumUses() != 1) { 979 // If the incrementing instruction does have more users than 980 // our phi, we need to copy it 981 IncInstruction = BinaryOperator::Create( 982 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, 983 IncrementPerRound, "LoopIncrement", IncInstruction); 984 Phi->setIncomingValue(IncrementingBlock, IncInstruction); 985 } 986 NewPhi = Phi; 987 } else { 988 // There are other users -> create a new phi 989 NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi); 990 std::vector<Value *> Increases; 991 // Copy the incoming values of the old phi 992 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), 993 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); 994 IncInstruction = BinaryOperator::Create( 995 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, 996 IncrementPerRound, "LoopIncrement", IncInstruction); 997 NewPhi->addIncoming(IncInstruction, 998 Phi->getIncomingBlock(IncrementingBlock)); 999 IncrementingBlock = 1; 1000 } 1001 1002 IRBuilder<> Builder(BB->getContext()); 1003 Builder.SetInsertPoint(Phi); 1004 Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); 1005 1006 switch (Offs->getOpcode()) { 1007 case Instruction::Add: 1008 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); 1009 break; 1010 case Instruction::Mul: 1011 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, 1012 Builder); 1013 break; 1014 default: 1015 return false; 1016 } 1017 LLVM_DEBUG( 1018 dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n"); 1019 1020 // The instruction has now been "absorbed" into the phi value 1021 Offs->replaceAllUsesWith(NewPhi); 1022 if (Offs->hasNUses(0)) 1023 Offs->eraseFromParent(); 1024 // Clean up the old increment in case it's unused because we built a new 1025 // one 1026 if (IncInstruction->hasNUses(0)) 1027 IncInstruction->eraseFromParent(); 1028 1029 return true; 1030 } 1031 1032 static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP, 1033 IRBuilder<> &Builder) { 1034 // Splat the non-vector value to a vector of the given type - if the value is 1035 // a constant (and its value isn't too big), we can even use this opportunity 1036 // to scale it to the size of the vector elements 1037 auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) { 1038 ConstantInt *Const; 1039 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) && 1040 VT->getElementType() != NonVectorVal->getType()) { 1041 unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits(); 1042 uint64_t N = Const->getZExtValue(); 1043 if (N < (unsigned)(1 << (TargetElemSize - 1))) { 1044 NonVectorVal = Builder.CreateVectorSplat( 1045 VT->getNumElements(), Builder.getIntN(TargetElemSize, N)); 1046 return; 1047 } 1048 } 1049 NonVectorVal = 1050 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal); 1051 }; 1052 1053 FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType()); 1054 FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType()); 1055 // If one of X, Y is not a vector, we have to splat it in order 1056 // to add the two of them. 1057 if (XElType && !YElType) { 1058 FixSummands(XElType, Y); 1059 YElType = cast<FixedVectorType>(Y->getType()); 1060 } else if (YElType && !XElType) { 1061 FixSummands(YElType, X); 1062 XElType = cast<FixedVectorType>(X->getType()); 1063 } 1064 assert(XElType && YElType && "Unknown vector types"); 1065 // Check that the summands are of compatible types 1066 if (XElType != YElType) { 1067 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n"); 1068 return nullptr; 1069 } 1070 1071 if (XElType->getElementType()->getScalarSizeInBits() != 32) { 1072 // Check that by adding the vectors we do not accidentally 1073 // create an overflow 1074 Constant *ConstX = dyn_cast<Constant>(X); 1075 Constant *ConstY = dyn_cast<Constant>(Y); 1076 if (!ConstX || !ConstY) 1077 return nullptr; 1078 unsigned TargetElemSize = 128 / XElType->getNumElements(); 1079 for (unsigned i = 0; i < XElType->getNumElements(); i++) { 1080 ConstantInt *ConstXEl = 1081 dyn_cast<ConstantInt>(ConstX->getAggregateElement(i)); 1082 ConstantInt *ConstYEl = 1083 dyn_cast<ConstantInt>(ConstY->getAggregateElement(i)); 1084 if (!ConstXEl || !ConstYEl || 1085 ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >= 1086 (unsigned)(1 << (TargetElemSize - 1))) 1087 return nullptr; 1088 } 1089 } 1090 1091 Value *Add = Builder.CreateAdd(X, Y); 1092 1093 FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType()); 1094 if (checkOffsetSize(Add, GEPType->getNumElements())) 1095 return Add; 1096 else 1097 return nullptr; 1098 } 1099 1100 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP, 1101 Value *&Offsets, 1102 IRBuilder<> &Builder) { 1103 Value *GEPPtr = GEP->getPointerOperand(); 1104 Offsets = GEP->getOperand(1); 1105 // We only merge geps with constant offsets, because only for those 1106 // we can make sure that we do not cause an overflow 1107 if (!isa<Constant>(Offsets)) 1108 return nullptr; 1109 GetElementPtrInst *BaseGEP; 1110 if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) { 1111 // Merge the two geps into one 1112 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder); 1113 if (!BaseBasePtr) 1114 return nullptr; 1115 Offsets = 1116 CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder); 1117 if (Offsets == nullptr) 1118 return nullptr; 1119 return BaseBasePtr; 1120 } 1121 return GEPPtr; 1122 } 1123 1124 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB, 1125 LoopInfo *LI) { 1126 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address); 1127 if (!GEP) 1128 return false; 1129 bool Changed = false; 1130 if (GEP->hasOneUse() && 1131 dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) { 1132 IRBuilder<> Builder(GEP->getContext()); 1133 Builder.SetInsertPoint(GEP); 1134 Builder.SetCurrentDebugLocation(GEP->getDebugLoc()); 1135 Value *Offsets; 1136 Value *Base = foldGEP(GEP, Offsets, Builder); 1137 // We only want to merge the geps if there is a real chance that they can be 1138 // used by an MVE gather; thus the offset has to have the correct size 1139 // (always i32 if it is not of vector type) and the base has to be a 1140 // pointer. 1141 if (Offsets && Base && Base != GEP) { 1142 PointerType *BaseType = cast<PointerType>(Base->getType()); 1143 GetElementPtrInst *NewAddress = GetElementPtrInst::Create( 1144 BaseType->getPointerElementType(), Base, Offsets, "gep.merged", GEP); 1145 GEP->replaceAllUsesWith(NewAddress); 1146 GEP = NewAddress; 1147 Changed = true; 1148 } 1149 } 1150 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI); 1151 return Changed; 1152 } 1153 1154 bool MVEGatherScatterLowering::runOnFunction(Function &F) { 1155 if (!EnableMaskedGatherScatters) 1156 return false; 1157 auto &TPC = getAnalysis<TargetPassConfig>(); 1158 auto &TM = TPC.getTM<TargetMachine>(); 1159 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 1160 if (!ST->hasMVEIntegerOps()) 1161 return false; 1162 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1163 SmallVector<IntrinsicInst *, 4> Gathers; 1164 SmallVector<IntrinsicInst *, 4> Scatters; 1165 1166 bool Changed = false; 1167 1168 for (BasicBlock &BB : F) { 1169 Changed |= SimplifyInstructionsInBlock(&BB); 1170 1171 for (Instruction &I : BB) { 1172 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 1173 if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 1174 isa<FixedVectorType>(II->getType())) { 1175 Gathers.push_back(II); 1176 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI); 1177 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 1178 isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 1179 Scatters.push_back(II); 1180 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI); 1181 } 1182 } 1183 } 1184 for (unsigned i = 0; i < Gathers.size(); i++) { 1185 IntrinsicInst *I = Gathers[i]; 1186 Value *L = lowerGather(I); 1187 if (L == nullptr) 1188 continue; 1189 1190 // Get rid of any now dead instructions 1191 SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent()); 1192 Changed = true; 1193 } 1194 1195 for (unsigned i = 0; i < Scatters.size(); i++) { 1196 IntrinsicInst *I = Scatters[i]; 1197 Value *S = lowerScatter(I); 1198 if (S == nullptr) 1199 continue; 1200 1201 // Get rid of any now dead instructions 1202 SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent()); 1203 Changed = true; 1204 } 1205 return Changed; 1206 } 1207