1 //===- HashRecognize.cpp ----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The HashRecognize analysis recognizes unoptimized polynomial hash functions
10 // with operations over a Galois field of characteristic 2, also called binary
11 // fields, or GF(2^n): this class of hash functions can be optimized using a
12 // lookup-table-driven implementation, or with target-specific instructions.
13 // Examples:
14 //
15 // 1. Cyclic redundancy check (CRC), which is a polynomial division in GF(2).
16 // 2. Rabin fingerprint, a component of the Rabin-Karp algorithm, which is a
17 // rolling hash polynomial division in GF(2).
18 // 3. Rijndael MixColumns, a step in AES computation, which is a polynomial
19 // multiplication in GF(2^3).
20 // 4. GHASH, the authentication mechanism in AES Galois/Counter Mode (GCM),
21 // which is a polynomial evaluation in GF(2^128).
22 //
23 // All of them use an irreducible generating polynomial of degree m,
24 //
25 // c_m * x^m + c_(m-1) * x^(m-1) + ... + c_0 * x^0
26 //
27 // where each coefficient c is can take values in GF(2^n), where 2^n is termed
28 // the order of the Galois field. For GF(2), each coefficient can take values
29 // either 0 or 1, and the polynomial is simply represented by m+1 bits,
30 // corresponding to the coefficients. The different variants of CRC are named by
31 // degree of generating polynomial used: so CRC-32 would use a polynomial of
32 // degree 32.
33 //
34 // The reason algorithms on GF(2^n) can be optimized with a lookup-table is the
35 // following: in such fields, polynomial addition and subtraction are identical
36 // and equivalent to XOR, polynomial multiplication is an AND, and polynomial
37 // division is identity: the XOR and AND operations in unoptimized
38 // implementations are performed bit-wise, and can be optimized to be performed
39 // chunk-wise, by interleaving copies of the generating polynomial, and storing
40 // the pre-computed values in a table.
41 //
42 // A generating polynomial of m bits always has the MSB set, so we usually
43 // omit it. An example of a 16-bit polynomial is the CRC-16-CCITT polynomial:
44 //
45 // (x^16) + x^12 + x^5 + 1 = (1) 0001 0000 0010 0001 = 0x1021
46 //
47 // Transmissions are either in big-endian or little-endian form, and hash
48 // algorithms are written according to this. For example, IEEE 802 and RS-232
49 // specify little-endian transmission.
50 //
51 //===----------------------------------------------------------------------===//
52 //
53 // At the moment, we only recognize the CRC algorithm.
54 // Documentation on CRC32 from the kernel:
55 // https://www.kernel.org/doc/Documentation/crc32.txt
56 //
57 //
58 //===----------------------------------------------------------------------===//
59
60 #include "llvm/Analysis/HashRecognize.h"
61 #include "llvm/ADT/APInt.h"
62 #include "llvm/Analysis/LoopAnalysisManager.h"
63 #include "llvm/Analysis/LoopInfo.h"
64 #include "llvm/Analysis/ScalarEvolution.h"
65 #include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
66 #include "llvm/Analysis/ValueTracking.h"
67 #include "llvm/IR/PatternMatch.h"
68 #include "llvm/Support/KnownBits.h"
69
70 using namespace llvm;
71 using namespace PatternMatch;
72 using namespace SCEVPatternMatch;
73
74 #define DEBUG_TYPE "hash-recognize"
75
76 // KnownBits for a PHI node. There are at most two PHI nodes, corresponding to
77 // the Simple Recurrence and Conditional Recurrence. The IndVar PHI is not
78 // relevant.
79 using KnownPhiMap = SmallDenseMap<const PHINode *, KnownBits, 2>;
80
81 // A pair of a PHI node along with its incoming value from within a loop.
82 using PhiStepPair = std::pair<const PHINode *, const Instruction *>;
83
84 /// A much simpler version of ValueTracking, in that it computes KnownBits of
85 /// values, except that it computes the evolution of KnownBits in a loop with a
86 /// given trip count, and predication is specialized for a significant-bit
87 /// check.
88 class ValueEvolution {
89 const unsigned TripCount;
90 const bool ByteOrderSwapped;
91 APInt GenPoly;
92 StringRef ErrStr;
93
94 // Compute the KnownBits of a BinaryOperator.
95 KnownBits computeBinOp(const BinaryOperator *I);
96
97 // Compute the KnownBits of an Instruction.
98 KnownBits computeInstr(const Instruction *I);
99
100 // Compute the KnownBits of a Value.
101 KnownBits compute(const Value *V);
102
103 public:
104 // ValueEvolution is meant to be constructed with the TripCount of the loop,
105 // and whether the polynomial algorithm is big-endian, for the significant-bit
106 // check.
107 ValueEvolution(unsigned TripCount, bool ByteOrderSwapped);
108
109 // Given a list of PHI nodes along with their incoming value from within the
110 // loop, computeEvolutions computes the KnownBits of each of the PHI nodes on
111 // the final iteration. Returns true on success and false on error.
112 bool computeEvolutions(ArrayRef<PhiStepPair> PhiEvolutions);
113
114 // In case ValueEvolution encounters an error, this is meant to be used for a
115 // precise error message.
getError() const116 StringRef getError() const { return ErrStr; }
117
118 // The computed KnownBits for each PHI node, which is populated after
119 // computeEvolutions is called.
120 KnownPhiMap KnownPhis;
121 };
122
ValueEvolution(unsigned TripCount,bool ByteOrderSwapped)123 ValueEvolution::ValueEvolution(unsigned TripCount, bool ByteOrderSwapped)
124 : TripCount(TripCount), ByteOrderSwapped(ByteOrderSwapped) {}
125
computeBinOp(const BinaryOperator * I)126 KnownBits ValueEvolution::computeBinOp(const BinaryOperator *I) {
127 KnownBits KnownL(compute(I->getOperand(0)));
128 KnownBits KnownR(compute(I->getOperand(1)));
129
130 switch (I->getOpcode()) {
131 case Instruction::BinaryOps::And:
132 return KnownL & KnownR;
133 case Instruction::BinaryOps::Or:
134 return KnownL | KnownR;
135 case Instruction::BinaryOps::Xor:
136 return KnownL ^ KnownR;
137 case Instruction::BinaryOps::Shl: {
138 auto *OBO = cast<OverflowingBinaryOperator>(I);
139 return KnownBits::shl(KnownL, KnownR, OBO->hasNoUnsignedWrap(),
140 OBO->hasNoSignedWrap());
141 }
142 case Instruction::BinaryOps::LShr:
143 return KnownBits::lshr(KnownL, KnownR);
144 case Instruction::BinaryOps::AShr:
145 return KnownBits::ashr(KnownL, KnownR);
146 case Instruction::BinaryOps::Add: {
147 auto *OBO = cast<OverflowingBinaryOperator>(I);
148 return KnownBits::add(KnownL, KnownR, OBO->hasNoUnsignedWrap(),
149 OBO->hasNoSignedWrap());
150 }
151 case Instruction::BinaryOps::Sub: {
152 auto *OBO = cast<OverflowingBinaryOperator>(I);
153 return KnownBits::sub(KnownL, KnownR, OBO->hasNoUnsignedWrap(),
154 OBO->hasNoSignedWrap());
155 }
156 case Instruction::BinaryOps::Mul: {
157 Value *Op0 = I->getOperand(0);
158 Value *Op1 = I->getOperand(1);
159 bool SelfMultiply = Op0 == Op1 && isGuaranteedNotToBeUndef(Op0);
160 return KnownBits::mul(KnownL, KnownR, SelfMultiply);
161 }
162 case Instruction::BinaryOps::UDiv:
163 return KnownBits::udiv(KnownL, KnownR);
164 case Instruction::BinaryOps::SDiv:
165 return KnownBits::sdiv(KnownL, KnownR);
166 case Instruction::BinaryOps::URem:
167 return KnownBits::urem(KnownL, KnownR);
168 case Instruction::BinaryOps::SRem:
169 return KnownBits::srem(KnownL, KnownR);
170 default:
171 ErrStr = "Unknown BinaryOperator";
172 unsigned BitWidth = I->getType()->getScalarSizeInBits();
173 return {BitWidth};
174 }
175 }
176
computeInstr(const Instruction * I)177 KnownBits ValueEvolution::computeInstr(const Instruction *I) {
178 unsigned BitWidth = I->getType()->getScalarSizeInBits();
179
180 // We look up in the map that contains the KnownBits of the PHI from the
181 // previous iteration.
182 if (const PHINode *P = dyn_cast<PHINode>(I))
183 return KnownPhis.lookup_or(P, BitWidth);
184
185 // Compute the KnownBits for a Select(Cmp()), forcing it to take the branch
186 // that is predicated on the (least|most)-significant-bit check.
187 CmpPredicate Pred;
188 Value *L, *R, *TV, *FV;
189 if (match(I, m_Select(m_ICmp(Pred, m_Value(L), m_Value(R)), m_Value(TV),
190 m_Value(FV)))) {
191 // We need to check LCR against [0, 2) in the little-endian case, because
192 // the RCR check is insufficient: it is simply [0, 1).
193 if (!ByteOrderSwapped) {
194 KnownBits KnownL = compute(L);
195 unsigned ICmpBW = KnownL.getBitWidth();
196 auto LCR = ConstantRange::fromKnownBits(KnownL, false);
197 auto CheckLCR = ConstantRange(APInt::getZero(ICmpBW), APInt(ICmpBW, 2));
198 if (LCR != CheckLCR) {
199 ErrStr = "Bad LHS of significant-bit-check";
200 return {BitWidth};
201 }
202 }
203
204 // Check that the predication is on (most|least) significant bit.
205 KnownBits KnownR = compute(R);
206 unsigned ICmpBW = KnownR.getBitWidth();
207 auto RCR = ConstantRange::fromKnownBits(KnownR, false);
208 auto AllowedR = ConstantRange::makeAllowedICmpRegion(Pred, RCR);
209 ConstantRange CheckRCR(APInt::getZero(ICmpBW),
210 ByteOrderSwapped ? APInt::getSignedMinValue(ICmpBW)
211 : APInt(ICmpBW, 1));
212 if (AllowedR == CheckRCR)
213 return compute(TV);
214 if (AllowedR.inverse() == CheckRCR)
215 return compute(FV);
216
217 ErrStr = "Bad RHS of significant-bit-check";
218 return {BitWidth};
219 }
220
221 if (auto *BO = dyn_cast<BinaryOperator>(I))
222 return computeBinOp(BO);
223
224 switch (I->getOpcode()) {
225 case Instruction::CastOps::Trunc:
226 return compute(I->getOperand(0)).trunc(BitWidth);
227 case Instruction::CastOps::ZExt:
228 return compute(I->getOperand(0)).zext(BitWidth);
229 case Instruction::CastOps::SExt:
230 return compute(I->getOperand(0)).sext(BitWidth);
231 default:
232 ErrStr = "Unknown Instruction";
233 return {BitWidth};
234 }
235 }
236
compute(const Value * V)237 KnownBits ValueEvolution::compute(const Value *V) {
238 if (auto *CI = dyn_cast<ConstantInt>(V))
239 return KnownBits::makeConstant(CI->getValue());
240
241 if (auto *I = dyn_cast<Instruction>(V))
242 return computeInstr(I);
243
244 ErrStr = "Unknown Value";
245 unsigned BitWidth = V->getType()->getScalarSizeInBits();
246 return {BitWidth};
247 }
248
computeEvolutions(ArrayRef<PhiStepPair> PhiEvolutions)249 bool ValueEvolution::computeEvolutions(ArrayRef<PhiStepPair> PhiEvolutions) {
250 for (unsigned I = 0; I < TripCount; ++I)
251 for (auto [Phi, Step] : PhiEvolutions)
252 KnownPhis.emplace_or_assign(Phi, computeInstr(Step));
253
254 return ErrStr.empty();
255 }
256
257 /// A structure that can hold either a Simple Recurrence or a Conditional
258 /// Recurrence. Note that in the case of a Simple Recurrence, Step is an operand
259 /// of the BO, while in a Conditional Recurrence, it is a SelectInst.
260 struct RecurrenceInfo {
261 const Loop &L;
262 const PHINode *Phi = nullptr;
263 BinaryOperator *BO = nullptr;
264 Value *Start = nullptr;
265 Value *Step = nullptr;
266 std::optional<APInt> ExtraConst;
267
RecurrenceInfoRecurrenceInfo268 RecurrenceInfo(const Loop &L) : L(L) {}
operator boolRecurrenceInfo269 operator bool() const { return BO; }
270
printRecurrenceInfo271 void print(raw_ostream &OS, unsigned Indent = 0) const {
272 OS.indent(Indent) << "Phi: ";
273 Phi->print(OS);
274 OS << "\n";
275 OS.indent(Indent) << "BinaryOperator: ";
276 BO->print(OS);
277 OS << "\n";
278 OS.indent(Indent) << "Start: ";
279 Start->print(OS);
280 OS << "\n";
281 OS.indent(Indent) << "Step: ";
282 Step->print(OS);
283 OS << "\n";
284 if (ExtraConst) {
285 OS.indent(Indent) << "ExtraConst: ";
286 ExtraConst->print(OS, false);
287 OS << "\n";
288 }
289 }
290
291 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
dumpRecurrenceInfo292 LLVM_DUMP_METHOD void dump() const { print(dbgs()); }
293 #endif
294
295 bool matchSimpleRecurrence(const PHINode *P);
296 bool matchConditionalRecurrence(
297 const PHINode *P,
298 Instruction::BinaryOps BOWithConstOpToMatch = Instruction::BinaryOpsEnd);
299
300 private:
301 BinaryOperator *digRecurrence(
302 Instruction *V,
303 Instruction::BinaryOps BOWithConstOpToMatch = Instruction::BinaryOpsEnd);
304 };
305
306 /// Wraps llvm::matchSimpleRecurrence. Match a simple first order recurrence
307 /// cycle of the form:
308 ///
309 /// loop:
310 /// %rec = phi [%start, %entry], [%BO, %loop]
311 /// ...
312 /// %BO = binop %rec, %step
313 ///
314 /// or
315 ///
316 /// loop:
317 /// %rec = phi [%start, %entry], [%BO, %loop]
318 /// ...
319 /// %BO = binop %step, %rec
320 ///
matchSimpleRecurrence(const PHINode * P)321 bool RecurrenceInfo::matchSimpleRecurrence(const PHINode *P) {
322 Phi = P;
323 return llvm::matchSimpleRecurrence(Phi, BO, Start, Step);
324 }
325
326 /// Digs for a recurrence starting with \p V hitting the PHI node in a use-def
327 /// chain. Used by matchConditionalRecurrence.
328 BinaryOperator *
digRecurrence(Instruction * V,Instruction::BinaryOps BOWithConstOpToMatch)329 RecurrenceInfo::digRecurrence(Instruction *V,
330 Instruction::BinaryOps BOWithConstOpToMatch) {
331 SmallVector<Instruction *> Worklist;
332 Worklist.push_back(V);
333 while (!Worklist.empty()) {
334 Instruction *I = Worklist.pop_back_val();
335
336 // Don't add a PHI's operands to the Worklist.
337 if (isa<PHINode>(I))
338 continue;
339
340 // Find a recurrence over a BinOp, by matching either of its operands
341 // with with the PHINode.
342 if (match(I, m_c_BinOp(m_Value(), m_Specific(Phi))))
343 return cast<BinaryOperator>(I);
344
345 // Bind to ExtraConst, if we match exactly one.
346 if (I->getOpcode() == BOWithConstOpToMatch) {
347 if (ExtraConst)
348 return nullptr;
349 const APInt *C = nullptr;
350 if (match(I, m_c_BinOp(m_APInt(C), m_Value())))
351 ExtraConst = *C;
352 }
353
354 // Continue along the use-def chain.
355 for (Use &U : I->operands())
356 if (auto *UI = dyn_cast<Instruction>(U))
357 if (L.contains(UI))
358 Worklist.push_back(UI);
359 }
360 return nullptr;
361 }
362
363 /// A Conditional Recurrence is a recurrence of the form:
364 ///
365 /// loop:
366 /// %rec = phi [%start, %entry], [%step, %loop]
367 /// ...
368 /// %step = select _, %tv, %fv
369 ///
370 /// where %tv and %fv ultimately end up using %rec via the same %BO instruction,
371 /// after digging through the use-def chain.
372 ///
373 /// ExtraConst is relevant if \p BOWithConstOpToMatch is supplied: when digging
374 /// the use-def chain, a BinOp with opcode \p BOWithConstOpToMatch is matched,
375 /// and ExtraConst is a constant operand of that BinOp. This peculiarity exists,
376 /// because in a CRC algorithm, the \p BOWithConstOpToMatch is an XOR, and the
377 /// ExtraConst ends up being the generating polynomial.
matchConditionalRecurrence(const PHINode * P,Instruction::BinaryOps BOWithConstOpToMatch)378 bool RecurrenceInfo::matchConditionalRecurrence(
379 const PHINode *P, Instruction::BinaryOps BOWithConstOpToMatch) {
380 Phi = P;
381 if (Phi->getNumIncomingValues() != 2)
382 return false;
383
384 for (unsigned Idx = 0; Idx != 2; ++Idx) {
385 Value *FoundStep = Phi->getIncomingValue(Idx);
386 Value *FoundStart = Phi->getIncomingValue(!Idx);
387
388 Instruction *TV, *FV;
389 if (!match(FoundStep,
390 m_Select(m_Cmp(), m_Instruction(TV), m_Instruction(FV))))
391 continue;
392
393 // For a conditional recurrence, both the true and false values of the
394 // select must ultimately end up in the same recurrent BinOp.
395 BinaryOperator *FoundBO = digRecurrence(TV, BOWithConstOpToMatch);
396 BinaryOperator *AltBO = digRecurrence(FV, BOWithConstOpToMatch);
397 if (!FoundBO || FoundBO != AltBO)
398 return false;
399
400 if (BOWithConstOpToMatch != Instruction::BinaryOpsEnd && !ExtraConst) {
401 LLVM_DEBUG(dbgs() << "HashRecognize: Unable to match single BinaryOp "
402 "with constant in conditional recurrence\n");
403 return false;
404 }
405
406 BO = FoundBO;
407 Start = FoundStart;
408 Step = FoundStep;
409 return true;
410 }
411 return false;
412 }
413
414 /// Iterates over all the phis in \p LoopLatch, and attempts to extract a
415 /// Conditional Recurrence and an optional Simple Recurrence.
416 static std::optional<std::pair<RecurrenceInfo, RecurrenceInfo>>
getRecurrences(BasicBlock * LoopLatch,const PHINode * IndVar,const Loop & L)417 getRecurrences(BasicBlock *LoopLatch, const PHINode *IndVar, const Loop &L) {
418 auto Phis = LoopLatch->phis();
419 unsigned NumPhis = std::distance(Phis.begin(), Phis.end());
420 if (NumPhis != 2 && NumPhis != 3)
421 return {};
422
423 RecurrenceInfo SimpleRecurrence(L);
424 RecurrenceInfo ConditionalRecurrence(L);
425 for (PHINode &P : Phis) {
426 if (&P == IndVar)
427 continue;
428 if (!SimpleRecurrence)
429 SimpleRecurrence.matchSimpleRecurrence(&P);
430 if (!ConditionalRecurrence)
431 ConditionalRecurrence.matchConditionalRecurrence(
432 &P, Instruction::BinaryOps::Xor);
433 }
434 if (NumPhis == 3 && (!SimpleRecurrence || !ConditionalRecurrence))
435 return {};
436 return std::make_pair(SimpleRecurrence, ConditionalRecurrence);
437 }
438
PolynomialInfo(unsigned TripCount,Value * LHS,const APInt & RHS,Value * ComputedValue,bool ByteOrderSwapped,Value * LHSAux)439 PolynomialInfo::PolynomialInfo(unsigned TripCount, Value *LHS, const APInt &RHS,
440 Value *ComputedValue, bool ByteOrderSwapped,
441 Value *LHSAux)
442 : TripCount(TripCount), LHS(LHS), RHS(RHS), ComputedValue(ComputedValue),
443 ByteOrderSwapped(ByteOrderSwapped), LHSAux(LHSAux) {}
444
445 /// In the big-endian case, checks the bottom N bits against CheckFn, and that
446 /// the rest are unknown. In the little-endian case, checks the top N bits
447 /// against CheckFn, and that the rest are unknown. Callers usually call this
448 /// function with N = TripCount, and CheckFn checking that the remainder bits of
449 /// the CRC polynomial division are zero.
checkExtractBits(const KnownBits & Known,unsigned N,function_ref<bool (const KnownBits &)> CheckFn,bool ByteOrderSwapped)450 static bool checkExtractBits(const KnownBits &Known, unsigned N,
451 function_ref<bool(const KnownBits &)> CheckFn,
452 bool ByteOrderSwapped) {
453 // Check that the entire thing is a constant.
454 if (N == Known.getBitWidth())
455 return CheckFn(Known.extractBits(N, 0));
456
457 // Check that the {top, bottom} N bits are not unknown and that the {bottom,
458 // top} N bits are known.
459 unsigned BitPos = ByteOrderSwapped ? 0 : Known.getBitWidth() - N;
460 unsigned SwappedBitPos = ByteOrderSwapped ? N : 0;
461 return CheckFn(Known.extractBits(N, BitPos)) &&
462 Known.extractBits(Known.getBitWidth() - N, SwappedBitPos).isUnknown();
463 }
464
465 /// Generate a lookup table of 256 entries by interleaving the generating
466 /// polynomial. The optimization technique of table-lookup for CRC is also
467 /// called the Sarwate algorithm.
genSarwateTable(const APInt & GenPoly,bool ByteOrderSwapped)468 CRCTable HashRecognize::genSarwateTable(const APInt &GenPoly,
469 bool ByteOrderSwapped) {
470 unsigned BW = GenPoly.getBitWidth();
471 CRCTable Table;
472 Table[0] = APInt::getZero(BW);
473
474 if (ByteOrderSwapped) {
475 APInt CRCInit = APInt::getSignedMinValue(BW);
476 for (unsigned I = 1; I < 256; I <<= 1) {
477 CRCInit = CRCInit.shl(1) ^
478 (CRCInit.isSignBitSet() ? GenPoly : APInt::getZero(BW));
479 for (unsigned J = 0; J < I; ++J)
480 Table[I + J] = CRCInit ^ Table[J];
481 }
482 return Table;
483 }
484
485 APInt CRCInit(BW, 1);
486 for (unsigned I = 128; I; I >>= 1) {
487 CRCInit = CRCInit.lshr(1) ^ (CRCInit[0] ? GenPoly : APInt::getZero(BW));
488 for (unsigned J = 0; J < 256; J += (I << 1))
489 Table[I + J] = CRCInit ^ Table[J];
490 }
491 return Table;
492 }
493
494 /// Checks that \p P1 and \p P2 are used together in an XOR in the use-def chain
495 /// of \p SI's condition, ignoring any casts. The purpose of this function is to
496 /// ensure that LHSAux from the SimpleRecurrence is used correctly in the CRC
497 /// computation. We cannot check the correctness of casts at this point, and
498 /// rely on the KnownBits propagation to check correctness of the CRC
499 /// computation.
500 ///
501 /// In other words, it checks for the following pattern:
502 ///
503 /// loop:
504 /// %P1 = phi [_, %entry], [%P1.next, %loop]
505 /// %P2 = phi [_, %entry], [%P2.next, %loop]
506 /// ...
507 /// %xor = xor (CastOrSelf %P1), (CastOrSelf %P2)
508 ///
509 /// where %xor is in the use-def chain of \p SI's condition.
isConditionalOnXorOfPHIs(const SelectInst * SI,const PHINode * P1,const PHINode * P2,const Loop & L)510 static bool isConditionalOnXorOfPHIs(const SelectInst *SI, const PHINode *P1,
511 const PHINode *P2, const Loop &L) {
512 SmallVector<const Instruction *> Worklist;
513
514 // matchConditionalRecurrence has already ensured that the SelectInst's
515 // condition is an Instruction.
516 Worklist.push_back(cast<Instruction>(SI->getCondition()));
517
518 while (!Worklist.empty()) {
519 const Instruction *I = Worklist.pop_back_val();
520
521 // Don't add a PHI's operands to the Worklist.
522 if (isa<PHINode>(I))
523 continue;
524
525 // If we match an XOR of the two PHIs ignoring casts, we're done.
526 if (match(I, m_c_Xor(m_CastOrSelf(m_Specific(P1)),
527 m_CastOrSelf(m_Specific(P2)))))
528 return true;
529
530 // Continue along the use-def chain.
531 for (const Use &U : I->operands())
532 if (auto *UI = dyn_cast<Instruction>(U))
533 if (L.contains(UI))
534 Worklist.push_back(UI);
535 }
536 return false;
537 }
538
539 // Recognizes a multiplication or division by the constant two, using SCEV. By
540 // doing this, we're immune to whether the IR expression is mul/udiv or
541 // equivalently shl/lshr. Return false when it is a UDiv, true when it is a Mul,
542 // and std::nullopt otherwise.
isBigEndianBitShift(Value * V,ScalarEvolution & SE)543 static std::optional<bool> isBigEndianBitShift(Value *V, ScalarEvolution &SE) {
544 if (!V->getType()->isIntegerTy())
545 return {};
546
547 const SCEV *E = SE.getSCEV(V);
548 if (match(E, m_scev_UDiv(m_SCEV(), m_scev_SpecificInt(2))))
549 return false;
550 if (match(E, m_scev_Mul(m_scev_SpecificInt(2), m_SCEV())))
551 return true;
552 return {};
553 }
554
555 /// The main entry point for analyzing a loop and recognizing the CRC algorithm.
556 /// Returns a PolynomialInfo on success, and either an ErrBits or a StringRef on
557 /// failure.
558 std::variant<PolynomialInfo, ErrBits, StringRef>
recognizeCRC() const559 HashRecognize::recognizeCRC() const {
560 if (!L.isInnermost())
561 return "Loop is not innermost";
562 BasicBlock *Latch = L.getLoopLatch();
563 BasicBlock *Exit = L.getExitBlock();
564 const PHINode *IndVar = L.getCanonicalInductionVariable();
565 if (!Latch || !Exit || !IndVar || L.getNumBlocks() != 1)
566 return "Loop not in canonical form";
567 unsigned TC = SE.getSmallConstantTripCount(&L);
568 if (!TC || TC > 256 || TC % 8)
569 return "Unable to find a small constant byte-multiple trip count";
570
571 auto R = getRecurrences(Latch, IndVar, L);
572 if (!R)
573 return "Found stray PHI";
574 auto [SimpleRecurrence, ConditionalRecurrence] = *R;
575 if (!ConditionalRecurrence)
576 return "Unable to find conditional recurrence";
577
578 // Make sure that all recurrences are either all SCEVMul with two or SCEVDiv
579 // with two, or in other words, that they're single bit-shifts.
580 std::optional<bool> ByteOrderSwapped =
581 isBigEndianBitShift(ConditionalRecurrence.BO, SE);
582 if (!ByteOrderSwapped)
583 return "Loop with non-unit bitshifts";
584 if (SimpleRecurrence) {
585 if (isBigEndianBitShift(SimpleRecurrence.BO, SE) != ByteOrderSwapped)
586 return "Loop with non-unit bitshifts";
587
588 // Ensure that the PHIs have exactly two uses:
589 // the bit-shift, and the XOR (or a cast feeding into the XOR).
590 if (!ConditionalRecurrence.Phi->hasNUses(2) ||
591 !SimpleRecurrence.Phi->hasNUses(2))
592 return "Recurrences have stray uses";
593
594 // Check that the SelectInst ConditionalRecurrence.Step is conditional on
595 // the XOR of SimpleRecurrence.Phi and ConditionalRecurrence.Phi.
596 if (!isConditionalOnXorOfPHIs(cast<SelectInst>(ConditionalRecurrence.Step),
597 SimpleRecurrence.Phi,
598 ConditionalRecurrence.Phi, L))
599 return "Recurrences not intertwined with XOR";
600 }
601
602 // Make sure that the TC doesn't exceed the bitwidth of LHSAux, or LHS.
603 Value *LHS = ConditionalRecurrence.Start;
604 Value *LHSAux = SimpleRecurrence ? SimpleRecurrence.Start : nullptr;
605 if (TC > (LHSAux ? LHSAux->getType()->getIntegerBitWidth()
606 : LHS->getType()->getIntegerBitWidth()))
607 return "Loop iterations exceed bitwidth of data";
608
609 // Make sure that the computed value is used in the exit block: this should be
610 // true even if it is only really used in an outer loop's exit block, since
611 // the loop is in LCSSA form.
612 auto *ComputedValue = cast<SelectInst>(ConditionalRecurrence.Step);
613 if (none_of(ComputedValue->users(), [Exit](User *U) {
614 auto *UI = dyn_cast<Instruction>(U);
615 return UI && UI->getParent() == Exit;
616 }))
617 return "Unable to find use of computed value in loop exit block";
618
619 assert(ConditionalRecurrence.ExtraConst &&
620 "Expected ExtraConst in conditional recurrence");
621 const APInt &GenPoly = *ConditionalRecurrence.ExtraConst;
622
623 // PhiEvolutions are pairs of PHINodes along with their incoming value from
624 // within the loop, which we term as their step. Note that in the case of a
625 // Simple Recurrence, Step is an operand of the BO, while in a Conditional
626 // Recurrence, it is a SelectInst.
627 SmallVector<PhiStepPair, 2> PhiEvolutions;
628 PhiEvolutions.emplace_back(ConditionalRecurrence.Phi, ComputedValue);
629 if (SimpleRecurrence)
630 PhiEvolutions.emplace_back(SimpleRecurrence.Phi, SimpleRecurrence.BO);
631
632 ValueEvolution VE(TC, *ByteOrderSwapped);
633 if (!VE.computeEvolutions(PhiEvolutions))
634 return VE.getError();
635 KnownBits ResultBits = VE.KnownPhis.at(ConditionalRecurrence.Phi);
636
637 unsigned N = std::min(TC, ResultBits.getBitWidth());
638 auto IsZero = [](const KnownBits &K) { return K.isZero(); };
639 if (!checkExtractBits(ResultBits, N, IsZero, *ByteOrderSwapped))
640 return ErrBits(ResultBits, TC, *ByteOrderSwapped);
641
642 return PolynomialInfo(TC, LHS, GenPoly, ComputedValue, *ByteOrderSwapped,
643 LHSAux);
644 }
645
print(raw_ostream & OS) const646 void CRCTable::print(raw_ostream &OS) const {
647 for (unsigned I = 0; I < 256; I++) {
648 (*this)[I].print(OS, false);
649 OS << (I % 16 == 15 ? '\n' : ' ');
650 }
651 }
652
653 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
dump() const654 void CRCTable::dump() const { print(dbgs()); }
655 #endif
656
print(raw_ostream & OS) const657 void HashRecognize::print(raw_ostream &OS) const {
658 if (!L.isInnermost())
659 return;
660 OS << "HashRecognize: Checking a loop in '"
661 << L.getHeader()->getParent()->getName() << "' from " << L.getLocStr()
662 << "\n";
663 auto Ret = recognizeCRC();
664 if (!std::holds_alternative<PolynomialInfo>(Ret)) {
665 OS << "Did not find a hash algorithm\n";
666 if (std::holds_alternative<StringRef>(Ret))
667 OS << "Reason: " << std::get<StringRef>(Ret) << "\n";
668 if (std::holds_alternative<ErrBits>(Ret)) {
669 auto [Actual, Iter, ByteOrderSwapped] = std::get<ErrBits>(Ret);
670 OS << "Reason: Expected " << (ByteOrderSwapped ? "bottom " : "top ")
671 << Iter << " bits zero (";
672 Actual.print(OS);
673 OS << ")\n";
674 }
675 return;
676 }
677
678 auto Info = std::get<PolynomialInfo>(Ret);
679 OS << "Found" << (Info.ByteOrderSwapped ? " big-endian " : " little-endian ")
680 << "CRC-" << Info.RHS.getBitWidth() << " loop with trip count "
681 << Info.TripCount << "\n";
682 OS.indent(2) << "Initial CRC: ";
683 Info.LHS->print(OS);
684 OS << "\n";
685 OS.indent(2) << "Generating polynomial: ";
686 Info.RHS.print(OS, false);
687 OS << "\n";
688 OS.indent(2) << "Computed CRC: ";
689 Info.ComputedValue->print(OS);
690 OS << "\n";
691 if (Info.LHSAux) {
692 OS.indent(2) << "Auxiliary data: ";
693 Info.LHSAux->print(OS);
694 OS << "\n";
695 }
696 OS.indent(2) << "Computed CRC lookup table:\n";
697 genSarwateTable(Info.RHS, Info.ByteOrderSwapped).print(OS);
698 }
699
700 #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
dump() const701 void HashRecognize::dump() const { print(dbgs()); }
702 #endif
703
getResult() const704 std::optional<PolynomialInfo> HashRecognize::getResult() const {
705 auto Res = HashRecognize(L, SE).recognizeCRC();
706 if (std::holds_alternative<PolynomialInfo>(Res))
707 return std::get<PolynomialInfo>(Res);
708 return std::nullopt;
709 }
710
HashRecognize(const Loop & L,ScalarEvolution & SE)711 HashRecognize::HashRecognize(const Loop &L, ScalarEvolution &SE)
712 : L(L), SE(SE) {}
713
run(Loop & L,LoopAnalysisManager & AM,LoopStandardAnalysisResults & AR,LPMUpdater &)714 PreservedAnalyses HashRecognizePrinterPass::run(Loop &L,
715 LoopAnalysisManager &AM,
716 LoopStandardAnalysisResults &AR,
717 LPMUpdater &) {
718 HashRecognize(L, AR.SE).print(OS);
719 return PreservedAnalyses::all();
720 }
721