xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 //                                    intrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/DomTreeUpdater.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/VectorUtils.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Constant.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/Dominators.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/IRBuilder.h"
28 #include "llvm/IR/Instruction.h"
29 #include "llvm/IR/Instructions.h"
30 #include "llvm/IR/IntrinsicInst.h"
31 #include "llvm/IR/Type.h"
32 #include "llvm/IR/Value.h"
33 #include "llvm/InitializePasses.h"
34 #include "llvm/Pass.h"
35 #include "llvm/Support/Casting.h"
36 #include "llvm/Transforms/Scalar.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 #include <cassert>
39 #include <optional>
40 
41 using namespace llvm;
42 
43 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
44 
45 namespace {
46 
47 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
48 public:
49   static char ID; // Pass identification, replacement for typeid
50 
ScalarizeMaskedMemIntrinLegacyPass()51   explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
52     initializeScalarizeMaskedMemIntrinLegacyPassPass(
53         *PassRegistry::getPassRegistry());
54   }
55 
56   bool runOnFunction(Function &F) override;
57 
getPassName() const58   StringRef getPassName() const override {
59     return "Scalarize Masked Memory Intrinsics";
60   }
61 
getAnalysisUsage(AnalysisUsage & AU) const62   void getAnalysisUsage(AnalysisUsage &AU) const override {
63     AU.addRequired<TargetTransformInfoWrapperPass>();
64     AU.addPreserved<DominatorTreeWrapperPass>();
65   }
66 };
67 
68 } // end anonymous namespace
69 
70 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
71                           const TargetTransformInfo &TTI, const DataLayout &DL,
72                           bool HasBranchDivergence, DomTreeUpdater *DTU);
73 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
74                              const TargetTransformInfo &TTI,
75                              const DataLayout &DL, bool HasBranchDivergence,
76                              DomTreeUpdater *DTU);
77 
78 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
79 
80 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
81                       "Scalarize unsupported masked memory intrinsics", false,
82                       false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)83 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
84 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
85 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
86                     "Scalarize unsupported masked memory intrinsics", false,
87                     false)
88 
89 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
90   return new ScalarizeMaskedMemIntrinLegacyPass();
91 }
92 
isConstantIntVector(Value * Mask)93 static bool isConstantIntVector(Value *Mask) {
94   Constant *C = dyn_cast<Constant>(Mask);
95   if (!C)
96     return false;
97 
98   unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
99   for (unsigned i = 0; i != NumElts; ++i) {
100     Constant *CElt = C->getAggregateElement(i);
101     if (!CElt || !isa<ConstantInt>(CElt))
102       return false;
103   }
104 
105   return true;
106 }
107 
adjustForEndian(const DataLayout & DL,unsigned VectorWidth,unsigned Idx)108 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth,
109                                 unsigned Idx) {
110   return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx;
111 }
112 
113 // Translate a masked load intrinsic like
114 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
115 //                               <16 x i1> %mask, <16 x i32> %passthru)
116 // to a chain of basic blocks, with loading element one-by-one if
117 // the appropriate mask bit is set
118 //
119 //  %1 = bitcast i8* %addr to i32*
120 //  %2 = extractelement <16 x i1> %mask, i32 0
121 //  br i1 %2, label %cond.load, label %else
122 //
123 // cond.load:                                        ; preds = %0
124 //  %3 = getelementptr i32* %1, i32 0
125 //  %4 = load i32* %3
126 //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
127 //  br label %else
128 //
129 // else:                                             ; preds = %0, %cond.load
130 //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ]
131 //  %6 = extractelement <16 x i1> %mask, i32 1
132 //  br i1 %6, label %cond.load1, label %else2
133 //
134 // cond.load1:                                       ; preds = %else
135 //  %7 = getelementptr i32* %1, i32 1
136 //  %8 = load i32* %7
137 //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
138 //  br label %else2
139 //
140 // else2:                                          ; preds = %else, %cond.load1
141 //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
142 //  %10 = extractelement <16 x i1> %mask, i32 2
143 //  br i1 %10, label %cond.load4, label %else5
144 //
scalarizeMaskedLoad(const DataLayout & DL,bool HasBranchDivergence,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)145 static void scalarizeMaskedLoad(const DataLayout &DL, bool HasBranchDivergence,
146                                 CallInst *CI, DomTreeUpdater *DTU,
147                                 bool &ModifiedDT) {
148   Value *Ptr = CI->getArgOperand(0);
149   Value *Alignment = CI->getArgOperand(1);
150   Value *Mask = CI->getArgOperand(2);
151   Value *Src0 = CI->getArgOperand(3);
152 
153   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
154   VectorType *VecType = cast<FixedVectorType>(CI->getType());
155 
156   Type *EltTy = VecType->getElementType();
157 
158   IRBuilder<> Builder(CI->getContext());
159   Instruction *InsertPt = CI;
160   BasicBlock *IfBlock = CI->getParent();
161 
162   Builder.SetInsertPoint(InsertPt);
163   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
164 
165   // Short-cut if the mask is all-true.
166   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
167     LoadInst *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
168     NewI->copyMetadata(*CI);
169     NewI->takeName(CI);
170     CI->replaceAllUsesWith(NewI);
171     CI->eraseFromParent();
172     return;
173   }
174 
175   // Adjust alignment for the scalar instruction.
176   const Align AdjustedAlignVal =
177       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
178   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
179 
180   // The result vector
181   Value *VResult = Src0;
182 
183   if (isConstantIntVector(Mask)) {
184     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
185       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
186         continue;
187       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
188       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
189       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
190     }
191     CI->replaceAllUsesWith(VResult);
192     CI->eraseFromParent();
193     return;
194   }
195 
196   // Optimize the case where the "masked load" is a predicated load - that is,
197   // where the mask is the splat of a non-constant scalar boolean. In that case,
198   // use that splated value as the guard on a conditional vector load.
199   if (isSplatValue(Mask, /*Index=*/0)) {
200     Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
201                                                     Mask->getName() + ".first");
202     Instruction *ThenTerm =
203         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
204                                   /*BranchWeights=*/nullptr, DTU);
205 
206     BasicBlock *CondBlock = ThenTerm->getParent();
207     CondBlock->setName("cond.load");
208     Builder.SetInsertPoint(CondBlock->getTerminator());
209     LoadInst *Load = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal,
210                                                CI->getName() + ".cond.load");
211     Load->copyMetadata(*CI);
212 
213     BasicBlock *PostLoad = ThenTerm->getSuccessor(0);
214     Builder.SetInsertPoint(PostLoad, PostLoad->begin());
215     PHINode *Phi = Builder.CreatePHI(VecType, /*NumReservedValues=*/2);
216     Phi->addIncoming(Load, CondBlock);
217     Phi->addIncoming(Src0, IfBlock);
218     Phi->takeName(CI);
219 
220     CI->replaceAllUsesWith(Phi);
221     CI->eraseFromParent();
222     ModifiedDT = true;
223     return;
224   }
225   // If the mask is not v1i1, use scalar bit test operations. This generates
226   // better results on X86 at least. However, don't do this on GPUs and other
227   // machines with divergence, as there each i1 needs a vector register.
228   Value *SclrMask = nullptr;
229   if (VectorWidth != 1 && !HasBranchDivergence) {
230     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
231     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
232   }
233 
234   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
235     // Fill the "else" block, created in the previous iteration
236     //
237     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
238     //  %else ] %mask_1 = and i16 %scalar_mask, i32 1 << Idx %cond = icmp ne i16
239     //  %mask_1, 0 br i1 %mask_1, label %cond.load, label %else
240     //
241     // On GPUs, use
242     //  %cond = extrectelement %mask, Idx
243     // instead
244     Value *Predicate;
245     if (SclrMask != nullptr) {
246       Value *Mask = Builder.getInt(APInt::getOneBitSet(
247           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
248       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
249                                        Builder.getIntN(VectorWidth, 0));
250     } else {
251       Predicate = Builder.CreateExtractElement(Mask, Idx);
252     }
253 
254     // Create "cond" block
255     //
256     //  %EltAddr = getelementptr i32* %1, i32 0
257     //  %Elt = load i32* %EltAddr
258     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
259     //
260     Instruction *ThenTerm =
261         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
262                                   /*BranchWeights=*/nullptr, DTU);
263 
264     BasicBlock *CondBlock = ThenTerm->getParent();
265     CondBlock->setName("cond.load");
266 
267     Builder.SetInsertPoint(CondBlock->getTerminator());
268     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
269     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
270     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
271 
272     // Create "else" block, fill it in the next iteration
273     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
274     NewIfBlock->setName("else");
275     BasicBlock *PrevIfBlock = IfBlock;
276     IfBlock = NewIfBlock;
277 
278     // Create the phi to join the new and previous value.
279     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
280     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
281     Phi->addIncoming(NewVResult, CondBlock);
282     Phi->addIncoming(VResult, PrevIfBlock);
283     VResult = Phi;
284   }
285 
286   CI->replaceAllUsesWith(VResult);
287   CI->eraseFromParent();
288 
289   ModifiedDT = true;
290 }
291 
292 // Translate a masked store intrinsic, like
293 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
294 //                               <16 x i1> %mask)
295 // to a chain of basic blocks, that stores element one-by-one if
296 // the appropriate mask bit is set
297 //
298 //   %1 = bitcast i8* %addr to i32*
299 //   %2 = extractelement <16 x i1> %mask, i32 0
300 //   br i1 %2, label %cond.store, label %else
301 //
302 // cond.store:                                       ; preds = %0
303 //   %3 = extractelement <16 x i32> %val, i32 0
304 //   %4 = getelementptr i32* %1, i32 0
305 //   store i32 %3, i32* %4
306 //   br label %else
307 //
308 // else:                                             ; preds = %0, %cond.store
309 //   %5 = extractelement <16 x i1> %mask, i32 1
310 //   br i1 %5, label %cond.store1, label %else2
311 //
312 // cond.store1:                                      ; preds = %else
313 //   %6 = extractelement <16 x i32> %val, i32 1
314 //   %7 = getelementptr i32* %1, i32 1
315 //   store i32 %6, i32* %7
316 //   br label %else2
317 //   . . .
scalarizeMaskedStore(const DataLayout & DL,bool HasBranchDivergence,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)318 static void scalarizeMaskedStore(const DataLayout &DL, bool HasBranchDivergence,
319                                  CallInst *CI, DomTreeUpdater *DTU,
320                                  bool &ModifiedDT) {
321   Value *Src = CI->getArgOperand(0);
322   Value *Ptr = CI->getArgOperand(1);
323   Value *Alignment = CI->getArgOperand(2);
324   Value *Mask = CI->getArgOperand(3);
325 
326   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
327   auto *VecType = cast<VectorType>(Src->getType());
328 
329   Type *EltTy = VecType->getElementType();
330 
331   IRBuilder<> Builder(CI->getContext());
332   Instruction *InsertPt = CI;
333   Builder.SetInsertPoint(InsertPt);
334   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
335 
336   // Short-cut if the mask is all-true.
337   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
338     StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
339     Store->takeName(CI);
340     Store->copyMetadata(*CI);
341     CI->eraseFromParent();
342     return;
343   }
344 
345   // Adjust alignment for the scalar instruction.
346   const Align AdjustedAlignVal =
347       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
348   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
349 
350   if (isConstantIntVector(Mask)) {
351     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
352       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
353         continue;
354       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
355       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
356       Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
357     }
358     CI->eraseFromParent();
359     return;
360   }
361 
362   // Optimize the case where the "masked store" is a predicated store - that is,
363   // when the mask is the splat of a non-constant scalar boolean. In that case,
364   // optimize to a conditional store.
365   if (isSplatValue(Mask, /*Index=*/0)) {
366     Value *Predicate = Builder.CreateExtractElement(Mask, uint64_t(0ull),
367                                                     Mask->getName() + ".first");
368     Instruction *ThenTerm =
369         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
370                                   /*BranchWeights=*/nullptr, DTU);
371     BasicBlock *CondBlock = ThenTerm->getParent();
372     CondBlock->setName("cond.store");
373     Builder.SetInsertPoint(CondBlock->getTerminator());
374 
375     StoreInst *Store = Builder.CreateAlignedStore(Src, Ptr, AlignVal);
376     Store->takeName(CI);
377     Store->copyMetadata(*CI);
378 
379     CI->eraseFromParent();
380     ModifiedDT = true;
381     return;
382   }
383 
384   // If the mask is not v1i1, use scalar bit test operations. This generates
385   // better results on X86 at least. However, don't do this on GPUs or other
386   // machines with branch divergence, as there each i1 takes up a register.
387   Value *SclrMask = nullptr;
388   if (VectorWidth != 1 && !HasBranchDivergence) {
389     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
390     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
391   }
392 
393   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
394     // Fill the "else" block, created in the previous iteration
395     //
396     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
397     //  %cond = icmp ne i16 %mask_1, 0
398     //  br i1 %mask_1, label %cond.store, label %else
399     //
400     // On GPUs, use
401     //  %cond = extrectelement %mask, Idx
402     // instead
403     Value *Predicate;
404     if (SclrMask != nullptr) {
405       Value *Mask = Builder.getInt(APInt::getOneBitSet(
406           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
407       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
408                                        Builder.getIntN(VectorWidth, 0));
409     } else {
410       Predicate = Builder.CreateExtractElement(Mask, Idx);
411     }
412 
413     // Create "cond" block
414     //
415     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
416     //  %EltAddr = getelementptr i32* %1, i32 0
417     //  %store i32 %OneElt, i32* %EltAddr
418     //
419     Instruction *ThenTerm =
420         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
421                                   /*BranchWeights=*/nullptr, DTU);
422 
423     BasicBlock *CondBlock = ThenTerm->getParent();
424     CondBlock->setName("cond.store");
425 
426     Builder.SetInsertPoint(CondBlock->getTerminator());
427     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
428     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx);
429     Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
430 
431     // Create "else" block, fill it in the next iteration
432     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
433     NewIfBlock->setName("else");
434 
435     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
436   }
437   CI->eraseFromParent();
438 
439   ModifiedDT = true;
440 }
441 
442 // Translate a masked gather intrinsic like
443 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
444 //                               <16 x i1> %Mask, <16 x i32> %Src)
445 // to a chain of basic blocks, with loading element one-by-one if
446 // the appropriate mask bit is set
447 //
448 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
449 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
450 // br i1 %Mask0, label %cond.load, label %else
451 //
452 // cond.load:
453 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454 // %Load0 = load i32, i32* %Ptr0, align 4
455 // %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0
456 // br label %else
457 //
458 // else:
459 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0]
460 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
461 // br i1 %Mask1, label %cond.load1, label %else2
462 //
463 // cond.load1:
464 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
465 // %Load1 = load i32, i32* %Ptr1, align 4
466 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
467 // br label %else2
468 // . . .
469 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
470 // ret <16 x i32> %Result
scalarizeMaskedGather(const DataLayout & DL,bool HasBranchDivergence,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)471 static void scalarizeMaskedGather(const DataLayout &DL,
472                                   bool HasBranchDivergence, CallInst *CI,
473                                   DomTreeUpdater *DTU, bool &ModifiedDT) {
474   Value *Ptrs = CI->getArgOperand(0);
475   Value *Alignment = CI->getArgOperand(1);
476   Value *Mask = CI->getArgOperand(2);
477   Value *Src0 = CI->getArgOperand(3);
478 
479   auto *VecType = cast<FixedVectorType>(CI->getType());
480   Type *EltTy = VecType->getElementType();
481 
482   IRBuilder<> Builder(CI->getContext());
483   Instruction *InsertPt = CI;
484   BasicBlock *IfBlock = CI->getParent();
485   Builder.SetInsertPoint(InsertPt);
486   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
487 
488   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
489 
490   // The result vector
491   Value *VResult = Src0;
492   unsigned VectorWidth = VecType->getNumElements();
493 
494   // Shorten the way if the mask is a vector of constants.
495   if (isConstantIntVector(Mask)) {
496     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
497       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
498         continue;
499       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
500       LoadInst *Load =
501           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
502       VResult =
503           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
504     }
505     CI->replaceAllUsesWith(VResult);
506     CI->eraseFromParent();
507     return;
508   }
509 
510   // If the mask is not v1i1, use scalar bit test operations. This generates
511   // better results on X86 at least. However, don't do this on GPUs or other
512   // machines with branch divergence, as there, each i1 takes up a register.
513   Value *SclrMask = nullptr;
514   if (VectorWidth != 1 && !HasBranchDivergence) {
515     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
516     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
517   }
518 
519   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
520     // Fill the "else" block, created in the previous iteration
521     //
522     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
523     //  %cond = icmp ne i16 %mask_1, 0
524     //  br i1 %Mask1, label %cond.load, label %else
525     //
526     // On GPUs, use
527     //  %cond = extrectelement %mask, Idx
528     // instead
529 
530     Value *Predicate;
531     if (SclrMask != nullptr) {
532       Value *Mask = Builder.getInt(APInt::getOneBitSet(
533           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
534       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
535                                        Builder.getIntN(VectorWidth, 0));
536     } else {
537       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
538     }
539 
540     // Create "cond" block
541     //
542     //  %EltAddr = getelementptr i32* %1, i32 0
543     //  %Elt = load i32* %EltAddr
544     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
545     //
546     Instruction *ThenTerm =
547         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
548                                   /*BranchWeights=*/nullptr, DTU);
549 
550     BasicBlock *CondBlock = ThenTerm->getParent();
551     CondBlock->setName("cond.load");
552 
553     Builder.SetInsertPoint(CondBlock->getTerminator());
554     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
555     LoadInst *Load =
556         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
557     Value *NewVResult =
558         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
559 
560     // Create "else" block, fill it in the next iteration
561     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
562     NewIfBlock->setName("else");
563     BasicBlock *PrevIfBlock = IfBlock;
564     IfBlock = NewIfBlock;
565 
566     // Create the phi to join the new and previous value.
567     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
568     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
569     Phi->addIncoming(NewVResult, CondBlock);
570     Phi->addIncoming(VResult, PrevIfBlock);
571     VResult = Phi;
572   }
573 
574   CI->replaceAllUsesWith(VResult);
575   CI->eraseFromParent();
576 
577   ModifiedDT = true;
578 }
579 
580 // Translate a masked scatter intrinsic, like
581 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
582 //                                  <16 x i1> %Mask)
583 // to a chain of basic blocks, that stores element one-by-one if
584 // the appropriate mask bit is set.
585 //
586 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
587 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
588 // br i1 %Mask0, label %cond.store, label %else
589 //
590 // cond.store:
591 // %Elt0 = extractelement <16 x i32> %Src, i32 0
592 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
593 // store i32 %Elt0, i32* %Ptr0, align 4
594 // br label %else
595 //
596 // else:
597 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
598 // br i1 %Mask1, label %cond.store1, label %else2
599 //
600 // cond.store1:
601 // %Elt1 = extractelement <16 x i32> %Src, i32 1
602 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
603 // store i32 %Elt1, i32* %Ptr1, align 4
604 // br label %else2
605 //   . . .
scalarizeMaskedScatter(const DataLayout & DL,bool HasBranchDivergence,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)606 static void scalarizeMaskedScatter(const DataLayout &DL,
607                                    bool HasBranchDivergence, CallInst *CI,
608                                    DomTreeUpdater *DTU, bool &ModifiedDT) {
609   Value *Src = CI->getArgOperand(0);
610   Value *Ptrs = CI->getArgOperand(1);
611   Value *Alignment = CI->getArgOperand(2);
612   Value *Mask = CI->getArgOperand(3);
613 
614   auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
615 
616   assert(
617       isa<VectorType>(Ptrs->getType()) &&
618       isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
619       "Vector of pointers is expected in masked scatter intrinsic");
620 
621   IRBuilder<> Builder(CI->getContext());
622   Instruction *InsertPt = CI;
623   Builder.SetInsertPoint(InsertPt);
624   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
625 
626   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
627   unsigned VectorWidth = SrcFVTy->getNumElements();
628 
629   // Shorten the way if the mask is a vector of constants.
630   if (isConstantIntVector(Mask)) {
631     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
632       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
633         continue;
634       Value *OneElt =
635           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
636       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
637       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
638     }
639     CI->eraseFromParent();
640     return;
641   }
642 
643   // If the mask is not v1i1, use scalar bit test operations. This generates
644   // better results on X86 at least.
645   Value *SclrMask = nullptr;
646   if (VectorWidth != 1 && !HasBranchDivergence) {
647     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
648     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
649   }
650 
651   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
652     // Fill the "else" block, created in the previous iteration
653     //
654     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
655     //  %cond = icmp ne i16 %mask_1, 0
656     //  br i1 %Mask1, label %cond.store, label %else
657     //
658     // On GPUs, use
659     //  %cond = extrectelement %mask, Idx
660     // instead
661     Value *Predicate;
662     if (SclrMask != nullptr) {
663       Value *Mask = Builder.getInt(APInt::getOneBitSet(
664           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
665       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
666                                        Builder.getIntN(VectorWidth, 0));
667     } else {
668       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
669     }
670 
671     // Create "cond" block
672     //
673     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
674     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
675     //  %store i32 %Elt1, i32* %Ptr1
676     //
677     Instruction *ThenTerm =
678         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
679                                   /*BranchWeights=*/nullptr, DTU);
680 
681     BasicBlock *CondBlock = ThenTerm->getParent();
682     CondBlock->setName("cond.store");
683 
684     Builder.SetInsertPoint(CondBlock->getTerminator());
685     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
686     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
687     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
688 
689     // Create "else" block, fill it in the next iteration
690     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
691     NewIfBlock->setName("else");
692 
693     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
694   }
695   CI->eraseFromParent();
696 
697   ModifiedDT = true;
698 }
699 
scalarizeMaskedExpandLoad(const DataLayout & DL,bool HasBranchDivergence,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)700 static void scalarizeMaskedExpandLoad(const DataLayout &DL,
701                                       bool HasBranchDivergence, CallInst *CI,
702                                       DomTreeUpdater *DTU, bool &ModifiedDT) {
703   Value *Ptr = CI->getArgOperand(0);
704   Value *Mask = CI->getArgOperand(1);
705   Value *PassThru = CI->getArgOperand(2);
706   Align Alignment = CI->getParamAlign(0).valueOrOne();
707 
708   auto *VecType = cast<FixedVectorType>(CI->getType());
709 
710   Type *EltTy = VecType->getElementType();
711 
712   IRBuilder<> Builder(CI->getContext());
713   Instruction *InsertPt = CI;
714   BasicBlock *IfBlock = CI->getParent();
715 
716   Builder.SetInsertPoint(InsertPt);
717   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
718 
719   unsigned VectorWidth = VecType->getNumElements();
720 
721   // The result vector
722   Value *VResult = PassThru;
723 
724   // Adjust alignment for the scalar instruction.
725   const Align AdjustedAlignment =
726       commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
727 
728   // Shorten the way if the mask is a vector of constants.
729   // Create a build_vector pattern, with loads/poisons as necessary and then
730   // shuffle blend with the pass through value.
731   if (isConstantIntVector(Mask)) {
732     unsigned MemIndex = 0;
733     VResult = PoisonValue::get(VecType);
734     SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem);
735     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
736       Value *InsertElt;
737       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
738         InsertElt = PoisonValue::get(EltTy);
739         ShuffleMask[Idx] = Idx + VectorWidth;
740       } else {
741         Value *NewPtr =
742             Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
743         InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, AdjustedAlignment,
744                                               "Load" + Twine(Idx));
745         ShuffleMask[Idx] = Idx;
746         ++MemIndex;
747       }
748       VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
749                                             "Res" + Twine(Idx));
750     }
751     VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
752     CI->replaceAllUsesWith(VResult);
753     CI->eraseFromParent();
754     return;
755   }
756 
757   // If the mask is not v1i1, use scalar bit test operations. This generates
758   // better results on X86 at least. However, don't do this on GPUs or other
759   // machines with branch divergence, as there, each i1 takes up a register.
760   Value *SclrMask = nullptr;
761   if (VectorWidth != 1 && !HasBranchDivergence) {
762     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
763     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
764   }
765 
766   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
767     // Fill the "else" block, created in the previous iteration
768     //
769     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else,
770     //  %else ] %mask_1 = extractelement <16 x i1> %mask, i32 Idx br i1 %mask_1,
771     //  label %cond.load, label %else
772     //
773     // On GPUs, use
774     //  %cond = extrectelement %mask, Idx
775     // instead
776 
777     Value *Predicate;
778     if (SclrMask != nullptr) {
779       Value *Mask = Builder.getInt(APInt::getOneBitSet(
780           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
781       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
782                                        Builder.getIntN(VectorWidth, 0));
783     } else {
784       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
785     }
786 
787     // Create "cond" block
788     //
789     //  %EltAddr = getelementptr i32* %1, i32 0
790     //  %Elt = load i32* %EltAddr
791     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
792     //
793     Instruction *ThenTerm =
794         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
795                                   /*BranchWeights=*/nullptr, DTU);
796 
797     BasicBlock *CondBlock = ThenTerm->getParent();
798     CondBlock->setName("cond.load");
799 
800     Builder.SetInsertPoint(CondBlock->getTerminator());
801     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, AdjustedAlignment);
802     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
803 
804     // Move the pointer if there are more blocks to come.
805     Value *NewPtr;
806     if ((Idx + 1) != VectorWidth)
807       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
808 
809     // Create "else" block, fill it in the next iteration
810     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
811     NewIfBlock->setName("else");
812     BasicBlock *PrevIfBlock = IfBlock;
813     IfBlock = NewIfBlock;
814 
815     // Create the phi to join the new and previous value.
816     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
817     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
818     ResultPhi->addIncoming(NewVResult, CondBlock);
819     ResultPhi->addIncoming(VResult, PrevIfBlock);
820     VResult = ResultPhi;
821 
822     // Add a PHI for the pointer if this isn't the last iteration.
823     if ((Idx + 1) != VectorWidth) {
824       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
825       PtrPhi->addIncoming(NewPtr, CondBlock);
826       PtrPhi->addIncoming(Ptr, PrevIfBlock);
827       Ptr = PtrPhi;
828     }
829   }
830 
831   CI->replaceAllUsesWith(VResult);
832   CI->eraseFromParent();
833 
834   ModifiedDT = true;
835 }
836 
scalarizeMaskedCompressStore(const DataLayout & DL,bool HasBranchDivergence,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)837 static void scalarizeMaskedCompressStore(const DataLayout &DL,
838                                          bool HasBranchDivergence, CallInst *CI,
839                                          DomTreeUpdater *DTU,
840                                          bool &ModifiedDT) {
841   Value *Src = CI->getArgOperand(0);
842   Value *Ptr = CI->getArgOperand(1);
843   Value *Mask = CI->getArgOperand(2);
844   Align Alignment = CI->getParamAlign(1).valueOrOne();
845 
846   auto *VecType = cast<FixedVectorType>(Src->getType());
847 
848   IRBuilder<> Builder(CI->getContext());
849   Instruction *InsertPt = CI;
850   BasicBlock *IfBlock = CI->getParent();
851 
852   Builder.SetInsertPoint(InsertPt);
853   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
854 
855   Type *EltTy = VecType->getElementType();
856 
857   // Adjust alignment for the scalar instruction.
858   const Align AdjustedAlignment =
859       commonAlignment(Alignment, EltTy->getPrimitiveSizeInBits() / 8);
860 
861   unsigned VectorWidth = VecType->getNumElements();
862 
863   // Shorten the way if the mask is a vector of constants.
864   if (isConstantIntVector(Mask)) {
865     unsigned MemIndex = 0;
866     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
867       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
868         continue;
869       Value *OneElt =
870           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
871       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
872       Builder.CreateAlignedStore(OneElt, NewPtr, AdjustedAlignment);
873       ++MemIndex;
874     }
875     CI->eraseFromParent();
876     return;
877   }
878 
879   // If the mask is not v1i1, use scalar bit test operations. This generates
880   // better results on X86 at least. However, don't do this on GPUs or other
881   // machines with branch divergence, as there, each i1 takes up a register.
882   Value *SclrMask = nullptr;
883   if (VectorWidth != 1 && !HasBranchDivergence) {
884     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
885     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
886   }
887 
888   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
889     // Fill the "else" block, created in the previous iteration
890     //
891     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
892     //  br i1 %mask_1, label %cond.store, label %else
893     //
894     // On GPUs, use
895     //  %cond = extrectelement %mask, Idx
896     // instead
897     Value *Predicate;
898     if (SclrMask != nullptr) {
899       Value *Mask = Builder.getInt(APInt::getOneBitSet(
900           VectorWidth, adjustForEndian(DL, VectorWidth, Idx)));
901       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
902                                        Builder.getIntN(VectorWidth, 0));
903     } else {
904       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
905     }
906 
907     // Create "cond" block
908     //
909     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
910     //  %EltAddr = getelementptr i32* %1, i32 0
911     //  %store i32 %OneElt, i32* %EltAddr
912     //
913     Instruction *ThenTerm =
914         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
915                                   /*BranchWeights=*/nullptr, DTU);
916 
917     BasicBlock *CondBlock = ThenTerm->getParent();
918     CondBlock->setName("cond.store");
919 
920     Builder.SetInsertPoint(CondBlock->getTerminator());
921     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
922     Builder.CreateAlignedStore(OneElt, Ptr, AdjustedAlignment);
923 
924     // Move the pointer if there are more blocks to come.
925     Value *NewPtr;
926     if ((Idx + 1) != VectorWidth)
927       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
928 
929     // Create "else" block, fill it in the next iteration
930     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
931     NewIfBlock->setName("else");
932     BasicBlock *PrevIfBlock = IfBlock;
933     IfBlock = NewIfBlock;
934 
935     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
936 
937     // Add a PHI for the pointer if this isn't the last iteration.
938     if ((Idx + 1) != VectorWidth) {
939       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
940       PtrPhi->addIncoming(NewPtr, CondBlock);
941       PtrPhi->addIncoming(Ptr, PrevIfBlock);
942       Ptr = PtrPhi;
943     }
944   }
945   CI->eraseFromParent();
946 
947   ModifiedDT = true;
948 }
949 
scalarizeMaskedVectorHistogram(const DataLayout & DL,CallInst * CI,DomTreeUpdater * DTU,bool & ModifiedDT)950 static void scalarizeMaskedVectorHistogram(const DataLayout &DL, CallInst *CI,
951                                            DomTreeUpdater *DTU,
952                                            bool &ModifiedDT) {
953   // If we extend histogram to return a result someday (like the updated vector)
954   // then we'll need to support it here.
955   assert(CI->getType()->isVoidTy() && "Histogram with non-void return.");
956   Value *Ptrs = CI->getArgOperand(0);
957   Value *Inc = CI->getArgOperand(1);
958   Value *Mask = CI->getArgOperand(2);
959 
960   auto *AddrType = cast<FixedVectorType>(Ptrs->getType());
961   Type *EltTy = Inc->getType();
962 
963   IRBuilder<> Builder(CI->getContext());
964   Instruction *InsertPt = CI;
965   Builder.SetInsertPoint(InsertPt);
966 
967   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
968 
969   // FIXME: Do we need to add an alignment parameter to the intrinsic?
970   unsigned VectorWidth = AddrType->getNumElements();
971   auto CreateHistogramUpdateValue = [&](IntrinsicInst *CI, Value *Load,
972                                         Value *Inc) -> Value * {
973     Value *UpdateOp;
974     switch (CI->getIntrinsicID()) {
975     case Intrinsic::experimental_vector_histogram_add:
976       UpdateOp = Builder.CreateAdd(Load, Inc);
977       break;
978     case Intrinsic::experimental_vector_histogram_uadd_sat:
979       UpdateOp =
980           Builder.CreateIntrinsic(Intrinsic::uadd_sat, {EltTy}, {Load, Inc});
981       break;
982     case Intrinsic::experimental_vector_histogram_umin:
983       UpdateOp = Builder.CreateIntrinsic(Intrinsic::umin, {EltTy}, {Load, Inc});
984       break;
985     case Intrinsic::experimental_vector_histogram_umax:
986       UpdateOp = Builder.CreateIntrinsic(Intrinsic::umax, {EltTy}, {Load, Inc});
987       break;
988 
989     default:
990       llvm_unreachable("Unexpected histogram intrinsic");
991     }
992     return UpdateOp;
993   };
994 
995   // Shorten the way if the mask is a vector of constants.
996   if (isConstantIntVector(Mask)) {
997     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
998       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
999         continue;
1000       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
1001       LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
1002       Value *Update =
1003           CreateHistogramUpdateValue(cast<IntrinsicInst>(CI), Load, Inc);
1004       Builder.CreateStore(Update, Ptr);
1005     }
1006     CI->eraseFromParent();
1007     return;
1008   }
1009 
1010   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
1011     Value *Predicate =
1012         Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
1013 
1014     Instruction *ThenTerm =
1015         SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false,
1016                                   /*BranchWeights=*/nullptr, DTU);
1017 
1018     BasicBlock *CondBlock = ThenTerm->getParent();
1019     CondBlock->setName("cond.histogram.update");
1020 
1021     Builder.SetInsertPoint(CondBlock->getTerminator());
1022     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
1023     LoadInst *Load = Builder.CreateLoad(EltTy, Ptr, "Load" + Twine(Idx));
1024     Value *UpdateOp =
1025         CreateHistogramUpdateValue(cast<IntrinsicInst>(CI), Load, Inc);
1026     Builder.CreateStore(UpdateOp, Ptr);
1027 
1028     // Create "else" block, fill it in the next iteration
1029     BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0);
1030     NewIfBlock->setName("else");
1031     Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin());
1032   }
1033 
1034   CI->eraseFromParent();
1035   ModifiedDT = true;
1036 }
1037 
runImpl(Function & F,const TargetTransformInfo & TTI,DominatorTree * DT)1038 static bool runImpl(Function &F, const TargetTransformInfo &TTI,
1039                     DominatorTree *DT) {
1040   std::optional<DomTreeUpdater> DTU;
1041   if (DT)
1042     DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1043 
1044   bool EverMadeChange = false;
1045   bool MadeChange = true;
1046   auto &DL = F.getDataLayout();
1047   bool HasBranchDivergence = TTI.hasBranchDivergence(&F);
1048   while (MadeChange) {
1049     MadeChange = false;
1050     for (BasicBlock &BB : llvm::make_early_inc_range(F)) {
1051       bool ModifiedDTOnIteration = false;
1052       MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL,
1053                                   HasBranchDivergence, DTU ? &*DTU : nullptr);
1054 
1055       // Restart BB iteration if the dominator tree of the Function was changed
1056       if (ModifiedDTOnIteration)
1057         break;
1058     }
1059 
1060     EverMadeChange |= MadeChange;
1061   }
1062   return EverMadeChange;
1063 }
1064 
runOnFunction(Function & F)1065 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
1066   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
1067   DominatorTree *DT = nullptr;
1068   if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>())
1069     DT = &DTWP->getDomTree();
1070   return runImpl(F, TTI, DT);
1071 }
1072 
1073 PreservedAnalyses
run(Function & F,FunctionAnalysisManager & AM)1074 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
1075   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1076   auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
1077   if (!runImpl(F, TTI, DT))
1078     return PreservedAnalyses::all();
1079   PreservedAnalyses PA;
1080   PA.preserve<TargetIRAnalysis>();
1081   PA.preserve<DominatorTreeAnalysis>();
1082   return PA;
1083 }
1084 
optimizeBlock(BasicBlock & BB,bool & ModifiedDT,const TargetTransformInfo & TTI,const DataLayout & DL,bool HasBranchDivergence,DomTreeUpdater * DTU)1085 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
1086                           const TargetTransformInfo &TTI, const DataLayout &DL,
1087                           bool HasBranchDivergence, DomTreeUpdater *DTU) {
1088   bool MadeChange = false;
1089 
1090   BasicBlock::iterator CurInstIterator = BB.begin();
1091   while (CurInstIterator != BB.end()) {
1092     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
1093       MadeChange |=
1094           optimizeCallInst(CI, ModifiedDT, TTI, DL, HasBranchDivergence, DTU);
1095     if (ModifiedDT)
1096       return true;
1097   }
1098 
1099   return MadeChange;
1100 }
1101 
optimizeCallInst(CallInst * CI,bool & ModifiedDT,const TargetTransformInfo & TTI,const DataLayout & DL,bool HasBranchDivergence,DomTreeUpdater * DTU)1102 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
1103                              const TargetTransformInfo &TTI,
1104                              const DataLayout &DL, bool HasBranchDivergence,
1105                              DomTreeUpdater *DTU) {
1106   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
1107   if (II) {
1108     // The scalarization code below does not work for scalable vectors.
1109     if (isa<ScalableVectorType>(II->getType()) ||
1110         any_of(II->args(),
1111                [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
1112       return false;
1113     switch (II->getIntrinsicID()) {
1114     default:
1115       break;
1116     case Intrinsic::experimental_vector_histogram_add:
1117     case Intrinsic::experimental_vector_histogram_uadd_sat:
1118     case Intrinsic::experimental_vector_histogram_umin:
1119     case Intrinsic::experimental_vector_histogram_umax:
1120       if (TTI.isLegalMaskedVectorHistogram(CI->getArgOperand(0)->getType(),
1121                                            CI->getArgOperand(1)->getType()))
1122         return false;
1123       scalarizeMaskedVectorHistogram(DL, CI, DTU, ModifiedDT);
1124       return true;
1125     case Intrinsic::masked_load:
1126       // Scalarize unsupported vector masked load
1127       if (TTI.isLegalMaskedLoad(
1128               CI->getType(),
1129               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue(),
1130               cast<PointerType>(CI->getArgOperand(0)->getType())
1131                   ->getAddressSpace()))
1132         return false;
1133       scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1134       return true;
1135     case Intrinsic::masked_store:
1136       if (TTI.isLegalMaskedStore(
1137               CI->getArgOperand(0)->getType(),
1138               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue(),
1139               cast<PointerType>(CI->getArgOperand(1)->getType())
1140                   ->getAddressSpace()))
1141         return false;
1142       scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1143       return true;
1144     case Intrinsic::masked_gather: {
1145       MaybeAlign MA =
1146           cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue();
1147       Type *LoadTy = CI->getType();
1148       Align Alignment = DL.getValueOrABITypeAlignment(MA,
1149                                                       LoadTy->getScalarType());
1150       if (TTI.isLegalMaskedGather(LoadTy, Alignment) &&
1151           !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment))
1152         return false;
1153       scalarizeMaskedGather(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1154       return true;
1155     }
1156     case Intrinsic::masked_scatter: {
1157       MaybeAlign MA =
1158           cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue();
1159       Type *StoreTy = CI->getArgOperand(0)->getType();
1160       Align Alignment = DL.getValueOrABITypeAlignment(MA,
1161                                                       StoreTy->getScalarType());
1162       if (TTI.isLegalMaskedScatter(StoreTy, Alignment) &&
1163           !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy),
1164                                            Alignment))
1165         return false;
1166       scalarizeMaskedScatter(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1167       return true;
1168     }
1169     case Intrinsic::masked_expandload:
1170       if (TTI.isLegalMaskedExpandLoad(
1171               CI->getType(),
1172               CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne()))
1173         return false;
1174       scalarizeMaskedExpandLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT);
1175       return true;
1176     case Intrinsic::masked_compressstore:
1177       if (TTI.isLegalMaskedCompressStore(
1178               CI->getArgOperand(0)->getType(),
1179               CI->getAttributes().getParamAttrs(1).getAlignment().valueOrOne()))
1180         return false;
1181       scalarizeMaskedCompressStore(DL, HasBranchDivergence, CI, DTU,
1182                                    ModifiedDT);
1183       return true;
1184     }
1185   }
1186 
1187   return false;
1188 }
1189