xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp (revision 0eae32dcef82f6f06de6419a0d623d7def0cc8f6)
15ffd83dbSDimitry Andric //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===//
25ffd83dbSDimitry Andric //
3349cc55cSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4349cc55cSDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5349cc55cSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
65ffd83dbSDimitry Andric //
75ffd83dbSDimitry Andric //===----------------------------------------------------------------------===//
85ffd83dbSDimitry Andric //
95ffd83dbSDimitry Andric // Performs general IR level optimizations on SVE intrinsics.
105ffd83dbSDimitry Andric //
11fe6060f1SDimitry Andric // This pass performs the following optimizations:
125ffd83dbSDimitry Andric //
13fe6060f1SDimitry Andric // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g:
14fe6060f1SDimitry Andric //     %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
15fe6060f1SDimitry Andric //     %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31)
16fe6060f1SDimitry Andric //     ; (%1 can be replaced with a reinterpret of %2)
175ffd83dbSDimitry Andric //
18fe6060f1SDimitry Andric // - optimizes ptest intrinsics where the operands are being needlessly
19fe6060f1SDimitry Andric //   converted to and from svbool_t.
205ffd83dbSDimitry Andric //
215ffd83dbSDimitry Andric //===----------------------------------------------------------------------===//
225ffd83dbSDimitry Andric 
23fe6060f1SDimitry Andric #include "AArch64.h"
245ffd83dbSDimitry Andric #include "Utils/AArch64BaseInfo.h"
255ffd83dbSDimitry Andric #include "llvm/ADT/PostOrderIterator.h"
265ffd83dbSDimitry Andric #include "llvm/ADT/SetVector.h"
275ffd83dbSDimitry Andric #include "llvm/IR/Constants.h"
285ffd83dbSDimitry Andric #include "llvm/IR/Dominators.h"
295ffd83dbSDimitry Andric #include "llvm/IR/IRBuilder.h"
305ffd83dbSDimitry Andric #include "llvm/IR/Instructions.h"
315ffd83dbSDimitry Andric #include "llvm/IR/IntrinsicInst.h"
325ffd83dbSDimitry Andric #include "llvm/IR/IntrinsicsAArch64.h"
335ffd83dbSDimitry Andric #include "llvm/IR/LLVMContext.h"
345ffd83dbSDimitry Andric #include "llvm/IR/PatternMatch.h"
355ffd83dbSDimitry Andric #include "llvm/InitializePasses.h"
365ffd83dbSDimitry Andric #include "llvm/Support/Debug.h"
375ffd83dbSDimitry Andric 
385ffd83dbSDimitry Andric using namespace llvm;
395ffd83dbSDimitry Andric using namespace llvm::PatternMatch;
405ffd83dbSDimitry Andric 
41e8d8bef9SDimitry Andric #define DEBUG_TYPE "aarch64-sve-intrinsic-opts"
425ffd83dbSDimitry Andric 
435ffd83dbSDimitry Andric namespace llvm {
445ffd83dbSDimitry Andric void initializeSVEIntrinsicOptsPass(PassRegistry &);
455ffd83dbSDimitry Andric }
465ffd83dbSDimitry Andric 
475ffd83dbSDimitry Andric namespace {
485ffd83dbSDimitry Andric struct SVEIntrinsicOpts : public ModulePass {
495ffd83dbSDimitry Andric   static char ID; // Pass identification, replacement for typeid
505ffd83dbSDimitry Andric   SVEIntrinsicOpts() : ModulePass(ID) {
515ffd83dbSDimitry Andric     initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry());
525ffd83dbSDimitry Andric   }
535ffd83dbSDimitry Andric 
545ffd83dbSDimitry Andric   bool runOnModule(Module &M) override;
555ffd83dbSDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override;
565ffd83dbSDimitry Andric 
575ffd83dbSDimitry Andric private:
58fe6060f1SDimitry Andric   bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
59fe6060f1SDimitry Andric                                    SmallSetVector<IntrinsicInst *, 4> &PTrues);
60fe6060f1SDimitry Andric   bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
61349cc55cSDimitry Andric   bool optimizePredicateStore(Instruction *I);
62349cc55cSDimitry Andric   bool optimizePredicateLoad(Instruction *I);
63349cc55cSDimitry Andric 
64349cc55cSDimitry Andric   bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
655ffd83dbSDimitry Andric 
66fe6060f1SDimitry Andric   /// Operates at the function-scope. I.e., optimizations are applied local to
67fe6060f1SDimitry Andric   /// the functions themselves.
685ffd83dbSDimitry Andric   bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions);
695ffd83dbSDimitry Andric };
705ffd83dbSDimitry Andric } // end anonymous namespace
715ffd83dbSDimitry Andric 
725ffd83dbSDimitry Andric void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const {
735ffd83dbSDimitry Andric   AU.addRequired<DominatorTreeWrapperPass>();
745ffd83dbSDimitry Andric   AU.setPreservesCFG();
755ffd83dbSDimitry Andric }
765ffd83dbSDimitry Andric 
775ffd83dbSDimitry Andric char SVEIntrinsicOpts::ID = 0;
785ffd83dbSDimitry Andric static const char *name = "SVE intrinsics optimizations";
795ffd83dbSDimitry Andric INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
805ffd83dbSDimitry Andric INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass);
815ffd83dbSDimitry Andric INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false)
825ffd83dbSDimitry Andric 
83fe6060f1SDimitry Andric ModulePass *llvm::createSVEIntrinsicOptsPass() {
84fe6060f1SDimitry Andric   return new SVEIntrinsicOpts();
855ffd83dbSDimitry Andric }
865ffd83dbSDimitry Andric 
87fe6060f1SDimitry Andric /// Checks if a ptrue intrinsic call is promoted. The act of promoting a
88fe6060f1SDimitry Andric /// ptrue will introduce zeroing. For example:
89fe6060f1SDimitry Andric ///
90fe6060f1SDimitry Andric ///     %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
91fe6060f1SDimitry Andric ///     %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1)
92fe6060f1SDimitry Andric ///     %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2)
93fe6060f1SDimitry Andric ///
94fe6060f1SDimitry Andric /// %1 is promoted, because it is converted:
95fe6060f1SDimitry Andric ///
96fe6060f1SDimitry Andric ///     <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1>
97fe6060f1SDimitry Andric ///
98fe6060f1SDimitry Andric /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool.
99fe6060f1SDimitry Andric static bool isPTruePromoted(IntrinsicInst *PTrue) {
100fe6060f1SDimitry Andric   // Find all users of this intrinsic that are calls to convert-to-svbool
101fe6060f1SDimitry Andric   // reinterpret intrinsics.
102fe6060f1SDimitry Andric   SmallVector<IntrinsicInst *, 4> ConvertToUses;
103fe6060f1SDimitry Andric   for (User *User : PTrue->users()) {
104fe6060f1SDimitry Andric     if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) {
105fe6060f1SDimitry Andric       ConvertToUses.push_back(cast<IntrinsicInst>(User));
106fe6060f1SDimitry Andric     }
107fe6060f1SDimitry Andric   }
1085ffd83dbSDimitry Andric 
109fe6060f1SDimitry Andric   // If no such calls were found, this is ptrue is not promoted.
110fe6060f1SDimitry Andric   if (ConvertToUses.empty())
1115ffd83dbSDimitry Andric     return false;
1125ffd83dbSDimitry Andric 
113fe6060f1SDimitry Andric   // Otherwise, try to find users of the convert-to-svbool intrinsics that are
114fe6060f1SDimitry Andric   // calls to the convert-from-svbool intrinsic, and would result in some lanes
115fe6060f1SDimitry Andric   // being zeroed.
116fe6060f1SDimitry Andric   const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType());
117fe6060f1SDimitry Andric   for (IntrinsicInst *ConvertToUse : ConvertToUses) {
118fe6060f1SDimitry Andric     for (User *User : ConvertToUse->users()) {
119fe6060f1SDimitry Andric       auto *IntrUser = dyn_cast<IntrinsicInst>(User);
120fe6060f1SDimitry Andric       if (IntrUser && IntrUser->getIntrinsicID() ==
121fe6060f1SDimitry Andric                           Intrinsic::aarch64_sve_convert_from_svbool) {
122fe6060f1SDimitry Andric         const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType());
123fe6060f1SDimitry Andric 
124fe6060f1SDimitry Andric         // Would some lanes become zeroed by the conversion?
125fe6060f1SDimitry Andric         if (IntrUserVTy->getElementCount().getKnownMinValue() >
126fe6060f1SDimitry Andric             PTrueVTy->getElementCount().getKnownMinValue())
127fe6060f1SDimitry Andric           // This is a promoted ptrue.
128fe6060f1SDimitry Andric           return true;
129fe6060f1SDimitry Andric       }
130fe6060f1SDimitry Andric     }
131fe6060f1SDimitry Andric   }
132fe6060f1SDimitry Andric 
133fe6060f1SDimitry Andric   // If no matching calls were found, this is not a promoted ptrue.
1345ffd83dbSDimitry Andric   return false;
1355ffd83dbSDimitry Andric }
1365ffd83dbSDimitry Andric 
137fe6060f1SDimitry Andric /// Attempts to coalesce ptrues in a basic block.
138fe6060f1SDimitry Andric bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls(
139fe6060f1SDimitry Andric     BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) {
140fe6060f1SDimitry Andric   if (PTrues.size() <= 1)
141fe6060f1SDimitry Andric     return false;
142fe6060f1SDimitry Andric 
143fe6060f1SDimitry Andric   // Find the ptrue with the most lanes.
144fe6060f1SDimitry Andric   auto *MostEncompassingPTrue = *std::max_element(
145fe6060f1SDimitry Andric       PTrues.begin(), PTrues.end(), [](auto *PTrue1, auto *PTrue2) {
146fe6060f1SDimitry Andric         auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType());
147fe6060f1SDimitry Andric         auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType());
148fe6060f1SDimitry Andric         return PTrue1VTy->getElementCount().getKnownMinValue() <
149fe6060f1SDimitry Andric                PTrue2VTy->getElementCount().getKnownMinValue();
150fe6060f1SDimitry Andric       });
151fe6060f1SDimitry Andric 
152fe6060f1SDimitry Andric   // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving
153fe6060f1SDimitry Andric   // behind only the ptrues to be coalesced.
154fe6060f1SDimitry Andric   PTrues.remove(MostEncompassingPTrue);
155*0eae32dcSDimitry Andric   PTrues.remove_if(isPTruePromoted);
156fe6060f1SDimitry Andric 
157fe6060f1SDimitry Andric   // Hoist MostEncompassingPTrue to the start of the basic block. It is always
158fe6060f1SDimitry Andric   // safe to do this, since ptrue intrinsic calls are guaranteed to have no
159fe6060f1SDimitry Andric   // predecessors.
160fe6060f1SDimitry Andric   MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt());
161fe6060f1SDimitry Andric 
162fe6060f1SDimitry Andric   LLVMContext &Ctx = BB.getContext();
1635ffd83dbSDimitry Andric   IRBuilder<> Builder(Ctx);
164fe6060f1SDimitry Andric   Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator());
1655ffd83dbSDimitry Andric 
166fe6060f1SDimitry Andric   auto *MostEncompassingPTrueVTy =
167fe6060f1SDimitry Andric       cast<VectorType>(MostEncompassingPTrue->getType());
168fe6060f1SDimitry Andric   auto *ConvertToSVBool = Builder.CreateIntrinsic(
169fe6060f1SDimitry Andric       Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy},
170fe6060f1SDimitry Andric       {MostEncompassingPTrue});
171fe6060f1SDimitry Andric 
172fe6060f1SDimitry Andric   bool ConvertFromCreated = false;
173fe6060f1SDimitry Andric   for (auto *PTrue : PTrues) {
174fe6060f1SDimitry Andric     auto *PTrueVTy = cast<VectorType>(PTrue->getType());
175fe6060f1SDimitry Andric 
176fe6060f1SDimitry Andric     // Only create the converts if the types are not already the same, otherwise
177fe6060f1SDimitry Andric     // just use the most encompassing ptrue.
178fe6060f1SDimitry Andric     if (MostEncompassingPTrueVTy != PTrueVTy) {
179fe6060f1SDimitry Andric       ConvertFromCreated = true;
180fe6060f1SDimitry Andric 
181fe6060f1SDimitry Andric       Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator());
182fe6060f1SDimitry Andric       auto *ConvertFromSVBool =
183fe6060f1SDimitry Andric           Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
184fe6060f1SDimitry Andric                                   {PTrueVTy}, {ConvertToSVBool});
185fe6060f1SDimitry Andric       PTrue->replaceAllUsesWith(ConvertFromSVBool);
186fe6060f1SDimitry Andric     } else
187fe6060f1SDimitry Andric       PTrue->replaceAllUsesWith(MostEncompassingPTrue);
188fe6060f1SDimitry Andric 
189fe6060f1SDimitry Andric     PTrue->eraseFromParent();
1905ffd83dbSDimitry Andric   }
1915ffd83dbSDimitry Andric 
192fe6060f1SDimitry Andric   // We never used the ConvertTo so remove it
193fe6060f1SDimitry Andric   if (!ConvertFromCreated)
194fe6060f1SDimitry Andric     ConvertToSVBool->eraseFromParent();
1955ffd83dbSDimitry Andric 
1965ffd83dbSDimitry Andric   return true;
1975ffd83dbSDimitry Andric }
1985ffd83dbSDimitry Andric 
199fe6060f1SDimitry Andric /// The goal of this function is to remove redundant calls to the SVE ptrue
200fe6060f1SDimitry Andric /// intrinsic in each basic block within the given functions.
201fe6060f1SDimitry Andric ///
202fe6060f1SDimitry Andric /// SVE ptrues have two representations in LLVM IR:
203fe6060f1SDimitry Andric /// - a logical representation -- an arbitrary-width scalable vector of i1s,
204fe6060f1SDimitry Andric ///   i.e. <vscale x N x i1>.
205fe6060f1SDimitry Andric /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element
206fe6060f1SDimitry Andric ///   scalable vector of i1s, i.e. <vscale x 16 x i1>.
207fe6060f1SDimitry Andric ///
208fe6060f1SDimitry Andric /// The SVE ptrue intrinsic is used to create a logical representation of an SVE
209fe6060f1SDimitry Andric /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If
210fe6060f1SDimitry Andric /// P1 creates a logical SVE predicate that is at least as wide as the logical
211fe6060f1SDimitry Andric /// SVE predicate created by P2, then all of the bits that are true in the
212fe6060f1SDimitry Andric /// physical representation of P2 are necessarily also true in the physical
213fe6060f1SDimitry Andric /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to
214fe6060f1SDimitry Andric /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via
215fe6060f1SDimitry Andric /// convert.{to,from}.svbool.
216fe6060f1SDimitry Andric ///
217fe6060f1SDimitry Andric /// Currently, this pass only coalesces calls to SVE ptrue intrinsics
218fe6060f1SDimitry Andric /// if they match the following conditions:
219fe6060f1SDimitry Andric ///
220fe6060f1SDimitry Andric /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns.
221fe6060f1SDimitry Andric ///   SV_ALL indicates that all bits of the predicate vector are to be set to
222fe6060f1SDimitry Andric ///   true. SV_POW2 indicates that all bits of the predicate vector up to the
223fe6060f1SDimitry Andric ///   largest power-of-two are to be set to true.
224fe6060f1SDimitry Andric /// - the result of the call to the intrinsic is not promoted to a wider
225fe6060f1SDimitry Andric ///   predicate. In this case, keeping the extra ptrue leads to better codegen
226fe6060f1SDimitry Andric ///   -- coalescing here would create an irreducible chain of SVE reinterprets
227fe6060f1SDimitry Andric ///   via convert.{to,from}.svbool.
228fe6060f1SDimitry Andric ///
229fe6060f1SDimitry Andric /// EXAMPLE:
230fe6060f1SDimitry Andric ///
231fe6060f1SDimitry Andric ///     %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL)
232fe6060f1SDimitry Andric ///     ; Logical:  <1, 1, 1, 1, 1, 1, 1, 1>
233fe6060f1SDimitry Andric ///     ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0>
234fe6060f1SDimitry Andric ///     ...
235fe6060f1SDimitry Andric ///
236fe6060f1SDimitry Andric ///     %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL)
237fe6060f1SDimitry Andric ///     ; Logical:  <1, 1, 1, 1>
238fe6060f1SDimitry Andric ///     ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0>
239fe6060f1SDimitry Andric ///     ...
240fe6060f1SDimitry Andric ///
241fe6060f1SDimitry Andric /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance:
242fe6060f1SDimitry Andric ///
243fe6060f1SDimitry Andric ///     %1 = <vscale x 8 x i1> ptrue(i32 i31)
244fe6060f1SDimitry Andric ///     %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1)
245fe6060f1SDimitry Andric ///     %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2)
246fe6060f1SDimitry Andric ///
247fe6060f1SDimitry Andric bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
248fe6060f1SDimitry Andric     SmallSetVector<Function *, 4> &Functions) {
249fe6060f1SDimitry Andric   bool Changed = false;
2505ffd83dbSDimitry Andric 
251fe6060f1SDimitry Andric   for (auto *F : Functions) {
252fe6060f1SDimitry Andric     for (auto &BB : *F) {
253fe6060f1SDimitry Andric       SmallSetVector<IntrinsicInst *, 4> SVAllPTrues;
254fe6060f1SDimitry Andric       SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues;
2555ffd83dbSDimitry Andric 
256fe6060f1SDimitry Andric       // For each basic block, collect the used ptrues and try to coalesce them.
257fe6060f1SDimitry Andric       for (Instruction &I : BB) {
258fe6060f1SDimitry Andric         if (I.use_empty())
259fe6060f1SDimitry Andric           continue;
2605ffd83dbSDimitry Andric 
261fe6060f1SDimitry Andric         auto *IntrI = dyn_cast<IntrinsicInst>(&I);
262fe6060f1SDimitry Andric         if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
263fe6060f1SDimitry Andric           continue;
2645ffd83dbSDimitry Andric 
265fe6060f1SDimitry Andric         const auto PTruePattern =
266fe6060f1SDimitry Andric             cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue();
2675ffd83dbSDimitry Andric 
268fe6060f1SDimitry Andric         if (PTruePattern == AArch64SVEPredPattern::all)
269fe6060f1SDimitry Andric           SVAllPTrues.insert(IntrI);
270fe6060f1SDimitry Andric         if (PTruePattern == AArch64SVEPredPattern::pow2)
271fe6060f1SDimitry Andric           SVPow2PTrues.insert(IntrI);
2725ffd83dbSDimitry Andric       }
2735ffd83dbSDimitry Andric 
274fe6060f1SDimitry Andric       Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues);
275fe6060f1SDimitry Andric       Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues);
276fe6060f1SDimitry Andric     }
2775ffd83dbSDimitry Andric   }
2785ffd83dbSDimitry Andric 
279fe6060f1SDimitry Andric   return Changed;
2805ffd83dbSDimitry Andric }
2815ffd83dbSDimitry Andric 
282349cc55cSDimitry Andric // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
283349cc55cSDimitry Andric // scalable stores as late as possible
284349cc55cSDimitry Andric bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
285349cc55cSDimitry Andric   auto *F = I->getFunction();
286349cc55cSDimitry Andric   auto Attr = F->getFnAttribute(Attribute::VScaleRange);
287349cc55cSDimitry Andric   if (!Attr.isValid())
288349cc55cSDimitry Andric     return false;
289349cc55cSDimitry Andric 
290*0eae32dcSDimitry Andric   unsigned MinVScale = Attr.getVScaleRangeMin();
291*0eae32dcSDimitry Andric   Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
292349cc55cSDimitry Andric   // The transform needs to know the exact runtime length of scalable vectors
293*0eae32dcSDimitry Andric   if (!MaxVScale || MinVScale != MaxVScale)
294349cc55cSDimitry Andric     return false;
295349cc55cSDimitry Andric 
296349cc55cSDimitry Andric   auto *PredType =
297349cc55cSDimitry Andric       ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
298349cc55cSDimitry Andric   auto *FixedPredType =
299349cc55cSDimitry Andric       FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
300349cc55cSDimitry Andric 
301349cc55cSDimitry Andric   // If we have a store..
302349cc55cSDimitry Andric   auto *Store = dyn_cast<StoreInst>(I);
303349cc55cSDimitry Andric   if (!Store || !Store->isSimple())
304349cc55cSDimitry Andric     return false;
305349cc55cSDimitry Andric 
306349cc55cSDimitry Andric   // ..that is storing a predicate vector sized worth of bits..
307349cc55cSDimitry Andric   if (Store->getOperand(0)->getType() != FixedPredType)
308349cc55cSDimitry Andric     return false;
309349cc55cSDimitry Andric 
310349cc55cSDimitry Andric   // ..where the value stored comes from a vector extract..
311349cc55cSDimitry Andric   auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0));
312349cc55cSDimitry Andric   if (!IntrI ||
313349cc55cSDimitry Andric       IntrI->getIntrinsicID() != Intrinsic::experimental_vector_extract)
314349cc55cSDimitry Andric     return false;
315349cc55cSDimitry Andric 
316349cc55cSDimitry Andric   // ..that is extracting from index 0..
317349cc55cSDimitry Andric   if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
318349cc55cSDimitry Andric     return false;
319349cc55cSDimitry Andric 
320349cc55cSDimitry Andric   // ..where the value being extract from comes from a bitcast
321349cc55cSDimitry Andric   auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
322349cc55cSDimitry Andric   if (!BitCast)
323349cc55cSDimitry Andric     return false;
324349cc55cSDimitry Andric 
325349cc55cSDimitry Andric   // ..and the bitcast is casting from predicate type
326349cc55cSDimitry Andric   if (BitCast->getOperand(0)->getType() != PredType)
327349cc55cSDimitry Andric     return false;
328349cc55cSDimitry Andric 
329349cc55cSDimitry Andric   IRBuilder<> Builder(I->getContext());
330349cc55cSDimitry Andric   Builder.SetInsertPoint(I);
331349cc55cSDimitry Andric 
332349cc55cSDimitry Andric   auto *PtrBitCast = Builder.CreateBitCast(
333349cc55cSDimitry Andric       Store->getPointerOperand(),
334349cc55cSDimitry Andric       PredType->getPointerTo(Store->getPointerAddressSpace()));
335349cc55cSDimitry Andric   Builder.CreateStore(BitCast->getOperand(0), PtrBitCast);
336349cc55cSDimitry Andric 
337349cc55cSDimitry Andric   Store->eraseFromParent();
338349cc55cSDimitry Andric   if (IntrI->getNumUses() == 0)
339349cc55cSDimitry Andric     IntrI->eraseFromParent();
340349cc55cSDimitry Andric   if (BitCast->getNumUses() == 0)
341349cc55cSDimitry Andric     BitCast->eraseFromParent();
342349cc55cSDimitry Andric 
343349cc55cSDimitry Andric   return true;
344349cc55cSDimitry Andric }
345349cc55cSDimitry Andric 
346349cc55cSDimitry Andric // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
347349cc55cSDimitry Andric // scalable loads as late as possible
348349cc55cSDimitry Andric bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
349349cc55cSDimitry Andric   auto *F = I->getFunction();
350349cc55cSDimitry Andric   auto Attr = F->getFnAttribute(Attribute::VScaleRange);
351349cc55cSDimitry Andric   if (!Attr.isValid())
352349cc55cSDimitry Andric     return false;
353349cc55cSDimitry Andric 
354*0eae32dcSDimitry Andric   unsigned MinVScale = Attr.getVScaleRangeMin();
355*0eae32dcSDimitry Andric   Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax();
356349cc55cSDimitry Andric   // The transform needs to know the exact runtime length of scalable vectors
357*0eae32dcSDimitry Andric   if (!MaxVScale || MinVScale != MaxVScale)
358349cc55cSDimitry Andric     return false;
359349cc55cSDimitry Andric 
360349cc55cSDimitry Andric   auto *PredType =
361349cc55cSDimitry Andric       ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
362349cc55cSDimitry Andric   auto *FixedPredType =
363349cc55cSDimitry Andric       FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
364349cc55cSDimitry Andric 
365349cc55cSDimitry Andric   // If we have a bitcast..
366349cc55cSDimitry Andric   auto *BitCast = dyn_cast<BitCastInst>(I);
367349cc55cSDimitry Andric   if (!BitCast || BitCast->getType() != PredType)
368349cc55cSDimitry Andric     return false;
369349cc55cSDimitry Andric 
370349cc55cSDimitry Andric   // ..whose operand is a vector_insert..
371349cc55cSDimitry Andric   auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
372349cc55cSDimitry Andric   if (!IntrI ||
373349cc55cSDimitry Andric       IntrI->getIntrinsicID() != Intrinsic::experimental_vector_insert)
374349cc55cSDimitry Andric     return false;
375349cc55cSDimitry Andric 
376349cc55cSDimitry Andric   // ..that is inserting into index zero of an undef vector..
377349cc55cSDimitry Andric   if (!isa<UndefValue>(IntrI->getOperand(0)) ||
378349cc55cSDimitry Andric       !cast<ConstantInt>(IntrI->getOperand(2))->isZero())
379349cc55cSDimitry Andric     return false;
380349cc55cSDimitry Andric 
381349cc55cSDimitry Andric   // ..where the value inserted comes from a load..
382349cc55cSDimitry Andric   auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
383349cc55cSDimitry Andric   if (!Load || !Load->isSimple())
384349cc55cSDimitry Andric     return false;
385349cc55cSDimitry Andric 
386349cc55cSDimitry Andric   // ..that is loading a predicate vector sized worth of bits..
387349cc55cSDimitry Andric   if (Load->getType() != FixedPredType)
388349cc55cSDimitry Andric     return false;
389349cc55cSDimitry Andric 
390349cc55cSDimitry Andric   IRBuilder<> Builder(I->getContext());
391349cc55cSDimitry Andric   Builder.SetInsertPoint(Load);
392349cc55cSDimitry Andric 
393349cc55cSDimitry Andric   auto *PtrBitCast = Builder.CreateBitCast(
394349cc55cSDimitry Andric       Load->getPointerOperand(),
395349cc55cSDimitry Andric       PredType->getPointerTo(Load->getPointerAddressSpace()));
396349cc55cSDimitry Andric   auto *LoadPred = Builder.CreateLoad(PredType, PtrBitCast);
397349cc55cSDimitry Andric 
398349cc55cSDimitry Andric   BitCast->replaceAllUsesWith(LoadPred);
399349cc55cSDimitry Andric   BitCast->eraseFromParent();
400349cc55cSDimitry Andric   if (IntrI->getNumUses() == 0)
401349cc55cSDimitry Andric     IntrI->eraseFromParent();
402349cc55cSDimitry Andric   if (Load->getNumUses() == 0)
403349cc55cSDimitry Andric     Load->eraseFromParent();
404349cc55cSDimitry Andric 
405349cc55cSDimitry Andric   return true;
406349cc55cSDimitry Andric }
407349cc55cSDimitry Andric 
408349cc55cSDimitry Andric bool SVEIntrinsicOpts::optimizeInstructions(
409349cc55cSDimitry Andric     SmallSetVector<Function *, 4> &Functions) {
410349cc55cSDimitry Andric   bool Changed = false;
411349cc55cSDimitry Andric 
412349cc55cSDimitry Andric   for (auto *F : Functions) {
413349cc55cSDimitry Andric     DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
414349cc55cSDimitry Andric 
415349cc55cSDimitry Andric     // Traverse the DT with an rpo walk so we see defs before uses, allowing
416349cc55cSDimitry Andric     // simplification to be done incrementally.
417349cc55cSDimitry Andric     BasicBlock *Root = DT->getRoot();
418349cc55cSDimitry Andric     ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
419349cc55cSDimitry Andric     for (auto *BB : RPOT) {
420349cc55cSDimitry Andric       for (Instruction &I : make_early_inc_range(*BB)) {
421349cc55cSDimitry Andric         switch (I.getOpcode()) {
422349cc55cSDimitry Andric         case Instruction::Store:
423349cc55cSDimitry Andric           Changed |= optimizePredicateStore(&I);
424349cc55cSDimitry Andric           break;
425349cc55cSDimitry Andric         case Instruction::BitCast:
426349cc55cSDimitry Andric           Changed |= optimizePredicateLoad(&I);
427349cc55cSDimitry Andric           break;
428349cc55cSDimitry Andric         }
429349cc55cSDimitry Andric       }
430349cc55cSDimitry Andric     }
431349cc55cSDimitry Andric   }
432349cc55cSDimitry Andric 
433349cc55cSDimitry Andric   return Changed;
434349cc55cSDimitry Andric }
435349cc55cSDimitry Andric 
4365ffd83dbSDimitry Andric bool SVEIntrinsicOpts::optimizeFunctions(
4375ffd83dbSDimitry Andric     SmallSetVector<Function *, 4> &Functions) {
4385ffd83dbSDimitry Andric   bool Changed = false;
4395ffd83dbSDimitry Andric 
440fe6060f1SDimitry Andric   Changed |= optimizePTrueIntrinsicCalls(Functions);
441349cc55cSDimitry Andric   Changed |= optimizeInstructions(Functions);
442fe6060f1SDimitry Andric 
4435ffd83dbSDimitry Andric   return Changed;
4445ffd83dbSDimitry Andric }
4455ffd83dbSDimitry Andric 
4465ffd83dbSDimitry Andric bool SVEIntrinsicOpts::runOnModule(Module &M) {
4475ffd83dbSDimitry Andric   bool Changed = false;
4485ffd83dbSDimitry Andric   SmallSetVector<Function *, 4> Functions;
4495ffd83dbSDimitry Andric 
4505ffd83dbSDimitry Andric   // Check for SVE intrinsic declarations first so that we only iterate over
4515ffd83dbSDimitry Andric   // relevant functions. Where an appropriate declaration is found, store the
4525ffd83dbSDimitry Andric   // function(s) where it is used so we can target these only.
4535ffd83dbSDimitry Andric   for (auto &F : M.getFunctionList()) {
4545ffd83dbSDimitry Andric     if (!F.isDeclaration())
4555ffd83dbSDimitry Andric       continue;
4565ffd83dbSDimitry Andric 
4575ffd83dbSDimitry Andric     switch (F.getIntrinsicID()) {
458349cc55cSDimitry Andric     case Intrinsic::experimental_vector_extract:
459349cc55cSDimitry Andric     case Intrinsic::experimental_vector_insert:
460fe6060f1SDimitry Andric     case Intrinsic::aarch64_sve_ptrue:
461e8d8bef9SDimitry Andric       for (User *U : F.users())
462e8d8bef9SDimitry Andric         Functions.insert(cast<Instruction>(U)->getFunction());
4635ffd83dbSDimitry Andric       break;
4645ffd83dbSDimitry Andric     default:
4655ffd83dbSDimitry Andric       break;
4665ffd83dbSDimitry Andric     }
4675ffd83dbSDimitry Andric   }
4685ffd83dbSDimitry Andric 
4695ffd83dbSDimitry Andric   if (!Functions.empty())
4705ffd83dbSDimitry Andric     Changed |= optimizeFunctions(Functions);
4715ffd83dbSDimitry Andric 
4725ffd83dbSDimitry Andric   return Changed;
4735ffd83dbSDimitry Andric }
474