xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp (revision 9c77fb6aaa366cbabc80ee1b834bcfe4df135491)
1 //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Performs general IR level optimizations on SVE intrinsics.
10 //
11 // This pass performs the following optimizations:
12 //
13 // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g:
14 //     %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
15 //     %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
16 //     ; (%1 can be replaced with a reinterpret of %2)
17 //
18 // - optimizes ptest intrinsics where the operands are being needlessly
19 //   converted to and from svbool_t.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "AArch64.h"
24 #include "Utils/AArch64BaseInfo.h"
25 #include "llvm/ADT/PostOrderIterator.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/IR/Constants.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/IntrinsicInst.h"
32 #include "llvm/IR/IntrinsicsAArch64.h"
33 #include "llvm/IR/LLVMContext.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PatternMatch.h"
36 #include "llvm/InitializePasses.h"
37 #include <optional>
38 
39 using namespace llvm;
40 using namespace llvm::PatternMatch;
41 
42 #define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
43 
44 namespace {
45 struct SVEIntrinsicOpts : public ModulePass {
46   static char ID; // Pass identification, replacement for typeid
47   SVEIntrinsicOpts() : ModulePass(ID) {}
48 
49   bool runOnModule(Module &M) override;
50   void getAnalysisUsage(AnalysisUsage &AU) const override;
51 
52 private:
53   bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
54                                    SmallSetVector<IntrinsicInst *, 4> &PTrues);
55   bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
56   bool optimizePredicateStore(Instruction *I);
57   bool optimizePredicateLoad(Instruction *I);
58 
59   bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
60 
61   /// Operates at the function-scope. I.e., optimizations are applied local to
62   /// the functions themselves.
63   bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
64 };
65 } // end anonymous namespace
66 
67 void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
68   AU.addRequired<DominatorTreeWrapperPass>();
69   AU.setPreservesCFG();
70 }
71 
72 char SVEIntrinsicOpts::ID = 0;
73 static const char *name = "SVE intrinsics optimizations";
74 INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
75 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
76 INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
77 
78 ModulePass *llvm::createSVEIntrinsicOptsPass() {
79   return new SVEIntrinsicOpts();
80 }
81 
82 /// Checks if a ptrue intrinsic call is promoted. The act of promoting a
83 /// ptrue will introduce zeroing. For example:
84 ///
85 ///     %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
86 ///     %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1)
87 ///     %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
88 ///
89 /// %1 is promoted, because it is converted:
90 ///
91 ///     <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1>
92 ///
93 /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool.
94 static bool isPTruePromoted(IntrinsicInst *PTrue) {
95   // Find all users of this intrinsic that are calls to convert-to-svbool
96   // reinterpret intrinsics.
97   SmallVector<IntrinsicInst *, 4> ConvertToUses;
98   for (User *User : PTrue->users()) {
99     if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
100       ConvertToUses.push_back(cast<IntrinsicInst>(User));
101     }
102   }
103 
104   // If no such calls were found, this is ptrue is not promoted.
105   if (ConvertToUses.empty())
106     return false;
107 
108   // Otherwise, try to find users of the convert-to-svbool intrinsics that are
109   // calls to the convert-from-svbool intrinsic, and would result in some lanes
110   // being zeroed.
111   const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType());
112   for (IntrinsicInst *ConvertToUse : ConvertToUses) {
113     for (User *User : ConvertToUse->users()) {
114       auto *IntrUser = dyn_cast<IntrinsicInst>(User);
115       if (IntrUser && IntrUser->getIntrinsicID() ==
116                           Intrinsic::aarch64_sve_convert_from_svbool) {
117         const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
118 
119         // Would some lanes become zeroed by the conversion?
120         if (IntrUserVTy->getElementCount().getKnownMinValue() >
121             PTrueVTy->getElementCount().getKnownMinValue())
122           // This is a promoted ptrue.
123           return true;
124       }
125     }
126   }
127 
128   // If no matching calls were found, this is not a promoted ptrue.
129   return false;
130 }
131 
132 /// Attempts to coalesce ptrues in a basic block.
133 bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
134     BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) {
135   if (PTrues.size() <= 1)
136     return false;
137 
138   // Find the ptrue with the most lanes.
139   auto *MostEncompassingPTrue =
140       *llvm::max_element(PTrues, [](auto *PTrue1, auto *PTrue2) {
141         auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
142         auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
143         return PTrue1VTy->getElementCount().getKnownMinValue() <
144                PTrue2VTy->getElementCount().getKnownMinValue();
145       });
146 
147   // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving
148   // behind only the ptrues to be coalesced.
149   PTrues.remove(MostEncompassingPTrue);
150   PTrues.remove_if(isPTruePromoted);
151 
152   // Hoist MostEncompassingPTrue to the start of the basic block. It is always
153   // safe to do this, since ptrue intrinsic calls are guaranteed to have no
154   // predecessors.
155   MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt());
156 
157   LLVMContext &Ctx = BB.getContext();
158   IRBuilder<> Builder(Ctx);
159   Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
160 
161   auto *MostEncompassingPTrueVTy =
162       cast<VectorType>(MostEncompassingPTrue->getType());
163   auto *ConvertToSVBool = Builder.CreateIntrinsic(
164       Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
165       {MostEncompassingPTrue});
166 
167   bool ConvertFromCreated = false;
168   for (auto *PTrue : PTrues) {
169     auto *PTrueVTy = cast<VectorType>(PTrue->getType());
170 
171     // Only create the converts if the types are not already the same, otherwise
172     // just use the most encompassing ptrue.
173     if (MostEncompassingPTrueVTy != PTrueVTy) {
174       ConvertFromCreated = true;
175 
176       Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
177       auto *ConvertFromSVBool =
178           Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
179                                   {PTrueVTy}, {ConvertToSVBool});
180       PTrue->replaceAllUsesWith(ConvertFromSVBool);
181     } else
182       PTrue->replaceAllUsesWith(MostEncompassingPTrue);
183 
184     PTrue->eraseFromParent();
185   }
186 
187   // We never used the ConvertTo so remove it
188   if (!ConvertFromCreated)
189     ConvertToSVBool->eraseFromParent();
190 
191   return true;
192 }
193 
194 /// The goal of this function is to remove redundant calls to the SVE ptrue
195 /// intrinsic in each basic block within the given functions.
196 ///
197 /// SVE ptrues have two representations in LLVM IR:
198 /// - a logical representation -- an arbitrary-width scalable vector of i1s,
199 ///   i.e. <vscale x N x i1>.
200 /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element
201 ///   scalable vector of i1s, i.e. <vscale x 16 x i1>.
202 ///
203 /// The SVE ptrue intrinsic is used to create a logical representation of an SVE
204 /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If
205 /// P1 creates a logical SVE predicate that is at least as wide as the logical
206 /// SVE predicate created by P2, then all of the bits that are true in the
207 /// physical representation of P2 are necessarily also true in the physical
208 /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to
209 /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via
210 /// convert.{to,from}.svbool.
211 ///
212 /// Currently, this pass only coalesces calls to SVE ptrue intrinsics
213 /// if they match the following conditions:
214 ///
215 /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns.
216 ///   SV_ALL indicates that all bits of the predicate vector are to be set to
217 ///   true. SV_POW2 indicates that all bits of the predicate vector up to the
218 ///   largest power-of-two are to be set to true.
219 /// - the result of the call to the intrinsic is not promoted to a wider
220 ///   predicate. In this case, keeping the extra ptrue leads to better codegen
221 ///   -- coalescing here would create an irreducible chain of SVE reinterprets
222 ///   via convert.{to,from}.svbool.
223 ///
224 /// EXAMPLE:
225 ///
226 ///     %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL)
227 ///     ; Logical:  <1, 1, 1, 1, 1, 1, 1, 1>
228 ///     ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0>
229 ///     ...
230 ///
231 ///     %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL)
232 ///     ; Logical:  <1, 1, 1, 1>
233 ///     ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0>
234 ///     ...
235 ///
236 /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance:
237 ///
238 ///     %1 = <vscale x 8 x i1> ptrue(i32 i31)
239 ///     %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1)
240 ///     %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2)
241 ///
242 bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
243     SmallSetVector<Function *, 4> &Functions) {
244   bool Changed = false;
245 
246   for (auto *F : Functions) {
247     for (auto &BB : *F) {
248       SmallSetVector<IntrinsicInst *, 4> SVAllPTrues;
249       SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues;
250 
251       // For each basic block, collect the used ptrues and try to coalesce them.
252       for (Instruction &I : BB) {
253         if (I.use_empty())
254           continue;
255 
256         auto *IntrI = dyn_cast<IntrinsicInst>(&I);
257         if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
258           continue;
259 
260         const auto PTruePattern =
261             cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
262 
263         if (PTruePattern == AArch64SVEPredPattern::all)
264           SVAllPTrues.insert(IntrI);
265         if (PTruePattern == AArch64SVEPredPattern::pow2)
266           SVPow2PTrues.insert(IntrI);
267       }
268 
269       Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
270       Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
271     }
272   }
273 
274   return Changed;
275 }
276 
277 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
278 // scalable stores as late as possible
279 bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
280   auto *F = I->getFunction();
281   auto Attr = F->getFnAttribute(Attribute::VScaleRange);
282   if (!Attr.isValid())
283     return false;
284 
285   unsigned MinVScale = Attr.getVScaleRangeMin();
286   std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
287   // The transform needs to know the exact runtime length of scalable vectors
288   if (!MaxVScale || MinVScale != MaxVScale)
289     return false;
290 
291   auto *PredType =
292       ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
293   auto *FixedPredType =
294       FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
295 
296   // If we have a store..
297   auto *Store = dyn_cast<StoreInst>(I);
298   if (!Store || !Store->isSimple())
299     return false;
300 
301   // ..that is storing a predicate vector sized worth of bits..
302   if (Store->getOperand(0)->getType() != FixedPredType)
303     return false;
304 
305   // ..where the value stored comes from a vector extract..
306   auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0));
307   if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract)
308     return false;
309 
310   // ..that is extracting from index 0..
311   if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
312     return false;
313 
314   // ..where the value being extract from comes from a bitcast
315   auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
316   if (!BitCast)
317     return false;
318 
319   // ..and the bitcast is casting from predicate type
320   if (BitCast->getOperand(0)->getType() != PredType)
321     return false;
322 
323   IRBuilder<> Builder(I->getContext());
324   Builder.SetInsertPoint(I);
325 
326   Builder.CreateStore(BitCast->getOperand(0), Store->getPointerOperand());
327 
328   Store->eraseFromParent();
329   if (IntrI->use_empty())
330     IntrI->eraseFromParent();
331   if (BitCast->use_empty())
332     BitCast->eraseFromParent();
333 
334   return true;
335 }
336 
337 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
338 // scalable loads as late as possible
339 bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
340   auto *F = I->getFunction();
341   auto Attr = F->getFnAttribute(Attribute::VScaleRange);
342   if (!Attr.isValid())
343     return false;
344 
345   unsigned MinVScale = Attr.getVScaleRangeMin();
346   std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
347   // The transform needs to know the exact runtime length of scalable vectors
348   if (!MaxVScale || MinVScale != MaxVScale)
349     return false;
350 
351   auto *PredType =
352       ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
353   auto *FixedPredType =
354       FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
355 
356   // If we have a bitcast..
357   auto *BitCast = dyn_cast<BitCastInst>(I);
358   if (!BitCast || BitCast->getType() != PredType)
359     return false;
360 
361   // ..whose operand is a vector_insert..
362   auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
363   if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert)
364     return false;
365 
366   // ..that is inserting into index zero of an undef vector..
367   if (!isa<UndefValue>(IntrI->getOperand(0)) ||
368       !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
369     return false;
370 
371   // ..where the value inserted comes from a load..
372   auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
373   if (!Load || !Load->isSimple())
374     return false;
375 
376   // ..that is loading a predicate vector sized worth of bits..
377   if (Load->getType() != FixedPredType)
378     return false;
379 
380   IRBuilder<> Builder(I->getContext());
381   Builder.SetInsertPoint(Load);
382 
383   auto *LoadPred = Builder.CreateLoad(PredType, Load->getPointerOperand());
384 
385   BitCast->replaceAllUsesWith(LoadPred);
386   BitCast->eraseFromParent();
387   if (IntrI->use_empty())
388     IntrI->eraseFromParent();
389   if (Load->use_empty())
390     Load->eraseFromParent();
391 
392   return true;
393 }
394 
395 bool SVEIntrinsicOpts::optimizeInstructions(
396     SmallSetVector<Function *, 4> &Functions) {
397   bool Changed = false;
398 
399   for (auto *F : Functions) {
400     DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
401 
402     // Traverse the DT with an rpo walk so we see defs before uses, allowing
403     // simplification to be done incrementally.
404     BasicBlock *Root = DT->getRoot();
405     ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
406     for (auto *BB : RPOT) {
407       for (Instruction &I : make_early_inc_range(*BB)) {
408         switch (I.getOpcode()) {
409         case Instruction::Store:
410           Changed |= optimizePredicateStore(&I);
411           break;
412         case Instruction::BitCast:
413           Changed |= optimizePredicateLoad(&I);
414           break;
415         }
416       }
417     }
418   }
419 
420   return Changed;
421 }
422 
423 bool SVEIntrinsicOpts::optimizeFunctions(
424     SmallSetVector<Function *, 4> &Functions) {
425   bool Changed = false;
426 
427   Changed |= optimizePTrueIntrinsicCalls(Functions);
428   Changed |= optimizeInstructions(Functions);
429 
430   return Changed;
431 }
432 
433 bool SVEIntrinsicOpts::runOnModule(Module &M) {
434   bool Changed = false;
435   SmallSetVector<Function *, 4> Functions;
436 
437   // Check for SVE intrinsic declarations first so that we only iterate over
438   // relevant functions. Where an appropriate declaration is found, store the
439   // function(s) where it is used so we can target these only.
440   for (auto &F : M.getFunctionList()) {
441     if (!F.isDeclaration())
442       continue;
443 
444     switch (F.getIntrinsicID()) {
445     case Intrinsic::vector_extract:
446     case Intrinsic::vector_insert:
447     case Intrinsic::aarch64_sve_ptrue:
448       for (User *U : F.users())
449         Functions.insert(cast<Instruction>(U)->getFunction());
450       break;
451     default:
452       break;
453     }
454   }
455 
456   if (!Functions.empty())
457     Changed |= optimizeFunctions(Functions);
458 
459   return Changed;
460 }
461