xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- Legality.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
10 #include "llvm/SandboxIR/Instruction.h"
11 #include "llvm/SandboxIR/Operator.h"
12 #include "llvm/SandboxIR/Utils.h"
13 #include "llvm/SandboxIR/Value.h"
14 #include "llvm/Support/Debug.h"
15 #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
16 #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
17 
18 namespace llvm::sandboxir {
19 
20 #ifndef NDEBUG
dump() const21 void ShuffleMask::dump() const {
22   print(dbgs());
23   dbgs() << "\n";
24 }
25 
dump() const26 void LegalityResult::dump() const {
27   print(dbgs());
28   dbgs() << "\n";
29 }
30 #endif // NDEBUG
31 
32 std::optional<ResultReason>
notVectorizableBasedOnOpcodesAndTypes(ArrayRef<Value * > Bndl)33 LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
34     ArrayRef<Value *> Bndl) {
35   auto *I0 = cast<Instruction>(Bndl[0]);
36   auto Opcode = I0->getOpcode();
37   // If they have different opcodes, then we cannot form a vector (for now).
38   if (any_of(drop_begin(Bndl), [Opcode](Value *V) {
39         return cast<Instruction>(V)->getOpcode() != Opcode;
40       }))
41     return ResultReason::DiffOpcodes;
42 
43   // If not the same scalar type, Pack. This will accept scalars and vectors as
44   // long as the element type is the same.
45   Type *ElmTy0 = VecUtils::getElementType(Utils::getExpectedType(I0));
46   if (any_of(drop_begin(Bndl), [ElmTy0](Value *V) {
47         return VecUtils::getElementType(Utils::getExpectedType(V)) != ElmTy0;
48       }))
49     return ResultReason::DiffTypes;
50 
51   // TODO: Allow vectorization of instrs with different flags as long as we
52   // change them to the least common one.
53   // For now pack if differnt FastMathFlags.
54   if (isa<FPMathOperator>(I0)) {
55     FastMathFlags FMF0 = cast<Instruction>(Bndl[0])->getFastMathFlags();
56     if (any_of(drop_begin(Bndl), [FMF0](auto *V) {
57           return cast<Instruction>(V)->getFastMathFlags() != FMF0;
58         }))
59       return ResultReason::DiffMathFlags;
60   }
61 
62   // TODO: Allow vectorization by using common flags.
63   // For now Pack if they don't have the same wrap flags.
64   bool CanHaveWrapFlags =
65       isa<OverflowingBinaryOperator>(I0) || isa<TruncInst>(I0);
66   if (CanHaveWrapFlags) {
67     bool NUW0 = I0->hasNoUnsignedWrap();
68     bool NSW0 = I0->hasNoSignedWrap();
69     if (any_of(drop_begin(Bndl), [NUW0, NSW0](auto *V) {
70           return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 ||
71                  cast<Instruction>(V)->hasNoSignedWrap() != NSW0;
72         })) {
73       return ResultReason::DiffWrapFlags;
74     }
75   }
76 
77   // Now we need to do further checks for specific opcodes.
78   switch (Opcode) {
79   case Instruction::Opcode::ZExt:
80   case Instruction::Opcode::SExt:
81   case Instruction::Opcode::FPToUI:
82   case Instruction::Opcode::FPToSI:
83   case Instruction::Opcode::FPExt:
84   case Instruction::Opcode::PtrToInt:
85   case Instruction::Opcode::IntToPtr:
86   case Instruction::Opcode::SIToFP:
87   case Instruction::Opcode::UIToFP:
88   case Instruction::Opcode::Trunc:
89   case Instruction::Opcode::FPTrunc:
90   case Instruction::Opcode::BitCast: {
91     // We have already checked that they are of the same opcode.
92     assert(all_of(Bndl,
93                   [Opcode](Value *V) {
94                     return cast<Instruction>(V)->getOpcode() == Opcode;
95                   }) &&
96            "Different opcodes, should have early returned!");
97     // But for these opcodes we should also check the operand type.
98     Type *FromTy0 = Utils::getExpectedType(I0->getOperand(0));
99     if (any_of(drop_begin(Bndl), [FromTy0](Value *V) {
100           return Utils::getExpectedType(cast<User>(V)->getOperand(0)) !=
101                  FromTy0;
102         }))
103       return ResultReason::DiffTypes;
104     return std::nullopt;
105   }
106   case Instruction::Opcode::FCmp:
107   case Instruction::Opcode::ICmp: {
108     // We need the same predicate..
109     auto Pred0 = cast<CmpInst>(I0)->getPredicate();
110     bool Same = all_of(Bndl, [Pred0](Value *V) {
111       return cast<CmpInst>(V)->getPredicate() == Pred0;
112     });
113     if (Same)
114       return std::nullopt;
115     return ResultReason::DiffOpcodes;
116   }
117   case Instruction::Opcode::Select: {
118     auto *Sel0 = cast<SelectInst>(Bndl[0]);
119     auto *Cond0 = Sel0->getCondition();
120     if (VecUtils::getNumLanes(Cond0) != VecUtils::getNumLanes(Sel0))
121       // TODO: For now we don't vectorize if the lanes in the condition don't
122       // match those of the select instruction.
123       return ResultReason::Unimplemented;
124     return std::nullopt;
125   }
126   case Instruction::Opcode::FNeg:
127   case Instruction::Opcode::Add:
128   case Instruction::Opcode::FAdd:
129   case Instruction::Opcode::Sub:
130   case Instruction::Opcode::FSub:
131   case Instruction::Opcode::Mul:
132   case Instruction::Opcode::FMul:
133   case Instruction::Opcode::FRem:
134   case Instruction::Opcode::UDiv:
135   case Instruction::Opcode::SDiv:
136   case Instruction::Opcode::FDiv:
137   case Instruction::Opcode::URem:
138   case Instruction::Opcode::SRem:
139   case Instruction::Opcode::Shl:
140   case Instruction::Opcode::LShr:
141   case Instruction::Opcode::AShr:
142   case Instruction::Opcode::And:
143   case Instruction::Opcode::Or:
144   case Instruction::Opcode::Xor:
145     return std::nullopt;
146   case Instruction::Opcode::Load:
147     if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL))
148       return std::nullopt;
149     return ResultReason::NotConsecutive;
150   case Instruction::Opcode::Store:
151     if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL))
152       return std::nullopt;
153     return ResultReason::NotConsecutive;
154   case Instruction::Opcode::PHI:
155     return ResultReason::Unimplemented;
156   case Instruction::Opcode::Opaque:
157     return ResultReason::Unimplemented;
158   case Instruction::Opcode::Br:
159   case Instruction::Opcode::Ret:
160   case Instruction::Opcode::AddrSpaceCast:
161   case Instruction::Opcode::InsertElement:
162   case Instruction::Opcode::InsertValue:
163   case Instruction::Opcode::ExtractElement:
164   case Instruction::Opcode::ExtractValue:
165   case Instruction::Opcode::ShuffleVector:
166   case Instruction::Opcode::Call:
167   case Instruction::Opcode::GetElementPtr:
168   case Instruction::Opcode::Switch:
169     return ResultReason::Unimplemented;
170   case Instruction::Opcode::VAArg:
171   case Instruction::Opcode::Freeze:
172   case Instruction::Opcode::Fence:
173   case Instruction::Opcode::Invoke:
174   case Instruction::Opcode::CallBr:
175   case Instruction::Opcode::LandingPad:
176   case Instruction::Opcode::CatchPad:
177   case Instruction::Opcode::CleanupPad:
178   case Instruction::Opcode::CatchRet:
179   case Instruction::Opcode::CleanupRet:
180   case Instruction::Opcode::Resume:
181   case Instruction::Opcode::CatchSwitch:
182   case Instruction::Opcode::AtomicRMW:
183   case Instruction::Opcode::AtomicCmpXchg:
184   case Instruction::Opcode::Alloca:
185   case Instruction::Opcode::Unreachable:
186     return ResultReason::Infeasible;
187   }
188 
189   return std::nullopt;
190 }
191 
192 CollectDescr
getHowToCollectValues(ArrayRef<Value * > Bndl) const193 LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
194   SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
195   Vec.reserve(Bndl.size());
196   for (auto [Elm, V] : enumerate(Bndl)) {
197     if (auto *VecOp = IMaps.getVectorForOrig(V)) {
198       // If there is a vector containing `V`, then get the lane it came from.
199       std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
200       // This could be a vector, like <2 x float> in which case the mask needs
201       // to enumerate all lanes.
202       for (unsigned Ln = 0, Lanes = VecUtils::getNumLanes(V); Ln != Lanes; ++Ln)
203         Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
204     } else {
205       Vec.emplace_back(V);
206     }
207   }
208   return CollectDescr(std::move(Vec));
209 }
210 
canVectorize(ArrayRef<Value * > Bndl,bool SkipScheduling)211 const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
212                                                      bool SkipScheduling) {
213   // If Bndl contains values other than instructions, we need to Pack.
214   if (any_of(Bndl, [](auto *V) { return !isa<Instruction>(V); }))
215     return createLegalityResult<Pack>(ResultReason::NotInstructions);
216   // Pack if not in the same BB.
217   auto *BB = cast<Instruction>(Bndl[0])->getParent();
218   if (any_of(drop_begin(Bndl),
219              [BB](auto *V) { return cast<Instruction>(V)->getParent() != BB; }))
220     return createLegalityResult<Pack>(ResultReason::DiffBBs);
221   // Pack if instructions repeat, i.e., require some sort of broadcast.
222   SmallPtrSet<Value *, 8> Unique(llvm::from_range, Bndl);
223   if (Unique.size() != Bndl.size())
224     return createLegalityResult<Pack>(ResultReason::RepeatedInstrs);
225 
226   auto CollectDescrs = getHowToCollectValues(Bndl);
227   if (CollectDescrs.hasVectorInputs()) {
228     if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
229       auto [Vec, Mask] = *ValueShuffleOpt;
230       if (Mask.isIdentity())
231         return createLegalityResult<DiamondReuse>(Vec);
232       return createLegalityResult<DiamondReuseWithShuffle>(Vec, Mask);
233     }
234     return createLegalityResult<DiamondReuseMultiInput>(
235         std::move(CollectDescrs));
236   }
237 
238   if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
239     return createLegalityResult<Pack>(*ReasonOpt);
240 
241   if (!SkipScheduling) {
242     // TODO: Try to remove the IBndl vector.
243     SmallVector<Instruction *, 8> IBndl;
244     IBndl.reserve(Bndl.size());
245     for (auto *V : Bndl)
246       IBndl.push_back(cast<Instruction>(V));
247     if (!Sched.trySchedule(IBndl))
248       return createLegalityResult<Pack>(ResultReason::CantSchedule);
249   }
250 
251   return createLegalityResult<Widen>();
252 }
253 
clear()254 void LegalityAnalysis::clear() {
255   Sched.clear();
256   IMaps.clear();
257 }
258 } // namespace llvm::sandboxir
259