Lines Matching +full:auto +full:- +full:load

1 //===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
15 //===----------------------------------------------------------------------===//
42 #define DEBUG_TYPE "arm-parallel-dsp"
47 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
51 NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16),
84 /// Represent a sequence of multiply-accumulate operations with the aim to
104 auto GetMulOperand = [](Value *V) -> Instruction* { in InsertMuls()
105 if (auto *SExt = dyn_cast<SExtInst>(V)) { in InsertMuls()
106 if (auto *I = dyn_cast<Instruction>(SExt->getOperand(0))) in InsertMuls()
107 if (I->getOpcode() == Instruction::Mul) in InsertMuls()
109 } else if (auto *I = dyn_cast<Instruction>(V)) { in InsertMuls()
110 if (I->getOpcode() == Instruction::Mul) in InsertMuls()
116 auto InsertMul = [this](Instruction *I) { in InsertMuls()
117 Value *LHS = cast<Instruction>(I->getOperand(0))->getOperand(0); in InsertMuls()
118 Value *RHS = cast<Instruction>(I->getOperand(1))->getOperand(0); in InsertMuls()
122 for (auto *Add : Adds) { in InsertMuls()
125 if (auto *Mul = GetMulOperand(Add->getOperand(0))) in InsertMuls()
127 if (auto *Mul = GetMulOperand(Add->getOperand(1))) in InsertMuls()
147 << *Mul0->Root << "\n" in AddMulPair()
148 << *Mul1->Root << "\n"); in AddMulPair()
149 Mul0->Paired = true; in AddMulPair()
150 Mul1->Paired = true; in AddMulPair()
152 Mul1->Exchange = true; in AddMulPair()
159 bool is64Bit() const { return Root->getType()->isIntegerTy(64); } in is64Bit()
161 Type *getType() const { return Root->getType(); } in getType()
179 Root->replaceAllUsesWith(SMLAD); in UpdateRoot()
184 for (auto *Add : Adds) in dump()
186 for (auto &Mul : Muls) in dump()
187 LLVM_DEBUG(dbgs() << *Mul->Root << "\n" in dump()
188 << " " << *Mul->LHS << "\n" in dump()
189 << " " << *Mul->RHS << "\n"); in dump()
229 /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
230 /// Dual performs two signed 16x16-bit multiplications. It adds the
231 /// products to a 32-bit accumulate operand. Optionally, the instruction can
264 auto &TPC = getAnalysis<TargetPassConfig>(); in runOnFunction()
267 DL = &M->getDataLayout(); in runOnFunction()
269 auto &TM = TPC.getTM<TargetMachine>(); in runOnFunction()
270 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); in runOnFunction()
272 if (!ST->allowsUnalignedMem()) { in runOnFunction()
278 if (!ST->hasDSP()) { in runOnFunction()
284 if (!ST->isLittle()) { in runOnFunction()
291 LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n"); in runOnFunction()
308 dbgs() << "Ld0:"; Ld0->dump(); in AreSequentialLoads()
309 dbgs() << "Ld1:"; Ld1->dump(); in AreSequentialLoads()
325 if (auto *SExt = dyn_cast<SExtInst>(V)) { in IsNarrowSequence()
326 if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) in IsNarrowSequence()
329 if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) { in IsNarrowSequence()
330 // Check that this load could be paired. in IsNarrowSequence()
338 /// be widened into a single load.
346 // record loads which are simple, sign-extended and have a single user. in RecordMemoryOps()
347 // TODO: Allow zero-extended loads. in RecordMemoryOps()
348 for (auto &I : *BB) { in RecordMemoryOps()
351 auto *Ld = dyn_cast<LoadInst>(&I); in RecordMemoryOps()
352 if (!Ld || !Ld->isSimple() || in RecordMemoryOps()
353 !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back())) in RecordMemoryOps()
365 // Record any writes that may alias a load. in RecordMemoryOps()
366 const auto Size = LocationSize::beforeOrAfterPointer(); in RecordMemoryOps()
367 for (auto *Write : Writes) { in RecordMemoryOps()
368 for (auto *Read : Loads) { in RecordMemoryOps()
370 MemoryLocation(Read->getPointerOperand(), Size); in RecordMemoryOps()
372 if (!isModOrRefSet(AA->getModRefInfo(Write, ReadLoc))) in RecordMemoryOps()
374 if (Write->comesBefore(Read)) in RecordMemoryOps()
381 auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) { in RecordMemoryOps()
382 bool BaseFirst = Base->comesBefore(Offset); in RecordMemoryOps()
389 for (auto *Before : WritesBefore) { in RecordMemoryOps()
390 // We can't move the second load backward, past a write, to merge in RecordMemoryOps()
391 // with the first load. in RecordMemoryOps()
392 if (Dominator->comesBefore(Before)) in RecordMemoryOps()
399 // Record base, offset load pairs. in RecordMemoryOps()
400 for (auto *Base : Loads) { in RecordMemoryOps()
401 for (auto *Offset : Loads) { in RecordMemoryOps()
415 dbgs() << "Consecutive load pairs:\n"; in RecordMemoryOps()
416 for (auto &MapIt : LoadPairs) { in RecordMemoryOps()
425 // form a multiply-accumulate chain. The search records the Add and Mul
429 // If we find a non-instruction, try to use it as the initial accumulator in Search()
432 auto *I = dyn_cast<Instruction>(V); in Search()
436 if (I->getParent() != BB) in Search()
439 switch (I->getOpcode()) { in Search()
450 Value *LHS = I->getOperand(0); in Search()
451 Value *RHS = I->getOperand(1); in Search()
465 Value *MulOp0 = I->getOperand(0); in Search()
466 Value *MulOp1 = I->getOperand(1); in Search()
470 return Search(I->getOperand(0), BB, R); in Search()
475 // The pass needs to identify integer add/sub reductions of 16-bit vector
481 // ld0 = load i16
483 // ld1 = load i16
486 // ld2 = load i16
488 // ld3 = load i16
509 for (auto &BB : F) { in MatchSMLAD()
521 const auto *Ty = I.getType(); in MatchSMLAD()
522 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64)) in MatchSMLAD()
552 for (auto &MulCand : R.getMuls()) { in CreateParallelPairs()
553 if (!MulCand->HasTwoLoadInputs()) in CreateParallelPairs()
557 auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) { in CreateParallelPairs()
562 auto Ld0 = static_cast<LoadInst*>(PMul0->LHS); in CreateParallelPairs()
563 auto Ld1 = static_cast<LoadInst*>(PMul1->LHS); in CreateParallelPairs()
564 auto Ld2 = static_cast<LoadInst*>(PMul0->RHS); in CreateParallelPairs()
565 auto Ld3 = static_cast<LoadInst*>(PMul1->RHS); in CreateParallelPairs()
571 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) { in CreateParallelPairs()
572 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) { in CreateParallelPairs()
576 } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) { in CreateParallelPairs()
582 } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) && in CreateParallelPairs()
583 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) { in CreateParallelPairs()
598 if (PMul0->Paired) in CreateParallelPairs()
606 if (PMul1->Paired) in CreateParallelPairs()
609 const Instruction *Mul0 = PMul0->Root; in CreateParallelPairs()
610 const Instruction *Mul1 = PMul1->Root; in CreateParallelPairs()
625 auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1, in InsertParallelMACs()
633 SMLAD = Acc->getType()->isIntegerTy(32) ? in InsertParallelMACs()
637 SMLAD = Acc->getType()->isIntegerTy(32) ? in InsertParallelMACs()
641 IRBuilder<NoFolder> Builder(InsertAfter->getParent(), in InsertParallelMACs()
649 auto GetInsertPoint = [this](Value *A, Value *B) { in InsertParallelMACs()
659 V = DT->dominates(cast<Instruction>(A), cast<Instruction>(B)) ? B : A; in InsertParallelMACs()
668 IRBuilder<NoFolder> Builder(R.getRoot()->getParent()); in InsertParallelMACs()
670 for (auto &MulCand : MulCands) { in InsertParallelMACs()
671 if (MulCand->Paired) in InsertParallelMACs()
674 Instruction *Mul = cast<Instruction>(MulCand->Root); in InsertParallelMACs()
677 if (R.getType() != Mul->getType()) { in InsertParallelMACs()
678 assert(R.is64Bit() && "expected 64-bit result"); in InsertParallelMACs()
680 Mul = cast<Instruction>(Builder.CreateSExt(Mul, R.getRoot()->getType())); in InsertParallelMACs()
697 ConstantInt::get(IntegerType::get(M->getContext(), 64), 0) : in InsertParallelMACs()
698 ConstantInt::get(IntegerType::get(M->getContext(), 32), 0); in InsertParallelMACs()
699 } else if (Acc->getType() != R.getType()) { in InsertParallelMACs()
705 llvm::sort(R.getMulPairs(), [](auto &PairA, auto &PairB) { in InsertParallelMACs()
706 const Instruction *A = PairA.first->Root; in InsertParallelMACs()
707 const Instruction *B = PairB.first->Root; in InsertParallelMACs()
708 return A->comesBefore(B); in InsertParallelMACs()
711 IntegerType *Ty = IntegerType::get(M->getContext(), 32); in InsertParallelMACs()
712 for (auto &Pair : R.getMulPairs()) { in InsertParallelMACs()
715 LoadInst *BaseLHS = LHSMul->getBaseLoad(); in InsertParallelMACs()
716 LoadInst *BaseRHS = RHSMul->getBaseLoad(); in InsertParallelMACs()
718 WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty); in InsertParallelMACs()
720 WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty); in InsertParallelMACs()
724 Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter); in InsertParallelMACs()
736 Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back()); in CreateWideLoad()
737 Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back()); in CreateWideLoad()
743 [&](Value *A, Value *B) -> void { in CreateWideLoad()
747 auto *Source = cast<Instruction>(A); in CreateWideLoad()
748 auto *Sink = cast<Instruction>(B); in CreateWideLoad()
750 if (DT->dominates(Source, Sink) || in CreateWideLoad()
751 Source->getParent() != Sink->getParent() || in CreateWideLoad()
755 Source->moveBefore(Sink); in CreateWideLoad()
756 for (auto &Op : Source->operands()) in CreateWideLoad()
760 // Insert the load at the point of the original dominating load. in CreateWideLoad()
761 LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset; in CreateWideLoad()
762 IRBuilder<NoFolder> IRB(DomLoad->getParent(), in CreateWideLoad()
765 // Create the wide load, while making sure to maintain the original alignment in CreateWideLoad()
768 Value *VecPtr = Base->getPointerOperand(); in CreateWideLoad()
769 LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr, Base->getAlign()); in CreateWideLoad()
772 MoveBefore(Base->getPointerOperand(), VecPtr); in CreateWideLoad()
775 // From the wide load, create two values that equal the original two loads. in CreateWideLoad()
777 // TODO: Support big-endian as well. in CreateWideLoad()
778 Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType()); in CreateWideLoad()
779 Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->getType()); in CreateWideLoad()
780 BaseSExt->replaceAllUsesWith(NewBaseSExt); in CreateWideLoad()
782 IntegerType *OffsetTy = cast<IntegerType>(Offset->getType()); in CreateWideLoad()
783 Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth()); in CreateWideLoad()
786 Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->getType()); in CreateWideLoad()
787 OffsetSExt->replaceAllUsesWith(NewOffsetSExt); in CreateWideLoad()
791 << "Created Wide Load:\n" in CreateWideLoad()
809 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
811 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",