xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp (revision a521f2116473fbd8c09db395518f060a27d02334)
1 //===- lib/CodeGen/GlobalISel/GISelKnownBits.cpp --------------*- C++ *-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// Provides analysis for querying information about KnownBits during GISel
10 /// passes.
11 //
12 //===------------------
13 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
14 #include "llvm/Analysis/TargetTransformInfo.h"
15 #include "llvm/Analysis/ValueTracking.h"
16 #include "llvm/CodeGen/GlobalISel/Utils.h"
17 #include "llvm/CodeGen/MachineFrameInfo.h"
18 #include "llvm/CodeGen/MachineRegisterInfo.h"
19 #include "llvm/CodeGen/TargetLowering.h"
20 #include "llvm/CodeGen/TargetOpcodes.h"
21 
22 #define DEBUG_TYPE "gisel-known-bits"
23 
24 using namespace llvm;
25 
26 char llvm::GISelKnownBitsAnalysis::ID = 0;
27 
28 INITIALIZE_PASS(GISelKnownBitsAnalysis, DEBUG_TYPE,
29                 "Analysis for ComputingKnownBits", false, true)
30 
31 GISelKnownBits::GISelKnownBits(MachineFunction &MF, unsigned MaxDepth)
32     : MF(MF), MRI(MF.getRegInfo()), TL(*MF.getSubtarget().getTargetLowering()),
33       DL(MF.getFunction().getParent()->getDataLayout()), MaxDepth(MaxDepth) {}
34 
35 Align GISelKnownBits::computeKnownAlignment(Register R, unsigned Depth) {
36   const MachineInstr *MI = MRI.getVRegDef(R);
37   switch (MI->getOpcode()) {
38   case TargetOpcode::COPY:
39     return computeKnownAlignment(MI->getOperand(1).getReg(), Depth);
40   case TargetOpcode::G_FRAME_INDEX: {
41     int FrameIdx = MI->getOperand(1).getIndex();
42     return MF.getFrameInfo().getObjectAlign(FrameIdx);
43   }
44   case TargetOpcode::G_INTRINSIC:
45   case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
46   default:
47     return TL.computeKnownAlignForTargetInstr(*this, R, MRI, Depth + 1);
48   }
49 }
50 
51 KnownBits GISelKnownBits::getKnownBits(MachineInstr &MI) {
52   assert(MI.getNumExplicitDefs() == 1 &&
53          "expected single return generic instruction");
54   return getKnownBits(MI.getOperand(0).getReg());
55 }
56 
57 KnownBits GISelKnownBits::getKnownBits(Register R) {
58   const LLT Ty = MRI.getType(R);
59   APInt DemandedElts =
60       Ty.isVector() ? APInt::getAllOnesValue(Ty.getNumElements()) : APInt(1, 1);
61   return getKnownBits(R, DemandedElts);
62 }
63 
64 KnownBits GISelKnownBits::getKnownBits(Register R, const APInt &DemandedElts,
65                                        unsigned Depth) {
66   // For now, we only maintain the cache during one request.
67   assert(ComputeKnownBitsCache.empty() && "Cache should have been cleared");
68 
69   KnownBits Known;
70   computeKnownBitsImpl(R, Known, DemandedElts);
71   ComputeKnownBitsCache.clear();
72   return Known;
73 }
74 
75 bool GISelKnownBits::signBitIsZero(Register R) {
76   LLT Ty = MRI.getType(R);
77   unsigned BitWidth = Ty.getScalarSizeInBits();
78   return maskedValueIsZero(R, APInt::getSignMask(BitWidth));
79 }
80 
81 APInt GISelKnownBits::getKnownZeroes(Register R) {
82   return getKnownBits(R).Zero;
83 }
84 
85 APInt GISelKnownBits::getKnownOnes(Register R) { return getKnownBits(R).One; }
86 
87 LLVM_ATTRIBUTE_UNUSED static void
88 dumpResult(const MachineInstr &MI, const KnownBits &Known, unsigned Depth) {
89   dbgs() << "[" << Depth << "] Compute known bits: " << MI << "[" << Depth
90          << "] Computed for: " << MI << "[" << Depth << "] Known: 0x"
91          << (Known.Zero | Known.One).toString(16, false) << "\n"
92          << "[" << Depth << "] Zero: 0x" << Known.Zero.toString(16, false)
93          << "\n"
94          << "[" << Depth << "] One:  0x" << Known.One.toString(16, false)
95          << "\n";
96 }
97 
98 void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
99                                           const APInt &DemandedElts,
100                                           unsigned Depth) {
101   MachineInstr &MI = *MRI.getVRegDef(R);
102   unsigned Opcode = MI.getOpcode();
103   LLT DstTy = MRI.getType(R);
104 
105   // Handle the case where this is called on a register that does not have a
106   // type constraint (i.e. it has a register class constraint instead). This is
107   // unlikely to occur except by looking through copies but it is possible for
108   // the initial register being queried to be in this state.
109   if (!DstTy.isValid()) {
110     Known = KnownBits();
111     return;
112   }
113 
114   unsigned BitWidth = DstTy.getSizeInBits();
115   auto CacheEntry = ComputeKnownBitsCache.find(R);
116   if (CacheEntry != ComputeKnownBitsCache.end()) {
117     Known = CacheEntry->second;
118     LLVM_DEBUG(dbgs() << "Cache hit at ");
119     LLVM_DEBUG(dumpResult(MI, Known, Depth));
120     assert(Known.getBitWidth() == BitWidth && "Cache entry size doesn't match");
121     return;
122   }
123   Known = KnownBits(BitWidth); // Don't know anything
124 
125   if (DstTy.isVector())
126     return; // TODO: Handle vectors.
127 
128   // Depth may get bigger than max depth if it gets passed to a different
129   // GISelKnownBits object.
130   // This may happen when say a generic part uses a GISelKnownBits object
131   // with some max depth, but then we hit TL.computeKnownBitsForTargetInstr
132   // which creates a new GISelKnownBits object with a different and smaller
133   // depth. If we just check for equality, we would never exit if the depth
134   // that is passed down to the target specific GISelKnownBits object is
135   // already bigger than its max depth.
136   if (Depth >= getMaxDepth())
137     return;
138 
139   if (!DemandedElts)
140     return; // No demanded elts, better to assume we don't know anything.
141 
142   KnownBits Known2;
143 
144   switch (Opcode) {
145   default:
146     TL.computeKnownBitsForTargetInstr(*this, R, Known, DemandedElts, MRI,
147                                       Depth);
148     break;
149   case TargetOpcode::COPY:
150   case TargetOpcode::G_PHI:
151   case TargetOpcode::PHI: {
152     Known.One = APInt::getAllOnesValue(BitWidth);
153     Known.Zero = APInt::getAllOnesValue(BitWidth);
154     // Destination registers should not have subregisters at this
155     // point of the pipeline, otherwise the main live-range will be
156     // defined more than once, which is against SSA.
157     assert(MI.getOperand(0).getSubReg() == 0 && "Is this code in SSA?");
158     // Record in the cache that we know nothing for MI.
159     // This will get updated later and in the meantime, if we reach that
160     // phi again, because of a loop, we will cut the search thanks to this
161     // cache entry.
162     // We could actually build up more information on the phi by not cutting
163     // the search, but that additional information is more a side effect
164     // than an intended choice.
165     // Therefore, for now, save on compile time until we derive a proper way
166     // to derive known bits for PHIs within loops.
167     ComputeKnownBitsCache[R] = KnownBits(BitWidth);
168     // PHI's operand are a mix of registers and basic blocks interleaved.
169     // We only care about the register ones.
170     for (unsigned Idx = 1; Idx < MI.getNumOperands(); Idx += 2) {
171       const MachineOperand &Src = MI.getOperand(Idx);
172       Register SrcReg = Src.getReg();
173       // Look through trivial copies and phis but don't look through trivial
174       // copies or phis of the form `%1:(s32) = OP %0:gpr32`, known-bits
175       // analysis is currently unable to determine the bit width of a
176       // register class.
177       //
178       // We can't use NoSubRegister by name as it's defined by each target but
179       // it's always defined to be 0 by tablegen.
180       if (SrcReg.isVirtual() && Src.getSubReg() == 0 /*NoSubRegister*/ &&
181           MRI.getType(SrcReg).isValid()) {
182         // For COPYs we don't do anything, don't increase the depth.
183         computeKnownBitsImpl(SrcReg, Known2, DemandedElts,
184                              Depth + (Opcode != TargetOpcode::COPY));
185         Known.One &= Known2.One;
186         Known.Zero &= Known2.Zero;
187         // If we reach a point where we don't know anything
188         // just stop looking through the operands.
189         if (Known.One == 0 && Known.Zero == 0)
190           break;
191       } else {
192         // We know nothing.
193         Known = KnownBits(BitWidth);
194         break;
195       }
196     }
197     break;
198   }
199   case TargetOpcode::G_CONSTANT: {
200     auto CstVal = getConstantVRegVal(R, MRI);
201     if (!CstVal)
202       break;
203     Known.One = *CstVal;
204     Known.Zero = ~Known.One;
205     break;
206   }
207   case TargetOpcode::G_FRAME_INDEX: {
208     int FrameIdx = MI.getOperand(1).getIndex();
209     TL.computeKnownBitsForFrameIndex(FrameIdx, Known, MF);
210     break;
211   }
212   case TargetOpcode::G_SUB: {
213     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
214                          Depth + 1);
215     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
216                          Depth + 1);
217     Known = KnownBits::computeForAddSub(/*Add*/ false, /*NSW*/ false, Known,
218                                         Known2);
219     break;
220   }
221   case TargetOpcode::G_XOR: {
222     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
223                          Depth + 1);
224     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
225                          Depth + 1);
226 
227     Known ^= Known2;
228     break;
229   }
230   case TargetOpcode::G_PTR_ADD: {
231     // G_PTR_ADD is like G_ADD. FIXME: Is this true for all targets?
232     LLT Ty = MRI.getType(MI.getOperand(1).getReg());
233     if (DL.isNonIntegralAddressSpace(Ty.getAddressSpace()))
234       break;
235     LLVM_FALLTHROUGH;
236   }
237   case TargetOpcode::G_ADD: {
238     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
239                          Depth + 1);
240     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
241                          Depth + 1);
242     Known =
243         KnownBits::computeForAddSub(/*Add*/ true, /*NSW*/ false, Known, Known2);
244     break;
245   }
246   case TargetOpcode::G_AND: {
247     // If either the LHS or the RHS are Zero, the result is zero.
248     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
249                          Depth + 1);
250     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
251                          Depth + 1);
252 
253     Known &= Known2;
254     break;
255   }
256   case TargetOpcode::G_OR: {
257     // If either the LHS or the RHS are Zero, the result is zero.
258     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
259                          Depth + 1);
260     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
261                          Depth + 1);
262 
263     Known |= Known2;
264     break;
265   }
266   case TargetOpcode::G_MUL: {
267     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
268                          Depth + 1);
269     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
270                          Depth + 1);
271     // If low bits are zero in either operand, output low known-0 bits.
272     // Also compute a conservative estimate for high known-0 bits.
273     // More trickiness is possible, but this is sufficient for the
274     // interesting case of alignment computation.
275     unsigned TrailZ =
276         Known.countMinTrailingZeros() + Known2.countMinTrailingZeros();
277     unsigned LeadZ =
278         std::max(Known.countMinLeadingZeros() + Known2.countMinLeadingZeros(),
279                  BitWidth) -
280         BitWidth;
281 
282     Known.resetAll();
283     Known.Zero.setLowBits(std::min(TrailZ, BitWidth));
284     Known.Zero.setHighBits(std::min(LeadZ, BitWidth));
285     break;
286   }
287   case TargetOpcode::G_SELECT: {
288     computeKnownBitsImpl(MI.getOperand(3).getReg(), Known, DemandedElts,
289                          Depth + 1);
290     // If we don't know any bits, early out.
291     if (Known.isUnknown())
292       break;
293     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
294                          Depth + 1);
295     // Only known if known in both the LHS and RHS.
296     Known.One &= Known2.One;
297     Known.Zero &= Known2.Zero;
298     break;
299   }
300   case TargetOpcode::G_FCMP:
301   case TargetOpcode::G_ICMP: {
302     if (TL.getBooleanContents(DstTy.isVector(),
303                               Opcode == TargetOpcode::G_FCMP) ==
304             TargetLowering::ZeroOrOneBooleanContent &&
305         BitWidth > 1)
306       Known.Zero.setBitsFrom(1);
307     break;
308   }
309   case TargetOpcode::G_SEXT: {
310     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
311                          Depth + 1);
312     // If the sign bit is known to be zero or one, then sext will extend
313     // it to the top bits, else it will just zext.
314     Known = Known.sext(BitWidth);
315     break;
316   }
317   case TargetOpcode::G_ANYEXT: {
318     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
319                          Depth + 1);
320     Known = Known.zext(BitWidth);
321     break;
322   }
323   case TargetOpcode::G_LOAD: {
324     if (MI.hasOneMemOperand()) {
325       const MachineMemOperand *MMO = *MI.memoperands_begin();
326       if (const MDNode *Ranges = MMO->getRanges()) {
327         computeKnownBitsFromRangeMetadata(*Ranges, Known);
328       }
329     }
330     break;
331   }
332   case TargetOpcode::G_ZEXTLOAD: {
333     // Everything above the retrieved bits is zero
334     if (MI.hasOneMemOperand())
335       Known.Zero.setBitsFrom((*MI.memoperands_begin())->getSizeInBits());
336     break;
337   }
338   case TargetOpcode::G_ASHR:
339   case TargetOpcode::G_LSHR:
340   case TargetOpcode::G_SHL: {
341     KnownBits RHSKnown;
342     computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, DemandedElts,
343                          Depth + 1);
344     if (!RHSKnown.isConstant()) {
345       LLVM_DEBUG(
346           MachineInstr *RHSMI = MRI.getVRegDef(MI.getOperand(2).getReg());
347           dbgs() << '[' << Depth << "] Shift not known constant: " << *RHSMI);
348       break;
349     }
350     uint64_t Shift = RHSKnown.getConstant().getZExtValue();
351     LLVM_DEBUG(dbgs() << '[' << Depth << "] Shift is " << Shift << '\n');
352 
353     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
354                          Depth + 1);
355 
356     switch (Opcode) {
357     case TargetOpcode::G_ASHR:
358       Known.Zero = Known.Zero.ashr(Shift);
359       Known.One = Known.One.ashr(Shift);
360       break;
361     case TargetOpcode::G_LSHR:
362       Known.Zero = Known.Zero.lshr(Shift);
363       Known.One = Known.One.lshr(Shift);
364       Known.Zero.setBitsFrom(Known.Zero.getBitWidth() - Shift);
365       break;
366     case TargetOpcode::G_SHL:
367       Known.Zero = Known.Zero.shl(Shift);
368       Known.One = Known.One.shl(Shift);
369       Known.Zero.setBits(0, Shift);
370       break;
371     }
372     break;
373   }
374   case TargetOpcode::G_INTTOPTR:
375   case TargetOpcode::G_PTRTOINT:
376     // Fall through and handle them the same as zext/trunc.
377     LLVM_FALLTHROUGH;
378   case TargetOpcode::G_ZEXT:
379   case TargetOpcode::G_TRUNC: {
380     Register SrcReg = MI.getOperand(1).getReg();
381     LLT SrcTy = MRI.getType(SrcReg);
382     unsigned SrcBitWidth = SrcTy.isPointer()
383                                ? DL.getIndexSizeInBits(SrcTy.getAddressSpace())
384                                : SrcTy.getSizeInBits();
385     assert(SrcBitWidth && "SrcBitWidth can't be zero");
386     Known = Known.zextOrTrunc(SrcBitWidth);
387     computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
388     Known = Known.zextOrTrunc(BitWidth);
389     if (BitWidth > SrcBitWidth)
390       Known.Zero.setBitsFrom(SrcBitWidth);
391     break;
392   }
393   }
394 
395   assert(!Known.hasConflict() && "Bits known to be one AND zero?");
396   LLVM_DEBUG(dumpResult(MI, Known, Depth));
397 
398   // Update the cache.
399   ComputeKnownBitsCache[R] = Known;
400 }
401 
402 unsigned GISelKnownBits::computeNumSignBits(Register R,
403                                             const APInt &DemandedElts,
404                                             unsigned Depth) {
405   MachineInstr &MI = *MRI.getVRegDef(R);
406   unsigned Opcode = MI.getOpcode();
407 
408   if (Opcode == TargetOpcode::G_CONSTANT)
409     return MI.getOperand(1).getCImm()->getValue().getNumSignBits();
410 
411   if (Depth == getMaxDepth())
412     return 1;
413 
414   if (!DemandedElts)
415     return 1; // No demanded elts, better to assume we don't know anything.
416 
417   LLT DstTy = MRI.getType(R);
418   const unsigned TyBits = DstTy.getScalarSizeInBits();
419 
420   // Handle the case where this is called on a register that does not have a
421   // type constraint. This is unlikely to occur except by looking through copies
422   // but it is possible for the initial register being queried to be in this
423   // state.
424   if (!DstTy.isValid())
425     return 1;
426 
427   unsigned FirstAnswer = 1;
428   switch (Opcode) {
429   case TargetOpcode::COPY: {
430     MachineOperand &Src = MI.getOperand(1);
431     if (Src.getReg().isVirtual() && Src.getSubReg() == 0 &&
432         MRI.getType(Src.getReg()).isValid()) {
433       // Don't increment Depth for this one since we didn't do any work.
434       return computeNumSignBits(Src.getReg(), DemandedElts, Depth);
435     }
436 
437     return 1;
438   }
439   case TargetOpcode::G_SEXT: {
440     Register Src = MI.getOperand(1).getReg();
441     LLT SrcTy = MRI.getType(Src);
442     unsigned Tmp = DstTy.getScalarSizeInBits() - SrcTy.getScalarSizeInBits();
443     return computeNumSignBits(Src, DemandedElts, Depth + 1) + Tmp;
444   }
445   case TargetOpcode::G_SEXTLOAD: {
446     Register Dst = MI.getOperand(0).getReg();
447     LLT Ty = MRI.getType(Dst);
448     // TODO: add vector support
449     if (Ty.isVector())
450       break;
451     if (MI.hasOneMemOperand())
452       return Ty.getSizeInBits() - (*MI.memoperands_begin())->getSizeInBits();
453     break;
454   }
455   case TargetOpcode::G_TRUNC: {
456     Register Src = MI.getOperand(1).getReg();
457     LLT SrcTy = MRI.getType(Src);
458 
459     // Check if the sign bits of source go down as far as the truncated value.
460     unsigned DstTyBits = DstTy.getScalarSizeInBits();
461     unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
462     unsigned NumSrcSignBits = computeNumSignBits(Src, DemandedElts, Depth + 1);
463     if (NumSrcSignBits > (NumSrcBits - DstTyBits))
464       return NumSrcSignBits - (NumSrcBits - DstTyBits);
465     break;
466   }
467   case TargetOpcode::G_INTRINSIC:
468   case TargetOpcode::G_INTRINSIC_W_SIDE_EFFECTS:
469   default: {
470     unsigned NumBits =
471       TL.computeNumSignBitsForTargetInstr(*this, R, DemandedElts, MRI, Depth);
472     if (NumBits > 1)
473       FirstAnswer = std::max(FirstAnswer, NumBits);
474     break;
475   }
476   }
477 
478   // Finally, if we can prove that the top bits of the result are 0's or 1's,
479   // use this information.
480   KnownBits Known = getKnownBits(R, DemandedElts, Depth);
481   APInt Mask;
482   if (Known.isNonNegative()) {        // sign bit is 0
483     Mask = Known.Zero;
484   } else if (Known.isNegative()) {  // sign bit is 1;
485     Mask = Known.One;
486   } else {
487     // Nothing known.
488     return FirstAnswer;
489   }
490 
491   // Okay, we know that the sign bit in Mask is set.  Use CLO to determine
492   // the number of identical bits in the top of the input value.
493   Mask <<= Mask.getBitWidth() - TyBits;
494   return std::max(FirstAnswer, Mask.countLeadingOnes());
495 }
496 
497 unsigned GISelKnownBits::computeNumSignBits(Register R, unsigned Depth) {
498   LLT Ty = MRI.getType(R);
499   APInt DemandedElts = Ty.isVector()
500                            ? APInt::getAllOnesValue(Ty.getNumElements())
501                            : APInt(1, 1);
502   return computeNumSignBits(R, DemandedElts, Depth);
503 }
504 
505 void GISelKnownBitsAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
506   AU.setPreservesAll();
507   MachineFunctionPass::getAnalysisUsage(AU);
508 }
509 
510 bool GISelKnownBitsAnalysis::runOnMachineFunction(MachineFunction &MF) {
511   return false;
512 }
513