xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp (revision 5e801ac66d24704442eba426ed13c3effb8a34e7)
1 //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Performs general IR level optimizations on SVE intrinsics.
10 //
11 // This pass performs the following optimizations:
12 //
13 // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g:
14 //     %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
15 //     %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
16 //     ; (%1 can be replaced with a reinterpret of %2)
17 //
18 // - optimizes ptest intrinsics where the operands are being needlessly
19 //   converted to and from svbool_t.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "AArch64.h"
24 #include "Utils/AArch64BaseInfo.h"
25 #include "llvm/ADT/PostOrderIterator.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/IntrinsicsAArch64.h"
33 #include "llvm/IR/LLVMContext.h"
34 #include "llvm/IR/PatternMatch.h"
35 #include "llvm/InitializePasses.h"
36 #include "llvm/Support/Debug.h"
37 
38 using namespace llvm;
39 using namespace llvm::PatternMatch;
40 
41 #define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
42 
43 namespace llvm {
44 void initializeSVEIntrinsicOptsPass(PassRegistry &);
45 }
46 
47 namespace {
48 struct SVEIntrinsicOpts : public ModulePass {
49   static char ID; // Pass identification, replacement for typeid
50   SVEIntrinsicOpts() : ModulePass(ID) {
51     initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry());
52   }
53 
54   bool runOnModule(Module &M) override;
55   void getAnalysisUsage(AnalysisUsage &AU) const override;
56 
57 private:
58   bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
59                                    SmallSetVector<IntrinsicInst *, 4> &PTrues);
60   bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
61   bool optimizePredicateStore(Instruction *I);
62   bool optimizePredicateLoad(Instruction *I);
63 
64   bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
65 
66   /// Operates at the function-scope. I.e., optimizations are applied local to
67   /// the functions themselves.
68   bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
69 };
70 } // end anonymous namespace
71 
72 void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
73   AU.addRequired<DominatorTreeWrapperPass>();
74   AU.setPreservesCFG();
75 }
76 
77 char SVEIntrinsicOpts::ID = 0;
78 static const char *name = "SVE intrinsics optimizations";
79 INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
80 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
81 INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
82 
83 ModulePass *llvm::createSVEIntrinsicOptsPass() {
84   return new SVEIntrinsicOpts();
85 }
86 
87 /// Checks if a ptrue intrinsic call is promoted. The act of promoting a
88 /// ptrue will introduce zeroing. For example:
89 ///
90 ///     %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
91 ///     %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1)
92 ///     %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
93 ///
94 /// %1 is promoted, because it is converted:
95 ///
96 ///     <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1>
97 ///
98 /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool.
99 static bool isPTruePromoted(IntrinsicInst *PTrue) {
100   // Find all users of this intrinsic that are calls to convert-to-svbool
101   // reinterpret intrinsics.
102   SmallVector<IntrinsicInst *, 4> ConvertToUses;
103   for (User *User : PTrue->users()) {
104     if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
105       ConvertToUses.push_back(cast<IntrinsicInst>(User));
106     }
107   }
108 
109   // If no such calls were found, this is ptrue is not promoted.
110   if (ConvertToUses.empty())
111     return false;
112 
113   // Otherwise, try to find users of the convert-to-svbool intrinsics that are
114   // calls to the convert-from-svbool intrinsic, and would result in some lanes
115   // being zeroed.
116   const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType());
117   for (IntrinsicInst *ConvertToUse : ConvertToUses) {
118     for (User *User : ConvertToUse->users()) {
119       auto *IntrUser = dyn_cast<IntrinsicInst>(User);
120       if (IntrUser && IntrUser->getIntrinsicID() ==
121                           Intrinsic::aarch64_sve_convert_from_svbool) {
122         const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
123 
124         // Would some lanes become zeroed by the conversion?
125         if (IntrUserVTy->getElementCount().getKnownMinValue() >
126             PTrueVTy->getElementCount().getKnownMinValue())
127           // This is a promoted ptrue.
128           return true;
129       }
130     }
131   }
132 
133   // If no matching calls were found, this is not a promoted ptrue.
134   return false;
135 }
136 
137 /// Attempts to coalesce ptrues in a basic block.
138 bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
139     BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) {
140   if (PTrues.size() <= 1)
141     return false;
142 
143   // Find the ptrue with the most lanes.
144   auto *MostEncompassingPTrue = *std::max_element(
145       PTrues.begin(), PTrues.end(), [](auto *PTrue1, auto *PTrue2) {
146         auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
147         auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
148         return PTrue1VTy->getElementCount().getKnownMinValue() <
149                PTrue2VTy->getElementCount().getKnownMinValue();
150       });
151 
152   // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving
153   // behind only the ptrues to be coalesced.
154   PTrues.remove(MostEncompassingPTrue);
155   PTrues.remove_if([](auto *PTrue) { return isPTruePromoted(PTrue); });
156 
157   // Hoist MostEncompassingPTrue to the start of the basic block. It is always
158   // safe to do this, since ptrue intrinsic calls are guaranteed to have no
159   // predecessors.
160   MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt());
161 
162   LLVMContext &Ctx = BB.getContext();
163   IRBuilder<> Builder(Ctx);
164   Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
165 
166   auto *MostEncompassingPTrueVTy =
167       cast<VectorType>(MostEncompassingPTrue->getType());
168   auto *ConvertToSVBool = Builder.CreateIntrinsic(
169       Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
170       {MostEncompassingPTrue});
171 
172   bool ConvertFromCreated = false;
173   for (auto *PTrue : PTrues) {
174     auto *PTrueVTy = cast<VectorType>(PTrue->getType());
175 
176     // Only create the converts if the types are not already the same, otherwise
177     // just use the most encompassing ptrue.
178     if (MostEncompassingPTrueVTy != PTrueVTy) {
179       ConvertFromCreated = true;
180 
181       Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
182       auto *ConvertFromSVBool =
183           Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
184                                   {PTrueVTy}, {ConvertToSVBool});
185       PTrue->replaceAllUsesWith(ConvertFromSVBool);
186     } else
187       PTrue->replaceAllUsesWith(MostEncompassingPTrue);
188 
189     PTrue->eraseFromParent();
190   }
191 
192   // We never used the ConvertTo so remove it
193   if (!ConvertFromCreated)
194     ConvertToSVBool->eraseFromParent();
195 
196   return true;
197 }
198 
199 /// The goal of this function is to remove redundant calls to the SVE ptrue
200 /// intrinsic in each basic block within the given functions.
201 ///
202 /// SVE ptrues have two representations in LLVM IR:
203 /// - a logical representation -- an arbitrary-width scalable vector of i1s,
204 ///   i.e. <vscale x N x i1>.
205 /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element
206 ///   scalable vector of i1s, i.e. <vscale x 16 x i1>.
207 ///
208 /// The SVE ptrue intrinsic is used to create a logical representation of an SVE
209 /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If
210 /// P1 creates a logical SVE predicate that is at least as wide as the logical
211 /// SVE predicate created by P2, then all of the bits that are true in the
212 /// physical representation of P2 are necessarily also true in the physical
213 /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to
214 /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via
215 /// convert.{to,from}.svbool.
216 ///
217 /// Currently, this pass only coalesces calls to SVE ptrue intrinsics
218 /// if they match the following conditions:
219 ///
220 /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns.
221 ///   SV_ALL indicates that all bits of the predicate vector are to be set to
222 ///   true. SV_POW2 indicates that all bits of the predicate vector up to the
223 ///   largest power-of-two are to be set to true.
224 /// - the result of the call to the intrinsic is not promoted to a wider
225 ///   predicate. In this case, keeping the extra ptrue leads to better codegen
226 ///   -- coalescing here would create an irreducible chain of SVE reinterprets
227 ///   via convert.{to,from}.svbool.
228 ///
229 /// EXAMPLE:
230 ///
231 ///     %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL)
232 ///     ; Logical:  <1, 1, 1, 1, 1, 1, 1, 1>
233 ///     ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0>
234 ///     ...
235 ///
236 ///     %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL)
237 ///     ; Logical:  <1, 1, 1, 1>
238 ///     ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0>
239 ///     ...
240 ///
241 /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance:
242 ///
243 ///     %1 = <vscale x 8 x i1> ptrue(i32 i31)
244 ///     %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1)
245 ///     %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2)
246 ///
247 bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
248     SmallSetVector<Function *, 4> &Functions) {
249   bool Changed = false;
250 
251   for (auto *F : Functions) {
252     for (auto &BB : *F) {
253       SmallSetVector<IntrinsicInst *, 4> SVAllPTrues;
254       SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues;
255 
256       // For each basic block, collect the used ptrues and try to coalesce them.
257       for (Instruction &I : BB) {
258         if (I.use_empty())
259           continue;
260 
261         auto *IntrI = dyn_cast<IntrinsicInst>(&I);
262         if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
263           continue;
264 
265         const auto PTruePattern =
266             cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
267 
268         if (PTruePattern == AArch64SVEPredPattern::all)
269           SVAllPTrues.insert(IntrI);
270         if (PTruePattern == AArch64SVEPredPattern::pow2)
271           SVPow2PTrues.insert(IntrI);
272       }
273 
274       Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
275       Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
276     }
277   }
278 
279   return Changed;
280 }
281 
282 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
283 // scalable stores as late as possible
284 bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
285   auto *F = I->getFunction();
286   auto Attr = F->getFnAttribute(Attribute::VScaleRange);
287   if (!Attr.isValid())
288     return false;
289 
290   unsigned MinVScale, MaxVScale;
291   std::tie(MinVScale, MaxVScale) = Attr.getVScaleRangeArgs();
292   // The transform needs to know the exact runtime length of scalable vectors
293   if (MinVScale != MaxVScale || MinVScale == 0)
294     return false;
295 
296   auto *PredType =
297       ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
298   auto *FixedPredType =
299       FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
300 
301   // If we have a store..
302   auto *Store = dyn_cast<StoreInst>(I);
303   if (!Store || !Store->isSimple())
304     return false;
305 
306   // ..that is storing a predicate vector sized worth of bits..
307   if (Store->getOperand(0)->getType() != FixedPredType)
308     return false;
309 
310   // ..where the value stored comes from a vector extract..
311   auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0));
312   if (!IntrI ||
313       IntrI->getIntrinsicID() != Intrinsic::experimental_vector_extract)
314     return false;
315 
316   // ..that is extracting from index 0..
317   if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
318     return false;
319 
320   // ..where the value being extract from comes from a bitcast
321   auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
322   if (!BitCast)
323     return false;
324 
325   // ..and the bitcast is casting from predicate type
326   if (BitCast->getOperand(0)->getType() != PredType)
327     return false;
328 
329   IRBuilder<> Builder(I->getContext());
330   Builder.SetInsertPoint(I);
331 
332   auto *PtrBitCast = Builder.CreateBitCast(
333       Store->getPointerOperand(),
334       PredType->getPointerTo(Store->getPointerAddressSpace()));
335   Builder.CreateStore(BitCast->getOperand(0), PtrBitCast);
336 
337   Store->eraseFromParent();
338   if (IntrI->getNumUses() == 0)
339     IntrI->eraseFromParent();
340   if (BitCast->getNumUses() == 0)
341     BitCast->eraseFromParent();
342 
343   return true;
344 }
345 
346 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
347 // scalable loads as late as possible
348 bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
349   auto *F = I->getFunction();
350   auto Attr = F->getFnAttribute(Attribute::VScaleRange);
351   if (!Attr.isValid())
352     return false;
353 
354   unsigned MinVScale, MaxVScale;
355   std::tie(MinVScale, MaxVScale) = Attr.getVScaleRangeArgs();
356   // The transform needs to know the exact runtime length of scalable vectors
357   if (MinVScale != MaxVScale || MinVScale == 0)
358     return false;
359 
360   auto *PredType =
361       ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
362   auto *FixedPredType =
363       FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
364 
365   // If we have a bitcast..
366   auto *BitCast = dyn_cast<BitCastInst>(I);
367   if (!BitCast || BitCast->getType() != PredType)
368     return false;
369 
370   // ..whose operand is a vector_insert..
371   auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
372   if (!IntrI ||
373       IntrI->getIntrinsicID() != Intrinsic::experimental_vector_insert)
374     return false;
375 
376   // ..that is inserting into index zero of an undef vector..
377   if (!isa<UndefValue>(IntrI->getOperand(0)) ||
378       !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
379     return false;
380 
381   // ..where the value inserted comes from a load..
382   auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
383   if (!Load || !Load->isSimple())
384     return false;
385 
386   // ..that is loading a predicate vector sized worth of bits..
387   if (Load->getType() != FixedPredType)
388     return false;
389 
390   IRBuilder<> Builder(I->getContext());
391   Builder.SetInsertPoint(Load);
392 
393   auto *PtrBitCast = Builder.CreateBitCast(
394       Load->getPointerOperand(),
395       PredType->getPointerTo(Load->getPointerAddressSpace()));
396   auto *LoadPred = Builder.CreateLoad(PredType, PtrBitCast);
397 
398   BitCast->replaceAllUsesWith(LoadPred);
399   BitCast->eraseFromParent();
400   if (IntrI->getNumUses() == 0)
401     IntrI->eraseFromParent();
402   if (Load->getNumUses() == 0)
403     Load->eraseFromParent();
404 
405   return true;
406 }
407 
408 bool SVEIntrinsicOpts::optimizeInstructions(
409     SmallSetVector<Function *, 4> &Functions) {
410   bool Changed = false;
411 
412   for (auto *F : Functions) {
413     DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
414 
415     // Traverse the DT with an rpo walk so we see defs before uses, allowing
416     // simplification to be done incrementally.
417     BasicBlock *Root = DT->getRoot();
418     ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
419     for (auto *BB : RPOT) {
420       for (Instruction &I : make_early_inc_range(*BB)) {
421         switch (I.getOpcode()) {
422         case Instruction::Store:
423           Changed |= optimizePredicateStore(&I);
424           break;
425         case Instruction::BitCast:
426           Changed |= optimizePredicateLoad(&I);
427           break;
428         }
429       }
430     }
431   }
432 
433   return Changed;
434 }
435 
436 bool SVEIntrinsicOpts::optimizeFunctions(
437     SmallSetVector<Function *, 4> &Functions) {
438   bool Changed = false;
439 
440   Changed |= optimizePTrueIntrinsicCalls(Functions);
441   Changed |= optimizeInstructions(Functions);
442 
443   return Changed;
444 }
445 
446 bool SVEIntrinsicOpts::runOnModule(Module &M) {
447   bool Changed = false;
448   SmallSetVector<Function *, 4> Functions;
449 
450   // Check for SVE intrinsic declarations first so that we only iterate over
451   // relevant functions. Where an appropriate declaration is found, store the
452   // function(s) where it is used so we can target these only.
453   for (auto &F : M.getFunctionList()) {
454     if (!F.isDeclaration())
455       continue;
456 
457     switch (F.getIntrinsicID()) {
458     case Intrinsic::experimental_vector_extract:
459     case Intrinsic::experimental_vector_insert:
460     case Intrinsic::aarch64_sve_ptrue:
461       for (User *U : F.users())
462         Functions.insert(cast<Instruction>(U)->getFunction());
463       break;
464     default:
465       break;
466     }
467   }
468 
469   if (!Functions.empty())
470     Changed |= optimizeFunctions(Functions);
471 
472   return Changed;
473 }
474