1 //===-------- LoopIdiomVectorize.cpp - Loop idiom vectorization -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements a pass that recognizes certain loop idioms and
10 // transforms them into more optimized versions of the same loop. In cases
11 // where this happens, it can be a significant performance win.
12 //
13 // We currently support two loops:
14 //
15 // 1. A loop that finds the first mismatched byte in an array and returns the
16 // index, i.e. something like:
17 //
18 // while (++i != n) {
19 // if (a[i] != b[i])
20 // break;
21 // }
22 //
23 // In this example we can actually vectorize the loop despite the early exit,
24 // although the loop vectorizer does not support it. It requires some extra
25 // checks to deal with the possibility of faulting loads when crossing page
26 // boundaries. However, even with these checks it is still profitable to do the
27 // transformation.
28 //
29 // TODO List:
30 //
31 // * Add support for the inverse case where we scan for a matching element.
32 // * Permit 64-bit induction variable types.
33 // * Recognize loops that increment the IV *after* comparing bytes.
34 // * Allow 32-bit sign-extends of the IV used by the GEP.
35 //
36 // 2. A loop that finds the first matching character in an array among a set of
37 // possible matches, e.g.:
38 //
39 // for (; first != last; ++first)
40 // for (s_it = s_first; s_it != s_last; ++s_it)
41 // if (*first == *s_it)
42 // return first;
43 // return last;
44 //
45 // This corresponds to std::find_first_of (for arrays of bytes) from the C++
46 // standard library. This function can be implemented efficiently for targets
47 // that support @llvm.experimental.vector.match. For example, on AArch64 targets
48 // that implement SVE2, this lower to a MATCH instruction, which enables us to
49 // perform up to 16x16=256 comparisons in one go. This can lead to very
50 // significant speedups.
51 //
52 // TODO:
53 //
54 // * Add support for `find_first_not_of' loops (i.e. with not-equal comparison).
55 // * Make VF a configurable parameter (right now we assume 128-bit vectors).
56 // * Potentially adjust the cost model to let the transformation kick-in even if
57 // @llvm.experimental.vector.match doesn't have direct support in hardware.
58 //
59 //===----------------------------------------------------------------------===//
60 //
61 // NOTE: This Pass matches really specific loop patterns because it's only
62 // supposed to be a temporary solution until our LoopVectorizer is powerful
63 // enough to vectorize them automatically.
64 //
65 //===----------------------------------------------------------------------===//
66
67 #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
68 #include "llvm/Analysis/DomTreeUpdater.h"
69 #include "llvm/Analysis/LoopPass.h"
70 #include "llvm/Analysis/TargetTransformInfo.h"
71 #include "llvm/IR/Dominators.h"
72 #include "llvm/IR/IRBuilder.h"
73 #include "llvm/IR/Intrinsics.h"
74 #include "llvm/IR/MDBuilder.h"
75 #include "llvm/IR/PatternMatch.h"
76 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
77
78 using namespace llvm;
79 using namespace PatternMatch;
80
81 #define DEBUG_TYPE "loop-idiom-vectorize"
82
83 static cl::opt<bool> DisableAll("disable-loop-idiom-vectorize-all", cl::Hidden,
84 cl::init(false),
85 cl::desc("Disable Loop Idiom Vectorize Pass."));
86
87 static cl::opt<LoopIdiomVectorizeStyle>
88 LITVecStyle("loop-idiom-vectorize-style", cl::Hidden,
89 cl::desc("The vectorization style for loop idiom transform."),
90 cl::values(clEnumValN(LoopIdiomVectorizeStyle::Masked, "masked",
91 "Use masked vector intrinsics"),
92 clEnumValN(LoopIdiomVectorizeStyle::Predicated,
93 "predicated", "Use VP intrinsics")),
94 cl::init(LoopIdiomVectorizeStyle::Masked));
95
96 static cl::opt<bool>
97 DisableByteCmp("disable-loop-idiom-vectorize-bytecmp", cl::Hidden,
98 cl::init(false),
99 cl::desc("Proceed with Loop Idiom Vectorize Pass, but do "
100 "not convert byte-compare loop(s)."));
101
102 static cl::opt<unsigned>
103 ByteCmpVF("loop-idiom-vectorize-bytecmp-vf", cl::Hidden,
104 cl::desc("The vectorization factor for byte-compare patterns."),
105 cl::init(16));
106
107 static cl::opt<bool>
108 DisableFindFirstByte("disable-loop-idiom-vectorize-find-first-byte",
109 cl::Hidden, cl::init(false),
110 cl::desc("Do not convert find-first-byte loop(s)."));
111
112 static cl::opt<bool>
113 VerifyLoops("loop-idiom-vectorize-verify", cl::Hidden, cl::init(false),
114 cl::desc("Verify loops generated Loop Idiom Vectorize Pass."));
115
116 namespace {
117 class LoopIdiomVectorize {
118 LoopIdiomVectorizeStyle VectorizeStyle;
119 unsigned ByteCompareVF;
120 Loop *CurLoop = nullptr;
121 DominatorTree *DT;
122 LoopInfo *LI;
123 const TargetTransformInfo *TTI;
124 const DataLayout *DL;
125
126 // Blocks that will be used for inserting vectorized code.
127 BasicBlock *EndBlock = nullptr;
128 BasicBlock *VectorLoopPreheaderBlock = nullptr;
129 BasicBlock *VectorLoopStartBlock = nullptr;
130 BasicBlock *VectorLoopMismatchBlock = nullptr;
131 BasicBlock *VectorLoopIncBlock = nullptr;
132
133 public:
LoopIdiomVectorize(LoopIdiomVectorizeStyle S,unsigned VF,DominatorTree * DT,LoopInfo * LI,const TargetTransformInfo * TTI,const DataLayout * DL)134 LoopIdiomVectorize(LoopIdiomVectorizeStyle S, unsigned VF, DominatorTree *DT,
135 LoopInfo *LI, const TargetTransformInfo *TTI,
136 const DataLayout *DL)
137 : VectorizeStyle(S), ByteCompareVF(VF), DT(DT), LI(LI), TTI(TTI), DL(DL) {
138 }
139
140 bool run(Loop *L);
141
142 private:
143 /// \name Countable Loop Idiom Handling
144 /// @{
145
146 bool runOnCountableLoop();
147 bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount,
148 SmallVectorImpl<BasicBlock *> &ExitBlocks);
149
150 bool recognizeByteCompare();
151
152 Value *expandFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
153 GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
154 Instruction *Index, Value *Start, Value *MaxLen);
155
156 Value *createMaskedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
157 GetElementPtrInst *GEPA,
158 GetElementPtrInst *GEPB, Value *ExtStart,
159 Value *ExtEnd);
160 Value *createPredicatedFindMismatch(IRBuilder<> &Builder, DomTreeUpdater &DTU,
161 GetElementPtrInst *GEPA,
162 GetElementPtrInst *GEPB, Value *ExtStart,
163 Value *ExtEnd);
164
165 void transformByteCompare(GetElementPtrInst *GEPA, GetElementPtrInst *GEPB,
166 PHINode *IndPhi, Value *MaxLen, Instruction *Index,
167 Value *Start, bool IncIdx, BasicBlock *FoundBB,
168 BasicBlock *EndBB);
169
170 bool recognizeFindFirstByte();
171
172 Value *expandFindFirstByte(IRBuilder<> &Builder, DomTreeUpdater &DTU,
173 unsigned VF, Type *CharTy, BasicBlock *ExitSucc,
174 BasicBlock *ExitFail, Value *SearchStart,
175 Value *SearchEnd, Value *NeedleStart,
176 Value *NeedleEnd);
177
178 void transformFindFirstByte(PHINode *IndPhi, unsigned VF, Type *CharTy,
179 BasicBlock *ExitSucc, BasicBlock *ExitFail,
180 Value *SearchStart, Value *SearchEnd,
181 Value *NeedleStart, Value *NeedleEnd);
182 /// @}
183 };
184 } // anonymous namespace
185
run(Loop & L,LoopAnalysisManager & AM,LoopStandardAnalysisResults & AR,LPMUpdater &)186 PreservedAnalyses LoopIdiomVectorizePass::run(Loop &L, LoopAnalysisManager &AM,
187 LoopStandardAnalysisResults &AR,
188 LPMUpdater &) {
189 if (DisableAll)
190 return PreservedAnalyses::all();
191
192 const auto *DL = &L.getHeader()->getDataLayout();
193
194 LoopIdiomVectorizeStyle VecStyle = VectorizeStyle;
195 if (LITVecStyle.getNumOccurrences())
196 VecStyle = LITVecStyle;
197
198 unsigned BCVF = ByteCompareVF;
199 if (ByteCmpVF.getNumOccurrences())
200 BCVF = ByteCmpVF;
201
202 LoopIdiomVectorize LIV(VecStyle, BCVF, &AR.DT, &AR.LI, &AR.TTI, DL);
203 if (!LIV.run(&L))
204 return PreservedAnalyses::all();
205
206 return PreservedAnalyses::none();
207 }
208
209 //===----------------------------------------------------------------------===//
210 //
211 // Implementation of LoopIdiomVectorize
212 //
213 //===----------------------------------------------------------------------===//
214
run(Loop * L)215 bool LoopIdiomVectorize::run(Loop *L) {
216 CurLoop = L;
217
218 Function &F = *L->getHeader()->getParent();
219 if (DisableAll || F.hasOptSize())
220 return false;
221
222 if (F.hasFnAttribute(Attribute::NoImplicitFloat)) {
223 LLVM_DEBUG(dbgs() << DEBUG_TYPE << " is disabled on " << F.getName()
224 << " due to its NoImplicitFloat attribute");
225 return false;
226 }
227
228 // If the loop could not be converted to canonical form, it must have an
229 // indirectbr in it, just give up.
230 if (!L->getLoopPreheader())
231 return false;
232
233 LLVM_DEBUG(dbgs() << DEBUG_TYPE " Scanning: F[" << F.getName() << "] Loop %"
234 << CurLoop->getHeader()->getName() << "\n");
235
236 if (recognizeByteCompare())
237 return true;
238
239 if (recognizeFindFirstByte())
240 return true;
241
242 return false;
243 }
244
recognizeByteCompare()245 bool LoopIdiomVectorize::recognizeByteCompare() {
246 // Currently the transformation only works on scalable vector types, although
247 // there is no fundamental reason why it cannot be made to work for fixed
248 // width too.
249
250 // We also need to know the minimum page size for the target in order to
251 // generate runtime memory checks to ensure the vector version won't fault.
252 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
253 DisableByteCmp)
254 return false;
255
256 BasicBlock *Header = CurLoop->getHeader();
257
258 // In LoopIdiomVectorize::run we have already checked that the loop
259 // has a preheader so we can assume it's in a canonical form.
260 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 2)
261 return false;
262
263 PHINode *PN = dyn_cast<PHINode>(&Header->front());
264 if (!PN || PN->getNumIncomingValues() != 2)
265 return false;
266
267 auto LoopBlocks = CurLoop->getBlocks();
268 // The first block in the loop should contain only 4 instructions, e.g.
269 //
270 // while.cond:
271 // %res.phi = phi i32 [ %start, %ph ], [ %inc, %while.body ]
272 // %inc = add i32 %res.phi, 1
273 // %cmp.not = icmp eq i32 %inc, %n
274 // br i1 %cmp.not, label %while.end, label %while.body
275 //
276 if (LoopBlocks[0]->sizeWithoutDebug() > 4)
277 return false;
278
279 // The second block should contain 7 instructions, e.g.
280 //
281 // while.body:
282 // %idx = zext i32 %inc to i64
283 // %idx.a = getelementptr inbounds i8, ptr %a, i64 %idx
284 // %load.a = load i8, ptr %idx.a
285 // %idx.b = getelementptr inbounds i8, ptr %b, i64 %idx
286 // %load.b = load i8, ptr %idx.b
287 // %cmp.not.ld = icmp eq i8 %load.a, %load.b
288 // br i1 %cmp.not.ld, label %while.cond, label %while.end
289 //
290 if (LoopBlocks[1]->sizeWithoutDebug() > 7)
291 return false;
292
293 // The incoming value to the PHI node from the loop should be an add of 1.
294 Value *StartIdx = nullptr;
295 Instruction *Index = nullptr;
296 if (!CurLoop->contains(PN->getIncomingBlock(0))) {
297 StartIdx = PN->getIncomingValue(0);
298 Index = dyn_cast<Instruction>(PN->getIncomingValue(1));
299 } else {
300 StartIdx = PN->getIncomingValue(1);
301 Index = dyn_cast<Instruction>(PN->getIncomingValue(0));
302 }
303
304 // Limit to 32-bit types for now
305 if (!Index || !Index->getType()->isIntegerTy(32) ||
306 !match(Index, m_c_Add(m_Specific(PN), m_One())))
307 return false;
308
309 // If we match the pattern, PN and Index will be replaced with the result of
310 // the cttz.elts intrinsic. If any other instructions are used outside of
311 // the loop, we cannot replace it.
312 for (BasicBlock *BB : LoopBlocks)
313 for (Instruction &I : *BB)
314 if (&I != PN && &I != Index)
315 for (User *U : I.users())
316 if (!CurLoop->contains(cast<Instruction>(U)))
317 return false;
318
319 // Match the branch instruction for the header
320 Value *MaxLen;
321 BasicBlock *EndBB, *WhileBB;
322 if (!match(Header->getTerminator(),
323 m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(Index),
324 m_Value(MaxLen)),
325 m_BasicBlock(EndBB), m_BasicBlock(WhileBB))) ||
326 !CurLoop->contains(WhileBB))
327 return false;
328
329 // WhileBB should contain the pattern of load & compare instructions. Match
330 // the pattern and find the GEP instructions used by the loads.
331 BasicBlock *FoundBB;
332 BasicBlock *TrueBB;
333 Value *LoadA, *LoadB;
334 if (!match(WhileBB->getTerminator(),
335 m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(LoadA),
336 m_Value(LoadB)),
337 m_BasicBlock(TrueBB), m_BasicBlock(FoundBB))) ||
338 !CurLoop->contains(TrueBB))
339 return false;
340
341 Value *A, *B;
342 if (!match(LoadA, m_Load(m_Value(A))) || !match(LoadB, m_Load(m_Value(B))))
343 return false;
344
345 LoadInst *LoadAI = cast<LoadInst>(LoadA);
346 LoadInst *LoadBI = cast<LoadInst>(LoadB);
347 if (!LoadAI->isSimple() || !LoadBI->isSimple())
348 return false;
349
350 GetElementPtrInst *GEPA = dyn_cast<GetElementPtrInst>(A);
351 GetElementPtrInst *GEPB = dyn_cast<GetElementPtrInst>(B);
352
353 if (!GEPA || !GEPB)
354 return false;
355
356 Value *PtrA = GEPA->getPointerOperand();
357 Value *PtrB = GEPB->getPointerOperand();
358
359 // Check we are loading i8 values from two loop invariant pointers
360 if (!CurLoop->isLoopInvariant(PtrA) || !CurLoop->isLoopInvariant(PtrB) ||
361 !GEPA->getResultElementType()->isIntegerTy(8) ||
362 !GEPB->getResultElementType()->isIntegerTy(8) ||
363 !LoadAI->getType()->isIntegerTy(8) ||
364 !LoadBI->getType()->isIntegerTy(8) || PtrA == PtrB)
365 return false;
366
367 // Check that the index to the GEPs is the index we found earlier
368 if (GEPA->getNumIndices() > 1 || GEPB->getNumIndices() > 1)
369 return false;
370
371 Value *IdxA = GEPA->getOperand(GEPA->getNumIndices());
372 Value *IdxB = GEPB->getOperand(GEPB->getNumIndices());
373 if (IdxA != IdxB || !match(IdxA, m_ZExt(m_Specific(Index))))
374 return false;
375
376 // We only ever expect the pre-incremented index value to be used inside the
377 // loop.
378 if (!PN->hasOneUse())
379 return false;
380
381 // Ensure that when the Found and End blocks are identical the PHIs have the
382 // supported format. We don't currently allow cases like this:
383 // while.cond:
384 // ...
385 // br i1 %cmp.not, label %while.end, label %while.body
386 //
387 // while.body:
388 // ...
389 // br i1 %cmp.not2, label %while.cond, label %while.end
390 //
391 // while.end:
392 // %final_ptr = phi ptr [ %c, %while.body ], [ %d, %while.cond ]
393 //
394 // Where the incoming values for %final_ptr are unique and from each of the
395 // loop blocks, but not actually defined in the loop. This requires extra
396 // work setting up the byte.compare block, i.e. by introducing a select to
397 // choose the correct value.
398 // TODO: We could add support for this in future.
399 if (FoundBB == EndBB) {
400 for (PHINode &EndPN : EndBB->phis()) {
401 Value *WhileCondVal = EndPN.getIncomingValueForBlock(Header);
402 Value *WhileBodyVal = EndPN.getIncomingValueForBlock(WhileBB);
403
404 // The value of the index when leaving the while.cond block is always the
405 // same as the end value (MaxLen) so we permit either. The value when
406 // leaving the while.body block should only be the index. Otherwise for
407 // any other values we only allow ones that are same for both blocks.
408 if (WhileCondVal != WhileBodyVal &&
409 ((WhileCondVal != Index && WhileCondVal != MaxLen) ||
410 (WhileBodyVal != Index)))
411 return false;
412 }
413 }
414
415 LLVM_DEBUG(dbgs() << "FOUND IDIOM IN LOOP: \n"
416 << *(EndBB->getParent()) << "\n\n");
417
418 // The index is incremented before the GEP/Load pair so we need to
419 // add 1 to the start value.
420 transformByteCompare(GEPA, GEPB, PN, MaxLen, Index, StartIdx, /*IncIdx=*/true,
421 FoundBB, EndBB);
422 return true;
423 }
424
createMaskedFindMismatch(IRBuilder<> & Builder,DomTreeUpdater & DTU,GetElementPtrInst * GEPA,GetElementPtrInst * GEPB,Value * ExtStart,Value * ExtEnd)425 Value *LoopIdiomVectorize::createMaskedFindMismatch(
426 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
427 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
428 Type *I64Type = Builder.getInt64Ty();
429 Type *ResType = Builder.getInt32Ty();
430 Type *LoadType = Builder.getInt8Ty();
431 Value *PtrA = GEPA->getPointerOperand();
432 Value *PtrB = GEPB->getPointerOperand();
433
434 ScalableVectorType *PredVTy =
435 ScalableVectorType::get(Builder.getInt1Ty(), ByteCompareVF);
436
437 Value *InitialPred = Builder.CreateIntrinsic(
438 Intrinsic::get_active_lane_mask, {PredVTy, I64Type}, {ExtStart, ExtEnd});
439
440 Value *VecLen = Builder.CreateVScale(I64Type);
441 VecLen =
442 Builder.CreateMul(VecLen, ConstantInt::get(I64Type, ByteCompareVF), "",
443 /*HasNUW=*/true, /*HasNSW=*/true);
444
445 Value *PFalse = Builder.CreateVectorSplat(PredVTy->getElementCount(),
446 Builder.getInt1(false));
447
448 BranchInst *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
449 Builder.Insert(JumpToVectorLoop);
450
451 DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
452 VectorLoopStartBlock}});
453
454 // Set up the first vector loop block by creating the PHIs, doing the vector
455 // loads and comparing the vectors.
456 Builder.SetInsertPoint(VectorLoopStartBlock);
457 PHINode *LoopPred = Builder.CreatePHI(PredVTy, 2, "mismatch_vec_loop_pred");
458 LoopPred->addIncoming(InitialPred, VectorLoopPreheaderBlock);
459 PHINode *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vec_index");
460 VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
461 Type *VectorLoadType =
462 ScalableVectorType::get(Builder.getInt8Ty(), ByteCompareVF);
463 Value *Passthru = ConstantInt::getNullValue(VectorLoadType);
464
465 Value *VectorLhsGep =
466 Builder.CreateGEP(LoadType, PtrA, VectorIndexPhi, "", GEPA->isInBounds());
467 Value *VectorLhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorLhsGep,
468 Align(1), LoopPred, Passthru);
469
470 Value *VectorRhsGep =
471 Builder.CreateGEP(LoadType, PtrB, VectorIndexPhi, "", GEPB->isInBounds());
472 Value *VectorRhsLoad = Builder.CreateMaskedLoad(VectorLoadType, VectorRhsGep,
473 Align(1), LoopPred, Passthru);
474
475 Value *VectorMatchCmp = Builder.CreateICmpNE(VectorLhsLoad, VectorRhsLoad);
476 VectorMatchCmp = Builder.CreateSelect(LoopPred, VectorMatchCmp, PFalse);
477 Value *VectorMatchHasActiveLanes = Builder.CreateOrReduce(VectorMatchCmp);
478 BranchInst *VectorEarlyExit = BranchInst::Create(
479 VectorLoopMismatchBlock, VectorLoopIncBlock, VectorMatchHasActiveLanes);
480 Builder.Insert(VectorEarlyExit);
481
482 DTU.applyUpdates(
483 {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
484 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
485
486 // Increment the index counter and calculate the predicate for the next
487 // iteration of the loop. We branch back to the start of the loop if there
488 // is at least one active lane.
489 Builder.SetInsertPoint(VectorLoopIncBlock);
490 Value *NewVectorIndexPhi =
491 Builder.CreateAdd(VectorIndexPhi, VecLen, "",
492 /*HasNUW=*/true, /*HasNSW=*/true);
493 VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
494 Value *NewPred =
495 Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask,
496 {PredVTy, I64Type}, {NewVectorIndexPhi, ExtEnd});
497 LoopPred->addIncoming(NewPred, VectorLoopIncBlock);
498
499 Value *PredHasActiveLanes =
500 Builder.CreateExtractElement(NewPred, uint64_t(0));
501 BranchInst *VectorLoopBranchBack =
502 BranchInst::Create(VectorLoopStartBlock, EndBlock, PredHasActiveLanes);
503 Builder.Insert(VectorLoopBranchBack);
504
505 DTU.applyUpdates(
506 {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
507 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
508
509 // If we found a mismatch then we need to calculate which lane in the vector
510 // had a mismatch and add that on to the current loop index.
511 Builder.SetInsertPoint(VectorLoopMismatchBlock);
512 PHINode *FoundPred = Builder.CreatePHI(PredVTy, 1, "mismatch_vec_found_pred");
513 FoundPred->addIncoming(VectorMatchCmp, VectorLoopStartBlock);
514 PHINode *LastLoopPred =
515 Builder.CreatePHI(PredVTy, 1, "mismatch_vec_last_loop_pred");
516 LastLoopPred->addIncoming(LoopPred, VectorLoopStartBlock);
517 PHINode *VectorFoundIndex =
518 Builder.CreatePHI(I64Type, 1, "mismatch_vec_found_index");
519 VectorFoundIndex->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
520
521 Value *PredMatchCmp = Builder.CreateAnd(LastLoopPred, FoundPred);
522 Value *Ctz = Builder.CreateCountTrailingZeroElems(ResType, PredMatchCmp);
523 Ctz = Builder.CreateZExt(Ctz, I64Type);
524 Value *VectorLoopRes64 = Builder.CreateAdd(VectorFoundIndex, Ctz, "",
525 /*HasNUW=*/true, /*HasNSW=*/true);
526 return Builder.CreateTrunc(VectorLoopRes64, ResType);
527 }
528
createPredicatedFindMismatch(IRBuilder<> & Builder,DomTreeUpdater & DTU,GetElementPtrInst * GEPA,GetElementPtrInst * GEPB,Value * ExtStart,Value * ExtEnd)529 Value *LoopIdiomVectorize::createPredicatedFindMismatch(
530 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
531 GetElementPtrInst *GEPB, Value *ExtStart, Value *ExtEnd) {
532 Type *I64Type = Builder.getInt64Ty();
533 Type *I32Type = Builder.getInt32Ty();
534 Type *ResType = I32Type;
535 Type *LoadType = Builder.getInt8Ty();
536 Value *PtrA = GEPA->getPointerOperand();
537 Value *PtrB = GEPB->getPointerOperand();
538
539 auto *JumpToVectorLoop = BranchInst::Create(VectorLoopStartBlock);
540 Builder.Insert(JumpToVectorLoop);
541
542 DTU.applyUpdates({{DominatorTree::Insert, VectorLoopPreheaderBlock,
543 VectorLoopStartBlock}});
544
545 // Set up the first Vector loop block by creating the PHIs, doing the vector
546 // loads and comparing the vectors.
547 Builder.SetInsertPoint(VectorLoopStartBlock);
548 auto *VectorIndexPhi = Builder.CreatePHI(I64Type, 2, "mismatch_vector_index");
549 VectorIndexPhi->addIncoming(ExtStart, VectorLoopPreheaderBlock);
550
551 // Calculate AVL by subtracting the vector loop index from the trip count
552 Value *AVL = Builder.CreateSub(ExtEnd, VectorIndexPhi, "avl", /*HasNUW=*/true,
553 /*HasNSW=*/true);
554
555 auto *VectorLoadType = ScalableVectorType::get(LoadType, ByteCompareVF);
556 auto *VF = ConstantInt::get(I32Type, ByteCompareVF);
557
558 Value *VL = Builder.CreateIntrinsic(Intrinsic::experimental_get_vector_length,
559 {I64Type}, {AVL, VF, Builder.getTrue()});
560 Value *GepOffset = VectorIndexPhi;
561
562 Value *VectorLhsGep =
563 Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds());
564 VectorType *TrueMaskTy =
565 VectorType::get(Builder.getInt1Ty(), VectorLoadType->getElementCount());
566 Value *AllTrueMask = Constant::getAllOnesValue(TrueMaskTy);
567 Value *VectorLhsLoad = Builder.CreateIntrinsic(
568 Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
569 {VectorLhsGep, AllTrueMask, VL}, nullptr, "lhs.load");
570
571 Value *VectorRhsGep =
572 Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds());
573 Value *VectorRhsLoad = Builder.CreateIntrinsic(
574 Intrinsic::vp_load, {VectorLoadType, VectorLhsGep->getType()},
575 {VectorRhsGep, AllTrueMask, VL}, nullptr, "rhs.load");
576
577 StringRef PredicateStr = CmpInst::getPredicateName(CmpInst::ICMP_NE);
578 auto *PredicateMDS = MDString::get(VectorLhsLoad->getContext(), PredicateStr);
579 Value *Pred = MetadataAsValue::get(VectorLhsLoad->getContext(), PredicateMDS);
580 Value *VectorMatchCmp = Builder.CreateIntrinsic(
581 Intrinsic::vp_icmp, {VectorLhsLoad->getType()},
582 {VectorLhsLoad, VectorRhsLoad, Pred, AllTrueMask, VL}, nullptr,
583 "mismatch.cmp");
584 Value *CTZ = Builder.CreateIntrinsic(
585 Intrinsic::vp_cttz_elts, {ResType, VectorMatchCmp->getType()},
586 {VectorMatchCmp, /*ZeroIsPoison=*/Builder.getInt1(false), AllTrueMask,
587 VL});
588 Value *MismatchFound = Builder.CreateICmpNE(CTZ, VL);
589 auto *VectorEarlyExit = BranchInst::Create(VectorLoopMismatchBlock,
590 VectorLoopIncBlock, MismatchFound);
591 Builder.Insert(VectorEarlyExit);
592
593 DTU.applyUpdates(
594 {{DominatorTree::Insert, VectorLoopStartBlock, VectorLoopMismatchBlock},
595 {DominatorTree::Insert, VectorLoopStartBlock, VectorLoopIncBlock}});
596
597 // Increment the index counter and calculate the predicate for the next
598 // iteration of the loop. We branch back to the start of the loop if there
599 // is at least one active lane.
600 Builder.SetInsertPoint(VectorLoopIncBlock);
601 Value *VL64 = Builder.CreateZExt(VL, I64Type);
602 Value *NewVectorIndexPhi =
603 Builder.CreateAdd(VectorIndexPhi, VL64, "",
604 /*HasNUW=*/true, /*HasNSW=*/true);
605 VectorIndexPhi->addIncoming(NewVectorIndexPhi, VectorLoopIncBlock);
606 Value *ExitCond = Builder.CreateICmpNE(NewVectorIndexPhi, ExtEnd);
607 auto *VectorLoopBranchBack =
608 BranchInst::Create(VectorLoopStartBlock, EndBlock, ExitCond);
609 Builder.Insert(VectorLoopBranchBack);
610
611 DTU.applyUpdates(
612 {{DominatorTree::Insert, VectorLoopIncBlock, VectorLoopStartBlock},
613 {DominatorTree::Insert, VectorLoopIncBlock, EndBlock}});
614
615 // If we found a mismatch then we need to calculate which lane in the vector
616 // had a mismatch and add that on to the current loop index.
617 Builder.SetInsertPoint(VectorLoopMismatchBlock);
618
619 // Add LCSSA phis for CTZ and VectorIndexPhi.
620 auto *CTZLCSSAPhi = Builder.CreatePHI(CTZ->getType(), 1, "ctz");
621 CTZLCSSAPhi->addIncoming(CTZ, VectorLoopStartBlock);
622 auto *VectorIndexLCSSAPhi =
623 Builder.CreatePHI(VectorIndexPhi->getType(), 1, "mismatch_vector_index");
624 VectorIndexLCSSAPhi->addIncoming(VectorIndexPhi, VectorLoopStartBlock);
625
626 Value *CTZI64 = Builder.CreateZExt(CTZLCSSAPhi, I64Type);
627 Value *VectorLoopRes64 = Builder.CreateAdd(VectorIndexLCSSAPhi, CTZI64, "",
628 /*HasNUW=*/true, /*HasNSW=*/true);
629 return Builder.CreateTrunc(VectorLoopRes64, ResType);
630 }
631
expandFindMismatch(IRBuilder<> & Builder,DomTreeUpdater & DTU,GetElementPtrInst * GEPA,GetElementPtrInst * GEPB,Instruction * Index,Value * Start,Value * MaxLen)632 Value *LoopIdiomVectorize::expandFindMismatch(
633 IRBuilder<> &Builder, DomTreeUpdater &DTU, GetElementPtrInst *GEPA,
634 GetElementPtrInst *GEPB, Instruction *Index, Value *Start, Value *MaxLen) {
635 Value *PtrA = GEPA->getPointerOperand();
636 Value *PtrB = GEPB->getPointerOperand();
637
638 // Get the arguments and types for the intrinsic.
639 BasicBlock *Preheader = CurLoop->getLoopPreheader();
640 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
641 LLVMContext &Ctx = PHBranch->getContext();
642 Type *LoadType = Type::getInt8Ty(Ctx);
643 Type *ResType = Builder.getInt32Ty();
644
645 // Split block in the original loop preheader.
646 EndBlock = SplitBlock(Preheader, PHBranch, DT, LI, nullptr, "mismatch_end");
647
648 // Create the blocks that we're going to need:
649 // 1. A block for checking the zero-extended length exceeds 0
650 // 2. A block to check that the start and end addresses of a given array
651 // lie on the same page.
652 // 3. The vector loop preheader.
653 // 4. The first vector loop block.
654 // 5. The vector loop increment block.
655 // 6. A block we can jump to from the vector loop when a mismatch is found.
656 // 7. The first block of the scalar loop itself, containing PHIs , loads
657 // and cmp.
658 // 8. A scalar loop increment block to increment the PHIs and go back
659 // around the loop.
660
661 BasicBlock *MinItCheckBlock = BasicBlock::Create(
662 Ctx, "mismatch_min_it_check", EndBlock->getParent(), EndBlock);
663
664 // Update the terminator added by SplitBlock to branch to the first block
665 Preheader->getTerminator()->setSuccessor(0, MinItCheckBlock);
666
667 BasicBlock *MemCheckBlock = BasicBlock::Create(
668 Ctx, "mismatch_mem_check", EndBlock->getParent(), EndBlock);
669
670 VectorLoopPreheaderBlock = BasicBlock::Create(
671 Ctx, "mismatch_vec_loop_preheader", EndBlock->getParent(), EndBlock);
672
673 VectorLoopStartBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop",
674 EndBlock->getParent(), EndBlock);
675
676 VectorLoopIncBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_inc",
677 EndBlock->getParent(), EndBlock);
678
679 VectorLoopMismatchBlock = BasicBlock::Create(Ctx, "mismatch_vec_loop_found",
680 EndBlock->getParent(), EndBlock);
681
682 BasicBlock *LoopPreHeaderBlock = BasicBlock::Create(
683 Ctx, "mismatch_loop_pre", EndBlock->getParent(), EndBlock);
684
685 BasicBlock *LoopStartBlock =
686 BasicBlock::Create(Ctx, "mismatch_loop", EndBlock->getParent(), EndBlock);
687
688 BasicBlock *LoopIncBlock = BasicBlock::Create(
689 Ctx, "mismatch_loop_inc", EndBlock->getParent(), EndBlock);
690
691 DTU.applyUpdates({{DominatorTree::Insert, Preheader, MinItCheckBlock},
692 {DominatorTree::Delete, Preheader, EndBlock}});
693
694 // Update LoopInfo with the new vector & scalar loops.
695 auto VectorLoop = LI->AllocateLoop();
696 auto ScalarLoop = LI->AllocateLoop();
697
698 if (CurLoop->getParentLoop()) {
699 CurLoop->getParentLoop()->addBasicBlockToLoop(MinItCheckBlock, *LI);
700 CurLoop->getParentLoop()->addBasicBlockToLoop(MemCheckBlock, *LI);
701 CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopPreheaderBlock,
702 *LI);
703 CurLoop->getParentLoop()->addChildLoop(VectorLoop);
704 CurLoop->getParentLoop()->addBasicBlockToLoop(VectorLoopMismatchBlock, *LI);
705 CurLoop->getParentLoop()->addBasicBlockToLoop(LoopPreHeaderBlock, *LI);
706 CurLoop->getParentLoop()->addChildLoop(ScalarLoop);
707 } else {
708 LI->addTopLevelLoop(VectorLoop);
709 LI->addTopLevelLoop(ScalarLoop);
710 }
711
712 // Add the new basic blocks to their associated loops.
713 VectorLoop->addBasicBlockToLoop(VectorLoopStartBlock, *LI);
714 VectorLoop->addBasicBlockToLoop(VectorLoopIncBlock, *LI);
715
716 ScalarLoop->addBasicBlockToLoop(LoopStartBlock, *LI);
717 ScalarLoop->addBasicBlockToLoop(LoopIncBlock, *LI);
718
719 // Set up some types and constants that we intend to reuse.
720 Type *I64Type = Builder.getInt64Ty();
721
722 // Check the zero-extended iteration count > 0
723 Builder.SetInsertPoint(MinItCheckBlock);
724 Value *ExtStart = Builder.CreateZExt(Start, I64Type);
725 Value *ExtEnd = Builder.CreateZExt(MaxLen, I64Type);
726 // This check doesn't really cost us very much.
727
728 Value *LimitCheck = Builder.CreateICmpULE(Start, MaxLen);
729 BranchInst *MinItCheckBr =
730 BranchInst::Create(MemCheckBlock, LoopPreHeaderBlock, LimitCheck);
731 MinItCheckBr->setMetadata(
732 LLVMContext::MD_prof,
733 MDBuilder(MinItCheckBr->getContext()).createBranchWeights(99, 1));
734 Builder.Insert(MinItCheckBr);
735
736 DTU.applyUpdates(
737 {{DominatorTree::Insert, MinItCheckBlock, MemCheckBlock},
738 {DominatorTree::Insert, MinItCheckBlock, LoopPreHeaderBlock}});
739
740 // For each of the arrays, check the start/end addresses are on the same
741 // page.
742 Builder.SetInsertPoint(MemCheckBlock);
743
744 // The early exit in the original loop means that when performing vector
745 // loads we are potentially reading ahead of the early exit. So we could
746 // fault if crossing a page boundary. Therefore, we create runtime memory
747 // checks based on the minimum page size as follows:
748 // 1. Calculate the addresses of the first memory accesses in the loop,
749 // i.e. LhsStart and RhsStart.
750 // 2. Get the last accessed addresses in the loop, i.e. LhsEnd and RhsEnd.
751 // 3. Determine which pages correspond to all the memory accesses, i.e
752 // LhsStartPage, LhsEndPage, RhsStartPage, RhsEndPage.
753 // 4. If LhsStartPage == LhsEndPage and RhsStartPage == RhsEndPage, then
754 // we know we won't cross any page boundaries in the loop so we can
755 // enter the vector loop! Otherwise we fall back on the scalar loop.
756 Value *LhsStartGEP = Builder.CreateGEP(LoadType, PtrA, ExtStart);
757 Value *RhsStartGEP = Builder.CreateGEP(LoadType, PtrB, ExtStart);
758 Value *RhsStart = Builder.CreatePtrToInt(RhsStartGEP, I64Type);
759 Value *LhsStart = Builder.CreatePtrToInt(LhsStartGEP, I64Type);
760 Value *LhsEndGEP = Builder.CreateGEP(LoadType, PtrA, ExtEnd);
761 Value *RhsEndGEP = Builder.CreateGEP(LoadType, PtrB, ExtEnd);
762 Value *LhsEnd = Builder.CreatePtrToInt(LhsEndGEP, I64Type);
763 Value *RhsEnd = Builder.CreatePtrToInt(RhsEndGEP, I64Type);
764
765 const uint64_t MinPageSize = TTI->getMinPageSize().value();
766 const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize);
767 Value *LhsStartPage = Builder.CreateLShr(LhsStart, AddrShiftAmt);
768 Value *LhsEndPage = Builder.CreateLShr(LhsEnd, AddrShiftAmt);
769 Value *RhsStartPage = Builder.CreateLShr(RhsStart, AddrShiftAmt);
770 Value *RhsEndPage = Builder.CreateLShr(RhsEnd, AddrShiftAmt);
771 Value *LhsPageCmp = Builder.CreateICmpNE(LhsStartPage, LhsEndPage);
772 Value *RhsPageCmp = Builder.CreateICmpNE(RhsStartPage, RhsEndPage);
773
774 Value *CombinedPageCmp = Builder.CreateOr(LhsPageCmp, RhsPageCmp);
775 BranchInst *CombinedPageCmpCmpBr = BranchInst::Create(
776 LoopPreHeaderBlock, VectorLoopPreheaderBlock, CombinedPageCmp);
777 CombinedPageCmpCmpBr->setMetadata(
778 LLVMContext::MD_prof, MDBuilder(CombinedPageCmpCmpBr->getContext())
779 .createBranchWeights(10, 90));
780 Builder.Insert(CombinedPageCmpCmpBr);
781
782 DTU.applyUpdates(
783 {{DominatorTree::Insert, MemCheckBlock, LoopPreHeaderBlock},
784 {DominatorTree::Insert, MemCheckBlock, VectorLoopPreheaderBlock}});
785
786 // Set up the vector loop preheader, i.e. calculate initial loop predicate,
787 // zero-extend MaxLen to 64-bits, determine the number of vector elements
788 // processed in each iteration, etc.
789 Builder.SetInsertPoint(VectorLoopPreheaderBlock);
790
791 // At this point we know two things must be true:
792 // 1. Start <= End
793 // 2. ExtMaxLen <= MinPageSize due to the page checks.
794 // Therefore, we know that we can use a 64-bit induction variable that
795 // starts from 0 -> ExtMaxLen and it will not overflow.
796 Value *VectorLoopRes = nullptr;
797 switch (VectorizeStyle) {
798 case LoopIdiomVectorizeStyle::Masked:
799 VectorLoopRes =
800 createMaskedFindMismatch(Builder, DTU, GEPA, GEPB, ExtStart, ExtEnd);
801 break;
802 case LoopIdiomVectorizeStyle::Predicated:
803 VectorLoopRes = createPredicatedFindMismatch(Builder, DTU, GEPA, GEPB,
804 ExtStart, ExtEnd);
805 break;
806 }
807
808 Builder.Insert(BranchInst::Create(EndBlock));
809
810 DTU.applyUpdates(
811 {{DominatorTree::Insert, VectorLoopMismatchBlock, EndBlock}});
812
813 // Generate code for scalar loop.
814 Builder.SetInsertPoint(LoopPreHeaderBlock);
815 Builder.Insert(BranchInst::Create(LoopStartBlock));
816
817 DTU.applyUpdates(
818 {{DominatorTree::Insert, LoopPreHeaderBlock, LoopStartBlock}});
819
820 Builder.SetInsertPoint(LoopStartBlock);
821 PHINode *IndexPhi = Builder.CreatePHI(ResType, 2, "mismatch_index");
822 IndexPhi->addIncoming(Start, LoopPreHeaderBlock);
823
824 // Otherwise compare the values
825 // Load bytes from each array and compare them.
826 Value *GepOffset = Builder.CreateZExt(IndexPhi, I64Type);
827
828 Value *LhsGep =
829 Builder.CreateGEP(LoadType, PtrA, GepOffset, "", GEPA->isInBounds());
830 Value *LhsLoad = Builder.CreateLoad(LoadType, LhsGep);
831
832 Value *RhsGep =
833 Builder.CreateGEP(LoadType, PtrB, GepOffset, "", GEPB->isInBounds());
834 Value *RhsLoad = Builder.CreateLoad(LoadType, RhsGep);
835
836 Value *MatchCmp = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
837 // If we have a mismatch then exit the loop ...
838 BranchInst *MatchCmpBr = BranchInst::Create(LoopIncBlock, EndBlock, MatchCmp);
839 Builder.Insert(MatchCmpBr);
840
841 DTU.applyUpdates({{DominatorTree::Insert, LoopStartBlock, LoopIncBlock},
842 {DominatorTree::Insert, LoopStartBlock, EndBlock}});
843
844 // Have we reached the maximum permitted length for the loop?
845 Builder.SetInsertPoint(LoopIncBlock);
846 Value *PhiInc = Builder.CreateAdd(IndexPhi, ConstantInt::get(ResType, 1), "",
847 /*HasNUW=*/Index->hasNoUnsignedWrap(),
848 /*HasNSW=*/Index->hasNoSignedWrap());
849 IndexPhi->addIncoming(PhiInc, LoopIncBlock);
850 Value *IVCmp = Builder.CreateICmpEQ(PhiInc, MaxLen);
851 BranchInst *IVCmpBr = BranchInst::Create(EndBlock, LoopStartBlock, IVCmp);
852 Builder.Insert(IVCmpBr);
853
854 DTU.applyUpdates({{DominatorTree::Insert, LoopIncBlock, EndBlock},
855 {DominatorTree::Insert, LoopIncBlock, LoopStartBlock}});
856
857 // In the end block we need to insert a PHI node to deal with three cases:
858 // 1. We didn't find a mismatch in the scalar loop, so we return MaxLen.
859 // 2. We exitted the scalar loop early due to a mismatch and need to return
860 // the index that we found.
861 // 3. We didn't find a mismatch in the vector loop, so we return MaxLen.
862 // 4. We exitted the vector loop early due to a mismatch and need to return
863 // the index that we found.
864 Builder.SetInsertPoint(EndBlock, EndBlock->getFirstInsertionPt());
865 PHINode *ResPhi = Builder.CreatePHI(ResType, 4, "mismatch_result");
866 ResPhi->addIncoming(MaxLen, LoopIncBlock);
867 ResPhi->addIncoming(IndexPhi, LoopStartBlock);
868 ResPhi->addIncoming(MaxLen, VectorLoopIncBlock);
869 ResPhi->addIncoming(VectorLoopRes, VectorLoopMismatchBlock);
870
871 Value *FinalRes = Builder.CreateTrunc(ResPhi, ResType);
872
873 if (VerifyLoops) {
874 ScalarLoop->verifyLoop();
875 VectorLoop->verifyLoop();
876 if (!VectorLoop->isRecursivelyLCSSAForm(*DT, *LI))
877 report_fatal_error("Loops must remain in LCSSA form!");
878 if (!ScalarLoop->isRecursivelyLCSSAForm(*DT, *LI))
879 report_fatal_error("Loops must remain in LCSSA form!");
880 }
881
882 return FinalRes;
883 }
884
transformByteCompare(GetElementPtrInst * GEPA,GetElementPtrInst * GEPB,PHINode * IndPhi,Value * MaxLen,Instruction * Index,Value * Start,bool IncIdx,BasicBlock * FoundBB,BasicBlock * EndBB)885 void LoopIdiomVectorize::transformByteCompare(GetElementPtrInst *GEPA,
886 GetElementPtrInst *GEPB,
887 PHINode *IndPhi, Value *MaxLen,
888 Instruction *Index, Value *Start,
889 bool IncIdx, BasicBlock *FoundBB,
890 BasicBlock *EndBB) {
891
892 // Insert the byte compare code at the end of the preheader block
893 BasicBlock *Preheader = CurLoop->getLoopPreheader();
894 BasicBlock *Header = CurLoop->getHeader();
895 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
896 IRBuilder<> Builder(PHBranch);
897 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
898 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
899
900 // Increment the pointer if this was done before the loads in the loop.
901 if (IncIdx)
902 Start = Builder.CreateAdd(Start, ConstantInt::get(Start->getType(), 1));
903
904 Value *ByteCmpRes =
905 expandFindMismatch(Builder, DTU, GEPA, GEPB, Index, Start, MaxLen);
906
907 // Replaces uses of index & induction Phi with intrinsic (we already
908 // checked that the the first instruction of Header is the Phi above).
909 assert(IndPhi->hasOneUse() && "Index phi node has more than one use!");
910 Index->replaceAllUsesWith(ByteCmpRes);
911
912 assert(PHBranch->isUnconditional() &&
913 "Expected preheader to terminate with an unconditional branch.");
914
915 // If no mismatch was found, we can jump to the end block. Create a
916 // new basic block for the compare instruction.
917 auto *CmpBB = BasicBlock::Create(Preheader->getContext(), "byte.compare",
918 Preheader->getParent());
919 CmpBB->moveBefore(EndBB);
920
921 // Replace the branch in the preheader with an always-true conditional branch.
922 // This ensures there is still a reference to the original loop.
923 Builder.CreateCondBr(Builder.getTrue(), CmpBB, Header);
924 PHBranch->eraseFromParent();
925
926 BasicBlock *MismatchEnd = cast<Instruction>(ByteCmpRes)->getParent();
927 DTU.applyUpdates({{DominatorTree::Insert, MismatchEnd, CmpBB}});
928
929 // Create the branch to either the end or found block depending on the value
930 // returned by the intrinsic.
931 Builder.SetInsertPoint(CmpBB);
932 if (FoundBB != EndBB) {
933 Value *FoundCmp = Builder.CreateICmpEQ(ByteCmpRes, MaxLen);
934 Builder.CreateCondBr(FoundCmp, EndBB, FoundBB);
935 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB},
936 {DominatorTree::Insert, CmpBB, EndBB}});
937
938 } else {
939 Builder.CreateBr(FoundBB);
940 DTU.applyUpdates({{DominatorTree::Insert, CmpBB, FoundBB}});
941 }
942
943 auto fixSuccessorPhis = [&](BasicBlock *SuccBB) {
944 for (PHINode &PN : SuccBB->phis()) {
945 // At this point we've already replaced all uses of the result from the
946 // loop with ByteCmp. Look through the incoming values to find ByteCmp,
947 // meaning this is a Phi collecting the results of the byte compare.
948 bool ResPhi = false;
949 for (Value *Op : PN.incoming_values())
950 if (Op == ByteCmpRes) {
951 ResPhi = true;
952 break;
953 }
954
955 // Any PHI that depended upon the result of the byte compare needs a new
956 // incoming value from CmpBB. This is because the original loop will get
957 // deleted.
958 if (ResPhi)
959 PN.addIncoming(ByteCmpRes, CmpBB);
960 else {
961 // There should be no other outside uses of other values in the
962 // original loop. Any incoming values should either:
963 // 1. Be for blocks outside the loop, which aren't interesting. Or ..
964 // 2. These are from blocks in the loop with values defined outside
965 // the loop. We should a similar incoming value from CmpBB.
966 for (BasicBlock *BB : PN.blocks())
967 if (CurLoop->contains(BB)) {
968 PN.addIncoming(PN.getIncomingValueForBlock(BB), CmpBB);
969 break;
970 }
971 }
972 }
973 };
974
975 // Ensure all Phis in the successors of CmpBB have an incoming value from it.
976 fixSuccessorPhis(EndBB);
977 if (EndBB != FoundBB)
978 fixSuccessorPhis(FoundBB);
979
980 // The new CmpBB block isn't part of the loop, but will need to be added to
981 // the outer loop if there is one.
982 if (!CurLoop->isOutermost())
983 CurLoop->getParentLoop()->addBasicBlockToLoop(CmpBB, *LI);
984
985 if (VerifyLoops && CurLoop->getParentLoop()) {
986 CurLoop->getParentLoop()->verifyLoop();
987 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI))
988 report_fatal_error("Loops must remain in LCSSA form!");
989 }
990 }
991
recognizeFindFirstByte()992 bool LoopIdiomVectorize::recognizeFindFirstByte() {
993 // Currently the transformation only works on scalable vector types, although
994 // there is no fundamental reason why it cannot be made to work for fixed
995 // vectors. We also need to know the target's minimum page size in order to
996 // generate runtime memory checks to ensure the vector version won't fault.
997 if (!TTI->supportsScalableVectors() || !TTI->getMinPageSize().has_value() ||
998 DisableFindFirstByte)
999 return false;
1000
1001 // Define some constants we need throughout.
1002 BasicBlock *Header = CurLoop->getHeader();
1003 LLVMContext &Ctx = Header->getContext();
1004
1005 // We are expecting the four blocks defined below: Header, MatchBB, InnerBB,
1006 // and OuterBB. For now, we will bail our for almost anything else. The Four
1007 // blocks contain one nested loop.
1008 if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 4 ||
1009 CurLoop->getSubLoops().size() != 1)
1010 return false;
1011
1012 auto *InnerLoop = CurLoop->getSubLoops().front();
1013 PHINode *IndPhi = dyn_cast<PHINode>(&Header->front());
1014 if (!IndPhi || IndPhi->getNumIncomingValues() != 2)
1015 return false;
1016
1017 // Check instruction counts.
1018 auto LoopBlocks = CurLoop->getBlocks();
1019 if (LoopBlocks[0]->sizeWithoutDebug() > 3 ||
1020 LoopBlocks[1]->sizeWithoutDebug() > 4 ||
1021 LoopBlocks[2]->sizeWithoutDebug() > 3 ||
1022 LoopBlocks[3]->sizeWithoutDebug() > 3)
1023 return false;
1024
1025 // Check that no instruction other than IndPhi has outside uses.
1026 for (BasicBlock *BB : LoopBlocks)
1027 for (Instruction &I : *BB)
1028 if (&I != IndPhi)
1029 for (User *U : I.users())
1030 if (!CurLoop->contains(cast<Instruction>(U)))
1031 return false;
1032
1033 // Match the branch instruction in the header. We are expecting an
1034 // unconditional branch to the inner loop.
1035 //
1036 // Header:
1037 // %14 = phi ptr [ %24, %OuterBB ], [ %3, %Header.preheader ]
1038 // %15 = load i8, ptr %14, align 1
1039 // br label %MatchBB
1040 BasicBlock *MatchBB;
1041 if (!match(Header->getTerminator(), m_UnconditionalBr(MatchBB)) ||
1042 !InnerLoop->contains(MatchBB))
1043 return false;
1044
1045 // MatchBB should be the entrypoint into the inner loop containing the
1046 // comparison between a search element and a needle.
1047 //
1048 // MatchBB:
1049 // %20 = phi ptr [ %7, %Header ], [ %17, %InnerBB ]
1050 // %21 = load i8, ptr %20, align 1
1051 // %22 = icmp eq i8 %15, %21
1052 // br i1 %22, label %ExitSucc, label %InnerBB
1053 BasicBlock *ExitSucc, *InnerBB;
1054 Value *LoadSearch, *LoadNeedle;
1055 CmpPredicate MatchPred;
1056 if (!match(MatchBB->getTerminator(),
1057 m_Br(m_ICmp(MatchPred, m_Value(LoadSearch), m_Value(LoadNeedle)),
1058 m_BasicBlock(ExitSucc), m_BasicBlock(InnerBB))) ||
1059 MatchPred != ICmpInst::ICMP_EQ || !InnerLoop->contains(InnerBB))
1060 return false;
1061
1062 // We expect outside uses of `IndPhi' in ExitSucc (and only there).
1063 for (User *U : IndPhi->users())
1064 if (!CurLoop->contains(cast<Instruction>(U))) {
1065 auto *PN = dyn_cast<PHINode>(U);
1066 if (!PN || PN->getParent() != ExitSucc)
1067 return false;
1068 }
1069
1070 // Match the loads and check they are simple.
1071 Value *Search, *Needle;
1072 if (!match(LoadSearch, m_Load(m_Value(Search))) ||
1073 !match(LoadNeedle, m_Load(m_Value(Needle))) ||
1074 !cast<LoadInst>(LoadSearch)->isSimple() ||
1075 !cast<LoadInst>(LoadNeedle)->isSimple())
1076 return false;
1077
1078 // Check we are loading valid characters.
1079 Type *CharTy = LoadSearch->getType();
1080 if (!CharTy->isIntegerTy() || LoadNeedle->getType() != CharTy)
1081 return false;
1082
1083 // Pick the vectorisation factor based on CharTy, work out the cost of the
1084 // match intrinsic and decide if we should use it.
1085 // Note: For the time being we assume 128-bit vectors.
1086 unsigned VF = 128 / CharTy->getIntegerBitWidth();
1087 SmallVector<Type *> Args = {
1088 ScalableVectorType::get(CharTy, VF), FixedVectorType::get(CharTy, VF),
1089 ScalableVectorType::get(Type::getInt1Ty(Ctx), VF)};
1090 IntrinsicCostAttributes Attrs(Intrinsic::experimental_vector_match, Args[2],
1091 Args);
1092 if (TTI->getIntrinsicInstrCost(Attrs, TTI::TCK_SizeAndLatency) > 4)
1093 return false;
1094
1095 // The loads come from two PHIs, each with two incoming values.
1096 PHINode *PSearch = dyn_cast<PHINode>(Search);
1097 PHINode *PNeedle = dyn_cast<PHINode>(Needle);
1098 if (!PSearch || PSearch->getNumIncomingValues() != 2 || !PNeedle ||
1099 PNeedle->getNumIncomingValues() != 2)
1100 return false;
1101
1102 // One PHI comes from the outer loop (PSearch), the other one from the inner
1103 // loop (PNeedle). PSearch effectively corresponds to IndPhi.
1104 if (InnerLoop->contains(PSearch))
1105 std::swap(PSearch, PNeedle);
1106 if (PSearch != &Header->front() || PNeedle != &MatchBB->front())
1107 return false;
1108
1109 // The incoming values of both PHI nodes should be a gep of 1.
1110 Value *SearchStart = PSearch->getIncomingValue(0);
1111 Value *SearchIndex = PSearch->getIncomingValue(1);
1112 if (CurLoop->contains(PSearch->getIncomingBlock(0)))
1113 std::swap(SearchStart, SearchIndex);
1114
1115 Value *NeedleStart = PNeedle->getIncomingValue(0);
1116 Value *NeedleIndex = PNeedle->getIncomingValue(1);
1117 if (InnerLoop->contains(PNeedle->getIncomingBlock(0)))
1118 std::swap(NeedleStart, NeedleIndex);
1119
1120 // Match the GEPs.
1121 if (!match(SearchIndex, m_GEP(m_Specific(PSearch), m_One())) ||
1122 !match(NeedleIndex, m_GEP(m_Specific(PNeedle), m_One())))
1123 return false;
1124
1125 // Check the GEPs result type matches `CharTy'.
1126 GetElementPtrInst *GEPSearch = cast<GetElementPtrInst>(SearchIndex);
1127 GetElementPtrInst *GEPNeedle = cast<GetElementPtrInst>(NeedleIndex);
1128 if (GEPSearch->getResultElementType() != CharTy ||
1129 GEPNeedle->getResultElementType() != CharTy)
1130 return false;
1131
1132 // InnerBB should increment the address of the needle pointer.
1133 //
1134 // InnerBB:
1135 // %17 = getelementptr inbounds i8, ptr %20, i64 1
1136 // %18 = icmp eq ptr %17, %10
1137 // br i1 %18, label %OuterBB, label %MatchBB
1138 BasicBlock *OuterBB;
1139 Value *NeedleEnd;
1140 if (!match(InnerBB->getTerminator(),
1141 m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(GEPNeedle),
1142 m_Value(NeedleEnd)),
1143 m_BasicBlock(OuterBB), m_Specific(MatchBB))) ||
1144 !CurLoop->contains(OuterBB))
1145 return false;
1146
1147 // OuterBB should increment the address of the search element pointer.
1148 //
1149 // OuterBB:
1150 // %24 = getelementptr inbounds i8, ptr %14, i64 1
1151 // %25 = icmp eq ptr %24, %6
1152 // br i1 %25, label %ExitFail, label %Header
1153 BasicBlock *ExitFail;
1154 Value *SearchEnd;
1155 if (!match(OuterBB->getTerminator(),
1156 m_Br(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(GEPSearch),
1157 m_Value(SearchEnd)),
1158 m_BasicBlock(ExitFail), m_Specific(Header))))
1159 return false;
1160
1161 if (!CurLoop->isLoopInvariant(SearchStart) ||
1162 !CurLoop->isLoopInvariant(SearchEnd) ||
1163 !CurLoop->isLoopInvariant(NeedleStart) ||
1164 !CurLoop->isLoopInvariant(NeedleEnd))
1165 return false;
1166
1167 LLVM_DEBUG(dbgs() << "Found idiom in loop: \n" << *CurLoop << "\n\n");
1168
1169 transformFindFirstByte(IndPhi, VF, CharTy, ExitSucc, ExitFail, SearchStart,
1170 SearchEnd, NeedleStart, NeedleEnd);
1171 return true;
1172 }
1173
expandFindFirstByte(IRBuilder<> & Builder,DomTreeUpdater & DTU,unsigned VF,Type * CharTy,BasicBlock * ExitSucc,BasicBlock * ExitFail,Value * SearchStart,Value * SearchEnd,Value * NeedleStart,Value * NeedleEnd)1174 Value *LoopIdiomVectorize::expandFindFirstByte(
1175 IRBuilder<> &Builder, DomTreeUpdater &DTU, unsigned VF, Type *CharTy,
1176 BasicBlock *ExitSucc, BasicBlock *ExitFail, Value *SearchStart,
1177 Value *SearchEnd, Value *NeedleStart, Value *NeedleEnd) {
1178 // Set up some types and constants that we intend to reuse.
1179 auto *PtrTy = Builder.getPtrTy();
1180 auto *I64Ty = Builder.getInt64Ty();
1181 auto *PredVTy = ScalableVectorType::get(Builder.getInt1Ty(), VF);
1182 auto *CharVTy = ScalableVectorType::get(CharTy, VF);
1183 auto *ConstVF = ConstantInt::get(I64Ty, VF);
1184
1185 // Other common arguments.
1186 BasicBlock *Preheader = CurLoop->getLoopPreheader();
1187 LLVMContext &Ctx = Preheader->getContext();
1188 Value *Passthru = ConstantInt::getNullValue(CharVTy);
1189
1190 // Split block in the original loop preheader.
1191 // SPH is the new preheader to the old scalar loop.
1192 BasicBlock *SPH = SplitBlock(Preheader, Preheader->getTerminator(), DT, LI,
1193 nullptr, "scalar_preheader");
1194
1195 // Create the blocks that we're going to use.
1196 //
1197 // We will have the following loops:
1198 // (O) Outer loop where we iterate over the elements of the search array.
1199 // (I) Inner loop where we iterate over the elements of the needle array.
1200 //
1201 // Overall, the blocks do the following:
1202 // (0) Check if the arrays can't cross page boundaries. If so go to (1),
1203 // otherwise fall back to the original scalar loop.
1204 // (1) Load the search array. Go to (2).
1205 // (2) (a) Load the needle array.
1206 // (b) Splat the first element to the inactive lanes.
1207 // (c) Check if any elements match. If so go to (3), otherwise go to (4).
1208 // (3) Compute the index of the first match and exit.
1209 // (4) Check if we've reached the end of the needle array. If not loop back to
1210 // (2), otherwise go to (5).
1211 // (5) Check if we've reached the end of the search array. If not loop back to
1212 // (1), otherwise exit.
1213 // Blocks (0,3) are not part of any loop. Blocks (1,5) and (2,4) belong to
1214 // the outer and inner loops, respectively.
1215 BasicBlock *BB0 = BasicBlock::Create(Ctx, "mem_check", SPH->getParent(), SPH);
1216 BasicBlock *BB1 =
1217 BasicBlock::Create(Ctx, "find_first_vec_header", SPH->getParent(), SPH);
1218 BasicBlock *BB2 =
1219 BasicBlock::Create(Ctx, "match_check_vec", SPH->getParent(), SPH);
1220 BasicBlock *BB3 =
1221 BasicBlock::Create(Ctx, "calculate_match", SPH->getParent(), SPH);
1222 BasicBlock *BB4 =
1223 BasicBlock::Create(Ctx, "needle_check_vec", SPH->getParent(), SPH);
1224 BasicBlock *BB5 =
1225 BasicBlock::Create(Ctx, "search_check_vec", SPH->getParent(), SPH);
1226
1227 // Update LoopInfo with the new loops.
1228 auto OuterLoop = LI->AllocateLoop();
1229 auto InnerLoop = LI->AllocateLoop();
1230
1231 if (auto ParentLoop = CurLoop->getParentLoop()) {
1232 ParentLoop->addBasicBlockToLoop(BB0, *LI);
1233 ParentLoop->addChildLoop(OuterLoop);
1234 ParentLoop->addBasicBlockToLoop(BB3, *LI);
1235 } else {
1236 LI->addTopLevelLoop(OuterLoop);
1237 }
1238
1239 // Add the inner loop to the outer.
1240 OuterLoop->addChildLoop(InnerLoop);
1241
1242 // Add the new basic blocks to the corresponding loops.
1243 OuterLoop->addBasicBlockToLoop(BB1, *LI);
1244 OuterLoop->addBasicBlockToLoop(BB5, *LI);
1245 InnerLoop->addBasicBlockToLoop(BB2, *LI);
1246 InnerLoop->addBasicBlockToLoop(BB4, *LI);
1247
1248 // Update the terminator added by SplitBlock to branch to the first block.
1249 Preheader->getTerminator()->setSuccessor(0, BB0);
1250 DTU.applyUpdates({{DominatorTree::Delete, Preheader, SPH},
1251 {DominatorTree::Insert, Preheader, BB0}});
1252
1253 // (0) Check if we could be crossing a page boundary; if so, fallback to the
1254 // old scalar loops. Also create a predicate of VF elements to be used in the
1255 // vector loops.
1256 Builder.SetInsertPoint(BB0);
1257 Value *ISearchStart =
1258 Builder.CreatePtrToInt(SearchStart, I64Ty, "search_start_int");
1259 Value *ISearchEnd =
1260 Builder.CreatePtrToInt(SearchEnd, I64Ty, "search_end_int");
1261 Value *INeedleStart =
1262 Builder.CreatePtrToInt(NeedleStart, I64Ty, "needle_start_int");
1263 Value *INeedleEnd =
1264 Builder.CreatePtrToInt(NeedleEnd, I64Ty, "needle_end_int");
1265 Value *PredVF =
1266 Builder.CreateIntrinsic(Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1267 {ConstantInt::get(I64Ty, 0), ConstVF});
1268
1269 const uint64_t MinPageSize = TTI->getMinPageSize().value();
1270 const uint64_t AddrShiftAmt = llvm::Log2_64(MinPageSize);
1271 Value *SearchStartPage =
1272 Builder.CreateLShr(ISearchStart, AddrShiftAmt, "search_start_page");
1273 Value *SearchEndPage =
1274 Builder.CreateLShr(ISearchEnd, AddrShiftAmt, "search_end_page");
1275 Value *NeedleStartPage =
1276 Builder.CreateLShr(INeedleStart, AddrShiftAmt, "needle_start_page");
1277 Value *NeedleEndPage =
1278 Builder.CreateLShr(INeedleEnd, AddrShiftAmt, "needle_end_page");
1279 Value *SearchPageCmp =
1280 Builder.CreateICmpNE(SearchStartPage, SearchEndPage, "search_page_cmp");
1281 Value *NeedlePageCmp =
1282 Builder.CreateICmpNE(NeedleStartPage, NeedleEndPage, "needle_page_cmp");
1283
1284 Value *CombinedPageCmp =
1285 Builder.CreateOr(SearchPageCmp, NeedlePageCmp, "combined_page_cmp");
1286 BranchInst *CombinedPageBr = Builder.CreateCondBr(CombinedPageCmp, SPH, BB1);
1287 CombinedPageBr->setMetadata(LLVMContext::MD_prof,
1288 MDBuilder(Ctx).createBranchWeights(10, 90));
1289 DTU.applyUpdates(
1290 {{DominatorTree::Insert, BB0, SPH}, {DominatorTree::Insert, BB0, BB1}});
1291
1292 // (1) Load the search array and branch to the inner loop.
1293 Builder.SetInsertPoint(BB1);
1294 PHINode *Search = Builder.CreatePHI(PtrTy, 2, "psearch");
1295 Value *PredSearch = Builder.CreateIntrinsic(
1296 Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1297 {Builder.CreatePtrToInt(Search, I64Ty), ISearchEnd}, nullptr,
1298 "search_pred");
1299 PredSearch = Builder.CreateAnd(PredVF, PredSearch, "search_masked");
1300 Value *LoadSearch = Builder.CreateMaskedLoad(
1301 CharVTy, Search, Align(1), PredSearch, Passthru, "search_load_vec");
1302 Builder.CreateBr(BB2);
1303 DTU.applyUpdates({{DominatorTree::Insert, BB1, BB2}});
1304
1305 // (2) Inner loop.
1306 Builder.SetInsertPoint(BB2);
1307 PHINode *Needle = Builder.CreatePHI(PtrTy, 2, "pneedle");
1308
1309 // (2.a) Load the needle array.
1310 Value *PredNeedle = Builder.CreateIntrinsic(
1311 Intrinsic::get_active_lane_mask, {PredVTy, I64Ty},
1312 {Builder.CreatePtrToInt(Needle, I64Ty), INeedleEnd}, nullptr,
1313 "needle_pred");
1314 PredNeedle = Builder.CreateAnd(PredVF, PredNeedle, "needle_masked");
1315 Value *LoadNeedle = Builder.CreateMaskedLoad(
1316 CharVTy, Needle, Align(1), PredNeedle, Passthru, "needle_load_vec");
1317
1318 // (2.b) Splat the first element to the inactive lanes.
1319 Value *Needle0 =
1320 Builder.CreateExtractElement(LoadNeedle, uint64_t(0), "needle0");
1321 Value *Needle0Splat = Builder.CreateVectorSplat(ElementCount::getScalable(VF),
1322 Needle0, "needle0");
1323 LoadNeedle = Builder.CreateSelect(PredNeedle, LoadNeedle, Needle0Splat,
1324 "needle_splat");
1325 LoadNeedle = Builder.CreateExtractVector(
1326 FixedVectorType::get(CharTy, VF), LoadNeedle, uint64_t(0), "needle_vec");
1327
1328 // (2.c) Test if there's a match.
1329 Value *MatchPred = Builder.CreateIntrinsic(
1330 Intrinsic::experimental_vector_match, {CharVTy, LoadNeedle->getType()},
1331 {LoadSearch, LoadNeedle, PredSearch}, nullptr, "match_pred");
1332 Value *IfAnyMatch = Builder.CreateOrReduce(MatchPred);
1333 Builder.CreateCondBr(IfAnyMatch, BB3, BB4);
1334 DTU.applyUpdates(
1335 {{DominatorTree::Insert, BB2, BB3}, {DominatorTree::Insert, BB2, BB4}});
1336
1337 // (3) We found a match. Compute the index of its location and exit.
1338 Builder.SetInsertPoint(BB3);
1339 PHINode *MatchLCSSA = Builder.CreatePHI(PtrTy, 1, "match_start");
1340 PHINode *MatchPredLCSSA =
1341 Builder.CreatePHI(MatchPred->getType(), 1, "match_vec");
1342 Value *MatchCnt = Builder.CreateIntrinsic(
1343 Intrinsic::experimental_cttz_elts, {I64Ty, MatchPred->getType()},
1344 {MatchPredLCSSA, /*ZeroIsPoison=*/Builder.getInt1(true)}, nullptr,
1345 "match_idx");
1346 Value *MatchVal =
1347 Builder.CreateGEP(CharTy, MatchLCSSA, MatchCnt, "match_res");
1348 Builder.CreateBr(ExitSucc);
1349 DTU.applyUpdates({{DominatorTree::Insert, BB3, ExitSucc}});
1350
1351 // (4) Check if we've reached the end of the needle array.
1352 Builder.SetInsertPoint(BB4);
1353 Value *NextNeedle =
1354 Builder.CreateGEP(CharTy, Needle, ConstVF, "needle_next_vec");
1355 Builder.CreateCondBr(Builder.CreateICmpULT(NextNeedle, NeedleEnd), BB2, BB5);
1356 DTU.applyUpdates(
1357 {{DominatorTree::Insert, BB4, BB2}, {DominatorTree::Insert, BB4, BB5}});
1358
1359 // (5) Check if we've reached the end of the search array.
1360 Builder.SetInsertPoint(BB5);
1361 Value *NextSearch =
1362 Builder.CreateGEP(CharTy, Search, ConstVF, "search_next_vec");
1363 Builder.CreateCondBr(Builder.CreateICmpULT(NextSearch, SearchEnd), BB1,
1364 ExitFail);
1365 DTU.applyUpdates({{DominatorTree::Insert, BB5, BB1},
1366 {DominatorTree::Insert, BB5, ExitFail}});
1367
1368 // Set up the PHI nodes.
1369 Search->addIncoming(SearchStart, BB0);
1370 Search->addIncoming(NextSearch, BB5);
1371 Needle->addIncoming(NeedleStart, BB1);
1372 Needle->addIncoming(NextNeedle, BB4);
1373 // These are needed to retain LCSSA form.
1374 MatchLCSSA->addIncoming(Search, BB2);
1375 MatchPredLCSSA->addIncoming(MatchPred, BB2);
1376
1377 if (VerifyLoops) {
1378 OuterLoop->verifyLoop();
1379 InnerLoop->verifyLoop();
1380 if (!OuterLoop->isRecursivelyLCSSAForm(*DT, *LI))
1381 report_fatal_error("Loops must remain in LCSSA form!");
1382 }
1383
1384 return MatchVal;
1385 }
1386
transformFindFirstByte(PHINode * IndPhi,unsigned VF,Type * CharTy,BasicBlock * ExitSucc,BasicBlock * ExitFail,Value * SearchStart,Value * SearchEnd,Value * NeedleStart,Value * NeedleEnd)1387 void LoopIdiomVectorize::transformFindFirstByte(
1388 PHINode *IndPhi, unsigned VF, Type *CharTy, BasicBlock *ExitSucc,
1389 BasicBlock *ExitFail, Value *SearchStart, Value *SearchEnd,
1390 Value *NeedleStart, Value *NeedleEnd) {
1391 // Insert the find first byte code at the end of the preheader block.
1392 BasicBlock *Preheader = CurLoop->getLoopPreheader();
1393 BranchInst *PHBranch = cast<BranchInst>(Preheader->getTerminator());
1394 IRBuilder<> Builder(PHBranch);
1395 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1396 Builder.SetCurrentDebugLocation(PHBranch->getDebugLoc());
1397
1398 Value *MatchVal =
1399 expandFindFirstByte(Builder, DTU, VF, CharTy, ExitSucc, ExitFail,
1400 SearchStart, SearchEnd, NeedleStart, NeedleEnd);
1401
1402 assert(PHBranch->isUnconditional() &&
1403 "Expected preheader to terminate with an unconditional branch.");
1404
1405 // Add new incoming values with the result of the transformation to PHINodes
1406 // of ExitSucc that use IndPhi.
1407 for (auto *U : llvm::make_early_inc_range(IndPhi->users())) {
1408 auto *PN = dyn_cast<PHINode>(U);
1409 if (PN && PN->getParent() == ExitSucc)
1410 PN->addIncoming(MatchVal, cast<Instruction>(MatchVal)->getParent());
1411 }
1412
1413 if (VerifyLoops && CurLoop->getParentLoop()) {
1414 CurLoop->getParentLoop()->verifyLoop();
1415 if (!CurLoop->getParentLoop()->isRecursivelyLCSSAForm(*DT, *LI))
1416 report_fatal_error("Loops must remain in LCSSA form!");
1417 }
1418 }
1419