xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Performs general IR level optimizations on SVE intrinsics.
10 //
11 // This pass performs the following optimizations:
12 //
13 // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g:
14 //     %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
15 //     %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
16 //     ; (%1 can be replaced with a reinterpret of %2)
17 //
18 // - optimizes ptest intrinsics where the operands are being needlessly
19 //   converted to and from svbool_t.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "AArch64.h"
24 #include "Utils/AArch64BaseInfo.h"
25 #include "llvm/ADT/PostOrderIterator.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/IntrinsicsAArch64.h"
33 #include "llvm/IR/LLVMContext.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PatternMatch.h"
36 #include "llvm/InitializePasses.h"
37 #include "llvm/Support/Debug.h"
38 #include <optional>
39 
40 using namespace llvm;
41 using namespace llvm::PatternMatch;
42 
43 #define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
44 
45 namespace {
46 struct SVEIntrinsicOpts : public ModulePass {
47   static char ID; // Pass identification, replacement for typeid
SVEIntrinsicOpts__anoncdf9f3250111::SVEIntrinsicOpts48   SVEIntrinsicOpts() : ModulePass(ID) {
49     initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry());
50   }
51 
52   bool runOnModule(Module &M) override;
53   void getAnalysisUsage(AnalysisUsage &AU) const override;
54 
55 private:
56   bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
57                                    SmallSetVector<IntrinsicInst *, 4> &PTrues);
58   bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
59   bool optimizePredicateStore(Instruction *I);
60   bool optimizePredicateLoad(Instruction *I);
61 
62   bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
63 
64   /// Operates at the function-scope. I.e., optimizations are applied local to
65   /// the functions themselves.
66   bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
67 };
68 } // end anonymous namespace
69 
getAnalysisUsage(AnalysisUsage & AU) const70 void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
71   AU.addRequired<DominatorTreeWrapperPass>();
72   AU.setPreservesCFG();
73 }
74 
75 char SVEIntrinsicOpts::ID = 0;
76 static const char *name = "SVE intrinsics optimizations";
77 INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
78 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
INITIALIZE_PASS_END(SVEIntrinsicOpts,DEBUG_TYPE,name,false,false)79 INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
80 
81 ModulePass *llvm::createSVEIntrinsicOptsPass() {
82   return new SVEIntrinsicOpts();
83 }
84 
85 /// Checks if a ptrue intrinsic call is promoted. The act of promoting a
86 /// ptrue will introduce zeroing. For example:
87 ///
88 ///     %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
89 ///     %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1)
90 ///     %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
91 ///
92 /// %1 is promoted, because it is converted:
93 ///
94 ///     <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1>
95 ///
96 /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool.
isPTruePromoted(IntrinsicInst * PTrue)97 static bool isPTruePromoted(IntrinsicInst *PTrue) {
98   // Find all users of this intrinsic that are calls to convert-to-svbool
99   // reinterpret intrinsics.
100   SmallVector<IntrinsicInst *, 4> ConvertToUses;
101   for (User *User : PTrue->users()) {
102     if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
103       ConvertToUses.push_back(cast<IntrinsicInst>(User));
104     }
105   }
106 
107   // If no such calls were found, this is ptrue is not promoted.
108   if (ConvertToUses.empty())
109     return false;
110 
111   // Otherwise, try to find users of the convert-to-svbool intrinsics that are
112   // calls to the convert-from-svbool intrinsic, and would result in some lanes
113   // being zeroed.
114   const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType());
115   for (IntrinsicInst *ConvertToUse : ConvertToUses) {
116     for (User *User : ConvertToUse->users()) {
117       auto *IntrUser = dyn_cast<IntrinsicInst>(User);
118       if (IntrUser && IntrUser->getIntrinsicID() ==
119                           Intrinsic::aarch64_sve_convert_from_svbool) {
120         const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
121 
122         // Would some lanes become zeroed by the conversion?
123         if (IntrUserVTy->getElementCount().getKnownMinValue() >
124             PTrueVTy->getElementCount().getKnownMinValue())
125           // This is a promoted ptrue.
126           return true;
127       }
128     }
129   }
130 
131   // If no matching calls were found, this is not a promoted ptrue.
132   return false;
133 }
134 
135 /// Attempts to coalesce ptrues in a basic block.
coalescePTrueIntrinsicCalls(BasicBlock & BB,SmallSetVector<IntrinsicInst *,4> & PTrues)136 bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
137     BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) {
138   if (PTrues.size() <= 1)
139     return false;
140 
141   // Find the ptrue with the most lanes.
142   auto *MostEncompassingPTrue =
143       *llvm::max_element(PTrues, [](auto *PTrue1, auto *PTrue2) {
144         auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
145         auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
146         return PTrue1VTy->getElementCount().getKnownMinValue() <
147                PTrue2VTy->getElementCount().getKnownMinValue();
148       });
149 
150   // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving
151   // behind only the ptrues to be coalesced.
152   PTrues.remove(MostEncompassingPTrue);
153   PTrues.remove_if(isPTruePromoted);
154 
155   // Hoist MostEncompassingPTrue to the start of the basic block. It is always
156   // safe to do this, since ptrue intrinsic calls are guaranteed to have no
157   // predecessors.
158   MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt());
159 
160   LLVMContext &Ctx = BB.getContext();
161   IRBuilder<> Builder(Ctx);
162   Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
163 
164   auto *MostEncompassingPTrueVTy =
165       cast<VectorType>(MostEncompassingPTrue->getType());
166   auto *ConvertToSVBool = Builder.CreateIntrinsic(
167       Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
168       {MostEncompassingPTrue});
169 
170   bool ConvertFromCreated = false;
171   for (auto *PTrue : PTrues) {
172     auto *PTrueVTy = cast<VectorType>(PTrue->getType());
173 
174     // Only create the converts if the types are not already the same, otherwise
175     // just use the most encompassing ptrue.
176     if (MostEncompassingPTrueVTy != PTrueVTy) {
177       ConvertFromCreated = true;
178 
179       Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
180       auto *ConvertFromSVBool =
181           Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
182                                   {PTrueVTy}, {ConvertToSVBool});
183       PTrue->replaceAllUsesWith(ConvertFromSVBool);
184     } else
185       PTrue->replaceAllUsesWith(MostEncompassingPTrue);
186 
187     PTrue->eraseFromParent();
188   }
189 
190   // We never used the ConvertTo so remove it
191   if (!ConvertFromCreated)
192     ConvertToSVBool->eraseFromParent();
193 
194   return true;
195 }
196 
197 /// The goal of this function is to remove redundant calls to the SVE ptrue
198 /// intrinsic in each basic block within the given functions.
199 ///
200 /// SVE ptrues have two representations in LLVM IR:
201 /// - a logical representation -- an arbitrary-width scalable vector of i1s,
202 ///   i.e. <vscale x N x i1>.
203 /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element
204 ///   scalable vector of i1s, i.e. <vscale x 16 x i1>.
205 ///
206 /// The SVE ptrue intrinsic is used to create a logical representation of an SVE
207 /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If
208 /// P1 creates a logical SVE predicate that is at least as wide as the logical
209 /// SVE predicate created by P2, then all of the bits that are true in the
210 /// physical representation of P2 are necessarily also true in the physical
211 /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to
212 /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via
213 /// convert.{to,from}.svbool.
214 ///
215 /// Currently, this pass only coalesces calls to SVE ptrue intrinsics
216 /// if they match the following conditions:
217 ///
218 /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns.
219 ///   SV_ALL indicates that all bits of the predicate vector are to be set to
220 ///   true. SV_POW2 indicates that all bits of the predicate vector up to the
221 ///   largest power-of-two are to be set to true.
222 /// - the result of the call to the intrinsic is not promoted to a wider
223 ///   predicate. In this case, keeping the extra ptrue leads to better codegen
224 ///   -- coalescing here would create an irreducible chain of SVE reinterprets
225 ///   via convert.{to,from}.svbool.
226 ///
227 /// EXAMPLE:
228 ///
229 ///     %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL)
230 ///     ; Logical:  <1, 1, 1, 1, 1, 1, 1, 1>
231 ///     ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0>
232 ///     ...
233 ///
234 ///     %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL)
235 ///     ; Logical:  <1, 1, 1, 1>
236 ///     ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0>
237 ///     ...
238 ///
239 /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance:
240 ///
241 ///     %1 = <vscale x 8 x i1> ptrue(i32 i31)
242 ///     %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1)
243 ///     %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2)
244 ///
optimizePTrueIntrinsicCalls(SmallSetVector<Function *,4> & Functions)245 bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
246     SmallSetVector<Function *, 4> &Functions) {
247   bool Changed = false;
248 
249   for (auto *F : Functions) {
250     for (auto &BB : *F) {
251       SmallSetVector<IntrinsicInst *, 4> SVAllPTrues;
252       SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues;
253 
254       // For each basic block, collect the used ptrues and try to coalesce them.
255       for (Instruction &I : BB) {
256         if (I.use_empty())
257           continue;
258 
259         auto *IntrI = dyn_cast<IntrinsicInst>(&I);
260         if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
261           continue;
262 
263         const auto PTruePattern =
264             cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
265 
266         if (PTruePattern == AArch64SVEPredPattern::all)
267           SVAllPTrues.insert(IntrI);
268         if (PTruePattern == AArch64SVEPredPattern::pow2)
269           SVPow2PTrues.insert(IntrI);
270       }
271 
272       Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
273       Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
274     }
275   }
276 
277   return Changed;
278 }
279 
280 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
281 // scalable stores as late as possible
optimizePredicateStore(Instruction * I)282 bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
283   auto *F = I->getFunction();
284   auto Attr = F->getFnAttribute(Attribute::VScaleRange);
285   if (!Attr.isValid())
286     return false;
287 
288   unsigned MinVScale = Attr.getVScaleRangeMin();
289   std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
290   // The transform needs to know the exact runtime length of scalable vectors
291   if (!MaxVScale || MinVScale != MaxVScale)
292     return false;
293 
294   auto *PredType =
295       ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
296   auto *FixedPredType =
297       FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
298 
299   // If we have a store..
300   auto *Store = dyn_cast<StoreInst>(I);
301   if (!Store || !Store->isSimple())
302     return false;
303 
304   // ..that is storing a predicate vector sized worth of bits..
305   if (Store->getOperand(0)->getType() != FixedPredType)
306     return false;
307 
308   // ..where the value stored comes from a vector extract..
309   auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0));
310   if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
311     return false;
312 
313   // ..that is extracting from index 0..
314   if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
315     return false;
316 
317   // ..where the value being extract from comes from a bitcast
318   auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
319   if (!BitCast)
320     return false;
321 
322   // ..and the bitcast is casting from predicate type
323   if (BitCast->getOperand(0)->getType() != PredType)
324     return false;
325 
326   IRBuilder<> Builder(I->getContext());
327   Builder.SetInsertPoint(I);
328 
329   Builder.CreateStore(BitCast->getOperand(0), Store->getPointerOperand());
330 
331   Store->eraseFromParent();
332   if (IntrI->getNumUses() == 0)
333     IntrI->eraseFromParent();
334   if (BitCast->getNumUses() == 0)
335     BitCast->eraseFromParent();
336 
337   return true;
338 }
339 
340 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
341 // scalable loads as late as possible
optimizePredicateLoad(Instruction * I)342 bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
343   auto *F = I->getFunction();
344   auto Attr = F->getFnAttribute(Attribute::VScaleRange);
345   if (!Attr.isValid())
346     return false;
347 
348   unsigned MinVScale = Attr.getVScaleRangeMin();
349   std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
350   // The transform needs to know the exact runtime length of scalable vectors
351   if (!MaxVScale || MinVScale != MaxVScale)
352     return false;
353 
354   auto *PredType =
355       ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
356   auto *FixedPredType =
357       FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
358 
359   // If we have a bitcast..
360   auto *BitCast = dyn_cast<BitCastInst>(I);
361   if (!BitCast || BitCast->getType() != PredType)
362     return false;
363 
364   // ..whose operand is a vector_insert..
365   auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
366   if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
367     return false;
368 
369   // ..that is inserting into index zero of an undef vector..
370   if (!isa<UndefValue>(IntrI->getOperand(0)) ||
371       !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
372     return false;
373 
374   // ..where the value inserted comes from a load..
375   auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
376   if (!Load || !Load->isSimple())
377     return false;
378 
379   // ..that is loading a predicate vector sized worth of bits..
380   if (Load->getType() != FixedPredType)
381     return false;
382 
383   IRBuilder<> Builder(I->getContext());
384   Builder.SetInsertPoint(Load);
385 
386   auto *LoadPred = Builder.CreateLoad(PredType, Load->getPointerOperand());
387 
388   BitCast->replaceAllUsesWith(LoadPred);
389   BitCast->eraseFromParent();
390   if (IntrI->getNumUses() == 0)
391     IntrI->eraseFromParent();
392   if (Load->getNumUses() == 0)
393     Load->eraseFromParent();
394 
395   return true;
396 }
397 
optimizeInstructions(SmallSetVector<Function *,4> & Functions)398 bool SVEIntrinsicOpts::optimizeInstructions(
399     SmallSetVector<Function *, 4> &Functions) {
400   bool Changed = false;
401 
402   for (auto *F : Functions) {
403     DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
404 
405     // Traverse the DT with an rpo walk so we see defs before uses, allowing
406     // simplification to be done incrementally.
407     BasicBlock *Root = DT->getRoot();
408     ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
409     for (auto *BB : RPOT) {
410       for (Instruction &I : make_early_inc_range(*BB)) {
411         switch (I.getOpcode()) {
412         case Instruction::Store:
413           Changed |= optimizePredicateStore(&I);
414           break;
415         case Instruction::BitCast:
416           Changed |= optimizePredicateLoad(&I);
417           break;
418         }
419       }
420     }
421   }
422 
423   return Changed;
424 }
425 
optimizeFunctions(SmallSetVector<Function *,4> & Functions)426 bool SVEIntrinsicOpts::optimizeFunctions(
427     SmallSetVector<Function *, 4> &Functions) {
428   bool Changed = false;
429 
430   Changed |= optimizePTrueIntrinsicCalls(Functions);
431   Changed |= optimizeInstructions(Functions);
432 
433   return Changed;
434 }
435 
runOnModule(Module & M)436 bool SVEIntrinsicOpts::runOnModule(Module &M) {
437   bool Changed = false;
438   SmallSetVector<Function *, 4> Functions;
439 
440   // Check for SVE intrinsic declarations first so that we only iterate over
441   // relevant functions. Where an appropriate declaration is found, store the
442   // function(s) where it is used so we can target these only.
443   for (auto &F : M.getFunctionList()) {
444     if (!F.isDeclaration())
445       continue;
446 
447     switch (F.getIntrinsicID()) {
448     case Intrinsic::vector_extract:
449     case Intrinsic::vector_insert:
450     case Intrinsic::aarch64_sve_ptrue:
451       for (User *U : F.users())
452         Functions.insert(cast<Instruction>(U)->getFunction());
453       break;
454     default:
455       break;
456     }
457   }
458 
459   if (!Functions.empty())
460     Changed |= optimizeFunctions(Functions);
461 
462   return Changed;
463 }
464