xref: /freebsd/contrib/llvm-project/llvm/include/llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- Legality.h -----------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Legality checks for the Sandbox Vectorizer.
10 //
11 
12 #ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
13 #define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
14 
15 #include "llvm/ADT/ArrayRef.h"
16 #include "llvm/Analysis/ScalarEvolution.h"
17 #include "llvm/IR/DataLayout.h"
18 #include "llvm/Support/Casting.h"
19 #include "llvm/Support/Compiler.h"
20 #include "llvm/Support/raw_ostream.h"
21 #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
22 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
23 
24 namespace llvm::sandboxir {
25 
26 class LegalityAnalysis;
27 class Value;
28 class InstrMaps;
29 
30 class ShuffleMask {
31 public:
32   using IndicesVecT = SmallVector<int, 8>;
33 
34 private:
35   IndicesVecT Indices;
36 
37 public:
ShuffleMask(SmallVectorImpl<int> && Indices)38   ShuffleMask(SmallVectorImpl<int> &&Indices) : Indices(std::move(Indices)) {}
ShuffleMask(std::initializer_list<int> Indices)39   ShuffleMask(std::initializer_list<int> Indices) : Indices(Indices) {}
ShuffleMask(ArrayRef<int> Indices)40   explicit ShuffleMask(ArrayRef<int> Indices) : Indices(Indices) {}
41   operator ArrayRef<int>() const { return Indices; }
42   /// Creates and returns an identity shuffle mask of size \p Sz.
43   /// For example if Sz == 4 the returned mask is {0, 1, 2, 3}.
getIdentity(unsigned Sz)44   static ShuffleMask getIdentity(unsigned Sz) {
45     IndicesVecT Indices;
46     Indices.reserve(Sz);
47     llvm::append_range(Indices, seq<int>(0, (int)Sz));
48     return ShuffleMask(std::move(Indices));
49   }
50   /// \Returns true if the mask is a perfect identity mask with consecutive
51   /// indices, i.e., performs no lane shuffling, like 0,1,2,3...
isIdentity()52   bool isIdentity() const {
53     for (auto [Idx, Elm] : enumerate(Indices)) {
54       if ((int)Idx != Elm)
55         return false;
56     }
57     return true;
58   }
59   bool operator==(const ShuffleMask &Other) const {
60     return Indices == Other.Indices;
61   }
62   bool operator!=(const ShuffleMask &Other) const { return !(*this == Other); }
size()63   size_t size() const { return Indices.size(); }
64   int operator[](int Idx) const { return Indices[Idx]; }
65   using const_iterator = IndicesVecT::const_iterator;
begin()66   const_iterator begin() const { return Indices.begin(); }
end()67   const_iterator end() const { return Indices.end(); }
68 #ifndef NDEBUG
69   friend raw_ostream &operator<<(raw_ostream &OS, const ShuffleMask &Mask) {
70     Mask.print(OS);
71     return OS;
72   }
print(raw_ostream & OS)73   void print(raw_ostream &OS) const {
74     interleave(Indices, OS, [&OS](auto Elm) { OS << Elm; }, ",");
75   }
76   LLVM_DUMP_METHOD void dump() const;
77 #endif
78 };
79 
80 enum class LegalityResultID {
81   Pack,                    ///> Collect scalar values.
82   Widen,                   ///> Vectorize by combining scalars to a vector.
83   DiamondReuse,            ///> Don't generate new code, reuse existing vector.
84   DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle.
85   DiamondReuseMultiInput,  ///> Reuse more than one vector and/or scalars.
86 };
87 
88 /// The reason for vectorizing or not vectorizing.
89 enum class ResultReason {
90   NotInstructions,
91   DiffOpcodes,
92   DiffTypes,
93   DiffMathFlags,
94   DiffWrapFlags,
95   DiffBBs,
96   RepeatedInstrs,
97   NotConsecutive,
98   CantSchedule,
99   Unimplemented,
100   Infeasible,
101   ForcePackForDebugging,
102 };
103 
104 #ifndef NDEBUG
105 struct ToStr {
getLegalityResultIDToStr106   static const char *getLegalityResultID(LegalityResultID ID) {
107     switch (ID) {
108     case LegalityResultID::Pack:
109       return "Pack";
110     case LegalityResultID::Widen:
111       return "Widen";
112     case LegalityResultID::DiamondReuse:
113       return "DiamondReuse";
114     case LegalityResultID::DiamondReuseWithShuffle:
115       return "DiamondReuseWithShuffle";
116     case LegalityResultID::DiamondReuseMultiInput:
117       return "DiamondReuseMultiInput";
118     }
119     llvm_unreachable("Unknown LegalityResultID enum");
120   }
121 
getVecReasonToStr122   static const char *getVecReason(ResultReason Reason) {
123     switch (Reason) {
124     case ResultReason::NotInstructions:
125       return "NotInstructions";
126     case ResultReason::DiffOpcodes:
127       return "DiffOpcodes";
128     case ResultReason::DiffTypes:
129       return "DiffTypes";
130     case ResultReason::DiffMathFlags:
131       return "DiffMathFlags";
132     case ResultReason::DiffWrapFlags:
133       return "DiffWrapFlags";
134     case ResultReason::DiffBBs:
135       return "DiffBBs";
136     case ResultReason::RepeatedInstrs:
137       return "RepeatedInstrs";
138     case ResultReason::NotConsecutive:
139       return "NotConsecutive";
140     case ResultReason::CantSchedule:
141       return "CantSchedule";
142     case ResultReason::Unimplemented:
143       return "Unimplemented";
144     case ResultReason::Infeasible:
145       return "Infeasible";
146     case ResultReason::ForcePackForDebugging:
147       return "ForcePackForDebugging";
148     }
149     llvm_unreachable("Unknown ResultReason enum");
150   }
151 };
152 #endif // NDEBUG
153 
154 /// The legality outcome is represented by a class rather than an enum class
155 /// because in some cases the legality checks are expensive and look for a
156 /// particular instruction that can be passed along to the vectorizer to avoid
157 /// repeating the same expensive computation.
158 class LegalityResult {
159 protected:
160   LegalityResultID ID;
161   /// Only Legality can create LegalityResults.
LegalityResult(LegalityResultID ID)162   LegalityResult(LegalityResultID ID) : ID(ID) {}
163   friend class LegalityAnalysis;
164 
165   /// We shouldn't need copies.
166   LegalityResult(const LegalityResult &) = delete;
167   LegalityResult &operator=(const LegalityResult &) = delete;
168 
169 public:
~LegalityResult()170   virtual ~LegalityResult() {}
getSubclassID()171   LegalityResultID getSubclassID() const { return ID; }
172 #ifndef NDEBUG
print(raw_ostream & OS)173   virtual void print(raw_ostream &OS) const {
174     OS << ToStr::getLegalityResultID(ID);
175   }
176   LLVM_DUMP_METHOD void dump() const;
177   friend raw_ostream &operator<<(raw_ostream &OS, const LegalityResult &LR) {
178     LR.print(OS);
179     return OS;
180   }
181 #endif // NDEBUG
182 };
183 
184 /// Base class for results with reason.
185 class LegalityResultWithReason : public LegalityResult {
186   [[maybe_unused]] ResultReason Reason;
LegalityResultWithReason(LegalityResultID ID,ResultReason Reason)187   LegalityResultWithReason(LegalityResultID ID, ResultReason Reason)
188       : LegalityResult(ID), Reason(Reason) {}
189   friend class Pack; // For constructor.
190 
191 public:
getReason()192   ResultReason getReason() const { return Reason; }
193 #ifndef NDEBUG
print(raw_ostream & OS)194   void print(raw_ostream &OS) const override {
195     LegalityResult::print(OS);
196     OS << " Reason: " << ToStr::getVecReason(Reason);
197   }
198 #endif
199 };
200 
201 class Widen final : public LegalityResult {
202   friend class LegalityAnalysis;
Widen()203   Widen() : LegalityResult(LegalityResultID::Widen) {}
204 
205 public:
classof(const LegalityResult * From)206   static bool classof(const LegalityResult *From) {
207     return From->getSubclassID() == LegalityResultID::Widen;
208   }
209 };
210 
211 class DiamondReuse final : public LegalityResult {
212   friend class LegalityAnalysis;
213   Action *Vec;
DiamondReuse(Action * Vec)214   DiamondReuse(Action *Vec)
215       : LegalityResult(LegalityResultID::DiamondReuse), Vec(Vec) {}
216 
217 public:
classof(const LegalityResult * From)218   static bool classof(const LegalityResult *From) {
219     return From->getSubclassID() == LegalityResultID::DiamondReuse;
220   }
getVector()221   Action *getVector() const { return Vec; }
222 };
223 
224 class DiamondReuseWithShuffle final : public LegalityResult {
225   friend class LegalityAnalysis;
226   Action *Vec;
227   ShuffleMask Mask;
DiamondReuseWithShuffle(Action * Vec,const ShuffleMask & Mask)228   DiamondReuseWithShuffle(Action *Vec, const ShuffleMask &Mask)
229       : LegalityResult(LegalityResultID::DiamondReuseWithShuffle), Vec(Vec),
230         Mask(Mask) {}
231 
232 public:
classof(const LegalityResult * From)233   static bool classof(const LegalityResult *From) {
234     return From->getSubclassID() == LegalityResultID::DiamondReuseWithShuffle;
235   }
getVector()236   Action *getVector() const { return Vec; }
getMask()237   const ShuffleMask &getMask() const { return Mask; }
238 };
239 
240 class Pack final : public LegalityResultWithReason {
Pack(ResultReason Reason)241   Pack(ResultReason Reason)
242       : LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
243   friend class LegalityAnalysis; // For constructor.
244 
245 public:
classof(const LegalityResult * From)246   static bool classof(const LegalityResult *From) {
247     return From->getSubclassID() == LegalityResultID::Pack;
248   }
249 };
250 
251 /// Describes how to collect the values needed by each lane.
252 class CollectDescr {
253 public:
254   /// Describes how to get a value element. If the value is a vector then it
255   /// also provides the index to extract it from.
256   class ExtractElementDescr {
257     PointerUnion<Action *, Value *> V = nullptr;
258     /// The index in `V` that the value can be extracted from.
259     int ExtractIdx = 0;
260 
261   public:
ExtractElementDescr(Action * V,int ExtractIdx)262     ExtractElementDescr(Action *V, int ExtractIdx)
263         : V(V), ExtractIdx(ExtractIdx) {}
ExtractElementDescr(Value * V)264     ExtractElementDescr(Value *V) : V(V) {}
getValue()265     Action *getValue() const { return cast<Action *>(V); }
getScalar()266     Value *getScalar() const { return cast<Value *>(V); }
needsExtract()267     bool needsExtract() const { return isa<Action *>(V); }
getExtractIdx()268     int getExtractIdx() const { return ExtractIdx; }
269   };
270 
271   using DescrVecT = SmallVector<ExtractElementDescr, 4>;
272   DescrVecT Descrs;
273 
274 public:
CollectDescr(SmallVectorImpl<ExtractElementDescr> && Descrs)275   CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
276       : Descrs(std::move(Descrs)) {}
277   /// If all elements come from a single vector input, then return that vector
278   /// and also the shuffle mask required to get them in order.
getSingleInput()279   std::optional<std::pair<Action *, ShuffleMask>> getSingleInput() const {
280     const auto &Descr0 = *Descrs.begin();
281     if (!Descr0.needsExtract())
282       return std::nullopt;
283     auto *V0 = Descr0.getValue();
284     ShuffleMask::IndicesVecT MaskIndices;
285     MaskIndices.push_back(Descr0.getExtractIdx());
286     for (const auto &Descr : drop_begin(Descrs)) {
287       if (!Descr.needsExtract())
288         return std::nullopt;
289       if (Descr.getValue() != V0)
290         return std::nullopt;
291       MaskIndices.push_back(Descr.getExtractIdx());
292     }
293     return std::make_pair(V0, ShuffleMask(std::move(MaskIndices)));
294   }
hasVectorInputs()295   bool hasVectorInputs() const {
296     return any_of(Descrs, [](const auto &D) { return D.needsExtract(); });
297   }
getDescrs()298   const SmallVector<ExtractElementDescr, 4> &getDescrs() const {
299     return Descrs;
300   }
301 };
302 
303 class DiamondReuseMultiInput final : public LegalityResult {
304   friend class LegalityAnalysis;
305   CollectDescr Descr;
DiamondReuseMultiInput(CollectDescr && Descr)306   DiamondReuseMultiInput(CollectDescr &&Descr)
307       : LegalityResult(LegalityResultID::DiamondReuseMultiInput),
308         Descr(std::move(Descr)) {}
309 
310 public:
classof(const LegalityResult * From)311   static bool classof(const LegalityResult *From) {
312     return From->getSubclassID() == LegalityResultID::DiamondReuseMultiInput;
313   }
getCollectDescr()314   const CollectDescr &getCollectDescr() const { return Descr; }
315 };
316 
317 /// Performs the legality analysis and returns a LegalityResult object.
318 class LegalityAnalysis {
319   Scheduler Sched;
320   /// Owns the legality result objects created by createLegalityResult().
321   SmallVector<std::unique_ptr<LegalityResult>> ResultPool;
322   /// Checks opcodes, types and other IR-specifics and returns a ResultReason
323   /// object if not vectorizable, or nullptr otherwise.
324   std::optional<ResultReason>
325   notVectorizableBasedOnOpcodesAndTypes(ArrayRef<Value *> Bndl);
326 
327   ScalarEvolution &SE;
328   const DataLayout &DL;
329   InstrMaps &IMaps;
330 
331   /// Finds how we can collect the values in \p Bndl from the vectorized or
332   /// non-vectorized code. It returns a map of the value we should extract from
333   /// and the corresponding shuffle mask we need to use.
334   CollectDescr getHowToCollectValues(ArrayRef<Value *> Bndl) const;
335 
336 public:
LegalityAnalysis(AAResults & AA,ScalarEvolution & SE,const DataLayout & DL,Context & Ctx,InstrMaps & IMaps)337   LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
338                    Context &Ctx, InstrMaps &IMaps)
339       : Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {}
340   /// A LegalityResult factory.
341   template <typename ResultT, typename... ArgsT>
createLegalityResult(ArgsT &&...Args)342   ResultT &createLegalityResult(ArgsT &&...Args) {
343     ResultPool.push_back(
344         std::unique_ptr<ResultT>(new ResultT(std::move(Args)...)));
345     return cast<ResultT>(*ResultPool.back());
346   }
347   /// Checks if it's legal to vectorize the instructions in \p Bndl.
348   /// \Returns a LegalityResult object owned by LegalityAnalysis.
349   /// \p SkipScheduling skips the scheduler check and is only meant for testing.
350   // TODO: Try to remove the SkipScheduling argument by refactoring the tests.
351   LLVM_ABI const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
352                                               bool SkipScheduling = false);
353   /// \Returns a Pack with reason 'ForcePackForDebugging'.
getForcedPackForDebugging()354   const LegalityResult &getForcedPackForDebugging() {
355     return createLegalityResult<Pack>(ResultReason::ForcePackForDebugging);
356   }
357   LLVM_ABI void clear();
358 };
359 
360 } // namespace llvm::sandboxir
361 
362 #endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
363