1 //===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass optimizes scalar/vector interactions using target cost models. The
10 // transforms implemented here may not fit in traditional loop-based or SLP
11 // vectorization passes.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/Vectorize/VectorCombine.h"
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/ScopeExit.h"
19 #include "llvm/ADT/Statistic.h"
20 #include "llvm/Analysis/AssumptionCache.h"
21 #include "llvm/Analysis/BasicAliasAnalysis.h"
22 #include "llvm/Analysis/GlobalsModRef.h"
23 #include "llvm/Analysis/InstSimplifyFolder.h"
24 #include "llvm/Analysis/Loads.h"
25 #include "llvm/Analysis/TargetFolder.h"
26 #include "llvm/Analysis/TargetTransformInfo.h"
27 #include "llvm/Analysis/ValueTracking.h"
28 #include "llvm/Analysis/VectorUtils.h"
29 #include "llvm/IR/Dominators.h"
30 #include "llvm/IR/Function.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/PatternMatch.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Transforms/Utils/Local.h"
35 #include "llvm/Transforms/Utils/LoopUtils.h"
36 #include <numeric>
37 #include <queue>
38 #include <set>
39
40 #define DEBUG_TYPE "vector-combine"
41 #include "llvm/Transforms/Utils/InstructionWorklist.h"
42
43 using namespace llvm;
44 using namespace llvm::PatternMatch;
45
46 STATISTIC(NumVecLoad, "Number of vector loads formed");
47 STATISTIC(NumVecCmp, "Number of vector compares formed");
48 STATISTIC(NumVecBO, "Number of vector binops formed");
49 STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
50 STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
51 STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
52 STATISTIC(NumScalarCmp, "Number of scalar compares formed");
53 STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
54
55 static cl::opt<bool> DisableVectorCombine(
56 "disable-vector-combine", cl::init(false), cl::Hidden,
57 cl::desc("Disable all vector combine transforms"));
58
59 static cl::opt<bool> DisableBinopExtractShuffle(
60 "disable-binop-extract-shuffle", cl::init(false), cl::Hidden,
61 cl::desc("Disable binop extract to shuffle transforms"));
62
63 static cl::opt<unsigned> MaxInstrsToScan(
64 "vector-combine-max-scan-instrs", cl::init(30), cl::Hidden,
65 cl::desc("Max number of instructions to scan for vector combining."));
66
67 static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
68
69 namespace {
70 class VectorCombine {
71 public:
VectorCombine(Function & F,const TargetTransformInfo & TTI,const DominatorTree & DT,AAResults & AA,AssumptionCache & AC,const DataLayout * DL,TTI::TargetCostKind CostKind,bool TryEarlyFoldsOnly)72 VectorCombine(Function &F, const TargetTransformInfo &TTI,
73 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
74 const DataLayout *DL, TTI::TargetCostKind CostKind,
75 bool TryEarlyFoldsOnly)
76 : F(F), Builder(F.getContext(), InstSimplifyFolder(*DL)), TTI(TTI),
77 DT(DT), AA(AA), AC(AC), DL(DL), CostKind(CostKind), SQ(*DL),
78 TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
79
80 bool run();
81
82 private:
83 Function &F;
84 IRBuilder<InstSimplifyFolder> Builder;
85 const TargetTransformInfo &TTI;
86 const DominatorTree &DT;
87 AAResults &AA;
88 AssumptionCache &AC;
89 const DataLayout *DL;
90 TTI::TargetCostKind CostKind;
91 const SimplifyQuery SQ;
92
93 /// If true, only perform beneficial early IR transforms. Do not introduce new
94 /// vector operations.
95 bool TryEarlyFoldsOnly;
96
97 InstructionWorklist Worklist;
98
99 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
100 // parameter. That should be updated to specific sub-classes because the
101 // run loop was changed to dispatch on opcode.
102 bool vectorizeLoadInsert(Instruction &I);
103 bool widenSubvectorLoad(Instruction &I);
104 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
105 ExtractElementInst *Ext1,
106 unsigned PreferredExtractIndex) const;
107 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
108 const Instruction &I,
109 ExtractElementInst *&ConvertToShuffle,
110 unsigned PreferredExtractIndex);
111 void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
112 Instruction &I);
113 void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
114 Instruction &I);
115 bool foldExtractExtract(Instruction &I);
116 bool foldInsExtFNeg(Instruction &I);
117 bool foldInsExtBinop(Instruction &I);
118 bool foldInsExtVectorToShuffle(Instruction &I);
119 bool foldBitOpOfBitcasts(Instruction &I);
120 bool foldBitcastShuffle(Instruction &I);
121 bool scalarizeOpOrCmp(Instruction &I);
122 bool scalarizeVPIntrinsic(Instruction &I);
123 bool foldExtractedCmps(Instruction &I);
124 bool foldBinopOfReductions(Instruction &I);
125 bool foldSingleElementStore(Instruction &I);
126 bool scalarizeLoadExtract(Instruction &I);
127 bool scalarizeExtExtract(Instruction &I);
128 bool foldConcatOfBoolMasks(Instruction &I);
129 bool foldPermuteOfBinops(Instruction &I);
130 bool foldShuffleOfBinops(Instruction &I);
131 bool foldShuffleOfSelects(Instruction &I);
132 bool foldShuffleOfCastops(Instruction &I);
133 bool foldShuffleOfShuffles(Instruction &I);
134 bool foldShuffleOfIntrinsics(Instruction &I);
135 bool foldShuffleToIdentity(Instruction &I);
136 bool foldShuffleFromReductions(Instruction &I);
137 bool foldCastFromReductions(Instruction &I);
138 bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
139 bool foldInterleaveIntrinsics(Instruction &I);
140 bool shrinkType(Instruction &I);
141
replaceValue(Value & Old,Value & New)142 void replaceValue(Value &Old, Value &New) {
143 LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
144 LLVM_DEBUG(dbgs() << " With: " << New << '\n');
145 Old.replaceAllUsesWith(&New);
146 if (auto *NewI = dyn_cast<Instruction>(&New)) {
147 New.takeName(&Old);
148 Worklist.pushUsersToWorkList(*NewI);
149 Worklist.pushValue(NewI);
150 }
151 Worklist.pushValue(&Old);
152 }
153
eraseInstruction(Instruction & I)154 void eraseInstruction(Instruction &I) {
155 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
156 SmallVector<Value *> Ops(I.operands());
157 Worklist.remove(&I);
158 I.eraseFromParent();
159
160 // Push remaining users of the operands and then the operand itself - allows
161 // further folds that were hindered by OneUse limits.
162 for (Value *Op : Ops)
163 if (auto *OpI = dyn_cast<Instruction>(Op)) {
164 Worklist.pushUsersToWorkList(*OpI);
165 Worklist.pushValue(OpI);
166 }
167 }
168 };
169 } // namespace
170
171 /// Return the source operand of a potentially bitcasted value. If there is no
172 /// bitcast, return the input value itself.
peekThroughBitcasts(Value * V)173 static Value *peekThroughBitcasts(Value *V) {
174 while (auto *BitCast = dyn_cast<BitCastInst>(V))
175 V = BitCast->getOperand(0);
176 return V;
177 }
178
canWidenLoad(LoadInst * Load,const TargetTransformInfo & TTI)179 static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
180 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
181 // The widened load may load data from dirty regions or create data races
182 // non-existent in the source.
183 if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
184 Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) ||
185 mustSuppressSpeculation(*Load))
186 return false;
187
188 // We are potentially transforming byte-sized (8-bit) memory accesses, so make
189 // sure we have all of our type-based constraints in place for this target.
190 Type *ScalarTy = Load->getType()->getScalarType();
191 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
192 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
193 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
194 ScalarSize % 8 != 0)
195 return false;
196
197 return true;
198 }
199
vectorizeLoadInsert(Instruction & I)200 bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
201 // Match insert into fixed vector of scalar value.
202 // TODO: Handle non-zero insert index.
203 Value *Scalar;
204 if (!match(&I,
205 m_InsertElt(m_Poison(), m_OneUse(m_Value(Scalar)), m_ZeroInt())))
206 return false;
207
208 // Optionally match an extract from another vector.
209 Value *X;
210 bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt()));
211 if (!HasExtract)
212 X = Scalar;
213
214 auto *Load = dyn_cast<LoadInst>(X);
215 if (!canWidenLoad(Load, TTI))
216 return false;
217
218 Type *ScalarTy = Scalar->getType();
219 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
220 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
221
222 // Check safety of replacing the scalar load with a larger vector load.
223 // We use minimal alignment (maximum flexibility) because we only care about
224 // the dereferenceable region. When calculating cost and creating a new op,
225 // we may use a larger value based on alignment attributes.
226 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
227 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
228
229 unsigned MinVecNumElts = MinVectorSize / ScalarSize;
230 auto *MinVecTy = VectorType::get(ScalarTy, MinVecNumElts, false);
231 unsigned OffsetEltIndex = 0;
232 Align Alignment = Load->getAlign();
233 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC,
234 &DT)) {
235 // It is not safe to load directly from the pointer, but we can still peek
236 // through gep offsets and check if it safe to load from a base address with
237 // updated alignment. If it is, we can shuffle the element(s) into place
238 // after loading.
239 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(SrcPtr->getType());
240 APInt Offset(OffsetBitWidth, 0);
241 SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(*DL, Offset);
242
243 // We want to shuffle the result down from a high element of a vector, so
244 // the offset must be positive.
245 if (Offset.isNegative())
246 return false;
247
248 // The offset must be a multiple of the scalar element to shuffle cleanly
249 // in the element's size.
250 uint64_t ScalarSizeInBytes = ScalarSize / 8;
251 if (Offset.urem(ScalarSizeInBytes) != 0)
252 return false;
253
254 // If we load MinVecNumElts, will our target element still be loaded?
255 OffsetEltIndex = Offset.udiv(ScalarSizeInBytes).getZExtValue();
256 if (OffsetEltIndex >= MinVecNumElts)
257 return false;
258
259 if (!isSafeToLoadUnconditionally(SrcPtr, MinVecTy, Align(1), *DL, Load, &AC,
260 &DT))
261 return false;
262
263 // Update alignment with offset value. Note that the offset could be negated
264 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
265 // negation does not change the result of the alignment calculation.
266 Alignment = commonAlignment(Alignment, Offset.getZExtValue());
267 }
268
269 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
270 // Use the greater of the alignment on the load or its source pointer.
271 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
272 Type *LoadTy = Load->getType();
273 unsigned AS = Load->getPointerAddressSpace();
274 InstructionCost OldCost =
275 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
276 APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0);
277 OldCost +=
278 TTI.getScalarizationOverhead(MinVecTy, DemandedElts,
279 /* Insert */ true, HasExtract, CostKind);
280
281 // New pattern: load VecPtr
282 InstructionCost NewCost =
283 TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS, CostKind);
284 // Optionally, we are shuffling the loaded vector element(s) into place.
285 // For the mask set everything but element 0 to undef to prevent poison from
286 // propagating from the extra loaded memory. This will also optionally
287 // shrink/grow the vector from the loaded size to the output size.
288 // We assume this operation has no cost in codegen if there was no offset.
289 // Note that we could use freeze to avoid poison problems, but then we might
290 // still need a shuffle to change the vector size.
291 auto *Ty = cast<FixedVectorType>(I.getType());
292 unsigned OutputNumElts = Ty->getNumElements();
293 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
294 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
295 Mask[0] = OffsetEltIndex;
296 if (OffsetEltIndex)
297 NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, Ty, MinVecTy, Mask,
298 CostKind);
299
300 // We can aggressively convert to the vector form because the backend can
301 // invert this transform if it does not result in a performance win.
302 if (OldCost < NewCost || !NewCost.isValid())
303 return false;
304
305 // It is safe and potentially profitable to load a vector directly:
306 // inselt undef, load Scalar, 0 --> load VecPtr
307 IRBuilder<> Builder(Load);
308 Value *CastedPtr =
309 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
310 Value *VecLd = Builder.CreateAlignedLoad(MinVecTy, CastedPtr, Alignment);
311 VecLd = Builder.CreateShuffleVector(VecLd, Mask);
312
313 replaceValue(I, *VecLd);
314 ++NumVecLoad;
315 return true;
316 }
317
318 /// If we are loading a vector and then inserting it into a larger vector with
319 /// undefined elements, try to load the larger vector and eliminate the insert.
320 /// This removes a shuffle in IR and may allow combining of other loaded values.
widenSubvectorLoad(Instruction & I)321 bool VectorCombine::widenSubvectorLoad(Instruction &I) {
322 // Match subvector insert of fixed vector.
323 auto *Shuf = cast<ShuffleVectorInst>(&I);
324 if (!Shuf->isIdentityWithPadding())
325 return false;
326
327 // Allow a non-canonical shuffle mask that is choosing elements from op1.
328 unsigned NumOpElts =
329 cast<FixedVectorType>(Shuf->getOperand(0)->getType())->getNumElements();
330 unsigned OpIndex = any_of(Shuf->getShuffleMask(), [&NumOpElts](int M) {
331 return M >= (int)(NumOpElts);
332 });
333
334 auto *Load = dyn_cast<LoadInst>(Shuf->getOperand(OpIndex));
335 if (!canWidenLoad(Load, TTI))
336 return false;
337
338 // We use minimal alignment (maximum flexibility) because we only care about
339 // the dereferenceable region. When calculating cost and creating a new op,
340 // we may use a larger value based on alignment attributes.
341 auto *Ty = cast<FixedVectorType>(I.getType());
342 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
343 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
344 Align Alignment = Load->getAlign();
345 if (!isSafeToLoadUnconditionally(SrcPtr, Ty, Align(1), *DL, Load, &AC, &DT))
346 return false;
347
348 Alignment = std::max(SrcPtr->getPointerAlignment(*DL), Alignment);
349 Type *LoadTy = Load->getType();
350 unsigned AS = Load->getPointerAddressSpace();
351
352 // Original pattern: insert_subvector (load PtrOp)
353 // This conservatively assumes that the cost of a subvector insert into an
354 // undef value is 0. We could add that cost if the cost model accurately
355 // reflects the real cost of that operation.
356 InstructionCost OldCost =
357 TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS, CostKind);
358
359 // New pattern: load PtrOp
360 InstructionCost NewCost =
361 TTI.getMemoryOpCost(Instruction::Load, Ty, Alignment, AS, CostKind);
362
363 // We can aggressively convert to the vector form because the backend can
364 // invert this transform if it does not result in a performance win.
365 if (OldCost < NewCost || !NewCost.isValid())
366 return false;
367
368 IRBuilder<> Builder(Load);
369 Value *CastedPtr =
370 Builder.CreatePointerBitCastOrAddrSpaceCast(SrcPtr, Builder.getPtrTy(AS));
371 Value *VecLd = Builder.CreateAlignedLoad(Ty, CastedPtr, Alignment);
372 replaceValue(I, *VecLd);
373 ++NumVecLoad;
374 return true;
375 }
376
377 /// Determine which, if any, of the inputs should be replaced by a shuffle
378 /// followed by extract from a different index.
getShuffleExtract(ExtractElementInst * Ext0,ExtractElementInst * Ext1,unsigned PreferredExtractIndex=InvalidIndex) const379 ExtractElementInst *VectorCombine::getShuffleExtract(
380 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
381 unsigned PreferredExtractIndex = InvalidIndex) const {
382 auto *Index0C = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
383 auto *Index1C = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
384 assert(Index0C && Index1C && "Expected constant extract indexes");
385
386 unsigned Index0 = Index0C->getZExtValue();
387 unsigned Index1 = Index1C->getZExtValue();
388
389 // If the extract indexes are identical, no shuffle is needed.
390 if (Index0 == Index1)
391 return nullptr;
392
393 Type *VecTy = Ext0->getVectorOperand()->getType();
394 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
395 InstructionCost Cost0 =
396 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
397 InstructionCost Cost1 =
398 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
399
400 // If both costs are invalid no shuffle is needed
401 if (!Cost0.isValid() && !Cost1.isValid())
402 return nullptr;
403
404 // We are extracting from 2 different indexes, so one operand must be shuffled
405 // before performing a vector operation and/or extract. The more expensive
406 // extract will be replaced by a shuffle.
407 if (Cost0 > Cost1)
408 return Ext0;
409 if (Cost1 > Cost0)
410 return Ext1;
411
412 // If the costs are equal and there is a preferred extract index, shuffle the
413 // opposite operand.
414 if (PreferredExtractIndex == Index0)
415 return Ext1;
416 if (PreferredExtractIndex == Index1)
417 return Ext0;
418
419 // Otherwise, replace the extract with the higher index.
420 return Index0 > Index1 ? Ext0 : Ext1;
421 }
422
423 /// Compare the relative costs of 2 extracts followed by scalar operation vs.
424 /// vector operation(s) followed by extract. Return true if the existing
425 /// instructions are cheaper than a vector alternative. Otherwise, return false
426 /// and if one of the extracts should be transformed to a shufflevector, set
427 /// \p ConvertToShuffle to that extract instruction.
isExtractExtractCheap(ExtractElementInst * Ext0,ExtractElementInst * Ext1,const Instruction & I,ExtractElementInst * & ConvertToShuffle,unsigned PreferredExtractIndex)428 bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
429 ExtractElementInst *Ext1,
430 const Instruction &I,
431 ExtractElementInst *&ConvertToShuffle,
432 unsigned PreferredExtractIndex) {
433 auto *Ext0IndexC = dyn_cast<ConstantInt>(Ext0->getIndexOperand());
434 auto *Ext1IndexC = dyn_cast<ConstantInt>(Ext1->getIndexOperand());
435 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
436
437 unsigned Opcode = I.getOpcode();
438 Value *Ext0Src = Ext0->getVectorOperand();
439 Value *Ext1Src = Ext1->getVectorOperand();
440 Type *ScalarTy = Ext0->getType();
441 auto *VecTy = cast<VectorType>(Ext0Src->getType());
442 InstructionCost ScalarOpCost, VectorOpCost;
443
444 // Get cost estimates for scalar and vector versions of the operation.
445 bool IsBinOp = Instruction::isBinaryOp(Opcode);
446 if (IsBinOp) {
447 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
448 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
449 } else {
450 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
451 "Expected a compare");
452 CmpInst::Predicate Pred = cast<CmpInst>(I).getPredicate();
453 ScalarOpCost = TTI.getCmpSelInstrCost(
454 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
455 VectorOpCost = TTI.getCmpSelInstrCost(
456 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
457 }
458
459 // Get cost estimates for the extract elements. These costs will factor into
460 // both sequences.
461 unsigned Ext0Index = Ext0IndexC->getZExtValue();
462 unsigned Ext1Index = Ext1IndexC->getZExtValue();
463
464 InstructionCost Extract0Cost =
465 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Ext0Index);
466 InstructionCost Extract1Cost =
467 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Ext1Index);
468
469 // A more expensive extract will always be replaced by a splat shuffle.
470 // For example, if Ext0 is more expensive:
471 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
472 // extelt (opcode (splat V0, Ext0), V1), Ext1
473 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
474 // check the cost of creating a broadcast shuffle and shuffling both
475 // operands to element 0.
476 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
477 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
478 InstructionCost CheapExtractCost = std::min(Extract0Cost, Extract1Cost);
479
480 // Extra uses of the extracts mean that we include those costs in the
481 // vector total because those instructions will not be eliminated.
482 InstructionCost OldCost, NewCost;
483 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
484 // Handle a special case. If the 2 extracts are identical, adjust the
485 // formulas to account for that. The extra use charge allows for either the
486 // CSE'd pattern or an unoptimized form with identical values:
487 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
488 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(2)
489 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
490 OldCost = CheapExtractCost + ScalarOpCost;
491 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
492 } else {
493 // Handle the general case. Each extract is actually a different value:
494 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
495 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
496 NewCost = VectorOpCost + CheapExtractCost +
497 !Ext0->hasOneUse() * Extract0Cost +
498 !Ext1->hasOneUse() * Extract1Cost;
499 }
500
501 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
502 if (ConvertToShuffle) {
503 if (IsBinOp && DisableBinopExtractShuffle)
504 return true;
505
506 // If we are extracting from 2 different indexes, then one operand must be
507 // shuffled before performing the vector operation. The shuffle mask is
508 // poison except for 1 lane that is being translated to the remaining
509 // extraction lane. Therefore, it is a splat shuffle. Ex:
510 // ShufMask = { poison, poison, 0, poison }
511 // TODO: The cost model has an option for a "broadcast" shuffle
512 // (splat-from-element-0), but no option for a more general splat.
513 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(VecTy)) {
514 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
515 PoisonMaskElem);
516 ShuffleMask[BestInsIndex] = BestExtIndex;
517 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
518 VecTy, VecTy, ShuffleMask, CostKind, 0,
519 nullptr, {ConvertToShuffle});
520 } else {
521 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
522 VecTy, VecTy, {}, CostKind, 0, nullptr,
523 {ConvertToShuffle});
524 }
525 }
526
527 // Aggressively form a vector op if the cost is equal because the transform
528 // may enable further optimization.
529 // Codegen can reverse this transform (scalarize) if it was not profitable.
530 return OldCost < NewCost;
531 }
532
533 /// Create a shuffle that translates (shifts) 1 element from the input vector
534 /// to a new element location.
createShiftShuffle(Value * Vec,unsigned OldIndex,unsigned NewIndex,IRBuilderBase & Builder)535 static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
536 unsigned NewIndex, IRBuilderBase &Builder) {
537 // The shuffle mask is poison except for 1 lane that is being translated
538 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
539 // ShufMask = { 2, poison, poison, poison }
540 auto *VecTy = cast<FixedVectorType>(Vec->getType());
541 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
542 ShufMask[NewIndex] = OldIndex;
543 return Builder.CreateShuffleVector(Vec, ShufMask, "shift");
544 }
545
546 /// Given an extract element instruction with constant index operand, shuffle
547 /// the source vector (shift the scalar element) to a NewIndex for extraction.
548 /// Return null if the input can be constant folded, so that we are not creating
549 /// unnecessary instructions.
translateExtract(ExtractElementInst * ExtElt,unsigned NewIndex,IRBuilderBase & Builder)550 static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
551 unsigned NewIndex,
552 IRBuilderBase &Builder) {
553 // Shufflevectors can only be created for fixed-width vectors.
554 Value *X = ExtElt->getVectorOperand();
555 if (!isa<FixedVectorType>(X->getType()))
556 return nullptr;
557
558 // If the extract can be constant-folded, this code is unsimplified. Defer
559 // to other passes to handle that.
560 Value *C = ExtElt->getIndexOperand();
561 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
562 if (isa<Constant>(X))
563 return nullptr;
564
565 Value *Shuf = createShiftShuffle(X, cast<ConstantInt>(C)->getZExtValue(),
566 NewIndex, Builder);
567 return dyn_cast<ExtractElementInst>(
568 Builder.CreateExtractElement(Shuf, NewIndex));
569 }
570
571 /// Try to reduce extract element costs by converting scalar compares to vector
572 /// compares followed by extract.
573 /// cmp (ext0 V0, C), (ext1 V1, C)
foldExtExtCmp(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)574 void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
575 ExtractElementInst *Ext1, Instruction &I) {
576 assert(isa<CmpInst>(&I) && "Expected a compare");
577 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
578 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
579 "Expected matching constant extract indexes");
580
581 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
582 ++NumVecCmp;
583 CmpInst::Predicate Pred = cast<CmpInst>(&I)->getPredicate();
584 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
585 Value *VecCmp = Builder.CreateCmp(Pred, V0, V1);
586 Value *NewExt = Builder.CreateExtractElement(VecCmp, Ext0->getIndexOperand());
587 replaceValue(I, *NewExt);
588 }
589
590 /// Try to reduce extract element costs by converting scalar binops to vector
591 /// binops followed by extract.
592 /// bo (ext0 V0, C), (ext1 V1, C)
foldExtExtBinop(ExtractElementInst * Ext0,ExtractElementInst * Ext1,Instruction & I)593 void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
594 ExtractElementInst *Ext1, Instruction &I) {
595 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
596 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
597 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
598 "Expected matching constant extract indexes");
599
600 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
601 ++NumVecBO;
602 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
603 Value *VecBO =
604 Builder.CreateBinOp(cast<BinaryOperator>(&I)->getOpcode(), V0, V1);
605
606 // All IR flags are safe to back-propagate because any potential poison
607 // created in unused vector elements is discarded by the extract.
608 if (auto *VecBOInst = dyn_cast<Instruction>(VecBO))
609 VecBOInst->copyIRFlags(&I);
610
611 Value *NewExt = Builder.CreateExtractElement(VecBO, Ext0->getIndexOperand());
612 replaceValue(I, *NewExt);
613 }
614
615 /// Match an instruction with extracted vector operands.
foldExtractExtract(Instruction & I)616 bool VectorCombine::foldExtractExtract(Instruction &I) {
617 // It is not safe to transform things like div, urem, etc. because we may
618 // create undefined behavior when executing those on unknown vector elements.
619 if (!isSafeToSpeculativelyExecute(&I))
620 return false;
621
622 Instruction *I0, *I1;
623 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
624 if (!match(&I, m_Cmp(Pred, m_Instruction(I0), m_Instruction(I1))) &&
625 !match(&I, m_BinOp(m_Instruction(I0), m_Instruction(I1))))
626 return false;
627
628 Value *V0, *V1;
629 uint64_t C0, C1;
630 if (!match(I0, m_ExtractElt(m_Value(V0), m_ConstantInt(C0))) ||
631 !match(I1, m_ExtractElt(m_Value(V1), m_ConstantInt(C1))) ||
632 V0->getType() != V1->getType())
633 return false;
634
635 // If the scalar value 'I' is going to be re-inserted into a vector, then try
636 // to create an extract to that same element. The extract/insert can be
637 // reduced to a "select shuffle".
638 // TODO: If we add a larger pattern match that starts from an insert, this
639 // probably becomes unnecessary.
640 auto *Ext0 = cast<ExtractElementInst>(I0);
641 auto *Ext1 = cast<ExtractElementInst>(I1);
642 uint64_t InsertIndex = InvalidIndex;
643 if (I.hasOneUse())
644 match(I.user_back(),
645 m_InsertElt(m_Value(), m_Value(), m_ConstantInt(InsertIndex)));
646
647 ExtractElementInst *ExtractToChange;
648 if (isExtractExtractCheap(Ext0, Ext1, I, ExtractToChange, InsertIndex))
649 return false;
650
651 if (ExtractToChange) {
652 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
653 ExtractElementInst *NewExtract =
654 translateExtract(ExtractToChange, CheapExtractIdx, Builder);
655 if (!NewExtract)
656 return false;
657 if (ExtractToChange == Ext0)
658 Ext0 = NewExtract;
659 else
660 Ext1 = NewExtract;
661 }
662
663 if (Pred != CmpInst::BAD_ICMP_PREDICATE)
664 foldExtExtCmp(Ext0, Ext1, I);
665 else
666 foldExtExtBinop(Ext0, Ext1, I);
667
668 Worklist.push(Ext0);
669 Worklist.push(Ext1);
670 return true;
671 }
672
673 /// Try to replace an extract + scalar fneg + insert with a vector fneg +
674 /// shuffle.
foldInsExtFNeg(Instruction & I)675 bool VectorCombine::foldInsExtFNeg(Instruction &I) {
676 // Match an insert (op (extract)) pattern.
677 Value *DestVec;
678 uint64_t Index;
679 Instruction *FNeg;
680 if (!match(&I, m_InsertElt(m_Value(DestVec), m_OneUse(m_Instruction(FNeg)),
681 m_ConstantInt(Index))))
682 return false;
683
684 // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
685 Value *SrcVec;
686 Instruction *Extract;
687 if (!match(FNeg, m_FNeg(m_CombineAnd(
688 m_Instruction(Extract),
689 m_ExtractElt(m_Value(SrcVec), m_SpecificInt(Index))))))
690 return false;
691
692 auto *VecTy = cast<FixedVectorType>(I.getType());
693 auto *ScalarTy = VecTy->getScalarType();
694 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
695 if (!SrcVecTy || ScalarTy != SrcVecTy->getScalarType())
696 return false;
697
698 // Ignore bogus insert/extract index.
699 unsigned NumElts = VecTy->getNumElements();
700 if (Index >= NumElts)
701 return false;
702
703 // We are inserting the negated element into the same lane that we extracted
704 // from. This is equivalent to a select-shuffle that chooses all but the
705 // negated element from the destination vector.
706 SmallVector<int> Mask(NumElts);
707 std::iota(Mask.begin(), Mask.end(), 0);
708 Mask[Index] = Index + NumElts;
709 InstructionCost OldCost =
710 TTI.getArithmeticInstrCost(Instruction::FNeg, ScalarTy, CostKind) +
711 TTI.getVectorInstrCost(I, VecTy, CostKind, Index);
712
713 // If the extract has one use, it will be eliminated, so count it in the
714 // original cost. If it has more than one use, ignore the cost because it will
715 // be the same before/after.
716 if (Extract->hasOneUse())
717 OldCost += TTI.getVectorInstrCost(*Extract, VecTy, CostKind, Index);
718
719 InstructionCost NewCost =
720 TTI.getArithmeticInstrCost(Instruction::FNeg, VecTy, CostKind) +
721 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, VecTy, VecTy,
722 Mask, CostKind);
723
724 bool NeedLenChg = SrcVecTy->getNumElements() != NumElts;
725 // If the lengths of the two vectors are not equal,
726 // we need to add a length-change vector. Add this cost.
727 SmallVector<int> SrcMask;
728 if (NeedLenChg) {
729 SrcMask.assign(NumElts, PoisonMaskElem);
730 SrcMask[Index] = Index;
731 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
732 VecTy, SrcVecTy, SrcMask, CostKind);
733 }
734
735 if (NewCost > OldCost)
736 return false;
737
738 Value *NewShuf;
739 // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index
740 Value *VecFNeg = Builder.CreateFNegFMF(SrcVec, FNeg);
741 if (NeedLenChg) {
742 // shuffle DestVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
743 Value *LenChgShuf = Builder.CreateShuffleVector(VecFNeg, SrcMask);
744 NewShuf = Builder.CreateShuffleVector(DestVec, LenChgShuf, Mask);
745 } else {
746 // shuffle DestVec, (fneg SrcVec), Mask
747 NewShuf = Builder.CreateShuffleVector(DestVec, VecFNeg, Mask);
748 }
749
750 replaceValue(I, *NewShuf);
751 return true;
752 }
753
754 /// Try to fold insert(binop(x,y),binop(a,b),idx)
755 /// --> binop(insert(x,a,idx),insert(y,b,idx))
foldInsExtBinop(Instruction & I)756 bool VectorCombine::foldInsExtBinop(Instruction &I) {
757 BinaryOperator *VecBinOp, *SclBinOp;
758 uint64_t Index;
759 if (!match(&I,
760 m_InsertElt(m_OneUse(m_BinOp(VecBinOp)),
761 m_OneUse(m_BinOp(SclBinOp)), m_ConstantInt(Index))))
762 return false;
763
764 // TODO: Add support for addlike etc.
765 Instruction::BinaryOps BinOpcode = VecBinOp->getOpcode();
766 if (BinOpcode != SclBinOp->getOpcode())
767 return false;
768
769 auto *ResultTy = dyn_cast<FixedVectorType>(I.getType());
770 if (!ResultTy)
771 return false;
772
773 // TODO: Attempt to detect m_ExtractElt for scalar operands and convert to
774 // shuffle?
775
776 InstructionCost OldCost = TTI.getInstructionCost(&I, CostKind) +
777 TTI.getInstructionCost(VecBinOp, CostKind) +
778 TTI.getInstructionCost(SclBinOp, CostKind);
779 InstructionCost NewCost =
780 TTI.getArithmeticInstrCost(BinOpcode, ResultTy, CostKind) +
781 TTI.getVectorInstrCost(Instruction::InsertElement, ResultTy, CostKind,
782 Index, VecBinOp->getOperand(0),
783 SclBinOp->getOperand(0)) +
784 TTI.getVectorInstrCost(Instruction::InsertElement, ResultTy, CostKind,
785 Index, VecBinOp->getOperand(1),
786 SclBinOp->getOperand(1));
787
788 LLVM_DEBUG(dbgs() << "Found an insertion of two binops: " << I
789 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
790 << "\n");
791 if (NewCost > OldCost)
792 return false;
793
794 Value *NewIns0 = Builder.CreateInsertElement(VecBinOp->getOperand(0),
795 SclBinOp->getOperand(0), Index);
796 Value *NewIns1 = Builder.CreateInsertElement(VecBinOp->getOperand(1),
797 SclBinOp->getOperand(1), Index);
798 Value *NewBO = Builder.CreateBinOp(BinOpcode, NewIns0, NewIns1);
799
800 // Intersect flags from the old binops.
801 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
802 NewInst->copyIRFlags(VecBinOp);
803 NewInst->andIRFlags(SclBinOp);
804 }
805
806 Worklist.pushValue(NewIns0);
807 Worklist.pushValue(NewIns1);
808 replaceValue(I, *NewBO);
809 return true;
810 }
811
foldBitOpOfBitcasts(Instruction & I)812 bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
813 // Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
814 Value *LHSSrc, *RHSSrc;
815 if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
816 m_BitCast(m_Value(RHSSrc)))))
817 return false;
818
819 // Source types must match
820 if (LHSSrc->getType() != RHSSrc->getType())
821 return false;
822 if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
823 return false;
824
825 // Only handle vector types
826 auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
827 auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
828 if (!SrcVecTy || !DstVecTy)
829 return false;
830
831 // Same total bit width
832 assert(SrcVecTy->getPrimitiveSizeInBits() ==
833 DstVecTy->getPrimitiveSizeInBits() &&
834 "Bitcast should preserve total bit width");
835
836 // Cost Check :
837 // OldCost = bitlogic + 2*bitcasts
838 // NewCost = bitlogic + bitcast
839 auto *BinOp = cast<BinaryOperator>(&I);
840 InstructionCost OldCost =
841 TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
842 TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
843 TTI::CastContextHint::None) +
844 TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
845 TTI::CastContextHint::None);
846 InstructionCost NewCost =
847 TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
848 TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
849 TTI::CastContextHint::None);
850
851 LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I
852 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
853 << "\n");
854
855 if (NewCost > OldCost)
856 return false;
857
858 // Create the operation on the source type
859 Value *NewOp = Builder.CreateBinOp(BinOp->getOpcode(), LHSSrc, RHSSrc,
860 BinOp->getName() + ".inner");
861 if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
862 NewBinOp->copyIRFlags(BinOp);
863
864 Worklist.pushValue(NewOp);
865
866 // Bitcast the result back
867 Value *Result = Builder.CreateBitCast(NewOp, I.getType());
868 replaceValue(I, *Result);
869 return true;
870 }
871
872 /// If this is a bitcast of a shuffle, try to bitcast the source vector to the
873 /// destination type followed by shuffle. This can enable further transforms by
874 /// moving bitcasts or shuffles together.
foldBitcastShuffle(Instruction & I)875 bool VectorCombine::foldBitcastShuffle(Instruction &I) {
876 Value *V0, *V1;
877 ArrayRef<int> Mask;
878 if (!match(&I, m_BitCast(m_OneUse(
879 m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(Mask))))))
880 return false;
881
882 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
883 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
884 // mask for scalable type is a splat or not.
885 // 2) Disallow non-vector casts.
886 // TODO: We could allow any shuffle.
887 auto *DestTy = dyn_cast<FixedVectorType>(I.getType());
888 auto *SrcTy = dyn_cast<FixedVectorType>(V0->getType());
889 if (!DestTy || !SrcTy)
890 return false;
891
892 unsigned DestEltSize = DestTy->getScalarSizeInBits();
893 unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
894 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
895 return false;
896
897 bool IsUnary = isa<UndefValue>(V1);
898
899 // For binary shuffles, only fold bitcast(shuffle(X,Y))
900 // if it won't increase the number of bitcasts.
901 if (!IsUnary) {
902 auto *BCTy0 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V0)->getType());
903 auto *BCTy1 = dyn_cast<FixedVectorType>(peekThroughBitcasts(V1)->getType());
904 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
905 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
906 return false;
907 }
908
909 SmallVector<int, 16> NewMask;
910 if (DestEltSize <= SrcEltSize) {
911 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
912 // always be expanded to the equivalent form choosing narrower elements.
913 assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask");
914 unsigned ScaleFactor = SrcEltSize / DestEltSize;
915 narrowShuffleMaskElts(ScaleFactor, Mask, NewMask);
916 } else {
917 // The bitcast is from narrow elements to wide elements. The shuffle mask
918 // must choose consecutive elements to allow casting first.
919 assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask");
920 unsigned ScaleFactor = DestEltSize / SrcEltSize;
921 if (!widenShuffleMaskElts(ScaleFactor, Mask, NewMask))
922 return false;
923 }
924
925 // Bitcast the shuffle src - keep its original width but using the destination
926 // scalar type.
927 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
928 auto *NewShuffleTy =
929 FixedVectorType::get(DestTy->getScalarType(), NumSrcElts);
930 auto *OldShuffleTy =
931 FixedVectorType::get(SrcTy->getScalarType(), Mask.size());
932 unsigned NumOps = IsUnary ? 1 : 2;
933
934 // The new shuffle must not cost more than the old shuffle.
935 TargetTransformInfo::ShuffleKind SK =
936 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
937 : TargetTransformInfo::SK_PermuteTwoSrc;
938
939 InstructionCost NewCost =
940 TTI.getShuffleCost(SK, DestTy, NewShuffleTy, NewMask, CostKind) +
941 (NumOps * TTI.getCastInstrCost(Instruction::BitCast, NewShuffleTy, SrcTy,
942 TargetTransformInfo::CastContextHint::None,
943 CostKind));
944 InstructionCost OldCost =
945 TTI.getShuffleCost(SK, OldShuffleTy, SrcTy, Mask, CostKind) +
946 TTI.getCastInstrCost(Instruction::BitCast, DestTy, OldShuffleTy,
947 TargetTransformInfo::CastContextHint::None,
948 CostKind);
949
950 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: "
951 << OldCost << " vs NewCost: " << NewCost << "\n");
952
953 if (NewCost > OldCost || !NewCost.isValid())
954 return false;
955
956 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
957 ++NumShufOfBitcast;
958 Value *CastV0 = Builder.CreateBitCast(peekThroughBitcasts(V0), NewShuffleTy);
959 Value *CastV1 = Builder.CreateBitCast(peekThroughBitcasts(V1), NewShuffleTy);
960 Value *Shuf = Builder.CreateShuffleVector(CastV0, CastV1, NewMask);
961 replaceValue(I, *Shuf);
962 return true;
963 }
964
965 /// VP Intrinsics whose vector operands are both splat values may be simplified
966 /// into the scalar version of the operation and the result splatted. This
967 /// can lead to scalarization down the line.
scalarizeVPIntrinsic(Instruction & I)968 bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
969 if (!isa<VPIntrinsic>(I))
970 return false;
971 VPIntrinsic &VPI = cast<VPIntrinsic>(I);
972 Value *Op0 = VPI.getArgOperand(0);
973 Value *Op1 = VPI.getArgOperand(1);
974
975 if (!isSplatValue(Op0) || !isSplatValue(Op1))
976 return false;
977
978 // Check getSplatValue early in this function, to avoid doing unnecessary
979 // work.
980 Value *ScalarOp0 = getSplatValue(Op0);
981 Value *ScalarOp1 = getSplatValue(Op1);
982 if (!ScalarOp0 || !ScalarOp1)
983 return false;
984
985 // For the binary VP intrinsics supported here, the result on disabled lanes
986 // is a poison value. For now, only do this simplification if all lanes
987 // are active.
988 // TODO: Relax the condition that all lanes are active by using insertelement
989 // on inactive lanes.
990 auto IsAllTrueMask = [](Value *MaskVal) {
991 if (Value *SplattedVal = getSplatValue(MaskVal))
992 if (auto *ConstValue = dyn_cast<Constant>(SplattedVal))
993 return ConstValue->isAllOnesValue();
994 return false;
995 };
996 if (!IsAllTrueMask(VPI.getArgOperand(2)))
997 return false;
998
999 // Check to make sure we support scalarization of the intrinsic
1000 Intrinsic::ID IntrID = VPI.getIntrinsicID();
1001 if (!VPBinOpIntrinsic::isVPBinOp(IntrID))
1002 return false;
1003
1004 // Calculate cost of splatting both operands into vectors and the vector
1005 // intrinsic
1006 VectorType *VecTy = cast<VectorType>(VPI.getType());
1007 SmallVector<int> Mask;
1008 if (auto *FVTy = dyn_cast<FixedVectorType>(VecTy))
1009 Mask.resize(FVTy->getNumElements(), 0);
1010 InstructionCost SplatCost =
1011 TTI.getVectorInstrCost(Instruction::InsertElement, VecTy, CostKind, 0) +
1012 TTI.getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, VecTy, Mask,
1013 CostKind);
1014
1015 // Calculate the cost of the VP Intrinsic
1016 SmallVector<Type *, 4> Args;
1017 for (Value *V : VPI.args())
1018 Args.push_back(V->getType());
1019 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
1020 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
1021 InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
1022
1023 // Determine scalar opcode
1024 std::optional<unsigned> FunctionalOpcode =
1025 VPI.getFunctionalOpcode();
1026 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
1027 if (!FunctionalOpcode) {
1028 ScalarIntrID = VPI.getFunctionalIntrinsicID();
1029 if (!ScalarIntrID)
1030 return false;
1031 }
1032
1033 // Calculate cost of scalarizing
1034 InstructionCost ScalarOpCost = 0;
1035 if (ScalarIntrID) {
1036 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
1037 ScalarOpCost = TTI.getIntrinsicInstrCost(Attrs, CostKind);
1038 } else {
1039 ScalarOpCost = TTI.getArithmeticInstrCost(*FunctionalOpcode,
1040 VecTy->getScalarType(), CostKind);
1041 }
1042
1043 // The existing splats may be kept around if other instructions use them.
1044 InstructionCost CostToKeepSplats =
1045 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
1046 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
1047
1048 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
1049 << "\n");
1050 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
1051 << ", Cost of scalarizing:" << NewCost << "\n");
1052
1053 // We want to scalarize unless the vector variant actually has lower cost.
1054 if (OldCost < NewCost || !NewCost.isValid())
1055 return false;
1056
1057 // Scalarize the intrinsic
1058 ElementCount EC = cast<VectorType>(Op0->getType())->getElementCount();
1059 Value *EVL = VPI.getArgOperand(3);
1060
1061 // If the VP op might introduce UB or poison, we can scalarize it provided
1062 // that we know the EVL > 0: If the EVL is zero, then the original VP op
1063 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
1064 // scalarizing it.
1065 bool SafeToSpeculate;
1066 if (ScalarIntrID)
1067 SafeToSpeculate = Intrinsic::getFnAttributes(I.getContext(), *ScalarIntrID)
1068 .hasAttribute(Attribute::AttrKind::Speculatable);
1069 else
1070 SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
1071 *FunctionalOpcode, &VPI, nullptr, &AC, &DT);
1072 if (!SafeToSpeculate &&
1073 !isKnownNonZero(EVL, SimplifyQuery(*DL, &DT, &AC, &VPI)))
1074 return false;
1075
1076 Value *ScalarVal =
1077 ScalarIntrID
1078 ? Builder.CreateIntrinsic(VecTy->getScalarType(), *ScalarIntrID,
1079 {ScalarOp0, ScalarOp1})
1080 : Builder.CreateBinOp((Instruction::BinaryOps)(*FunctionalOpcode),
1081 ScalarOp0, ScalarOp1);
1082
1083 replaceValue(VPI, *Builder.CreateVectorSplat(EC, ScalarVal));
1084 return true;
1085 }
1086
1087 /// Match a vector op/compare/intrinsic with at least one
1088 /// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
1089 /// by insertelement.
scalarizeOpOrCmp(Instruction & I)1090 bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
1091 auto *UO = dyn_cast<UnaryOperator>(&I);
1092 auto *BO = dyn_cast<BinaryOperator>(&I);
1093 auto *CI = dyn_cast<CmpInst>(&I);
1094 auto *II = dyn_cast<IntrinsicInst>(&I);
1095 if (!UO && !BO && !CI && !II)
1096 return false;
1097
1098 // TODO: Allow intrinsics with different argument types
1099 if (II) {
1100 if (!isTriviallyVectorizable(II->getIntrinsicID()))
1101 return false;
1102 for (auto [Idx, Arg] : enumerate(II->args()))
1103 if (Arg->getType() != II->getType() &&
1104 !isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, &TTI))
1105 return false;
1106 }
1107
1108 // Do not convert the vector condition of a vector select into a scalar
1109 // condition. That may cause problems for codegen because of differences in
1110 // boolean formats and register-file transfers.
1111 // TODO: Can we account for that in the cost model?
1112 if (CI)
1113 for (User *U : I.users())
1114 if (match(U, m_Select(m_Specific(&I), m_Value(), m_Value())))
1115 return false;
1116
1117 // Match constant vectors or scalars being inserted into constant vectors:
1118 // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
1119 SmallVector<Value *> VecCs, ScalarOps;
1120 std::optional<uint64_t> Index;
1121
1122 auto Ops = II ? II->args() : I.operands();
1123 for (auto [OpNum, Op] : enumerate(Ops)) {
1124 Constant *VecC;
1125 Value *V;
1126 uint64_t InsIdx = 0;
1127 if (match(Op.get(), m_InsertElt(m_Constant(VecC), m_Value(V),
1128 m_ConstantInt(InsIdx)))) {
1129 // Bail if any inserts are out of bounds.
1130 VectorType *OpTy = cast<VectorType>(Op->getType());
1131 if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
1132 return false;
1133 // All inserts must have the same index.
1134 // TODO: Deal with mismatched index constants and variable indexes?
1135 if (!Index)
1136 Index = InsIdx;
1137 else if (InsIdx != *Index)
1138 return false;
1139 VecCs.push_back(VecC);
1140 ScalarOps.push_back(V);
1141 } else if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(),
1142 OpNum, &TTI)) {
1143 VecCs.push_back(Op.get());
1144 ScalarOps.push_back(Op.get());
1145 } else if (match(Op.get(), m_Constant(VecC))) {
1146 VecCs.push_back(VecC);
1147 ScalarOps.push_back(nullptr);
1148 } else {
1149 return false;
1150 }
1151 }
1152
1153 // Bail if all operands are constant.
1154 if (!Index.has_value())
1155 return false;
1156
1157 VectorType *VecTy = cast<VectorType>(I.getType());
1158 Type *ScalarTy = VecTy->getScalarType();
1159 assert(VecTy->isVectorTy() &&
1160 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1161 ScalarTy->isPointerTy()) &&
1162 "Unexpected types for insert element into binop or cmp");
1163
1164 unsigned Opcode = I.getOpcode();
1165 InstructionCost ScalarOpCost, VectorOpCost;
1166 if (CI) {
1167 CmpInst::Predicate Pred = CI->getPredicate();
1168 ScalarOpCost = TTI.getCmpSelInstrCost(
1169 Opcode, ScalarTy, CmpInst::makeCmpResultType(ScalarTy), Pred, CostKind);
1170 VectorOpCost = TTI.getCmpSelInstrCost(
1171 Opcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1172 } else if (UO || BO) {
1173 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, ScalarTy, CostKind);
1174 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, VecTy, CostKind);
1175 } else {
1176 IntrinsicCostAttributes ScalarICA(
1177 II->getIntrinsicID(), ScalarTy,
1178 SmallVector<Type *>(II->arg_size(), ScalarTy));
1179 ScalarOpCost = TTI.getIntrinsicInstrCost(ScalarICA, CostKind);
1180 IntrinsicCostAttributes VectorICA(
1181 II->getIntrinsicID(), VecTy,
1182 SmallVector<Type *>(II->arg_size(), VecTy));
1183 VectorOpCost = TTI.getIntrinsicInstrCost(VectorICA, CostKind);
1184 }
1185
1186 // Fold the vector constants in the original vectors into a new base vector to
1187 // get more accurate cost modelling.
1188 Value *NewVecC = nullptr;
1189 if (CI)
1190 NewVecC = simplifyCmpInst(CI->getPredicate(), VecCs[0], VecCs[1], SQ);
1191 else if (UO)
1192 NewVecC =
1193 simplifyUnOp(UO->getOpcode(), VecCs[0], UO->getFastMathFlags(), SQ);
1194 else if (BO)
1195 NewVecC = simplifyBinOp(BO->getOpcode(), VecCs[0], VecCs[1], SQ);
1196 else if (II)
1197 NewVecC = simplifyCall(II, II->getCalledOperand(), VecCs, SQ);
1198
1199 if (!NewVecC)
1200 return false;
1201
1202 // Get cost estimate for the insert element. This cost will factor into
1203 // both sequences.
1204 InstructionCost OldCost = VectorOpCost;
1205 InstructionCost NewCost =
1206 ScalarOpCost + TTI.getVectorInstrCost(Instruction::InsertElement, VecTy,
1207 CostKind, *Index, NewVecC);
1208
1209 for (auto [Idx, Op, VecC, Scalar] : enumerate(Ops, VecCs, ScalarOps)) {
1210 if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg(
1211 II->getIntrinsicID(), Idx, &TTI)))
1212 continue;
1213 InstructionCost InsertCost = TTI.getVectorInstrCost(
1214 Instruction::InsertElement, VecTy, CostKind, *Index, VecC, Scalar);
1215 OldCost += InsertCost;
1216 NewCost += !Op->hasOneUse() * InsertCost;
1217 }
1218
1219 // We want to scalarize unless the vector variant actually has lower cost.
1220 if (OldCost < NewCost || !NewCost.isValid())
1221 return false;
1222
1223 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1224 // inselt NewVecC, (scalar_op V0, V1), Index
1225 if (CI)
1226 ++NumScalarCmp;
1227 else if (UO || BO)
1228 ++NumScalarOps;
1229 else
1230 ++NumScalarIntrinsic;
1231
1232 // For constant cases, extract the scalar element, this should constant fold.
1233 for (auto [OpIdx, Scalar, VecC] : enumerate(ScalarOps, VecCs))
1234 if (!Scalar)
1235 ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
1236 cast<Constant>(VecC), Builder.getInt64(*Index));
1237
1238 Value *Scalar;
1239 if (CI)
1240 Scalar = Builder.CreateCmp(CI->getPredicate(), ScalarOps[0], ScalarOps[1]);
1241 else if (UO || BO)
1242 Scalar = Builder.CreateNAryOp(Opcode, ScalarOps);
1243 else
1244 Scalar = Builder.CreateIntrinsic(ScalarTy, II->getIntrinsicID(), ScalarOps);
1245
1246 Scalar->setName(I.getName() + ".scalar");
1247
1248 // All IR flags are safe to back-propagate. There is no potential for extra
1249 // poison to be created by the scalar instruction.
1250 if (auto *ScalarInst = dyn_cast<Instruction>(Scalar))
1251 ScalarInst->copyIRFlags(&I);
1252
1253 Value *Insert = Builder.CreateInsertElement(NewVecC, Scalar, *Index);
1254 replaceValue(I, *Insert);
1255 return true;
1256 }
1257
1258 /// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1259 /// a vector into vector operations followed by extract. Note: The SLP pass
1260 /// may miss this pattern because of implementation problems.
foldExtractedCmps(Instruction & I)1261 bool VectorCombine::foldExtractedCmps(Instruction &I) {
1262 auto *BI = dyn_cast<BinaryOperator>(&I);
1263
1264 // We are looking for a scalar binop of booleans.
1265 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1266 if (!BI || !I.getType()->isIntegerTy(1))
1267 return false;
1268
1269 // The compare predicates should match, and each compare should have a
1270 // constant operand.
1271 Value *B0 = I.getOperand(0), *B1 = I.getOperand(1);
1272 Instruction *I0, *I1;
1273 Constant *C0, *C1;
1274 CmpPredicate P0, P1;
1275 if (!match(B0, m_Cmp(P0, m_Instruction(I0), m_Constant(C0))) ||
1276 !match(B1, m_Cmp(P1, m_Instruction(I1), m_Constant(C1))))
1277 return false;
1278
1279 auto MatchingPred = CmpPredicate::getMatching(P0, P1);
1280 if (!MatchingPred)
1281 return false;
1282
1283 // The compare operands must be extracts of the same vector with constant
1284 // extract indexes.
1285 Value *X;
1286 uint64_t Index0, Index1;
1287 if (!match(I0, m_ExtractElt(m_Value(X), m_ConstantInt(Index0))) ||
1288 !match(I1, m_ExtractElt(m_Specific(X), m_ConstantInt(Index1))))
1289 return false;
1290
1291 auto *Ext0 = cast<ExtractElementInst>(I0);
1292 auto *Ext1 = cast<ExtractElementInst>(I1);
1293 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, CostKind);
1294 if (!ConvertToShuf)
1295 return false;
1296 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1297 "Unknown ExtractElementInst");
1298
1299 // The original scalar pattern is:
1300 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1301 CmpInst::Predicate Pred = *MatchingPred;
1302 unsigned CmpOpcode =
1303 CmpInst::isFPPredicate(Pred) ? Instruction::FCmp : Instruction::ICmp;
1304 auto *VecTy = dyn_cast<FixedVectorType>(X->getType());
1305 if (!VecTy)
1306 return false;
1307
1308 InstructionCost Ext0Cost =
1309 TTI.getVectorInstrCost(*Ext0, VecTy, CostKind, Index0);
1310 InstructionCost Ext1Cost =
1311 TTI.getVectorInstrCost(*Ext1, VecTy, CostKind, Index1);
1312 InstructionCost CmpCost = TTI.getCmpSelInstrCost(
1313 CmpOpcode, I0->getType(), CmpInst::makeCmpResultType(I0->getType()), Pred,
1314 CostKind);
1315
1316 InstructionCost OldCost =
1317 Ext0Cost + Ext1Cost + CmpCost * 2 +
1318 TTI.getArithmeticInstrCost(I.getOpcode(), I.getType(), CostKind);
1319
1320 // The proposed vector pattern is:
1321 // vcmp = cmp Pred X, VecC
1322 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1323 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1324 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1325 auto *CmpTy = cast<FixedVectorType>(CmpInst::makeCmpResultType(VecTy));
1326 InstructionCost NewCost = TTI.getCmpSelInstrCost(
1327 CmpOpcode, VecTy, CmpInst::makeCmpResultType(VecTy), Pred, CostKind);
1328 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1329 ShufMask[CheapIndex] = ExpensiveIndex;
1330 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, CmpTy,
1331 CmpTy, ShufMask, CostKind);
1332 NewCost += TTI.getArithmeticInstrCost(I.getOpcode(), CmpTy, CostKind);
1333 NewCost += TTI.getVectorInstrCost(*Ext0, CmpTy, CostKind, CheapIndex);
1334 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1335 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1336
1337 // Aggressively form vector ops if the cost is equal because the transform
1338 // may enable further optimization.
1339 // Codegen can reverse this transform (scalarize) if it was not profitable.
1340 if (OldCost < NewCost || !NewCost.isValid())
1341 return false;
1342
1343 // Create a vector constant from the 2 scalar constants.
1344 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1345 PoisonValue::get(VecTy->getElementType()));
1346 CmpC[Index0] = C0;
1347 CmpC[Index1] = C1;
1348 Value *VCmp = Builder.CreateCmp(Pred, X, ConstantVector::get(CmpC));
1349 Value *Shuf = createShiftShuffle(VCmp, ExpensiveIndex, CheapIndex, Builder);
1350 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1351 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1352 Value *VecLogic = Builder.CreateBinOp(BI->getOpcode(), LHS, RHS);
1353 Value *NewExt = Builder.CreateExtractElement(VecLogic, CheapIndex);
1354 replaceValue(I, *NewExt);
1355 ++NumVecCmpBO;
1356 return true;
1357 }
1358
analyzeCostOfVecReduction(const IntrinsicInst & II,TTI::TargetCostKind CostKind,const TargetTransformInfo & TTI,InstructionCost & CostBeforeReduction,InstructionCost & CostAfterReduction)1359 static void analyzeCostOfVecReduction(const IntrinsicInst &II,
1360 TTI::TargetCostKind CostKind,
1361 const TargetTransformInfo &TTI,
1362 InstructionCost &CostBeforeReduction,
1363 InstructionCost &CostAfterReduction) {
1364 Instruction *Op0, *Op1;
1365 auto *RedOp = dyn_cast<Instruction>(II.getOperand(0));
1366 auto *VecRedTy = cast<VectorType>(II.getOperand(0)->getType());
1367 unsigned ReductionOpc =
1368 getArithmeticReductionInstruction(II.getIntrinsicID());
1369 if (RedOp && match(RedOp, m_ZExtOrSExt(m_Value()))) {
1370 bool IsUnsigned = isa<ZExtInst>(RedOp);
1371 auto *ExtType = cast<VectorType>(RedOp->getOperand(0)->getType());
1372
1373 CostBeforeReduction =
1374 TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, ExtType,
1375 TTI::CastContextHint::None, CostKind, RedOp);
1376 CostAfterReduction =
1377 TTI.getExtendedReductionCost(ReductionOpc, IsUnsigned, II.getType(),
1378 ExtType, FastMathFlags(), CostKind);
1379 return;
1380 }
1381 if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
1382 match(RedOp,
1383 m_ZExtOrSExt(m_Mul(m_Instruction(Op0), m_Instruction(Op1)))) &&
1384 match(Op0, m_ZExtOrSExt(m_Value())) &&
1385 Op0->getOpcode() == Op1->getOpcode() &&
1386 Op0->getOperand(0)->getType() == Op1->getOperand(0)->getType() &&
1387 (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
1388 // Matched reduce.add(ext(mul(ext(A), ext(B)))
1389 bool IsUnsigned = isa<ZExtInst>(Op0);
1390 auto *ExtType = cast<VectorType>(Op0->getOperand(0)->getType());
1391 VectorType *MulType = VectorType::get(Op0->getType(), VecRedTy);
1392
1393 InstructionCost ExtCost =
1394 TTI.getCastInstrCost(Op0->getOpcode(), MulType, ExtType,
1395 TTI::CastContextHint::None, CostKind, Op0);
1396 InstructionCost MulCost =
1397 TTI.getArithmeticInstrCost(Instruction::Mul, MulType, CostKind);
1398 InstructionCost Ext2Cost =
1399 TTI.getCastInstrCost(RedOp->getOpcode(), VecRedTy, MulType,
1400 TTI::CastContextHint::None, CostKind, RedOp);
1401
1402 CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1403 CostAfterReduction =
1404 TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
1405 return;
1406 }
1407 CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
1408 std::nullopt, CostKind);
1409 }
1410
foldBinopOfReductions(Instruction & I)1411 bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1412 Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(&I)->getOpcode();
1413 Intrinsic::ID ReductionIID = getReductionForBinop(BinOpOpc);
1414 if (BinOpOpc == Instruction::Sub)
1415 ReductionIID = Intrinsic::vector_reduce_add;
1416 if (ReductionIID == Intrinsic::not_intrinsic)
1417 return false;
1418
1419 auto checkIntrinsicAndGetItsArgument = [](Value *V,
1420 Intrinsic::ID IID) -> Value * {
1421 auto *II = dyn_cast<IntrinsicInst>(V);
1422 if (!II)
1423 return nullptr;
1424 if (II->getIntrinsicID() == IID && II->hasOneUse())
1425 return II->getArgOperand(0);
1426 return nullptr;
1427 };
1428
1429 Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(0), ReductionIID);
1430 if (!V0)
1431 return false;
1432 Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(1), ReductionIID);
1433 if (!V1)
1434 return false;
1435
1436 auto *VTy = cast<VectorType>(V0->getType());
1437 if (V1->getType() != VTy)
1438 return false;
1439 const auto &II0 = *cast<IntrinsicInst>(I.getOperand(0));
1440 const auto &II1 = *cast<IntrinsicInst>(I.getOperand(1));
1441 unsigned ReductionOpc =
1442 getArithmeticReductionInstruction(II0.getIntrinsicID());
1443
1444 InstructionCost OldCost = 0;
1445 InstructionCost NewCost = 0;
1446 InstructionCost CostOfRedOperand0 = 0;
1447 InstructionCost CostOfRed0 = 0;
1448 InstructionCost CostOfRedOperand1 = 0;
1449 InstructionCost CostOfRed1 = 0;
1450 analyzeCostOfVecReduction(II0, CostKind, TTI, CostOfRedOperand0, CostOfRed0);
1451 analyzeCostOfVecReduction(II1, CostKind, TTI, CostOfRedOperand1, CostOfRed1);
1452 OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(&I, CostKind);
1453 NewCost =
1454 CostOfRedOperand0 + CostOfRedOperand1 +
1455 TTI.getArithmeticInstrCost(BinOpOpc, VTy, CostKind) +
1456 TTI.getArithmeticReductionCost(ReductionOpc, VTy, std::nullopt, CostKind);
1457 if (NewCost >= OldCost || !NewCost.isValid())
1458 return false;
1459
1460 LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1461 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1462 << "\n");
1463 Value *VectorBO;
1464 if (BinOpOpc == Instruction::Or)
1465 VectorBO = Builder.CreateOr(V0, V1, "",
1466 cast<PossiblyDisjointInst>(I).isDisjoint());
1467 else
1468 VectorBO = Builder.CreateBinOp(BinOpOpc, V0, V1);
1469
1470 Instruction *Rdx = Builder.CreateIntrinsic(ReductionIID, {VTy}, {VectorBO});
1471 replaceValue(I, *Rdx);
1472 return true;
1473 }
1474
1475 // Check if memory loc modified between two instrs in the same BB
isMemModifiedBetween(BasicBlock::iterator Begin,BasicBlock::iterator End,const MemoryLocation & Loc,AAResults & AA)1476 static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1477 BasicBlock::iterator End,
1478 const MemoryLocation &Loc, AAResults &AA) {
1479 unsigned NumScanned = 0;
1480 return std::any_of(Begin, End, [&](const Instruction &Instr) {
1481 return isModSet(AA.getModRefInfo(&Instr, Loc)) ||
1482 ++NumScanned > MaxInstrsToScan;
1483 });
1484 }
1485
1486 namespace {
1487 /// Helper class to indicate whether a vector index can be safely scalarized and
1488 /// if a freeze needs to be inserted.
1489 class ScalarizationResult {
1490 enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1491
1492 StatusTy Status;
1493 Value *ToFreeze;
1494
ScalarizationResult(StatusTy Status,Value * ToFreeze=nullptr)1495 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1496 : Status(Status), ToFreeze(ToFreeze) {}
1497
1498 public:
1499 ScalarizationResult(const ScalarizationResult &Other) = default;
~ScalarizationResult()1500 ~ScalarizationResult() {
1501 assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1502 }
1503
unsafe()1504 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
safe()1505 static ScalarizationResult safe() { return {StatusTy::Safe}; }
safeWithFreeze(Value * ToFreeze)1506 static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1507 return {StatusTy::SafeWithFreeze, ToFreeze};
1508 }
1509
1510 /// Returns true if the index can be scalarize without requiring a freeze.
isSafe() const1511 bool isSafe() const { return Status == StatusTy::Safe; }
1512 /// Returns true if the index cannot be scalarized.
isUnsafe() const1513 bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1514 /// Returns true if the index can be scalarize, but requires inserting a
1515 /// freeze.
isSafeWithFreeze() const1516 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1517
1518 /// Reset the state of Unsafe and clear ToFreze if set.
discard()1519 void discard() {
1520 ToFreeze = nullptr;
1521 Status = StatusTy::Unsafe;
1522 }
1523
1524 /// Freeze the ToFreeze and update the use in \p User to use it.
freeze(IRBuilderBase & Builder,Instruction & UserI)1525 void freeze(IRBuilderBase &Builder, Instruction &UserI) {
1526 assert(isSafeWithFreeze() &&
1527 "should only be used when freezing is required");
1528 assert(is_contained(ToFreeze->users(), &UserI) &&
1529 "UserI must be a user of ToFreeze");
1530 IRBuilder<>::InsertPointGuard Guard(Builder);
1531 Builder.SetInsertPoint(cast<Instruction>(&UserI));
1532 Value *Frozen =
1533 Builder.CreateFreeze(ToFreeze, ToFreeze->getName() + ".frozen");
1534 for (Use &U : make_early_inc_range((UserI.operands())))
1535 if (U.get() == ToFreeze)
1536 U.set(Frozen);
1537
1538 ToFreeze = nullptr;
1539 }
1540 };
1541 } // namespace
1542
1543 /// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1544 /// Idx. \p Idx must access a valid vector element.
canScalarizeAccess(VectorType * VecTy,Value * Idx,Instruction * CtxI,AssumptionCache & AC,const DominatorTree & DT)1545 static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1546 Instruction *CtxI,
1547 AssumptionCache &AC,
1548 const DominatorTree &DT) {
1549 // We do checks for both fixed vector types and scalable vector types.
1550 // This is the number of elements of fixed vector types,
1551 // or the minimum number of elements of scalable vector types.
1552 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1553 unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1554
1555 if (auto *C = dyn_cast<ConstantInt>(Idx)) {
1556 if (C->getValue().ult(NumElements))
1557 return ScalarizationResult::safe();
1558 return ScalarizationResult::unsafe();
1559 }
1560
1561 // Always unsafe if the index type can't handle all inbound values.
1562 if (!llvm::isUIntN(IntWidth, NumElements))
1563 return ScalarizationResult::unsafe();
1564
1565 APInt Zero(IntWidth, 0);
1566 APInt MaxElts(IntWidth, NumElements);
1567 ConstantRange ValidIndices(Zero, MaxElts);
1568 ConstantRange IdxRange(IntWidth, true);
1569
1570 if (isGuaranteedNotToBePoison(Idx, &AC)) {
1571 if (ValidIndices.contains(computeConstantRange(Idx, /* ForSigned */ false,
1572 true, &AC, CtxI, &DT)))
1573 return ScalarizationResult::safe();
1574 return ScalarizationResult::unsafe();
1575 }
1576
1577 // If the index may be poison, check if we can insert a freeze before the
1578 // range of the index is restricted.
1579 Value *IdxBase;
1580 ConstantInt *CI;
1581 if (match(Idx, m_And(m_Value(IdxBase), m_ConstantInt(CI)))) {
1582 IdxRange = IdxRange.binaryAnd(CI->getValue());
1583 } else if (match(Idx, m_URem(m_Value(IdxBase), m_ConstantInt(CI)))) {
1584 IdxRange = IdxRange.urem(CI->getValue());
1585 }
1586
1587 if (ValidIndices.contains(IdxRange))
1588 return ScalarizationResult::safeWithFreeze(IdxBase);
1589 return ScalarizationResult::unsafe();
1590 }
1591
1592 /// The memory operation on a vector of \p ScalarType had alignment of
1593 /// \p VectorAlignment. Compute the maximal, but conservatively correct,
1594 /// alignment that will be valid for the memory operation on a single scalar
1595 /// element of the same type with index \p Idx.
computeAlignmentAfterScalarization(Align VectorAlignment,Type * ScalarType,Value * Idx,const DataLayout & DL)1596 static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1597 Type *ScalarType, Value *Idx,
1598 const DataLayout &DL) {
1599 if (auto *C = dyn_cast<ConstantInt>(Idx))
1600 return commonAlignment(VectorAlignment,
1601 C->getZExtValue() * DL.getTypeStoreSize(ScalarType));
1602 return commonAlignment(VectorAlignment, DL.getTypeStoreSize(ScalarType));
1603 }
1604
1605 // Combine patterns like:
1606 // %0 = load <4 x i32>, <4 x i32>* %a
1607 // %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1608 // store <4 x i32> %1, <4 x i32>* %a
1609 // to:
1610 // %0 = bitcast <4 x i32>* %a to i32*
1611 // %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1612 // store i32 %b, i32* %1
foldSingleElementStore(Instruction & I)1613 bool VectorCombine::foldSingleElementStore(Instruction &I) {
1614 auto *SI = cast<StoreInst>(&I);
1615 if (!SI->isSimple() || !isa<VectorType>(SI->getValueOperand()->getType()))
1616 return false;
1617
1618 // TODO: Combine more complicated patterns (multiple insert) by referencing
1619 // TargetTransformInfo.
1620 Instruction *Source;
1621 Value *NewElement;
1622 Value *Idx;
1623 if (!match(SI->getValueOperand(),
1624 m_InsertElt(m_Instruction(Source), m_Value(NewElement),
1625 m_Value(Idx))))
1626 return false;
1627
1628 if (auto *Load = dyn_cast<LoadInst>(Source)) {
1629 auto VecTy = cast<VectorType>(SI->getValueOperand()->getType());
1630 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1631 // Don't optimize for atomic/volatile load or store. Ensure memory is not
1632 // modified between, vector type matches store size, and index is inbounds.
1633 if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1634 !DL->typeSizeEqualsStoreSize(Load->getType()->getScalarType()) ||
1635 SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1636 return false;
1637
1638 auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, Load, AC, DT);
1639 if (ScalarizableIdx.isUnsafe() ||
1640 isMemModifiedBetween(Load->getIterator(), SI->getIterator(),
1641 MemoryLocation::get(SI), AA))
1642 return false;
1643
1644 // Ensure we add the load back to the worklist BEFORE its users so they can
1645 // erased in the correct order.
1646 Worklist.push(Load);
1647
1648 if (ScalarizableIdx.isSafeWithFreeze())
1649 ScalarizableIdx.freeze(Builder, *cast<Instruction>(Idx));
1650 Value *GEP = Builder.CreateInBoundsGEP(
1651 SI->getValueOperand()->getType(), SI->getPointerOperand(),
1652 {ConstantInt::get(Idx->getType(), 0), Idx});
1653 StoreInst *NSI = Builder.CreateStore(NewElement, GEP);
1654 NSI->copyMetadata(*SI);
1655 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1656 std::max(SI->getAlign(), Load->getAlign()), NewElement->getType(), Idx,
1657 *DL);
1658 NSI->setAlignment(ScalarOpAlignment);
1659 replaceValue(I, *NSI);
1660 eraseInstruction(I);
1661 return true;
1662 }
1663
1664 return false;
1665 }
1666
1667 /// Try to scalarize vector loads feeding extractelement instructions.
scalarizeLoadExtract(Instruction & I)1668 bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1669 Value *Ptr;
1670 if (!match(&I, m_Load(m_Value(Ptr))))
1671 return false;
1672
1673 auto *LI = cast<LoadInst>(&I);
1674 auto *VecTy = cast<VectorType>(LI->getType());
1675 if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(VecTy->getScalarType()))
1676 return false;
1677
1678 InstructionCost OriginalCost =
1679 TTI.getMemoryOpCost(Instruction::Load, VecTy, LI->getAlign(),
1680 LI->getPointerAddressSpace(), CostKind);
1681 InstructionCost ScalarizedCost = 0;
1682
1683 Instruction *LastCheckedInst = LI;
1684 unsigned NumInstChecked = 0;
1685 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1686 auto FailureGuard = make_scope_exit([&]() {
1687 // If the transform is aborted, discard the ScalarizationResults.
1688 for (auto &Pair : NeedFreeze)
1689 Pair.second.discard();
1690 });
1691
1692 // Check if all users of the load are extracts with no memory modifications
1693 // between the load and the extract. Compute the cost of both the original
1694 // code and the scalarized version.
1695 for (User *U : LI->users()) {
1696 auto *UI = dyn_cast<ExtractElementInst>(U);
1697 if (!UI || UI->getParent() != LI->getParent())
1698 return false;
1699
1700 // If any extract is waiting to be erased, then bail out as this will
1701 // distort the cost calculation and possibly lead to infinite loops.
1702 if (UI->use_empty())
1703 return false;
1704
1705 // Check if any instruction between the load and the extract may modify
1706 // memory.
1707 if (LastCheckedInst->comesBefore(UI)) {
1708 for (Instruction &I :
1709 make_range(std::next(LI->getIterator()), UI->getIterator())) {
1710 // Bail out if we reached the check limit or the instruction may write
1711 // to memory.
1712 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
1713 return false;
1714 NumInstChecked++;
1715 }
1716 LastCheckedInst = UI;
1717 }
1718
1719 auto ScalarIdx =
1720 canScalarizeAccess(VecTy, UI->getIndexOperand(), LI, AC, DT);
1721 if (ScalarIdx.isUnsafe())
1722 return false;
1723 if (ScalarIdx.isSafeWithFreeze()) {
1724 NeedFreeze.try_emplace(UI, ScalarIdx);
1725 ScalarIdx.discard();
1726 }
1727
1728 auto *Index = dyn_cast<ConstantInt>(UI->getIndexOperand());
1729 OriginalCost +=
1730 TTI.getVectorInstrCost(Instruction::ExtractElement, VecTy, CostKind,
1731 Index ? Index->getZExtValue() : -1);
1732 ScalarizedCost +=
1733 TTI.getMemoryOpCost(Instruction::Load, VecTy->getElementType(),
1734 Align(1), LI->getPointerAddressSpace(), CostKind);
1735 ScalarizedCost += TTI.getAddressComputationCost(VecTy->getElementType());
1736 }
1737
1738 LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
1739 << "\n LoadExtractCost: " << OriginalCost
1740 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
1741
1742 if (ScalarizedCost >= OriginalCost)
1743 return false;
1744
1745 // Ensure we add the load back to the worklist BEFORE its users so they can
1746 // erased in the correct order.
1747 Worklist.push(LI);
1748
1749 // Replace extracts with narrow scalar loads.
1750 for (User *U : LI->users()) {
1751 auto *EI = cast<ExtractElementInst>(U);
1752 Value *Idx = EI->getIndexOperand();
1753
1754 // Insert 'freeze' for poison indexes.
1755 auto It = NeedFreeze.find(EI);
1756 if (It != NeedFreeze.end())
1757 It->second.freeze(Builder, *cast<Instruction>(Idx));
1758
1759 Builder.SetInsertPoint(EI);
1760 Value *GEP =
1761 Builder.CreateInBoundsGEP(VecTy, Ptr, {Builder.getInt32(0), Idx});
1762 auto *NewLoad = cast<LoadInst>(Builder.CreateLoad(
1763 VecTy->getElementType(), GEP, EI->getName() + ".scalar"));
1764
1765 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1766 LI->getAlign(), VecTy->getElementType(), Idx, *DL);
1767 NewLoad->setAlignment(ScalarOpAlignment);
1768
1769 replaceValue(*EI, *NewLoad);
1770 }
1771
1772 FailureGuard.release();
1773 return true;
1774 }
1775
scalarizeExtExtract(Instruction & I)1776 bool VectorCombine::scalarizeExtExtract(Instruction &I) {
1777 auto *Ext = dyn_cast<ZExtInst>(&I);
1778 if (!Ext)
1779 return false;
1780
1781 // Try to convert a vector zext feeding only extracts to a set of scalar
1782 // (Src << ExtIdx *Size) & (Size -1)
1783 // if profitable .
1784 auto *SrcTy = dyn_cast<FixedVectorType>(Ext->getOperand(0)->getType());
1785 if (!SrcTy)
1786 return false;
1787 auto *DstTy = cast<FixedVectorType>(Ext->getType());
1788
1789 Type *ScalarDstTy = DstTy->getElementType();
1790 if (DL->getTypeSizeInBits(SrcTy) != DL->getTypeSizeInBits(ScalarDstTy))
1791 return false;
1792
1793 InstructionCost VectorCost =
1794 TTI.getCastInstrCost(Instruction::ZExt, DstTy, SrcTy,
1795 TTI::CastContextHint::None, CostKind, Ext);
1796 unsigned ExtCnt = 0;
1797 bool ExtLane0 = false;
1798 for (User *U : Ext->users()) {
1799 const APInt *Idx;
1800 if (!match(U, m_ExtractElt(m_Value(), m_APInt(Idx))))
1801 return false;
1802 if (cast<Instruction>(U)->use_empty())
1803 continue;
1804 ExtCnt += 1;
1805 ExtLane0 |= Idx->isZero();
1806 VectorCost += TTI.getVectorInstrCost(Instruction::ExtractElement, DstTy,
1807 CostKind, Idx->getZExtValue(), U);
1808 }
1809
1810 InstructionCost ScalarCost =
1811 ExtCnt * TTI.getArithmeticInstrCost(
1812 Instruction::And, ScalarDstTy, CostKind,
1813 {TTI::OK_AnyValue, TTI::OP_None},
1814 {TTI::OK_NonUniformConstantValue, TTI::OP_None}) +
1815 (ExtCnt - ExtLane0) *
1816 TTI.getArithmeticInstrCost(
1817 Instruction::LShr, ScalarDstTy, CostKind,
1818 {TTI::OK_AnyValue, TTI::OP_None},
1819 {TTI::OK_NonUniformConstantValue, TTI::OP_None});
1820 if (ScalarCost > VectorCost)
1821 return false;
1822
1823 Value *ScalarV = Ext->getOperand(0);
1824 if (!isGuaranteedNotToBePoison(ScalarV, &AC, dyn_cast<Instruction>(ScalarV),
1825 &DT))
1826 ScalarV = Builder.CreateFreeze(ScalarV);
1827 ScalarV = Builder.CreateBitCast(
1828 ScalarV,
1829 IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));
1830 uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(SrcTy->getElementType());
1831 uint64_t EltBitMask = (1ull << SrcEltSizeInBits) - 1;
1832 uint64_t TotalBits = DL->getTypeSizeInBits(SrcTy);
1833 Type *PackedTy = IntegerType::get(SrcTy->getContext(), TotalBits);
1834 Value *Mask = ConstantInt::get(PackedTy, EltBitMask);
1835 for (User *U : Ext->users()) {
1836 auto *Extract = cast<ExtractElementInst>(U);
1837 uint64_t Idx =
1838 cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
1839 uint64_t ShiftAmt =
1840 DL->isBigEndian()
1841 ? (TotalBits - SrcEltSizeInBits - Idx * SrcEltSizeInBits)
1842 : (Idx * SrcEltSizeInBits);
1843 Value *LShr = Builder.CreateLShr(ScalarV, ShiftAmt);
1844 Value *And = Builder.CreateAnd(LShr, Mask);
1845 U->replaceAllUsesWith(And);
1846 }
1847 return true;
1848 }
1849
1850 /// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
1851 /// to "(bitcast (concat X, Y))"
1852 /// where X/Y are bitcasted from i1 mask vectors.
foldConcatOfBoolMasks(Instruction & I)1853 bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
1854 Type *Ty = I.getType();
1855 if (!Ty->isIntegerTy())
1856 return false;
1857
1858 // TODO: Add big endian test coverage
1859 if (DL->isBigEndian())
1860 return false;
1861
1862 // Restrict to disjoint cases so the mask vectors aren't overlapping.
1863 Instruction *X, *Y;
1864 if (!match(&I, m_DisjointOr(m_Instruction(X), m_Instruction(Y))))
1865 return false;
1866
1867 // Allow both sources to contain shl, to handle more generic pattern:
1868 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
1869 Value *SrcX;
1870 uint64_t ShAmtX = 0;
1871 if (!match(X, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX)))))) &&
1872 !match(X, m_OneUse(
1873 m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcX))))),
1874 m_ConstantInt(ShAmtX)))))
1875 return false;
1876
1877 Value *SrcY;
1878 uint64_t ShAmtY = 0;
1879 if (!match(Y, m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY)))))) &&
1880 !match(Y, m_OneUse(
1881 m_Shl(m_OneUse(m_ZExt(m_OneUse(m_BitCast(m_Value(SrcY))))),
1882 m_ConstantInt(ShAmtY)))))
1883 return false;
1884
1885 // Canonicalize larger shift to the RHS.
1886 if (ShAmtX > ShAmtY) {
1887 std::swap(X, Y);
1888 std::swap(SrcX, SrcY);
1889 std::swap(ShAmtX, ShAmtY);
1890 }
1891
1892 // Ensure both sources are matching vXi1 bool mask types, and that the shift
1893 // difference is the mask width so they can be easily concatenated together.
1894 uint64_t ShAmtDiff = ShAmtY - ShAmtX;
1895 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
1896 unsigned BitWidth = Ty->getPrimitiveSizeInBits();
1897 auto *MaskTy = dyn_cast<FixedVectorType>(SrcX->getType());
1898 if (!MaskTy || SrcX->getType() != SrcY->getType() ||
1899 !MaskTy->getElementType()->isIntegerTy(1) ||
1900 MaskTy->getNumElements() != ShAmtDiff ||
1901 MaskTy->getNumElements() > (BitWidth / 2))
1902 return false;
1903
1904 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(MaskTy);
1905 auto *ConcatIntTy =
1906 Type::getIntNTy(Ty->getContext(), ConcatTy->getNumElements());
1907 auto *MaskIntTy = Type::getIntNTy(Ty->getContext(), ShAmtDiff);
1908
1909 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
1910 std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
1911
1912 // TODO: Is it worth supporting multi use cases?
1913 InstructionCost OldCost = 0;
1914 OldCost += TTI.getArithmeticInstrCost(Instruction::Or, Ty, CostKind);
1915 OldCost +=
1916 NumSHL * TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1917 OldCost += 2 * TTI.getCastInstrCost(Instruction::ZExt, Ty, MaskIntTy,
1918 TTI::CastContextHint::None, CostKind);
1919 OldCost += 2 * TTI.getCastInstrCost(Instruction::BitCast, MaskIntTy, MaskTy,
1920 TTI::CastContextHint::None, CostKind);
1921
1922 InstructionCost NewCost = 0;
1923 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ConcatTy,
1924 MaskTy, ConcatMask, CostKind);
1925 NewCost += TTI.getCastInstrCost(Instruction::BitCast, ConcatIntTy, ConcatTy,
1926 TTI::CastContextHint::None, CostKind);
1927 if (Ty != ConcatIntTy)
1928 NewCost += TTI.getCastInstrCost(Instruction::ZExt, Ty, ConcatIntTy,
1929 TTI::CastContextHint::None, CostKind);
1930 if (ShAmtX > 0)
1931 NewCost += TTI.getArithmeticInstrCost(Instruction::Shl, Ty, CostKind);
1932
1933 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
1934 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1935 << "\n");
1936
1937 if (NewCost > OldCost)
1938 return false;
1939
1940 // Build bool mask concatenation, bitcast back to scalar integer, and perform
1941 // any residual zero-extension or shifting.
1942 Value *Concat = Builder.CreateShuffleVector(SrcX, SrcY, ConcatMask);
1943 Worklist.pushValue(Concat);
1944
1945 Value *Result = Builder.CreateBitCast(Concat, ConcatIntTy);
1946
1947 if (Ty != ConcatIntTy) {
1948 Worklist.pushValue(Result);
1949 Result = Builder.CreateZExt(Result, Ty);
1950 }
1951
1952 if (ShAmtX > 0) {
1953 Worklist.pushValue(Result);
1954 Result = Builder.CreateShl(Result, ShAmtX);
1955 }
1956
1957 replaceValue(I, *Result);
1958 return true;
1959 }
1960
1961 /// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
1962 /// --> "binop (shuffle), (shuffle)".
foldPermuteOfBinops(Instruction & I)1963 bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
1964 BinaryOperator *BinOp;
1965 ArrayRef<int> OuterMask;
1966 if (!match(&I,
1967 m_Shuffle(m_OneUse(m_BinOp(BinOp)), m_Undef(), m_Mask(OuterMask))))
1968 return false;
1969
1970 // Don't introduce poison into div/rem.
1971 if (BinOp->isIntDivRem() && llvm::is_contained(OuterMask, PoisonMaskElem))
1972 return false;
1973
1974 Value *Op00, *Op01, *Op10, *Op11;
1975 ArrayRef<int> Mask0, Mask1;
1976 bool Match0 =
1977 match(BinOp->getOperand(0),
1978 m_OneUse(m_Shuffle(m_Value(Op00), m_Value(Op01), m_Mask(Mask0))));
1979 bool Match1 =
1980 match(BinOp->getOperand(1),
1981 m_OneUse(m_Shuffle(m_Value(Op10), m_Value(Op11), m_Mask(Mask1))));
1982 if (!Match0 && !Match1)
1983 return false;
1984
1985 Op00 = Match0 ? Op00 : BinOp->getOperand(0);
1986 Op01 = Match0 ? Op01 : BinOp->getOperand(0);
1987 Op10 = Match1 ? Op10 : BinOp->getOperand(1);
1988 Op11 = Match1 ? Op11 : BinOp->getOperand(1);
1989
1990 Instruction::BinaryOps Opcode = BinOp->getOpcode();
1991 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
1992 auto *BinOpTy = dyn_cast<FixedVectorType>(BinOp->getType());
1993 auto *Op0Ty = dyn_cast<FixedVectorType>(Op00->getType());
1994 auto *Op1Ty = dyn_cast<FixedVectorType>(Op10->getType());
1995 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1996 return false;
1997
1998 unsigned NumSrcElts = BinOpTy->getNumElements();
1999
2000 // Don't accept shuffles that reference the second operand in
2001 // div/rem or if its an undef arg.
2002 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(I.getOperand(1))) &&
2003 any_of(OuterMask, [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
2004 return false;
2005
2006 // Merge outer / inner (or identity if no match) shuffles.
2007 SmallVector<int> NewMask0, NewMask1;
2008 for (int M : OuterMask) {
2009 if (M < 0 || M >= (int)NumSrcElts) {
2010 NewMask0.push_back(PoisonMaskElem);
2011 NewMask1.push_back(PoisonMaskElem);
2012 } else {
2013 NewMask0.push_back(Match0 ? Mask0[M] : M);
2014 NewMask1.push_back(Match1 ? Mask1[M] : M);
2015 }
2016 }
2017
2018 unsigned NumOpElts = Op0Ty->getNumElements();
2019 bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
2020 all_of(NewMask0, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2021 ShuffleVectorInst::isIdentityMask(NewMask0, NumOpElts);
2022 bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
2023 all_of(NewMask1, [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2024 ShuffleVectorInst::isIdentityMask(NewMask1, NumOpElts);
2025
2026 // Try to merge shuffles across the binop if the new shuffles are not costly.
2027 InstructionCost OldCost =
2028 TTI.getArithmeticInstrCost(Opcode, BinOpTy, CostKind) +
2029 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc, ShuffleDstTy,
2030 BinOpTy, OuterMask, CostKind, 0, nullptr, {BinOp}, &I);
2031 if (Match0)
2032 OldCost += TTI.getShuffleCost(
2033 TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy, Op0Ty, Mask0, CostKind,
2034 0, nullptr, {Op00, Op01}, cast<Instruction>(BinOp->getOperand(0)));
2035 if (Match1)
2036 OldCost += TTI.getShuffleCost(
2037 TargetTransformInfo::SK_PermuteTwoSrc, BinOpTy, Op1Ty, Mask1, CostKind,
2038 0, nullptr, {Op10, Op11}, cast<Instruction>(BinOp->getOperand(1)));
2039
2040 InstructionCost NewCost =
2041 TTI.getArithmeticInstrCost(Opcode, ShuffleDstTy, CostKind);
2042
2043 if (!IsIdentity0)
2044 NewCost +=
2045 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
2046 Op0Ty, NewMask0, CostKind, 0, nullptr, {Op00, Op01});
2047 if (!IsIdentity1)
2048 NewCost +=
2049 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
2050 Op1Ty, NewMask1, CostKind, 0, nullptr, {Op10, Op11});
2051
2052 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
2053 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2054 << "\n");
2055
2056 // If costs are equal, still fold as we reduce instruction count.
2057 if (NewCost > OldCost)
2058 return false;
2059
2060 Value *LHS =
2061 IsIdentity0 ? Op00 : Builder.CreateShuffleVector(Op00, Op01, NewMask0);
2062 Value *RHS =
2063 IsIdentity1 ? Op10 : Builder.CreateShuffleVector(Op10, Op11, NewMask1);
2064 Value *NewBO = Builder.CreateBinOp(Opcode, LHS, RHS);
2065
2066 // Intersect flags from the old binops.
2067 if (auto *NewInst = dyn_cast<Instruction>(NewBO))
2068 NewInst->copyIRFlags(BinOp);
2069
2070 Worklist.pushValue(LHS);
2071 Worklist.pushValue(RHS);
2072 replaceValue(I, *NewBO);
2073 return true;
2074 }
2075
2076 /// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
2077 /// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
foldShuffleOfBinops(Instruction & I)2078 bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
2079 ArrayRef<int> OldMask;
2080 Instruction *LHS, *RHS;
2081 if (!match(&I, m_Shuffle(m_OneUse(m_Instruction(LHS)),
2082 m_OneUse(m_Instruction(RHS)), m_Mask(OldMask))))
2083 return false;
2084
2085 // TODO: Add support for addlike etc.
2086 if (LHS->getOpcode() != RHS->getOpcode())
2087 return false;
2088
2089 Value *X, *Y, *Z, *W;
2090 bool IsCommutative = false;
2091 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
2092 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
2093 if (match(LHS, m_BinOp(m_Value(X), m_Value(Y))) &&
2094 match(RHS, m_BinOp(m_Value(Z), m_Value(W)))) {
2095 auto *BO = cast<BinaryOperator>(LHS);
2096 // Don't introduce poison into div/rem.
2097 if (llvm::is_contained(OldMask, PoisonMaskElem) && BO->isIntDivRem())
2098 return false;
2099 IsCommutative = BinaryOperator::isCommutative(BO->getOpcode());
2100 } else if (match(LHS, m_Cmp(PredLHS, m_Value(X), m_Value(Y))) &&
2101 match(RHS, m_Cmp(PredRHS, m_Value(Z), m_Value(W))) &&
2102 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
2103 IsCommutative = cast<CmpInst>(LHS)->isCommutative();
2104 } else
2105 return false;
2106
2107 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2108 auto *BinResTy = dyn_cast<FixedVectorType>(LHS->getType());
2109 auto *BinOpTy = dyn_cast<FixedVectorType>(X->getType());
2110 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
2111 return false;
2112
2113 unsigned NumSrcElts = BinOpTy->getNumElements();
2114
2115 // If we have something like "add X, Y" and "add Z, X", swap ops to match.
2116 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
2117 std::swap(X, Y);
2118
2119 auto ConvertToUnary = [NumSrcElts](int &M) {
2120 if (M >= (int)NumSrcElts)
2121 M -= NumSrcElts;
2122 };
2123
2124 SmallVector<int> NewMask0(OldMask);
2125 TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
2126 if (X == Z) {
2127 llvm::for_each(NewMask0, ConvertToUnary);
2128 SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
2129 Z = PoisonValue::get(BinOpTy);
2130 }
2131
2132 SmallVector<int> NewMask1(OldMask);
2133 TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
2134 if (Y == W) {
2135 llvm::for_each(NewMask1, ConvertToUnary);
2136 SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
2137 W = PoisonValue::get(BinOpTy);
2138 }
2139
2140 // Try to replace a binop with a shuffle if the shuffle is not costly.
2141 InstructionCost OldCost =
2142 TTI.getInstructionCost(LHS, CostKind) +
2143 TTI.getInstructionCost(RHS, CostKind) +
2144 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
2145 BinResTy, OldMask, CostKind, 0, nullptr, {LHS, RHS},
2146 &I);
2147
2148 // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns
2149 // where one use shuffles have gotten split across the binop/cmp. These
2150 // often allow a major reduction in total cost that wouldn't happen as
2151 // individual folds.
2152 auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int> Mask,
2153 TTI::TargetCostKind CostKind) -> bool {
2154 Value *InnerOp;
2155 ArrayRef<int> InnerMask;
2156 if (match(Op, m_OneUse(m_Shuffle(m_Value(InnerOp), m_Undef(),
2157 m_Mask(InnerMask)))) &&
2158 InnerOp->getType() == Op->getType() &&
2159 all_of(InnerMask,
2160 [NumSrcElts](int M) { return M < (int)NumSrcElts; })) {
2161 for (int &M : Mask)
2162 if (Offset <= M && M < (int)(Offset + NumSrcElts)) {
2163 M = InnerMask[M - Offset];
2164 M = 0 <= M ? M + Offset : M;
2165 }
2166 OldCost += TTI.getInstructionCost(cast<Instruction>(Op), CostKind);
2167 Op = InnerOp;
2168 return true;
2169 }
2170 return false;
2171 };
2172 bool ReducedInstCount = false;
2173 ReducedInstCount |= MergeInner(X, 0, NewMask0, CostKind);
2174 ReducedInstCount |= MergeInner(Y, 0, NewMask1, CostKind);
2175 ReducedInstCount |= MergeInner(Z, NumSrcElts, NewMask0, CostKind);
2176 ReducedInstCount |= MergeInner(W, NumSrcElts, NewMask1, CostKind);
2177
2178 auto *ShuffleCmpTy =
2179 FixedVectorType::get(BinOpTy->getElementType(), ShuffleDstTy);
2180 InstructionCost NewCost =
2181 TTI.getShuffleCost(SK0, ShuffleCmpTy, BinOpTy, NewMask0, CostKind, 0,
2182 nullptr, {X, Z}) +
2183 TTI.getShuffleCost(SK1, ShuffleCmpTy, BinOpTy, NewMask1, CostKind, 0,
2184 nullptr, {Y, W});
2185
2186 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
2187 NewCost +=
2188 TTI.getArithmeticInstrCost(LHS->getOpcode(), ShuffleDstTy, CostKind);
2189 } else {
2190 NewCost += TTI.getCmpSelInstrCost(LHS->getOpcode(), ShuffleCmpTy,
2191 ShuffleDstTy, PredLHS, CostKind);
2192 }
2193
2194 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
2195 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2196 << "\n");
2197
2198 // If either shuffle will constant fold away, then fold for the same cost as
2199 // we will reduce the instruction count.
2200 ReducedInstCount |= (isa<Constant>(X) && isa<Constant>(Z)) ||
2201 (isa<Constant>(Y) && isa<Constant>(W));
2202 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
2203 return false;
2204
2205 Value *Shuf0 = Builder.CreateShuffleVector(X, Z, NewMask0);
2206 Value *Shuf1 = Builder.CreateShuffleVector(Y, W, NewMask1);
2207 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
2208 ? Builder.CreateBinOp(
2209 cast<BinaryOperator>(LHS)->getOpcode(), Shuf0, Shuf1)
2210 : Builder.CreateCmp(PredLHS, Shuf0, Shuf1);
2211
2212 // Intersect flags from the old binops.
2213 if (auto *NewInst = dyn_cast<Instruction>(NewBO)) {
2214 NewInst->copyIRFlags(LHS);
2215 NewInst->andIRFlags(RHS);
2216 }
2217
2218 Worklist.pushValue(Shuf0);
2219 Worklist.pushValue(Shuf1);
2220 replaceValue(I, *NewBO);
2221 return true;
2222 }
2223
2224 /// Try to convert,
2225 /// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
2226 /// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
foldShuffleOfSelects(Instruction & I)2227 bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
2228 ArrayRef<int> Mask;
2229 Value *C1, *T1, *F1, *C2, *T2, *F2;
2230 if (!match(&I, m_Shuffle(
2231 m_OneUse(m_Select(m_Value(C1), m_Value(T1), m_Value(F1))),
2232 m_OneUse(m_Select(m_Value(C2), m_Value(T2), m_Value(F2))),
2233 m_Mask(Mask))))
2234 return false;
2235
2236 auto *C1VecTy = dyn_cast<FixedVectorType>(C1->getType());
2237 auto *C2VecTy = dyn_cast<FixedVectorType>(C2->getType());
2238 if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
2239 return false;
2240
2241 auto *SI0FOp = dyn_cast<FPMathOperator>(I.getOperand(0));
2242 auto *SI1FOp = dyn_cast<FPMathOperator>(I.getOperand(1));
2243 // SelectInsts must have the same FMF.
2244 if (((SI0FOp == nullptr) != (SI1FOp == nullptr)) ||
2245 ((SI0FOp != nullptr) &&
2246 (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())))
2247 return false;
2248
2249 auto *SrcVecTy = cast<FixedVectorType>(T1->getType());
2250 auto *DstVecTy = cast<FixedVectorType>(I.getType());
2251 auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
2252 auto SelOp = Instruction::Select;
2253 InstructionCost OldCost = TTI.getCmpSelInstrCost(
2254 SelOp, SrcVecTy, C1VecTy, CmpInst::BAD_ICMP_PREDICATE, CostKind);
2255 OldCost += TTI.getCmpSelInstrCost(SelOp, SrcVecTy, C2VecTy,
2256 CmpInst::BAD_ICMP_PREDICATE, CostKind);
2257 OldCost +=
2258 TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0, nullptr,
2259 {I.getOperand(0), I.getOperand(1)}, &I);
2260
2261 InstructionCost NewCost = TTI.getShuffleCost(
2262 SK, FixedVectorType::get(C1VecTy->getScalarType(), Mask.size()), C1VecTy,
2263 Mask, CostKind, 0, nullptr, {C1, C2});
2264 NewCost += TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0,
2265 nullptr, {T1, T2});
2266 NewCost += TTI.getShuffleCost(SK, DstVecTy, SrcVecTy, Mask, CostKind, 0,
2267 nullptr, {F1, F2});
2268 auto *C1C2ShuffledVecTy = cast<FixedVectorType>(
2269 toVectorTy(Type::getInt1Ty(I.getContext()), DstVecTy->getNumElements()));
2270 NewCost += TTI.getCmpSelInstrCost(SelOp, DstVecTy, C1C2ShuffledVecTy,
2271 CmpInst::BAD_ICMP_PREDICATE, CostKind);
2272
2273 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
2274 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2275 << "\n");
2276 if (NewCost > OldCost)
2277 return false;
2278
2279 Value *ShuffleCmp = Builder.CreateShuffleVector(C1, C2, Mask);
2280 Value *ShuffleTrue = Builder.CreateShuffleVector(T1, T2, Mask);
2281 Value *ShuffleFalse = Builder.CreateShuffleVector(F1, F2, Mask);
2282 Value *NewSel;
2283 // We presuppose that the SelectInsts have the same FMF.
2284 if (SI0FOp)
2285 NewSel = Builder.CreateSelectFMF(ShuffleCmp, ShuffleTrue, ShuffleFalse,
2286 SI0FOp->getFastMathFlags());
2287 else
2288 NewSel = Builder.CreateSelect(ShuffleCmp, ShuffleTrue, ShuffleFalse);
2289
2290 Worklist.pushValue(ShuffleCmp);
2291 Worklist.pushValue(ShuffleTrue);
2292 Worklist.pushValue(ShuffleFalse);
2293 replaceValue(I, *NewSel);
2294 return true;
2295 }
2296
2297 /// Try to convert "shuffle (castop), (castop)" with a shared castop operand
2298 /// into "castop (shuffle)".
foldShuffleOfCastops(Instruction & I)2299 bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2300 Value *V0, *V1;
2301 ArrayRef<int> OldMask;
2302 if (!match(&I, m_Shuffle(m_Value(V0), m_Value(V1), m_Mask(OldMask))))
2303 return false;
2304
2305 auto *C0 = dyn_cast<CastInst>(V0);
2306 auto *C1 = dyn_cast<CastInst>(V1);
2307 if (!C0 || !C1)
2308 return false;
2309
2310 Instruction::CastOps Opcode = C0->getOpcode();
2311 if (C0->getSrcTy() != C1->getSrcTy())
2312 return false;
2313
2314 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2315 if (Opcode != C1->getOpcode()) {
2316 if (match(C0, m_SExtLike(m_Value())) && match(C1, m_SExtLike(m_Value())))
2317 Opcode = Instruction::SExt;
2318 else
2319 return false;
2320 }
2321
2322 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2323 auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
2324 auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
2325 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
2326 return false;
2327
2328 unsigned NumSrcElts = CastSrcTy->getNumElements();
2329 unsigned NumDstElts = CastDstTy->getNumElements();
2330 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
2331 "Only bitcasts expected to alter src/dst element counts");
2332
2333 // Check for bitcasting of unscalable vector types.
2334 // e.g. <32 x i40> -> <40 x i32>
2335 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
2336 (NumDstElts % NumSrcElts) != 0)
2337 return false;
2338
2339 SmallVector<int, 16> NewMask;
2340 if (NumSrcElts >= NumDstElts) {
2341 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
2342 // always be expanded to the equivalent form choosing narrower elements.
2343 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
2344 unsigned ScaleFactor = NumSrcElts / NumDstElts;
2345 narrowShuffleMaskElts(ScaleFactor, OldMask, NewMask);
2346 } else {
2347 // The bitcast is from narrow elements to wide elements. The shuffle mask
2348 // must choose consecutive elements to allow casting first.
2349 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
2350 unsigned ScaleFactor = NumDstElts / NumSrcElts;
2351 if (!widenShuffleMaskElts(ScaleFactor, OldMask, NewMask))
2352 return false;
2353 }
2354
2355 auto *NewShuffleDstTy =
2356 FixedVectorType::get(CastSrcTy->getScalarType(), NewMask.size());
2357
2358 // Try to replace a castop with a shuffle if the shuffle is not costly.
2359 InstructionCost CostC0 =
2360 TTI.getCastInstrCost(C0->getOpcode(), CastDstTy, CastSrcTy,
2361 TTI::CastContextHint::None, CostKind);
2362 InstructionCost CostC1 =
2363 TTI.getCastInstrCost(C1->getOpcode(), CastDstTy, CastSrcTy,
2364 TTI::CastContextHint::None, CostKind);
2365 InstructionCost OldCost = CostC0 + CostC1;
2366 OldCost +=
2367 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
2368 CastDstTy, OldMask, CostKind, 0, nullptr, {}, &I);
2369
2370 InstructionCost NewCost =
2371 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, NewShuffleDstTy,
2372 CastSrcTy, NewMask, CostKind);
2373 NewCost += TTI.getCastInstrCost(Opcode, ShuffleDstTy, NewShuffleDstTy,
2374 TTI::CastContextHint::None, CostKind);
2375 if (!C0->hasOneUse())
2376 NewCost += CostC0;
2377 if (!C1->hasOneUse())
2378 NewCost += CostC1;
2379
2380 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
2381 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2382 << "\n");
2383 if (NewCost > OldCost)
2384 return false;
2385
2386 Value *Shuf = Builder.CreateShuffleVector(C0->getOperand(0),
2387 C1->getOperand(0), NewMask);
2388 Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
2389
2390 // Intersect flags from the old casts.
2391 if (auto *NewInst = dyn_cast<Instruction>(Cast)) {
2392 NewInst->copyIRFlags(C0);
2393 NewInst->andIRFlags(C1);
2394 }
2395
2396 Worklist.pushValue(Shuf);
2397 replaceValue(I, *Cast);
2398 return true;
2399 }
2400
2401 /// Try to convert any of:
2402 /// "shuffle (shuffle x, y), (shuffle y, x)"
2403 /// "shuffle (shuffle x, undef), (shuffle y, undef)"
2404 /// "shuffle (shuffle x, undef), y"
2405 /// "shuffle x, (shuffle y, undef)"
2406 /// into "shuffle x, y".
foldShuffleOfShuffles(Instruction & I)2407 bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
2408 ArrayRef<int> OuterMask;
2409 Value *OuterV0, *OuterV1;
2410 if (!match(&I,
2411 m_Shuffle(m_Value(OuterV0), m_Value(OuterV1), m_Mask(OuterMask))))
2412 return false;
2413
2414 ArrayRef<int> InnerMask0, InnerMask1;
2415 Value *X0, *X1, *Y0, *Y1;
2416 bool Match0 =
2417 match(OuterV0, m_Shuffle(m_Value(X0), m_Value(Y0), m_Mask(InnerMask0)));
2418 bool Match1 =
2419 match(OuterV1, m_Shuffle(m_Value(X1), m_Value(Y1), m_Mask(InnerMask1)));
2420 if (!Match0 && !Match1)
2421 return false;
2422
2423 // If the outer shuffle is a permute, then create a fake inner all-poison
2424 // shuffle. This is easier than accounting for length-changing shuffles below.
2425 SmallVector<int, 16> PoisonMask1;
2426 if (!Match1 && isa<PoisonValue>(OuterV1)) {
2427 X1 = X0;
2428 Y1 = Y0;
2429 PoisonMask1.append(InnerMask0.size(), PoisonMaskElem);
2430 InnerMask1 = PoisonMask1;
2431 Match1 = true; // fake match
2432 }
2433
2434 X0 = Match0 ? X0 : OuterV0;
2435 Y0 = Match0 ? Y0 : OuterV0;
2436 X1 = Match1 ? X1 : OuterV1;
2437 Y1 = Match1 ? Y1 : OuterV1;
2438 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2439 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(X0->getType());
2440 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(OuterV0->getType());
2441 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
2442 X0->getType() != X1->getType())
2443 return false;
2444
2445 unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
2446 unsigned NumImmElts = ShuffleImmTy->getNumElements();
2447
2448 // Attempt to merge shuffles, matching upto 2 source operands.
2449 // Replace index to a poison arg with PoisonMaskElem.
2450 // Bail if either inner masks reference an undef arg.
2451 SmallVector<int, 16> NewMask(OuterMask);
2452 Value *NewX = nullptr, *NewY = nullptr;
2453 for (int &M : NewMask) {
2454 Value *Src = nullptr;
2455 if (0 <= M && M < (int)NumImmElts) {
2456 Src = OuterV0;
2457 if (Match0) {
2458 M = InnerMask0[M];
2459 Src = M >= (int)NumSrcElts ? Y0 : X0;
2460 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2461 }
2462 } else if (M >= (int)NumImmElts) {
2463 Src = OuterV1;
2464 M -= NumImmElts;
2465 if (Match1) {
2466 M = InnerMask1[M];
2467 Src = M >= (int)NumSrcElts ? Y1 : X1;
2468 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2469 }
2470 }
2471 if (Src && M != PoisonMaskElem) {
2472 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
2473 if (isa<UndefValue>(Src)) {
2474 // We've referenced an undef element - if its poison, update the shuffle
2475 // mask, else bail.
2476 if (!isa<PoisonValue>(Src))
2477 return false;
2478 M = PoisonMaskElem;
2479 continue;
2480 }
2481 if (!NewX || NewX == Src) {
2482 NewX = Src;
2483 continue;
2484 }
2485 if (!NewY || NewY == Src) {
2486 M += NumSrcElts;
2487 NewY = Src;
2488 continue;
2489 }
2490 return false;
2491 }
2492 }
2493
2494 if (!NewX)
2495 return PoisonValue::get(ShuffleDstTy);
2496 if (!NewY)
2497 NewY = PoisonValue::get(ShuffleSrcTy);
2498
2499 // Have we folded to an Identity shuffle?
2500 if (ShuffleVectorInst::isIdentityMask(NewMask, NumSrcElts)) {
2501 replaceValue(I, *NewX);
2502 return true;
2503 }
2504
2505 // Try to merge the shuffles if the new shuffle is not costly.
2506 InstructionCost InnerCost0 = 0;
2507 if (Match0)
2508 InnerCost0 = TTI.getInstructionCost(cast<User>(OuterV0), CostKind);
2509
2510 InstructionCost InnerCost1 = 0;
2511 if (Match1)
2512 InnerCost1 = TTI.getInstructionCost(cast<User>(OuterV1), CostKind);
2513
2514 InstructionCost OuterCost = TTI.getInstructionCost(&I, CostKind);
2515
2516 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
2517
2518 bool IsUnary = all_of(NewMask, [&](int M) { return M < (int)NumSrcElts; });
2519 TargetTransformInfo::ShuffleKind SK =
2520 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
2521 : TargetTransformInfo::SK_PermuteTwoSrc;
2522 InstructionCost NewCost =
2523 TTI.getShuffleCost(SK, ShuffleDstTy, ShuffleSrcTy, NewMask, CostKind, 0,
2524 nullptr, {NewX, NewY});
2525 if (!OuterV0->hasOneUse())
2526 NewCost += InnerCost0;
2527 if (!OuterV1->hasOneUse())
2528 NewCost += InnerCost1;
2529
2530 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
2531 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2532 << "\n");
2533 if (NewCost > OldCost)
2534 return false;
2535
2536 Value *Shuf = Builder.CreateShuffleVector(NewX, NewY, NewMask);
2537 replaceValue(I, *Shuf);
2538 return true;
2539 }
2540
2541 /// Try to convert
2542 /// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
foldShuffleOfIntrinsics(Instruction & I)2543 bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
2544 Value *V0, *V1;
2545 ArrayRef<int> OldMask;
2546 if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)),
2547 m_Mask(OldMask))))
2548 return false;
2549
2550 auto *II0 = dyn_cast<IntrinsicInst>(V0);
2551 auto *II1 = dyn_cast<IntrinsicInst>(V1);
2552 if (!II0 || !II1)
2553 return false;
2554
2555 Intrinsic::ID IID = II0->getIntrinsicID();
2556 if (IID != II1->getIntrinsicID())
2557 return false;
2558
2559 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
2560 auto *II0Ty = dyn_cast<FixedVectorType>(II0->getType());
2561 if (!ShuffleDstTy || !II0Ty)
2562 return false;
2563
2564 if (!isTriviallyVectorizable(IID))
2565 return false;
2566
2567 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2568 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI) &&
2569 II0->getArgOperand(I) != II1->getArgOperand(I))
2570 return false;
2571
2572 InstructionCost OldCost =
2573 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II0), CostKind) +
2574 TTI.getIntrinsicInstrCost(IntrinsicCostAttributes(IID, *II1), CostKind) +
2575 TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc, ShuffleDstTy,
2576 II0Ty, OldMask, CostKind, 0, nullptr, {II0, II1}, &I);
2577
2578 SmallVector<Type *> NewArgsTy;
2579 InstructionCost NewCost = 0;
2580 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
2581 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
2582 NewArgsTy.push_back(II0->getArgOperand(I)->getType());
2583 } else {
2584 auto *VecTy = cast<FixedVectorType>(II0->getArgOperand(I)->getType());
2585 auto *ArgTy = FixedVectorType::get(VecTy->getElementType(),
2586 ShuffleDstTy->getNumElements());
2587 NewArgsTy.push_back(ArgTy);
2588 NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
2589 ArgTy, VecTy, OldMask, CostKind);
2590 }
2591 }
2592 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
2593 NewCost += TTI.getIntrinsicInstrCost(NewAttr, CostKind);
2594
2595 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
2596 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2597 << "\n");
2598
2599 if (NewCost > OldCost)
2600 return false;
2601
2602 SmallVector<Value *> NewArgs;
2603 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2604 if (isVectorIntrinsicWithScalarOpAtArg(IID, I, &TTI)) {
2605 NewArgs.push_back(II0->getArgOperand(I));
2606 } else {
2607 Value *Shuf = Builder.CreateShuffleVector(II0->getArgOperand(I),
2608 II1->getArgOperand(I), OldMask);
2609 NewArgs.push_back(Shuf);
2610 Worklist.pushValue(Shuf);
2611 }
2612 Value *NewIntrinsic = Builder.CreateIntrinsic(ShuffleDstTy, IID, NewArgs);
2613
2614 // Intersect flags from the old intrinsics.
2615 if (auto *NewInst = dyn_cast<Instruction>(NewIntrinsic)) {
2616 NewInst->copyIRFlags(II0);
2617 NewInst->andIRFlags(II1);
2618 }
2619
2620 replaceValue(I, *NewIntrinsic);
2621 return true;
2622 }
2623
2624 using InstLane = std::pair<Use *, int>;
2625
lookThroughShuffles(Use * U,int Lane)2626 static InstLane lookThroughShuffles(Use *U, int Lane) {
2627 while (auto *SV = dyn_cast<ShuffleVectorInst>(U->get())) {
2628 unsigned NumElts =
2629 cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
2630 int M = SV->getMaskValue(Lane);
2631 if (M < 0)
2632 return {nullptr, PoisonMaskElem};
2633 if (static_cast<unsigned>(M) < NumElts) {
2634 U = &SV->getOperandUse(0);
2635 Lane = M;
2636 } else {
2637 U = &SV->getOperandUse(1);
2638 Lane = M - NumElts;
2639 }
2640 }
2641 return InstLane{U, Lane};
2642 }
2643
2644 static SmallVector<InstLane>
generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item,int Op)2645 generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
2646 SmallVector<InstLane> NItem;
2647 for (InstLane IL : Item) {
2648 auto [U, Lane] = IL;
2649 InstLane OpLane =
2650 U ? lookThroughShuffles(&cast<Instruction>(U->get())->getOperandUse(Op),
2651 Lane)
2652 : InstLane{nullptr, PoisonMaskElem};
2653 NItem.emplace_back(OpLane);
2654 }
2655 return NItem;
2656 }
2657
2658 /// Detect concat of multiple values into a vector
isFreeConcat(ArrayRef<InstLane> Item,TTI::TargetCostKind CostKind,const TargetTransformInfo & TTI)2659 static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind,
2660 const TargetTransformInfo &TTI) {
2661 auto *Ty = cast<FixedVectorType>(Item.front().first->get()->getType());
2662 unsigned NumElts = Ty->getNumElements();
2663 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
2664 return false;
2665
2666 // Check that the concat is free, usually meaning that the type will be split
2667 // during legalization.
2668 SmallVector<int, 16> ConcatMask(NumElts * 2);
2669 std::iota(ConcatMask.begin(), ConcatMask.end(), 0);
2670 if (TTI.getShuffleCost(TTI::SK_PermuteTwoSrc,
2671 FixedVectorType::get(Ty->getScalarType(), NumElts * 2),
2672 Ty, ConcatMask, CostKind) != 0)
2673 return false;
2674
2675 unsigned NumSlices = Item.size() / NumElts;
2676 // Currently we generate a tree of shuffles for the concats, which limits us
2677 // to a power2.
2678 if (!isPowerOf2_32(NumSlices))
2679 return false;
2680 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
2681 Use *SliceV = Item[Slice * NumElts].first;
2682 if (!SliceV || SliceV->get()->getType() != Ty)
2683 return false;
2684 for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
2685 auto [V, Lane] = Item[Slice * NumElts + Elt];
2686 if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get())
2687 return false;
2688 }
2689 }
2690 return true;
2691 }
2692
generateNewInstTree(ArrayRef<InstLane> Item,FixedVectorType * Ty,const SmallPtrSet<Use *,4> & IdentityLeafs,const SmallPtrSet<Use *,4> & SplatLeafs,const SmallPtrSet<Use *,4> & ConcatLeafs,IRBuilderBase & Builder,const TargetTransformInfo * TTI)2693 static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
2694 const SmallPtrSet<Use *, 4> &IdentityLeafs,
2695 const SmallPtrSet<Use *, 4> &SplatLeafs,
2696 const SmallPtrSet<Use *, 4> &ConcatLeafs,
2697 IRBuilderBase &Builder,
2698 const TargetTransformInfo *TTI) {
2699 auto [FrontU, FrontLane] = Item.front();
2700
2701 if (IdentityLeafs.contains(FrontU)) {
2702 return FrontU->get();
2703 }
2704 if (SplatLeafs.contains(FrontU)) {
2705 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
2706 return Builder.CreateShuffleVector(FrontU->get(), Mask);
2707 }
2708 if (ConcatLeafs.contains(FrontU)) {
2709 unsigned NumElts =
2710 cast<FixedVectorType>(FrontU->get()->getType())->getNumElements();
2711 SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
2712 for (unsigned S = 0; S < Values.size(); ++S)
2713 Values[S] = Item[S * NumElts].first->get();
2714
2715 while (Values.size() > 1) {
2716 NumElts *= 2;
2717 SmallVector<int, 16> Mask(NumElts, 0);
2718 std::iota(Mask.begin(), Mask.end(), 0);
2719 SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
2720 for (unsigned S = 0; S < NewValues.size(); ++S)
2721 NewValues[S] =
2722 Builder.CreateShuffleVector(Values[S * 2], Values[S * 2 + 1], Mask);
2723 Values = NewValues;
2724 }
2725 return Values[0];
2726 }
2727
2728 auto *I = cast<Instruction>(FrontU->get());
2729 auto *II = dyn_cast<IntrinsicInst>(I);
2730 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
2731 SmallVector<Value *> Ops(NumOps);
2732 for (unsigned Idx = 0; Idx < NumOps; Idx++) {
2733 if (II &&
2734 isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx, TTI)) {
2735 Ops[Idx] = II->getOperand(Idx);
2736 continue;
2737 }
2738 Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
2739 Ty, IdentityLeafs, SplatLeafs, ConcatLeafs,
2740 Builder, TTI);
2741 }
2742
2743 SmallVector<Value *, 8> ValueList;
2744 for (const auto &Lane : Item)
2745 if (Lane.first)
2746 ValueList.push_back(Lane.first->get());
2747
2748 Type *DstTy =
2749 FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
2750 if (auto *BI = dyn_cast<BinaryOperator>(I)) {
2751 auto *Value = Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
2752 Ops[0], Ops[1]);
2753 propagateIRFlags(Value, ValueList);
2754 return Value;
2755 }
2756 if (auto *CI = dyn_cast<CmpInst>(I)) {
2757 auto *Value = Builder.CreateCmp(CI->getPredicate(), Ops[0], Ops[1]);
2758 propagateIRFlags(Value, ValueList);
2759 return Value;
2760 }
2761 if (auto *SI = dyn_cast<SelectInst>(I)) {
2762 auto *Value = Builder.CreateSelect(Ops[0], Ops[1], Ops[2], "", SI);
2763 propagateIRFlags(Value, ValueList);
2764 return Value;
2765 }
2766 if (auto *CI = dyn_cast<CastInst>(I)) {
2767 auto *Value = Builder.CreateCast((Instruction::CastOps)CI->getOpcode(),
2768 Ops[0], DstTy);
2769 propagateIRFlags(Value, ValueList);
2770 return Value;
2771 }
2772 if (II) {
2773 auto *Value = Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
2774 propagateIRFlags(Value, ValueList);
2775 return Value;
2776 }
2777 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
2778 auto *Value =
2779 Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
2780 propagateIRFlags(Value, ValueList);
2781 return Value;
2782 }
2783
2784 // Starting from a shuffle, look up through operands tracking the shuffled index
2785 // of each lane. If we can simplify away the shuffles to identities then
2786 // do so.
foldShuffleToIdentity(Instruction & I)2787 bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
2788 auto *Ty = dyn_cast<FixedVectorType>(I.getType());
2789 if (!Ty || I.use_empty())
2790 return false;
2791
2792 SmallVector<InstLane> Start(Ty->getNumElements());
2793 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
2794 Start[M] = lookThroughShuffles(&*I.use_begin(), M);
2795
2796 SmallVector<SmallVector<InstLane>> Worklist;
2797 Worklist.push_back(Start);
2798 SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs;
2799 unsigned NumVisited = 0;
2800
2801 while (!Worklist.empty()) {
2802 if (++NumVisited > MaxInstrsToScan)
2803 return false;
2804
2805 SmallVector<InstLane> Item = Worklist.pop_back_val();
2806 auto [FrontU, FrontLane] = Item.front();
2807
2808 // If we found an undef first lane then bail out to keep things simple.
2809 if (!FrontU)
2810 return false;
2811
2812 // Helper to peek through bitcasts to the same value.
2813 auto IsEquiv = [&](Value *X, Value *Y) {
2814 return X->getType() == Y->getType() &&
2815 peekThroughBitcasts(X) == peekThroughBitcasts(Y);
2816 };
2817
2818 // Look for an identity value.
2819 if (FrontLane == 0 &&
2820 cast<FixedVectorType>(FrontU->get()->getType())->getNumElements() ==
2821 Ty->getNumElements() &&
2822 all_of(drop_begin(enumerate(Item)), [IsEquiv, Item](const auto &E) {
2823 Value *FrontV = Item.front().first->get();
2824 return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) &&
2825 E.value().second == (int)E.index());
2826 })) {
2827 IdentityLeafs.insert(FrontU);
2828 continue;
2829 }
2830 // Look for constants, for the moment only supporting constant splats.
2831 if (auto *C = dyn_cast<Constant>(FrontU);
2832 C && C->getSplatValue() &&
2833 all_of(drop_begin(Item), [Item](InstLane &IL) {
2834 Value *FrontV = Item.front().first->get();
2835 Use *U = IL.first;
2836 return !U || (isa<Constant>(U->get()) &&
2837 cast<Constant>(U->get())->getSplatValue() ==
2838 cast<Constant>(FrontV)->getSplatValue());
2839 })) {
2840 SplatLeafs.insert(FrontU);
2841 continue;
2842 }
2843 // Look for a splat value.
2844 if (all_of(drop_begin(Item), [Item](InstLane &IL) {
2845 auto [FrontU, FrontLane] = Item.front();
2846 auto [U, Lane] = IL;
2847 return !U || (U->get() == FrontU->get() && Lane == FrontLane);
2848 })) {
2849 SplatLeafs.insert(FrontU);
2850 continue;
2851 }
2852
2853 // We need each element to be the same type of value, and check that each
2854 // element has a single use.
2855 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
2856 Value *FrontV = Item.front().first->get();
2857 if (!IL.first)
2858 return true;
2859 Value *V = IL.first->get();
2860 if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
2861 return false;
2862 if (V->getValueID() != FrontV->getValueID())
2863 return false;
2864 if (auto *CI = dyn_cast<CmpInst>(V))
2865 if (CI->getPredicate() != cast<CmpInst>(FrontV)->getPredicate())
2866 return false;
2867 if (auto *CI = dyn_cast<CastInst>(V))
2868 if (CI->getSrcTy()->getScalarType() !=
2869 cast<CastInst>(FrontV)->getSrcTy()->getScalarType())
2870 return false;
2871 if (auto *SI = dyn_cast<SelectInst>(V))
2872 if (!isa<VectorType>(SI->getOperand(0)->getType()) ||
2873 SI->getOperand(0)->getType() !=
2874 cast<SelectInst>(FrontV)->getOperand(0)->getType())
2875 return false;
2876 if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
2877 return false;
2878 auto *II = dyn_cast<IntrinsicInst>(V);
2879 return !II || (isa<IntrinsicInst>(FrontV) &&
2880 II->getIntrinsicID() ==
2881 cast<IntrinsicInst>(FrontV)->getIntrinsicID() &&
2882 !II->hasOperandBundles());
2883 };
2884 if (all_of(drop_begin(Item), CheckLaneIsEquivalentToFirst)) {
2885 // Check the operator is one that we support.
2886 if (isa<BinaryOperator, CmpInst>(FrontU)) {
2887 // We exclude div/rem in case they hit UB from poison lanes.
2888 if (auto *BO = dyn_cast<BinaryOperator>(FrontU);
2889 BO && BO->isIntDivRem())
2890 return false;
2891 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2892 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
2893 continue;
2894 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
2895 FPToUIInst, SIToFPInst, UIToFPInst>(FrontU)) {
2896 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2897 continue;
2898 } else if (auto *BitCast = dyn_cast<BitCastInst>(FrontU)) {
2899 // TODO: Handle vector widening/narrowing bitcasts.
2900 auto *DstTy = dyn_cast<FixedVectorType>(BitCast->getDestTy());
2901 auto *SrcTy = dyn_cast<FixedVectorType>(BitCast->getSrcTy());
2902 if (DstTy && SrcTy &&
2903 SrcTy->getNumElements() == DstTy->getNumElements()) {
2904 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2905 continue;
2906 }
2907 } else if (isa<SelectInst>(FrontU)) {
2908 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
2909 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
2910 Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
2911 continue;
2912 } else if (auto *II = dyn_cast<IntrinsicInst>(FrontU);
2913 II && isTriviallyVectorizable(II->getIntrinsicID()) &&
2914 !II->hasOperandBundles()) {
2915 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
2916 if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op,
2917 &TTI)) {
2918 if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
2919 Value *FrontV = Item.front().first->get();
2920 Use *U = IL.first;
2921 return !U || (cast<Instruction>(U->get())->getOperand(Op) ==
2922 cast<Instruction>(FrontV)->getOperand(Op));
2923 }))
2924 return false;
2925 continue;
2926 }
2927 Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
2928 }
2929 continue;
2930 }
2931 }
2932
2933 if (isFreeConcat(Item, CostKind, TTI)) {
2934 ConcatLeafs.insert(FrontU);
2935 continue;
2936 }
2937
2938 return false;
2939 }
2940
2941 if (NumVisited <= 1)
2942 return false;
2943
2944 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
2945
2946 // If we got this far, we know the shuffles are superfluous and can be
2947 // removed. Scan through again and generate the new tree of instructions.
2948 Builder.SetInsertPoint(&I);
2949 Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs,
2950 ConcatLeafs, Builder, &TTI);
2951 replaceValue(I, *V);
2952 return true;
2953 }
2954
2955 /// Given a commutative reduction, the order of the input lanes does not alter
2956 /// the results. We can use this to remove certain shuffles feeding the
2957 /// reduction, removing the need to shuffle at all.
foldShuffleFromReductions(Instruction & I)2958 bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
2959 auto *II = dyn_cast<IntrinsicInst>(&I);
2960 if (!II)
2961 return false;
2962 switch (II->getIntrinsicID()) {
2963 case Intrinsic::vector_reduce_add:
2964 case Intrinsic::vector_reduce_mul:
2965 case Intrinsic::vector_reduce_and:
2966 case Intrinsic::vector_reduce_or:
2967 case Intrinsic::vector_reduce_xor:
2968 case Intrinsic::vector_reduce_smin:
2969 case Intrinsic::vector_reduce_smax:
2970 case Intrinsic::vector_reduce_umin:
2971 case Intrinsic::vector_reduce_umax:
2972 break;
2973 default:
2974 return false;
2975 }
2976
2977 // Find all the inputs when looking through operations that do not alter the
2978 // lane order (binops, for example). Currently we look for a single shuffle,
2979 // and can ignore splat values.
2980 std::queue<Value *> Worklist;
2981 SmallPtrSet<Value *, 4> Visited;
2982 ShuffleVectorInst *Shuffle = nullptr;
2983 if (auto *Op = dyn_cast<Instruction>(I.getOperand(0)))
2984 Worklist.push(Op);
2985
2986 while (!Worklist.empty()) {
2987 Value *CV = Worklist.front();
2988 Worklist.pop();
2989 if (Visited.contains(CV))
2990 continue;
2991
2992 // Splats don't change the order, so can be safely ignored.
2993 if (isSplatValue(CV))
2994 continue;
2995
2996 Visited.insert(CV);
2997
2998 if (auto *CI = dyn_cast<Instruction>(CV)) {
2999 if (CI->isBinaryOp()) {
3000 for (auto *Op : CI->operand_values())
3001 Worklist.push(Op);
3002 continue;
3003 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(CI)) {
3004 if (Shuffle && Shuffle != SV)
3005 return false;
3006 Shuffle = SV;
3007 continue;
3008 }
3009 }
3010
3011 // Anything else is currently an unknown node.
3012 return false;
3013 }
3014
3015 if (!Shuffle)
3016 return false;
3017
3018 // Check all uses of the binary ops and shuffles are also included in the
3019 // lane-invariant operations (Visited should be the list of lanewise
3020 // instructions, including the shuffle that we found).
3021 for (auto *V : Visited)
3022 for (auto *U : V->users())
3023 if (!Visited.contains(U) && U != &I)
3024 return false;
3025
3026 FixedVectorType *VecType =
3027 dyn_cast<FixedVectorType>(II->getOperand(0)->getType());
3028 if (!VecType)
3029 return false;
3030 FixedVectorType *ShuffleInputType =
3031 dyn_cast<FixedVectorType>(Shuffle->getOperand(0)->getType());
3032 if (!ShuffleInputType)
3033 return false;
3034 unsigned NumInputElts = ShuffleInputType->getNumElements();
3035
3036 // Find the mask from sorting the lanes into order. This is most likely to
3037 // become a identity or concat mask. Undef elements are pushed to the end.
3038 SmallVector<int> ConcatMask;
3039 Shuffle->getShuffleMask(ConcatMask);
3040 sort(ConcatMask, [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
3041 bool UsesSecondVec =
3042 any_of(ConcatMask, [&](int M) { return M >= (int)NumInputElts; });
3043
3044 InstructionCost OldCost = TTI.getShuffleCost(
3045 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
3046 ShuffleInputType, Shuffle->getShuffleMask(), CostKind);
3047 InstructionCost NewCost = TTI.getShuffleCost(
3048 UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, VecType,
3049 ShuffleInputType, ConcatMask, CostKind);
3050
3051 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
3052 << "\n");
3053 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
3054 << "\n");
3055 bool MadeChanges = false;
3056 if (NewCost < OldCost) {
3057 Builder.SetInsertPoint(Shuffle);
3058 Value *NewShuffle = Builder.CreateShuffleVector(
3059 Shuffle->getOperand(0), Shuffle->getOperand(1), ConcatMask);
3060 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
3061 replaceValue(*Shuffle, *NewShuffle);
3062 MadeChanges = true;
3063 }
3064
3065 // See if we can re-use foldSelectShuffle, getting it to reduce the size of
3066 // the shuffle into a nicer order, as it can ignore the order of the shuffles.
3067 MadeChanges |= foldSelectShuffle(*Shuffle, true);
3068 return MadeChanges;
3069 }
3070
3071 /// Determine if its more efficient to fold:
3072 /// reduce(trunc(x)) -> trunc(reduce(x)).
3073 /// reduce(sext(x)) -> sext(reduce(x)).
3074 /// reduce(zext(x)) -> zext(reduce(x)).
foldCastFromReductions(Instruction & I)3075 bool VectorCombine::foldCastFromReductions(Instruction &I) {
3076 auto *II = dyn_cast<IntrinsicInst>(&I);
3077 if (!II)
3078 return false;
3079
3080 bool TruncOnly = false;
3081 Intrinsic::ID IID = II->getIntrinsicID();
3082 switch (IID) {
3083 case Intrinsic::vector_reduce_add:
3084 case Intrinsic::vector_reduce_mul:
3085 TruncOnly = true;
3086 break;
3087 case Intrinsic::vector_reduce_and:
3088 case Intrinsic::vector_reduce_or:
3089 case Intrinsic::vector_reduce_xor:
3090 break;
3091 default:
3092 return false;
3093 }
3094
3095 unsigned ReductionOpc = getArithmeticReductionInstruction(IID);
3096 Value *ReductionSrc = I.getOperand(0);
3097
3098 Value *Src;
3099 if (!match(ReductionSrc, m_OneUse(m_Trunc(m_Value(Src)))) &&
3100 (TruncOnly || !match(ReductionSrc, m_OneUse(m_ZExtOrSExt(m_Value(Src))))))
3101 return false;
3102
3103 auto CastOpc =
3104 (Instruction::CastOps)cast<Instruction>(ReductionSrc)->getOpcode();
3105
3106 auto *SrcTy = cast<VectorType>(Src->getType());
3107 auto *ReductionSrcTy = cast<VectorType>(ReductionSrc->getType());
3108 Type *ResultTy = I.getType();
3109
3110 InstructionCost OldCost = TTI.getArithmeticReductionCost(
3111 ReductionOpc, ReductionSrcTy, std::nullopt, CostKind);
3112 OldCost += TTI.getCastInstrCost(CastOpc, ReductionSrcTy, SrcTy,
3113 TTI::CastContextHint::None, CostKind,
3114 cast<CastInst>(ReductionSrc));
3115 InstructionCost NewCost =
3116 TTI.getArithmeticReductionCost(ReductionOpc, SrcTy, std::nullopt,
3117 CostKind) +
3118 TTI.getCastInstrCost(CastOpc, ResultTy, ReductionSrcTy->getScalarType(),
3119 TTI::CastContextHint::None, CostKind);
3120
3121 if (OldCost <= NewCost || !NewCost.isValid())
3122 return false;
3123
3124 Value *NewReduction = Builder.CreateIntrinsic(SrcTy->getScalarType(),
3125 II->getIntrinsicID(), {Src});
3126 Value *NewCast = Builder.CreateCast(CastOpc, NewReduction, ResultTy);
3127 replaceValue(I, *NewCast);
3128 return true;
3129 }
3130
3131 /// This method looks for groups of shuffles acting on binops, of the form:
3132 /// %x = shuffle ...
3133 /// %y = shuffle ...
3134 /// %a = binop %x, %y
3135 /// %b = binop %x, %y
3136 /// shuffle %a, %b, selectmask
3137 /// We may, especially if the shuffle is wider than legal, be able to convert
3138 /// the shuffle to a form where only parts of a and b need to be computed. On
3139 /// architectures with no obvious "select" shuffle, this can reduce the total
3140 /// number of operations if the target reports them as cheaper.
foldSelectShuffle(Instruction & I,bool FromReduction)3141 bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
3142 auto *SVI = cast<ShuffleVectorInst>(&I);
3143 auto *VT = cast<FixedVectorType>(I.getType());
3144 auto *Op0 = dyn_cast<Instruction>(SVI->getOperand(0));
3145 auto *Op1 = dyn_cast<Instruction>(SVI->getOperand(1));
3146 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
3147 VT != Op0->getType())
3148 return false;
3149
3150 auto *SVI0A = dyn_cast<Instruction>(Op0->getOperand(0));
3151 auto *SVI0B = dyn_cast<Instruction>(Op0->getOperand(1));
3152 auto *SVI1A = dyn_cast<Instruction>(Op1->getOperand(0));
3153 auto *SVI1B = dyn_cast<Instruction>(Op1->getOperand(1));
3154 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
3155 auto checkSVNonOpUses = [&](Instruction *I) {
3156 if (!I || I->getOperand(0)->getType() != VT)
3157 return true;
3158 return any_of(I->users(), [&](User *U) {
3159 return U != Op0 && U != Op1 &&
3160 !(isa<ShuffleVectorInst>(U) &&
3161 (InputShuffles.contains(cast<Instruction>(U)) ||
3162 isInstructionTriviallyDead(cast<Instruction>(U))));
3163 });
3164 };
3165 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
3166 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
3167 return false;
3168
3169 // Collect all the uses that are shuffles that we can transform together. We
3170 // may not have a single shuffle, but a group that can all be transformed
3171 // together profitably.
3172 SmallVector<ShuffleVectorInst *> Shuffles;
3173 auto collectShuffles = [&](Instruction *I) {
3174 for (auto *U : I->users()) {
3175 auto *SV = dyn_cast<ShuffleVectorInst>(U);
3176 if (!SV || SV->getType() != VT)
3177 return false;
3178 if ((SV->getOperand(0) != Op0 && SV->getOperand(0) != Op1) ||
3179 (SV->getOperand(1) != Op0 && SV->getOperand(1) != Op1))
3180 return false;
3181 if (!llvm::is_contained(Shuffles, SV))
3182 Shuffles.push_back(SV);
3183 }
3184 return true;
3185 };
3186 if (!collectShuffles(Op0) || !collectShuffles(Op1))
3187 return false;
3188 // From a reduction, we need to be processing a single shuffle, otherwise the
3189 // other uses will not be lane-invariant.
3190 if (FromReduction && Shuffles.size() > 1)
3191 return false;
3192
3193 // Add any shuffle uses for the shuffles we have found, to include them in our
3194 // cost calculations.
3195 if (!FromReduction) {
3196 for (ShuffleVectorInst *SV : Shuffles) {
3197 for (auto *U : SV->users()) {
3198 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(U);
3199 if (SSV && isa<UndefValue>(SSV->getOperand(1)) && SSV->getType() == VT)
3200 Shuffles.push_back(SSV);
3201 }
3202 }
3203 }
3204
3205 // For each of the output shuffles, we try to sort all the first vector
3206 // elements to the beginning, followed by the second array elements at the
3207 // end. If the binops are legalized to smaller vectors, this may reduce total
3208 // number of binops. We compute the ReconstructMask mask needed to convert
3209 // back to the original lane order.
3210 SmallVector<std::pair<int, int>> V1, V2;
3211 SmallVector<SmallVector<int>> OrigReconstructMasks;
3212 int MaxV1Elt = 0, MaxV2Elt = 0;
3213 unsigned NumElts = VT->getNumElements();
3214 for (ShuffleVectorInst *SVN : Shuffles) {
3215 SmallVector<int> Mask;
3216 SVN->getShuffleMask(Mask);
3217
3218 // Check the operands are the same as the original, or reversed (in which
3219 // case we need to commute the mask).
3220 Value *SVOp0 = SVN->getOperand(0);
3221 Value *SVOp1 = SVN->getOperand(1);
3222 if (isa<UndefValue>(SVOp1)) {
3223 auto *SSV = cast<ShuffleVectorInst>(SVOp0);
3224 SVOp0 = SSV->getOperand(0);
3225 SVOp1 = SSV->getOperand(1);
3226 for (int &Elem : Mask) {
3227 if (Elem >= static_cast<int>(SSV->getShuffleMask().size()))
3228 return false;
3229 Elem = Elem < 0 ? Elem : SSV->getMaskValue(Elem);
3230 }
3231 }
3232 if (SVOp0 == Op1 && SVOp1 == Op0) {
3233 std::swap(SVOp0, SVOp1);
3234 ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
3235 }
3236 if (SVOp0 != Op0 || SVOp1 != Op1)
3237 return false;
3238
3239 // Calculate the reconstruction mask for this shuffle, as the mask needed to
3240 // take the packed values from Op0/Op1 and reconstructing to the original
3241 // order.
3242 SmallVector<int> ReconstructMask;
3243 for (unsigned I = 0; I < Mask.size(); I++) {
3244 if (Mask[I] < 0) {
3245 ReconstructMask.push_back(-1);
3246 } else if (Mask[I] < static_cast<int>(NumElts)) {
3247 MaxV1Elt = std::max(MaxV1Elt, Mask[I]);
3248 auto It = find_if(V1, [&](const std::pair<int, int> &A) {
3249 return Mask[I] == A.first;
3250 });
3251 if (It != V1.end())
3252 ReconstructMask.push_back(It - V1.begin());
3253 else {
3254 ReconstructMask.push_back(V1.size());
3255 V1.emplace_back(Mask[I], V1.size());
3256 }
3257 } else {
3258 MaxV2Elt = std::max<int>(MaxV2Elt, Mask[I] - NumElts);
3259 auto It = find_if(V2, [&](const std::pair<int, int> &A) {
3260 return Mask[I] - static_cast<int>(NumElts) == A.first;
3261 });
3262 if (It != V2.end())
3263 ReconstructMask.push_back(NumElts + It - V2.begin());
3264 else {
3265 ReconstructMask.push_back(NumElts + V2.size());
3266 V2.emplace_back(Mask[I] - NumElts, NumElts + V2.size());
3267 }
3268 }
3269 }
3270
3271 // For reductions, we know that the lane ordering out doesn't alter the
3272 // result. In-order can help simplify the shuffle away.
3273 if (FromReduction)
3274 sort(ReconstructMask);
3275 OrigReconstructMasks.push_back(std::move(ReconstructMask));
3276 }
3277
3278 // If the Maximum element used from V1 and V2 are not larger than the new
3279 // vectors, the vectors are already packes and performing the optimization
3280 // again will likely not help any further. This also prevents us from getting
3281 // stuck in a cycle in case the costs do not also rule it out.
3282 if (V1.empty() || V2.empty() ||
3283 (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
3284 MaxV2Elt == static_cast<int>(V2.size()) - 1))
3285 return false;
3286
3287 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
3288 // shuffle of another shuffle, or not a shuffle (that is treated like a
3289 // identity shuffle).
3290 auto GetBaseMaskValue = [&](Instruction *I, int M) {
3291 auto *SV = dyn_cast<ShuffleVectorInst>(I);
3292 if (!SV)
3293 return M;
3294 if (isa<UndefValue>(SV->getOperand(1)))
3295 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
3296 if (InputShuffles.contains(SSV))
3297 return SSV->getMaskValue(SV->getMaskValue(M));
3298 return SV->getMaskValue(M);
3299 };
3300
3301 // Attempt to sort the inputs my ascending mask values to make simpler input
3302 // shuffles and push complex shuffles down to the uses. We sort on the first
3303 // of the two input shuffle orders, to try and get at least one input into a
3304 // nice order.
3305 auto SortBase = [&](Instruction *A, std::pair<int, int> X,
3306 std::pair<int, int> Y) {
3307 int MXA = GetBaseMaskValue(A, X.first);
3308 int MYA = GetBaseMaskValue(A, Y.first);
3309 return MXA < MYA;
3310 };
3311 stable_sort(V1, [&](std::pair<int, int> A, std::pair<int, int> B) {
3312 return SortBase(SVI0A, A, B);
3313 });
3314 stable_sort(V2, [&](std::pair<int, int> A, std::pair<int, int> B) {
3315 return SortBase(SVI1A, A, B);
3316 });
3317 // Calculate our ReconstructMasks from the OrigReconstructMasks and the
3318 // modified order of the input shuffles.
3319 SmallVector<SmallVector<int>> ReconstructMasks;
3320 for (const auto &Mask : OrigReconstructMasks) {
3321 SmallVector<int> ReconstructMask;
3322 for (int M : Mask) {
3323 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
3324 auto It = find_if(V, [M](auto A) { return A.second == M; });
3325 assert(It != V.end() && "Expected all entries in Mask");
3326 return std::distance(V.begin(), It);
3327 };
3328 if (M < 0)
3329 ReconstructMask.push_back(-1);
3330 else if (M < static_cast<int>(NumElts)) {
3331 ReconstructMask.push_back(FindIndex(V1, M));
3332 } else {
3333 ReconstructMask.push_back(NumElts + FindIndex(V2, M));
3334 }
3335 }
3336 ReconstructMasks.push_back(std::move(ReconstructMask));
3337 }
3338
3339 // Calculate the masks needed for the new input shuffles, which get padded
3340 // with undef
3341 SmallVector<int> V1A, V1B, V2A, V2B;
3342 for (unsigned I = 0; I < V1.size(); I++) {
3343 V1A.push_back(GetBaseMaskValue(SVI0A, V1[I].first));
3344 V1B.push_back(GetBaseMaskValue(SVI0B, V1[I].first));
3345 }
3346 for (unsigned I = 0; I < V2.size(); I++) {
3347 V2A.push_back(GetBaseMaskValue(SVI1A, V2[I].first));
3348 V2B.push_back(GetBaseMaskValue(SVI1B, V2[I].first));
3349 }
3350 while (V1A.size() < NumElts) {
3351 V1A.push_back(PoisonMaskElem);
3352 V1B.push_back(PoisonMaskElem);
3353 }
3354 while (V2A.size() < NumElts) {
3355 V2A.push_back(PoisonMaskElem);
3356 V2B.push_back(PoisonMaskElem);
3357 }
3358
3359 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
3360 auto *SV = dyn_cast<ShuffleVectorInst>(I);
3361 if (!SV)
3362 return C;
3363 return C + TTI.getShuffleCost(isa<UndefValue>(SV->getOperand(1))
3364 ? TTI::SK_PermuteSingleSrc
3365 : TTI::SK_PermuteTwoSrc,
3366 VT, VT, SV->getShuffleMask(), CostKind);
3367 };
3368 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
3369 return C +
3370 TTI.getShuffleCost(TTI::SK_PermuteTwoSrc, VT, VT, Mask, CostKind);
3371 };
3372
3373 // Get the costs of the shuffles + binops before and after with the new
3374 // shuffle masks.
3375 InstructionCost CostBefore =
3376 TTI.getArithmeticInstrCost(Op0->getOpcode(), VT, CostKind) +
3377 TTI.getArithmeticInstrCost(Op1->getOpcode(), VT, CostKind);
3378 CostBefore += std::accumulate(Shuffles.begin(), Shuffles.end(),
3379 InstructionCost(0), AddShuffleCost);
3380 CostBefore += std::accumulate(InputShuffles.begin(), InputShuffles.end(),
3381 InstructionCost(0), AddShuffleCost);
3382
3383 // The new binops will be unused for lanes past the used shuffle lengths.
3384 // These types attempt to get the correct cost for that from the target.
3385 FixedVectorType *Op0SmallVT =
3386 FixedVectorType::get(VT->getScalarType(), V1.size());
3387 FixedVectorType *Op1SmallVT =
3388 FixedVectorType::get(VT->getScalarType(), V2.size());
3389 InstructionCost CostAfter =
3390 TTI.getArithmeticInstrCost(Op0->getOpcode(), Op0SmallVT, CostKind) +
3391 TTI.getArithmeticInstrCost(Op1->getOpcode(), Op1SmallVT, CostKind);
3392 CostAfter += std::accumulate(ReconstructMasks.begin(), ReconstructMasks.end(),
3393 InstructionCost(0), AddShuffleMaskCost);
3394 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
3395 CostAfter +=
3396 std::accumulate(OutputShuffleMasks.begin(), OutputShuffleMasks.end(),
3397 InstructionCost(0), AddShuffleMaskCost);
3398
3399 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
3400 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
3401 << " vs CostAfter: " << CostAfter << "\n");
3402 if (CostBefore <= CostAfter)
3403 return false;
3404
3405 // The cost model has passed, create the new instructions.
3406 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
3407 auto *SV = dyn_cast<ShuffleVectorInst>(I);
3408 if (!SV)
3409 return I;
3410 if (isa<UndefValue>(SV->getOperand(1)))
3411 if (auto *SSV = dyn_cast<ShuffleVectorInst>(SV->getOperand(0)))
3412 if (InputShuffles.contains(SSV))
3413 return SSV->getOperand(Op);
3414 return SV->getOperand(Op);
3415 };
3416 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
3417 Value *NSV0A = Builder.CreateShuffleVector(GetShuffleOperand(SVI0A, 0),
3418 GetShuffleOperand(SVI0A, 1), V1A);
3419 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
3420 Value *NSV0B = Builder.CreateShuffleVector(GetShuffleOperand(SVI0B, 0),
3421 GetShuffleOperand(SVI0B, 1), V1B);
3422 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
3423 Value *NSV1A = Builder.CreateShuffleVector(GetShuffleOperand(SVI1A, 0),
3424 GetShuffleOperand(SVI1A, 1), V2A);
3425 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
3426 Value *NSV1B = Builder.CreateShuffleVector(GetShuffleOperand(SVI1B, 0),
3427 GetShuffleOperand(SVI1B, 1), V2B);
3428 Builder.SetInsertPoint(Op0);
3429 Value *NOp0 = Builder.CreateBinOp((Instruction::BinaryOps)Op0->getOpcode(),
3430 NSV0A, NSV0B);
3431 if (auto *I = dyn_cast<Instruction>(NOp0))
3432 I->copyIRFlags(Op0, true);
3433 Builder.SetInsertPoint(Op1);
3434 Value *NOp1 = Builder.CreateBinOp((Instruction::BinaryOps)Op1->getOpcode(),
3435 NSV1A, NSV1B);
3436 if (auto *I = dyn_cast<Instruction>(NOp1))
3437 I->copyIRFlags(Op1, true);
3438
3439 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
3440 Builder.SetInsertPoint(Shuffles[S]);
3441 Value *NSV = Builder.CreateShuffleVector(NOp0, NOp1, ReconstructMasks[S]);
3442 replaceValue(*Shuffles[S], *NSV);
3443 }
3444
3445 Worklist.pushValue(NSV0A);
3446 Worklist.pushValue(NSV0B);
3447 Worklist.pushValue(NSV1A);
3448 Worklist.pushValue(NSV1B);
3449 return true;
3450 }
3451
3452 /// Check if instruction depends on ZExt and this ZExt can be moved after the
3453 /// instruction. Move ZExt if it is profitable. For example:
3454 /// logic(zext(x),y) -> zext(logic(x,trunc(y)))
3455 /// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
3456 /// Cost model calculations takes into account if zext(x) has other users and
3457 /// whether it can be propagated through them too.
shrinkType(Instruction & I)3458 bool VectorCombine::shrinkType(Instruction &I) {
3459 Value *ZExted, *OtherOperand;
3460 if (!match(&I, m_c_BitwiseLogic(m_ZExt(m_Value(ZExted)),
3461 m_Value(OtherOperand))) &&
3462 !match(&I, m_LShr(m_ZExt(m_Value(ZExted)), m_Value(OtherOperand))))
3463 return false;
3464
3465 Value *ZExtOperand = I.getOperand(I.getOperand(0) == OtherOperand ? 1 : 0);
3466
3467 auto *BigTy = cast<FixedVectorType>(I.getType());
3468 auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
3469 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
3470
3471 if (I.getOpcode() == Instruction::LShr) {
3472 // Check that the shift amount is less than the number of bits in the
3473 // smaller type. Otherwise, the smaller lshr will return a poison value.
3474 KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
3475 if (ShAmtKB.getMaxValue().uge(BW))
3476 return false;
3477 } else {
3478 // Check that the expression overall uses at most the same number of bits as
3479 // ZExted
3480 KnownBits KB = computeKnownBits(&I, *DL);
3481 if (KB.countMaxActiveBits() > BW)
3482 return false;
3483 }
3484
3485 // Calculate costs of leaving current IR as it is and moving ZExt operation
3486 // later, along with adding truncates if needed
3487 InstructionCost ZExtCost = TTI.getCastInstrCost(
3488 Instruction::ZExt, BigTy, SmallTy,
3489 TargetTransformInfo::CastContextHint::None, CostKind);
3490 InstructionCost CurrentCost = ZExtCost;
3491 InstructionCost ShrinkCost = 0;
3492
3493 // Calculate total cost and check that we can propagate through all ZExt users
3494 for (User *U : ZExtOperand->users()) {
3495 auto *UI = cast<Instruction>(U);
3496 if (UI == &I) {
3497 CurrentCost +=
3498 TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
3499 ShrinkCost +=
3500 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
3501 ShrinkCost += ZExtCost;
3502 continue;
3503 }
3504
3505 if (!Instruction::isBinaryOp(UI->getOpcode()))
3506 return false;
3507
3508 // Check if we can propagate ZExt through its other users
3509 KnownBits KB = computeKnownBits(UI, *DL);
3510 if (KB.countMaxActiveBits() > BW)
3511 return false;
3512
3513 CurrentCost += TTI.getArithmeticInstrCost(UI->getOpcode(), BigTy, CostKind);
3514 ShrinkCost +=
3515 TTI.getArithmeticInstrCost(UI->getOpcode(), SmallTy, CostKind);
3516 ShrinkCost += ZExtCost;
3517 }
3518
3519 // If the other instruction operand is not a constant, we'll need to
3520 // generate a truncate instruction. So we have to adjust cost
3521 if (!isa<Constant>(OtherOperand))
3522 ShrinkCost += TTI.getCastInstrCost(
3523 Instruction::Trunc, SmallTy, BigTy,
3524 TargetTransformInfo::CastContextHint::None, CostKind);
3525
3526 // If the cost of shrinking types and leaving the IR is the same, we'll lean
3527 // towards modifying the IR because shrinking opens opportunities for other
3528 // shrinking optimisations.
3529 if (ShrinkCost > CurrentCost)
3530 return false;
3531
3532 Builder.SetInsertPoint(&I);
3533 Value *Op0 = ZExted;
3534 Value *Op1 = Builder.CreateTrunc(OtherOperand, SmallTy);
3535 // Keep the order of operands the same
3536 if (I.getOperand(0) == OtherOperand)
3537 std::swap(Op0, Op1);
3538 Value *NewBinOp =
3539 Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(), Op0, Op1);
3540 cast<Instruction>(NewBinOp)->copyIRFlags(&I);
3541 cast<Instruction>(NewBinOp)->copyMetadata(I);
3542 Value *NewZExtr = Builder.CreateZExt(NewBinOp, BigTy);
3543 replaceValue(I, *NewZExtr);
3544 return true;
3545 }
3546
3547 /// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
3548 /// shuffle (DstVec, SrcVec, Mask)
foldInsExtVectorToShuffle(Instruction & I)3549 bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
3550 Value *DstVec, *SrcVec;
3551 uint64_t ExtIdx, InsIdx;
3552 if (!match(&I,
3553 m_InsertElt(m_Value(DstVec),
3554 m_ExtractElt(m_Value(SrcVec), m_ConstantInt(ExtIdx)),
3555 m_ConstantInt(InsIdx))))
3556 return false;
3557
3558 auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
3559 auto *SrcVecTy = dyn_cast<FixedVectorType>(SrcVec->getType());
3560 // We can try combining vectors with different element sizes.
3561 if (!DstVecTy || !SrcVecTy ||
3562 SrcVecTy->getElementType() != DstVecTy->getElementType())
3563 return false;
3564
3565 unsigned NumDstElts = DstVecTy->getNumElements();
3566 unsigned NumSrcElts = SrcVecTy->getNumElements();
3567 if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
3568 return false;
3569
3570 // Insertion into poison is a cheaper single operand shuffle.
3571 TargetTransformInfo::ShuffleKind SK;
3572 SmallVector<int> Mask(NumDstElts, PoisonMaskElem);
3573
3574 bool NeedExpOrNarrow = NumSrcElts != NumDstElts;
3575 bool IsExtIdxInBounds = ExtIdx < NumDstElts;
3576 bool NeedDstSrcSwap = isa<PoisonValue>(DstVec) && !isa<UndefValue>(SrcVec);
3577 if (NeedDstSrcSwap) {
3578 SK = TargetTransformInfo::SK_PermuteSingleSrc;
3579 if (!IsExtIdxInBounds && NeedExpOrNarrow)
3580 Mask[InsIdx] = 0;
3581 else
3582 Mask[InsIdx] = ExtIdx;
3583 std::swap(DstVec, SrcVec);
3584 } else {
3585 SK = TargetTransformInfo::SK_PermuteTwoSrc;
3586 std::iota(Mask.begin(), Mask.end(), 0);
3587 if (!IsExtIdxInBounds && NeedExpOrNarrow)
3588 Mask[InsIdx] = NumDstElts;
3589 else
3590 Mask[InsIdx] = ExtIdx + NumDstElts;
3591 }
3592
3593 // Cost
3594 auto *Ins = cast<InsertElementInst>(&I);
3595 auto *Ext = cast<ExtractElementInst>(I.getOperand(1));
3596 InstructionCost InsCost =
3597 TTI.getVectorInstrCost(*Ins, DstVecTy, CostKind, InsIdx);
3598 InstructionCost ExtCost =
3599 TTI.getVectorInstrCost(*Ext, DstVecTy, CostKind, ExtIdx);
3600 InstructionCost OldCost = ExtCost + InsCost;
3601
3602 InstructionCost NewCost = 0;
3603 SmallVector<int> ExtToVecMask;
3604 if (!NeedExpOrNarrow) {
3605 // Ignore 'free' identity insertion shuffle.
3606 // TODO: getShuffleCost should return TCC_Free for Identity shuffles.
3607 if (!ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts))
3608 NewCost += TTI.getShuffleCost(SK, DstVecTy, DstVecTy, Mask, CostKind, 0,
3609 nullptr, {DstVec, SrcVec});
3610 } else {
3611 // When creating length-changing-vector, always create with a Mask whose
3612 // first element has an ExtIdx, so that the first element of the vector
3613 // being created is always the target to be extracted.
3614 ExtToVecMask.assign(NumDstElts, PoisonMaskElem);
3615 if (IsExtIdxInBounds)
3616 ExtToVecMask[ExtIdx] = ExtIdx;
3617 else
3618 ExtToVecMask[0] = ExtIdx;
3619 // Add cost for expanding or narrowing
3620 NewCost = TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
3621 DstVecTy, SrcVecTy, ExtToVecMask, CostKind);
3622 NewCost += TTI.getShuffleCost(SK, DstVecTy, DstVecTy, Mask, CostKind);
3623 }
3624
3625 if (!Ext->hasOneUse())
3626 NewCost += ExtCost;
3627
3628 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair: " << I
3629 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3630 << "\n");
3631
3632 if (OldCost < NewCost)
3633 return false;
3634
3635 if (NeedExpOrNarrow) {
3636 if (!NeedDstSrcSwap)
3637 SrcVec = Builder.CreateShuffleVector(SrcVec, ExtToVecMask);
3638 else
3639 DstVec = Builder.CreateShuffleVector(DstVec, ExtToVecMask);
3640 }
3641
3642 // Canonicalize undef param to RHS to help further folds.
3643 if (isa<UndefValue>(DstVec) && !isa<UndefValue>(SrcVec)) {
3644 ShuffleVectorInst::commuteShuffleMask(Mask, NumDstElts);
3645 std::swap(DstVec, SrcVec);
3646 }
3647
3648 Value *Shuf = Builder.CreateShuffleVector(DstVec, SrcVec, Mask);
3649 replaceValue(I, *Shuf);
3650
3651 return true;
3652 }
3653
3654 /// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
3655 /// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
3656 /// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
3657 /// before casting it back into `<vscale x 16 x i32>`.
foldInterleaveIntrinsics(Instruction & I)3658 bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
3659 const APInt *SplatVal0, *SplatVal1;
3660 if (!match(&I, m_Intrinsic<Intrinsic::vector_interleave2>(
3661 m_APInt(SplatVal0), m_APInt(SplatVal1))))
3662 return false;
3663
3664 LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
3665 << "\n");
3666
3667 auto *VTy =
3668 cast<VectorType>(cast<IntrinsicInst>(I).getArgOperand(0)->getType());
3669 auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
3670 unsigned Width = VTy->getElementType()->getIntegerBitWidth();
3671
3672 // Just in case the cost of interleave2 intrinsic and bitcast are both
3673 // invalid, in which case we want to bail out, we use <= rather
3674 // than < here. Even they both have valid and equal costs, it's probably
3675 // not a good idea to emit a high-cost constant splat.
3676 if (TTI.getInstructionCost(&I, CostKind) <=
3677 TTI.getCastInstrCost(Instruction::BitCast, I.getType(), ExtVTy,
3678 TTI::CastContextHint::None, CostKind)) {
3679 LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
3680 << *I.getType() << " is too high.\n");
3681 return false;
3682 }
3683
3684 APInt NewSplatVal = SplatVal1->zext(Width * 2);
3685 NewSplatVal <<= Width;
3686 NewSplatVal |= SplatVal0->zext(Width * 2);
3687 auto *NewSplat = ConstantVector::getSplat(
3688 ExtVTy->getElementCount(), ConstantInt::get(F.getContext(), NewSplatVal));
3689
3690 IRBuilder<> Builder(&I);
3691 replaceValue(I, *Builder.CreateBitCast(NewSplat, I.getType()));
3692 return true;
3693 }
3694
3695 /// This is the entry point for all transforms. Pass manager differences are
3696 /// handled in the callers of this function.
run()3697 bool VectorCombine::run() {
3698 if (DisableVectorCombine)
3699 return false;
3700
3701 // Don't attempt vectorization if the target does not support vectors.
3702 if (!TTI.getNumberOfRegisters(TTI.getRegisterClassForType(/*Vector*/ true)))
3703 return false;
3704
3705 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
3706
3707 bool MadeChange = false;
3708 auto FoldInst = [this, &MadeChange](Instruction &I) {
3709 Builder.SetInsertPoint(&I);
3710 bool IsVectorType = isa<VectorType>(I.getType());
3711 bool IsFixedVectorType = isa<FixedVectorType>(I.getType());
3712 auto Opcode = I.getOpcode();
3713
3714 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
3715
3716 // These folds should be beneficial regardless of when this pass is run
3717 // in the optimization pipeline.
3718 // The type checking is for run-time efficiency. We can avoid wasting time
3719 // dispatching to folding functions if there's no chance of matching.
3720 if (IsFixedVectorType) {
3721 switch (Opcode) {
3722 case Instruction::InsertElement:
3723 MadeChange |= vectorizeLoadInsert(I);
3724 break;
3725 case Instruction::ShuffleVector:
3726 MadeChange |= widenSubvectorLoad(I);
3727 break;
3728 default:
3729 break;
3730 }
3731 }
3732
3733 // This transform works with scalable and fixed vectors
3734 // TODO: Identify and allow other scalable transforms
3735 if (IsVectorType) {
3736 MadeChange |= scalarizeOpOrCmp(I);
3737 MadeChange |= scalarizeLoadExtract(I);
3738 MadeChange |= scalarizeExtExtract(I);
3739 MadeChange |= scalarizeVPIntrinsic(I);
3740 MadeChange |= foldInterleaveIntrinsics(I);
3741 }
3742
3743 if (Opcode == Instruction::Store)
3744 MadeChange |= foldSingleElementStore(I);
3745
3746 // If this is an early pipeline invocation of this pass, we are done.
3747 if (TryEarlyFoldsOnly)
3748 return;
3749
3750 // Otherwise, try folds that improve codegen but may interfere with
3751 // early IR canonicalizations.
3752 // The type checking is for run-time efficiency. We can avoid wasting time
3753 // dispatching to folding functions if there's no chance of matching.
3754 if (IsFixedVectorType) {
3755 switch (Opcode) {
3756 case Instruction::InsertElement:
3757 MadeChange |= foldInsExtFNeg(I);
3758 MadeChange |= foldInsExtBinop(I);
3759 MadeChange |= foldInsExtVectorToShuffle(I);
3760 break;
3761 case Instruction::ShuffleVector:
3762 MadeChange |= foldPermuteOfBinops(I);
3763 MadeChange |= foldShuffleOfBinops(I);
3764 MadeChange |= foldShuffleOfSelects(I);
3765 MadeChange |= foldShuffleOfCastops(I);
3766 MadeChange |= foldShuffleOfShuffles(I);
3767 MadeChange |= foldShuffleOfIntrinsics(I);
3768 MadeChange |= foldSelectShuffle(I);
3769 MadeChange |= foldShuffleToIdentity(I);
3770 break;
3771 case Instruction::BitCast:
3772 MadeChange |= foldBitcastShuffle(I);
3773 break;
3774 case Instruction::And:
3775 case Instruction::Or:
3776 case Instruction::Xor:
3777 MadeChange |= foldBitOpOfBitcasts(I);
3778 break;
3779 default:
3780 MadeChange |= shrinkType(I);
3781 break;
3782 }
3783 } else {
3784 switch (Opcode) {
3785 case Instruction::Call:
3786 MadeChange |= foldShuffleFromReductions(I);
3787 MadeChange |= foldCastFromReductions(I);
3788 break;
3789 case Instruction::ICmp:
3790 case Instruction::FCmp:
3791 MadeChange |= foldExtractExtract(I);
3792 break;
3793 case Instruction::Or:
3794 MadeChange |= foldConcatOfBoolMasks(I);
3795 [[fallthrough]];
3796 default:
3797 if (Instruction::isBinaryOp(Opcode)) {
3798 MadeChange |= foldExtractExtract(I);
3799 MadeChange |= foldExtractedCmps(I);
3800 MadeChange |= foldBinopOfReductions(I);
3801 }
3802 break;
3803 }
3804 }
3805 };
3806
3807 for (BasicBlock &BB : F) {
3808 // Ignore unreachable basic blocks.
3809 if (!DT.isReachableFromEntry(&BB))
3810 continue;
3811 // Use early increment range so that we can erase instructions in loop.
3812 for (Instruction &I : make_early_inc_range(BB)) {
3813 if (I.isDebugOrPseudoInst())
3814 continue;
3815 FoldInst(I);
3816 }
3817 }
3818
3819 while (!Worklist.isEmpty()) {
3820 Instruction *I = Worklist.removeOne();
3821 if (!I)
3822 continue;
3823
3824 if (isInstructionTriviallyDead(I)) {
3825 eraseInstruction(*I);
3826 continue;
3827 }
3828
3829 FoldInst(*I);
3830 }
3831
3832 return MadeChange;
3833 }
3834
run(Function & F,FunctionAnalysisManager & FAM)3835 PreservedAnalyses VectorCombinePass::run(Function &F,
3836 FunctionAnalysisManager &FAM) {
3837 auto &AC = FAM.getResult<AssumptionAnalysis>(F);
3838 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F);
3839 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
3840 AAResults &AA = FAM.getResult<AAManager>(F);
3841 const DataLayout *DL = &F.getDataLayout();
3842 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput,
3843 TryEarlyFoldsOnly);
3844 if (!Combiner.run())
3845 return PreservedAnalyses::all();
3846 PreservedAnalyses PA;
3847 PA.preserveSet<CFGAnalyses>();
3848 return PA;
3849 }
3850