Lines Matching +full:multiply +full:- +full:accumulate
1 //===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The
12 /// DSP intrinsics, which map on these 32-bit SIMD operations.
15 //===----------------------------------------------------------------------===//
42 #define DEBUG_TYPE "arm-parallel-dsp"
47 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
51 NumLoadLimit("arm-parallel-dsp-load-limit", cl::Hidden, cl::init(16),
84 /// Represent a sequence of multiply-accumulate operations with the aim to
104 auto GetMulOperand = [](Value *V) -> Instruction* { in InsertMuls()
106 if (auto *I = dyn_cast<Instruction>(SExt->getOperand(0))) in InsertMuls()
107 if (I->getOpcode() == Instruction::Mul) in InsertMuls()
110 if (I->getOpcode() == Instruction::Mul) in InsertMuls()
117 Value *LHS = cast<Instruction>(I->getOperand(0))->getOperand(0); in InsertMuls()
118 Value *RHS = cast<Instruction>(I->getOperand(1))->getOperand(0); in InsertMuls()
125 if (auto *Mul = GetMulOperand(Add->getOperand(0))) in InsertMuls()
127 if (auto *Mul = GetMulOperand(Add->getOperand(1))) in InsertMuls()
147 << *Mul0->Root << "\n" in AddMulPair()
148 << *Mul1->Root << "\n"); in AddMulPair()
149 Mul0->Paired = true; in AddMulPair()
150 Mul1->Paired = true; in AddMulPair()
152 Mul1->Exchange = true; in AddMulPair()
159 bool is64Bit() const { return Root->getType()->isIntegerTy(64); } in is64Bit()
161 Type *getType() const { return Root->getType(); } in getType()
179 Root->replaceAllUsesWith(SMLAD); in UpdateRoot()
187 LLVM_DEBUG(dbgs() << *Mul->Root << "\n" in dump()
188 << " " << *Mul->LHS << "\n" in dump()
189 << " " << *Mul->RHS << "\n"); in dump()
229 /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
230 /// Dual performs two signed 16x16-bit multiplications. It adds the
231 /// products to a 32-bit accumulate operand. Optionally, the instruction can
267 DL = &M->getDataLayout(); in runOnFunction()
272 if (!ST->allowsUnalignedMem()) { in runOnFunction()
278 if (!ST->hasDSP()) { in runOnFunction()
284 if (!ST->isLittle()) { in runOnFunction()
291 LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n"); in runOnFunction()
308 dbgs() << "Ld0:"; Ld0->dump(); in AreSequentialLoads()
309 dbgs() << "Ld1:"; Ld1->dump(); in AreSequentialLoads()
326 if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) in IsNarrowSequence()
329 if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) { in IsNarrowSequence()
346 // record loads which are simple, sign-extended and have a single user. in RecordMemoryOps()
347 // TODO: Allow zero-extended loads. in RecordMemoryOps()
352 if (!Ld || !Ld->isSimple() || in RecordMemoryOps()
353 !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back())) in RecordMemoryOps()
370 MemoryLocation(Read->getPointerOperand(), Size); in RecordMemoryOps()
372 if (!isModOrRefSet(AA->getModRefInfo(Write, ReadLoc))) in RecordMemoryOps()
374 if (Write->comesBefore(Read)) in RecordMemoryOps()
382 bool BaseFirst = Base->comesBefore(Offset); in RecordMemoryOps()
392 if (Dominator->comesBefore(Before)) in RecordMemoryOps()
425 // form a multiply-accumulate chain. The search records the Add and Mul
429 // If we find a non-instruction, try to use it as the initial accumulator in Search()
436 if (I->getParent() != BB) in Search()
439 switch (I->getOpcode()) { in Search()
450 Value *LHS = I->getOperand(0); in Search()
451 Value *RHS = I->getOperand(1); in Search()
465 Value *MulOp0 = I->getOperand(0); in Search()
466 Value *MulOp1 = I->getOperand(1); in Search()
470 return Search(I->getOperand(0), BB, R); in Search()
475 // The pass needs to identify integer add/sub reductions of 16-bit vector
522 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64)) in MatchSMLAD()
553 if (!MulCand->HasTwoLoadInputs()) in CreateParallelPairs()
562 auto Ld0 = static_cast<LoadInst*>(PMul0->LHS); in CreateParallelPairs()
563 auto Ld1 = static_cast<LoadInst*>(PMul1->LHS); in CreateParallelPairs()
564 auto Ld2 = static_cast<LoadInst*>(PMul0->RHS); in CreateParallelPairs()
565 auto Ld3 = static_cast<LoadInst*>(PMul1->RHS); in CreateParallelPairs()
571 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) { in CreateParallelPairs()
572 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) { in CreateParallelPairs()
576 } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) { in CreateParallelPairs()
582 } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) && in CreateParallelPairs()
583 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) { in CreateParallelPairs()
598 if (PMul0->Paired) in CreateParallelPairs()
606 if (PMul1->Paired) in CreateParallelPairs()
609 const Instruction *Mul0 = PMul0->Root; in CreateParallelPairs()
610 const Instruction *Mul1 = PMul1->Root; in CreateParallelPairs()
633 SMLAD = Acc->getType()->isIntegerTy(32) ? in InsertParallelMACs()
637 SMLAD = Acc->getType()->isIntegerTy(32) ? in InsertParallelMACs()
641 IRBuilder<NoFolder> Builder(InsertAfter->getParent(), in InsertParallelMACs()
659 V = DT->dominates(cast<Instruction>(A), cast<Instruction>(B)) ? B : A; in InsertParallelMACs()
666 // For any muls that were discovered but not paired, accumulate their values in InsertParallelMACs()
668 IRBuilder<NoFolder> Builder(R.getRoot()->getParent()); in InsertParallelMACs()
671 if (MulCand->Paired) in InsertParallelMACs()
674 Instruction *Mul = cast<Instruction>(MulCand->Root); in InsertParallelMACs()
677 if (R.getType() != Mul->getType()) { in InsertParallelMACs()
678 assert(R.is64Bit() && "expected 64-bit result"); in InsertParallelMACs()
680 Mul = cast<Instruction>(Builder.CreateSExt(Mul, R.getRoot()->getType())); in InsertParallelMACs()
697 ConstantInt::get(IntegerType::get(M->getContext(), 64), 0) : in InsertParallelMACs()
698 ConstantInt::get(IntegerType::get(M->getContext(), 32), 0); in InsertParallelMACs()
699 } else if (Acc->getType() != R.getType()) { in InsertParallelMACs()
706 const Instruction *A = PairA.first->Root; in InsertParallelMACs()
707 const Instruction *B = PairB.first->Root; in InsertParallelMACs()
708 return A->comesBefore(B); in InsertParallelMACs()
711 IntegerType *Ty = IntegerType::get(M->getContext(), 32); in InsertParallelMACs()
715 LoadInst *BaseLHS = LHSMul->getBaseLoad(); in InsertParallelMACs()
716 LoadInst *BaseRHS = RHSMul->getBaseLoad(); in InsertParallelMACs()
718 WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty); in InsertParallelMACs()
720 WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty); in InsertParallelMACs()
724 Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter); in InsertParallelMACs()
736 Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back()); in CreateWideLoad()
737 Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back()); in CreateWideLoad()
743 [&](Value *A, Value *B) -> void { in CreateWideLoad()
750 if (DT->dominates(Source, Sink) || in CreateWideLoad()
751 Source->getParent() != Sink->getParent() || in CreateWideLoad()
755 Source->moveBefore(Sink); in CreateWideLoad()
756 for (auto &Op : Source->operands()) in CreateWideLoad()
761 LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset; in CreateWideLoad()
762 IRBuilder<NoFolder> IRB(DomLoad->getParent(), in CreateWideLoad()
768 Value *VecPtr = Base->getPointerOperand(); in CreateWideLoad()
769 LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr, Base->getAlign()); in CreateWideLoad()
772 MoveBefore(Base->getPointerOperand(), VecPtr); in CreateWideLoad()
777 // TODO: Support big-endian as well. in CreateWideLoad()
778 Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType()); in CreateWideLoad()
779 Value *NewBaseSExt = IRB.CreateSExt(Bottom, BaseSExt->getType()); in CreateWideLoad()
780 BaseSExt->replaceAllUsesWith(NewBaseSExt); in CreateWideLoad()
782 IntegerType *OffsetTy = cast<IntegerType>(Offset->getType()); in CreateWideLoad()
783 Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth()); in CreateWideLoad()
786 Value *NewOffsetSExt = IRB.CreateSExt(Trunc, OffsetSExt->getType()); in CreateWideLoad()
787 OffsetSExt->replaceAllUsesWith(NewOffsetSExt); in CreateWideLoad()
809 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
811 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",