xref: /freebsd/contrib/llvm-project/llvm/utils/TableGen/GlobalISelCombinerEmitter.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- GlobalISelCombinerMatchTableEmitter.cpp - --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file Generate a combiner implementation for GlobalISel from a declarative
10 /// syntax using GlobalISelMatchTable.
11 ///
12 /// Usually, TableGen backends use "assert is an error" as a means to report
13 /// invalid input. They try to diagnose common case but don't try very hard and
14 /// crashes can be common. This backend aims to behave closer to how a language
15 /// compiler frontend would behave: we try extra hard to diagnose invalid inputs
16 /// early, and any crash should be considered a bug (= a feature or diagnostic
17 /// is missing).
18 ///
19 /// While this can make the backend a bit more complex than it needs to be, it
20 /// pays off because MIR patterns can get complicated. Giving useful error
21 /// messages to combine writers can help boost their productivity.
22 ///
23 /// As with anything, a good balance has to be found. We also don't want to
24 /// write hundreds of lines of code to detect edge cases. In practice, crashing
25 /// very occasionally, or giving poor errors in some rare instances, is fine.
26 ///
27 //===----------------------------------------------------------------------===//
28 
29 #include "Basic/CodeGenIntrinsics.h"
30 #include "Common/CodeGenInstruction.h"
31 #include "Common/CodeGenTarget.h"
32 #include "Common/GlobalISel/CXXPredicates.h"
33 #include "Common/GlobalISel/CodeExpander.h"
34 #include "Common/GlobalISel/CodeExpansions.h"
35 #include "Common/GlobalISel/CombinerUtils.h"
36 #include "Common/GlobalISel/GlobalISelMatchTable.h"
37 #include "Common/GlobalISel/GlobalISelMatchTableExecutorEmitter.h"
38 #include "Common/GlobalISel/PatternParser.h"
39 #include "Common/GlobalISel/Patterns.h"
40 #include "Common/SubtargetFeatureInfo.h"
41 #include "llvm/ADT/APInt.h"
42 #include "llvm/ADT/EquivalenceClasses.h"
43 #include "llvm/ADT/Hashing.h"
44 #include "llvm/ADT/MapVector.h"
45 #include "llvm/ADT/SetVector.h"
46 #include "llvm/ADT/Statistic.h"
47 #include "llvm/ADT/StringExtras.h"
48 #include "llvm/ADT/StringSet.h"
49 #include "llvm/Support/CommandLine.h"
50 #include "llvm/Support/Debug.h"
51 #include "llvm/Support/PrettyStackTrace.h"
52 #include "llvm/Support/ScopedPrinter.h"
53 #include "llvm/TableGen/Error.h"
54 #include "llvm/TableGen/Record.h"
55 #include "llvm/TableGen/StringMatcher.h"
56 #include "llvm/TableGen/TableGenBackend.h"
57 #include <cstdint>
58 
59 using namespace llvm;
60 using namespace llvm::gi;
61 
62 #define DEBUG_TYPE "gicombiner-emitter"
63 
64 namespace {
65 cl::OptionCategory
66     GICombinerEmitterCat("Options for -gen-global-isel-combiner");
67 cl::opt<bool> StopAfterParse(
68     "gicombiner-stop-after-parse",
69     cl::desc("Stop processing after parsing rules and dump state"),
70     cl::cat(GICombinerEmitterCat));
71 cl::list<std::string>
72     SelectedCombiners("combiners", cl::desc("Emit the specified combiners"),
73                       cl::cat(GICombinerEmitterCat), cl::CommaSeparated);
74 cl::opt<bool> DebugCXXPreds(
75     "gicombiner-debug-cxxpreds",
76     cl::desc("Add Contextual/Debug comments to all C++ predicates"),
77     cl::cat(GICombinerEmitterCat));
78 cl::opt<bool> DebugTypeInfer("gicombiner-debug-typeinfer",
79                              cl::desc("Print type inference debug logs"),
80                              cl::cat(GICombinerEmitterCat));
81 
82 constexpr StringLiteral CXXCustomActionPrefix = "GICXXCustomAction_";
83 constexpr StringLiteral CXXPredPrefix = "GICXXPred_MI_Predicate_";
84 constexpr StringLiteral MatchDataClassName = "GIDefMatchData";
85 
86 //===- CodeExpansions Helpers  --------------------------------------------===//
87 
declareInstExpansion(CodeExpansions & CE,const InstructionMatcher & IM,StringRef Name)88 void declareInstExpansion(CodeExpansions &CE, const InstructionMatcher &IM,
89                           StringRef Name) {
90   CE.declare(Name, "State.MIs[" + to_string(IM.getInsnVarID()) + "]");
91 }
92 
declareInstExpansion(CodeExpansions & CE,const BuildMIAction & A,StringRef Name)93 void declareInstExpansion(CodeExpansions &CE, const BuildMIAction &A,
94                           StringRef Name) {
95   // Note: we use redeclare here because this may overwrite a matcher inst
96   // expansion.
97   CE.redeclare(Name, "OutMIs[" + to_string(A.getInsnID()) + "]");
98 }
99 
declareOperandExpansion(CodeExpansions & CE,const OperandMatcher & OM,StringRef Name)100 void declareOperandExpansion(CodeExpansions &CE, const OperandMatcher &OM,
101                              StringRef Name) {
102   CE.declare(Name, "State.MIs[" + to_string(OM.getInsnVarID()) +
103                        "]->getOperand(" + to_string(OM.getOpIdx()) + ")");
104 }
105 
declareTempRegExpansion(CodeExpansions & CE,unsigned TempRegID,StringRef Name)106 void declareTempRegExpansion(CodeExpansions &CE, unsigned TempRegID,
107                              StringRef Name) {
108   CE.declare(Name, "State.TempRegisters[" + to_string(TempRegID) + "]");
109 }
110 
111 //===- Misc. Helpers  -----------------------------------------------------===//
112 
keys(Container && C)113 template <typename Container> auto keys(Container &&C) {
114   return map_range(C, [](auto &Entry) -> auto & { return Entry.first; });
115 }
116 
values(Container && C)117 template <typename Container> auto values(Container &&C) {
118   return map_range(C, [](auto &Entry) -> auto & { return Entry.second; });
119 }
120 
getIsEnabledPredicateEnumName(unsigned CombinerRuleID)121 std::string getIsEnabledPredicateEnumName(unsigned CombinerRuleID) {
122   return "GICXXPred_Simple_IsRule" + to_string(CombinerRuleID) + "Enabled";
123 }
124 
125 //===- MatchTable Helpers  ------------------------------------------------===//
126 
getLLTCodeGen(const PatternType & PT)127 LLTCodeGen getLLTCodeGen(const PatternType &PT) {
128   return *MVTToLLT(getValueType(PT.getLLTRecord()));
129 }
130 
getLLTCodeGenOrTempType(const PatternType & PT,RuleMatcher & RM)131 LLTCodeGenOrTempType getLLTCodeGenOrTempType(const PatternType &PT,
132                                              RuleMatcher &RM) {
133   assert(!PT.isNone());
134 
135   if (PT.isLLT())
136     return getLLTCodeGen(PT);
137 
138   assert(PT.isTypeOf());
139   auto &OM = RM.getOperandMatcher(PT.getTypeOfOpName());
140   return OM.getTempTypeIdx(RM);
141 }
142 
143 //===- PrettyStackTrace Helpers  ------------------------------------------===//
144 
145 class PrettyStackTraceParse : public PrettyStackTraceEntry {
146   const Record &Def;
147 
148 public:
PrettyStackTraceParse(const Record & Def)149   PrettyStackTraceParse(const Record &Def) : Def(Def) {}
150 
print(raw_ostream & OS) const151   void print(raw_ostream &OS) const override {
152     if (Def.isSubClassOf("GICombineRule"))
153       OS << "Parsing GICombineRule '" << Def.getName() << "'";
154     else if (Def.isSubClassOf(PatFrag::ClassName))
155       OS << "Parsing " << PatFrag::ClassName << " '" << Def.getName() << "'";
156     else
157       OS << "Parsing '" << Def.getName() << "'";
158     OS << '\n';
159   }
160 };
161 
162 class PrettyStackTraceEmit : public PrettyStackTraceEntry {
163   const Record &Def;
164   const Pattern *Pat = nullptr;
165 
166 public:
PrettyStackTraceEmit(const Record & Def,const Pattern * Pat=nullptr)167   PrettyStackTraceEmit(const Record &Def, const Pattern *Pat = nullptr)
168       : Def(Def), Pat(Pat) {}
169 
print(raw_ostream & OS) const170   void print(raw_ostream &OS) const override {
171     if (Def.isSubClassOf("GICombineRule"))
172       OS << "Emitting GICombineRule '" << Def.getName() << "'";
173     else if (Def.isSubClassOf(PatFrag::ClassName))
174       OS << "Emitting " << PatFrag::ClassName << " '" << Def.getName() << "'";
175     else
176       OS << "Emitting '" << Def.getName() << "'";
177 
178     if (Pat)
179       OS << " [" << Pat->getKindName() << " '" << Pat->getName() << "']";
180     OS << '\n';
181   }
182 };
183 
184 //===- CombineRuleOperandTypeChecker --------------------------------------===//
185 
186 /// This is a wrapper around OperandTypeChecker specialized for Combiner Rules.
187 /// On top of doing the same things as OperandTypeChecker, this also attempts to
188 /// infer as many types as possible for temporary register defs & immediates in
189 /// apply patterns.
190 ///
191 /// The inference is trivial and leverages the MCOI OperandTypes encoded in
192 /// CodeGenInstructions to infer types across patterns in a CombineRule. It's
193 /// thus very limited and only supports CodeGenInstructions (but that's the main
194 /// use case so it's fine).
195 ///
196 /// We only try to infer untyped operands in apply patterns when they're temp
197 /// reg defs, or immediates. Inference always outputs a `TypeOf<$x>` where $x is
198 /// a named operand from a match pattern.
199 class CombineRuleOperandTypeChecker : private OperandTypeChecker {
200 public:
CombineRuleOperandTypeChecker(const Record & RuleDef,const OperandTable & MatchOpTable)201   CombineRuleOperandTypeChecker(const Record &RuleDef,
202                                 const OperandTable &MatchOpTable)
203       : OperandTypeChecker(RuleDef.getLoc()), RuleDef(RuleDef),
204         MatchOpTable(MatchOpTable) {}
205 
206   /// Records and checks a 'match' pattern.
207   bool processMatchPattern(InstructionPattern &P);
208 
209   /// Records and checks an 'apply' pattern.
210   bool processApplyPattern(InstructionPattern &P);
211 
212   /// Propagates types, then perform type inference and do a second round of
213   /// propagation in the apply patterns only if any types were inferred.
214   void propagateAndInferTypes();
215 
216 private:
217   /// TypeEquivalenceClasses are groups of operands of an instruction that share
218   /// a common type.
219   ///
220   /// e.g. [[a, b], [c, d]] means a and b have the same type, and c and
221   /// d have the same type too. b/c and a/d don't have to have the same type,
222   /// though.
223   using TypeEquivalenceClasses = EquivalenceClasses<StringRef>;
224 
225   /// \returns true for `OPERAND_GENERIC_` 0 through 5.
226   /// These are the MCOI types that can be registers. The other MCOI types are
227   /// either immediates, or fancier operands used only post-ISel, so we don't
228   /// care about them for combiners.
canMCOIOperandTypeBeARegister(StringRef MCOIType)229   static bool canMCOIOperandTypeBeARegister(StringRef MCOIType) {
230     // Assume OPERAND_GENERIC_0 through 5 can be registers. The other MCOI
231     // OperandTypes are either never used in gMIR, or not relevant (e.g.
232     // OPERAND_GENERIC_IMM, which is definitely never a register).
233     return MCOIType.drop_back(1).ends_with("OPERAND_GENERIC_");
234   }
235 
236   /// Finds the "MCOI::"" operand types for each operand of \p CGP.
237   ///
238   /// This is a bit trickier than it looks because we need to handle variadic
239   /// in/outs.
240   ///
241   /// e.g. for
242   ///   (G_BUILD_VECTOR $vec, $x, $y) ->
243   ///   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
244   ///    MCOI::OPERAND_GENERIC_1]
245   ///
246   /// For unknown types (which can happen in variadics where varargs types are
247   /// inconsistent), a unique name is given, e.g. "unknown_type_0".
248   static std::vector<std::string>
249   getMCOIOperandTypes(const CodeGenInstructionPattern &CGP);
250 
251   /// Adds the TypeEquivalenceClasses for \p P in \p OutTECs.
252   void getInstEqClasses(const InstructionPattern &P,
253                         TypeEquivalenceClasses &OutTECs) const;
254 
255   /// Calls `getInstEqClasses` on all patterns of the rule to produce the whole
256   /// rule's TypeEquivalenceClasses.
257   TypeEquivalenceClasses getRuleEqClasses() const;
258 
259   /// Tries to infer the type of the \p ImmOpIdx -th operand of \p IP using \p
260   /// TECs.
261   ///
262   /// This is achieved by trying to find a named operand in \p IP that shares
263   /// the same type as \p ImmOpIdx, and using \ref inferNamedOperandType on that
264   /// operand instead.
265   ///
266   /// \returns the inferred type or an empty PatternType if inference didn't
267   /// succeed.
268   PatternType inferImmediateType(const InstructionPattern &IP,
269                                  unsigned ImmOpIdx,
270                                  const TypeEquivalenceClasses &TECs) const;
271 
272   /// Looks inside \p TECs to infer \p OpName's type.
273   ///
274   /// \returns the inferred type or an empty PatternType if inference didn't
275   /// succeed.
276   PatternType inferNamedOperandType(const InstructionPattern &IP,
277                                     StringRef OpName,
278                                     const TypeEquivalenceClasses &TECs,
279                                     bool AllowSelf = false) const;
280 
281   const Record &RuleDef;
282   SmallVector<InstructionPattern *, 8> MatchPats;
283   SmallVector<InstructionPattern *, 8> ApplyPats;
284 
285   const OperandTable &MatchOpTable;
286 };
287 
processMatchPattern(InstructionPattern & P)288 bool CombineRuleOperandTypeChecker::processMatchPattern(InstructionPattern &P) {
289   MatchPats.push_back(&P);
290   return check(P, /*CheckTypeOf*/ [](const auto &) {
291     // GITypeOf in 'match' is currently always rejected by the
292     // CombineRuleBuilder after inference is done.
293     return true;
294   });
295 }
296 
processApplyPattern(InstructionPattern & P)297 bool CombineRuleOperandTypeChecker::processApplyPattern(InstructionPattern &P) {
298   ApplyPats.push_back(&P);
299   return check(P, /*CheckTypeOf*/ [&](const PatternType &Ty) {
300     // GITypeOf<"$x"> can only be used if "$x" is a matched operand.
301     const auto OpName = Ty.getTypeOfOpName();
302     if (MatchOpTable.lookup(OpName).Found)
303       return true;
304 
305     PrintError(RuleDef.getLoc(), "'" + OpName + "' ('" + Ty.str() +
306                                      "') does not refer to a matched operand!");
307     return false;
308   });
309 }
310 
propagateAndInferTypes()311 void CombineRuleOperandTypeChecker::propagateAndInferTypes() {
312   /// First step here is to propagate types using the OperandTypeChecker. That
313   /// way we ensure all uses of a given register have consistent types.
314   propagateTypes();
315 
316   /// Build the TypeEquivalenceClasses for the whole rule.
317   const TypeEquivalenceClasses TECs = getRuleEqClasses();
318 
319   /// Look at the apply patterns and find operands that need to be
320   /// inferred. We then try to find an equivalence class that they're a part of
321   /// and select the best operand to use for the `GITypeOf` type. We prioritize
322   /// defs of matched instructions because those are guaranteed to be registers.
323   bool InferredAny = false;
324   for (auto *Pat : ApplyPats) {
325     for (unsigned K = 0; K < Pat->operands_size(); ++K) {
326       auto &Op = Pat->getOperand(K);
327 
328       // We only want to take a look at untyped defs or immediates.
329       if ((!Op.isDef() && !Op.hasImmValue()) || Op.getType())
330         continue;
331 
332       // Infer defs & named immediates.
333       if (Op.isDef() || Op.isNamedImmediate()) {
334         // Check it's not a redefinition of a matched operand.
335         // In such cases, inference is not necessary because we just copy
336         // operands and don't create temporary registers.
337         if (MatchOpTable.lookup(Op.getOperandName()).Found)
338           continue;
339 
340         // Inference is needed here, so try to do it.
341         if (PatternType Ty =
342                 inferNamedOperandType(*Pat, Op.getOperandName(), TECs)) {
343           if (DebugTypeInfer)
344             errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << '\n';
345           Op.setType(Ty);
346           InferredAny = true;
347         }
348 
349         continue;
350       }
351 
352       // Infer immediates
353       if (Op.hasImmValue()) {
354         if (PatternType Ty = inferImmediateType(*Pat, K, TECs)) {
355           if (DebugTypeInfer)
356             errs() << "INFER: " << Op.describe() << " -> " << Ty.str() << '\n';
357           Op.setType(Ty);
358           InferredAny = true;
359         }
360         continue;
361       }
362     }
363   }
364 
365   // If we've inferred any types, we want to propagate them across the apply
366   // patterns. Type inference only adds GITypeOf types that point to Matched
367   // operands, so we definitely don't want to propagate types into the match
368   // patterns as well, otherwise bad things happen.
369   if (InferredAny) {
370     OperandTypeChecker OTC(RuleDef.getLoc());
371     for (auto *Pat : ApplyPats) {
372       if (!OTC.check(*Pat, [&](const auto &) { return true; }))
373         PrintFatalError(RuleDef.getLoc(),
374                         "OperandTypeChecker unexpectedly failed on '" +
375                             Pat->getName() + "' during Type Inference");
376     }
377     OTC.propagateTypes();
378 
379     if (DebugTypeInfer) {
380       errs() << "Apply patterns for rule " << RuleDef.getName()
381              << " after inference:\n";
382       for (auto *Pat : ApplyPats) {
383         errs() << "  ";
384         Pat->print(errs(), /*PrintName*/ true);
385         errs() << '\n';
386       }
387       errs() << '\n';
388     }
389   }
390 }
391 
inferImmediateType(const InstructionPattern & IP,unsigned ImmOpIdx,const TypeEquivalenceClasses & TECs) const392 PatternType CombineRuleOperandTypeChecker::inferImmediateType(
393     const InstructionPattern &IP, unsigned ImmOpIdx,
394     const TypeEquivalenceClasses &TECs) const {
395   // We can only infer CGPs (except intrinsics).
396   const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&IP);
397   if (!CGP || CGP->isIntrinsic())
398     return {};
399 
400   // For CGPs, we try to infer immediates by trying to infer another named
401   // operand that shares its type.
402   //
403   // e.g.
404   //    Pattern: G_BUILD_VECTOR $x, $y, 0
405   //    MCOIs:   [MCOI::OPERAND_GENERIC_0, MCOI::OPERAND_GENERIC_1,
406   //              MCOI::OPERAND_GENERIC_1]
407   //    $y has the same type as 0, so we can infer $y and get the type 0 should
408   //    have.
409 
410   // We infer immediates by looking for a named operand that shares the same
411   // MCOI type.
412   const auto MCOITypes = getMCOIOperandTypes(*CGP);
413   StringRef ImmOpTy = MCOITypes[ImmOpIdx];
414 
415   for (const auto &[Idx, Ty] : enumerate(MCOITypes)) {
416     if (Idx != ImmOpIdx && Ty == ImmOpTy) {
417       const auto &Op = IP.getOperand(Idx);
418       if (!Op.isNamedOperand())
419         continue;
420 
421       // Named operand with the same name, try to infer that.
422       if (PatternType InferTy = inferNamedOperandType(IP, Op.getOperandName(),
423                                                       TECs, /*AllowSelf=*/true))
424         return InferTy;
425     }
426   }
427 
428   return {};
429 }
430 
inferNamedOperandType(const InstructionPattern & IP,StringRef OpName,const TypeEquivalenceClasses & TECs,bool AllowSelf) const431 PatternType CombineRuleOperandTypeChecker::inferNamedOperandType(
432     const InstructionPattern &IP, StringRef OpName,
433     const TypeEquivalenceClasses &TECs, bool AllowSelf) const {
434   // This is the simplest possible case, we just need to find a TEC that
435   // contains OpName. Look at all operands in equivalence class and try to
436   // find a suitable one. If `AllowSelf` is true, the operand itself is also
437   // considered suitable.
438 
439   // Check for a def of a matched pattern. This is guaranteed to always
440   // be a register so we can blindly use that.
441   StringRef GoodOpName;
442   for (auto It = TECs.findLeader(OpName); It != TECs.member_end(); ++It) {
443     if (!AllowSelf && *It == OpName)
444       continue;
445 
446     const auto LookupRes = MatchOpTable.lookup(*It);
447     if (LookupRes.Def) // Favor defs
448       return PatternType::getTypeOf(*It);
449 
450     // Otherwise just save this in case we don't find any def.
451     if (GoodOpName.empty() && LookupRes.Found)
452       GoodOpName = *It;
453   }
454 
455   if (!GoodOpName.empty())
456     return PatternType::getTypeOf(GoodOpName);
457 
458   // No good operand found, give up.
459   return {};
460 }
461 
getMCOIOperandTypes(const CodeGenInstructionPattern & CGP)462 std::vector<std::string> CombineRuleOperandTypeChecker::getMCOIOperandTypes(
463     const CodeGenInstructionPattern &CGP) {
464   // FIXME?: Should we cache this? We call it twice when inferring immediates.
465 
466   static unsigned UnknownTypeIdx = 0;
467 
468   std::vector<std::string> OpTypes;
469   auto &CGI = CGP.getInst();
470   Record *VarArgsTy = CGI.TheDef->isSubClassOf("GenericInstruction")
471                           ? CGI.TheDef->getValueAsOptionalDef("variadicOpsType")
472                           : nullptr;
473   std::string VarArgsTyName =
474       VarArgsTy ? ("MCOI::" + VarArgsTy->getValueAsString("OperandType")).str()
475                 : ("unknown_type_" + Twine(UnknownTypeIdx++)).str();
476 
477   // First, handle defs.
478   for (unsigned K = 0; K < CGI.Operands.NumDefs; ++K)
479     OpTypes.push_back(CGI.Operands[K].OperandType);
480 
481   // Then, handle variadic defs if there are any.
482   if (CGP.hasVariadicDefs()) {
483     for (unsigned K = CGI.Operands.NumDefs; K < CGP.getNumInstDefs(); ++K)
484       OpTypes.push_back(VarArgsTyName);
485   }
486 
487   // If we had variadic defs, the op idx in the pattern won't match the op idx
488   // in the CGI anymore.
489   int CGIOpOffset = int(CGI.Operands.NumDefs) - CGP.getNumInstDefs();
490   assert(CGP.hasVariadicDefs() ? (CGIOpOffset <= 0) : (CGIOpOffset == 0));
491 
492   // Handle all remaining use operands, including variadic ones.
493   for (unsigned K = CGP.getNumInstDefs(); K < CGP.getNumInstOperands(); ++K) {
494     unsigned CGIOpIdx = K + CGIOpOffset;
495     if (CGIOpIdx >= CGI.Operands.size()) {
496       assert(CGP.isVariadic());
497       OpTypes.push_back(VarArgsTyName);
498     } else {
499       OpTypes.push_back(CGI.Operands[CGIOpIdx].OperandType);
500     }
501   }
502 
503   assert(OpTypes.size() == CGP.operands_size());
504   return OpTypes;
505 }
506 
getInstEqClasses(const InstructionPattern & P,TypeEquivalenceClasses & OutTECs) const507 void CombineRuleOperandTypeChecker::getInstEqClasses(
508     const InstructionPattern &P, TypeEquivalenceClasses &OutTECs) const {
509   // Determine the TypeEquivalenceClasses by:
510   //    - Getting the MCOI Operand Types.
511   //    - Creating a Map of MCOI Type -> [Operand Indexes]
512   //    - Iterating over the map, filtering types we don't like, and just adding
513   //      the array of Operand Indexes to \p OutTECs.
514 
515   // We can only do this on CodeGenInstructions that aren't intrinsics. Other
516   // InstructionPatterns have no type inference information associated with
517   // them.
518   // TODO: We could try to extract some info from CodeGenIntrinsic to
519   //       guide inference.
520 
521   // TODO: Could we add some inference information to builtins at least? e.g.
522   // ReplaceReg should always replace with a reg of the same type, for instance.
523   // Though, those patterns are often used alone so it might not be worth the
524   // trouble to infer their types.
525   auto *CGP = dyn_cast<CodeGenInstructionPattern>(&P);
526   if (!CGP || CGP->isIntrinsic())
527     return;
528 
529   const auto MCOITypes = getMCOIOperandTypes(*CGP);
530   assert(MCOITypes.size() == P.operands_size());
531 
532   MapVector<StringRef, SmallVector<unsigned, 0>> TyToOpIdx;
533   for (const auto &[Idx, Ty] : enumerate(MCOITypes))
534     TyToOpIdx[Ty].push_back(Idx);
535 
536   if (DebugTypeInfer)
537     errs() << "\tGroups for " << P.getName() << ":\t";
538 
539   for (const auto &[Ty, Idxs] : TyToOpIdx) {
540     if (!canMCOIOperandTypeBeARegister(Ty))
541       continue;
542 
543     if (DebugTypeInfer)
544       errs() << '[';
545     StringRef Sep = "";
546 
547     // We only collect named operands.
548     StringRef Leader;
549     for (unsigned Idx : Idxs) {
550       const auto &Op = P.getOperand(Idx);
551       if (!Op.isNamedOperand())
552         continue;
553 
554       const auto OpName = Op.getOperandName();
555       if (DebugTypeInfer) {
556         errs() << Sep << OpName;
557         Sep = ", ";
558       }
559 
560       if (Leader.empty())
561         OutTECs.insert((Leader = OpName));
562       else
563         OutTECs.unionSets(Leader, OpName);
564     }
565 
566     if (DebugTypeInfer)
567       errs() << "] ";
568   }
569 
570   if (DebugTypeInfer)
571     errs() << '\n';
572 }
573 
574 CombineRuleOperandTypeChecker::TypeEquivalenceClasses
getRuleEqClasses() const575 CombineRuleOperandTypeChecker::getRuleEqClasses() const {
576   StringMap<unsigned> OpNameToEqClassIdx;
577   TypeEquivalenceClasses TECs;
578 
579   if (DebugTypeInfer)
580     errs() << "Rule Operand Type Equivalence Classes for " << RuleDef.getName()
581            << ":\n";
582 
583   for (const auto *Pat : MatchPats)
584     getInstEqClasses(*Pat, TECs);
585   for (const auto *Pat : ApplyPats)
586     getInstEqClasses(*Pat, TECs);
587 
588   if (DebugTypeInfer) {
589     errs() << "Final Type Equivalence Classes: ";
590     for (auto ClassIt = TECs.begin(); ClassIt != TECs.end(); ++ClassIt) {
591       // only print non-empty classes.
592       if (auto MembIt = TECs.member_begin(ClassIt);
593           MembIt != TECs.member_end()) {
594         errs() << '[';
595         StringRef Sep = "";
596         for (; MembIt != TECs.member_end(); ++MembIt) {
597           errs() << Sep << *MembIt;
598           Sep = ", ";
599         }
600         errs() << "] ";
601       }
602     }
603     errs() << '\n';
604   }
605 
606   return TECs;
607 }
608 
609 //===- MatchData Handling -------------------------------------------------===//
610 struct MatchDataDef {
MatchDataDef__anon569e2aa20111::MatchDataDef611   MatchDataDef(StringRef Symbol, StringRef Type) : Symbol(Symbol), Type(Type) {}
612 
613   StringRef Symbol;
614   StringRef Type;
615 
616   /// \returns the desired variable name for this MatchData.
getVarName__anon569e2aa20111::MatchDataDef617   std::string getVarName() const {
618     // Add a prefix in case the symbol name is very generic and conflicts with
619     // something else.
620     return "GIMatchData_" + Symbol.str();
621   }
622 };
623 
624 //===- CombineRuleBuilder -------------------------------------------------===//
625 
626 /// Parses combine rule and builds a small intermediate representation to tie
627 /// patterns together and emit RuleMatchers to match them. This may emit more
628 /// than one RuleMatcher, e.g. for `wip_match_opcode`.
629 ///
630 /// Memory management for `Pattern` objects is done through `std::unique_ptr`.
631 /// In most cases, there are two stages to a pattern's lifetime:
632 ///   - Creation in a `parse` function
633 ///     - The unique_ptr is stored in a variable, and may be destroyed if the
634 ///       pattern is found to be semantically invalid.
635 ///   - Ownership transfer into a `PatternMap`
636 ///     - Once a pattern is moved into either the map of Match or Apply
637 ///       patterns, it is known to be valid and it never moves back.
638 class CombineRuleBuilder {
639 public:
640   using PatternMap = MapVector<StringRef, std::unique_ptr<Pattern>>;
641   using PatternAlternatives = DenseMap<const Pattern *, unsigned>;
642 
CombineRuleBuilder(const CodeGenTarget & CGT,SubtargetFeatureInfoMap & SubtargetFeatures,Record & RuleDef,unsigned ID,std::vector<RuleMatcher> & OutRMs)643   CombineRuleBuilder(const CodeGenTarget &CGT,
644                      SubtargetFeatureInfoMap &SubtargetFeatures,
645                      Record &RuleDef, unsigned ID,
646                      std::vector<RuleMatcher> &OutRMs)
647       : Parser(CGT, RuleDef.getLoc()), CGT(CGT),
648         SubtargetFeatures(SubtargetFeatures), RuleDef(RuleDef), RuleID(ID),
649         OutRMs(OutRMs) {}
650 
651   /// Parses all fields in the RuleDef record.
652   bool parseAll();
653 
654   /// Emits all RuleMatchers into the vector of RuleMatchers passed in the
655   /// constructor.
656   bool emitRuleMatchers();
657 
658   void print(raw_ostream &OS) const;
dump() const659   void dump() const { print(dbgs()); }
660 
661   /// Debug-only verification of invariants.
662 #ifndef NDEBUG
663   void verify() const;
664 #endif
665 
666 private:
getGConstant() const667   const CodeGenInstruction &getGConstant() const {
668     return CGT.getInstruction(RuleDef.getRecords().getDef("G_CONSTANT"));
669   }
670 
PrintError(Twine Msg) const671   void PrintError(Twine Msg) const { ::PrintError(&RuleDef, Msg); }
PrintWarning(Twine Msg) const672   void PrintWarning(Twine Msg) const { ::PrintWarning(RuleDef.getLoc(), Msg); }
PrintNote(Twine Msg) const673   void PrintNote(Twine Msg) const { ::PrintNote(RuleDef.getLoc(), Msg); }
674 
675   void print(raw_ostream &OS, const PatternAlternatives &Alts) const;
676 
677   bool addApplyPattern(std::unique_ptr<Pattern> Pat);
678   bool addMatchPattern(std::unique_ptr<Pattern> Pat);
679 
680   /// Adds the expansions from \see MatchDatas to \p CE.
681   void declareAllMatchDatasExpansions(CodeExpansions &CE) const;
682 
683   /// Adds a matcher \p P to \p IM, expanding its code using \p CE.
684   /// Note that the predicate is added on the last InstructionMatcher.
685   ///
686   /// \p Alts is only used if DebugCXXPreds is enabled.
687   void addCXXPredicate(RuleMatcher &M, const CodeExpansions &CE,
688                        const CXXPattern &P, const PatternAlternatives &Alts);
689 
690   bool hasOnlyCXXApplyPatterns() const;
691   bool hasEraseRoot() const;
692 
693   // Infer machine operand types and check their consistency.
694   bool typecheckPatterns();
695 
696   /// For all PatFragPatterns, add a new entry in PatternAlternatives for each
697   /// PatternList it contains. This is multiplicative, so if we have 2
698   /// PatFrags with 3 alternatives each, we get 2*3 permutations added to
699   /// PermutationsToEmit. The "MaxPermutations" field controls how many
700   /// permutations are allowed before an error is emitted and this function
701   /// returns false. This is a simple safeguard to prevent combination of
702   /// PatFrags from generating enormous amounts of rules.
703   bool buildPermutationsToEmit();
704 
705   /// Checks additional semantics of the Patterns.
706   bool checkSemantics();
707 
708   /// Creates a new RuleMatcher with some boilerplate
709   /// settings/actions/predicates, and and adds it to \p OutRMs.
710   /// \see addFeaturePredicates too.
711   ///
712   /// \param Alts Current set of alternatives, for debug comment.
713   /// \param AdditionalComment Comment string to be added to the
714   ///        `DebugCommentAction`.
715   RuleMatcher &addRuleMatcher(const PatternAlternatives &Alts,
716                               Twine AdditionalComment = "");
717   bool addFeaturePredicates(RuleMatcher &M);
718 
719   bool findRoots();
720   bool buildRuleOperandsTable();
721 
722   bool parseDefs(const DagInit &Def);
723 
724   bool emitMatchPattern(CodeExpansions &CE, const PatternAlternatives &Alts,
725                         const InstructionPattern &IP);
726   bool emitMatchPattern(CodeExpansions &CE, const PatternAlternatives &Alts,
727                         const AnyOpcodePattern &AOP);
728 
729   bool emitPatFragMatchPattern(CodeExpansions &CE,
730                                const PatternAlternatives &Alts, RuleMatcher &RM,
731                                InstructionMatcher *IM,
732                                const PatFragPattern &PFP,
733                                DenseSet<const Pattern *> &SeenPats);
734 
735   bool emitApplyPatterns(CodeExpansions &CE, RuleMatcher &M);
736   bool emitCXXMatchApply(CodeExpansions &CE, RuleMatcher &M,
737                          ArrayRef<CXXPattern *> Matchers);
738 
739   // Recursively visits InstructionPatterns from P to build up the
740   // RuleMatcher actions.
741   bool emitInstructionApplyPattern(CodeExpansions &CE, RuleMatcher &M,
742                                    const InstructionPattern &P,
743                                    DenseSet<const Pattern *> &SeenPats,
744                                    StringMap<unsigned> &OperandToTempRegID);
745 
746   bool emitCodeGenInstructionApplyImmOperand(RuleMatcher &M,
747                                              BuildMIAction &DstMI,
748                                              const CodeGenInstructionPattern &P,
749                                              const InstructionOperand &O);
750 
751   bool emitBuiltinApplyPattern(CodeExpansions &CE, RuleMatcher &M,
752                                const BuiltinPattern &P,
753                                StringMap<unsigned> &OperandToTempRegID);
754 
755   // Recursively visits CodeGenInstructionPattern from P to build up the
756   // RuleMatcher/InstructionMatcher. May create new InstructionMatchers as
757   // needed.
758   using OperandMapperFnRef =
759       function_ref<InstructionOperand(const InstructionOperand &)>;
760   using OperandDefLookupFn =
761       function_ref<const InstructionPattern *(StringRef)>;
762   bool emitCodeGenInstructionMatchPattern(
763       CodeExpansions &CE, const PatternAlternatives &Alts, RuleMatcher &M,
764       InstructionMatcher &IM, const CodeGenInstructionPattern &P,
765       DenseSet<const Pattern *> &SeenPats, OperandDefLookupFn LookupOperandDef,
__anon569e2aa20702(const auto &O) 766       OperandMapperFnRef OperandMapper = [](const auto &O) { return O; });
767 
768   PatternParser Parser;
769   const CodeGenTarget &CGT;
770   SubtargetFeatureInfoMap &SubtargetFeatures;
771   Record &RuleDef;
772   const unsigned RuleID;
773   std::vector<RuleMatcher> &OutRMs;
774 
775   // For InstructionMatcher::addOperand
776   unsigned AllocatedTemporariesBaseID = 0;
777 
778   /// The root of the pattern.
779   StringRef RootName;
780 
781   /// These maps have ownership of the actual Pattern objects.
782   /// They both map a Pattern's name to the Pattern instance.
783   PatternMap MatchPats;
784   PatternMap ApplyPats;
785 
786   /// Operand tables to tie match/apply patterns together.
787   OperandTable MatchOpTable;
788   OperandTable ApplyOpTable;
789 
790   /// Set by findRoots.
791   Pattern *MatchRoot = nullptr;
792   SmallDenseSet<InstructionPattern *, 2> ApplyRoots;
793 
794   SmallVector<MatchDataDef, 2> MatchDatas;
795   SmallVector<PatternAlternatives, 1> PermutationsToEmit;
796 };
797 
parseAll()798 bool CombineRuleBuilder::parseAll() {
799   auto StackTrace = PrettyStackTraceParse(RuleDef);
800 
801   if (!parseDefs(*RuleDef.getValueAsDag("Defs")))
802     return false;
803 
804   if (!Parser.parsePatternList(
805           *RuleDef.getValueAsDag("Match"),
806           [this](auto Pat) { return addMatchPattern(std::move(Pat)); }, "match",
807           (RuleDef.getName() + "_match").str()))
808     return false;
809 
810   if (!Parser.parsePatternList(
811           *RuleDef.getValueAsDag("Apply"),
812           [this](auto Pat) { return addApplyPattern(std::move(Pat)); }, "apply",
813           (RuleDef.getName() + "_apply").str()))
814     return false;
815 
816   if (!buildRuleOperandsTable() || !typecheckPatterns() || !findRoots() ||
817       !checkSemantics() || !buildPermutationsToEmit())
818     return false;
819   LLVM_DEBUG(verify());
820   return true;
821 }
822 
emitRuleMatchers()823 bool CombineRuleBuilder::emitRuleMatchers() {
824   auto StackTrace = PrettyStackTraceEmit(RuleDef);
825 
826   assert(MatchRoot);
827   CodeExpansions CE;
828 
829   assert(!PermutationsToEmit.empty());
830   for (const auto &Alts : PermutationsToEmit) {
831     switch (MatchRoot->getKind()) {
832     case Pattern::K_AnyOpcode: {
833       if (!emitMatchPattern(CE, Alts, *cast<AnyOpcodePattern>(MatchRoot)))
834         return false;
835       break;
836     }
837     case Pattern::K_PatFrag:
838     case Pattern::K_Builtin:
839     case Pattern::K_CodeGenInstruction:
840       if (!emitMatchPattern(CE, Alts, *cast<InstructionPattern>(MatchRoot)))
841         return false;
842       break;
843     case Pattern::K_CXX:
844       PrintError("C++ code cannot be the root of a rule!");
845       return false;
846     default:
847       llvm_unreachable("unknown pattern kind!");
848     }
849   }
850 
851   return true;
852 }
853 
print(raw_ostream & OS) const854 void CombineRuleBuilder::print(raw_ostream &OS) const {
855   OS << "(CombineRule name:" << RuleDef.getName() << " id:" << RuleID
856      << " root:" << RootName << '\n';
857 
858   if (!MatchDatas.empty()) {
859     OS << "  (MatchDatas\n";
860     for (const auto &MD : MatchDatas) {
861       OS << "    (MatchDataDef symbol:" << MD.Symbol << " type:" << MD.Type
862          << ")\n";
863     }
864     OS << "  )\n";
865   }
866 
867   const auto &SeenPFs = Parser.getSeenPatFrags();
868   if (!SeenPFs.empty()) {
869     OS << "  (PatFrags\n";
870     for (const auto *PF : Parser.getSeenPatFrags()) {
871       PF->print(OS, /*Indent=*/"    ");
872       OS << '\n';
873     }
874     OS << "  )\n";
875   }
876 
877   const auto DumpPats = [&](StringRef Name, const PatternMap &Pats) {
878     OS << "  (" << Name << " ";
879     if (Pats.empty()) {
880       OS << "<empty>)\n";
881       return;
882     }
883 
884     OS << '\n';
885     for (const auto &[Name, Pat] : Pats) {
886       OS << "    ";
887       if (Pat.get() == MatchRoot)
888         OS << "<match_root>";
889       if (isa<InstructionPattern>(Pat.get()) &&
890           ApplyRoots.contains(cast<InstructionPattern>(Pat.get())))
891         OS << "<apply_root>";
892       OS << Name << ":";
893       Pat->print(OS, /*PrintName=*/false);
894       OS << '\n';
895     }
896     OS << "  )\n";
897   };
898 
899   DumpPats("MatchPats", MatchPats);
900   DumpPats("ApplyPats", ApplyPats);
901 
902   MatchOpTable.print(OS, "MatchPats", /*Indent*/ "  ");
903   ApplyOpTable.print(OS, "ApplyPats", /*Indent*/ "  ");
904 
905   if (PermutationsToEmit.size() > 1) {
906     OS << "  (PermutationsToEmit\n";
907     for (const auto &Perm : PermutationsToEmit) {
908       OS << "    ";
909       print(OS, Perm);
910       OS << ",\n";
911     }
912     OS << "  )\n";
913   }
914 
915   OS << ")\n";
916 }
917 
918 #ifndef NDEBUG
verify() const919 void CombineRuleBuilder::verify() const {
920   const auto VerifyPats = [&](const PatternMap &Pats) {
921     for (const auto &[Name, Pat] : Pats) {
922       if (!Pat)
923         PrintFatalError("null pattern in pattern map!");
924 
925       if (Name != Pat->getName()) {
926         Pat->dump();
927         PrintFatalError("Pattern name mismatch! Map name: " + Name +
928                         ", Pat name: " + Pat->getName());
929       }
930 
931       // Sanity check: the map should point to the same data as the Pattern.
932       // Both strings are allocated in the pool using insertStrRef.
933       if (Name.data() != Pat->getName().data()) {
934         dbgs() << "Map StringRef: '" << Name << "' @ "
935                << (const void *)Name.data() << '\n';
936         dbgs() << "Pat String: '" << Pat->getName() << "' @ "
937                << (const void *)Pat->getName().data() << '\n';
938         PrintFatalError("StringRef stored in the PatternMap is not referencing "
939                         "the same string as its Pattern!");
940       }
941     }
942   };
943 
944   VerifyPats(MatchPats);
945   VerifyPats(ApplyPats);
946 
947   // Check there are no wip_match_opcode patterns in the "apply" patterns.
948   if (any_of(ApplyPats,
949              [&](auto &E) { return isa<AnyOpcodePattern>(E.second.get()); })) {
950     dump();
951     PrintFatalError(
952         "illegal wip_match_opcode pattern in the 'apply' patterns!");
953   }
954 
955   // Check there are no nullptrs in ApplyRoots.
956   if (ApplyRoots.contains(nullptr)) {
957     PrintFatalError(
958         "CombineRuleBuilder's ApplyRoots set contains a null pointer!");
959   }
960 }
961 #endif
962 
print(raw_ostream & OS,const PatternAlternatives & Alts) const963 void CombineRuleBuilder::print(raw_ostream &OS,
964                                const PatternAlternatives &Alts) const {
965   SmallVector<std::string, 1> Strings(
966       map_range(Alts, [](const auto &PatAndPerm) {
967         return PatAndPerm.first->getName().str() + "[" +
968                to_string(PatAndPerm.second) + "]";
969       }));
970   // Sort so output is deterministic for tests. Otherwise it's sorted by pointer
971   // values.
972   sort(Strings);
973   OS << "[" << join(Strings, ", ") << "]";
974 }
975 
addApplyPattern(std::unique_ptr<Pattern> Pat)976 bool CombineRuleBuilder::addApplyPattern(std::unique_ptr<Pattern> Pat) {
977   StringRef Name = Pat->getName();
978   if (ApplyPats.contains(Name)) {
979     PrintError("'" + Name + "' apply pattern defined more than once!");
980     return false;
981   }
982 
983   if (isa<AnyOpcodePattern>(Pat.get())) {
984     PrintError("'" + Name +
985                "': wip_match_opcode is not supported in apply patterns");
986     return false;
987   }
988 
989   if (isa<PatFragPattern>(Pat.get())) {
990     PrintError("'" + Name + "': using " + PatFrag::ClassName +
991                " is not supported in apply patterns");
992     return false;
993   }
994 
995   if (auto *CXXPat = dyn_cast<CXXPattern>(Pat.get()))
996     CXXPat->setIsApply();
997 
998   ApplyPats[Name] = std::move(Pat);
999   return true;
1000 }
1001 
addMatchPattern(std::unique_ptr<Pattern> Pat)1002 bool CombineRuleBuilder::addMatchPattern(std::unique_ptr<Pattern> Pat) {
1003   StringRef Name = Pat->getName();
1004   if (MatchPats.contains(Name)) {
1005     PrintError("'" + Name + "' match pattern defined more than once!");
1006     return false;
1007   }
1008 
1009   // For now, none of the builtins can appear in 'match'.
1010   if (const auto *BP = dyn_cast<BuiltinPattern>(Pat.get())) {
1011     PrintError("'" + BP->getInstName() +
1012                "' cannot be used in a 'match' pattern");
1013     return false;
1014   }
1015 
1016   MatchPats[Name] = std::move(Pat);
1017   return true;
1018 }
1019 
declareAllMatchDatasExpansions(CodeExpansions & CE) const1020 void CombineRuleBuilder::declareAllMatchDatasExpansions(
1021     CodeExpansions &CE) const {
1022   for (const auto &MD : MatchDatas)
1023     CE.declare(MD.Symbol, MD.getVarName());
1024 }
1025 
addCXXPredicate(RuleMatcher & M,const CodeExpansions & CE,const CXXPattern & P,const PatternAlternatives & Alts)1026 void CombineRuleBuilder::addCXXPredicate(RuleMatcher &M,
1027                                          const CodeExpansions &CE,
1028                                          const CXXPattern &P,
1029                                          const PatternAlternatives &Alts) {
1030   // FIXME: Hack so C++ code is executed last. May not work for more complex
1031   // patterns.
1032   auto &IM = *std::prev(M.insnmatchers().end());
1033   auto Loc = RuleDef.getLoc();
1034   const auto AddComment = [&](raw_ostream &OS) {
1035     OS << "// Pattern Alternatives: ";
1036     print(OS, Alts);
1037     OS << '\n';
1038   };
1039   const auto &ExpandedCode =
1040       DebugCXXPreds ? P.expandCode(CE, Loc, AddComment) : P.expandCode(CE, Loc);
1041   IM->addPredicate<GenericInstructionPredicateMatcher>(
1042       ExpandedCode.getEnumNameWithPrefix(CXXPredPrefix));
1043 }
1044 
hasOnlyCXXApplyPatterns() const1045 bool CombineRuleBuilder::hasOnlyCXXApplyPatterns() const {
1046   return all_of(ApplyPats, [&](auto &Entry) {
1047     return isa<CXXPattern>(Entry.second.get());
1048   });
1049 }
1050 
hasEraseRoot() const1051 bool CombineRuleBuilder::hasEraseRoot() const {
1052   return any_of(ApplyPats, [&](auto &Entry) {
1053     if (const auto *BP = dyn_cast<BuiltinPattern>(Entry.second.get()))
1054       return BP->getBuiltinKind() == BI_EraseRoot;
1055     return false;
1056   });
1057 }
1058 
typecheckPatterns()1059 bool CombineRuleBuilder::typecheckPatterns() {
1060   CombineRuleOperandTypeChecker OTC(RuleDef, MatchOpTable);
1061 
1062   for (auto &Pat : values(MatchPats)) {
1063     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
1064       if (!OTC.processMatchPattern(*IP))
1065         return false;
1066     }
1067   }
1068 
1069   for (auto &Pat : values(ApplyPats)) {
1070     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
1071       if (!OTC.processApplyPattern(*IP))
1072         return false;
1073     }
1074   }
1075 
1076   OTC.propagateAndInferTypes();
1077 
1078   // Always check this after in case inference adds some special types to the
1079   // match patterns.
1080   for (auto &Pat : values(MatchPats)) {
1081     if (auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
1082       if (IP->diagnoseAllSpecialTypes(
1083               RuleDef.getLoc(), PatternType::SpecialTyClassName +
1084                                     " is not supported in 'match' patterns")) {
1085         return false;
1086       }
1087     }
1088   }
1089   return true;
1090 }
1091 
buildPermutationsToEmit()1092 bool CombineRuleBuilder::buildPermutationsToEmit() {
1093   PermutationsToEmit.clear();
1094 
1095   // Start with one empty set of alternatives.
1096   PermutationsToEmit.emplace_back();
1097   for (const auto &Pat : values(MatchPats)) {
1098     unsigned NumAlts = 0;
1099     // Note: technically, AnyOpcodePattern also needs permutations, but:
1100     //    - We only allow a single one of them in the root.
1101     //    - They cannot be mixed with any other pattern other than C++ code.
1102     // So we don't really need to take them into account here. We could, but
1103     // that pattern is a hack anyway and the less it's involved, the better.
1104     if (const auto *PFP = dyn_cast<PatFragPattern>(Pat.get()))
1105       NumAlts = PFP->getPatFrag().num_alternatives();
1106     else
1107       continue;
1108 
1109     // For each pattern that needs permutations, multiply the current set of
1110     // alternatives.
1111     auto CurPerms = PermutationsToEmit;
1112     PermutationsToEmit.clear();
1113 
1114     for (const auto &Perm : CurPerms) {
1115       assert(!Perm.count(Pat.get()) && "Pattern already emitted?");
1116       for (unsigned K = 0; K < NumAlts; ++K) {
1117         PatternAlternatives NewPerm = Perm;
1118         NewPerm[Pat.get()] = K;
1119         PermutationsToEmit.emplace_back(std::move(NewPerm));
1120       }
1121     }
1122   }
1123 
1124   if (int64_t MaxPerms = RuleDef.getValueAsInt("MaxPermutations");
1125       MaxPerms > 0) {
1126     if ((int64_t)PermutationsToEmit.size() > MaxPerms) {
1127       PrintError("cannot emit rule '" + RuleDef.getName() + "'; " +
1128                  Twine(PermutationsToEmit.size()) +
1129                  " permutations would be emitted, but the max is " +
1130                  Twine(MaxPerms));
1131       return false;
1132     }
1133   }
1134 
1135   // Ensure we always have a single empty entry, it simplifies the emission
1136   // logic so it doesn't need to handle the case where there are no perms.
1137   if (PermutationsToEmit.empty()) {
1138     PermutationsToEmit.emplace_back();
1139     return true;
1140   }
1141 
1142   return true;
1143 }
1144 
checkSemantics()1145 bool CombineRuleBuilder::checkSemantics() {
1146   assert(MatchRoot && "Cannot call this before findRoots()");
1147 
1148   bool UsesWipMatchOpcode = false;
1149   for (const auto &Match : MatchPats) {
1150     const auto *Pat = Match.second.get();
1151 
1152     if (const auto *CXXPat = dyn_cast<CXXPattern>(Pat)) {
1153       if (!CXXPat->getRawCode().contains("return "))
1154         PrintWarning("'match' C++ code does not seem to return!");
1155       continue;
1156     }
1157 
1158     // MIFlags in match cannot use the following syntax: (MIFlags $mi)
1159     if (const auto *CGP = dyn_cast<CodeGenInstructionPattern>(Pat)) {
1160       if (auto *FI = CGP->getMIFlagsInfo()) {
1161         if (!FI->copy_flags().empty()) {
1162           PrintError(
1163               "'match' patterns cannot refer to flags from other instructions");
1164           PrintNote("MIFlags in '" + CGP->getName() +
1165                     "' refer to: " + join(FI->copy_flags(), ", "));
1166           return false;
1167         }
1168       }
1169     }
1170 
1171     const auto *AOP = dyn_cast<AnyOpcodePattern>(Pat);
1172     if (!AOP)
1173       continue;
1174 
1175     if (UsesWipMatchOpcode) {
1176       PrintError("wip_opcode_match can only be present once");
1177       return false;
1178     }
1179 
1180     UsesWipMatchOpcode = true;
1181   }
1182 
1183   std::optional<bool> IsUsingCXXPatterns;
1184   for (const auto &Apply : ApplyPats) {
1185     Pattern *Pat = Apply.second.get();
1186     if (IsUsingCXXPatterns) {
1187       if (*IsUsingCXXPatterns != isa<CXXPattern>(Pat)) {
1188         PrintError("'apply' patterns cannot mix C++ code with other types of "
1189                    "patterns");
1190         return false;
1191       }
1192     } else
1193       IsUsingCXXPatterns = isa<CXXPattern>(Pat);
1194 
1195     assert(Pat);
1196     const auto *IP = dyn_cast<InstructionPattern>(Pat);
1197     if (!IP)
1198       continue;
1199 
1200     if (UsesWipMatchOpcode) {
1201       PrintError("cannot use wip_match_opcode in combination with apply "
1202                  "instruction patterns!");
1203       return false;
1204     }
1205 
1206     // Check that the insts mentioned in copy_flags exist.
1207     if (const auto *CGP = dyn_cast<CodeGenInstructionPattern>(IP)) {
1208       if (auto *FI = CGP->getMIFlagsInfo()) {
1209         for (auto InstName : FI->copy_flags()) {
1210           auto It = MatchPats.find(InstName);
1211           if (It == MatchPats.end()) {
1212             PrintError("unknown instruction '$" + InstName +
1213                        "' referenced in MIFlags of '" + CGP->getName() + "'");
1214             return false;
1215           }
1216 
1217           if (!isa<CodeGenInstructionPattern>(It->second.get())) {
1218             PrintError(
1219                 "'$" + InstName +
1220                 "' does not refer to a CodeGenInstruction in MIFlags of '" +
1221                 CGP->getName() + "'");
1222             return false;
1223           }
1224         }
1225       }
1226     }
1227 
1228     const auto *BIP = dyn_cast<BuiltinPattern>(IP);
1229     if (!BIP)
1230       continue;
1231     StringRef Name = BIP->getInstName();
1232 
1233     // (GIEraseInst) has to be the only apply pattern, or it can not be used at
1234     // all. The root cannot have any defs either.
1235     switch (BIP->getBuiltinKind()) {
1236     case BI_EraseRoot: {
1237       if (ApplyPats.size() > 1) {
1238         PrintError(Name + " must be the only 'apply' pattern");
1239         return false;
1240       }
1241 
1242       const auto *IRoot = dyn_cast<CodeGenInstructionPattern>(MatchRoot);
1243       if (!IRoot) {
1244         PrintError(Name + " can only be used if the root is a "
1245                           "CodeGenInstruction or Intrinsic");
1246         return false;
1247       }
1248 
1249       if (IRoot->getNumInstDefs() != 0) {
1250         PrintError(Name + " can only be used if on roots that do "
1251                           "not have any output operand");
1252         PrintNote("'" + IRoot->getInstName() + "' has " +
1253                   Twine(IRoot->getNumInstDefs()) + " output operands");
1254         return false;
1255       }
1256       break;
1257     }
1258     case BI_ReplaceReg: {
1259       // (GIReplaceReg can only be used on the root instruction)
1260       // TODO: When we allow rewriting non-root instructions, also allow this.
1261       StringRef OldRegName = BIP->getOperand(0).getOperandName();
1262       auto *Def = MatchOpTable.getDef(OldRegName);
1263       if (!Def) {
1264         PrintError(Name + " cannot find a matched pattern that defines '" +
1265                    OldRegName + "'");
1266         return false;
1267       }
1268       if (MatchOpTable.getDef(OldRegName) != MatchRoot) {
1269         PrintError(Name + " cannot replace '" + OldRegName +
1270                    "': this builtin can only replace a register defined by the "
1271                    "match root");
1272         return false;
1273       }
1274       break;
1275     }
1276     }
1277   }
1278 
1279   if (!hasOnlyCXXApplyPatterns() && !MatchDatas.empty()) {
1280     PrintError(MatchDataClassName +
1281                " can only be used if 'apply' in entirely written in C++");
1282     return false;
1283   }
1284 
1285   return true;
1286 }
1287 
addRuleMatcher(const PatternAlternatives & Alts,Twine AdditionalComment)1288 RuleMatcher &CombineRuleBuilder::addRuleMatcher(const PatternAlternatives &Alts,
1289                                                 Twine AdditionalComment) {
1290   auto &RM = OutRMs.emplace_back(RuleDef.getLoc());
1291   addFeaturePredicates(RM);
1292   RM.setPermanentGISelFlags(GISF_IgnoreCopies);
1293   RM.addRequiredSimplePredicate(getIsEnabledPredicateEnumName(RuleID));
1294 
1295   std::string Comment;
1296   raw_string_ostream CommentOS(Comment);
1297   CommentOS << "Combiner Rule #" << RuleID << ": " << RuleDef.getName();
1298   if (!Alts.empty()) {
1299     CommentOS << " @ ";
1300     print(CommentOS, Alts);
1301   }
1302   if (!AdditionalComment.isTriviallyEmpty())
1303     CommentOS << "; " << AdditionalComment;
1304   RM.addAction<DebugCommentAction>(Comment);
1305   return RM;
1306 }
1307 
addFeaturePredicates(RuleMatcher & M)1308 bool CombineRuleBuilder::addFeaturePredicates(RuleMatcher &M) {
1309   if (!RuleDef.getValue("Predicates"))
1310     return true;
1311 
1312   ListInit *Preds = RuleDef.getValueAsListInit("Predicates");
1313   for (Init *PI : Preds->getValues()) {
1314     DefInit *Pred = dyn_cast<DefInit>(PI);
1315     if (!Pred)
1316       continue;
1317 
1318     Record *Def = Pred->getDef();
1319     if (!Def->isSubClassOf("Predicate")) {
1320       ::PrintError(Def, "Unknown 'Predicate' Type");
1321       return false;
1322     }
1323 
1324     if (Def->getValueAsString("CondString").empty())
1325       continue;
1326 
1327     if (SubtargetFeatures.count(Def) == 0) {
1328       SubtargetFeatures.emplace(
1329           Def, SubtargetFeatureInfo(Def, SubtargetFeatures.size()));
1330     }
1331 
1332     M.addRequiredFeature(Def);
1333   }
1334 
1335   return true;
1336 }
1337 
findRoots()1338 bool CombineRuleBuilder::findRoots() {
1339   const auto Finish = [&]() {
1340     assert(MatchRoot);
1341 
1342     if (hasOnlyCXXApplyPatterns() || hasEraseRoot())
1343       return true;
1344 
1345     auto *IPRoot = dyn_cast<InstructionPattern>(MatchRoot);
1346     if (!IPRoot)
1347       return true;
1348 
1349     if (IPRoot->getNumInstDefs() == 0) {
1350       // No defs to work with -> find the root using the pattern name.
1351       auto It = ApplyPats.find(RootName);
1352       if (It == ApplyPats.end()) {
1353         PrintError("Cannot find root '" + RootName + "' in apply patterns!");
1354         return false;
1355       }
1356 
1357       auto *ApplyRoot = dyn_cast<InstructionPattern>(It->second.get());
1358       if (!ApplyRoot) {
1359         PrintError("apply pattern root '" + RootName +
1360                    "' must be an instruction pattern");
1361         return false;
1362       }
1363 
1364       ApplyRoots.insert(ApplyRoot);
1365       return true;
1366     }
1367 
1368     // Collect all redefinitions of the MatchRoot's defs and put them in
1369     // ApplyRoots.
1370     const auto DefsNeeded = IPRoot->getApplyDefsNeeded();
1371     for (auto &Op : DefsNeeded) {
1372       assert(Op.isDef() && Op.isNamedOperand());
1373       StringRef Name = Op.getOperandName();
1374 
1375       auto *ApplyRedef = ApplyOpTable.getDef(Name);
1376       if (!ApplyRedef) {
1377         PrintError("'" + Name + "' must be redefined in the 'apply' pattern");
1378         return false;
1379       }
1380 
1381       ApplyRoots.insert((InstructionPattern *)ApplyRedef);
1382     }
1383 
1384     if (auto It = ApplyPats.find(RootName); It != ApplyPats.end()) {
1385       if (find(ApplyRoots, It->second.get()) == ApplyRoots.end()) {
1386         PrintError("apply pattern '" + RootName +
1387                    "' is supposed to be a root but it does not redefine any of "
1388                    "the defs of the match root");
1389         return false;
1390       }
1391     }
1392 
1393     return true;
1394   };
1395 
1396   // Look by pattern name, e.g.
1397   //    (G_FNEG $x, $y):$root
1398   if (auto MatchPatIt = MatchPats.find(RootName);
1399       MatchPatIt != MatchPats.end()) {
1400     MatchRoot = MatchPatIt->second.get();
1401     return Finish();
1402   }
1403 
1404   // Look by def:
1405   //    (G_FNEG $root, $y)
1406   auto LookupRes = MatchOpTable.lookup(RootName);
1407   if (!LookupRes.Found) {
1408     PrintError("Cannot find root '" + RootName + "' in match patterns!");
1409     return false;
1410   }
1411 
1412   MatchRoot = LookupRes.Def;
1413   if (!MatchRoot) {
1414     PrintError("Cannot use live-in operand '" + RootName +
1415                "' as match pattern root!");
1416     return false;
1417   }
1418 
1419   return Finish();
1420 }
1421 
buildRuleOperandsTable()1422 bool CombineRuleBuilder::buildRuleOperandsTable() {
1423   const auto DiagnoseRedefMatch = [&](StringRef OpName) {
1424     PrintError("Operand '" + OpName +
1425                "' is defined multiple times in the 'match' patterns");
1426   };
1427 
1428   const auto DiagnoseRedefApply = [&](StringRef OpName) {
1429     PrintError("Operand '" + OpName +
1430                "' is defined multiple times in the 'apply' patterns");
1431   };
1432 
1433   for (auto &Pat : values(MatchPats)) {
1434     auto *IP = dyn_cast<InstructionPattern>(Pat.get());
1435     if (IP && !MatchOpTable.addPattern(IP, DiagnoseRedefMatch))
1436       return false;
1437   }
1438 
1439   for (auto &Pat : values(ApplyPats)) {
1440     auto *IP = dyn_cast<InstructionPattern>(Pat.get());
1441     if (IP && !ApplyOpTable.addPattern(IP, DiagnoseRedefApply))
1442       return false;
1443   }
1444 
1445   return true;
1446 }
1447 
parseDefs(const DagInit & Def)1448 bool CombineRuleBuilder::parseDefs(const DagInit &Def) {
1449   if (Def.getOperatorAsDef(RuleDef.getLoc())->getName() != "defs") {
1450     PrintError("Expected defs operator");
1451     return false;
1452   }
1453 
1454   SmallVector<StringRef> Roots;
1455   for (unsigned I = 0, E = Def.getNumArgs(); I < E; ++I) {
1456     if (isSpecificDef(*Def.getArg(I), "root")) {
1457       Roots.emplace_back(Def.getArgNameStr(I));
1458       continue;
1459     }
1460 
1461     // Subclasses of GIDefMatchData should declare that this rule needs to pass
1462     // data from the match stage to the apply stage, and ensure that the
1463     // generated matcher has a suitable variable for it to do so.
1464     if (Record *MatchDataRec =
1465             getDefOfSubClass(*Def.getArg(I), MatchDataClassName)) {
1466       MatchDatas.emplace_back(Def.getArgNameStr(I),
1467                               MatchDataRec->getValueAsString("Type"));
1468       continue;
1469     }
1470 
1471     // Otherwise emit an appropriate error message.
1472     if (getDefOfSubClass(*Def.getArg(I), "GIDefKind"))
1473       PrintError("This GIDefKind not implemented in tablegen");
1474     else if (getDefOfSubClass(*Def.getArg(I), "GIDefKindWithArgs"))
1475       PrintError("This GIDefKindWithArgs not implemented in tablegen");
1476     else
1477       PrintError("Expected a subclass of GIDefKind or a sub-dag whose "
1478                  "operator is of type GIDefKindWithArgs");
1479     return false;
1480   }
1481 
1482   if (Roots.size() != 1) {
1483     PrintError("Combine rules must have exactly one root");
1484     return false;
1485   }
1486 
1487   RootName = Roots.front();
1488   return true;
1489 }
1490 
emitMatchPattern(CodeExpansions & CE,const PatternAlternatives & Alts,const InstructionPattern & IP)1491 bool CombineRuleBuilder::emitMatchPattern(CodeExpansions &CE,
1492                                           const PatternAlternatives &Alts,
1493                                           const InstructionPattern &IP) {
1494   auto StackTrace = PrettyStackTraceEmit(RuleDef, &IP);
1495 
1496   auto &M = addRuleMatcher(Alts);
1497   InstructionMatcher &IM = M.addInstructionMatcher(IP.getName());
1498   declareInstExpansion(CE, IM, IP.getName());
1499 
1500   DenseSet<const Pattern *> SeenPats;
1501 
1502   const auto FindOperandDef = [&](StringRef Op) -> InstructionPattern * {
1503     return MatchOpTable.getDef(Op);
1504   };
1505 
1506   if (const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&IP)) {
1507     if (!emitCodeGenInstructionMatchPattern(CE, Alts, M, IM, *CGP, SeenPats,
1508                                             FindOperandDef))
1509       return false;
1510   } else if (const auto *PFP = dyn_cast<PatFragPattern>(&IP)) {
1511     if (!PFP->getPatFrag().canBeMatchRoot()) {
1512       PrintError("cannot use '" + PFP->getInstName() + " as match root");
1513       return false;
1514     }
1515 
1516     if (!emitPatFragMatchPattern(CE, Alts, M, &IM, *PFP, SeenPats))
1517       return false;
1518   } else if (isa<BuiltinPattern>(&IP)) {
1519     llvm_unreachable("No match builtins known!");
1520   } else
1521     llvm_unreachable("Unknown kind of InstructionPattern!");
1522 
1523   // Emit remaining patterns
1524   const bool IsUsingCustomCXXAction = hasOnlyCXXApplyPatterns();
1525   SmallVector<CXXPattern *, 2> CXXMatchers;
1526   for (auto &Pat : values(MatchPats)) {
1527     if (SeenPats.contains(Pat.get()))
1528       continue;
1529 
1530     switch (Pat->getKind()) {
1531     case Pattern::K_AnyOpcode:
1532       PrintError("wip_match_opcode can not be used with instruction patterns!");
1533       return false;
1534     case Pattern::K_PatFrag: {
1535       if (!emitPatFragMatchPattern(CE, Alts, M, /*IM*/ nullptr,
1536                                    *cast<PatFragPattern>(Pat.get()), SeenPats))
1537         return false;
1538       continue;
1539     }
1540     case Pattern::K_Builtin:
1541       PrintError("No known match builtins");
1542       return false;
1543     case Pattern::K_CodeGenInstruction:
1544       cast<InstructionPattern>(Pat.get())->reportUnreachable(RuleDef.getLoc());
1545       return false;
1546     case Pattern::K_CXX: {
1547       // Delay emission for top-level C++ matchers (which can use MatchDatas).
1548       if (IsUsingCustomCXXAction)
1549         CXXMatchers.push_back(cast<CXXPattern>(Pat.get()));
1550       else
1551         addCXXPredicate(M, CE, *cast<CXXPattern>(Pat.get()), Alts);
1552       continue;
1553     }
1554     default:
1555       llvm_unreachable("unknown pattern kind!");
1556     }
1557   }
1558 
1559   return IsUsingCustomCXXAction ? emitCXXMatchApply(CE, M, CXXMatchers)
1560                                 : emitApplyPatterns(CE, M);
1561 }
1562 
emitMatchPattern(CodeExpansions & CE,const PatternAlternatives & Alts,const AnyOpcodePattern & AOP)1563 bool CombineRuleBuilder::emitMatchPattern(CodeExpansions &CE,
1564                                           const PatternAlternatives &Alts,
1565                                           const AnyOpcodePattern &AOP) {
1566   auto StackTrace = PrettyStackTraceEmit(RuleDef, &AOP);
1567 
1568   const bool IsUsingCustomCXXAction = hasOnlyCXXApplyPatterns();
1569   for (const CodeGenInstruction *CGI : AOP.insts()) {
1570     auto &M = addRuleMatcher(Alts, "wip_match_opcode '" +
1571                                        CGI->TheDef->getName() + "'");
1572 
1573     InstructionMatcher &IM = M.addInstructionMatcher(AOP.getName());
1574     declareInstExpansion(CE, IM, AOP.getName());
1575     // declareInstExpansion needs to be identical, otherwise we need to create a
1576     // CodeExpansions object here instead.
1577     assert(IM.getInsnVarID() == 0);
1578 
1579     IM.addPredicate<InstructionOpcodeMatcher>(CGI);
1580 
1581     // Emit remaining patterns.
1582     SmallVector<CXXPattern *, 2> CXXMatchers;
1583     for (auto &Pat : values(MatchPats)) {
1584       if (Pat.get() == &AOP)
1585         continue;
1586 
1587       switch (Pat->getKind()) {
1588       case Pattern::K_AnyOpcode:
1589         PrintError("wip_match_opcode can only be present once!");
1590         return false;
1591       case Pattern::K_PatFrag: {
1592         DenseSet<const Pattern *> SeenPats;
1593         if (!emitPatFragMatchPattern(CE, Alts, M, /*IM*/ nullptr,
1594                                      *cast<PatFragPattern>(Pat.get()),
1595                                      SeenPats))
1596           return false;
1597         continue;
1598       }
1599       case Pattern::K_Builtin:
1600         PrintError("No known match builtins");
1601         return false;
1602       case Pattern::K_CodeGenInstruction:
1603         cast<InstructionPattern>(Pat.get())->reportUnreachable(
1604             RuleDef.getLoc());
1605         return false;
1606       case Pattern::K_CXX: {
1607         // Delay emission for top-level C++ matchers (which can use MatchDatas).
1608         if (IsUsingCustomCXXAction)
1609           CXXMatchers.push_back(cast<CXXPattern>(Pat.get()));
1610         else
1611           addCXXPredicate(M, CE, *cast<CXXPattern>(Pat.get()), Alts);
1612         break;
1613       }
1614       default:
1615         llvm_unreachable("unknown pattern kind!");
1616       }
1617     }
1618 
1619     const bool Res = IsUsingCustomCXXAction
1620                          ? emitCXXMatchApply(CE, M, CXXMatchers)
1621                          : emitApplyPatterns(CE, M);
1622     if (!Res)
1623       return false;
1624   }
1625 
1626   return true;
1627 }
1628 
emitPatFragMatchPattern(CodeExpansions & CE,const PatternAlternatives & Alts,RuleMatcher & RM,InstructionMatcher * IM,const PatFragPattern & PFP,DenseSet<const Pattern * > & SeenPats)1629 bool CombineRuleBuilder::emitPatFragMatchPattern(
1630     CodeExpansions &CE, const PatternAlternatives &Alts, RuleMatcher &RM,
1631     InstructionMatcher *IM, const PatFragPattern &PFP,
1632     DenseSet<const Pattern *> &SeenPats) {
1633   auto StackTrace = PrettyStackTraceEmit(RuleDef, &PFP);
1634 
1635   if (SeenPats.contains(&PFP))
1636     return true;
1637   SeenPats.insert(&PFP);
1638 
1639   const auto &PF = PFP.getPatFrag();
1640 
1641   if (!IM) {
1642     // When we don't have an IM, this means this PatFrag isn't reachable from
1643     // the root. This is only acceptable if it doesn't define anything (e.g. a
1644     // pure C++ PatFrag).
1645     if (PF.num_out_params() != 0) {
1646       PFP.reportUnreachable(RuleDef.getLoc());
1647       return false;
1648     }
1649   } else {
1650     // When an IM is provided, this is reachable from the root, and we're
1651     // expecting to have output operands.
1652     // TODO: If we want to allow for multiple roots we'll need a map of IMs
1653     // then, and emission becomes a bit more complicated.
1654     assert(PF.num_roots() == 1);
1655   }
1656 
1657   CodeExpansions PatFragCEs;
1658   if (!PFP.mapInputCodeExpansions(CE, PatFragCEs, RuleDef.getLoc()))
1659     return false;
1660 
1661   // List of {ParamName, ArgName}.
1662   // When all patterns have been emitted, find expansions in PatFragCEs named
1663   // ArgName and add their expansion to CE using ParamName as the key.
1664   SmallVector<std::pair<std::string, std::string>, 4> CEsToImport;
1665 
1666   // Map parameter names to the actual argument.
1667   const auto OperandMapper =
1668       [&](const InstructionOperand &O) -> InstructionOperand {
1669     if (!O.isNamedOperand())
1670       return O;
1671 
1672     StringRef ParamName = O.getOperandName();
1673 
1674     // Not sure what to do with those tbh. They should probably never be here.
1675     assert(!O.isNamedImmediate() && "TODO: handle named imms");
1676     unsigned PIdx = PF.getParamIdx(ParamName);
1677 
1678     // Map parameters to the argument values.
1679     if (PIdx == (unsigned)-1) {
1680       // This is a temp of the PatFragPattern, prefix the name to avoid
1681       // conflicts.
1682       return O.withNewName(
1683           insertStrRef((PFP.getName() + "." + ParamName).str()));
1684     }
1685 
1686     // The operand will be added to PatFragCEs's code expansions using the
1687     // parameter's name. If it's bound to some operand during emission of the
1688     // patterns, we'll want to add it to CE.
1689     auto ArgOp = PFP.getOperand(PIdx);
1690     if (ArgOp.isNamedOperand())
1691       CEsToImport.emplace_back(ArgOp.getOperandName().str(), ParamName);
1692 
1693     if (ArgOp.getType() && O.getType() && ArgOp.getType() != O.getType()) {
1694       StringRef PFName = PF.getName();
1695       PrintWarning("impossible type constraints: operand " + Twine(PIdx) +
1696                    " of '" + PFP.getName() + "' has type '" +
1697                    ArgOp.getType().str() + "', but '" + PFName +
1698                    "' constrains it to '" + O.getType().str() + "'");
1699       if (ArgOp.isNamedOperand())
1700         PrintNote("operand " + Twine(PIdx) + " of '" + PFP.getName() +
1701                   "' is '" + ArgOp.getOperandName() + "'");
1702       if (O.isNamedOperand())
1703         PrintNote("argument " + Twine(PIdx) + " of '" + PFName + "' is '" +
1704                   ParamName + "'");
1705     }
1706 
1707     return ArgOp;
1708   };
1709 
1710   // PatFragPatterns are only made of InstructionPatterns or CXXPatterns.
1711   // Emit instructions from the root.
1712   const auto &FragAlt = PF.getAlternative(Alts.lookup(&PFP));
1713   const auto &FragAltOT = FragAlt.OpTable;
1714   const auto LookupOperandDef =
1715       [&](StringRef Op) -> const InstructionPattern * {
1716     return FragAltOT.getDef(Op);
1717   };
1718 
1719   DenseSet<const Pattern *> PatFragSeenPats;
1720   for (const auto &[Idx, InOp] : enumerate(PF.out_params())) {
1721     if (InOp.Kind != PatFrag::PK_Root)
1722       continue;
1723 
1724     StringRef ParamName = InOp.Name;
1725     const auto *Def = FragAltOT.getDef(ParamName);
1726     assert(Def && "PatFrag::checkSemantics should have emitted an error if "
1727                   "an out operand isn't defined!");
1728     assert(isa<CodeGenInstructionPattern>(Def) &&
1729            "Nested PatFrags not supported yet");
1730 
1731     if (!emitCodeGenInstructionMatchPattern(
1732             PatFragCEs, Alts, RM, *IM, *cast<CodeGenInstructionPattern>(Def),
1733             PatFragSeenPats, LookupOperandDef, OperandMapper))
1734       return false;
1735   }
1736 
1737   // Emit leftovers.
1738   for (const auto &Pat : FragAlt.Pats) {
1739     if (PatFragSeenPats.contains(Pat.get()))
1740       continue;
1741 
1742     if (const auto *CXXPat = dyn_cast<CXXPattern>(Pat.get())) {
1743       addCXXPredicate(RM, PatFragCEs, *CXXPat, Alts);
1744       continue;
1745     }
1746 
1747     if (const auto *IP = dyn_cast<InstructionPattern>(Pat.get())) {
1748       IP->reportUnreachable(PF.getLoc());
1749       return false;
1750     }
1751 
1752     llvm_unreachable("Unexpected pattern kind in PatFrag");
1753   }
1754 
1755   for (const auto &[ParamName, ArgName] : CEsToImport) {
1756     // Note: we're find if ParamName already exists. It just means it's been
1757     // bound before, so we prefer to keep the first binding.
1758     CE.declare(ParamName, PatFragCEs.lookup(ArgName));
1759   }
1760 
1761   return true;
1762 }
1763 
emitApplyPatterns(CodeExpansions & CE,RuleMatcher & M)1764 bool CombineRuleBuilder::emitApplyPatterns(CodeExpansions &CE, RuleMatcher &M) {
1765   assert(MatchDatas.empty());
1766 
1767   DenseSet<const Pattern *> SeenPats;
1768   StringMap<unsigned> OperandToTempRegID;
1769 
1770   for (auto *ApplyRoot : ApplyRoots) {
1771     assert(isa<InstructionPattern>(ApplyRoot) &&
1772            "Root can only be a InstructionPattern!");
1773     if (!emitInstructionApplyPattern(CE, M,
1774                                      cast<InstructionPattern>(*ApplyRoot),
1775                                      SeenPats, OperandToTempRegID))
1776       return false;
1777   }
1778 
1779   for (auto &Pat : values(ApplyPats)) {
1780     if (SeenPats.contains(Pat.get()))
1781       continue;
1782 
1783     switch (Pat->getKind()) {
1784     case Pattern::K_AnyOpcode:
1785       llvm_unreachable("Unexpected pattern in apply!");
1786     case Pattern::K_PatFrag:
1787       // TODO: We could support pure C++ PatFrags as a temporary thing.
1788       llvm_unreachable("Unexpected pattern in apply!");
1789     case Pattern::K_Builtin:
1790       if (!emitInstructionApplyPattern(CE, M, cast<BuiltinPattern>(*Pat),
1791                                        SeenPats, OperandToTempRegID))
1792         return false;
1793       break;
1794     case Pattern::K_CodeGenInstruction:
1795       cast<CodeGenInstructionPattern>(*Pat).reportUnreachable(RuleDef.getLoc());
1796       return false;
1797     case Pattern::K_CXX: {
1798       llvm_unreachable(
1799           "CXX Pattern Emission should have been handled earlier!");
1800     }
1801     default:
1802       llvm_unreachable("unknown pattern kind!");
1803     }
1804   }
1805 
1806   // Erase the root.
1807   unsigned RootInsnID =
1808       M.getInsnVarID(M.getInstructionMatcher(MatchRoot->getName()));
1809   M.addAction<EraseInstAction>(RootInsnID);
1810 
1811   return true;
1812 }
1813 
emitCXXMatchApply(CodeExpansions & CE,RuleMatcher & M,ArrayRef<CXXPattern * > Matchers)1814 bool CombineRuleBuilder::emitCXXMatchApply(CodeExpansions &CE, RuleMatcher &M,
1815                                            ArrayRef<CXXPattern *> Matchers) {
1816   assert(hasOnlyCXXApplyPatterns());
1817   declareAllMatchDatasExpansions(CE);
1818 
1819   std::string CodeStr;
1820   raw_string_ostream OS(CodeStr);
1821 
1822   for (auto &MD : MatchDatas)
1823     OS << MD.Type << " " << MD.getVarName() << ";\n";
1824 
1825   if (!Matchers.empty()) {
1826     OS << "// Match Patterns\n";
1827     for (auto *M : Matchers) {
1828       OS << "if(![&](){";
1829       CodeExpander Expander(M->getRawCode(), CE, RuleDef.getLoc(),
1830                             /*ShowExpansions=*/false);
1831       Expander.emit(OS);
1832       OS << "}()) {\n"
1833          << "  return false;\n}\n";
1834     }
1835   }
1836 
1837   OS << "// Apply Patterns\n";
1838   ListSeparator LS("\n");
1839   for (auto &Pat : ApplyPats) {
1840     auto *CXXPat = cast<CXXPattern>(Pat.second.get());
1841     CodeExpander Expander(CXXPat->getRawCode(), CE, RuleDef.getLoc(),
1842                           /*ShowExpansions=*/ false);
1843     OS << LS;
1844     Expander.emit(OS);
1845   }
1846 
1847   const auto &Code = CXXPredicateCode::getCustomActionCode(CodeStr);
1848   M.setCustomCXXAction(Code.getEnumNameWithPrefix(CXXCustomActionPrefix));
1849   return true;
1850 }
1851 
emitInstructionApplyPattern(CodeExpansions & CE,RuleMatcher & M,const InstructionPattern & P,DenseSet<const Pattern * > & SeenPats,StringMap<unsigned> & OperandToTempRegID)1852 bool CombineRuleBuilder::emitInstructionApplyPattern(
1853     CodeExpansions &CE, RuleMatcher &M, const InstructionPattern &P,
1854     DenseSet<const Pattern *> &SeenPats,
1855     StringMap<unsigned> &OperandToTempRegID) {
1856   auto StackTrace = PrettyStackTraceEmit(RuleDef, &P);
1857 
1858   if (SeenPats.contains(&P))
1859     return true;
1860 
1861   SeenPats.insert(&P);
1862 
1863   // First, render the uses.
1864   for (auto &Op : P.named_operands()) {
1865     if (Op.isDef())
1866       continue;
1867 
1868     StringRef OpName = Op.getOperandName();
1869     if (const auto *DefPat = ApplyOpTable.getDef(OpName)) {
1870       if (!emitInstructionApplyPattern(CE, M, *DefPat, SeenPats,
1871                                        OperandToTempRegID))
1872         return false;
1873     } else {
1874       // If we have no def, check this exists in the MatchRoot.
1875       if (!Op.isNamedImmediate() && !MatchOpTable.lookup(OpName).Found) {
1876         PrintError("invalid output operand '" + OpName +
1877                    "': operand is not a live-in of the match pattern, and it "
1878                    "has no definition");
1879         return false;
1880       }
1881     }
1882   }
1883 
1884   if (const auto *BP = dyn_cast<BuiltinPattern>(&P))
1885     return emitBuiltinApplyPattern(CE, M, *BP, OperandToTempRegID);
1886 
1887   if (isa<PatFragPattern>(&P))
1888     llvm_unreachable("PatFragPatterns is not supported in 'apply'!");
1889 
1890   auto &CGIP = cast<CodeGenInstructionPattern>(P);
1891 
1892   // Now render this inst.
1893   auto &DstMI =
1894       M.addAction<BuildMIAction>(M.allocateOutputInsnID(), &CGIP.getInst());
1895 
1896   bool HasEmittedIntrinsicID = false;
1897   const auto EmitIntrinsicID = [&]() {
1898     assert(CGIP.isIntrinsic());
1899     DstMI.addRenderer<IntrinsicIDRenderer>(CGIP.getIntrinsic());
1900     HasEmittedIntrinsicID = true;
1901   };
1902 
1903   for (auto &Op : P.operands()) {
1904     // Emit the intrinsic ID after the last def.
1905     if (CGIP.isIntrinsic() && !Op.isDef() && !HasEmittedIntrinsicID)
1906       EmitIntrinsicID();
1907 
1908     if (Op.isNamedImmediate()) {
1909       PrintError("invalid output operand '" + Op.getOperandName() +
1910                  "': output immediates cannot be named");
1911       PrintNote("while emitting pattern '" + P.getName() + "' (" +
1912                 P.getInstName() + ")");
1913       return false;
1914     }
1915 
1916     if (Op.hasImmValue()) {
1917       if (!emitCodeGenInstructionApplyImmOperand(M, DstMI, CGIP, Op))
1918         return false;
1919       continue;
1920     }
1921 
1922     StringRef OpName = Op.getOperandName();
1923 
1924     // Uses of operand.
1925     if (!Op.isDef()) {
1926       if (auto It = OperandToTempRegID.find(OpName);
1927           It != OperandToTempRegID.end()) {
1928         assert(!MatchOpTable.lookup(OpName).Found &&
1929                "Temp reg is also from match pattern?");
1930         DstMI.addRenderer<TempRegRenderer>(It->second);
1931       } else {
1932         // This should be a match live in or a redef of a matched instr.
1933         // If it's a use of a temporary register, then we messed up somewhere -
1934         // the previous condition should have passed.
1935         assert(MatchOpTable.lookup(OpName).Found &&
1936                !ApplyOpTable.getDef(OpName) && "Temp reg not emitted yet!");
1937         DstMI.addRenderer<CopyRenderer>(OpName);
1938       }
1939       continue;
1940     }
1941 
1942     // Determine what we're dealing with. Are we replace a matched instruction?
1943     // Creating a new one?
1944     auto OpLookupRes = MatchOpTable.lookup(OpName);
1945     if (OpLookupRes.Found) {
1946       if (OpLookupRes.isLiveIn()) {
1947         // live-in of the match pattern.
1948         PrintError("Cannot define live-in operand '" + OpName +
1949                    "' in the 'apply' pattern");
1950         return false;
1951       }
1952       assert(OpLookupRes.Def);
1953 
1954       // TODO: Handle this. We need to mutate the instr, or delete the old
1955       // one.
1956       //       Likewise, we also need to ensure we redef everything, if the
1957       //       instr has more than one def, we need to redef all or nothing.
1958       if (OpLookupRes.Def != MatchRoot) {
1959         PrintError("redefining an instruction other than the root is not "
1960                    "supported (operand '" +
1961                    OpName + "')");
1962         return false;
1963       }
1964       // redef of a match
1965       DstMI.addRenderer<CopyRenderer>(OpName);
1966       continue;
1967     }
1968 
1969     // Define a new register unique to the apply patterns (AKA a "temp"
1970     // register).
1971     unsigned TempRegID;
1972     if (auto It = OperandToTempRegID.find(OpName);
1973         It != OperandToTempRegID.end()) {
1974       TempRegID = It->second;
1975     } else {
1976       // This is a brand new register.
1977       TempRegID = M.allocateTempRegID();
1978       OperandToTempRegID[OpName] = TempRegID;
1979       const auto Ty = Op.getType();
1980       if (!Ty) {
1981         PrintError("def of a new register '" + OpName +
1982                    "' in the apply patterns must have a type");
1983         return false;
1984       }
1985 
1986       declareTempRegExpansion(CE, TempRegID, OpName);
1987       // Always insert the action at the beginning, otherwise we may end up
1988       // using the temp reg before it's available.
1989       M.insertAction<MakeTempRegisterAction>(
1990           M.actions_begin(), getLLTCodeGenOrTempType(Ty, M), TempRegID);
1991     }
1992 
1993     DstMI.addRenderer<TempRegRenderer>(TempRegID, /*IsDef=*/true);
1994   }
1995 
1996   // Some intrinsics have no in operands, ensure the ID is still emitted in such
1997   // cases.
1998   if (CGIP.isIntrinsic() && !HasEmittedIntrinsicID)
1999     EmitIntrinsicID();
2000 
2001   // Render MIFlags
2002   if (const auto *FI = CGIP.getMIFlagsInfo()) {
2003     for (StringRef InstName : FI->copy_flags())
2004       DstMI.addCopiedMIFlags(M.getInstructionMatcher(InstName));
2005     for (StringRef F : FI->set_flags())
2006       DstMI.addSetMIFlags(F);
2007     for (StringRef F : FI->unset_flags())
2008       DstMI.addUnsetMIFlags(F);
2009   }
2010 
2011   // Don't allow mutating opcodes for GISel combiners. We want a more precise
2012   // handling of MIFlags so we require them to be explicitly preserved.
2013   //
2014   // TODO: We don't mutate very often, if at all in combiners, but it'd be nice
2015   // to re-enable this. We'd then need to always clear MIFlags when mutating
2016   // opcodes, and never mutate an inst that we copy flags from.
2017   // DstMI.chooseInsnToMutate(M);
2018   declareInstExpansion(CE, DstMI, P.getName());
2019 
2020   return true;
2021 }
2022 
emitCodeGenInstructionApplyImmOperand(RuleMatcher & M,BuildMIAction & DstMI,const CodeGenInstructionPattern & P,const InstructionOperand & O)2023 bool CombineRuleBuilder::emitCodeGenInstructionApplyImmOperand(
2024     RuleMatcher &M, BuildMIAction &DstMI, const CodeGenInstructionPattern &P,
2025     const InstructionOperand &O) {
2026   // If we have a type, we implicitly emit a G_CONSTANT, except for G_CONSTANT
2027   // itself where we emit a CImm.
2028   //
2029   // No type means we emit a simple imm.
2030   // G_CONSTANT is a special case and needs a CImm though so this is likely a
2031   // mistake.
2032   const bool isGConstant = P.is("G_CONSTANT");
2033   const auto Ty = O.getType();
2034   if (!Ty) {
2035     if (isGConstant) {
2036       PrintError("'G_CONSTANT' immediate must be typed!");
2037       PrintNote("while emitting pattern '" + P.getName() + "' (" +
2038                 P.getInstName() + ")");
2039       return false;
2040     }
2041 
2042     DstMI.addRenderer<ImmRenderer>(O.getImmValue());
2043     return true;
2044   }
2045 
2046   auto ImmTy = getLLTCodeGenOrTempType(Ty, M);
2047 
2048   if (isGConstant) {
2049     DstMI.addRenderer<ImmRenderer>(O.getImmValue(), ImmTy);
2050     return true;
2051   }
2052 
2053   unsigned TempRegID = M.allocateTempRegID();
2054   // Ensure MakeTempReg & the BuildConstantAction occur at the beginning.
2055   auto InsertIt = M.insertAction<MakeTempRegisterAction>(M.actions_begin(),
2056                                                          ImmTy, TempRegID);
2057   M.insertAction<BuildConstantAction>(++InsertIt, TempRegID, O.getImmValue());
2058   DstMI.addRenderer<TempRegRenderer>(TempRegID);
2059   return true;
2060 }
2061 
emitBuiltinApplyPattern(CodeExpansions & CE,RuleMatcher & M,const BuiltinPattern & P,StringMap<unsigned> & OperandToTempRegID)2062 bool CombineRuleBuilder::emitBuiltinApplyPattern(
2063     CodeExpansions &CE, RuleMatcher &M, const BuiltinPattern &P,
2064     StringMap<unsigned> &OperandToTempRegID) {
2065   const auto Error = [&](Twine Reason) {
2066     PrintError("cannot emit '" + P.getInstName() + "' builtin: " + Reason);
2067     return false;
2068   };
2069 
2070   switch (P.getBuiltinKind()) {
2071   case BI_EraseRoot: {
2072     // Root is always inst 0.
2073     M.addAction<EraseInstAction>(/*InsnID*/ 0);
2074     return true;
2075   }
2076   case BI_ReplaceReg: {
2077     StringRef Old = P.getOperand(0).getOperandName();
2078     StringRef New = P.getOperand(1).getOperandName();
2079 
2080     if (!ApplyOpTable.lookup(New).Found && !MatchOpTable.lookup(New).Found)
2081       return Error("unknown operand '" + Old + "'");
2082 
2083     auto &OldOM = M.getOperandMatcher(Old);
2084     if (auto It = OperandToTempRegID.find(New);
2085         It != OperandToTempRegID.end()) {
2086       // Replace with temp reg.
2087       M.addAction<ReplaceRegAction>(OldOM.getInsnVarID(), OldOM.getOpIdx(),
2088                                     It->second);
2089     } else {
2090       // Replace with matched reg.
2091       auto &NewOM = M.getOperandMatcher(New);
2092       M.addAction<ReplaceRegAction>(OldOM.getInsnVarID(), OldOM.getOpIdx(),
2093                                     NewOM.getInsnVarID(), NewOM.getOpIdx());
2094     }
2095     // checkSemantics should have ensured that we can only rewrite the root.
2096     // Ensure we're deleting it.
2097     assert(MatchOpTable.getDef(Old) == MatchRoot);
2098     return true;
2099   }
2100   }
2101 
2102   llvm_unreachable("Unknown BuiltinKind!");
2103 }
2104 
isLiteralImm(const InstructionPattern & P,unsigned OpIdx)2105 bool isLiteralImm(const InstructionPattern &P, unsigned OpIdx) {
2106   if (const auto *CGP = dyn_cast<CodeGenInstructionPattern>(&P)) {
2107     StringRef InstName = CGP->getInst().TheDef->getName();
2108     return (InstName == "G_CONSTANT" || InstName == "G_FCONSTANT") &&
2109            OpIdx == 1;
2110   }
2111 
2112   llvm_unreachable("TODO");
2113 }
2114 
emitCodeGenInstructionMatchPattern(CodeExpansions & CE,const PatternAlternatives & Alts,RuleMatcher & M,InstructionMatcher & IM,const CodeGenInstructionPattern & P,DenseSet<const Pattern * > & SeenPats,OperandDefLookupFn LookupOperandDef,OperandMapperFnRef OperandMapper)2115 bool CombineRuleBuilder::emitCodeGenInstructionMatchPattern(
2116     CodeExpansions &CE, const PatternAlternatives &Alts, RuleMatcher &M,
2117     InstructionMatcher &IM, const CodeGenInstructionPattern &P,
2118     DenseSet<const Pattern *> &SeenPats, OperandDefLookupFn LookupOperandDef,
2119     OperandMapperFnRef OperandMapper) {
2120   auto StackTrace = PrettyStackTraceEmit(RuleDef, &P);
2121 
2122   if (SeenPats.contains(&P))
2123     return true;
2124 
2125   SeenPats.insert(&P);
2126 
2127   IM.addPredicate<InstructionOpcodeMatcher>(&P.getInst());
2128   declareInstExpansion(CE, IM, P.getName());
2129 
2130   // If this is an intrinsic, check the intrinsic ID.
2131   if (P.isIntrinsic()) {
2132     // The IntrinsicID's operand is the first operand after the defs.
2133     OperandMatcher &OM = IM.addOperand(P.getNumInstDefs(), "$intrinsic_id",
2134                                        AllocatedTemporariesBaseID++);
2135     OM.addPredicate<IntrinsicIDOperandMatcher>(P.getIntrinsic());
2136   }
2137 
2138   // Check flags if needed.
2139   if (const auto *FI = P.getMIFlagsInfo()) {
2140     assert(FI->copy_flags().empty());
2141 
2142     if (const auto &SetF = FI->set_flags(); !SetF.empty())
2143       IM.addPredicate<MIFlagsInstructionPredicateMatcher>(SetF.getArrayRef());
2144     if (const auto &UnsetF = FI->unset_flags(); !UnsetF.empty())
2145       IM.addPredicate<MIFlagsInstructionPredicateMatcher>(UnsetF.getArrayRef(),
2146                                                           /*CheckNot=*/true);
2147   }
2148 
2149   for (auto [Idx, OriginalO] : enumerate(P.operands())) {
2150     // Remap the operand. This is used when emitting InstructionPatterns inside
2151     // PatFrags, so it can remap them to the arguments passed to the pattern.
2152     //
2153     // We use the remapped operand to emit immediates, and for the symbolic
2154     // operand names (in IM.addOperand). CodeExpansions and OperandTable lookups
2155     // still use the original name.
2156     //
2157     // The "def" flag on the remapped operand is always ignored.
2158     auto RemappedO = OperandMapper(OriginalO);
2159     assert(RemappedO.isNamedOperand() == OriginalO.isNamedOperand() &&
2160            "Cannot remap an unnamed operand to a named one!");
2161 
2162     const auto OpName =
2163         RemappedO.isNamedOperand() ? RemappedO.getOperandName().str() : "";
2164 
2165     // For intrinsics, the first use operand is the intrinsic id, so the true
2166     // operand index is shifted by 1.
2167     //
2168     // From now on:
2169     //    Idx = index in the pattern operand list.
2170     //    RealIdx = expected index in the MachineInstr.
2171     const unsigned RealIdx =
2172         (P.isIntrinsic() && !OriginalO.isDef()) ? (Idx + 1) : Idx;
2173     OperandMatcher &OM =
2174         IM.addOperand(RealIdx, OpName, AllocatedTemporariesBaseID++);
2175     if (!OpName.empty())
2176       declareOperandExpansion(CE, OM, OriginalO.getOperandName());
2177 
2178     // Handle immediates.
2179     if (RemappedO.hasImmValue()) {
2180       if (isLiteralImm(P, Idx))
2181         OM.addPredicate<LiteralIntOperandMatcher>(RemappedO.getImmValue());
2182       else
2183         OM.addPredicate<ConstantIntOperandMatcher>(RemappedO.getImmValue());
2184     }
2185 
2186     // Handle typed operands, but only bother to check if it hasn't been done
2187     // before.
2188     //
2189     // getOperandMatcher will always return the first OM to have been created
2190     // for that Operand. "OM" here is always a new OperandMatcher.
2191     //
2192     // Always emit a check for unnamed operands.
2193     if (OpName.empty() ||
2194         !M.getOperandMatcher(OpName).contains<LLTOperandMatcher>()) {
2195       if (const auto Ty = RemappedO.getType()) {
2196         // TODO: We could support GITypeOf here on the condition that the
2197         // OperandMatcher exists already. Though it's clunky to make this work
2198         // and isn't all that useful so it's just rejected in typecheckPatterns
2199         // at this time.
2200         assert(Ty.isLLT() && "Only LLTs are supported in match patterns!");
2201         OM.addPredicate<LLTOperandMatcher>(getLLTCodeGen(Ty));
2202       }
2203     }
2204 
2205     // Stop here if the operand is a def, or if it had no name.
2206     if (OriginalO.isDef() || !OriginalO.isNamedOperand())
2207       continue;
2208 
2209     const auto *DefPat = LookupOperandDef(OriginalO.getOperandName());
2210     if (!DefPat)
2211       continue;
2212 
2213     if (OriginalO.hasImmValue()) {
2214       assert(!OpName.empty());
2215       // This is a named immediate that also has a def, that's not okay.
2216       // e.g.
2217       //    (G_SEXT $y, (i32 0))
2218       //    (COPY $x, 42:$y)
2219       PrintError("'" + OpName +
2220                  "' is a named immediate, it cannot be defined by another "
2221                  "instruction");
2222       PrintNote("'" + OpName + "' is defined by '" + DefPat->getName() + "'");
2223       return false;
2224     }
2225 
2226     // From here we know that the operand defines an instruction, and we need to
2227     // emit it.
2228     auto InstOpM =
2229         OM.addPredicate<InstructionOperandMatcher>(M, DefPat->getName());
2230     if (!InstOpM) {
2231       // TODO: copy-pasted from GlobalISelEmitter.cpp. Is it still relevant
2232       // here?
2233       PrintError("Nested instruction '" + DefPat->getName() +
2234                  "' cannot be the same as another operand '" +
2235                  OriginalO.getOperandName() + "'");
2236       return false;
2237     }
2238 
2239     auto &IM = (*InstOpM)->getInsnMatcher();
2240     if (const auto *CGIDef = dyn_cast<CodeGenInstructionPattern>(DefPat)) {
2241       if (!emitCodeGenInstructionMatchPattern(CE, Alts, M, IM, *CGIDef,
2242                                               SeenPats, LookupOperandDef,
2243                                               OperandMapper))
2244         return false;
2245       continue;
2246     }
2247 
2248     if (const auto *PFPDef = dyn_cast<PatFragPattern>(DefPat)) {
2249       if (!emitPatFragMatchPattern(CE, Alts, M, &IM, *PFPDef, SeenPats))
2250         return false;
2251       continue;
2252     }
2253 
2254     llvm_unreachable("unknown type of InstructionPattern");
2255   }
2256 
2257   return true;
2258 }
2259 
2260 //===- GICombinerEmitter --------------------------------------------------===//
2261 
2262 /// Main implementation class. This emits the tablegenerated output.
2263 ///
2264 /// It collects rules, uses `CombineRuleBuilder` to parse them and accumulate
2265 /// RuleMatchers, then takes all the necessary state/data from the various
2266 /// static storage pools and wires them together to emit the match table &
2267 /// associated function/data structures.
2268 class GICombinerEmitter final : public GlobalISelMatchTableExecutorEmitter {
2269   RecordKeeper &Records;
2270   StringRef Name;
2271   const CodeGenTarget &Target;
2272   Record *Combiner;
2273   unsigned NextRuleID = 0;
2274 
2275   // List all combine rules (ID, name) imported.
2276   // Note that the combiner rule ID is different from the RuleMatcher ID. The
2277   // latter is internal to the MatchTable, the former is the canonical ID of the
2278   // combine rule used to disable/enable it.
2279   std::vector<std::pair<unsigned, std::string>> AllCombineRules;
2280 
2281   // Keep track of all rules we've seen so far to ensure we don't process
2282   // the same rule twice.
2283   StringSet<> RulesSeen;
2284 
2285   MatchTable buildMatchTable(MutableArrayRef<RuleMatcher> Rules);
2286 
2287   void emitRuleConfigImpl(raw_ostream &OS);
2288 
2289   void emitAdditionalImpl(raw_ostream &OS) override;
2290 
2291   void emitMIPredicateFns(raw_ostream &OS) override;
2292   void emitI64ImmPredicateFns(raw_ostream &OS) override;
2293   void emitAPFloatImmPredicateFns(raw_ostream &OS) override;
2294   void emitAPIntImmPredicateFns(raw_ostream &OS) override;
2295   void emitTestSimplePredicate(raw_ostream &OS) override;
2296   void emitRunCustomAction(raw_ostream &OS) override;
2297 
getTarget() const2298   const CodeGenTarget &getTarget() const override { return Target; }
getClassName() const2299   StringRef getClassName() const override {
2300     return Combiner->getValueAsString("Classname");
2301   }
2302 
getCombineAllMethodName() const2303   StringRef getCombineAllMethodName() const {
2304     return Combiner->getValueAsString("CombineAllMethodName");
2305   }
2306 
getRuleConfigClassName() const2307   std::string getRuleConfigClassName() const {
2308     return getClassName().str() + "RuleConfig";
2309   }
2310 
2311   void gatherRules(std::vector<RuleMatcher> &Rules,
2312                    const std::vector<Record *> &&RulesAndGroups);
2313 
2314 public:
2315   explicit GICombinerEmitter(RecordKeeper &RK, const CodeGenTarget &Target,
2316                              StringRef Name, Record *Combiner);
~GICombinerEmitter()2317   ~GICombinerEmitter() {}
2318 
2319   void run(raw_ostream &OS);
2320 };
2321 
emitRuleConfigImpl(raw_ostream & OS)2322 void GICombinerEmitter::emitRuleConfigImpl(raw_ostream &OS) {
2323   OS << "struct " << getRuleConfigClassName() << " {\n"
2324      << "  SparseBitVector<> DisabledRules;\n\n"
2325      << "  bool isRuleEnabled(unsigned RuleID) const;\n"
2326      << "  bool parseCommandLineOption();\n"
2327      << "  bool setRuleEnabled(StringRef RuleIdentifier);\n"
2328      << "  bool setRuleDisabled(StringRef RuleIdentifier);\n"
2329      << "};\n\n";
2330 
2331   std::vector<std::pair<std::string, std::string>> Cases;
2332   Cases.reserve(AllCombineRules.size());
2333 
2334   for (const auto &[ID, Name] : AllCombineRules)
2335     Cases.emplace_back(Name, "return " + to_string(ID) + ";\n");
2336 
2337   OS << "static std::optional<uint64_t> getRuleIdxForIdentifier(StringRef "
2338         "RuleIdentifier) {\n"
2339      << "  uint64_t I;\n"
2340      << "  // getAtInteger(...) returns false on success\n"
2341      << "  bool Parsed = !RuleIdentifier.getAsInteger(0, I);\n"
2342      << "  if (Parsed)\n"
2343      << "    return I;\n\n"
2344      << "#ifndef NDEBUG\n";
2345   StringMatcher Matcher("RuleIdentifier", Cases, OS);
2346   Matcher.Emit();
2347   OS << "#endif // ifndef NDEBUG\n\n"
2348      << "  return std::nullopt;\n"
2349      << "}\n";
2350 
2351   OS << "static std::optional<std::pair<uint64_t, uint64_t>> "
2352         "getRuleRangeForIdentifier(StringRef RuleIdentifier) {\n"
2353      << "  std::pair<StringRef, StringRef> RangePair = "
2354         "RuleIdentifier.split('-');\n"
2355      << "  if (!RangePair.second.empty()) {\n"
2356      << "    const auto First = "
2357         "getRuleIdxForIdentifier(RangePair.first);\n"
2358      << "    const auto Last = "
2359         "getRuleIdxForIdentifier(RangePair.second);\n"
2360      << "    if (!First || !Last)\n"
2361      << "      return std::nullopt;\n"
2362      << "    if (First >= Last)\n"
2363      << "      report_fatal_error(\"Beginning of range should be before "
2364         "end of range\");\n"
2365      << "    return {{*First, *Last + 1}};\n"
2366      << "  }\n"
2367      << "  if (RangePair.first == \"*\") {\n"
2368      << "    return {{0, " << AllCombineRules.size() << "}};\n"
2369      << "  }\n"
2370      << "  const auto I = getRuleIdxForIdentifier(RangePair.first);\n"
2371      << "  if (!I)\n"
2372      << "    return std::nullopt;\n"
2373      << "  return {{*I, *I + 1}};\n"
2374      << "}\n\n";
2375 
2376   for (bool Enabled : {true, false}) {
2377     OS << "bool " << getRuleConfigClassName() << "::setRule"
2378        << (Enabled ? "Enabled" : "Disabled") << "(StringRef RuleIdentifier) {\n"
2379        << "  auto MaybeRange = getRuleRangeForIdentifier(RuleIdentifier);\n"
2380        << "  if (!MaybeRange)\n"
2381        << "    return false;\n"
2382        << "  for (auto I = MaybeRange->first; I < MaybeRange->second; ++I)\n"
2383        << "    DisabledRules." << (Enabled ? "reset" : "set") << "(I);\n"
2384        << "  return true;\n"
2385        << "}\n\n";
2386   }
2387 
2388   OS << "static std::vector<std::string> " << Name << "Option;\n"
2389      << "static cl::list<std::string> " << Name << "DisableOption(\n"
2390      << "    \"" << Name.lower() << "-disable-rule\",\n"
2391      << "    cl::desc(\"Disable one or more combiner rules temporarily in "
2392      << "the " << Name << " pass\"),\n"
2393      << "    cl::CommaSeparated,\n"
2394      << "    cl::Hidden,\n"
2395      << "    cl::cat(GICombinerOptionCategory),\n"
2396      << "    cl::callback([](const std::string &Str) {\n"
2397      << "      " << Name << "Option.push_back(Str);\n"
2398      << "    }));\n"
2399      << "static cl::list<std::string> " << Name << "OnlyEnableOption(\n"
2400      << "    \"" << Name.lower() << "-only-enable-rule\",\n"
2401      << "    cl::desc(\"Disable all rules in the " << Name
2402      << " pass then re-enable the specified ones\"),\n"
2403      << "    cl::Hidden,\n"
2404      << "    cl::cat(GICombinerOptionCategory),\n"
2405      << "    cl::callback([](const std::string &CommaSeparatedArg) {\n"
2406      << "      StringRef Str = CommaSeparatedArg;\n"
2407      << "      " << Name << "Option.push_back(\"*\");\n"
2408      << "      do {\n"
2409      << "        auto X = Str.split(\",\");\n"
2410      << "        " << Name << "Option.push_back((\"!\" + X.first).str());\n"
2411      << "        Str = X.second;\n"
2412      << "      } while (!Str.empty());\n"
2413      << "    }));\n"
2414      << "\n\n"
2415      << "bool " << getRuleConfigClassName()
2416      << "::isRuleEnabled(unsigned RuleID) const {\n"
2417      << "    return  !DisabledRules.test(RuleID);\n"
2418      << "}\n"
2419      << "bool " << getRuleConfigClassName() << "::parseCommandLineOption() {\n"
2420      << "  for (StringRef Identifier : " << Name << "Option) {\n"
2421      << "    bool Enabled = Identifier.consume_front(\"!\");\n"
2422      << "    if (Enabled && !setRuleEnabled(Identifier))\n"
2423      << "      return false;\n"
2424      << "    if (!Enabled && !setRuleDisabled(Identifier))\n"
2425      << "      return false;\n"
2426      << "  }\n"
2427      << "  return true;\n"
2428      << "}\n\n";
2429 }
2430 
emitAdditionalImpl(raw_ostream & OS)2431 void GICombinerEmitter::emitAdditionalImpl(raw_ostream &OS) {
2432   OS << "bool " << getClassName() << "::" << getCombineAllMethodName()
2433      << "(MachineInstr &I) const {\n"
2434      << "  const TargetSubtargetInfo &ST = MF.getSubtarget();\n"
2435      << "  const PredicateBitset AvailableFeatures = "
2436         "getAvailableFeatures();\n"
2437      << "  B.setInstrAndDebugLoc(I);\n"
2438      << "  State.MIs.clear();\n"
2439      << "  State.MIs.push_back(&I);\n"
2440      << "  if (executeMatchTable(*this, State, ExecInfo, B"
2441      << ", getMatchTable(), *ST.getInstrInfo(), MRI, "
2442         "*MRI.getTargetRegisterInfo(), *ST.getRegBankInfo(), AvailableFeatures"
2443      << ", /*CoverageInfo*/ nullptr)) {\n"
2444      << "    return true;\n"
2445      << "  }\n\n"
2446      << "  return false;\n"
2447      << "}\n\n";
2448 }
2449 
emitMIPredicateFns(raw_ostream & OS)2450 void GICombinerEmitter::emitMIPredicateFns(raw_ostream &OS) {
2451   auto MatchCode = CXXPredicateCode::getAllMatchCode();
2452   emitMIPredicateFnsImpl<const CXXPredicateCode *>(
2453       OS, "", ArrayRef<const CXXPredicateCode *>(MatchCode),
2454       [](const CXXPredicateCode *C) -> StringRef { return C->BaseEnumName; },
2455       [](const CXXPredicateCode *C) -> StringRef { return C->Code; });
2456 }
2457 
emitI64ImmPredicateFns(raw_ostream & OS)2458 void GICombinerEmitter::emitI64ImmPredicateFns(raw_ostream &OS) {
2459   // Unused, but still needs to be called.
2460   emitImmPredicateFnsImpl<unsigned>(
2461       OS, "I64", "int64_t", {}, [](unsigned) { return ""; },
2462       [](unsigned) { return ""; });
2463 }
2464 
emitAPFloatImmPredicateFns(raw_ostream & OS)2465 void GICombinerEmitter::emitAPFloatImmPredicateFns(raw_ostream &OS) {
2466   // Unused, but still needs to be called.
2467   emitImmPredicateFnsImpl<unsigned>(
2468       OS, "APFloat", "const APFloat &", {}, [](unsigned) { return ""; },
2469       [](unsigned) { return ""; });
2470 }
2471 
emitAPIntImmPredicateFns(raw_ostream & OS)2472 void GICombinerEmitter::emitAPIntImmPredicateFns(raw_ostream &OS) {
2473   // Unused, but still needs to be called.
2474   emitImmPredicateFnsImpl<unsigned>(
2475       OS, "APInt", "const APInt &", {}, [](unsigned) { return ""; },
2476       [](unsigned) { return ""; });
2477 }
2478 
emitTestSimplePredicate(raw_ostream & OS)2479 void GICombinerEmitter::emitTestSimplePredicate(raw_ostream &OS) {
2480   if (!AllCombineRules.empty()) {
2481     OS << "enum {\n";
2482     std::string EnumeratorSeparator = " = GICXXPred_Invalid + 1,\n";
2483     // To avoid emitting a switch, we expect that all those rules are in order.
2484     // That way we can just get the RuleID from the enum by subtracting
2485     // (GICXXPred_Invalid + 1).
2486     unsigned ExpectedID = 0;
2487     (void)ExpectedID;
2488     for (const auto &ID : keys(AllCombineRules)) {
2489       assert(ExpectedID++ == ID && "combine rules are not ordered!");
2490       OS << "  " << getIsEnabledPredicateEnumName(ID) << EnumeratorSeparator;
2491       EnumeratorSeparator = ",\n";
2492     }
2493     OS << "};\n\n";
2494   }
2495 
2496   OS << "bool " << getClassName()
2497      << "::testSimplePredicate(unsigned Predicate) const {\n"
2498      << "    return RuleConfig.isRuleEnabled(Predicate - "
2499         "GICXXPred_Invalid - "
2500         "1);\n"
2501      << "}\n";
2502 }
2503 
emitRunCustomAction(raw_ostream & OS)2504 void GICombinerEmitter::emitRunCustomAction(raw_ostream &OS) {
2505   const auto CustomActionsCode = CXXPredicateCode::getAllCustomActionsCode();
2506 
2507   if (!CustomActionsCode.empty()) {
2508     OS << "enum {\n";
2509     std::string EnumeratorSeparator = " = GICXXCustomAction_Invalid + 1,\n";
2510     for (const auto &CA : CustomActionsCode) {
2511       OS << "  " << CA->getEnumNameWithPrefix(CXXCustomActionPrefix)
2512          << EnumeratorSeparator;
2513       EnumeratorSeparator = ",\n";
2514     }
2515     OS << "};\n";
2516   }
2517 
2518   OS << "bool " << getClassName()
2519      << "::runCustomAction(unsigned ApplyID, const MatcherState &State, "
2520         "NewMIVector &OutMIs) const "
2521         "{\n  Helper.getBuilder().setInstrAndDebugLoc(*State.MIs[0]);\n";
2522   if (!CustomActionsCode.empty()) {
2523     OS << "  switch(ApplyID) {\n";
2524     for (const auto &CA : CustomActionsCode) {
2525       OS << "  case " << CA->getEnumNameWithPrefix(CXXCustomActionPrefix)
2526          << ":{\n"
2527          << "    " << join(split(CA->Code, '\n'), "\n    ") << '\n'
2528          << "    return true;\n";
2529       OS << "  }\n";
2530     }
2531     OS << "  }\n";
2532   }
2533   OS << "  llvm_unreachable(\"Unknown Apply Action\");\n"
2534      << "}\n";
2535 }
2536 
GICombinerEmitter(RecordKeeper & RK,const CodeGenTarget & Target,StringRef Name,Record * Combiner)2537 GICombinerEmitter::GICombinerEmitter(RecordKeeper &RK,
2538                                      const CodeGenTarget &Target,
2539                                      StringRef Name, Record *Combiner)
2540     : Records(RK), Name(Name), Target(Target), Combiner(Combiner) {}
2541 
2542 MatchTable
buildMatchTable(MutableArrayRef<RuleMatcher> Rules)2543 GICombinerEmitter::buildMatchTable(MutableArrayRef<RuleMatcher> Rules) {
2544   std::vector<Matcher *> InputRules;
2545   for (Matcher &Rule : Rules)
2546     InputRules.push_back(&Rule);
2547 
2548   unsigned CurrentOrdering = 0;
2549   StringMap<unsigned> OpcodeOrder;
2550   for (RuleMatcher &Rule : Rules) {
2551     const StringRef Opcode = Rule.getOpcode();
2552     assert(!Opcode.empty() && "Didn't expect an undefined opcode");
2553     if (OpcodeOrder.count(Opcode) == 0)
2554       OpcodeOrder[Opcode] = CurrentOrdering++;
2555   }
2556 
2557   llvm::stable_sort(InputRules, [&OpcodeOrder](const Matcher *A,
2558                                                const Matcher *B) {
2559     auto *L = static_cast<const RuleMatcher *>(A);
2560     auto *R = static_cast<const RuleMatcher *>(B);
2561     return std::make_tuple(OpcodeOrder[L->getOpcode()], L->getNumOperands()) <
2562            std::make_tuple(OpcodeOrder[R->getOpcode()], R->getNumOperands());
2563   });
2564 
2565   for (Matcher *Rule : InputRules)
2566     Rule->optimize();
2567 
2568   std::vector<std::unique_ptr<Matcher>> MatcherStorage;
2569   std::vector<Matcher *> OptRules =
2570       optimizeRules<GroupMatcher>(InputRules, MatcherStorage);
2571 
2572   for (Matcher *Rule : OptRules)
2573     Rule->optimize();
2574 
2575   OptRules = optimizeRules<SwitchMatcher>(OptRules, MatcherStorage);
2576 
2577   return MatchTable::buildTable(OptRules, /*WithCoverage*/ false,
2578                                 /*IsCombiner*/ true);
2579 }
2580 
2581 /// Recurse into GICombineGroup's and flatten the ruleset into a simple list.
gatherRules(std::vector<RuleMatcher> & ActiveRules,const std::vector<Record * > && RulesAndGroups)2582 void GICombinerEmitter::gatherRules(
2583     std::vector<RuleMatcher> &ActiveRules,
2584     const std::vector<Record *> &&RulesAndGroups) {
2585   for (Record *Rec : RulesAndGroups) {
2586     if (!Rec->isValueUnset("Rules")) {
2587       gatherRules(ActiveRules, Rec->getValueAsListOfDefs("Rules"));
2588       continue;
2589     }
2590 
2591     StringRef RuleName = Rec->getName();
2592     if (!RulesSeen.insert(RuleName).second) {
2593       PrintWarning(Rec->getLoc(),
2594                    "skipping rule '" + Rec->getName() +
2595                        "' because it has already been processed");
2596       continue;
2597     }
2598 
2599     AllCombineRules.emplace_back(NextRuleID, Rec->getName().str());
2600     CombineRuleBuilder CRB(Target, SubtargetFeatures, *Rec, NextRuleID++,
2601                            ActiveRules);
2602 
2603     if (!CRB.parseAll()) {
2604       assert(ErrorsPrinted && "Parsing failed without errors!");
2605       continue;
2606     }
2607 
2608     if (StopAfterParse) {
2609       CRB.print(outs());
2610       continue;
2611     }
2612 
2613     if (!CRB.emitRuleMatchers()) {
2614       assert(ErrorsPrinted && "Emission failed without errors!");
2615       continue;
2616     }
2617   }
2618 }
2619 
run(raw_ostream & OS)2620 void GICombinerEmitter::run(raw_ostream &OS) {
2621   InstructionOpcodeMatcher::initOpcodeValuesMap(Target);
2622   LLTOperandMatcher::initTypeIDValuesMap();
2623 
2624   Records.startTimer("Gather rules");
2625   std::vector<RuleMatcher> Rules;
2626   gatherRules(Rules, Combiner->getValueAsListOfDefs("Rules"));
2627   if (ErrorsPrinted)
2628     PrintFatalError(Combiner->getLoc(), "Failed to parse one or more rules");
2629 
2630   if (StopAfterParse)
2631     return;
2632 
2633   Records.startTimer("Creating Match Table");
2634   unsigned MaxTemporaries = 0;
2635   for (const auto &Rule : Rules)
2636     MaxTemporaries = std::max(MaxTemporaries, Rule.countRendererFns());
2637 
2638   llvm::stable_sort(Rules, [&](const RuleMatcher &A, const RuleMatcher &B) {
2639     if (A.isHigherPriorityThan(B)) {
2640       assert(!B.isHigherPriorityThan(A) && "Cannot be more important "
2641                                            "and less important at "
2642                                            "the same time");
2643       return true;
2644     }
2645     return false;
2646   });
2647 
2648   const MatchTable Table = buildMatchTable(Rules);
2649 
2650   Records.startTimer("Emit combiner");
2651 
2652   emitSourceFileHeader(getClassName().str() + " Combiner Match Table", OS);
2653 
2654   // Unused
2655   std::vector<StringRef> CustomRendererFns;
2656   // Unused
2657   std::vector<Record *> ComplexPredicates;
2658 
2659   SmallVector<LLTCodeGen, 16> TypeObjects;
2660   append_range(TypeObjects, KnownTypes);
2661   llvm::sort(TypeObjects);
2662 
2663   // Hack: Avoid empty declarator.
2664   if (TypeObjects.empty())
2665     TypeObjects.push_back(LLT::scalar(1));
2666 
2667   // GET_GICOMBINER_DEPS, which pulls in extra dependencies.
2668   OS << "#ifdef GET_GICOMBINER_DEPS\n"
2669      << "#include \"llvm/ADT/SparseBitVector.h\"\n"
2670      << "namespace llvm {\n"
2671      << "extern cl::OptionCategory GICombinerOptionCategory;\n"
2672      << "} // end namespace llvm\n"
2673      << "#endif // ifdef GET_GICOMBINER_DEPS\n\n";
2674 
2675   // GET_GICOMBINER_TYPES, which needs to be included before the declaration of
2676   // the class.
2677   OS << "#ifdef GET_GICOMBINER_TYPES\n";
2678   emitRuleConfigImpl(OS);
2679   OS << "#endif // ifdef GET_GICOMBINER_TYPES\n\n";
2680   emitPredicateBitset(OS, "GET_GICOMBINER_TYPES");
2681 
2682   // GET_GICOMBINER_CLASS_MEMBERS, which need to be included inside the class.
2683   emitPredicatesDecl(OS, "GET_GICOMBINER_CLASS_MEMBERS");
2684   emitTemporariesDecl(OS, "GET_GICOMBINER_CLASS_MEMBERS");
2685 
2686   // GET_GICOMBINER_IMPL, which needs to be included outside the class.
2687   emitExecutorImpl(OS, Table, TypeObjects, Rules, ComplexPredicates,
2688                    CustomRendererFns, "GET_GICOMBINER_IMPL");
2689 
2690   // GET_GICOMBINER_CONSTRUCTOR_INITS, which are in the constructor's
2691   // initializer list.
2692   emitPredicatesInit(OS, "GET_GICOMBINER_CONSTRUCTOR_INITS");
2693   emitTemporariesInit(OS, MaxTemporaries, "GET_GICOMBINER_CONSTRUCTOR_INITS");
2694 }
2695 
2696 } // end anonymous namespace
2697 
2698 //===----------------------------------------------------------------------===//
2699 
EmitGICombiner(RecordKeeper & RK,raw_ostream & OS)2700 static void EmitGICombiner(RecordKeeper &RK, raw_ostream &OS) {
2701   EnablePrettyStackTrace();
2702   CodeGenTarget Target(RK);
2703 
2704   if (SelectedCombiners.empty())
2705     PrintFatalError("No combiners selected with -combiners");
2706   for (const auto &Combiner : SelectedCombiners) {
2707     Record *CombinerDef = RK.getDef(Combiner);
2708     if (!CombinerDef)
2709       PrintFatalError("Could not find " + Combiner);
2710     GICombinerEmitter(RK, Target, Combiner, CombinerDef).run(OS);
2711   }
2712 }
2713 
2714 static TableGen::Emitter::Opt X("gen-global-isel-combiner", EmitGICombiner,
2715                                 "Generate GlobalISel Combiner");
2716