xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Vectorize/LoopIdiomVectorize.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements a pass that recognizes certain loop idioms and
10 // transforms them into more optimized versions of the same loop. In cases
11 // where this happens, it can be a significant performance win.
12 //
13 // We currently support two loops:
14 //
15 // 1. A loop that finds the first mismatched byte in an array and returns the
16 // index, i.e. something like:
17 //
18 //  while (++i != n) {
19 //    if (a[i] != b[i])
20 //      break;
21 //  }
22 //
23 // In this example we can actually vectorize the loop despite the early exit,
24 // although the loop vectorizer does not support it. It requires some extra
25 // checks to deal with the possibility of faulting loads when crossing page
26 // boundaries. However, even with these checks it is still profitable to do the
27 // transformation.
28 //
29 // TODO List:
30 //
31 // * Add support for the inverse case where we scan for a matching element.
32 // * Permit 64-bit induction variable types.
33 // * Recognize loops that increment the IV *after* comparing bytes.
34 // * Allow 32-bit sign-extends of the IV used by the GEP.
35 //
36 // 2. A loop that finds the first matching character in an array among a set of
37 // possible matches, e.g.:
38 //
39 //   for (; first != last; ++first)
40 //     for (s_it = s_first; s_it != s_last; ++s_it)
41 //       if (*first == *s_it)
42 //         return first;
43 //   return last;
44 //
45 // This corresponds to std::find_first_of (for arrays of bytes) from the C++
46 // standard library. This function can be implemented efficiently for targets
47 // that support @llvm.experimental.vector.match. For example, on AArch64 targets
48 // that implement SVE2, this lower to a MATCH instruction, which enables us to
49 // perform up to 16x16=256 comparisons in one go. This can lead to very
50 // significant speedups.
51 //
52 // TODO:
53 //
54 // * Add support for `find_first_not_of' loops (i.e. with not-equal comparison).
55 // * Make VF a configurable parameter (right now we assume 128-bit vectors).
56 // * Potentially adjust the cost model to let the transformation kick-in even if
57 //   @llvm.experimental.vector.match doesn't have direct support in hardware.
58 //
59 //===----------------------------------------------------------------------===//
60 //
61 // NOTE: This Pass matches really specific loop patterns because it's only
62 // supposed to be a temporary solution until our LoopVectorizer is powerful
63 // enough to vectorize them automatically.
64 //
65 //===----------------------------------------------------------------------===//
66 
67 #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
68 #include "llvm/Analysis/DomTreeUpdater.h"
69 #include "llvm/Analysis/LoopPass.h"
70 #include "llvm/Analysis/TargetTransformInfo.h"
71 #include "llvm/IR/Dominators.h"
72 #include "llvm/IR/IRBuilder.h"
73 #include "llvm/IR/Intrinsics.h"
74 #include "llvm/IR/MDBuilder.h"
75 #include "llvm/IR/PatternMatch.h"
76 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
77 
78 using namespace llvm;
79 using namespace PatternMatch;
80 
81 #define DEBUG_TYPE "loop-idiom-vectorize"
82 
83 static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
84                                 cl::init(false),
85                                 cl::desc("Disable Loop Idiom Vectorize Pass."));
86 
87 static cl::opt<LoopIdiomVectorizeStyle>
88     LITVecStyle("loop-idiom-vectorize-style", cl::Hidden,
89                 cl::desc("The vectorization style for loop idiom transform."),
90                 cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked",
91                                       "Use masked vector intrinsics"),
92                            clEnumValN(LoopIdiomVectorizeStyle::Predicated,
93                                       "predicated", "Use VP intrinsics")),
94                 cl::init(LoopIdiomVectorizeStyle::Masked));
95 
96 static cl::opt<bool>
97     DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden,
98                    cl::init(false),
99                    cl::desc("Proceed with Loop Idiom Vectorize Pass, but do "
100                             "not convert byte-compare loop(s)."));
101 
102 static cl::opt<unsigned>
103     ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden,
104               cl::desc("The vectorization factor for byte-compare patterns."),
105               cl::init(16));
106 
107 static cl::opt<bool>
108     DisableFindFirstByte("disable-loop-idiom-vectorize-find-first-byte",
109                          cl::Hidden, cl::init(false),
110                          cl::desc("Do not convert find-first-byte loop(s)."));
111 
112 static cl::opt<bool>
113     VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
114                 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
115 
116 namespace {
117 class LoopIdiomVectorize {
118   LoopIdiomVectorizeStyle VectorizeStyle;
119   unsigned ByteCompareVF;
120   Loop *CurLoop = nullptr;
121   DominatorTree *DT;
122   LoopInfo *LI;
123   const TargetTransformInfo *TTI;
124   const DataLayout *DL;
125 
126   // Blocks that will be used for inserting vectorized code.
127   BasicBlock *EndBlock = nullptr;
128   BasicBlock *VectorLoopPreheaderBlock = nullptr;
129   BasicBlock *VectorLoopStartBlock = nullptr;
130   BasicBlock *VectorLoopMismatchBlock = nullptr;
131   BasicBlock *VectorLoopIncBlock = nullptr;
132 
133 public:
LoopIdiomVectorize(LoopIdiomVectorizeStyle S,unsigned VF,DominatorTree * DT,LoopInfo * LI,const TargetTransformInfo * TTI,const DataLayout * DL)134   LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
135                      LoopInfo *LI, const TargetTransformInfo *TTI,
136                      const DataLayout *DL)
137       : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
138   }
139 
140   bool run(Loop *L);
141 
142 private:
143   /// \name Countable Loop Idiom Handling
144   /// @{
145 
146   bool runOnCountableLoop();
147   bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
148                       SmallVectorImpl<BasicBlock *> &ExitBlocks);
149 
150   bool recognizeByteCompare();
151 
152   Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
153                             GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
154                             Instruction *Index, Value *Start, Value *MaxLen);
155 
156   Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
157                                   GetElementPtrInst *GEPA,
158                                   GetElementPtrInst *GEPB, Value *ExtStart,
159                                   Value *ExtEnd);
160   Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
161                                       GetElementPtrInst *GEPA,
162                                       GetElementPtrInst *GEPB, Value *ExtStart,
163                                       Value *ExtEnd);
164 
165   void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
166                             PHINode *IndPhi, Value *MaxLen, Instruction *Index,
167                             Value *Start, bool IncIdx, BasicBlock *FoundBB,
168                             BasicBlock *EndBB);
169 
170   bool recognizeFindFirstByte();
171 
172   Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU,
173                              unsigned VF, Type *CharTy, BasicBlock *ExitSucc,
174                              BasicBlock *ExitFail, Value *SearchStart,
175                              Value *SearchEnd, Value *NeedleStart,
176                              Value *NeedleEnd);
177 
178   void transformFindFirstByte(PHINode *IndPhi, unsigned VF, Type *CharTy,
179                               BasicBlock *ExitSucc, BasicBlock *ExitFail,
180                               Value *SearchStart, Value *SearchEnd,
181                               Value *NeedleStart, Value *NeedleEnd);
182   /// @}
183 };
184 } // anonymous namespace
185 
run(Loop & L,LoopAnalysisManager & AM,LoopStandardAnalysisResults & AR,LPMUpdater &)186 PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
187                                               LoopStandardAnalysisResults &AR,
188                                               LPMUpdater &) {
189   if (DisableAll)
190     return PreservedAnalyses::all();
191 
192   const auto *DL = &L.getHeader()->getDataLayout();
193 
194   LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
195   if (LITVecStyle.getNumOccurrences())
196     VecStyle = LITVecStyle;
197 
198   unsigned BCVF = ByteCompareVF;
199   if (ByteCmpVF.getNumOccurrences())
200     BCVF = ByteCmpVF;
201 
202   LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL);
203   if (!LIV.run(&L))
204     return PreservedAnalyses::all();
205 
206   return PreservedAnalyses::none();
207 }
208 
209 //===----------------------------------------------------------------------===//
210 //
211 //          Implementation of LoopIdiomVectorize
212 //
213 //===----------------------------------------------------------------------===//
214 
run(Loop * L)215 bool LoopIdiomVectorize::run(Loop *L) {
216   CurLoop = L;
217 
218   Function &F = *L->getHeader()->getParent();
219   if (DisableAll || F.hasOptSize())
220     return false;
221 
222   if (F.hasFnAttribute(Attribute::NoImplicitFloat)) {
223     LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName()
224                       << " due to its NoImplicitFloat attribute");
225     return false;
226   }
227 
228   // If the loop could not be converted to canonical form, it must have an
229   // indirectbr in it, just give up.
230   if (!L->getLoopPreheader())
231     return false;
232 
233   LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
234                     << CurLoop->getHeader()->getName() << "\n");
235 
236   if (recognizeByteCompare())
237     return true;
238 
239   if (recognizeFindFirstByte())
240     return true;
241 
242   return false;
243 }
244 
recognizeByteCompare()245 bool LoopIdiomVectorize::recognizeByteCompare() {
246   // Currently the transformation only works on scalable vector types, although
247   // there is no fundamental reason why it cannot be made to work for fixed
248   // width too.
249 
250   // We also need to know the minimum page size for the target in order to
251   // generate runtime memory checks to ensure the vector version won't fault.
252   if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
253       DisableByteCmp)
254     return false;
255 
256   BasicBlock *Header = CurLoop->getHeader();
257 
258   // In LoopIdiomVectorize::run we have already checked that the loop
259   // has a preheader so we can assume it's in a canonical form.
260   if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
261     return false;
262 
263   PHINode *PN = dyn_cast<PHINode>(&Header->front());
264   if (!PN || PN->getNumIncomingValues() != 2)
265     return false;
266 
267   auto LoopBlocks = CurLoop->getBlocks();
268   // The first block in the loop should contain only 4 instructions, e.g.
269   //
270   //  while.cond:
271   //   %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
272   //   %inc = add i32 %res.phi, 1
273   //   %cmp.not = icmp eq i32 %inc, %n
274   //   br i1 %cmp.not, label %while.end, label %while.body
275   //
276   if (LoopBlocks[0]->sizeWithoutDebug() > 4)
277     return false;
278 
279   // The second block should contain 7 instructions, e.g.
280   //
281   // while.body:
282   //   %idx = zext i32 %inc to i64
283   //   %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
284   //   %load.a = load i8, ptr %idx.a
285   //   %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
286   //   %load.b = load i8, ptr %idx.b
287   //   %cmp.not.ld = icmp eq i8 %load.a, %load.b
288   //   br i1 %cmp.not.ld, label %while.cond, label %while.end
289   //
290   if (LoopBlocks[1]->sizeWithoutDebug() > 7)
291     return false;
292 
293   // The incoming value to the PHI node from the loop should be an add of 1.
294   Value *StartIdx = nullptr;
295   Instruction *Index = nullptr;
296   if (!CurLoop->contains(PN->getIncomingBlock(0))) {
297     StartIdx = PN->getIncomingValue(0);
298     Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
299   } else {
300     StartIdx = PN->getIncomingValue(1);
301     Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
302   }
303 
304   // Limit to 32-bit types for now
305   if (!Index || !Index->getType()->isIntegerTy(32) ||
306       !match(Index, m_c_Add(m_Specific(PN), m_One())))
307     return false;
308 
309   // If we match the pattern, PN and Index will be replaced with the result of
310   // the cttz.elts intrinsic. If any other instructions are used outside of
311   // the loop, we cannot replace it.
312   for (BasicBlock *BB : LoopBlocks)
313     for (Instruction &I : *BB)
314       if (&I != PN && &I != Index)
315         for (User *U : I.users())
316           if (!CurLoop->contains(cast<Instruction>(U)))
317             return false;
318 
319   // Match the branch instruction for the header
320   Value *MaxLen;
321   BasicBlock *EndBB, *WhileBB;
322   if (!match(Header->getTerminator(),
323              m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(Index),
324                                  m_Value(MaxLen)),
325                   m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
326       !CurLoop->contains(WhileBB))
327     return false;
328 
329   // WhileBB should contain the pattern of load & compare instructions. Match
330   // the pattern and find the GEP instructions used by the loads.
331   BasicBlock *FoundBB;
332   BasicBlock *TrueBB;
333   Value *LoadA, *LoadB;
334   if (!match(WhileBB->getTerminator(),
335              m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(LoadA),
336                                  m_Value(LoadB)),
337                   m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
338       !CurLoop->contains(TrueBB))
339     return false;
340 
341   Value *A, *B;
342   if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
343     return false;
344 
345   LoadInst *LoadAI = cast<LoadInst>(LoadA);
346   LoadInst *LoadBI = cast<LoadInst>(LoadB);
347   if (!LoadAI->isSimple() || !LoadBI->isSimple())
348     return false;
349 
350   GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
351   GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
352 
353   if (!GEPA || !GEPB)
354     return false;
355 
356   Value *PtrA = GEPA->getPointerOperand();
357   Value *PtrB = GEPB->getPointerOperand();
358 
359   // Check we are loading i8 values from two loop invariant pointers
360   if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
361       !GEPA->getResultElementType()->isIntegerTy(8) ||
362       !GEPB->getResultElementType()->isIntegerTy(8) ||
363       !LoadAI->getType()->isIntegerTy(8) ||
364       !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
365     return false;
366 
367   // Check that the index to the GEPs is the index we found earlier
368   if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
369     return false;
370 
371   Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
372   Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
373   if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
374     return false;
375 
376   // We only ever expect the pre-incremented index value to be used inside the
377   // loop.
378   if (!PN->hasOneUse())
379     return false;
380 
381   // Ensure that when the Found and End blocks are identical the PHIs have the
382   // supported format. We don't currently allow cases like this:
383   // while.cond:
384   //   ...
385   //   br i1 %cmp.not, label %while.end, label %while.body
386   //
387   // while.body:
388   //   ...
389   //   br i1 %cmp.not2, label %while.cond, label %while.end
390   //
391   // while.end:
392   //   %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
393   //
394   // Where the incoming values for %final_ptr are unique and from each of the
395   // loop blocks, but not actually defined in the loop. This requires extra
396   // work setting up the byte.compare block, i.e. by introducing a select to
397   // choose the correct value.
398   // TODO: We could add support for this in future.
399   if (FoundBB == EndBB) {
400     for (PHINode &EndPN : EndBB->phis()) {
401       Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header);
402       Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB);
403 
404       // The value of the index when leaving the while.cond block is always the
405       // same as the end value (MaxLen) so we permit either. The value when
406       // leaving the while.body block should only be the index. Otherwise for
407       // any other values we only allow ones that are same for both blocks.
408       if (WhileCondVal != WhileBodyVal &&
409           ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
410            (WhileBodyVal != Index)))
411         return false;
412     }
413   }
414 
415   LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
416                     << *(EndBB->getParent()) << "\n\n");
417 
418   // The index is incremented before the GEP/Load pair so we need to
419   // add 1 to the start value.
420   transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true,
421                        FoundBB, EndBB);
422   return true;
423 }
424 
createMaskedFindMismatch(IRBuilder<> & Builder,DomTreeUpdater & DTU,GetElementPtrInst * GEPA,GetElementPtrInst * GEPB,Value * ExtStart,Value * ExtEnd)425 Value *LoopIdiomVectorize::createMaskedFindMismatch(
426     IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
427     GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
428   Type *I64Type = Builder.getInt64Ty();
429   Type *ResType = Builder.getInt32Ty();
430   Type *LoadType = Builder.getInt8Ty();
431   Value *PtrA = GEPA->getPointerOperand();
432   Value *PtrB = GEPB->getPointerOperand();
433 
434   ScalableVectorType *PredVTy =
435       ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF);
436 
437   Value *InitialPred = Builder.CreateIntrinsic(
438       Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
439 
440   Value *VecLen = Builder.CreateVScale(I64Type);
441   VecLen =
442       Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
443                         /*HasNUW=*/true, /*HasNSW=*/true);
444 
445   Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
446                                             Builder.getInt1(false));
447 
448   BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
449   Builder.Insert(JumpToVectorLoop);
450 
451   DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
452                      VectorLoopStartBlock}});
453 
454   // Set up the first vector loop block by creating the PHIs, doing the vector
455   // loads and comparing the vectors.
456   Builder.SetInsertPoint(VectorLoopStartBlock);
457   PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
458   LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
459   PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
460   VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
461   Type *VectorLoadType =
462       ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF);
463   Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
464 
465   Value *VectorLhsGep =
466       Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
467   Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
468                                                   Align(1), LoopPred, Passthru);
469 
470   Value *VectorRhsGep =
471       Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
472   Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
473                                                   Align(1), LoopPred, Passthru);
474 
475   Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
476   VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
477   Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
478   BranchInst *VectorEarlyExit = BranchInst::Create(
479       VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
480   Builder.Insert(VectorEarlyExit);
481 
482   DTU.applyUpdates(
483       {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
484        {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
485 
486   // Increment the index counter and calculate the predicate for the next
487   // iteration of the loop. We branch back to the start of the loop if there
488   // is at least one active lane.
489   Builder.SetInsertPoint(VectorLoopIncBlock);
490   Value *NewVectorIndexPhi =
491       Builder.CreateAdd(VectorIndexPhi, VecLen, "",
492                         /*HasNUW=*/true, /*HasNSW=*/true);
493   VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
494   Value *NewPred =
495       Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
496                               {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
497   LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
498 
499   Value *PredHasActiveLanes =
500       Builder.CreateExtractElement(NewPred, uint64_t(0));
501   BranchInst *VectorLoopBranchBack =
502       BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
503   Builder.Insert(VectorLoopBranchBack);
504 
505   DTU.applyUpdates(
506       {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
507        {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
508 
509   // If we found a mismatch then we need to calculate which lane in the vector
510   // had a mismatch and add that on to the current loop index.
511   Builder.SetInsertPoint(VectorLoopMismatchBlock);
512   PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
513   FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
514   PHINode *LastLoopPred =
515       Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
516   LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
517   PHINode *VectorFoundIndex =
518       Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
519   VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
520 
521   Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
522   Value *Ctz = Builder.CreateCountTrailingZeroElems(ResType, PredMatchCmp);
523   Ctz = Builder.CreateZExt(Ctz, I64Type);
524   Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
525                                              /*HasNUW=*/true, /*HasNSW=*/true);
526   return Builder.CreateTrunc(VectorLoopRes64, ResType);
527 }
528 
createPredicatedFindMismatch(IRBuilder<> & Builder,DomTreeUpdater & DTU,GetElementPtrInst * GEPA,GetElementPtrInst * GEPB,Value * ExtStart,Value * ExtEnd)529 Value *LoopIdiomVectorize::createPredicatedFindMismatch(
530     IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
531     GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
532   Type *I64Type = Builder.getInt64Ty();
533   Type *I32Type = Builder.getInt32Ty();
534   Type *ResType = I32Type;
535   Type *LoadType = Builder.getInt8Ty();
536   Value *PtrA = GEPA->getPointerOperand();
537   Value *PtrB = GEPB->getPointerOperand();
538 
539   auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
540   Builder.Insert(JumpToVectorLoop);
541 
542   DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
543                      VectorLoopStartBlock}});
544 
545   // Set up the first Vector loop block by creating the PHIs, doing the vector
546   // loads and comparing the vectors.
547   Builder.SetInsertPoint(VectorLoopStartBlock);
548   auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index");
549   VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
550 
551   // Calculate AVL by subtracting the vector loop index from the trip count
552   Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true,
553                                  /*HasNSW=*/true);
554 
555   auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF);
556   auto *VF = ConstantInt::get(I32Type, ByteCompareVF);
557 
558   Value *VL = Builder.CreateIntrinsic(Intrinsic::experimental_get_vector_length,
559                                       {I64Type}, {AVL, VF, Builder.getTrue()});
560   Value *GepOffset = VectorIndexPhi;
561 
562   Value *VectorLhsGep =
563       Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds());
564   VectorType *TrueMaskTy =
565       VectorType::get(Builder.getInt1Ty(), VectorLoadType->getElementCount());
566   Value *AllTrueMask = Constant::getAllOnesValue(TrueMaskTy);
567   Value *VectorLhsLoad = Builder.CreateIntrinsic(
568       Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
569       {VectorLhsGep, AllTrueMask, VL}, nullptr, "lhs.load");
570 
571   Value *VectorRhsGep =
572       Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds());
573   Value *VectorRhsLoad = Builder.CreateIntrinsic(
574       Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
575       {VectorRhsGep, AllTrueMask, VL}, nullptr, "rhs.load");
576 
577   StringRef PredicateStr = CmpInst::getPredicateName(CmpInst::ICMP_NE);
578   auto *PredicateMDS = MDString::get(VectorLhsLoad->getContext(), PredicateStr);
579   Value *Pred = MetadataAsValue::get(VectorLhsLoad->getContext(), PredicateMDS);
580   Value *VectorMatchCmp = Builder.CreateIntrinsic(
581       Intrinsic::vp_icmp, {VectorLhsLoad->getType()},
582       {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr,
583       "mismatch.cmp");
584   Value *CTZ = Builder.CreateIntrinsic(
585       Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType()},
586       {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(false), AllTrueMask,
587        VL});
588   Value *MismatchFound = Builder.CreateICmpNE(CTZ, VL);
589   auto *VectorEarlyExit = BranchInst::Create(VectorLoopMismatchBlock,
590                                              VectorLoopIncBlock, MismatchFound);
591   Builder.Insert(VectorEarlyExit);
592 
593   DTU.applyUpdates(
594       {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
595        {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
596 
597   // Increment the index counter and calculate the predicate for the next
598   // iteration of the loop. We branch back to the start of the loop if there
599   // is at least one active lane.
600   Builder.SetInsertPoint(VectorLoopIncBlock);
601   Value *VL64 = Builder.CreateZExt(VL, I64Type);
602   Value *NewVectorIndexPhi =
603       Builder.CreateAdd(VectorIndexPhi, VL64, "",
604                         /*HasNUW=*/true, /*HasNSW=*/true);
605   VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
606   Value *ExitCond = Builder.CreateICmpNE(NewVectorIndexPhi, ExtEnd);
607   auto *VectorLoopBranchBack =
608       BranchInst::Create(VectorLoopStartBlock, EndBlock, ExitCond);
609   Builder.Insert(VectorLoopBranchBack);
610 
611   DTU.applyUpdates(
612       {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
613        {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
614 
615   // If we found a mismatch then we need to calculate which lane in the vector
616   // had a mismatch and add that on to the current loop index.
617   Builder.SetInsertPoint(VectorLoopMismatchBlock);
618 
619   // Add LCSSA phis for CTZ and VectorIndexPhi.
620   auto *CTZLCSSAPhi = Builder.CreatePHI(CTZ->getType(), 1, "ctz");
621   CTZLCSSAPhi->addIncoming(CTZ, VectorLoopStartBlock);
622   auto *VectorIndexLCSSAPhi =
623       Builder.CreatePHI(VectorIndexPhi->getType(), 1, "mismatch_vector_index");
624   VectorIndexLCSSAPhi->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
625 
626   Value *CTZI64 = Builder.CreateZExt(CTZLCSSAPhi, I64Type);
627   Value *VectorLoopRes64 = Builder.CreateAdd(VectorIndexLCSSAPhi, CTZI64, "",
628                                              /*HasNUW=*/true, /*HasNSW=*/true);
629   return Builder.CreateTrunc(VectorLoopRes64, ResType);
630 }
631 
expandFindMismatch(IRBuilder<> & Builder,DomTreeUpdater & DTU,GetElementPtrInst * GEPA,GetElementPtrInst * GEPB,Instruction * Index,Value * Start,Value * MaxLen)632 Value *LoopIdiomVectorize::expandFindMismatch(
633     IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
634     GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
635   Value *PtrA = GEPA->getPointerOperand();
636   Value *PtrB = GEPB->getPointerOperand();
637 
638   // Get the arguments and types for the intrinsic.
639   BasicBlock *Preheader = CurLoop->getLoopPreheader();
640   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
641   LLVMContext &Ctx = PHBranch->getContext();
642   Type *LoadType = Type::getInt8Ty(Ctx);
643   Type *ResType = Builder.getInt32Ty();
644 
645   // Split block in the original loop preheader.
646   EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
647 
648   // Create the blocks that we're going to need:
649   //  1. A block for checking the zero-extended length exceeds 0
650   //  2. A block to check that the start and end addresses of a given array
651   //     lie on the same page.
652   //  3. The vector loop preheader.
653   //  4. The first vector loop block.
654   //  5. The vector loop increment block.
655   //  6. A block we can jump to from the vector loop when a mismatch is found.
656   //  7. The first block of the scalar loop itself, containing PHIs , loads
657   //  and cmp.
658   //  8. A scalar loop increment block to increment the PHIs and go back
659   //  around the loop.
660 
661   BasicBlock *MinItCheckBlock = BasicBlock::Create(
662       Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock);
663 
664   // Update the terminator added by SplitBlock to branch to the first block
665   Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock);
666 
667   BasicBlock *MemCheckBlock = BasicBlock::Create(
668       Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
669 
670   VectorLoopPreheaderBlock = BasicBlock::Create(
671       Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock);
672 
673   VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop",
674                                             EndBlock->getParent(), EndBlock);
675 
676   VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc",
677                                           EndBlock->getParent(), EndBlock);
678 
679   VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found",
680                                                EndBlock->getParent(), EndBlock);
681 
682   BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
683       Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
684 
685   BasicBlock *LoopStartBlock =
686       BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock);
687 
688   BasicBlock *LoopIncBlock = BasicBlock::Create(
689       Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock);
690 
691   DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock},
692                     {DominatorTree::Delete, Preheader, EndBlock}});
693 
694   // Update LoopInfo with the new vector & scalar loops.
695   auto VectorLoop = LI->AllocateLoop();
696   auto ScalarLoop = LI->AllocateLoop();
697 
698   if (CurLoop->getParentLoop()) {
699     CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI);
700     CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI);
701     CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopPreheaderBlock,
702                                                   *LI);
703     CurLoop->getParentLoop()->addChildLoop(VectorLoop);
704     CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopMismatchBlock, *LI);
705     CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI);
706     CurLoop->getParentLoop()->addChildLoop(ScalarLoop);
707   } else {
708     LI->addTopLevelLoop(VectorLoop);
709     LI->addTopLevelLoop(ScalarLoop);
710   }
711 
712   // Add the new basic blocks to their associated loops.
713   VectorLoop->addBasicBlockToLoop(VectorLoopStartBlock, *LI);
714   VectorLoop->addBasicBlockToLoop(VectorLoopIncBlock, *LI);
715 
716   ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI);
717   ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI);
718 
719   // Set up some types and constants that we intend to reuse.
720   Type *I64Type = Builder.getInt64Ty();
721 
722   // Check the zero-extended iteration count > 0
723   Builder.SetInsertPoint(MinItCheckBlock);
724   Value *ExtStart = Builder.CreateZExt(Start, I64Type);
725   Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type);
726   // This check doesn't really cost us very much.
727 
728   Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen);
729   BranchInst *MinItCheckBr =
730       BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck);
731   MinItCheckBr->setMetadata(
732       LLVMContext::MD_prof,
733       MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1));
734   Builder.Insert(MinItCheckBr);
735 
736   DTU.applyUpdates(
737       {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock},
738        {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}});
739 
740   // For each of the arrays, check the start/end addresses are on the same
741   // page.
742   Builder.SetInsertPoint(MemCheckBlock);
743 
744   // The early exit in the original loop means that when performing vector
745   // loads we are potentially reading ahead of the early exit. So we could
746   // fault if crossing a page boundary. Therefore, we create runtime memory
747   // checks based on the minimum page size as follows:
748   //   1. Calculate the addresses of the first memory accesses in the loop,
749   //      i.e. LhsStart and RhsStart.
750   //   2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd.
751   //   3. Determine which pages correspond to all the memory accesses, i.e
752   //      LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage.
753   //   4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then
754   //      we know we won't cross any page boundaries in the loop so we can
755   //      enter the vector loop! Otherwise we fall back on the scalar loop.
756   Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart);
757   Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart);
758   Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type);
759   Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type);
760   Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd);
761   Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd);
762   Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type);
763   Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type);
764 
765   const uint64_t MinPageSize = TTI->getMinPageSize().value();
766   const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize);
767   Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt);
768   Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt);
769   Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt);
770   Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt);
771   Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage);
772   Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage);
773 
774   Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp);
775   BranchInst *CombinedPageCmpCmpBr = BranchInst::Create(
776       LoopPreHeaderBlock, VectorLoopPreheaderBlock, CombinedPageCmp);
777   CombinedPageCmpCmpBr->setMetadata(
778       LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext())
779                                 .createBranchWeights(10, 90));
780   Builder.Insert(CombinedPageCmpCmpBr);
781 
782   DTU.applyUpdates(
783       {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock},
784        {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}});
785 
786   // Set up the vector loop preheader, i.e. calculate initial loop predicate,
787   // zero-extend MaxLen to 64-bits, determine the number of vector elements
788   // processed in each iteration, etc.
789   Builder.SetInsertPoint(VectorLoopPreheaderBlock);
790 
791   // At this point we know two things must be true:
792   //  1. Start <= End
793   //  2. ExtMaxLen <= MinPageSize due to the page checks.
794   // Therefore, we know that we can use a 64-bit induction variable that
795   // starts from 0 -> ExtMaxLen and it will not overflow.
796   Value *VectorLoopRes = nullptr;
797   switch (VectorizeStyle) {
798   case LoopIdiomVectorizeStyle::Masked:
799     VectorLoopRes =
800         createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
801     break;
802   case LoopIdiomVectorizeStyle::Predicated:
803     VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB,
804                                                  ExtStart, ExtEnd);
805     break;
806   }
807 
808   Builder.Insert(BranchInst::Create(EndBlock));
809 
810   DTU.applyUpdates(
811       {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}});
812 
813   // Generate code for scalar loop.
814   Builder.SetInsertPoint(LoopPreHeaderBlock);
815   Builder.Insert(BranchInst::Create(LoopStartBlock));
816 
817   DTU.applyUpdates(
818       {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}});
819 
820   Builder.SetInsertPoint(LoopStartBlock);
821   PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index");
822   IndexPhi->addIncoming(Start, LoopPreHeaderBlock);
823 
824   // Otherwise compare the values
825   // Load bytes from each array and compare them.
826   Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type);
827 
828   Value *LhsGep =
829       Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds());
830   Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep);
831 
832   Value *RhsGep =
833       Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds());
834   Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep);
835 
836   Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
837   // If we have a mismatch then exit the loop ...
838   BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp);
839   Builder.Insert(MatchCmpBr);
840 
841   DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock},
842                     {DominatorTree::Insert, LoopStartBlock, EndBlock}});
843 
844   // Have we reached the maximum permitted length for the loop?
845   Builder.SetInsertPoint(LoopIncBlock);
846   Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "",
847                                     /*HasNUW=*/Index->hasNoUnsignedWrap(),
848                                     /*HasNSW=*/Index->hasNoSignedWrap());
849   IndexPhi->addIncoming(PhiInc, LoopIncBlock);
850   Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen);
851   BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp);
852   Builder.Insert(IVCmpBr);
853 
854   DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock},
855                     {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}});
856 
857   // In the end block we need to insert a PHI node to deal with three cases:
858   //  1. We didn't find a mismatch in the scalar loop, so we return MaxLen.
859   //  2. We exitted the scalar loop early due to a mismatch and need to return
860   //  the index that we found.
861   //  3. We didn't find a mismatch in the vector loop, so we return MaxLen.
862   //  4. We exitted the vector loop early due to a mismatch and need to return
863   //  the index that we found.
864   Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt());
865   PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result");
866   ResPhi->addIncoming(MaxLen, LoopIncBlock);
867   ResPhi->addIncoming(IndexPhi, LoopStartBlock);
868   ResPhi->addIncoming(MaxLen, VectorLoopIncBlock);
869   ResPhi->addIncoming(VectorLoopRes, VectorLoopMismatchBlock);
870 
871   Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType);
872 
873   if (VerifyLoops) {
874     ScalarLoop->verifyLoop();
875     VectorLoop->verifyLoop();
876     if (!VectorLoop->isRecursivelyLCSSAForm(*DT, *LI))
877       report_fatal_error("Loops must remain in LCSSA form!");
878     if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI))
879       report_fatal_error("Loops must remain in LCSSA form!");
880   }
881 
882   return FinalRes;
883 }
884 
transformByteCompare(GetElementPtrInst * GEPA,GetElementPtrInst * GEPB,PHINode * IndPhi,Value * MaxLen,Instruction * Index,Value * Start,bool IncIdx,BasicBlock * FoundBB,BasicBlock * EndBB)885 void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
886                                               GetElementPtrInst *GEPB,
887                                               PHINode *IndPhi, Value *MaxLen,
888                                               Instruction *Index, Value *Start,
889                                               bool IncIdx, BasicBlock *FoundBB,
890                                               BasicBlock *EndBB) {
891 
892   // Insert the byte compare code at the end of the preheader block
893   BasicBlock *Preheader = CurLoop->getLoopPreheader();
894   BasicBlock *Header = CurLoop->getHeader();
895   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
896   IRBuilder<> Builder(PHBranch);
897   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
898   Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
899 
900   // Increment the pointer if this was done before the loads in the loop.
901   if (IncIdx)
902     Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1));
903 
904   Value *ByteCmpRes =
905       expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen);
906 
907   // Replaces uses of index & induction Phi with intrinsic (we already
908   // checked that the the first instruction of Header is the Phi above).
909   assert(IndPhi->hasOneUse() && "Index phi node has more than one use!");
910   Index->replaceAllUsesWith(ByteCmpRes);
911 
912   assert(PHBranch->isUnconditional() &&
913          "Expected preheader to terminate with an unconditional branch.");
914 
915   // If no mismatch was found, we can jump to the end block. Create a
916   // new basic block for the compare instruction.
917   auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare",
918                                    Preheader->getParent());
919   CmpBB->moveBefore(EndBB);
920 
921   // Replace the branch in the preheader with an always-true conditional branch.
922   // This ensures there is still a reference to the original loop.
923   Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header);
924   PHBranch->eraseFromParent();
925 
926   BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent();
927   DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}});
928 
929   // Create the branch to either the end or found block depending on the value
930   // returned by the intrinsic.
931   Builder.SetInsertPoint(CmpBB);
932   if (FoundBB != EndBB) {
933     Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen);
934     Builder.CreateCondBr(FoundCmp, EndBB, FoundBB);
935     DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB},
936                       {DominatorTree::Insert, CmpBB, EndBB}});
937 
938   } else {
939     Builder.CreateBr(FoundBB);
940     DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}});
941   }
942 
943   auto fixSuccessorPhis = [&](BasicBlock *SuccBB) {
944     for (PHINode &PN : SuccBB->phis()) {
945       // At this point we've already replaced all uses of the result from the
946       // loop with ByteCmp. Look through the incoming values to find ByteCmp,
947       // meaning this is a Phi collecting the results of the byte compare.
948       bool ResPhi = false;
949       for (Value *Op : PN.incoming_values())
950         if (Op == ByteCmpRes) {
951           ResPhi = true;
952           break;
953         }
954 
955       // Any PHI that depended upon the result of the byte compare needs a new
956       // incoming value from CmpBB. This is because the original loop will get
957       // deleted.
958       if (ResPhi)
959         PN.addIncoming(ByteCmpRes, CmpBB);
960       else {
961         // There should be no other outside uses of other values in the
962         // original loop. Any incoming values should either:
963         //   1. Be for blocks outside the loop, which aren't interesting. Or ..
964         //   2. These are from blocks in the loop with values defined outside
965         //      the loop. We should a similar incoming value from CmpBB.
966         for (BasicBlock *BB : PN.blocks())
967           if (CurLoop->contains(BB)) {
968             PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB);
969             break;
970           }
971       }
972     }
973   };
974 
975   // Ensure all Phis in the successors of CmpBB have an incoming value from it.
976   fixSuccessorPhis(EndBB);
977   if (EndBB != FoundBB)
978     fixSuccessorPhis(FoundBB);
979 
980   // The new CmpBB block isn't part of the loop, but will need to be added to
981   // the outer loop if there is one.
982   if (!CurLoop->isOutermost())
983     CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI);
984 
985   if (VerifyLoops && CurLoop->getParentLoop()) {
986     CurLoop->getParentLoop()->verifyLoop();
987     if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI))
988       report_fatal_error("Loops must remain in LCSSA form!");
989   }
990 }
991 
recognizeFindFirstByte()992 bool LoopIdiomVectorize::recognizeFindFirstByte() {
993   // Currently the transformation only works on scalable vector types, although
994   // there is no fundamental reason why it cannot be made to work for fixed
995   // vectors. We also need to know the target's minimum page size in order to
996   // generate runtime memory checks to ensure the vector version won't fault.
997   if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
998       DisableFindFirstByte)
999     return false;
1000 
1001   // Define some constants we need throughout.
1002   BasicBlock *Header = CurLoop->getHeader();
1003   LLVMContext &Ctx = Header->getContext();
1004 
1005   // We are expecting the four blocks defined below: Header, MatchBB, InnerBB,
1006   // and OuterBB. For now, we will bail our for almost anything else. The Four
1007   // blocks contain one nested loop.
1008   if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4 ||
1009       CurLoop->getSubLoops().size() != 1)
1010     return false;
1011 
1012   auto *InnerLoop = CurLoop->getSubLoops().front();
1013   PHINode *IndPhi = dyn_cast<PHINode>(&Header->front());
1014   if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
1015     return false;
1016 
1017   // Check instruction counts.
1018   auto LoopBlocks = CurLoop->getBlocks();
1019   if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
1020       LoopBlocks[1]->sizeWithoutDebug() > 4 ||
1021       LoopBlocks[2]->sizeWithoutDebug() > 3 ||
1022       LoopBlocks[3]->sizeWithoutDebug() > 3)
1023     return false;
1024 
1025   // Check that no instruction other than IndPhi has outside uses.
1026   for (BasicBlock *BB : LoopBlocks)
1027     for (Instruction &I : *BB)
1028       if (&I != IndPhi)
1029         for (User *U : I.users())
1030           if (!CurLoop->contains(cast<Instruction>(U)))
1031             return false;
1032 
1033   // Match the branch instruction in the header. We are expecting an
1034   // unconditional branch to the inner loop.
1035   //
1036   // Header:
1037   //   %14 = phi ptr [ %24, %OuterBB ], [ %3, %Header.preheader ]
1038   //   %15 = load i8, ptr %14, align 1
1039   //   br label %MatchBB
1040   BasicBlock *MatchBB;
1041   if (!match(Header->getTerminator(), m_UnconditionalBr(MatchBB)) ||
1042       !InnerLoop->contains(MatchBB))
1043     return false;
1044 
1045   // MatchBB should be the entrypoint into the inner loop containing the
1046   // comparison between a search element and a needle.
1047   //
1048   // MatchBB:
1049   //   %20 = phi ptr [ %7, %Header ], [ %17, %InnerBB ]
1050   //   %21 = load i8, ptr %20, align 1
1051   //   %22 = icmp eq i8 %15, %21
1052   //   br i1 %22, label %ExitSucc, label %InnerBB
1053   BasicBlock *ExitSucc, *InnerBB;
1054   Value *LoadSearch, *LoadNeedle;
1055   CmpPredicate MatchPred;
1056   if (!match(MatchBB->getTerminator(),
1057              m_Br(m_ICmp(MatchPred, m_Value(LoadSearch), m_Value(LoadNeedle)),
1058                   m_BasicBlock(ExitSucc), m_BasicBlock(InnerBB))) ||
1059       MatchPred != ICmpInst::ICMP_EQ || !InnerLoop->contains(InnerBB))
1060     return false;
1061 
1062   // We expect outside uses of `IndPhi' in ExitSucc (and only there).
1063   for (User *U : IndPhi->users())
1064     if (!CurLoop->contains(cast<Instruction>(U))) {
1065       auto *PN = dyn_cast<PHINode>(U);
1066       if (!PN || PN->getParent() != ExitSucc)
1067         return false;
1068     }
1069 
1070   // Match the loads and check they are simple.
1071   Value *Search, *Needle;
1072   if (!match(LoadSearch, m_Load(m_Value(Search))) ||
1073       !match(LoadNeedle, m_Load(m_Value(Needle))) ||
1074       !cast<LoadInst>(LoadSearch)->isSimple() ||
1075       !cast<LoadInst>(LoadNeedle)->isSimple())
1076     return false;
1077 
1078   // Check we are loading valid characters.
1079   Type *CharTy = LoadSearch->getType();
1080   if (!CharTy->isIntegerTy() || LoadNeedle->getType() != CharTy)
1081     return false;
1082 
1083   // Pick the vectorisation factor based on CharTy, work out the cost of the
1084   // match intrinsic and decide if we should use it.
1085   // Note: For the time being we assume 128-bit vectors.
1086   unsigned VF = 128 / CharTy->getIntegerBitWidth();
1087   SmallVector<Type *> Args = {
1088       ScalableVectorType::get(CharTy, VF), FixedVectorType::get(CharTy, VF),
1089       ScalableVectorType::get(Type::getInt1Ty(Ctx), VF)};
1090   IntrinsicCostAttributes Attrs(Intrinsic::experimental_vector_match, Args[2],
1091                                 Args);
1092   if (TTI->getIntrinsicInstrCost(Attrs, TTI::TCK_SizeAndLatency) > 4)
1093     return false;
1094 
1095   // The loads come from two PHIs, each with two incoming values.
1096   PHINode *PSearch = dyn_cast<PHINode>(Search);
1097   PHINode *PNeedle = dyn_cast<PHINode>(Needle);
1098   if (!PSearch || PSearch->getNumIncomingValues() != 2 || !PNeedle ||
1099       PNeedle->getNumIncomingValues() != 2)
1100     return false;
1101 
1102   // One PHI comes from the outer loop (PSearch), the other one from the inner
1103   // loop (PNeedle). PSearch effectively corresponds to IndPhi.
1104   if (InnerLoop->contains(PSearch))
1105     std::swap(PSearch, PNeedle);
1106   if (PSearch != &Header->front() || PNeedle != &MatchBB->front())
1107     return false;
1108 
1109   // The incoming values of both PHI nodes should be a gep of 1.
1110   Value *SearchStart = PSearch->getIncomingValue(0);
1111   Value *SearchIndex = PSearch->getIncomingValue(1);
1112   if (CurLoop->contains(PSearch->getIncomingBlock(0)))
1113     std::swap(SearchStart, SearchIndex);
1114 
1115   Value *NeedleStart = PNeedle->getIncomingValue(0);
1116   Value *NeedleIndex = PNeedle->getIncomingValue(1);
1117   if (InnerLoop->contains(PNeedle->getIncomingBlock(0)))
1118     std::swap(NeedleStart, NeedleIndex);
1119 
1120   // Match the GEPs.
1121   if (!match(SearchIndex, m_GEP(m_Specific(PSearch), m_One())) ||
1122       !match(NeedleIndex, m_GEP(m_Specific(PNeedle), m_One())))
1123     return false;
1124 
1125   // Check the GEPs result type matches `CharTy'.
1126   GetElementPtrInst *GEPSearch = cast<GetElementPtrInst>(SearchIndex);
1127   GetElementPtrInst *GEPNeedle = cast<GetElementPtrInst>(NeedleIndex);
1128   if (GEPSearch->getResultElementType() != CharTy ||
1129       GEPNeedle->getResultElementType() != CharTy)
1130     return false;
1131 
1132   // InnerBB should increment the address of the needle pointer.
1133   //
1134   // InnerBB:
1135   //   %17 = getelementptr inbounds i8, ptr %20, i64 1
1136   //   %18 = icmp eq ptr %17, %10
1137   //   br i1 %18, label %OuterBB, label %MatchBB
1138   BasicBlock *OuterBB;
1139   Value *NeedleEnd;
1140   if (!match(InnerBB->getTerminator(),
1141              m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(GEPNeedle),
1142                                  m_Value(NeedleEnd)),
1143                   m_BasicBlock(OuterBB), m_Specific(MatchBB))) ||
1144       !CurLoop->contains(OuterBB))
1145     return false;
1146 
1147   // OuterBB should increment the address of the search element pointer.
1148   //
1149   // OuterBB:
1150   //   %24 = getelementptr inbounds i8, ptr %14, i64 1
1151   //   %25 = icmp eq ptr %24, %6
1152   //   br i1 %25, label %ExitFail, label %Header
1153   BasicBlock *ExitFail;
1154   Value *SearchEnd;
1155   if (!match(OuterBB->getTerminator(),
1156              m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(GEPSearch),
1157                                  m_Value(SearchEnd)),
1158                   m_BasicBlock(ExitFail), m_Specific(Header))))
1159     return false;
1160 
1161   if (!CurLoop->isLoopInvariant(SearchStart) ||
1162       !CurLoop->isLoopInvariant(SearchEnd) ||
1163       !CurLoop->isLoopInvariant(NeedleStart) ||
1164       !CurLoop->isLoopInvariant(NeedleEnd))
1165     return false;
1166 
1167   LLVM_DEBUG(dbgs() << "Found idiom in loop: \n" << *CurLoop << "\n\n");
1168 
1169   transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, SearchStart,
1170                          SearchEnd, NeedleStart, NeedleEnd);
1171   return true;
1172 }
1173 
expandFindFirstByte(IRBuilder<> & Builder,DomTreeUpdater & DTU,unsigned VF,Type * CharTy,BasicBlock * ExitSucc,BasicBlock * ExitFail,Value * SearchStart,Value * SearchEnd,Value * NeedleStart,Value * NeedleEnd)1174 Value *LoopIdiomVectorize::expandFindFirstByte(
1175     IRBuilder<> &Builder, DomTreeUpdater &DTU, unsigned VF, Type *CharTy,
1176     BasicBlock *ExitSucc, BasicBlock *ExitFail, Value *SearchStart,
1177     Value *SearchEnd, Value *NeedleStart, Value *NeedleEnd) {
1178   // Set up some types and constants that we intend to reuse.
1179   auto *PtrTy = Builder.getPtrTy();
1180   auto *I64Ty = Builder.getInt64Ty();
1181   auto *PredVTy = ScalableVectorType::get(Builder.getInt1Ty(), VF);
1182   auto *CharVTy = ScalableVectorType::get(CharTy, VF);
1183   auto *ConstVF = ConstantInt::get(I64Ty, VF);
1184 
1185   // Other common arguments.
1186   BasicBlock *Preheader = CurLoop->getLoopPreheader();
1187   LLVMContext &Ctx = Preheader->getContext();
1188   Value *Passthru = ConstantInt::getNullValue(CharVTy);
1189 
1190   // Split block in the original loop preheader.
1191   // SPH is the new preheader to the old scalar loop.
1192   BasicBlock *SPH = SplitBlock(Preheader, Preheader->getTerminator(), DT, LI,
1193                                nullptr, "scalar_preheader");
1194 
1195   // Create the blocks that we're going to use.
1196   //
1197   // We will have the following loops:
1198   // (O) Outer loop where we iterate over the elements of the search array.
1199   // (I) Inner loop where we iterate over the elements of the needle array.
1200   //
1201   // Overall, the blocks do the following:
1202   // (0) Check if the arrays can't cross page boundaries. If so go to (1),
1203   //     otherwise fall back to the original scalar loop.
1204   // (1) Load the search array. Go to (2).
1205   // (2) (a) Load the needle array.
1206   //     (b) Splat the first element to the inactive lanes.
1207   //     (c) Check if any elements match. If so go to (3), otherwise go to (4).
1208   // (3) Compute the index of the first match and exit.
1209   // (4) Check if we've reached the end of the needle array. If not loop back to
1210   //     (2), otherwise go to (5).
1211   // (5) Check if we've reached the end of the search array. If not loop back to
1212   //     (1), otherwise exit.
1213   // Blocks (0,3) are not part of any loop. Blocks (1,5) and (2,4) belong to
1214   // the outer and inner loops, respectively.
1215   BasicBlock *BB0 = BasicBlock::Create(Ctx, "mem_check", SPH->getParent(), SPH);
1216   BasicBlock *BB1 =
1217       BasicBlock::Create(Ctx, "find_first_vec_header", SPH->getParent(), SPH);
1218   BasicBlock *BB2 =
1219       BasicBlock::Create(Ctx, "match_check_vec", SPH->getParent(), SPH);
1220   BasicBlock *BB3 =
1221       BasicBlock::Create(Ctx, "calculate_match", SPH->getParent(), SPH);
1222   BasicBlock *BB4 =
1223       BasicBlock::Create(Ctx, "needle_check_vec", SPH->getParent(), SPH);
1224   BasicBlock *BB5 =
1225       BasicBlock::Create(Ctx, "search_check_vec", SPH->getParent(), SPH);
1226 
1227   // Update LoopInfo with the new loops.
1228   auto OuterLoop = LI->AllocateLoop();
1229   auto InnerLoop = LI->AllocateLoop();
1230 
1231   if (auto ParentLoop = CurLoop->getParentLoop()) {
1232     ParentLoop->addBasicBlockToLoop(BB0, *LI);
1233     ParentLoop->addChildLoop(OuterLoop);
1234     ParentLoop->addBasicBlockToLoop(BB3, *LI);
1235   } else {
1236     LI->addTopLevelLoop(OuterLoop);
1237   }
1238 
1239   // Add the inner loop to the outer.
1240   OuterLoop->addChildLoop(InnerLoop);
1241 
1242   // Add the new basic blocks to the corresponding loops.
1243   OuterLoop->addBasicBlockToLoop(BB1, *LI);
1244   OuterLoop->addBasicBlockToLoop(BB5, *LI);
1245   InnerLoop->addBasicBlockToLoop(BB2, *LI);
1246   InnerLoop->addBasicBlockToLoop(BB4, *LI);
1247 
1248   // Update the terminator added by SplitBlock to branch to the first block.
1249   Preheader->getTerminator()->setSuccessor(0, BB0);
1250   DTU.applyUpdates({{DominatorTree::Delete, Preheader, SPH},
1251                     {DominatorTree::Insert, Preheader, BB0}});
1252 
1253   // (0) Check if we could be crossing a page boundary; if so, fallback to the
1254   // old scalar loops. Also create a predicate of VF elements to be used in the
1255   // vector loops.
1256   Builder.SetInsertPoint(BB0);
1257   Value *ISearchStart =
1258       Builder.CreatePtrToInt(SearchStart, I64Ty, "search_start_int");
1259   Value *ISearchEnd =
1260       Builder.CreatePtrToInt(SearchEnd, I64Ty, "search_end_int");
1261   Value *INeedleStart =
1262       Builder.CreatePtrToInt(NeedleStart, I64Ty, "needle_start_int");
1263   Value *INeedleEnd =
1264       Builder.CreatePtrToInt(NeedleEnd, I64Ty, "needle_end_int");
1265   Value *PredVF =
1266       Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1267                               {ConstantInt::get(I64Ty, 0), ConstVF});
1268 
1269   const uint64_t MinPageSize = TTI->getMinPageSize().value();
1270   const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize);
1271   Value *SearchStartPage =
1272       Builder.CreateLShr(ISearchStart, AddrShiftAmt, "search_start_page");
1273   Value *SearchEndPage =
1274       Builder.CreateLShr(ISearchEnd, AddrShiftAmt, "search_end_page");
1275   Value *NeedleStartPage =
1276       Builder.CreateLShr(INeedleStart, AddrShiftAmt, "needle_start_page");
1277   Value *NeedleEndPage =
1278       Builder.CreateLShr(INeedleEnd, AddrShiftAmt, "needle_end_page");
1279   Value *SearchPageCmp =
1280       Builder.CreateICmpNE(SearchStartPage, SearchEndPage, "search_page_cmp");
1281   Value *NeedlePageCmp =
1282       Builder.CreateICmpNE(NeedleStartPage, NeedleEndPage, "needle_page_cmp");
1283 
1284   Value *CombinedPageCmp =
1285       Builder.CreateOr(SearchPageCmp, NeedlePageCmp, "combined_page_cmp");
1286   BranchInst *CombinedPageBr = Builder.CreateCondBr(CombinedPageCmp, SPH, BB1);
1287   CombinedPageBr->setMetadata(LLVMContext::MD_prof,
1288                               MDBuilder(Ctx).createBranchWeights(10, 90));
1289   DTU.applyUpdates(
1290       {{DominatorTree::Insert, BB0, SPH}, {DominatorTree::Insert, BB0, BB1}});
1291 
1292   // (1) Load the search array and branch to the inner loop.
1293   Builder.SetInsertPoint(BB1);
1294   PHINode *Search = Builder.CreatePHI(PtrTy, 2, "psearch");
1295   Value *PredSearch = Builder.CreateIntrinsic(
1296       Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1297       {Builder.CreatePtrToInt(Search, I64Ty), ISearchEnd}, nullptr,
1298       "search_pred");
1299   PredSearch = Builder.CreateAnd(PredVF, PredSearch, "search_masked");
1300   Value *LoadSearch = Builder.CreateMaskedLoad(
1301       CharVTy, Search, Align(1), PredSearch, Passthru, "search_load_vec");
1302   Builder.CreateBr(BB2);
1303   DTU.applyUpdates({{DominatorTree::Insert, BB1, BB2}});
1304 
1305   // (2) Inner loop.
1306   Builder.SetInsertPoint(BB2);
1307   PHINode *Needle = Builder.CreatePHI(PtrTy, 2, "pneedle");
1308 
1309   // (2.a) Load the needle array.
1310   Value *PredNeedle = Builder.CreateIntrinsic(
1311       Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1312       {Builder.CreatePtrToInt(Needle, I64Ty), INeedleEnd}, nullptr,
1313       "needle_pred");
1314   PredNeedle = Builder.CreateAnd(PredVF, PredNeedle, "needle_masked");
1315   Value *LoadNeedle = Builder.CreateMaskedLoad(
1316       CharVTy, Needle, Align(1), PredNeedle, Passthru, "needle_load_vec");
1317 
1318   // (2.b) Splat the first element to the inactive lanes.
1319   Value *Needle0 =
1320       Builder.CreateExtractElement(LoadNeedle, uint64_t(0), "needle0");
1321   Value *Needle0Splat = Builder.CreateVectorSplat(ElementCount::getScalable(VF),
1322                                                   Needle0, "needle0");
1323   LoadNeedle = Builder.CreateSelect(PredNeedle, LoadNeedle, Needle0Splat,
1324                                     "needle_splat");
1325   LoadNeedle = Builder.CreateExtractVector(
1326       FixedVectorType::get(CharTy, VF), LoadNeedle, uint64_t(0), "needle_vec");
1327 
1328   // (2.c) Test if there's a match.
1329   Value *MatchPred = Builder.CreateIntrinsic(
1330       Intrinsic::experimental_vector_match, {CharVTy, LoadNeedle->getType()},
1331       {LoadSearch, LoadNeedle, PredSearch}, nullptr, "match_pred");
1332   Value *IfAnyMatch = Builder.CreateOrReduce(MatchPred);
1333   Builder.CreateCondBr(IfAnyMatch, BB3, BB4);
1334   DTU.applyUpdates(
1335       {{DominatorTree::Insert, BB2, BB3}, {DominatorTree::Insert, BB2, BB4}});
1336 
1337   // (3) We found a match. Compute the index of its location and exit.
1338   Builder.SetInsertPoint(BB3);
1339   PHINode *MatchLCSSA = Builder.CreatePHI(PtrTy, 1, "match_start");
1340   PHINode *MatchPredLCSSA =
1341       Builder.CreatePHI(MatchPred->getType(), 1, "match_vec");
1342   Value *MatchCnt = Builder.CreateIntrinsic(
1343       Intrinsic::experimental_cttz_elts, {I64Ty, MatchPred->getType()},
1344       {MatchPredLCSSA, /*ZeroIsPoison=*/Builder.getInt1(true)}, nullptr,
1345       "match_idx");
1346   Value *MatchVal =
1347       Builder.CreateGEP(CharTy, MatchLCSSA, MatchCnt, "match_res");
1348   Builder.CreateBr(ExitSucc);
1349   DTU.applyUpdates({{DominatorTree::Insert, BB3, ExitSucc}});
1350 
1351   // (4) Check if we've reached the end of the needle array.
1352   Builder.SetInsertPoint(BB4);
1353   Value *NextNeedle =
1354       Builder.CreateGEP(CharTy, Needle, ConstVF, "needle_next_vec");
1355   Builder.CreateCondBr(Builder.CreateICmpULT(NextNeedle, NeedleEnd), BB2, BB5);
1356   DTU.applyUpdates(
1357       {{DominatorTree::Insert, BB4, BB2}, {DominatorTree::Insert, BB4, BB5}});
1358 
1359   // (5) Check if we've reached the end of the search array.
1360   Builder.SetInsertPoint(BB5);
1361   Value *NextSearch =
1362       Builder.CreateGEP(CharTy, Search, ConstVF, "search_next_vec");
1363   Builder.CreateCondBr(Builder.CreateICmpULT(NextSearch, SearchEnd), BB1,
1364                        ExitFail);
1365   DTU.applyUpdates({{DominatorTree::Insert, BB5, BB1},
1366                     {DominatorTree::Insert, BB5, ExitFail}});
1367 
1368   // Set up the PHI nodes.
1369   Search->addIncoming(SearchStart, BB0);
1370   Search->addIncoming(NextSearch, BB5);
1371   Needle->addIncoming(NeedleStart, BB1);
1372   Needle->addIncoming(NextNeedle, BB4);
1373   // These are needed to retain LCSSA form.
1374   MatchLCSSA->addIncoming(Search, BB2);
1375   MatchPredLCSSA->addIncoming(MatchPred, BB2);
1376 
1377   if (VerifyLoops) {
1378     OuterLoop->verifyLoop();
1379     InnerLoop->verifyLoop();
1380     if (!OuterLoop->isRecursivelyLCSSAForm(*DT, *LI))
1381       report_fatal_error("Loops must remain in LCSSA form!");
1382   }
1383 
1384   return MatchVal;
1385 }
1386 
transformFindFirstByte(PHINode * IndPhi,unsigned VF,Type * CharTy,BasicBlock * ExitSucc,BasicBlock * ExitFail,Value * SearchStart,Value * SearchEnd,Value * NeedleStart,Value * NeedleEnd)1387 void LoopIdiomVectorize::transformFindFirstByte(
1388     PHINode *IndPhi, unsigned VF, Type *CharTy, BasicBlock *ExitSucc,
1389     BasicBlock *ExitFail, Value *SearchStart, Value *SearchEnd,
1390     Value *NeedleStart, Value *NeedleEnd) {
1391   // Insert the find first byte code at the end of the preheader block.
1392   BasicBlock *Preheader = CurLoop->getLoopPreheader();
1393   BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
1394   IRBuilder<> Builder(PHBranch);
1395   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1396   Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
1397 
1398   Value *MatchVal =
1399       expandFindFirstByte(Builder, DTU, VF, CharTy, ExitSucc, ExitFail,
1400                           SearchStart, SearchEnd, NeedleStart, NeedleEnd);
1401 
1402   assert(PHBranch->isUnconditional() &&
1403          "Expected preheader to terminate with an unconditional branch.");
1404 
1405   // Add new incoming values with the result of the transformation to PHINodes
1406   // of ExitSucc that use IndPhi.
1407   for (auto *U : llvm::make_early_inc_range(IndPhi->users())) {
1408     auto *PN = dyn_cast<PHINode>(U);
1409     if (PN && PN->getParent() == ExitSucc)
1410       PN->addIncoming(MatchVal, cast<Instruction>(MatchVal)->getParent());
1411   }
1412 
1413   if (VerifyLoops && CurLoop->getParentLoop()) {
1414     CurLoop->getParentLoop()->verifyLoop();
1415     if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI))
1416       report_fatal_error("Loops must remain in LCSSA form!");
1417   }
1418 }
1419