Lines Matching +full:n +full:- +full:factor
1 //===- InterleavedLoadCombine.cpp - Combine Interleaved Loads ---*- C++ -*-===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 // This file defines the interleaved-load-combine pass. The pass searches for
16 // executed just before InterleavedAccesPass to find any left-over instances
19 //===----------------------------------------------------------------------===//
50 #define DEBUG_TYPE "interleaved-load-combine"
59 "disable-" DEBUG_TYPE, cl::init(false), cl::Hidden,
70 TLI(*TM.getSubtargetImpl(F)->getTargetLowering()), TTI(TTI) {} in InterleavedLoadCombineImpl()
103 /// factor, find a set that represents a 'factor' interleaved load.
105 std::list<VectorInfo> &InterleavedLoad, unsigned Factor,
109 /// First Order Polynomial on an n-Bit Integer Value
111 /// Polynomial(Value) = Value * B + A + E*2^(n-e)
113 /// A and B are the coefficients. E*2^(n-e) is an error within 'e' most
160 // Pb_5 - Pa_5 = 16 #0 | subtract to get the offset
175 unsigned ErrorMSBs = (unsigned)-1;
188 IntegerType *Ty = dyn_cast<IntegerType>(V->getType()); in Polynomial()
191 this->V = V; in Polynomial()
192 A = APInt(Ty->getBitWidth(), 0); in Polynomial()
206 if (ErrorMSBs == (unsigned)-1) in incErrorMSBs()
216 if (ErrorMSBs == (unsigned)-1) in decErrorMSBs()
220 ErrorMSBs -= amt; in decErrorMSBs()
240 // (B + A + E*2^(n-e)) + C = B + (A + C) + E*2^(n-e) in add()
244 ErrorMSBs = (unsigned)-1; in add()
260 // (B+A)*C =- in mul()
273 // Let B' and A' be the n-Bit inputs with some unknown errors EA, in mul()
276 // B' = B + 2^(n-e)*EB in mul()
277 // A' = A + 2^(n-e)*EA in mul()
286 // (B'*C' + A'*C') = [B + 2^(n-e)*EB] * C' + [A + 2^(n-e)*EA] * C' = in mul()
287 // = [B + 2^(n-e)*EB + A + 2^(n-e)*EA] * C' = in mul()
289 // = [B + 2^(n-e)*EB + A + 2^(n-e)*EA] * C' = in mul()
290 // = [B + A + 2^(n-e)*EB + 2^(n-e)*EA] * C' = in mul()
291 // = (B + A) * C' + [2^(n-e)*EB + 2^(n-e)*EA)] * C' = in mul()
292 // = (B + A) * C' + [2^(n-e)*EB + 2^(n-e)*EA)] * C*2^c = in mul()
293 // = (B + A) * C' + C*(EB + EA)*2^(n-e)*2^c = in mul()
297 // = (B + A)*C' + EC*2^(n-e)*2^c = in mul()
298 // = (B + A)*C' + EC*2^(n-(e-c)) in mul()
300 // Since EC is multiplied by 2^(n-(e-c)) the resulting error contains c in mul()
305 ErrorMSBs = (unsigned)-1; in mul()
309 // Multiplying by one is a no-op. in mul()
331 // Theorem(1): (B + A + E*2^(n-e)) >> 1 => (B >> 1) + (A >> 1) + E'*2^(n-e') in lshr()
334 // E is a e-bit number, in lshr()
335 // E' is a e'-bit number, in lshr()
338 // pre(2): e < n, (see Theorem(2) for the trivial case with e=n) in lshr()
343 // B = b_h * 2^(n-1) + b_m * 2 + b_l in lshr()
344 // A = a_h * 2^(n-1) + a_m * 2 (pre(1)) in lshr()
346 // where a_h, b_h, b_l are single bits, and a_m, b_m are (n-2) bit numbers in lshr()
348 // Let X = (B + A + E*2^(n-e)) >> 1 in lshr()
349 // Let Y = (B >> 1) + (A >> 1) + E*2^(n-e) >> 1 in lshr()
351 // X = [B + A + E*2^(n-e)] >> 1 = in lshr()
352 // = [ b_h * 2^(n-1) + b_m * 2 + b_l + in lshr()
353 // + a_h * 2^(n-1) + a_m * 2 + in lshr()
354 // + E * 2^(n-e) ] >> 1 = in lshr()
356 // The sum is built by putting the overflow of [a_m + b+n] into the term in lshr()
357 // 2^(n-1). As there are no more bits beyond 2^(n-1) the overflow within in lshr()
362 // = [ ([b_h + a_h + (b_m + a_m) >> (n-2)] % 2) * 2^(n-1) + in lshr()
363 // + ((b_m + a_m) % 2^(n-2)) * 2 + in lshr()
364 // + b_l + E * 2^(n-e) ] >> 1 = in lshr()
369 // = ([b_h + a_h + (b_m + a_m) >> (n-2)] % 2) * 2^(n-2) + in lshr()
370 // + ((b_m + a_m) % 2^(n-2)) + in lshr()
371 // + E * 2^(n-(e+1)) = in lshr()
375 // = ([b_h + a_h + (b_m + a_m) >> (n-2)] % 2) * 2^(n-2) + in lshr()
376 // + ((b_m + a_m) % 2^(n-2)) + in lshr()
377 // + E * 2^(n-e') = in lshr()
381 // Y = (B >> 1) + (A >> 1) + E*2^(n-e') = in lshr()
382 // = (b_h * 2^(n-1) + b_m * 2 + b_l) >> 1 + in lshr()
383 // + (a_h * 2^(n-1) + a_m * 2) >> 1 + in lshr()
384 // + E * 2^(n-e) >> 1 = in lshr()
389 // = b_h * 2^(n-2) + b_m + in lshr()
390 // + a_h * 2^(n-2) + a_m + in lshr()
391 // + E * 2^(n-(e+1)) = in lshr()
393 // Again, the sum is built by putting the overflow of [a_m + b+n] into in lshr()
394 // the term 2^(n-1). But this time there is room for a second bit in the in lshr()
395 // term 2^(n-2) we add this bit to a new term and denote it o_h in a in lshr()
398 // = ([b_h + a_h + (b_m + a_m) >> (n-2)] >> 1) * 2^(n-1) + in lshr()
399 // + ([b_h + a_h + (b_m + a_m) >> (n-2)] % 2) * 2^(n-2) + in lshr()
400 // + ((b_m + a_m) % 2^(n-2)) + in lshr()
401 // + E * 2^(n-(e+1)) = in lshr()
403 // Let o_h = [b_h + a_h + (b_m + a_m) >> (n-2)] >> 1 in lshr()
406 // = o_h * 2^(n-1) + in lshr()
407 // + ([b_h + a_h + (b_m + a_m) >> (n-2)] % 2) * 2^(n-2) + in lshr()
408 // + ((b_m + a_m) % 2^(n-2)) + in lshr()
409 // + E * 2^(n-e') = in lshr()
412 // no 2^x with negative x, this step requires pre(2) (e < n). in lshr()
414 // = ([b_h + a_h + (b_m + a_m) >> (n-2)] % 2) * 2^(n-2) + in lshr()
415 // + ((b_m + a_m) % 2^(n-2)) + in lshr()
416 // + o_h * 2^(e'-1) * 2^(n-e') + | pre(2), move 2^(e'-1) in lshr()
418 // + E * 2^(n-e') = in lshr()
419 // = ([b_h + a_h + (b_m + a_m) >> (n-2)] % 2) * 2^(n-2) + in lshr()
420 // + ((b_m + a_m) % 2^(n-2)) + in lshr()
421 // + [o_h * 2^(e'-1) + E] * 2^(n-e') + | move 2^(e'-1) out of in lshr()
424 // Let E' = o_h * 2^(e'-1) + E in lshr()
426 // = ([b_h + a_h + (b_m + a_m) >> (n-2)] % 2) * 2^(n-2) + in lshr()
427 // + ((b_m + a_m) % 2^(n-2)) + in lshr()
428 // + E' * 2^(n-e') in lshr()
434 // For completeness in case of the case e=n it is also required to show that in lshr()
437 // In this case Theorem(1) transforms to (the pre-condition on A can also be in lshr()
453 // hold. This is trivially the case for E' = X - Y. in lshr()
457 // Remark: Distributing lshr with and arbitrary number n can be expressed as in lshr()
458 // ((((B + A) lshr 1) lshr 1) ... ) {n times}. in lshr()
459 // This construction induces n additional error bits at the left. in lshr()
462 ErrorMSBs = (unsigned)-1; in lshr()
491 /// Apply a sign-extend or truncate operation on the polynomial.
492 Polynomial &sextOrTrunc(unsigned n) { in sextOrTrunc() argument
493 if (n < A.getBitWidth()) { in sextOrTrunc()
496 decErrorMSBs(A.getBitWidth() - n); in sextOrTrunc()
497 A = A.trunc(n); in sextOrTrunc()
498 pushBOperation(Trunc, APInt(sizeof(n) * 8, n)); in sextOrTrunc()
500 if (n > A.getBitWidth()) { in sextOrTrunc()
503 incErrorMSBs(n - A.getBitWidth()); in sextOrTrunc()
504 A = A.sext(n); in sextOrTrunc()
505 pushBOperation(SExt, APInt(sizeof(n) * 8, n)); in sextOrTrunc()
544 Polynomial operator-(const Polynomial &o) const { in operator -()
552 return Polynomial(A - o.A, std::max(ErrorMSBs, o.ErrorMSBs)); in operator -()
556 Polynomial operator-(uint64_t C) const { in operator -()
558 Result.A -= C; in operator -()
572 Polynomial r = *this - o; in isProvenEqualTo()
657 /// Basic-block the load instructions are within
669 /// Final shuffle-vector instruction
679 EI = new ElementInfo[VTy->getNumElements()]; in VectorInfo()
686 unsigned getDimension() const { return VTy->getNumElements(); } in getDimension()
689 /// specified factor.
691 /// \param Factor of the interleave
695 bool isInterleaved(unsigned Factor, const DataLayout &DL) const { in isInterleaved()
696 unsigned Size = DL.getTypeAllocSize(VTy->getElementType()); in isInterleaved()
698 if (!EI[i].Ofs.isProvenEqualTo(EI[0].Ofs + i * Factor * Size)) { in isInterleaved()
734 Instruction *Op = dyn_cast<Instruction>(BCI->getOperand(0)); in computeFromBCI()
739 FixedVectorType *VTy = dyn_cast<FixedVectorType>(Op->getType()); in computeFromBCI()
744 if (Result.VTy->getNumElements() % VTy->getNumElements()) in computeFromBCI()
747 unsigned Factor = Result.VTy->getNumElements() / VTy->getNumElements(); in computeFromBCI() local
748 unsigned NewSize = DL.getTypeAllocSize(Result.VTy->getElementType()); in computeFromBCI()
749 unsigned OldSize = DL.getTypeAllocSize(VTy->getElementType()); in computeFromBCI()
751 if (NewSize * Factor != OldSize) in computeFromBCI()
758 for (unsigned i = 0; i < Result.VTy->getNumElements(); i += Factor) { in computeFromBCI()
759 for (unsigned j = 0; j < Factor; j++) { in computeFromBCI()
761 ElementInfo(Old.EI[i / Factor].Ofs + j * NewSize, in computeFromBCI()
762 j == 0 ? Old.EI[i / Factor].LI : nullptr); in computeFromBCI()
790 cast<FixedVectorType>(SVI->getOperand(0)->getType()); in computeFromSVI()
794 if (!compute(SVI->getOperand(0), LHS, DL)) in computeFromSVI()
799 if (!compute(SVI->getOperand(1), RHS, DL)) in computeFromSVI()
838 for (int i : SVI->getShuffleMask()) { in computeFromSVI()
839 assert((i < 2 * (signed)ArgTy->getNumElements()) && in computeFromSVI()
844 else if (i < (signed)ArgTy->getNumElements()) { in computeFromSVI()
851 Result.EI[j] = RHS.EI[i - ArgTy->getNumElements()]; in computeFromSVI()
874 if (LI->isVolatile()) in computeFromLI()
877 if (LI->isAtomic()) in computeFromLI()
880 if (!DL.typeSizeEqualsStoreSize(Result.VTy->getElementType())) in computeFromLI()
884 computePolynomialFromPointer(*LI->getPointerOperand(), Offset, BasePtr, DL); in computeFromLI()
886 Result.BB = LI->getParent(); in computeFromLI()
893 ConstantInt::get(Type::getInt32Ty(LI->getContext()), 0), in computeFromLI()
894 ConstantInt::get(Type::getInt32Ty(LI->getContext()), i), in computeFromLI()
925 Result.add(C->getValue()); in computePolynomialBinOp()
933 Result.lshr(C->getValue()); in computePolynomialBinOp()
971 DL.getIndexSizeInBits(PtrTy->getPointerAddressSpace()); in computePolynomialFromPointer()
999 // non-constant. in computePolynomialFromPointer()
1037 Polynomial(DL.getIndexSizeInBits(PtrTy->getPointerAddressSpace()), 0); in computePolynomialFromPointer()
1059 unsigned Factor, const DataLayout &DL) { in findPattern() argument
1063 unsigned Size = DL.getTypeAllocSize(C0->VTy->getElementType()); in findPattern()
1066 std::vector<std::list<VectorInfo>::iterator> Res(Factor, Candidates.end()); in findPattern()
1069 if (C->VTy != C0->VTy) in findPattern()
1071 if (C->BB != C0->BB) in findPattern()
1073 if (C->PV != C0->PV) in findPattern()
1076 // Check the current value matches any of factor - 1 remaining lines in findPattern()
1077 for (i = 1; i < Factor; i++) { in findPattern()
1078 if (C->EI[0].Ofs.isProvenEqualTo(C0->EI[0].Ofs + i * Size)) { in findPattern()
1083 for (i = 1; i < Factor; i++) { in findPattern()
1087 if (i == Factor) { in findPattern()
1095 for (unsigned i = 0; i < Factor; i++) { in findPattern()
1110 BasicBlock *BB = (*LIs.begin())->getParent(); in findFirstLoad()
1112 *BB, [&LIs](Instruction &I) -> bool { return is_contained(LIs, &I); }); in findFirstLoad()
1113 assert(FLI != BB->end()); in findFirstLoad()
1120 LLVM_DEBUG(dbgs() << "Checking interleaved load\n"); in combine()
1139 // Get the interleave factor in combine()
1140 unsigned Factor = InterleavedLoad.size(); in combine() local
1173 for (auto *U : I->users()) { in combine()
1192 auto MADef = MSSA.getMemoryAccess(LI)->getDefiningAccess(); in combine()
1207 Type *ETy = InterleavedLoad.front().SVI->getType()->getElementType(); in combine()
1209 cast<FixedVectorType>(InterleavedLoad.front().SVI->getType()) in combine()
1210 ->getNumElements(); in combine()
1211 FixedVectorType *ILTy = FixedVectorType::get(ETy, Factor * ElementsPerSVI); in combine()
1213 auto Indices = llvm::to_vector<4>(llvm::seq<unsigned>(0, Factor)); in combine()
1215 Instruction::Load, ILTy, Factor, Indices, InsertionPoint->getAlign(), in combine()
1216 InsertionPoint->getPointerAddressSpace(), CostKind); in combine()
1223 auto Ptr = InsertionPoint->getPointerOperand(); in combine()
1224 auto LI = Builder.CreateAlignedLoad(ILTy, Ptr, InsertionPoint->getAlign(), in combine()
1236 Mask.push_back(i + j * Factor); in combine()
1240 VI.SVI->replaceAllUsesWith(SVI); in combine()
1247 << "Load interleaved combined with factor " in combine()
1248 << ore::NV("Factor", Factor); in combine()
1261 // Start with the highest factor to avoid combining and recombining. in run()
1262 for (unsigned Factor = MaxFactor; Factor >= 2; Factor--) { in run() local
1269 if (isa<ScalableVectorType>(SVI->getType())) in run()
1272 Candidates.emplace_back(cast<FixedVectorType>(SVI->getType())); in run()
1279 if (!Candidates.back().isInterleaved(Factor, DL)) { in run()
1287 while (findPattern(Candidates, InterleavedLoad, Factor, DL)) { in run()
1327 << "\n"); in runOnFunction()
1333 TPC->getTM<TargetMachine>()) in runOnFunction()