xref: /freebsd/contrib/llvm-project/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- DAGISelMatcherEmitter.cpp - Matcher Emitter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains code to generate C++ code for a matcher.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Basic/SDNodeProperties.h"
14 #include "Common/CodeGenDAGPatterns.h"
15 #include "Common/CodeGenInstruction.h"
16 #include "Common/CodeGenRegisters.h"
17 #include "Common/CodeGenTarget.h"
18 #include "Common/DAGISelMatcher.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "llvm/ADT/TinyPtrVector.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Format.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "llvm/TableGen/Error.h"
27 #include "llvm/TableGen/Record.h"
28 
29 using namespace llvm;
30 
31 enum {
32   IndexWidth = 6,
33   FullIndexWidth = IndexWidth + 4,
34   HistOpcWidth = 40,
35 };
36 
37 static cl::OptionCategory DAGISelCat("Options for -gen-dag-isel");
38 
39 // To reduce generated source code size.
40 static cl::opt<bool> OmitComments("omit-comments",
41                                   cl::desc("Do not generate comments"),
42                                   cl::init(false), cl::cat(DAGISelCat));
43 
44 static cl::opt<bool> InstrumentCoverage(
45     "instrument-coverage",
46     cl::desc("Generates tables to help identify patterns matched"),
47     cl::init(false), cl::cat(DAGISelCat));
48 
49 namespace {
50 class MatcherTableEmitter {
51   const CodeGenDAGPatterns &CGP;
52 
53   SmallVector<unsigned, Matcher::HighestKind + 1> OpcodeCounts;
54 
55   std::vector<TreePattern *> NodePredicates;
56   std::vector<TreePattern *> NodePredicatesWithOperands;
57 
58   // We de-duplicate the predicates by code string, and use this map to track
59   // all the patterns with "identical" predicates.
60   MapVector<std::string, TinyPtrVector<TreePattern *>, StringMap<unsigned>>
61       NodePredicatesByCodeToRun;
62 
63   std::vector<std::string> PatternPredicates;
64 
65   std::vector<const ComplexPattern *> ComplexPatterns;
66 
67   DenseMap<const Record *, unsigned> NodeXFormMap;
68   std::vector<const Record *> NodeXForms;
69 
70   std::vector<std::string> VecIncludeStrings;
71   MapVector<std::string, unsigned, StringMap<unsigned>> VecPatterns;
72 
getPatternIdxFromTable(std::string && P,std::string && include_loc)73   unsigned getPatternIdxFromTable(std::string &&P, std::string &&include_loc) {
74     const auto [It, Inserted] =
75         VecPatterns.try_emplace(std::move(P), VecPatterns.size());
76     if (Inserted) {
77       VecIncludeStrings.push_back(std::move(include_loc));
78       return VecIncludeStrings.size() - 1;
79     }
80     return It->second;
81   }
82 
83 public:
MatcherTableEmitter(const Matcher * TheMatcher,const CodeGenDAGPatterns & cgp)84   MatcherTableEmitter(const Matcher *TheMatcher, const CodeGenDAGPatterns &cgp)
85       : CGP(cgp), OpcodeCounts(Matcher::HighestKind + 1, 0) {
86     // Record the usage of ComplexPattern.
87     MapVector<const ComplexPattern *, unsigned> ComplexPatternUsage;
88     // Record the usage of PatternPredicate.
89     MapVector<StringRef, unsigned> PatternPredicateUsage;
90     // Record the usage of Predicate.
91     MapVector<TreePattern *, unsigned> PredicateUsage;
92 
93     // Iterate the whole MatcherTable once and do some statistics.
94     std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
95       while (N) {
96         if (auto *SM = dyn_cast<ScopeMatcher>(N))
97           for (unsigned I = 0; I < SM->getNumChildren(); I++)
98             Statistic(SM->getChild(I));
99         else if (auto *SOM = dyn_cast<SwitchOpcodeMatcher>(N))
100           for (unsigned I = 0; I < SOM->getNumCases(); I++)
101             Statistic(SOM->getCaseMatcher(I));
102         else if (auto *STM = dyn_cast<SwitchTypeMatcher>(N))
103           for (unsigned I = 0; I < STM->getNumCases(); I++)
104             Statistic(STM->getCaseMatcher(I));
105         else if (auto *CPM = dyn_cast<CheckComplexPatMatcher>(N))
106           ++ComplexPatternUsage[&CPM->getPattern()];
107         else if (auto *CPPM = dyn_cast<CheckPatternPredicateMatcher>(N))
108           ++PatternPredicateUsage[CPPM->getPredicate()];
109         else if (auto *PM = dyn_cast<CheckPredicateMatcher>(N))
110           ++PredicateUsage[PM->getPredicate().getOrigPatFragRecord()];
111         N = N->getNext();
112       }
113     };
114     Statistic(TheMatcher);
115 
116     // Sort ComplexPatterns by usage.
117     std::vector<std::pair<const ComplexPattern *, unsigned>> ComplexPatternList(
118         ComplexPatternUsage.begin(), ComplexPatternUsage.end());
119     stable_sort(ComplexPatternList, [](const auto &A, const auto &B) {
120       return A.second > B.second;
121     });
122     for (const auto &ComplexPattern : ComplexPatternList)
123       ComplexPatterns.push_back(ComplexPattern.first);
124 
125     // Sort PatternPredicates by usage.
126     std::vector<std::pair<std::string, unsigned>> PatternPredicateList(
127         PatternPredicateUsage.begin(), PatternPredicateUsage.end());
128     stable_sort(PatternPredicateList, [](const auto &A, const auto &B) {
129       return A.second > B.second;
130     });
131     for (const auto &PatternPredicate : PatternPredicateList)
132       PatternPredicates.push_back(PatternPredicate.first);
133 
134     // Sort Predicates by usage.
135     // Merge predicates with same code.
136     for (const auto &Usage : PredicateUsage) {
137       TreePattern *TP = Usage.first;
138       TreePredicateFn Pred(TP);
139       NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()].push_back(TP);
140     }
141 
142     std::vector<std::pair<TreePattern *, unsigned>> PredicateList;
143     // Sum the usage.
144     for (auto &Predicate : NodePredicatesByCodeToRun) {
145       TinyPtrVector<TreePattern *> &TPs = Predicate.second;
146       stable_sort(TPs, [](const auto *A, const auto *B) {
147         return A->getRecord()->getName() < B->getRecord()->getName();
148       });
149       unsigned Uses = 0;
150       for (TreePattern *TP : TPs)
151         Uses += PredicateUsage[TP];
152 
153       // We only add the first predicate here since they are with the same code.
154       PredicateList.emplace_back(TPs[0], Uses);
155     }
156 
157     stable_sort(PredicateList, [](const auto &A, const auto &B) {
158       return A.second > B.second;
159     });
160     for (const auto &Predicate : PredicateList) {
161       TreePattern *TP = Predicate.first;
162       if (TreePredicateFn(TP).usesOperands())
163         NodePredicatesWithOperands.push_back(TP);
164       else
165         NodePredicates.push_back(TP);
166     }
167   }
168 
169   unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
170                            unsigned StartIdx, raw_ostream &OS);
171 
172   unsigned SizeMatcherList(Matcher *N, raw_ostream &OS);
173 
174   void EmitPredicateFunctions(raw_ostream &OS);
175 
176   void EmitHistogram(const Matcher *N, raw_ostream &OS);
177 
178   void EmitPatternMatchTable(raw_ostream &OS);
179 
180 private:
181   void EmitNodePredicatesFunction(const std::vector<TreePattern *> &Preds,
182                                   StringRef Decl, raw_ostream &OS);
183 
184   unsigned SizeMatcher(Matcher *N, raw_ostream &OS);
185 
186   unsigned EmitMatcher(const Matcher *N, const unsigned Indent,
187                        unsigned CurrentIdx, raw_ostream &OS);
188 
getNodePredicate(TreePredicateFn Pred)189   unsigned getNodePredicate(TreePredicateFn Pred) {
190     // We use the first predicate.
191     TreePattern *PredPat =
192         NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()][0];
193     return Pred.usesOperands()
194                ? llvm::find(NodePredicatesWithOperands, PredPat) -
195                      NodePredicatesWithOperands.begin()
196                : llvm::find(NodePredicates, PredPat) - NodePredicates.begin();
197   }
198 
getPatternPredicate(StringRef PredName)199   unsigned getPatternPredicate(StringRef PredName) {
200     return llvm::find(PatternPredicates, PredName) - PatternPredicates.begin();
201   }
getComplexPat(const ComplexPattern & P)202   unsigned getComplexPat(const ComplexPattern &P) {
203     return llvm::find(ComplexPatterns, &P) - ComplexPatterns.begin();
204   }
205 
getNodeXFormID(const Record * Rec)206   unsigned getNodeXFormID(const Record *Rec) {
207     unsigned &Entry = NodeXFormMap[Rec];
208     if (Entry == 0) {
209       NodeXForms.push_back(Rec);
210       Entry = NodeXForms.size();
211     }
212     return Entry - 1;
213   }
214 };
215 } // end anonymous namespace.
216 
GetPatFromTreePatternNode(const TreePatternNode & N)217 static std::string GetPatFromTreePatternNode(const TreePatternNode &N) {
218   std::string str;
219   raw_string_ostream Stream(str);
220   Stream << N;
221   return str;
222 }
223 
GetVBRSize(unsigned Val)224 static unsigned GetVBRSize(unsigned Val) {
225   if (Val <= 127)
226     return 1;
227 
228   unsigned NumBytes = 0;
229   while (Val >= 128) {
230     Val >>= 7;
231     ++NumBytes;
232   }
233   return NumBytes + 1;
234 }
235 
236 /// EmitVBRValue - Emit the specified value as a VBR, returning the number of
237 /// bytes emitted.
EmitVBRValue(uint64_t Val,raw_ostream & OS)238 static unsigned EmitVBRValue(uint64_t Val, raw_ostream &OS) {
239   if (Val <= 127) {
240     OS << Val << ", ";
241     return 1;
242   }
243 
244   uint64_t InVal = Val;
245   unsigned NumBytes = 0;
246   while (Val >= 128) {
247     OS << (Val & 127) << "|128,";
248     Val >>= 7;
249     ++NumBytes;
250   }
251   OS << Val;
252   if (!OmitComments)
253     OS << "/*" << InVal << "*/";
254   OS << ", ";
255   return NumBytes + 1;
256 }
257 
258 /// Emit the specified signed value as a VBR. To improve compression we encode
259 /// positive numbers shifted left by 1 and negative numbers negated and shifted
260 /// left by 1 with bit 0 set.
EmitSignedVBRValue(uint64_t Val,raw_ostream & OS)261 static unsigned EmitSignedVBRValue(uint64_t Val, raw_ostream &OS) {
262   if ((int64_t)Val >= 0)
263     Val = Val << 1;
264   else
265     Val = (-Val << 1) | 1;
266 
267   return EmitVBRValue(Val, OS);
268 }
269 
270 // This is expensive and slow.
getIncludePath(const Record * R)271 static std::string getIncludePath(const Record *R) {
272   std::string str;
273   raw_string_ostream Stream(str);
274   auto Locs = R->getLoc();
275   SMLoc L;
276   if (Locs.size() > 1) {
277     // Get where the pattern prototype was instantiated
278     L = Locs[1];
279   } else if (Locs.size() == 1) {
280     L = Locs[0];
281   }
282   unsigned CurBuf = SrcMgr.FindBufferContainingLoc(L);
283   assert(CurBuf && "Invalid or unspecified location!");
284 
285   Stream << SrcMgr.getBufferInfo(CurBuf).Buffer->getBufferIdentifier() << ":"
286          << SrcMgr.FindLineNumber(L, CurBuf);
287   return str;
288 }
289 
290 /// This function traverses the matcher tree and sizes all the nodes
291 /// that are children of the three kinds of nodes that have them.
SizeMatcherList(Matcher * N,raw_ostream & OS)292 unsigned MatcherTableEmitter::SizeMatcherList(Matcher *N, raw_ostream &OS) {
293   unsigned Size = 0;
294   while (N) {
295     Size += SizeMatcher(N, OS);
296     N = N->getNext();
297   }
298   return Size;
299 }
300 
301 /// This function sizes the children of the three kinds of nodes that
302 /// have them. It does so by using special cases for those three
303 /// nodes, but sharing the code in EmitMatcher() for the other kinds.
SizeMatcher(Matcher * N,raw_ostream & OS)304 unsigned MatcherTableEmitter::SizeMatcher(Matcher *N, raw_ostream &OS) {
305   unsigned Idx = 0;
306 
307   ++OpcodeCounts[N->getKind()];
308   switch (N->getKind()) {
309   // The Scope matcher has its kind, a series of child size + child,
310   // and a trailing zero.
311   case Matcher::Scope: {
312     ScopeMatcher *SM = cast<ScopeMatcher>(N);
313     assert(SM->getNext() == nullptr && "Scope matcher should not have next");
314     unsigned Size = 1; // Count the kind.
315     for (unsigned i = 0, e = SM->getNumChildren(); i != e; ++i) {
316       const unsigned ChildSize = SizeMatcherList(SM->getChild(i), OS);
317       assert(ChildSize != 0 && "Matcher cannot have child of size 0");
318       SM->getChild(i)->setSize(ChildSize);
319       Size += GetVBRSize(ChildSize) + ChildSize; // Count VBR and child size.
320     }
321     ++Size; // Count the zero sentinel.
322     return Size;
323   }
324 
325   // SwitchOpcode and SwitchType have their kind, a series of child size +
326   // opcode/type + child, and a trailing zero.
327   case Matcher::SwitchOpcode:
328   case Matcher::SwitchType: {
329     unsigned Size = 1; // Count the kind.
330     unsigned NumCases;
331     if (const SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N))
332       NumCases = SOM->getNumCases();
333     else
334       NumCases = cast<SwitchTypeMatcher>(N)->getNumCases();
335     for (unsigned i = 0, e = NumCases; i != e; ++i) {
336       Matcher *Child;
337       if (SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N)) {
338         Child = SOM->getCaseMatcher(i);
339         Size += 2; // Count the child's opcode.
340       } else {
341         Child = cast<SwitchTypeMatcher>(N)->getCaseMatcher(i);
342         Size += GetVBRSize(cast<SwitchTypeMatcher>(N)->getCaseType(
343             i)); // Count the child's type.
344       }
345       const unsigned ChildSize = SizeMatcherList(Child, OS);
346       assert(ChildSize != 0 && "Matcher cannot have child of size 0");
347       Child->setSize(ChildSize);
348       Size += GetVBRSize(ChildSize) + ChildSize; // Count VBR and child size.
349     }
350     ++Size; // Count the zero sentinel.
351     return Size;
352   }
353 
354   default:
355     // Employ the matcher emitter to size other matchers.
356     return EmitMatcher(N, 0, Idx, OS);
357   }
358   llvm_unreachable("Unreachable");
359 }
360 
BeginEmitFunction(raw_ostream & OS,StringRef RetType,StringRef Decl,bool AddOverride)361 static void BeginEmitFunction(raw_ostream &OS, StringRef RetType,
362                               StringRef Decl, bool AddOverride) {
363   OS << "#ifdef GET_DAGISEL_DECL\n";
364   OS << RetType << ' ' << Decl;
365   if (AddOverride)
366     OS << " override";
367   OS << ";\n"
368         "#endif\n"
369         "#if defined(GET_DAGISEL_BODY) || DAGISEL_INLINE\n";
370   OS << RetType << " DAGISEL_CLASS_COLONCOLON " << Decl << "\n";
371   if (AddOverride) {
372     OS << "#if DAGISEL_INLINE\n"
373           "  override\n"
374           "#endif\n";
375   }
376 }
377 
EndEmitFunction(raw_ostream & OS)378 static void EndEmitFunction(raw_ostream &OS) {
379   OS << "#endif // GET_DAGISEL_BODY\n\n";
380 }
381 
EmitPatternMatchTable(raw_ostream & OS)382 void MatcherTableEmitter::EmitPatternMatchTable(raw_ostream &OS) {
383 
384   if (!isUInt<32>(VecPatterns.size()))
385     report_fatal_error("More patterns defined that can fit into 32-bit Pattern "
386                        "Table index encoding");
387 
388   assert(VecPatterns.size() == VecIncludeStrings.size() &&
389          "The sizes of Pattern and include vectors should be the same");
390 
391   BeginEmitFunction(OS, "StringRef", "getPatternForIndex(unsigned Index)",
392                     true /*AddOverride*/);
393   OS << "{\n";
394   OS << "static const char *PATTERN_MATCH_TABLE[] = {\n";
395 
396   for (const auto &It : VecPatterns) {
397     OS << "\"" << It.first << "\",\n";
398   }
399 
400   OS << "\n};";
401   OS << "\nreturn StringRef(PATTERN_MATCH_TABLE[Index]);";
402   OS << "\n}\n";
403   EndEmitFunction(OS);
404 
405   BeginEmitFunction(OS, "StringRef", "getIncludePathForIndex(unsigned Index)",
406                     true /*AddOverride*/);
407   OS << "{\n";
408   OS << "static const char *INCLUDE_PATH_TABLE[] = {\n";
409 
410   for (const auto &It : VecIncludeStrings) {
411     OS << "\"" << It << "\",\n";
412   }
413 
414   OS << "\n};";
415   OS << "\nreturn StringRef(INCLUDE_PATH_TABLE[Index]);";
416   OS << "\n}\n";
417   EndEmitFunction(OS);
418 }
419 
420 /// EmitMatcher - Emit bytes for the specified matcher and return
421 /// the number of bytes emitted.
EmitMatcher(const Matcher * N,const unsigned Indent,unsigned CurrentIdx,raw_ostream & OS)422 unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
423                                           const unsigned Indent,
424                                           unsigned CurrentIdx,
425                                           raw_ostream &OS) {
426   OS.indent(Indent);
427 
428   switch (N->getKind()) {
429   case Matcher::Scope: {
430     const ScopeMatcher *SM = cast<ScopeMatcher>(N);
431     unsigned StartIdx = CurrentIdx;
432 
433     // Emit all of the children.
434     for (unsigned i = 0, e = SM->getNumChildren(); i != e; ++i) {
435       if (i == 0) {
436         OS << "OPC_Scope, ";
437         ++CurrentIdx;
438       } else {
439         if (!OmitComments) {
440           OS << "/*" << format_decimal(CurrentIdx, IndexWidth) << "*/";
441           OS.indent(Indent) << "/*Scope*/ ";
442         } else {
443           OS.indent(Indent);
444         }
445       }
446 
447       unsigned ChildSize = SM->getChild(i)->getSize();
448       unsigned VBRSize = EmitVBRValue(ChildSize, OS);
449       if (!OmitComments) {
450         OS << "/*->" << CurrentIdx + VBRSize + ChildSize << "*/";
451         if (i == 0)
452           OS << " // " << SM->getNumChildren() << " children in Scope";
453       }
454       OS << '\n';
455 
456       ChildSize = EmitMatcherList(SM->getChild(i), Indent + 1,
457                                   CurrentIdx + VBRSize, OS);
458       assert(ChildSize == SM->getChild(i)->getSize() &&
459              "Emitted child size does not match calculated size");
460       CurrentIdx += VBRSize + ChildSize;
461     }
462 
463     // Emit a zero as a sentinel indicating end of 'Scope'.
464     if (!OmitComments)
465       OS << "/*" << format_decimal(CurrentIdx, IndexWidth) << "*/";
466     OS.indent(Indent) << "0, ";
467     if (!OmitComments)
468       OS << "/*End of Scope*/";
469     OS << '\n';
470     return CurrentIdx - StartIdx + 1;
471   }
472 
473   case Matcher::RecordNode:
474     OS << "OPC_RecordNode,";
475     if (!OmitComments)
476       OS << " // #" << cast<RecordMatcher>(N)->getResultNo() << " = "
477          << cast<RecordMatcher>(N)->getWhatFor();
478     OS << '\n';
479     return 1;
480 
481   case Matcher::RecordChild:
482     OS << "OPC_RecordChild" << cast<RecordChildMatcher>(N)->getChildNo() << ',';
483     if (!OmitComments)
484       OS << " // #" << cast<RecordChildMatcher>(N)->getResultNo() << " = "
485          << cast<RecordChildMatcher>(N)->getWhatFor();
486     OS << '\n';
487     return 1;
488 
489   case Matcher::RecordMemRef:
490     OS << "OPC_RecordMemRef,\n";
491     return 1;
492 
493   case Matcher::CaptureGlueInput:
494     OS << "OPC_CaptureGlueInput,\n";
495     return 1;
496 
497   case Matcher::MoveChild: {
498     const auto *MCM = cast<MoveChildMatcher>(N);
499 
500     OS << "OPC_MoveChild";
501     // Handle the specialized forms.
502     if (MCM->getChildNo() >= 8)
503       OS << ", ";
504     OS << MCM->getChildNo() << ",\n";
505     return (MCM->getChildNo() >= 8) ? 2 : 1;
506   }
507 
508   case Matcher::MoveSibling: {
509     const auto *MSM = cast<MoveSiblingMatcher>(N);
510 
511     OS << "OPC_MoveSibling";
512     // Handle the specialized forms.
513     if (MSM->getSiblingNo() >= 8)
514       OS << ", ";
515     OS << MSM->getSiblingNo() << ",\n";
516     return (MSM->getSiblingNo() >= 8) ? 2 : 1;
517   }
518 
519   case Matcher::MoveParent:
520     OS << "OPC_MoveParent,\n";
521     return 1;
522 
523   case Matcher::CheckSame:
524     OS << "OPC_CheckSame, " << cast<CheckSameMatcher>(N)->getMatchNumber()
525        << ",\n";
526     return 2;
527 
528   case Matcher::CheckChildSame:
529     OS << "OPC_CheckChild" << cast<CheckChildSameMatcher>(N)->getChildNo()
530        << "Same, " << cast<CheckChildSameMatcher>(N)->getMatchNumber() << ",\n";
531     return 2;
532 
533   case Matcher::CheckPatternPredicate: {
534     StringRef Pred = cast<CheckPatternPredicateMatcher>(N)->getPredicate();
535     unsigned PredNo = getPatternPredicate(Pred);
536     if (PredNo > 255)
537       OS << "OPC_CheckPatternPredicateTwoByte, TARGET_VAL(" << PredNo << "),";
538     else if (PredNo < 8)
539       OS << "OPC_CheckPatternPredicate" << PredNo << ',';
540     else
541       OS << "OPC_CheckPatternPredicate, " << PredNo << ',';
542     if (!OmitComments)
543       OS << " // " << Pred;
544     OS << '\n';
545     return 2 + (PredNo > 255) - (PredNo < 8);
546   }
547   case Matcher::CheckPredicate: {
548     TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();
549     unsigned OperandBytes = 0;
550     unsigned PredNo = getNodePredicate(Pred);
551 
552     if (Pred.usesOperands()) {
553       unsigned NumOps = cast<CheckPredicateMatcher>(N)->getNumOperands();
554       OS << "OPC_CheckPredicateWithOperands, " << NumOps << "/*#Ops*/, ";
555       for (unsigned i = 0; i < NumOps; ++i)
556         OS << cast<CheckPredicateMatcher>(N)->getOperandNo(i) << ", ";
557       OperandBytes = 1 + NumOps;
558     } else {
559       if (PredNo < 8) {
560         OperandBytes = -1;
561         OS << "OPC_CheckPredicate" << PredNo << ", ";
562       } else {
563         OS << "OPC_CheckPredicate, ";
564       }
565     }
566 
567     if (PredNo >= 8 || Pred.usesOperands())
568       OS << PredNo << ',';
569     if (!OmitComments)
570       OS << " // " << Pred.getFnName();
571     OS << '\n';
572     return 2 + OperandBytes;
573   }
574 
575   case Matcher::CheckOpcode:
576     OS << "OPC_CheckOpcode, TARGET_VAL("
577        << cast<CheckOpcodeMatcher>(N)->getOpcode().getEnumName() << "),\n";
578     return 3;
579 
580   case Matcher::SwitchOpcode:
581   case Matcher::SwitchType: {
582     unsigned StartIdx = CurrentIdx;
583 
584     unsigned NumCases;
585     if (const SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N)) {
586       OS << "OPC_SwitchOpcode ";
587       NumCases = SOM->getNumCases();
588     } else {
589       OS << "OPC_SwitchType ";
590       NumCases = cast<SwitchTypeMatcher>(N)->getNumCases();
591     }
592 
593     if (!OmitComments)
594       OS << "/*" << NumCases << " cases */";
595     OS << ", ";
596     ++CurrentIdx;
597 
598     // For each case we emit the size, then the opcode, then the matcher.
599     for (unsigned i = 0, e = NumCases; i != e; ++i) {
600       const Matcher *Child;
601       unsigned IdxSize;
602       if (const SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N)) {
603         Child = SOM->getCaseMatcher(i);
604         IdxSize = 2; // size of opcode in table is 2 bytes.
605       } else {
606         Child = cast<SwitchTypeMatcher>(N)->getCaseMatcher(i);
607         IdxSize = GetVBRSize(cast<SwitchTypeMatcher>(N)->getCaseType(
608             i)); // size of type in table is sizeof(VBR(MVT)) byte.
609       }
610 
611       if (i != 0) {
612         if (!OmitComments)
613           OS << "/*" << format_decimal(CurrentIdx, IndexWidth) << "*/";
614         OS.indent(Indent);
615         if (!OmitComments)
616           OS << (isa<SwitchOpcodeMatcher>(N) ? "/*SwitchOpcode*/ "
617                                              : "/*SwitchType*/ ");
618       }
619 
620       unsigned ChildSize = Child->getSize();
621       CurrentIdx += EmitVBRValue(ChildSize, OS) + IdxSize;
622       if (const SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N))
623         OS << "TARGET_VAL(" << SOM->getCaseOpcode(i).getEnumName() << "),";
624       else {
625         if (!OmitComments)
626           OS << "/*" << getEnumName(cast<SwitchTypeMatcher>(N)->getCaseType(i))
627              << "*/";
628         EmitVBRValue(cast<SwitchTypeMatcher>(N)->getCaseType(i),
629                      OS);
630       }
631       if (!OmitComments)
632         OS << "// ->" << CurrentIdx + ChildSize;
633       OS << '\n';
634 
635       ChildSize = EmitMatcherList(Child, Indent + 1, CurrentIdx, OS);
636       assert(ChildSize == Child->getSize() &&
637              "Emitted child size does not match calculated size");
638       CurrentIdx += ChildSize;
639     }
640 
641     // Emit the final zero to terminate the switch.
642     if (!OmitComments)
643       OS << "/*" << format_decimal(CurrentIdx, IndexWidth) << "*/";
644     OS.indent(Indent) << "0,";
645     if (!OmitComments)
646       OS << (isa<SwitchOpcodeMatcher>(N) ? " // EndSwitchOpcode"
647                                          : " // EndSwitchType");
648 
649     OS << '\n';
650     return CurrentIdx - StartIdx + 1;
651   }
652 
653   case Matcher::CheckType: {
654     if (cast<CheckTypeMatcher>(N)->getResNo() == 0) {
655       MVT::SimpleValueType VT = cast<CheckTypeMatcher>(N)->getType();
656       switch (VT) {
657       case MVT::i32:
658       case MVT::i64:
659         OS << "OPC_CheckTypeI" << MVT(VT).getSizeInBits() << ",\n";
660         return 1;
661       default:
662         OS << "OPC_CheckType, ";
663         if (!OmitComments)
664           OS << "/*" << getEnumName(VT) << "*/";
665         unsigned NumBytes = EmitVBRValue(VT, OS);
666         OS << "\n";
667         return NumBytes + 1;
668       }
669     }
670     OS << "OPC_CheckTypeRes, " << cast<CheckTypeMatcher>(N)->getResNo() << ", ";
671     if (!OmitComments)
672       OS << "/*" << getEnumName(cast<CheckTypeMatcher>(N)->getType()) << "*/";
673     unsigned NumBytes = EmitVBRValue(cast<CheckTypeMatcher>(N)->getType(), OS);
674     OS << "\n";
675     return NumBytes + 2;
676   }
677 
678   case Matcher::CheckChildType: {
679     MVT::SimpleValueType VT = cast<CheckChildTypeMatcher>(N)->getType();
680     switch (VT) {
681     case MVT::i32:
682     case MVT::i64:
683       OS << "OPC_CheckChild" << cast<CheckChildTypeMatcher>(N)->getChildNo()
684          << "TypeI" << MVT(VT).getSizeInBits() << ",\n";
685       return 1;
686     default:
687       OS << "OPC_CheckChild" << cast<CheckChildTypeMatcher>(N)->getChildNo()
688          << "Type, ";
689       if (!OmitComments)
690         OS << "/*" << getEnumName(VT) << "*/";
691       unsigned NumBytes = EmitVBRValue(VT, OS);
692       OS << "\n";
693       return NumBytes + 1;
694     }
695   }
696 
697   case Matcher::CheckInteger: {
698     OS << "OPC_CheckInteger, ";
699     unsigned Bytes =
700         1 + EmitSignedVBRValue(cast<CheckIntegerMatcher>(N)->getValue(), OS);
701     OS << '\n';
702     return Bytes;
703   }
704   case Matcher::CheckChildInteger: {
705     OS << "OPC_CheckChild" << cast<CheckChildIntegerMatcher>(N)->getChildNo()
706        << "Integer, ";
707     unsigned Bytes = 1 + EmitSignedVBRValue(
708                              cast<CheckChildIntegerMatcher>(N)->getValue(), OS);
709     OS << '\n';
710     return Bytes;
711   }
712   case Matcher::CheckCondCode:
713     OS << "OPC_CheckCondCode, ISD::"
714        << cast<CheckCondCodeMatcher>(N)->getCondCodeName() << ",\n";
715     return 2;
716 
717   case Matcher::CheckChild2CondCode:
718     OS << "OPC_CheckChild2CondCode, ISD::"
719        << cast<CheckChild2CondCodeMatcher>(N)->getCondCodeName() << ",\n";
720     return 2;
721 
722   case Matcher::CheckValueType: {
723     OS << "OPC_CheckValueType, ";
724     if (!OmitComments)
725       OS << "/*" << getEnumName(cast<CheckValueTypeMatcher>(N)->getVT())
726          << "*/";
727     unsigned NumBytes =
728         EmitVBRValue(cast<CheckValueTypeMatcher>(N)->getVT(), OS);
729     OS << "\n";
730     return NumBytes + 1;
731   }
732 
733   case Matcher::CheckComplexPat: {
734     const CheckComplexPatMatcher *CCPM = cast<CheckComplexPatMatcher>(N);
735     const ComplexPattern &Pattern = CCPM->getPattern();
736     unsigned PatternNo = getComplexPat(Pattern);
737     if (PatternNo < 8)
738       OS << "OPC_CheckComplexPat" << PatternNo << ", /*#*/"
739          << CCPM->getMatchNumber() << ',';
740     else
741       OS << "OPC_CheckComplexPat, /*CP*/" << PatternNo << ", /*#*/"
742          << CCPM->getMatchNumber() << ',';
743 
744     if (!OmitComments) {
745       OS << " // " << Pattern.getSelectFunc();
746       OS << ":$" << CCPM->getName();
747       for (unsigned i = 0, e = Pattern.getNumOperands(); i != e; ++i)
748         OS << " #" << CCPM->getFirstResult() + i;
749 
750       if (Pattern.hasProperty(SDNPHasChain))
751         OS << " + chain result";
752     }
753     OS << '\n';
754     return PatternNo < 8 ? 2 : 3;
755   }
756 
757   case Matcher::CheckAndImm: {
758     OS << "OPC_CheckAndImm, ";
759     unsigned Bytes =
760         1 + EmitVBRValue(cast<CheckAndImmMatcher>(N)->getValue(), OS);
761     OS << '\n';
762     return Bytes;
763   }
764 
765   case Matcher::CheckOrImm: {
766     OS << "OPC_CheckOrImm, ";
767     unsigned Bytes =
768         1 + EmitVBRValue(cast<CheckOrImmMatcher>(N)->getValue(), OS);
769     OS << '\n';
770     return Bytes;
771   }
772 
773   case Matcher::CheckFoldableChainNode:
774     OS << "OPC_CheckFoldableChainNode,\n";
775     return 1;
776 
777   case Matcher::CheckImmAllOnesV:
778     OS << "OPC_CheckImmAllOnesV,\n";
779     return 1;
780 
781   case Matcher::CheckImmAllZerosV:
782     OS << "OPC_CheckImmAllZerosV,\n";
783     return 1;
784 
785   case Matcher::EmitInteger: {
786     int64_t Val = cast<EmitIntegerMatcher>(N)->getValue();
787     MVT::SimpleValueType VT = cast<EmitIntegerMatcher>(N)->getVT();
788     unsigned OpBytes;
789     switch (VT) {
790     case MVT::i8:
791     case MVT::i16:
792     case MVT::i32:
793     case MVT::i64:
794       OpBytes = 1;
795       OS << "OPC_EmitInteger" << MVT(VT).getSizeInBits() << ", ";
796       break;
797     default:
798       OS << "OPC_EmitInteger, ";
799       if (!OmitComments)
800         OS << "/*" << getEnumName(VT) << "*/";
801       OpBytes = EmitVBRValue(VT, OS) + 1;
802       break;
803     }
804     unsigned Bytes = OpBytes + EmitSignedVBRValue(Val, OS);
805     if (!OmitComments)
806       OS << " // " << Val << " #" << cast<EmitIntegerMatcher>(N)->getResultNo();
807     OS << '\n';
808     return Bytes;
809   }
810   case Matcher::EmitStringInteger: {
811     const std::string &Val = cast<EmitStringIntegerMatcher>(N)->getValue();
812     MVT::SimpleValueType VT = cast<EmitStringIntegerMatcher>(N)->getVT();
813     // These should always fit into 7 bits.
814     unsigned OpBytes;
815     switch (VT) {
816     case MVT::i32:
817       OpBytes = 1;
818       OS << "OPC_EmitStringInteger" << MVT(VT).getSizeInBits() << ", ";
819       break;
820     default:
821       OS << "OPC_EmitStringInteger, ";
822       if (!OmitComments)
823         OS << "/*" << getEnumName(VT) << "*/";
824       OpBytes = EmitVBRValue(VT, OS) + 1;
825       break;
826     }
827     OS << Val << ',';
828     if (!OmitComments)
829       OS << " // #" << cast<EmitStringIntegerMatcher>(N)->getResultNo();
830     OS << '\n';
831     return OpBytes + 1;
832   }
833 
834   case Matcher::EmitRegister: {
835     const EmitRegisterMatcher *Matcher = cast<EmitRegisterMatcher>(N);
836     const CodeGenRegister *Reg = Matcher->getReg();
837     MVT::SimpleValueType VT = Matcher->getVT();
838     unsigned OpBytes;
839     // If the enum value of the register is larger than one byte can handle,
840     // use EmitRegister2.
841     if (Reg && Reg->EnumValue > 255) {
842       OS << "OPC_EmitRegister2, ";
843       if (!OmitComments)
844         OS << "/*" << getEnumName(VT) << "*/";
845       OpBytes = EmitVBRValue(VT, OS);
846       OS << "TARGET_VAL(" << getQualifiedName(Reg->TheDef) << "),\n";
847       return OpBytes + 3;
848     }
849     switch (VT) {
850     case MVT::i32:
851     case MVT::i64:
852       OpBytes = 1;
853       OS << "OPC_EmitRegisterI" << MVT(VT).getSizeInBits() << ", ";
854       break;
855     default:
856       OS << "OPC_EmitRegister, ";
857       if (!OmitComments)
858         OS << "/*" << getEnumName(VT) << "*/";
859       OpBytes = EmitVBRValue(VT, OS) + 1;
860       break;
861     }
862     if (Reg) {
863       OS << getQualifiedName(Reg->TheDef);
864     } else {
865       OS << "0 ";
866       if (!OmitComments)
867         OS << "/*zero_reg*/";
868     }
869 
870     OS << ',';
871     if (!OmitComments)
872       OS << " // #" << Matcher->getResultNo();
873     OS << '\n';
874     return OpBytes + 1;
875   }
876 
877   case Matcher::EmitConvertToTarget: {
878     const auto *CTTM = cast<EmitConvertToTargetMatcher>(N);
879     unsigned Slot = CTTM->getSlot();
880     OS << "OPC_EmitConvertToTarget";
881     if (Slot >= 8)
882       OS << ", ";
883     OS << Slot << ',';
884     if (!OmitComments)
885       OS << " // #" << CTTM->getResultNo();
886     OS << '\n';
887     return 1 + (Slot >= 8);
888   }
889 
890   case Matcher::EmitMergeInputChains: {
891     const EmitMergeInputChainsMatcher *MN =
892         cast<EmitMergeInputChainsMatcher>(N);
893 
894     // Handle the specialized forms OPC_EmitMergeInputChains1_0, 1_1, and 1_2.
895     if (MN->getNumNodes() == 1 && MN->getNode(0) < 3) {
896       OS << "OPC_EmitMergeInputChains1_" << MN->getNode(0) << ",\n";
897       return 1;
898     }
899 
900     OS << "OPC_EmitMergeInputChains, " << MN->getNumNodes() << ", ";
901     for (unsigned i = 0, e = MN->getNumNodes(); i != e; ++i)
902       OS << MN->getNode(i) << ", ";
903     OS << '\n';
904     return 2 + MN->getNumNodes();
905   }
906   case Matcher::EmitCopyToReg: {
907     const auto *C2RMatcher = cast<EmitCopyToRegMatcher>(N);
908     int Bytes = 3;
909     const CodeGenRegister *Reg = C2RMatcher->getDestPhysReg();
910     unsigned Slot = C2RMatcher->getSrcSlot();
911     if (Reg->EnumValue > 255) {
912       assert(isUInt<16>(Reg->EnumValue) && "not handled");
913       OS << "OPC_EmitCopyToRegTwoByte, " << Slot << ", "
914          << "TARGET_VAL(" << getQualifiedName(Reg->TheDef) << "),\n";
915       ++Bytes;
916     } else {
917       if (Slot < 8) {
918         OS << "OPC_EmitCopyToReg" << Slot << ", "
919            << getQualifiedName(Reg->TheDef) << ",\n";
920         --Bytes;
921       } else {
922         OS << "OPC_EmitCopyToReg, " << Slot << ", "
923            << getQualifiedName(Reg->TheDef) << ",\n";
924       }
925     }
926 
927     return Bytes;
928   }
929   case Matcher::EmitNodeXForm: {
930     const EmitNodeXFormMatcher *XF = cast<EmitNodeXFormMatcher>(N);
931     OS << "OPC_EmitNodeXForm, " << getNodeXFormID(XF->getNodeXForm()) << ", "
932        << XF->getSlot() << ',';
933     if (!OmitComments)
934       OS << " // " << XF->getNodeXForm()->getName() << " #"
935          << XF->getResultNo();
936     OS << '\n';
937     return 3;
938   }
939 
940   case Matcher::EmitNode:
941   case Matcher::MorphNodeTo: {
942     auto NumCoveredBytes = 0;
943     if (InstrumentCoverage) {
944       if (const MorphNodeToMatcher *SNT = dyn_cast<MorphNodeToMatcher>(N)) {
945         NumCoveredBytes = 3;
946         OS << "OPC_Coverage, ";
947         std::string src =
948             GetPatFromTreePatternNode(SNT->getPattern().getSrcPattern());
949         std::string dst =
950             GetPatFromTreePatternNode(SNT->getPattern().getDstPattern());
951         const Record *PatRecord = SNT->getPattern().getSrcRecord();
952         std::string include_src = getIncludePath(PatRecord);
953         unsigned Offset =
954             getPatternIdxFromTable(src + " -> " + dst, std::move(include_src));
955         OS << "COVERAGE_IDX_VAL(" << Offset << "),\n";
956         OS.indent(FullIndexWidth + Indent);
957       }
958     }
959     const EmitNodeMatcherCommon *EN = cast<EmitNodeMatcherCommon>(N);
960     bool IsEmitNode = isa<EmitNodeMatcher>(EN);
961     OS << (IsEmitNode ? "OPC_EmitNode" : "OPC_MorphNodeTo");
962     bool CompressVTs = EN->getNumVTs() < 3;
963     bool CompressNodeInfo = false;
964     if (CompressVTs) {
965       OS << EN->getNumVTs();
966       if (!EN->hasChain() && !EN->hasInGlue() && !EN->hasOutGlue() &&
967           !EN->hasMemRefs() && EN->getNumFixedArityOperands() == -1) {
968         CompressNodeInfo = true;
969         OS << "None";
970       } else if (EN->hasChain() && !EN->hasInGlue() && !EN->hasOutGlue() &&
971                  !EN->hasMemRefs() && EN->getNumFixedArityOperands() == -1) {
972         CompressNodeInfo = true;
973         OS << "Chain";
974       } else if (!IsEmitNode && !EN->hasChain() && EN->hasInGlue() &&
975                  !EN->hasOutGlue() && !EN->hasMemRefs() &&
976                  EN->getNumFixedArityOperands() == -1) {
977         CompressNodeInfo = true;
978         OS << "GlueInput";
979       } else if (!IsEmitNode && !EN->hasChain() && !EN->hasInGlue() &&
980                  EN->hasOutGlue() && !EN->hasMemRefs() &&
981                  EN->getNumFixedArityOperands() == -1) {
982         CompressNodeInfo = true;
983         OS << "GlueOutput";
984       }
985     }
986 
987     const CodeGenInstruction &CGI = EN->getInstruction();
988     OS << ", TARGET_VAL(" << CGI.Namespace << "::" << CGI.TheDef->getName()
989        << ")";
990 
991     if (!CompressNodeInfo) {
992       OS << ", 0";
993       if (EN->hasChain())
994         OS << "|OPFL_Chain";
995       if (EN->hasInGlue())
996         OS << "|OPFL_GlueInput";
997       if (EN->hasOutGlue())
998         OS << "|OPFL_GlueOutput";
999       if (EN->hasMemRefs())
1000         OS << "|OPFL_MemRefs";
1001       if (EN->getNumFixedArityOperands() != -1)
1002         OS << "|OPFL_Variadic" << EN->getNumFixedArityOperands();
1003     }
1004     OS << ",\n";
1005 
1006     OS.indent(FullIndexWidth + Indent + 4);
1007     if (!CompressVTs) {
1008       OS << EN->getNumVTs();
1009       if (!OmitComments)
1010         OS << "/*#VTs*/";
1011       OS << ", ";
1012     }
1013     unsigned NumTypeBytes = 0;
1014     for (unsigned i = 0, e = EN->getNumVTs(); i != e; ++i) {
1015       if (!OmitComments)
1016         OS << "/*" << getEnumName(EN->getVT(i)) << "*/";
1017       NumTypeBytes += EmitVBRValue(EN->getVT(i), OS);
1018     }
1019 
1020     OS << EN->getNumOperands();
1021     if (!OmitComments)
1022       OS << "/*#Ops*/";
1023     OS << ", ";
1024     unsigned NumOperandBytes = 0;
1025     for (unsigned i = 0, e = EN->getNumOperands(); i != e; ++i)
1026       NumOperandBytes += EmitVBRValue(EN->getOperand(i), OS);
1027 
1028     if (!OmitComments) {
1029       // Print the result #'s for EmitNode.
1030       if (const EmitNodeMatcher *E = dyn_cast<EmitNodeMatcher>(EN)) {
1031         if (unsigned NumResults = EN->getNumVTs()) {
1032           OS << " // Results =";
1033           unsigned First = E->getFirstResultSlot();
1034           for (unsigned i = 0; i != NumResults; ++i)
1035             OS << " #" << First + i;
1036         }
1037       }
1038       OS << '\n';
1039 
1040       if (const MorphNodeToMatcher *SNT = dyn_cast<MorphNodeToMatcher>(N)) {
1041         OS.indent(FullIndexWidth + Indent)
1042             << "// Src: " << SNT->getPattern().getSrcPattern()
1043             << " - Complexity = " << SNT->getPattern().getPatternComplexity(CGP)
1044             << '\n';
1045         OS.indent(FullIndexWidth + Indent)
1046             << "// Dst: " << SNT->getPattern().getDstPattern() << '\n';
1047       }
1048     } else {
1049       OS << '\n';
1050     }
1051 
1052     return 4 + !CompressVTs + !CompressNodeInfo + NumTypeBytes +
1053            NumOperandBytes + NumCoveredBytes;
1054   }
1055   case Matcher::CompleteMatch: {
1056     const CompleteMatchMatcher *CM = cast<CompleteMatchMatcher>(N);
1057     auto NumCoveredBytes = 0;
1058     if (InstrumentCoverage) {
1059       NumCoveredBytes = 3;
1060       OS << "OPC_Coverage, ";
1061       std::string src =
1062           GetPatFromTreePatternNode(CM->getPattern().getSrcPattern());
1063       std::string dst =
1064           GetPatFromTreePatternNode(CM->getPattern().getDstPattern());
1065       const Record *PatRecord = CM->getPattern().getSrcRecord();
1066       std::string include_src = getIncludePath(PatRecord);
1067       unsigned Offset =
1068           getPatternIdxFromTable(src + " -> " + dst, std::move(include_src));
1069       OS << "COVERAGE_IDX_VAL(" << Offset << "),\n";
1070       OS.indent(FullIndexWidth + Indent);
1071     }
1072     OS << "OPC_CompleteMatch, " << CM->getNumResults() << ", ";
1073     unsigned NumResultBytes = 0;
1074     for (unsigned i = 0, e = CM->getNumResults(); i != e; ++i)
1075       NumResultBytes += EmitVBRValue(CM->getResult(i), OS);
1076     OS << '\n';
1077     if (!OmitComments) {
1078       OS.indent(FullIndexWidth + Indent)
1079           << " // Src: " << CM->getPattern().getSrcPattern()
1080           << " - Complexity = " << CM->getPattern().getPatternComplexity(CGP)
1081           << '\n';
1082       OS.indent(FullIndexWidth + Indent)
1083           << " // Dst: " << CM->getPattern().getDstPattern();
1084     }
1085     OS << '\n';
1086     return 2 + NumResultBytes + NumCoveredBytes;
1087   }
1088   }
1089   llvm_unreachable("Unreachable");
1090 }
1091 
1092 /// This function traverses the matcher tree and emits all the nodes.
1093 /// The nodes have already been sized.
EmitMatcherList(const Matcher * N,const unsigned Indent,unsigned CurrentIdx,raw_ostream & OS)1094 unsigned MatcherTableEmitter::EmitMatcherList(const Matcher *N,
1095                                               const unsigned Indent,
1096                                               unsigned CurrentIdx,
1097                                               raw_ostream &OS) {
1098   unsigned Size = 0;
1099   while (N) {
1100     if (!OmitComments)
1101       OS << "/*" << format_decimal(CurrentIdx, IndexWidth) << "*/";
1102     unsigned MatcherSize = EmitMatcher(N, Indent, CurrentIdx, OS);
1103     Size += MatcherSize;
1104     CurrentIdx += MatcherSize;
1105 
1106     // If there are other nodes in this list, iterate to them, otherwise we're
1107     // done.
1108     N = N->getNext();
1109   }
1110   return Size;
1111 }
1112 
EmitNodePredicatesFunction(const std::vector<TreePattern * > & Preds,StringRef Decl,raw_ostream & OS)1113 void MatcherTableEmitter::EmitNodePredicatesFunction(
1114     const std::vector<TreePattern *> &Preds, StringRef Decl, raw_ostream &OS) {
1115   if (Preds.empty())
1116     return;
1117 
1118   BeginEmitFunction(OS, "bool", Decl, true /*AddOverride*/);
1119   OS << "{\n";
1120   OS << "  switch (PredNo) {\n";
1121   OS << "  default: llvm_unreachable(\"Invalid predicate in table?\");\n";
1122   for (unsigned i = 0, e = Preds.size(); i != e; ++i) {
1123     // Emit the predicate code corresponding to this pattern.
1124     TreePredicateFn PredFn(Preds[i]);
1125     assert(!PredFn.isAlwaysTrue() && "No code in this predicate");
1126     std::string PredFnCodeStr = PredFn.getCodeToRunOnSDNode();
1127 
1128     OS << "  case " << i << ": {\n";
1129     for (auto *SimilarPred : NodePredicatesByCodeToRun[PredFnCodeStr])
1130       OS << "    // " << TreePredicateFn(SimilarPred).getFnName() << '\n';
1131     OS << PredFnCodeStr << "\n  }\n";
1132   }
1133   OS << "  }\n";
1134   OS << "}\n";
1135   EndEmitFunction(OS);
1136 }
1137 
EmitPredicateFunctions(raw_ostream & OS)1138 void MatcherTableEmitter::EmitPredicateFunctions(raw_ostream &OS) {
1139   // Emit pattern predicates.
1140   if (!PatternPredicates.empty()) {
1141     BeginEmitFunction(OS, "bool",
1142                       "CheckPatternPredicate(unsigned PredNo) const",
1143                       true /*AddOverride*/);
1144     OS << "{\n";
1145     OS << "  switch (PredNo) {\n";
1146     OS << "  default: llvm_unreachable(\"Invalid predicate in table?\");\n";
1147     for (unsigned i = 0, e = PatternPredicates.size(); i != e; ++i)
1148       OS << "  case " << i << ": return " << PatternPredicates[i] << ";\n";
1149     OS << "  }\n";
1150     OS << "}\n";
1151     EndEmitFunction(OS);
1152   }
1153 
1154   // Emit Node predicates.
1155   EmitNodePredicatesFunction(
1156       NodePredicates, "CheckNodePredicate(SDValue Op, unsigned PredNo) const",
1157       OS);
1158   EmitNodePredicatesFunction(
1159       NodePredicatesWithOperands,
1160       "CheckNodePredicateWithOperands(SDValue Op, unsigned PredNo, "
1161       "ArrayRef<SDValue> Operands) const",
1162       OS);
1163 
1164   // Emit CompletePattern matchers.
1165   // FIXME: This should be const.
1166   if (!ComplexPatterns.empty()) {
1167     BeginEmitFunction(
1168         OS, "bool",
1169         "CheckComplexPattern(SDNode *Root, SDNode *Parent,\n"
1170         "      SDValue N, unsigned PatternNo,\n"
1171         "      SmallVectorImpl<std::pair<SDValue, SDNode *>> &Result)",
1172         true /*AddOverride*/);
1173     OS << "{\n";
1174     OS << "  unsigned NextRes = Result.size();\n";
1175     OS << "  switch (PatternNo) {\n";
1176     OS << "  default: llvm_unreachable(\"Invalid pattern # in table?\");\n";
1177     for (unsigned i = 0, e = ComplexPatterns.size(); i != e; ++i) {
1178       const ComplexPattern &P = *ComplexPatterns[i];
1179       unsigned NumOps = P.getNumOperands();
1180 
1181       if (P.hasProperty(SDNPHasChain))
1182         ++NumOps; // Get the chained node too.
1183 
1184       OS << "  case " << i << ":\n";
1185       if (InstrumentCoverage)
1186         OS << "  {\n";
1187       OS << "    Result.resize(NextRes+" << NumOps << ");\n";
1188       if (InstrumentCoverage)
1189         OS << "    bool Succeeded = " << P.getSelectFunc();
1190       else
1191         OS << "  return " << P.getSelectFunc();
1192 
1193       OS << "(";
1194       // If the complex pattern wants the root of the match, pass it in as the
1195       // first argument.
1196       if (P.wantsRoot())
1197         OS << "Root, ";
1198 
1199       // If the complex pattern wants the parent of the operand being matched,
1200       // pass it in as the next argument.
1201       if (P.wantsParent())
1202         OS << "Parent, ";
1203 
1204       OS << "N";
1205       for (unsigned i = 0; i != NumOps; ++i)
1206         OS << ", Result[NextRes+" << i << "].first";
1207       OS << ");\n";
1208       if (InstrumentCoverage) {
1209         OS << "    if (Succeeded)\n";
1210         OS << "       dbgs() << \"\\nCOMPLEX_PATTERN: " << P.getSelectFunc()
1211            << "\\n\" ;\n";
1212         OS << "    return Succeeded;\n";
1213         OS << "    }\n";
1214       }
1215     }
1216     OS << "  }\n";
1217     OS << "}\n";
1218     EndEmitFunction(OS);
1219   }
1220 
1221   // Emit SDNodeXForm handlers.
1222   // FIXME: This should be const.
1223   if (!NodeXForms.empty()) {
1224     BeginEmitFunction(OS, "SDValue",
1225                       "RunSDNodeXForm(SDValue V, unsigned XFormNo)",
1226                       true /*AddOverride*/);
1227     OS << "{\n";
1228     OS << "  switch (XFormNo) {\n";
1229     OS << "  default: llvm_unreachable(\"Invalid xform # in table?\");\n";
1230 
1231     // FIXME: The node xform could take SDValue's instead of SDNode*'s.
1232     for (unsigned i = 0, e = NodeXForms.size(); i != e; ++i) {
1233       const CodeGenDAGPatterns::NodeXForm &Entry =
1234           CGP.getSDNodeTransform(NodeXForms[i]);
1235 
1236       const Record *SDNode = Entry.first;
1237       const std::string &Code = Entry.second;
1238 
1239       OS << "  case " << i << ": {  ";
1240       if (!OmitComments)
1241         OS << "// " << NodeXForms[i]->getName();
1242       OS << '\n';
1243 
1244       std::string ClassName = CGP.getSDNodeInfo(SDNode).getSDClassName().str();
1245       if (ClassName == "SDNode")
1246         OS << "    SDNode *N = V.getNode();\n";
1247       else
1248         OS << "    " << ClassName << " *N = cast<" << ClassName
1249            << ">(V.getNode());\n";
1250       OS << Code << "\n  }\n";
1251     }
1252     OS << "  }\n";
1253     OS << "}\n";
1254     EndEmitFunction(OS);
1255   }
1256 }
1257 
getOpcodeString(Matcher::KindTy Kind)1258 static StringRef getOpcodeString(Matcher::KindTy Kind) {
1259   switch (Kind) {
1260   case Matcher::Scope:
1261     return "OPC_Scope";
1262   case Matcher::RecordNode:
1263     return "OPC_RecordNode";
1264   case Matcher::RecordChild:
1265     return "OPC_RecordChild";
1266   case Matcher::RecordMemRef:
1267     return "OPC_RecordMemRef";
1268   case Matcher::CaptureGlueInput:
1269     return "OPC_CaptureGlueInput";
1270   case Matcher::MoveChild:
1271     return "OPC_MoveChild";
1272   case Matcher::MoveSibling:
1273     return "OPC_MoveSibling";
1274   case Matcher::MoveParent:
1275     return "OPC_MoveParent";
1276   case Matcher::CheckSame:
1277     return "OPC_CheckSame";
1278   case Matcher::CheckChildSame:
1279     return "OPC_CheckChildSame";
1280   case Matcher::CheckPatternPredicate:
1281     return "OPC_CheckPatternPredicate";
1282   case Matcher::CheckPredicate:
1283     return "OPC_CheckPredicate";
1284   case Matcher::CheckOpcode:
1285     return "OPC_CheckOpcode";
1286   case Matcher::SwitchOpcode:
1287     return "OPC_SwitchOpcode";
1288   case Matcher::CheckType:
1289     return "OPC_CheckType";
1290   case Matcher::SwitchType:
1291     return "OPC_SwitchType";
1292   case Matcher::CheckChildType:
1293     return "OPC_CheckChildType";
1294   case Matcher::CheckInteger:
1295     return "OPC_CheckInteger";
1296   case Matcher::CheckChildInteger:
1297     return "OPC_CheckChildInteger";
1298   case Matcher::CheckCondCode:
1299     return "OPC_CheckCondCode";
1300   case Matcher::CheckChild2CondCode:
1301     return "OPC_CheckChild2CondCode";
1302   case Matcher::CheckValueType:
1303     return "OPC_CheckValueType";
1304   case Matcher::CheckComplexPat:
1305     return "OPC_CheckComplexPat";
1306   case Matcher::CheckAndImm:
1307     return "OPC_CheckAndImm";
1308   case Matcher::CheckOrImm:
1309     return "OPC_CheckOrImm";
1310   case Matcher::CheckFoldableChainNode:
1311     return "OPC_CheckFoldableChainNode";
1312   case Matcher::CheckImmAllOnesV:
1313     return "OPC_CheckImmAllOnesV";
1314   case Matcher::CheckImmAllZerosV:
1315     return "OPC_CheckImmAllZerosV";
1316   case Matcher::EmitInteger:
1317     return "OPC_EmitInteger";
1318   case Matcher::EmitStringInteger:
1319     return "OPC_EmitStringInteger";
1320   case Matcher::EmitRegister:
1321     return "OPC_EmitRegister";
1322   case Matcher::EmitConvertToTarget:
1323     return "OPC_EmitConvertToTarget";
1324   case Matcher::EmitMergeInputChains:
1325     return "OPC_EmitMergeInputChains";
1326   case Matcher::EmitCopyToReg:
1327     return "OPC_EmitCopyToReg";
1328   case Matcher::EmitNode:
1329     return "OPC_EmitNode";
1330   case Matcher::MorphNodeTo:
1331     return "OPC_MorphNodeTo";
1332   case Matcher::EmitNodeXForm:
1333     return "OPC_EmitNodeXForm";
1334   case Matcher::CompleteMatch:
1335     return "OPC_CompleteMatch";
1336   }
1337 
1338   llvm_unreachable("Unhandled opcode?");
1339 }
1340 
EmitHistogram(const Matcher * M,raw_ostream & OS)1341 void MatcherTableEmitter::EmitHistogram(const Matcher *M, raw_ostream &OS) {
1342   if (OmitComments)
1343     return;
1344 
1345   OS << "  // Opcode Histogram:\n";
1346   for (unsigned i = 0, e = OpcodeCounts.size(); i != e; ++i) {
1347     OS << "  // #"
1348        << left_justify(getOpcodeString((Matcher::KindTy)i), HistOpcWidth)
1349        << " = " << OpcodeCounts[i] << '\n';
1350   }
1351   OS << '\n';
1352 }
1353 
EmitMatcherTable(Matcher * TheMatcher,const CodeGenDAGPatterns & CGP,raw_ostream & OS)1354 void llvm::EmitMatcherTable(Matcher *TheMatcher, const CodeGenDAGPatterns &CGP,
1355                             raw_ostream &OS) {
1356   OS << "#if defined(GET_DAGISEL_DECL) && defined(GET_DAGISEL_BODY)\n";
1357   OS << "#error GET_DAGISEL_DECL and GET_DAGISEL_BODY cannot be both defined, ";
1358   OS << "undef both for inline definitions\n";
1359   OS << "#endif\n\n";
1360 
1361   // Emit a check for omitted class name.
1362   OS << "#ifdef GET_DAGISEL_BODY\n";
1363   OS << "#define LOCAL_DAGISEL_STRINGIZE(X) LOCAL_DAGISEL_STRINGIZE_(X)\n";
1364   OS << "#define LOCAL_DAGISEL_STRINGIZE_(X) #X\n";
1365   OS << "static_assert(sizeof(LOCAL_DAGISEL_STRINGIZE(GET_DAGISEL_BODY)) > 1,"
1366         "\n";
1367   OS << "   \"GET_DAGISEL_BODY is empty: it should be defined with the class "
1368         "name\");\n";
1369   OS << "#undef LOCAL_DAGISEL_STRINGIZE_\n";
1370   OS << "#undef LOCAL_DAGISEL_STRINGIZE\n";
1371   OS << "#endif\n\n";
1372 
1373   OS << "#if !defined(GET_DAGISEL_DECL) && !defined(GET_DAGISEL_BODY)\n";
1374   OS << "#define DAGISEL_INLINE 1\n";
1375   OS << "#else\n";
1376   OS << "#define DAGISEL_INLINE 0\n";
1377   OS << "#endif\n\n";
1378 
1379   OS << "#if !DAGISEL_INLINE\n";
1380   OS << "#define DAGISEL_CLASS_COLONCOLON GET_DAGISEL_BODY ::\n";
1381   OS << "#else\n";
1382   OS << "#define DAGISEL_CLASS_COLONCOLON\n";
1383   OS << "#endif\n\n";
1384 
1385   BeginEmitFunction(OS, "void", "SelectCode(SDNode *N)", false /*AddOverride*/);
1386   MatcherTableEmitter MatcherEmitter(TheMatcher, CGP);
1387 
1388   // First we size all the children of the three kinds of matchers that have
1389   // them. This is done by sharing the code in EmitMatcher(). but we don't
1390   // want to emit anything, so we turn off comments and use a null stream.
1391   bool SaveOmitComments = OmitComments;
1392   OmitComments = true;
1393   raw_null_ostream NullOS;
1394   unsigned TotalSize = MatcherEmitter.SizeMatcherList(TheMatcher, NullOS);
1395   OmitComments = SaveOmitComments;
1396 
1397   // Now that the matchers are sized, we can emit the code for them to the
1398   // final stream.
1399   OS << "{\n";
1400   OS << "  // Some target values are emitted as 2 bytes, TARGET_VAL handles\n";
1401   OS << "  // this. Coverage indexes are emitted as 4 bytes,\n";
1402   OS << "  // COVERAGE_IDX_VAL handles this.\n";
1403   OS << "  #define TARGET_VAL(X) X & 255, unsigned(X) >> 8\n";
1404   OS << "  #define COVERAGE_IDX_VAL(X) X & 255, (unsigned(X) >> 8) & 255, ";
1405   OS << "(unsigned(X) >> 16) & 255, (unsigned(X) >> 24) & 255\n";
1406   OS << "  static const unsigned char MatcherTable[] = {\n";
1407   TotalSize = MatcherEmitter.EmitMatcherList(TheMatcher, 1, 0, OS);
1408   OS << "    0\n  }; // Total Array size is " << (TotalSize + 1)
1409      << " bytes\n\n";
1410 
1411   MatcherEmitter.EmitHistogram(TheMatcher, OS);
1412 
1413   OS << "  #undef COVERAGE_IDX_VAL\n";
1414   OS << "  #undef TARGET_VAL\n";
1415   OS << "  SelectCodeCommon(N, MatcherTable, sizeof(MatcherTable));\n";
1416   OS << "}\n";
1417   EndEmitFunction(OS);
1418 
1419   // Next up, emit the function for node and pattern predicates:
1420   MatcherEmitter.EmitPredicateFunctions(OS);
1421 
1422   if (InstrumentCoverage)
1423     MatcherEmitter.EmitPatternMatchTable(OS);
1424 
1425   // Clean up the preprocessor macros.
1426   OS << "\n";
1427   OS << "#ifdef DAGISEL_INLINE\n";
1428   OS << "#undef DAGISEL_INLINE\n";
1429   OS << "#endif\n";
1430   OS << "#ifdef DAGISEL_CLASS_COLONCOLON\n";
1431   OS << "#undef DAGISEL_CLASS_COLONCOLON\n";
1432   OS << "#endif\n";
1433   OS << "#ifdef GET_DAGISEL_DECL\n";
1434   OS << "#undef GET_DAGISEL_DECL\n";
1435   OS << "#endif\n";
1436   OS << "#ifdef GET_DAGISEL_BODY\n";
1437   OS << "#undef GET_DAGISEL_BODY\n";
1438   OS << "#endif\n";
1439 }
1440