xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/GlobalISel/GISelKnownBits.cpp (revision a134ebd6e63f658f2d3d04ac0c60d23bcaa86dd7)
1 //===- lib/CodeGen/GlobalISel/GISelKnownBits.cpp --------------*- C++ *-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// Provides analysis for querying information about KnownBits during GISel
10 /// passes.
11 //
12 //===------------------
13 #include "llvm/CodeGen/GlobalISel/GISelKnownBits.h"
14 #include "llvm/Analysis/ValueTracking.h"
15 #include "llvm/CodeGen/GlobalISel/Utils.h"
16 #include "llvm/CodeGen/MachineFrameInfo.h"
17 #include "llvm/CodeGen/MachineRegisterInfo.h"
18 #include "llvm/CodeGen/TargetLowering.h"
19 #include "llvm/CodeGen/TargetOpcodes.h"
20 
21 #define DEBUG_TYPE "gisel-known-bits"
22 
23 using namespace llvm;
24 
25 char llvm::GISelKnownBitsAnalysis::ID = 0;
26 
27 INITIALIZE_PASS_BEGIN(GISelKnownBitsAnalysis, DEBUG_TYPE,
28                       "Analysis for ComputingKnownBits", false, true)
29 INITIALIZE_PASS_END(GISelKnownBitsAnalysis, DEBUG_TYPE,
30                     "Analysis for ComputingKnownBits", false, true)
31 
32 GISelKnownBits::GISelKnownBits(MachineFunction &MF)
33     : MF(MF), MRI(MF.getRegInfo()), TL(*MF.getSubtarget().getTargetLowering()),
34       DL(MF.getFunction().getParent()->getDataLayout()) {}
35 
36 Align GISelKnownBits::inferAlignmentForFrameIdx(int FrameIdx, int Offset,
37                                                 const MachineFunction &MF) {
38   const MachineFrameInfo &MFI = MF.getFrameInfo();
39   return commonAlignment(Align(MFI.getObjectAlignment(FrameIdx)), Offset);
40   // TODO: How to handle cases with Base + Offset?
41 }
42 
43 MaybeAlign GISelKnownBits::inferPtrAlignment(const MachineInstr &MI) {
44   if (MI.getOpcode() == TargetOpcode::G_FRAME_INDEX) {
45     int FrameIdx = MI.getOperand(1).getIndex();
46     return inferAlignmentForFrameIdx(FrameIdx, 0, *MI.getMF());
47   }
48   return None;
49 }
50 
51 void GISelKnownBits::computeKnownBitsForFrameIndex(Register R, KnownBits &Known,
52                                                    const APInt &DemandedElts,
53                                                    unsigned Depth) {
54   const MachineInstr &MI = *MRI.getVRegDef(R);
55   computeKnownBitsForAlignment(Known, inferPtrAlignment(MI));
56 }
57 
58 void GISelKnownBits::computeKnownBitsForAlignment(KnownBits &Known,
59                                                   MaybeAlign Alignment) {
60   if (Alignment)
61     // The low bits are known zero if the pointer is aligned.
62     Known.Zero.setLowBits(Log2(Alignment));
63 }
64 
65 KnownBits GISelKnownBits::getKnownBits(MachineInstr &MI) {
66   return getKnownBits(MI.getOperand(0).getReg());
67 }
68 
69 KnownBits GISelKnownBits::getKnownBits(Register R) {
70   KnownBits Known;
71   LLT Ty = MRI.getType(R);
72   APInt DemandedElts =
73       Ty.isVector() ? APInt::getAllOnesValue(Ty.getNumElements()) : APInt(1, 1);
74   computeKnownBitsImpl(R, Known, DemandedElts);
75   return Known;
76 }
77 
78 bool GISelKnownBits::signBitIsZero(Register R) {
79   LLT Ty = MRI.getType(R);
80   unsigned BitWidth = Ty.getScalarSizeInBits();
81   return maskedValueIsZero(R, APInt::getSignMask(BitWidth));
82 }
83 
84 APInt GISelKnownBits::getKnownZeroes(Register R) {
85   return getKnownBits(R).Zero;
86 }
87 
88 APInt GISelKnownBits::getKnownOnes(Register R) { return getKnownBits(R).One; }
89 
90 void GISelKnownBits::computeKnownBitsImpl(Register R, KnownBits &Known,
91                                           const APInt &DemandedElts,
92                                           unsigned Depth) {
93   MachineInstr &MI = *MRI.getVRegDef(R);
94   unsigned Opcode = MI.getOpcode();
95   LLT DstTy = MRI.getType(R);
96 
97   // Handle the case where this is called on a register that does not have a
98   // type constraint (i.e. it has a register class constraint instead). This is
99   // unlikely to occur except by looking through copies but it is possible for
100   // the initial register being queried to be in this state.
101   if (!DstTy.isValid()) {
102     Known = KnownBits();
103     return;
104   }
105 
106   unsigned BitWidth = DstTy.getSizeInBits();
107   Known = KnownBits(BitWidth); // Don't know anything
108 
109   if (DstTy.isVector())
110     return; // TODO: Handle vectors.
111 
112   if (Depth == getMaxDepth())
113     return;
114 
115   if (!DemandedElts)
116     return; // No demanded elts, better to assume we don't know anything.
117 
118   KnownBits Known2;
119 
120   switch (Opcode) {
121   default:
122     TL.computeKnownBitsForTargetInstr(*this, R, Known, DemandedElts, MRI,
123                                       Depth);
124     break;
125   case TargetOpcode::COPY: {
126     MachineOperand Dst = MI.getOperand(0);
127     MachineOperand Src = MI.getOperand(1);
128     // Look through trivial copies but don't look through trivial copies of the
129     // form `%1:(s32) = OP %0:gpr32` known-bits analysis is currently unable to
130     // determine the bit width of a register class.
131     //
132     // We can't use NoSubRegister by name as it's defined by each target but
133     // it's always defined to be 0 by tablegen.
134     if (Dst.getSubReg() == 0 /*NoSubRegister*/ && Src.getReg().isVirtual() &&
135         Src.getSubReg() == 0 /*NoSubRegister*/ &&
136         MRI.getType(Src.getReg()).isValid()) {
137       // Don't increment Depth for this one since we didn't do any work.
138       computeKnownBitsImpl(Src.getReg(), Known, DemandedElts, Depth);
139     }
140     break;
141   }
142   case TargetOpcode::G_CONSTANT: {
143     auto CstVal = getConstantVRegVal(R, MRI);
144     if (!CstVal)
145       break;
146     Known.One = *CstVal;
147     Known.Zero = ~Known.One;
148     break;
149   }
150   case TargetOpcode::G_FRAME_INDEX: {
151     computeKnownBitsForFrameIndex(R, Known, DemandedElts);
152     break;
153   }
154   case TargetOpcode::G_SUB: {
155     // If low bits are known to be zero in both operands, then we know they are
156     // going to be 0 in the result. Both addition and complement operations
157     // preserve the low zero bits.
158     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
159                          Depth + 1);
160     unsigned KnownZeroLow = Known2.countMinTrailingZeros();
161     if (KnownZeroLow == 0)
162       break;
163     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
164                          Depth + 1);
165     KnownZeroLow = std::min(KnownZeroLow, Known2.countMinTrailingZeros());
166     Known.Zero.setLowBits(KnownZeroLow);
167     break;
168   }
169   case TargetOpcode::G_XOR: {
170     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
171                          Depth + 1);
172     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
173                          Depth + 1);
174 
175     // Output known-0 bits are known if clear or set in both the LHS & RHS.
176     APInt KnownZeroOut = (Known.Zero & Known2.Zero) | (Known.One & Known2.One);
177     // Output known-1 are known to be set if set in only one of the LHS, RHS.
178     Known.One = (Known.Zero & Known2.One) | (Known.One & Known2.Zero);
179     Known.Zero = KnownZeroOut;
180     break;
181   }
182   case TargetOpcode::G_PTR_ADD: {
183     // G_PTR_ADD is like G_ADD. FIXME: Is this true for all targets?
184     LLT Ty = MRI.getType(MI.getOperand(1).getReg());
185     if (DL.isNonIntegralAddressSpace(Ty.getAddressSpace()))
186       break;
187     LLVM_FALLTHROUGH;
188   }
189   case TargetOpcode::G_ADD: {
190     // Output known-0 bits are known if clear or set in both the low clear bits
191     // common to both LHS & RHS.  For example, 8+(X<<3) is known to have the
192     // low 3 bits clear.
193     // Output known-0 bits are also known if the top bits of each input are
194     // known to be clear. For example, if one input has the top 10 bits clear
195     // and the other has the top 8 bits clear, we know the top 7 bits of the
196     // output must be clear.
197     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
198                          Depth + 1);
199     unsigned KnownZeroHigh = Known2.countMinLeadingZeros();
200     unsigned KnownZeroLow = Known2.countMinTrailingZeros();
201     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
202                          Depth + 1);
203     KnownZeroHigh = std::min(KnownZeroHigh, Known2.countMinLeadingZeros());
204     KnownZeroLow = std::min(KnownZeroLow, Known2.countMinTrailingZeros());
205     Known.Zero.setLowBits(KnownZeroLow);
206     if (KnownZeroHigh > 1)
207       Known.Zero.setHighBits(KnownZeroHigh - 1);
208     break;
209   }
210   case TargetOpcode::G_AND: {
211     // If either the LHS or the RHS are Zero, the result is zero.
212     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
213                          Depth + 1);
214     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
215                          Depth + 1);
216 
217     // Output known-1 bits are only known if set in both the LHS & RHS.
218     Known.One &= Known2.One;
219     // Output known-0 are known to be clear if zero in either the LHS | RHS.
220     Known.Zero |= Known2.Zero;
221     break;
222   }
223   case TargetOpcode::G_OR: {
224     // If either the LHS or the RHS are Zero, the result is zero.
225     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
226                          Depth + 1);
227     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
228                          Depth + 1);
229 
230     // Output known-0 bits are only known if clear in both the LHS & RHS.
231     Known.Zero &= Known2.Zero;
232     // Output known-1 are known to be set if set in either the LHS | RHS.
233     Known.One |= Known2.One;
234     break;
235   }
236   case TargetOpcode::G_MUL: {
237     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known, DemandedElts,
238                          Depth + 1);
239     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known2, DemandedElts,
240                          Depth + 1);
241     // If low bits are zero in either operand, output low known-0 bits.
242     // Also compute a conservative estimate for high known-0 bits.
243     // More trickiness is possible, but this is sufficient for the
244     // interesting case of alignment computation.
245     unsigned TrailZ =
246         Known.countMinTrailingZeros() + Known2.countMinTrailingZeros();
247     unsigned LeadZ =
248         std::max(Known.countMinLeadingZeros() + Known2.countMinLeadingZeros(),
249                  BitWidth) -
250         BitWidth;
251 
252     Known.resetAll();
253     Known.Zero.setLowBits(std::min(TrailZ, BitWidth));
254     Known.Zero.setHighBits(std::min(LeadZ, BitWidth));
255     break;
256   }
257   case TargetOpcode::G_SELECT: {
258     computeKnownBitsImpl(MI.getOperand(3).getReg(), Known, DemandedElts,
259                          Depth + 1);
260     // If we don't know any bits, early out.
261     if (Known.isUnknown())
262       break;
263     computeKnownBitsImpl(MI.getOperand(2).getReg(), Known2, DemandedElts,
264                          Depth + 1);
265     // Only known if known in both the LHS and RHS.
266     Known.One &= Known2.One;
267     Known.Zero &= Known2.Zero;
268     break;
269   }
270   case TargetOpcode::G_FCMP:
271   case TargetOpcode::G_ICMP: {
272     if (TL.getBooleanContents(DstTy.isVector(),
273                               Opcode == TargetOpcode::G_FCMP) ==
274             TargetLowering::ZeroOrOneBooleanContent &&
275         BitWidth > 1)
276       Known.Zero.setBitsFrom(1);
277     break;
278   }
279   case TargetOpcode::G_SEXT: {
280     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
281                          Depth + 1);
282     // If the sign bit is known to be zero or one, then sext will extend
283     // it to the top bits, else it will just zext.
284     Known = Known.sext(BitWidth);
285     break;
286   }
287   case TargetOpcode::G_ANYEXT: {
288     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
289                          Depth + 1);
290     Known = Known.zext(BitWidth, true /* ExtendedBitsAreKnownZero */);
291     break;
292   }
293   case TargetOpcode::G_LOAD: {
294     if (MI.hasOneMemOperand()) {
295       const MachineMemOperand *MMO = *MI.memoperands_begin();
296       if (const MDNode *Ranges = MMO->getRanges()) {
297         computeKnownBitsFromRangeMetadata(*Ranges, Known);
298       }
299     }
300     break;
301   }
302   case TargetOpcode::G_ZEXTLOAD: {
303     // Everything above the retrieved bits is zero
304     if (MI.hasOneMemOperand())
305       Known.Zero.setBitsFrom((*MI.memoperands_begin())->getSizeInBits());
306     break;
307   }
308   case TargetOpcode::G_ASHR:
309   case TargetOpcode::G_LSHR:
310   case TargetOpcode::G_SHL: {
311     KnownBits RHSKnown;
312     computeKnownBitsImpl(MI.getOperand(2).getReg(), RHSKnown, DemandedElts,
313                          Depth + 1);
314     if (!RHSKnown.isConstant()) {
315       LLVM_DEBUG(
316           MachineInstr *RHSMI = MRI.getVRegDef(MI.getOperand(2).getReg());
317           dbgs() << '[' << Depth << "] Shift not known constant: " << *RHSMI);
318       break;
319     }
320     uint64_t Shift = RHSKnown.getConstant().getZExtValue();
321     LLVM_DEBUG(dbgs() << '[' << Depth << "] Shift is " << Shift << '\n');
322 
323     computeKnownBitsImpl(MI.getOperand(1).getReg(), Known, DemandedElts,
324                          Depth + 1);
325 
326     switch (Opcode) {
327     case TargetOpcode::G_ASHR:
328       Known.Zero = Known.Zero.ashr(Shift);
329       Known.One = Known.One.ashr(Shift);
330       break;
331     case TargetOpcode::G_LSHR:
332       Known.Zero = Known.Zero.lshr(Shift);
333       Known.One = Known.One.lshr(Shift);
334       Known.Zero.setBitsFrom(Known.Zero.getBitWidth() - Shift);
335       break;
336     case TargetOpcode::G_SHL:
337       Known.Zero = Known.Zero.shl(Shift);
338       Known.One = Known.One.shl(Shift);
339       Known.Zero.setBits(0, Shift);
340       break;
341     }
342     break;
343   }
344   case TargetOpcode::G_INTTOPTR:
345   case TargetOpcode::G_PTRTOINT:
346     // Fall through and handle them the same as zext/trunc.
347     LLVM_FALLTHROUGH;
348   case TargetOpcode::G_ZEXT:
349   case TargetOpcode::G_TRUNC: {
350     Register SrcReg = MI.getOperand(1).getReg();
351     LLT SrcTy = MRI.getType(SrcReg);
352     unsigned SrcBitWidth = SrcTy.isPointer()
353                                ? DL.getIndexSizeInBits(SrcTy.getAddressSpace())
354                                : SrcTy.getSizeInBits();
355     assert(SrcBitWidth && "SrcBitWidth can't be zero");
356     Known = Known.zextOrTrunc(SrcBitWidth, true);
357     computeKnownBitsImpl(SrcReg, Known, DemandedElts, Depth + 1);
358     Known = Known.zextOrTrunc(BitWidth, true);
359     if (BitWidth > SrcBitWidth)
360       Known.Zero.setBitsFrom(SrcBitWidth);
361     break;
362   }
363   }
364 
365   assert(!Known.hasConflict() && "Bits known to be one AND zero?");
366   LLVM_DEBUG(dbgs() << "[" << Depth << "] Compute known bits: " << MI << "["
367                     << Depth << "] Computed for: " << MI << "[" << Depth
368                     << "] Known: 0x"
369                     << (Known.Zero | Known.One).toString(16, false) << "\n"
370                     << "[" << Depth << "] Zero: 0x"
371                     << Known.Zero.toString(16, false) << "\n"
372                     << "[" << Depth << "] One:  0x"
373                     << Known.One.toString(16, false) << "\n");
374 }
375 
376 unsigned GISelKnownBits::computeNumSignBits(Register R,
377                                             const APInt &DemandedElts,
378                                             unsigned Depth) {
379   MachineInstr &MI = *MRI.getVRegDef(R);
380   unsigned Opcode = MI.getOpcode();
381 
382   if (Opcode == TargetOpcode::G_CONSTANT)
383     return MI.getOperand(1).getCImm()->getValue().getNumSignBits();
384 
385   if (Depth == getMaxDepth())
386     return 1;
387 
388   if (!DemandedElts)
389     return 1; // No demanded elts, better to assume we don't know anything.
390 
391   LLT DstTy = MRI.getType(R);
392 
393   // Handle the case where this is called on a register that does not have a
394   // type constraint. This is unlikely to occur except by looking through copies
395   // but it is possible for the initial register being queried to be in this
396   // state.
397   if (!DstTy.isValid())
398     return 1;
399 
400   switch (Opcode) {
401   case TargetOpcode::COPY: {
402     MachineOperand &Src = MI.getOperand(1);
403     if (Src.getReg().isVirtual() && Src.getSubReg() == 0 &&
404         MRI.getType(Src.getReg()).isValid()) {
405       // Don't increment Depth for this one since we didn't do any work.
406       return computeNumSignBits(Src.getReg(), DemandedElts, Depth);
407     }
408 
409     return 1;
410   }
411   case TargetOpcode::G_SEXT: {
412     Register Src = MI.getOperand(1).getReg();
413     LLT SrcTy = MRI.getType(Src);
414     unsigned Tmp = DstTy.getScalarSizeInBits() - SrcTy.getScalarSizeInBits();
415     return computeNumSignBits(Src, DemandedElts, Depth + 1) + Tmp;
416   }
417   case TargetOpcode::G_TRUNC: {
418     Register Src = MI.getOperand(1).getReg();
419     LLT SrcTy = MRI.getType(Src);
420 
421     // Check if the sign bits of source go down as far as the truncated value.
422     unsigned DstTyBits = DstTy.getScalarSizeInBits();
423     unsigned NumSrcBits = SrcTy.getScalarSizeInBits();
424     unsigned NumSrcSignBits = computeNumSignBits(Src, DemandedElts, Depth + 1);
425     if (NumSrcSignBits > (NumSrcBits - DstTyBits))
426       return NumSrcSignBits - (NumSrcBits - DstTyBits);
427     break;
428   }
429   default:
430     break;
431   }
432 
433   // TODO: Handle target instructions
434   // TODO: Fall back to known bits
435   return 1;
436 }
437 
438 unsigned GISelKnownBits::computeNumSignBits(Register R, unsigned Depth) {
439   LLT Ty = MRI.getType(R);
440   APInt DemandedElts = Ty.isVector()
441                            ? APInt::getAllOnesValue(Ty.getNumElements())
442                            : APInt(1, 1);
443   return computeNumSignBits(R, DemandedElts, Depth);
444 }
445 
446 void GISelKnownBitsAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
447   AU.setPreservesAll();
448   MachineFunctionPass::getAnalysisUsage(AU);
449 }
450 
451 bool GISelKnownBitsAnalysis::runOnMachineFunction(MachineFunction &MF) {
452   return false;
453 }
454