xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/DemandedBits.cpp (revision 2c2ec6bbc9cc7762a250ffe903bda6c2e44d25ff)
1 //===- DemandedBits.cpp - Determine demanded bits -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements a demanded bits analysis. A demanded bit is one that
10 // contributes to a result; bits that are not demanded can be either zero or
11 // one without affecting control or data flow. For example in this sequence:
12 //
13 //   %1 = add i32 %x, %y
14 //   %2 = trunc i32 %1 to i16
15 //
16 // Only the lowest 16 bits of %1 are demanded; the rest are removed by the
17 // trunc.
18 //
19 //===----------------------------------------------------------------------===//
20 
21 #include "llvm/Analysis/DemandedBits.h"
22 #include "llvm/ADT/APInt.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/Analysis/AssumptionCache.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/IR/DataLayout.h"
27 #include "llvm/IR/Dominators.h"
28 #include "llvm/IR/InstIterator.h"
29 #include "llvm/IR/Instruction.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Operator.h"
32 #include "llvm/IR/PassManager.h"
33 #include "llvm/IR/PatternMatch.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/IR/Use.h"
36 #include "llvm/Support/Casting.h"
37 #include "llvm/Support/Debug.h"
38 #include "llvm/Support/KnownBits.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include <algorithm>
41 #include <cstdint>
42 
43 using namespace llvm;
44 using namespace llvm::PatternMatch;
45 
46 #define DEBUG_TYPE "demanded-bits"
47 
48 static bool isAlwaysLive(Instruction *I) {
49   return I->isTerminator() || I->isEHPad() || I->mayHaveSideEffects();
50 }
51 
52 void DemandedBits::determineLiveOperandBits(
53     const Instruction *UserI, const Value *Val, unsigned OperandNo,
54     const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2,
55     bool &KnownBitsComputed) {
56   unsigned BitWidth = AB.getBitWidth();
57 
58   // We're called once per operand, but for some instructions, we need to
59   // compute known bits of both operands in order to determine the live bits of
60   // either (when both operands are instructions themselves). We don't,
61   // however, want to do this twice, so we cache the result in APInts that live
62   // in the caller. For the two-relevant-operands case, both operand values are
63   // provided here.
64   auto ComputeKnownBits =
65       [&](unsigned BitWidth, const Value *V1, const Value *V2) {
66         if (KnownBitsComputed)
67           return;
68         KnownBitsComputed = true;
69 
70         const DataLayout &DL = UserI->getDataLayout();
71         Known = KnownBits(BitWidth);
72         computeKnownBits(V1, Known, DL, &AC, UserI, &DT);
73 
74         if (V2) {
75           Known2 = KnownBits(BitWidth);
76           computeKnownBits(V2, Known2, DL, &AC, UserI, &DT);
77         }
78       };
79 
80   switch (UserI->getOpcode()) {
81   default: break;
82   case Instruction::Call:
83   case Instruction::Invoke:
84     if (const auto *II = dyn_cast<IntrinsicInst>(UserI)) {
85       switch (II->getIntrinsicID()) {
86       default: break;
87       case Intrinsic::bswap:
88         // The alive bits of the input are the swapped alive bits of
89         // the output.
90         AB = AOut.byteSwap();
91         break;
92       case Intrinsic::bitreverse:
93         // The alive bits of the input are the reversed alive bits of
94         // the output.
95         AB = AOut.reverseBits();
96         break;
97       case Intrinsic::ctlz:
98         if (OperandNo == 0) {
99           // We need some output bits, so we need all bits of the
100           // input to the left of, and including, the leftmost bit
101           // known to be one.
102           ComputeKnownBits(BitWidth, Val, nullptr);
103           AB = APInt::getHighBitsSet(BitWidth,
104                  std::min(BitWidth, Known.countMaxLeadingZeros()+1));
105         }
106         break;
107       case Intrinsic::cttz:
108         if (OperandNo == 0) {
109           // We need some output bits, so we need all bits of the
110           // input to the right of, and including, the rightmost bit
111           // known to be one.
112           ComputeKnownBits(BitWidth, Val, nullptr);
113           AB = APInt::getLowBitsSet(BitWidth,
114                  std::min(BitWidth, Known.countMaxTrailingZeros()+1));
115         }
116         break;
117       case Intrinsic::fshl:
118       case Intrinsic::fshr: {
119         const APInt *SA;
120         if (OperandNo == 2) {
121           // Shift amount is modulo the bitwidth. For powers of two we have
122           // SA % BW == SA & (BW - 1).
123           if (isPowerOf2_32(BitWidth))
124             AB = BitWidth - 1;
125         } else if (match(II->getOperand(2), m_APInt(SA))) {
126           // Normalize to funnel shift left. APInt shifts of BitWidth are well-
127           // defined, so no need to special-case zero shifts here.
128           uint64_t ShiftAmt = SA->urem(BitWidth);
129           if (II->getIntrinsicID() == Intrinsic::fshr)
130             ShiftAmt = BitWidth - ShiftAmt;
131 
132           if (OperandNo == 0)
133             AB = AOut.lshr(ShiftAmt);
134           else if (OperandNo == 1)
135             AB = AOut.shl(BitWidth - ShiftAmt);
136         }
137         break;
138       }
139       case Intrinsic::umax:
140       case Intrinsic::umin:
141       case Intrinsic::smax:
142       case Intrinsic::smin:
143         // If low bits of result are not demanded, they are also not demanded
144         // for the min/max operands.
145         AB = APInt::getBitsSetFrom(BitWidth, AOut.countr_zero());
146         break;
147       }
148     }
149     break;
150   case Instruction::Add:
151     if (AOut.isMask()) {
152       AB = AOut;
153     } else {
154       ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
155       AB = determineLiveOperandBitsAdd(OperandNo, AOut, Known, Known2);
156     }
157     break;
158   case Instruction::Sub:
159     if (AOut.isMask()) {
160       AB = AOut;
161     } else {
162       ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
163       AB = determineLiveOperandBitsSub(OperandNo, AOut, Known, Known2);
164     }
165     break;
166   case Instruction::Mul:
167     // Find the highest live output bit. We don't need any more input
168     // bits than that (adds, and thus subtracts, ripple only to the
169     // left).
170     AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
171     break;
172   case Instruction::Shl:
173     if (OperandNo == 0) {
174       const APInt *ShiftAmtC;
175       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
176         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
177         AB = AOut.lshr(ShiftAmt);
178 
179         // If the shift is nuw/nsw, then the high bits are not dead
180         // (because we've promised that they *must* be zero).
181         const auto *S = cast<ShlOperator>(UserI);
182         if (S->hasNoSignedWrap())
183           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
184         else if (S->hasNoUnsignedWrap())
185           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
186       }
187     }
188     break;
189   case Instruction::LShr:
190     if (OperandNo == 0) {
191       const APInt *ShiftAmtC;
192       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
193         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
194         AB = AOut.shl(ShiftAmt);
195 
196         // If the shift is exact, then the low bits are not dead
197         // (they must be zero).
198         if (cast<LShrOperator>(UserI)->isExact())
199           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
200       }
201     }
202     break;
203   case Instruction::AShr:
204     if (OperandNo == 0) {
205       const APInt *ShiftAmtC;
206       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
207         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
208         AB = AOut.shl(ShiftAmt);
209         // Because the high input bit is replicated into the
210         // high-order bits of the result, if we need any of those
211         // bits, then we must keep the highest input bit.
212         if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt))
213             .getBoolValue())
214           AB.setSignBit();
215 
216         // If the shift is exact, then the low bits are not dead
217         // (they must be zero).
218         if (cast<AShrOperator>(UserI)->isExact())
219           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
220       }
221     }
222     break;
223   case Instruction::And:
224     AB = AOut;
225 
226     // For bits that are known zero, the corresponding bits in the
227     // other operand are dead (unless they're both zero, in which
228     // case they can't both be dead, so just mark the LHS bits as
229     // dead).
230     ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
231     if (OperandNo == 0)
232       AB &= ~Known2.Zero;
233     else
234       AB &= ~(Known.Zero & ~Known2.Zero);
235     break;
236   case Instruction::Or:
237     AB = AOut;
238 
239     // For bits that are known one, the corresponding bits in the
240     // other operand are dead (unless they're both one, in which
241     // case they can't both be dead, so just mark the LHS bits as
242     // dead).
243     ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
244     if (OperandNo == 0)
245       AB &= ~Known2.One;
246     else
247       AB &= ~(Known.One & ~Known2.One);
248     break;
249   case Instruction::Xor:
250   case Instruction::PHI:
251     AB = AOut;
252     break;
253   case Instruction::Trunc:
254     AB = AOut.zext(BitWidth);
255     break;
256   case Instruction::ZExt:
257     AB = AOut.trunc(BitWidth);
258     break;
259   case Instruction::SExt:
260     AB = AOut.trunc(BitWidth);
261     // Because the high input bit is replicated into the
262     // high-order bits of the result, if we need any of those
263     // bits, then we must keep the highest input bit.
264     if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(),
265                                       AOut.getBitWidth() - BitWidth))
266         .getBoolValue())
267       AB.setSignBit();
268     break;
269   case Instruction::Select:
270     if (OperandNo != 0)
271       AB = AOut;
272     break;
273   case Instruction::ExtractElement:
274     if (OperandNo == 0)
275       AB = AOut;
276     break;
277   case Instruction::InsertElement:
278   case Instruction::ShuffleVector:
279     if (OperandNo == 0 || OperandNo == 1)
280       AB = AOut;
281     break;
282   }
283 }
284 
285 void DemandedBits::performAnalysis() {
286   if (Analyzed)
287     // Analysis already completed for this function.
288     return;
289   Analyzed = true;
290 
291   Visited.clear();
292   AliveBits.clear();
293   DeadUses.clear();
294 
295   SmallSetVector<Instruction*, 16> Worklist;
296 
297   // Collect the set of "root" instructions that are known live.
298   for (Instruction &I : instructions(F)) {
299     if (!isAlwaysLive(&I))
300       continue;
301 
302     LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
303     // For integer-valued instructions, set up an initial empty set of alive
304     // bits and add the instruction to the work list. For other instructions
305     // add their operands to the work list (for integer values operands, mark
306     // all bits as live).
307     Type *T = I.getType();
308     if (T->isIntOrIntVectorTy()) {
309       if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second)
310         Worklist.insert(&I);
311 
312       continue;
313     }
314 
315     // Non-integer-typed instructions...
316     for (Use &OI : I.operands()) {
317       if (auto *J = dyn_cast<Instruction>(OI)) {
318         Type *T = J->getType();
319         if (T->isIntOrIntVectorTy())
320           AliveBits[J] = APInt::getAllOnes(T->getScalarSizeInBits());
321         else
322           Visited.insert(J);
323         Worklist.insert(J);
324       }
325     }
326     // To save memory, we don't add I to the Visited set here. Instead, we
327     // check isAlwaysLive on every instruction when searching for dead
328     // instructions later (we need to check isAlwaysLive for the
329     // integer-typed instructions anyway).
330   }
331 
332   // Propagate liveness backwards to operands.
333   while (!Worklist.empty()) {
334     Instruction *UserI = Worklist.pop_back_val();
335 
336     LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
337     APInt AOut;
338     bool InputIsKnownDead = false;
339     if (UserI->getType()->isIntOrIntVectorTy()) {
340       AOut = AliveBits[UserI];
341       LLVM_DEBUG(dbgs() << " Alive Out: 0x"
342                         << Twine::utohexstr(AOut.getLimitedValue()));
343 
344       // If all bits of the output are dead, then all bits of the input
345       // are also dead.
346       InputIsKnownDead = !AOut && !isAlwaysLive(UserI);
347     }
348     LLVM_DEBUG(dbgs() << "\n");
349 
350     KnownBits Known, Known2;
351     bool KnownBitsComputed = false;
352     // Compute the set of alive bits for each operand. These are anded into the
353     // existing set, if any, and if that changes the set of alive bits, the
354     // operand is added to the work-list.
355     for (Use &OI : UserI->operands()) {
356       // We also want to detect dead uses of arguments, but will only store
357       // demanded bits for instructions.
358       auto *I = dyn_cast<Instruction>(OI);
359       if (!I && !isa<Argument>(OI))
360         continue;
361 
362       Type *T = OI->getType();
363       if (T->isIntOrIntVectorTy()) {
364         unsigned BitWidth = T->getScalarSizeInBits();
365         APInt AB = APInt::getAllOnes(BitWidth);
366         if (InputIsKnownDead) {
367           AB = APInt(BitWidth, 0);
368         } else {
369           // Bits of each operand that are used to compute alive bits of the
370           // output are alive, all others are dead.
371           determineLiveOperandBits(UserI, OI, OI.getOperandNo(), AOut, AB,
372                                    Known, Known2, KnownBitsComputed);
373 
374           // Keep track of uses which have no demanded bits.
375           if (AB.isZero())
376             DeadUses.insert(&OI);
377           else
378             DeadUses.erase(&OI);
379         }
380 
381         if (I) {
382           // If we've added to the set of alive bits (or the operand has not
383           // been previously visited), then re-queue the operand to be visited
384           // again.
385           auto Res = AliveBits.try_emplace(I);
386           if (Res.second || (AB |= Res.first->second) != Res.first->second) {
387             Res.first->second = std::move(AB);
388             Worklist.insert(I);
389           }
390         }
391       } else if (I && Visited.insert(I).second) {
392         Worklist.insert(I);
393       }
394     }
395   }
396 }
397 
398 APInt DemandedBits::getDemandedBits(Instruction *I) {
399   performAnalysis();
400 
401   auto Found = AliveBits.find(I);
402   if (Found != AliveBits.end())
403     return Found->second;
404 
405   const DataLayout &DL = I->getDataLayout();
406   return APInt::getAllOnes(DL.getTypeSizeInBits(I->getType()->getScalarType()));
407 }
408 
409 APInt DemandedBits::getDemandedBits(Use *U) {
410   Type *T = (*U)->getType();
411   auto *UserI = cast<Instruction>(U->getUser());
412   const DataLayout &DL = UserI->getDataLayout();
413   unsigned BitWidth = DL.getTypeSizeInBits(T->getScalarType());
414 
415   // We only track integer uses, everything else produces a mask with all bits
416   // set
417   if (!T->isIntOrIntVectorTy())
418     return APInt::getAllOnes(BitWidth);
419 
420   if (isUseDead(U))
421     return APInt(BitWidth, 0);
422 
423   performAnalysis();
424 
425   APInt AOut = getDemandedBits(UserI);
426   APInt AB = APInt::getAllOnes(BitWidth);
427   KnownBits Known, Known2;
428   bool KnownBitsComputed = false;
429 
430   determineLiveOperandBits(UserI, *U, U->getOperandNo(), AOut, AB, Known,
431                            Known2, KnownBitsComputed);
432 
433   return AB;
434 }
435 
436 bool DemandedBits::isInstructionDead(Instruction *I) {
437   performAnalysis();
438 
439   return !Visited.count(I) && !AliveBits.contains(I) && !isAlwaysLive(I);
440 }
441 
442 bool DemandedBits::isUseDead(Use *U) {
443   // We only track integer uses, everything else is assumed live.
444   if (!(*U)->getType()->isIntOrIntVectorTy())
445     return false;
446 
447   // Uses by always-live instructions are never dead.
448   auto *UserI = cast<Instruction>(U->getUser());
449   if (isAlwaysLive(UserI))
450     return false;
451 
452   performAnalysis();
453   if (DeadUses.count(U))
454     return true;
455 
456   // If no output bits are demanded, no input bits are demanded and the use
457   // is dead. These uses might not be explicitly present in the DeadUses map.
458   if (UserI->getType()->isIntOrIntVectorTy()) {
459     auto Found = AliveBits.find(UserI);
460     if (Found != AliveBits.end() && Found->second.isZero())
461       return true;
462   }
463 
464   return false;
465 }
466 
467 void DemandedBits::print(raw_ostream &OS) {
468   auto PrintDB = [&](const Instruction *I, const APInt &A, Value *V = nullptr) {
469     OS << "DemandedBits: 0x" << Twine::utohexstr(A.getLimitedValue())
470        << " for ";
471     if (V) {
472       V->printAsOperand(OS, false);
473       OS << " in ";
474     }
475     OS << *I << '\n';
476   };
477 
478   OS << "Printing analysis 'Demanded Bits Analysis' for function '" << F.getName() << "':\n";
479   performAnalysis();
480   for (auto &KV : AliveBits) {
481     Instruction *I = KV.first;
482     PrintDB(I, KV.second);
483 
484     for (Use &OI : I->operands()) {
485       PrintDB(I, getDemandedBits(&OI), OI);
486     }
487   }
488 }
489 
490 static APInt determineLiveOperandBitsAddCarry(unsigned OperandNo,
491                                               const APInt &AOut,
492                                               const KnownBits &LHS,
493                                               const KnownBits &RHS,
494                                               bool CarryZero, bool CarryOne) {
495   assert(!(CarryZero && CarryOne) &&
496          "Carry can't be zero and one at the same time");
497 
498   // The following check should be done by the caller, as it also indicates
499   // that LHS and RHS don't need to be computed.
500   //
501   // if (AOut.isMask())
502   //   return AOut;
503 
504   // Boundary bits' carry out is unaffected by their carry in.
505   APInt Bound = (LHS.Zero & RHS.Zero) | (LHS.One & RHS.One);
506 
507   // First, the alive carry bits are determined from the alive output bits:
508   // Let demand ripple to the right but only up to any set bit in Bound.
509   //   AOut         = -1----
510   //   Bound        = ----1-
511   //   ACarry&~AOut = --111-
512   APInt RBound = Bound.reverseBits();
513   APInt RAOut = AOut.reverseBits();
514   APInt RProp = RAOut + (RAOut | ~RBound);
515   APInt RACarry = RProp ^ ~RBound;
516   APInt ACarry = RACarry.reverseBits();
517 
518   // Then, the alive input bits are determined from the alive carry bits:
519   APInt NeededToMaintainCarryZero;
520   APInt NeededToMaintainCarryOne;
521   if (OperandNo == 0) {
522     NeededToMaintainCarryZero = LHS.Zero | ~RHS.Zero;
523     NeededToMaintainCarryOne = LHS.One | ~RHS.One;
524   } else {
525     NeededToMaintainCarryZero = RHS.Zero | ~LHS.Zero;
526     NeededToMaintainCarryOne = RHS.One | ~LHS.One;
527   }
528 
529   // As in computeForAddCarry
530   APInt PossibleSumZero = ~LHS.Zero + ~RHS.Zero + !CarryZero;
531   APInt PossibleSumOne = LHS.One + RHS.One + CarryOne;
532 
533   // The below is simplified from
534   //
535   // APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
536   // APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
537   // APInt CarryUnknown = ~(CarryKnownZero | CarryKnownOne);
538   //
539   // APInt NeededToMaintainCarry =
540   //   (CarryKnownZero & NeededToMaintainCarryZero) |
541   //   (CarryKnownOne  & NeededToMaintainCarryOne) |
542   //   CarryUnknown;
543 
544   APInt NeededToMaintainCarry = (~PossibleSumZero | NeededToMaintainCarryZero) &
545                                 (PossibleSumOne | NeededToMaintainCarryOne);
546 
547   APInt AB = AOut | (ACarry & NeededToMaintainCarry);
548   return AB;
549 }
550 
551 APInt DemandedBits::determineLiveOperandBitsAdd(unsigned OperandNo,
552                                                 const APInt &AOut,
553                                                 const KnownBits &LHS,
554                                                 const KnownBits &RHS) {
555   return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, RHS, true,
556                                           false);
557 }
558 
559 APInt DemandedBits::determineLiveOperandBitsSub(unsigned OperandNo,
560                                                 const APInt &AOut,
561                                                 const KnownBits &LHS,
562                                                 const KnownBits &RHS) {
563   KnownBits NRHS;
564   NRHS.Zero = RHS.One;
565   NRHS.One = RHS.Zero;
566   return determineLiveOperandBitsAddCarry(OperandNo, AOut, LHS, NRHS, false,
567                                           true);
568 }
569 
570 AnalysisKey DemandedBitsAnalysis::Key;
571 
572 DemandedBits DemandedBitsAnalysis::run(Function &F,
573                                              FunctionAnalysisManager &AM) {
574   auto &AC = AM.getResult<AssumptionAnalysis>(F);
575   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
576   return DemandedBits(F, AC, DT);
577 }
578 
579 PreservedAnalyses DemandedBitsPrinterPass::run(Function &F,
580                                                FunctionAnalysisManager &AM) {
581   AM.getResult<DemandedBitsAnalysis>(F).print(OS);
582   return PreservedAnalyses::all();
583 }
584