xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/DemandedBits.cpp (revision 2f513db72b034fd5ef7f080b11be5c711c15186a)
1 //===- DemandedBits.cpp - Determine demanded bits -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements a demanded bits analysis. A demanded bit is one that
10 // contributes to a result; bits that are not demanded can be either zero or
11 // one without affecting control or data flow. For example in this sequence:
12 //
13 //   %1 = add i32 %x, %y
14 //   %2 = trunc i32 %1 to i16
15 //
16 // Only the lowest 16 bits of %1 are demanded; the rest are removed by the
17 // trunc.
18 //
19 //===----------------------------------------------------------------------===//
20 
21 #include "llvm/Analysis/DemandedBits.h"
22 #include "llvm/ADT/APInt.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/Analysis/AssumptionCache.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/Constants.h"
29 #include "llvm/IR/DataLayout.h"
30 #include "llvm/IR/DerivedTypes.h"
31 #include "llvm/IR/Dominators.h"
32 #include "llvm/IR/InstIterator.h"
33 #include "llvm/IR/InstrTypes.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/IR/Intrinsics.h"
37 #include "llvm/IR/Module.h"
38 #include "llvm/IR/Operator.h"
39 #include "llvm/IR/PassManager.h"
40 #include "llvm/IR/PatternMatch.h"
41 #include "llvm/IR/Type.h"
42 #include "llvm/IR/Use.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Support/Casting.h"
45 #include "llvm/Support/Debug.h"
46 #include "llvm/Support/KnownBits.h"
47 #include "llvm/Support/raw_ostream.h"
48 #include <algorithm>
49 #include <cstdint>
50 
51 using namespace llvm;
52 using namespace llvm::PatternMatch;
53 
54 #define DEBUG_TYPE "demanded-bits"
55 
56 char DemandedBitsWrapperPass::ID = 0;
57 
58 INITIALIZE_PASS_BEGIN(DemandedBitsWrapperPass, "demanded-bits",
59                       "Demanded bits analysis", false, false)
60 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
61 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
62 INITIALIZE_PASS_END(DemandedBitsWrapperPass, "demanded-bits",
63                     "Demanded bits analysis", false, false)
64 
65 DemandedBitsWrapperPass::DemandedBitsWrapperPass() : FunctionPass(ID) {
66   initializeDemandedBitsWrapperPassPass(*PassRegistry::getPassRegistry());
67 }
68 
69 void DemandedBitsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
70   AU.setPreservesCFG();
71   AU.addRequired<AssumptionCacheTracker>();
72   AU.addRequired<DominatorTreeWrapperPass>();
73   AU.setPreservesAll();
74 }
75 
76 void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const {
77   DB->print(OS);
78 }
79 
80 static bool isAlwaysLive(Instruction *I) {
81   return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() ||
82          I->mayHaveSideEffects();
83 }
84 
85 void DemandedBits::determineLiveOperandBits(
86     const Instruction *UserI, const Value *Val, unsigned OperandNo,
87     const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2,
88     bool &KnownBitsComputed) {
89   unsigned BitWidth = AB.getBitWidth();
90 
91   // We're called once per operand, but for some instructions, we need to
92   // compute known bits of both operands in order to determine the live bits of
93   // either (when both operands are instructions themselves). We don't,
94   // however, want to do this twice, so we cache the result in APInts that live
95   // in the caller. For the two-relevant-operands case, both operand values are
96   // provided here.
97   auto ComputeKnownBits =
98       [&](unsigned BitWidth, const Value *V1, const Value *V2) {
99         if (KnownBitsComputed)
100           return;
101         KnownBitsComputed = true;
102 
103         const DataLayout &DL = UserI->getModule()->getDataLayout();
104         Known = KnownBits(BitWidth);
105         computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT);
106 
107         if (V2) {
108           Known2 = KnownBits(BitWidth);
109           computeKnownBits(V2, Known2, DL, 0, &AC, UserI, &DT);
110         }
111       };
112 
113   switch (UserI->getOpcode()) {
114   default: break;
115   case Instruction::Call:
116   case Instruction::Invoke:
117     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(UserI))
118       switch (II->getIntrinsicID()) {
119       default: break;
120       case Intrinsic::bswap:
121         // The alive bits of the input are the swapped alive bits of
122         // the output.
123         AB = AOut.byteSwap();
124         break;
125       case Intrinsic::bitreverse:
126         // The alive bits of the input are the reversed alive bits of
127         // the output.
128         AB = AOut.reverseBits();
129         break;
130       case Intrinsic::ctlz:
131         if (OperandNo == 0) {
132           // We need some output bits, so we need all bits of the
133           // input to the left of, and including, the leftmost bit
134           // known to be one.
135           ComputeKnownBits(BitWidth, Val, nullptr);
136           AB = APInt::getHighBitsSet(BitWidth,
137                  std::min(BitWidth, Known.countMaxLeadingZeros()+1));
138         }
139         break;
140       case Intrinsic::cttz:
141         if (OperandNo == 0) {
142           // We need some output bits, so we need all bits of the
143           // input to the right of, and including, the rightmost bit
144           // known to be one.
145           ComputeKnownBits(BitWidth, Val, nullptr);
146           AB = APInt::getLowBitsSet(BitWidth,
147                  std::min(BitWidth, Known.countMaxTrailingZeros()+1));
148         }
149         break;
150       case Intrinsic::fshl:
151       case Intrinsic::fshr: {
152         const APInt *SA;
153         if (OperandNo == 2) {
154           // Shift amount is modulo the bitwidth. For powers of two we have
155           // SA % BW == SA & (BW - 1).
156           if (isPowerOf2_32(BitWidth))
157             AB = BitWidth - 1;
158         } else if (match(II->getOperand(2), m_APInt(SA))) {
159           // Normalize to funnel shift left. APInt shifts of BitWidth are well-
160           // defined, so no need to special-case zero shifts here.
161           uint64_t ShiftAmt = SA->urem(BitWidth);
162           if (II->getIntrinsicID() == Intrinsic::fshr)
163             ShiftAmt = BitWidth - ShiftAmt;
164 
165           if (OperandNo == 0)
166             AB = AOut.lshr(ShiftAmt);
167           else if (OperandNo == 1)
168             AB = AOut.shl(BitWidth - ShiftAmt);
169         }
170         break;
171       }
172       }
173     break;
174   case Instruction::Add:
175   case Instruction::Sub:
176   case Instruction::Mul:
177     // Find the highest live output bit. We don't need any more input
178     // bits than that (adds, and thus subtracts, ripple only to the
179     // left).
180     AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
181     break;
182   case Instruction::Shl:
183     if (OperandNo == 0) {
184       const APInt *ShiftAmtC;
185       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
186         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
187         AB = AOut.lshr(ShiftAmt);
188 
189         // If the shift is nuw/nsw, then the high bits are not dead
190         // (because we've promised that they *must* be zero).
191         const ShlOperator *S = cast<ShlOperator>(UserI);
192         if (S->hasNoSignedWrap())
193           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
194         else if (S->hasNoUnsignedWrap())
195           AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
196       }
197     }
198     break;
199   case Instruction::LShr:
200     if (OperandNo == 0) {
201       const APInt *ShiftAmtC;
202       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
203         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
204         AB = AOut.shl(ShiftAmt);
205 
206         // If the shift is exact, then the low bits are not dead
207         // (they must be zero).
208         if (cast<LShrOperator>(UserI)->isExact())
209           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
210       }
211     }
212     break;
213   case Instruction::AShr:
214     if (OperandNo == 0) {
215       const APInt *ShiftAmtC;
216       if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) {
217         uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1);
218         AB = AOut.shl(ShiftAmt);
219         // Because the high input bit is replicated into the
220         // high-order bits of the result, if we need any of those
221         // bits, then we must keep the highest input bit.
222         if ((AOut & APInt::getHighBitsSet(BitWidth, ShiftAmt))
223             .getBoolValue())
224           AB.setSignBit();
225 
226         // If the shift is exact, then the low bits are not dead
227         // (they must be zero).
228         if (cast<AShrOperator>(UserI)->isExact())
229           AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
230       }
231     }
232     break;
233   case Instruction::And:
234     AB = AOut;
235 
236     // For bits that are known zero, the corresponding bits in the
237     // other operand are dead (unless they're both zero, in which
238     // case they can't both be dead, so just mark the LHS bits as
239     // dead).
240     ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
241     if (OperandNo == 0)
242       AB &= ~Known2.Zero;
243     else
244       AB &= ~(Known.Zero & ~Known2.Zero);
245     break;
246   case Instruction::Or:
247     AB = AOut;
248 
249     // For bits that are known one, the corresponding bits in the
250     // other operand are dead (unless they're both one, in which
251     // case they can't both be dead, so just mark the LHS bits as
252     // dead).
253     ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1));
254     if (OperandNo == 0)
255       AB &= ~Known2.One;
256     else
257       AB &= ~(Known.One & ~Known2.One);
258     break;
259   case Instruction::Xor:
260   case Instruction::PHI:
261     AB = AOut;
262     break;
263   case Instruction::Trunc:
264     AB = AOut.zext(BitWidth);
265     break;
266   case Instruction::ZExt:
267     AB = AOut.trunc(BitWidth);
268     break;
269   case Instruction::SExt:
270     AB = AOut.trunc(BitWidth);
271     // Because the high input bit is replicated into the
272     // high-order bits of the result, if we need any of those
273     // bits, then we must keep the highest input bit.
274     if ((AOut & APInt::getHighBitsSet(AOut.getBitWidth(),
275                                       AOut.getBitWidth() - BitWidth))
276         .getBoolValue())
277       AB.setSignBit();
278     break;
279   case Instruction::Select:
280     if (OperandNo != 0)
281       AB = AOut;
282     break;
283   case Instruction::ExtractElement:
284     if (OperandNo == 0)
285       AB = AOut;
286     break;
287   case Instruction::InsertElement:
288   case Instruction::ShuffleVector:
289     if (OperandNo == 0 || OperandNo == 1)
290       AB = AOut;
291     break;
292   }
293 }
294 
295 bool DemandedBitsWrapperPass::runOnFunction(Function &F) {
296   auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
297   auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
298   DB.emplace(F, AC, DT);
299   return false;
300 }
301 
302 void DemandedBitsWrapperPass::releaseMemory() {
303   DB.reset();
304 }
305 
306 void DemandedBits::performAnalysis() {
307   if (Analyzed)
308     // Analysis already completed for this function.
309     return;
310   Analyzed = true;
311 
312   Visited.clear();
313   AliveBits.clear();
314   DeadUses.clear();
315 
316   SmallSetVector<Instruction*, 16> Worklist;
317 
318   // Collect the set of "root" instructions that are known live.
319   for (Instruction &I : instructions(F)) {
320     if (!isAlwaysLive(&I))
321       continue;
322 
323     LLVM_DEBUG(dbgs() << "DemandedBits: Root: " << I << "\n");
324     // For integer-valued instructions, set up an initial empty set of alive
325     // bits and add the instruction to the work list. For other instructions
326     // add their operands to the work list (for integer values operands, mark
327     // all bits as live).
328     Type *T = I.getType();
329     if (T->isIntOrIntVectorTy()) {
330       if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second)
331         Worklist.insert(&I);
332 
333       continue;
334     }
335 
336     // Non-integer-typed instructions...
337     for (Use &OI : I.operands()) {
338       if (Instruction *J = dyn_cast<Instruction>(OI)) {
339         Type *T = J->getType();
340         if (T->isIntOrIntVectorTy())
341           AliveBits[J] = APInt::getAllOnesValue(T->getScalarSizeInBits());
342         else
343           Visited.insert(J);
344         Worklist.insert(J);
345       }
346     }
347     // To save memory, we don't add I to the Visited set here. Instead, we
348     // check isAlwaysLive on every instruction when searching for dead
349     // instructions later (we need to check isAlwaysLive for the
350     // integer-typed instructions anyway).
351   }
352 
353   // Propagate liveness backwards to operands.
354   while (!Worklist.empty()) {
355     Instruction *UserI = Worklist.pop_back_val();
356 
357     LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI);
358     APInt AOut;
359     bool InputIsKnownDead = false;
360     if (UserI->getType()->isIntOrIntVectorTy()) {
361       AOut = AliveBits[UserI];
362       LLVM_DEBUG(dbgs() << " Alive Out: 0x"
363                         << Twine::utohexstr(AOut.getLimitedValue()));
364 
365       // If all bits of the output are dead, then all bits of the input
366       // are also dead.
367       InputIsKnownDead = !AOut && !isAlwaysLive(UserI);
368     }
369     LLVM_DEBUG(dbgs() << "\n");
370 
371     KnownBits Known, Known2;
372     bool KnownBitsComputed = false;
373     // Compute the set of alive bits for each operand. These are anded into the
374     // existing set, if any, and if that changes the set of alive bits, the
375     // operand is added to the work-list.
376     for (Use &OI : UserI->operands()) {
377       // We also want to detect dead uses of arguments, but will only store
378       // demanded bits for instructions.
379       Instruction *I = dyn_cast<Instruction>(OI);
380       if (!I && !isa<Argument>(OI))
381         continue;
382 
383       Type *T = OI->getType();
384       if (T->isIntOrIntVectorTy()) {
385         unsigned BitWidth = T->getScalarSizeInBits();
386         APInt AB = APInt::getAllOnesValue(BitWidth);
387         if (InputIsKnownDead) {
388           AB = APInt(BitWidth, 0);
389         } else {
390           // Bits of each operand that are used to compute alive bits of the
391           // output are alive, all others are dead.
392           determineLiveOperandBits(UserI, OI, OI.getOperandNo(), AOut, AB,
393                                    Known, Known2, KnownBitsComputed);
394 
395           // Keep track of uses which have no demanded bits.
396           if (AB.isNullValue())
397             DeadUses.insert(&OI);
398           else
399             DeadUses.erase(&OI);
400         }
401 
402         if (I) {
403           // If we've added to the set of alive bits (or the operand has not
404           // been previously visited), then re-queue the operand to be visited
405           // again.
406           auto Res = AliveBits.try_emplace(I);
407           if (Res.second || (AB |= Res.first->second) != Res.first->second) {
408             Res.first->second = std::move(AB);
409             Worklist.insert(I);
410           }
411         }
412       } else if (I && Visited.insert(I).second) {
413         Worklist.insert(I);
414       }
415     }
416   }
417 }
418 
419 APInt DemandedBits::getDemandedBits(Instruction *I) {
420   performAnalysis();
421 
422   auto Found = AliveBits.find(I);
423   if (Found != AliveBits.end())
424     return Found->second;
425 
426   const DataLayout &DL = I->getModule()->getDataLayout();
427   return APInt::getAllOnesValue(
428       DL.getTypeSizeInBits(I->getType()->getScalarType()));
429 }
430 
431 bool DemandedBits::isInstructionDead(Instruction *I) {
432   performAnalysis();
433 
434   return !Visited.count(I) && AliveBits.find(I) == AliveBits.end() &&
435     !isAlwaysLive(I);
436 }
437 
438 bool DemandedBits::isUseDead(Use *U) {
439   // We only track integer uses, everything else is assumed live.
440   if (!(*U)->getType()->isIntOrIntVectorTy())
441     return false;
442 
443   // Uses by always-live instructions are never dead.
444   Instruction *UserI = cast<Instruction>(U->getUser());
445   if (isAlwaysLive(UserI))
446     return false;
447 
448   performAnalysis();
449   if (DeadUses.count(U))
450     return true;
451 
452   // If no output bits are demanded, no input bits are demanded and the use
453   // is dead. These uses might not be explicitly present in the DeadUses map.
454   if (UserI->getType()->isIntOrIntVectorTy()) {
455     auto Found = AliveBits.find(UserI);
456     if (Found != AliveBits.end() && Found->second.isNullValue())
457       return true;
458   }
459 
460   return false;
461 }
462 
463 void DemandedBits::print(raw_ostream &OS) {
464   performAnalysis();
465   for (auto &KV : AliveBits) {
466     OS << "DemandedBits: 0x" << Twine::utohexstr(KV.second.getLimitedValue())
467        << " for " << *KV.first << '\n';
468   }
469 }
470 
471 FunctionPass *llvm::createDemandedBitsWrapperPass() {
472   return new DemandedBitsWrapperPass();
473 }
474 
475 AnalysisKey DemandedBitsAnalysis::Key;
476 
477 DemandedBits DemandedBitsAnalysis::run(Function &F,
478                                              FunctionAnalysisManager &AM) {
479   auto &AC = AM.getResult<AssumptionAnalysis>(F);
480   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
481   return DemandedBits(F, AC, DT);
482 }
483 
484 PreservedAnalyses DemandedBitsPrinterPass::run(Function &F,
485                                                FunctionAnalysisManager &AM) {
486   AM.getResult<DemandedBitsAnalysis>(F).print(OS);
487   return PreservedAnalyses::all();
488 }
489