xref: /freebsd/contrib/llvm-project/llvm/utils/TableGen/DAGISelMatcherEmitter.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- DAGISelMatcherEmitter.cpp - Matcher Emitter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains code to generate C++ code for a matcher.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Basic/SDNodeProperties.h"
14 #include "Common/CodeGenDAGPatterns.h"
15 #include "Common/CodeGenInstruction.h"
16 #include "Common/CodeGenRegisters.h"
17 #include "Common/CodeGenTarget.h"
18 #include "Common/DAGISelMatcher.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/StringMap.h"
22 #include "llvm/ADT/TinyPtrVector.h"
23 #include "llvm/Support/CommandLine.h"
24 #include "llvm/Support/Format.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "llvm/TableGen/Error.h"
27 #include "llvm/TableGen/Record.h"
28 
29 using namespace llvm;
30 
31 enum {
32   IndexWidth = 6,
33   FullIndexWidth = IndexWidth + 4,
34   HistOpcWidth = 40,
35 };
36 
37 cl::OptionCategory DAGISelCat("Options for -gen-dag-isel");
38 
39 // To reduce generated source code size.
40 static cl::opt<bool> OmitComments("omit-comments",
41                                   cl::desc("Do not generate comments"),
42                                   cl::init(false), cl::cat(DAGISelCat));
43 
44 static cl::opt<bool> InstrumentCoverage(
45     "instrument-coverage",
46     cl::desc("Generates tables to help identify patterns matched"),
47     cl::init(false), cl::cat(DAGISelCat));
48 
49 namespace {
50 class MatcherTableEmitter {
51   const CodeGenDAGPatterns &CGP;
52 
53   SmallVector<unsigned, Matcher::HighestKind + 1> OpcodeCounts;
54 
55   std::vector<TreePattern *> NodePredicates;
56   std::vector<TreePattern *> NodePredicatesWithOperands;
57 
58   // We de-duplicate the predicates by code string, and use this map to track
59   // all the patterns with "identical" predicates.
60   MapVector<std::string, TinyPtrVector<TreePattern *>, StringMap<unsigned>>
61       NodePredicatesByCodeToRun;
62 
63   std::vector<std::string> PatternPredicates;
64 
65   std::vector<const ComplexPattern *> ComplexPatterns;
66 
67   DenseMap<Record *, unsigned> NodeXFormMap;
68   std::vector<Record *> NodeXForms;
69 
70   std::vector<std::string> VecIncludeStrings;
71   MapVector<std::string, unsigned, StringMap<unsigned>> VecPatterns;
72 
getPatternIdxFromTable(std::string && P,std::string && include_loc)73   unsigned getPatternIdxFromTable(std::string &&P, std::string &&include_loc) {
74     const auto It = VecPatterns.find(P);
75     if (It == VecPatterns.end()) {
76       VecPatterns.insert(std::pair(std::move(P), VecPatterns.size()));
77       VecIncludeStrings.push_back(std::move(include_loc));
78       return VecIncludeStrings.size() - 1;
79     }
80     return It->second;
81   }
82 
83 public:
MatcherTableEmitter(const Matcher * TheMatcher,const CodeGenDAGPatterns & cgp)84   MatcherTableEmitter(const Matcher *TheMatcher, const CodeGenDAGPatterns &cgp)
85       : CGP(cgp), OpcodeCounts(Matcher::HighestKind + 1, 0) {
86     // Record the usage of ComplexPattern.
87     MapVector<const ComplexPattern *, unsigned> ComplexPatternUsage;
88     // Record the usage of PatternPredicate.
89     MapVector<StringRef, unsigned> PatternPredicateUsage;
90     // Record the usage of Predicate.
91     MapVector<TreePattern *, unsigned> PredicateUsage;
92 
93     // Iterate the whole MatcherTable once and do some statistics.
94     std::function<void(const Matcher *)> Statistic = [&](const Matcher *N) {
95       while (N) {
96         if (auto *SM = dyn_cast<ScopeMatcher>(N))
97           for (unsigned I = 0; I < SM->getNumChildren(); I++)
98             Statistic(SM->getChild(I));
99         else if (auto *SOM = dyn_cast<SwitchOpcodeMatcher>(N))
100           for (unsigned I = 0; I < SOM->getNumCases(); I++)
101             Statistic(SOM->getCaseMatcher(I));
102         else if (auto *STM = dyn_cast<SwitchTypeMatcher>(N))
103           for (unsigned I = 0; I < STM->getNumCases(); I++)
104             Statistic(STM->getCaseMatcher(I));
105         else if (auto *CPM = dyn_cast<CheckComplexPatMatcher>(N))
106           ++ComplexPatternUsage[&CPM->getPattern()];
107         else if (auto *CPPM = dyn_cast<CheckPatternPredicateMatcher>(N))
108           ++PatternPredicateUsage[CPPM->getPredicate()];
109         else if (auto *PM = dyn_cast<CheckPredicateMatcher>(N))
110           ++PredicateUsage[PM->getPredicate().getOrigPatFragRecord()];
111         N = N->getNext();
112       }
113     };
114     Statistic(TheMatcher);
115 
116     // Sort ComplexPatterns by usage.
117     std::vector<std::pair<const ComplexPattern *, unsigned>> ComplexPatternList(
118         ComplexPatternUsage.begin(), ComplexPatternUsage.end());
119     stable_sort(ComplexPatternList, [](const auto &A, const auto &B) {
120       return A.second > B.second;
121     });
122     for (const auto &ComplexPattern : ComplexPatternList)
123       ComplexPatterns.push_back(ComplexPattern.first);
124 
125     // Sort PatternPredicates by usage.
126     std::vector<std::pair<std::string, unsigned>> PatternPredicateList(
127         PatternPredicateUsage.begin(), PatternPredicateUsage.end());
128     stable_sort(PatternPredicateList, [](const auto &A, const auto &B) {
129       return A.second > B.second;
130     });
131     for (const auto &PatternPredicate : PatternPredicateList)
132       PatternPredicates.push_back(PatternPredicate.first);
133 
134     // Sort Predicates by usage.
135     // Merge predicates with same code.
136     for (const auto &Usage : PredicateUsage) {
137       TreePattern *TP = Usage.first;
138       TreePredicateFn Pred(TP);
139       NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()].push_back(TP);
140     }
141 
142     std::vector<std::pair<TreePattern *, unsigned>> PredicateList;
143     // Sum the usage.
144     for (auto &Predicate : NodePredicatesByCodeToRun) {
145       TinyPtrVector<TreePattern *> &TPs = Predicate.second;
146       stable_sort(TPs, [](const auto *A, const auto *B) {
147         return A->getRecord()->getName() < B->getRecord()->getName();
148       });
149       unsigned Uses = 0;
150       for (TreePattern *TP : TPs)
151         Uses += PredicateUsage[TP];
152 
153       // We only add the first predicate here since they are with the same code.
154       PredicateList.push_back({TPs[0], Uses});
155     }
156 
157     stable_sort(PredicateList, [](const auto &A, const auto &B) {
158       return A.second > B.second;
159     });
160     for (const auto &Predicate : PredicateList) {
161       TreePattern *TP = Predicate.first;
162       if (TreePredicateFn(TP).usesOperands())
163         NodePredicatesWithOperands.push_back(TP);
164       else
165         NodePredicates.push_back(TP);
166     }
167   }
168 
169   unsigned EmitMatcherList(const Matcher *N, const unsigned Indent,
170                            unsigned StartIdx, raw_ostream &OS);
171 
172   unsigned SizeMatcherList(Matcher *N, raw_ostream &OS);
173 
174   void EmitPredicateFunctions(raw_ostream &OS);
175 
176   void EmitHistogram(const Matcher *N, raw_ostream &OS);
177 
178   void EmitPatternMatchTable(raw_ostream &OS);
179 
180 private:
181   void EmitNodePredicatesFunction(const std::vector<TreePattern *> &Preds,
182                                   StringRef Decl, raw_ostream &OS);
183 
184   unsigned SizeMatcher(Matcher *N, raw_ostream &OS);
185 
186   unsigned EmitMatcher(const Matcher *N, const unsigned Indent,
187                        unsigned CurrentIdx, raw_ostream &OS);
188 
getNodePredicate(TreePredicateFn Pred)189   unsigned getNodePredicate(TreePredicateFn Pred) {
190     // We use the first predicate.
191     TreePattern *PredPat =
192         NodePredicatesByCodeToRun[Pred.getCodeToRunOnSDNode()][0];
193     return Pred.usesOperands()
194                ? llvm::find(NodePredicatesWithOperands, PredPat) -
195                      NodePredicatesWithOperands.begin()
196                : llvm::find(NodePredicates, PredPat) - NodePredicates.begin();
197   }
198 
getPatternPredicate(StringRef PredName)199   unsigned getPatternPredicate(StringRef PredName) {
200     return llvm::find(PatternPredicates, PredName) - PatternPredicates.begin();
201   }
getComplexPat(const ComplexPattern & P)202   unsigned getComplexPat(const ComplexPattern &P) {
203     return llvm::find(ComplexPatterns, &P) - ComplexPatterns.begin();
204   }
205 
getNodeXFormID(Record * Rec)206   unsigned getNodeXFormID(Record *Rec) {
207     unsigned &Entry = NodeXFormMap[Rec];
208     if (Entry == 0) {
209       NodeXForms.push_back(Rec);
210       Entry = NodeXForms.size();
211     }
212     return Entry - 1;
213   }
214 };
215 } // end anonymous namespace.
216 
GetPatFromTreePatternNode(const TreePatternNode & N)217 static std::string GetPatFromTreePatternNode(const TreePatternNode &N) {
218   std::string str;
219   raw_string_ostream Stream(str);
220   Stream << N;
221   return str;
222 }
223 
GetVBRSize(unsigned Val)224 static unsigned GetVBRSize(unsigned Val) {
225   if (Val <= 127)
226     return 1;
227 
228   unsigned NumBytes = 0;
229   while (Val >= 128) {
230     Val >>= 7;
231     ++NumBytes;
232   }
233   return NumBytes + 1;
234 }
235 
236 /// EmitVBRValue - Emit the specified value as a VBR, returning the number of
237 /// bytes emitted.
EmitVBRValue(uint64_t Val,raw_ostream & OS)238 static unsigned EmitVBRValue(uint64_t Val, raw_ostream &OS) {
239   if (Val <= 127) {
240     OS << Val << ", ";
241     return 1;
242   }
243 
244   uint64_t InVal = Val;
245   unsigned NumBytes = 0;
246   while (Val >= 128) {
247     OS << (Val & 127) << "|128,";
248     Val >>= 7;
249     ++NumBytes;
250   }
251   OS << Val;
252   if (!OmitComments)
253     OS << "/*" << InVal << "*/";
254   OS << ", ";
255   return NumBytes + 1;
256 }
257 
258 /// Emit the specified signed value as a VBR. To improve compression we encode
259 /// positive numbers shifted left by 1 and negative numbers negated and shifted
260 /// left by 1 with bit 0 set.
EmitSignedVBRValue(uint64_t Val,raw_ostream & OS)261 static unsigned EmitSignedVBRValue(uint64_t Val, raw_ostream &OS) {
262   if ((int64_t)Val >= 0)
263     Val = Val << 1;
264   else
265     Val = (-Val << 1) | 1;
266 
267   return EmitVBRValue(Val, OS);
268 }
269 
270 // This is expensive and slow.
getIncludePath(const Record * R)271 static std::string getIncludePath(const Record *R) {
272   std::string str;
273   raw_string_ostream Stream(str);
274   auto Locs = R->getLoc();
275   SMLoc L;
276   if (Locs.size() > 1) {
277     // Get where the pattern prototype was instantiated
278     L = Locs[1];
279   } else if (Locs.size() == 1) {
280     L = Locs[0];
281   }
282   unsigned CurBuf = SrcMgr.FindBufferContainingLoc(L);
283   assert(CurBuf && "Invalid or unspecified location!");
284 
285   Stream << SrcMgr.getBufferInfo(CurBuf).Buffer->getBufferIdentifier() << ":"
286          << SrcMgr.FindLineNumber(L, CurBuf);
287   return str;
288 }
289 
290 /// This function traverses the matcher tree and sizes all the nodes
291 /// that are children of the three kinds of nodes that have them.
SizeMatcherList(Matcher * N,raw_ostream & OS)292 unsigned MatcherTableEmitter::SizeMatcherList(Matcher *N, raw_ostream &OS) {
293   unsigned Size = 0;
294   while (N) {
295     Size += SizeMatcher(N, OS);
296     N = N->getNext();
297   }
298   return Size;
299 }
300 
301 /// This function sizes the children of the three kinds of nodes that
302 /// have them. It does so by using special cases for those three
303 /// nodes, but sharing the code in EmitMatcher() for the other kinds.
SizeMatcher(Matcher * N,raw_ostream & OS)304 unsigned MatcherTableEmitter::SizeMatcher(Matcher *N, raw_ostream &OS) {
305   unsigned Idx = 0;
306 
307   ++OpcodeCounts[N->getKind()];
308   switch (N->getKind()) {
309   // The Scope matcher has its kind, a series of child size + child,
310   // and a trailing zero.
311   case Matcher::Scope: {
312     ScopeMatcher *SM = cast<ScopeMatcher>(N);
313     assert(SM->getNext() == nullptr && "Scope matcher should not have next");
314     unsigned Size = 1; // Count the kind.
315     for (unsigned i = 0, e = SM->getNumChildren(); i != e; ++i) {
316       const unsigned ChildSize = SizeMatcherList(SM->getChild(i), OS);
317       assert(ChildSize != 0 && "Matcher cannot have child of size 0");
318       SM->getChild(i)->setSize(ChildSize);
319       Size += GetVBRSize(ChildSize) + ChildSize; // Count VBR and child size.
320     }
321     ++Size; // Count the zero sentinel.
322     return Size;
323   }
324 
325   // SwitchOpcode and SwitchType have their kind, a series of child size +
326   // opcode/type + child, and a trailing zero.
327   case Matcher::SwitchOpcode:
328   case Matcher::SwitchType: {
329     unsigned Size = 1; // Count the kind.
330     unsigned NumCases;
331     if (const SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N))
332       NumCases = SOM->getNumCases();
333     else
334       NumCases = cast<SwitchTypeMatcher>(N)->getNumCases();
335     for (unsigned i = 0, e = NumCases; i != e; ++i) {
336       Matcher *Child;
337       if (SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N)) {
338         Child = SOM->getCaseMatcher(i);
339         Size += 2; // Count the child's opcode.
340       } else {
341         Child = cast<SwitchTypeMatcher>(N)->getCaseMatcher(i);
342         ++Size; // Count the child's type.
343       }
344       const unsigned ChildSize = SizeMatcherList(Child, OS);
345       assert(ChildSize != 0 && "Matcher cannot have child of size 0");
346       Child->setSize(ChildSize);
347       Size += GetVBRSize(ChildSize) + ChildSize; // Count VBR and child size.
348     }
349     ++Size; // Count the zero sentinel.
350     return Size;
351   }
352 
353   default:
354     // Employ the matcher emitter to size other matchers.
355     return EmitMatcher(N, 0, Idx, OS);
356   }
357   llvm_unreachable("Unreachable");
358 }
359 
BeginEmitFunction(raw_ostream & OS,StringRef RetType,StringRef Decl,bool AddOverride)360 static void BeginEmitFunction(raw_ostream &OS, StringRef RetType,
361                               StringRef Decl, bool AddOverride) {
362   OS << "#ifdef GET_DAGISEL_DECL\n";
363   OS << RetType << ' ' << Decl;
364   if (AddOverride)
365     OS << " override";
366   OS << ";\n"
367         "#endif\n"
368         "#if defined(GET_DAGISEL_BODY) || DAGISEL_INLINE\n";
369   OS << RetType << " DAGISEL_CLASS_COLONCOLON " << Decl << "\n";
370   if (AddOverride) {
371     OS << "#if DAGISEL_INLINE\n"
372           "  override\n"
373           "#endif\n";
374   }
375 }
376 
EndEmitFunction(raw_ostream & OS)377 static void EndEmitFunction(raw_ostream &OS) {
378   OS << "#endif // GET_DAGISEL_BODY\n\n";
379 }
380 
EmitPatternMatchTable(raw_ostream & OS)381 void MatcherTableEmitter::EmitPatternMatchTable(raw_ostream &OS) {
382 
383   assert(isUInt<16>(VecPatterns.size()) &&
384          "Using only 16 bits to encode offset into Pattern Table");
385   assert(VecPatterns.size() == VecIncludeStrings.size() &&
386          "The sizes of Pattern and include vectors should be the same");
387 
388   BeginEmitFunction(OS, "StringRef", "getPatternForIndex(unsigned Index)",
389                     true /*AddOverride*/);
390   OS << "{\n";
391   OS << "static const char *PATTERN_MATCH_TABLE[] = {\n";
392 
393   for (const auto &It : VecPatterns) {
394     OS << "\"" << It.first << "\",\n";
395   }
396 
397   OS << "\n};";
398   OS << "\nreturn StringRef(PATTERN_MATCH_TABLE[Index]);";
399   OS << "\n}\n";
400   EndEmitFunction(OS);
401 
402   BeginEmitFunction(OS, "StringRef", "getIncludePathForIndex(unsigned Index)",
403                     true /*AddOverride*/);
404   OS << "{\n";
405   OS << "static const char *INCLUDE_PATH_TABLE[] = {\n";
406 
407   for (const auto &It : VecIncludeStrings) {
408     OS << "\"" << It << "\",\n";
409   }
410 
411   OS << "\n};";
412   OS << "\nreturn StringRef(INCLUDE_PATH_TABLE[Index]);";
413   OS << "\n}\n";
414   EndEmitFunction(OS);
415 }
416 
417 /// EmitMatcher - Emit bytes for the specified matcher and return
418 /// the number of bytes emitted.
EmitMatcher(const Matcher * N,const unsigned Indent,unsigned CurrentIdx,raw_ostream & OS)419 unsigned MatcherTableEmitter::EmitMatcher(const Matcher *N,
420                                           const unsigned Indent,
421                                           unsigned CurrentIdx,
422                                           raw_ostream &OS) {
423   OS.indent(Indent);
424 
425   switch (N->getKind()) {
426   case Matcher::Scope: {
427     const ScopeMatcher *SM = cast<ScopeMatcher>(N);
428     unsigned StartIdx = CurrentIdx;
429 
430     // Emit all of the children.
431     for (unsigned i = 0, e = SM->getNumChildren(); i != e; ++i) {
432       if (i == 0) {
433         OS << "OPC_Scope, ";
434         ++CurrentIdx;
435       } else {
436         if (!OmitComments) {
437           OS << "/*" << format_decimal(CurrentIdx, IndexWidth) << "*/";
438           OS.indent(Indent) << "/*Scope*/ ";
439         } else
440           OS.indent(Indent);
441       }
442 
443       unsigned ChildSize = SM->getChild(i)->getSize();
444       unsigned VBRSize = EmitVBRValue(ChildSize, OS);
445       if (!OmitComments) {
446         OS << "/*->" << CurrentIdx + VBRSize + ChildSize << "*/";
447         if (i == 0)
448           OS << " // " << SM->getNumChildren() << " children in Scope";
449       }
450       OS << '\n';
451 
452       ChildSize = EmitMatcherList(SM->getChild(i), Indent + 1,
453                                   CurrentIdx + VBRSize, OS);
454       assert(ChildSize == SM->getChild(i)->getSize() &&
455              "Emitted child size does not match calculated size");
456       CurrentIdx += VBRSize + ChildSize;
457     }
458 
459     // Emit a zero as a sentinel indicating end of 'Scope'.
460     if (!OmitComments)
461       OS << "/*" << format_decimal(CurrentIdx, IndexWidth) << "*/";
462     OS.indent(Indent) << "0, ";
463     if (!OmitComments)
464       OS << "/*End of Scope*/";
465     OS << '\n';
466     return CurrentIdx - StartIdx + 1;
467   }
468 
469   case Matcher::RecordNode:
470     OS << "OPC_RecordNode,";
471     if (!OmitComments)
472       OS << " // #" << cast<RecordMatcher>(N)->getResultNo() << " = "
473          << cast<RecordMatcher>(N)->getWhatFor();
474     OS << '\n';
475     return 1;
476 
477   case Matcher::RecordChild:
478     OS << "OPC_RecordChild" << cast<RecordChildMatcher>(N)->getChildNo() << ',';
479     if (!OmitComments)
480       OS << " // #" << cast<RecordChildMatcher>(N)->getResultNo() << " = "
481          << cast<RecordChildMatcher>(N)->getWhatFor();
482     OS << '\n';
483     return 1;
484 
485   case Matcher::RecordMemRef:
486     OS << "OPC_RecordMemRef,\n";
487     return 1;
488 
489   case Matcher::CaptureGlueInput:
490     OS << "OPC_CaptureGlueInput,\n";
491     return 1;
492 
493   case Matcher::MoveChild: {
494     const auto *MCM = cast<MoveChildMatcher>(N);
495 
496     OS << "OPC_MoveChild";
497     // Handle the specialized forms.
498     if (MCM->getChildNo() >= 8)
499       OS << ", ";
500     OS << MCM->getChildNo() << ",\n";
501     return (MCM->getChildNo() >= 8) ? 2 : 1;
502   }
503 
504   case Matcher::MoveSibling: {
505     const auto *MSM = cast<MoveSiblingMatcher>(N);
506 
507     OS << "OPC_MoveSibling";
508     // Handle the specialized forms.
509     if (MSM->getSiblingNo() >= 8)
510       OS << ", ";
511     OS << MSM->getSiblingNo() << ",\n";
512     return (MSM->getSiblingNo() >= 8) ? 2 : 1;
513   }
514 
515   case Matcher::MoveParent:
516     OS << "OPC_MoveParent,\n";
517     return 1;
518 
519   case Matcher::CheckSame:
520     OS << "OPC_CheckSame, " << cast<CheckSameMatcher>(N)->getMatchNumber()
521        << ",\n";
522     return 2;
523 
524   case Matcher::CheckChildSame:
525     OS << "OPC_CheckChild" << cast<CheckChildSameMatcher>(N)->getChildNo()
526        << "Same, " << cast<CheckChildSameMatcher>(N)->getMatchNumber() << ",\n";
527     return 2;
528 
529   case Matcher::CheckPatternPredicate: {
530     StringRef Pred = cast<CheckPatternPredicateMatcher>(N)->getPredicate();
531     unsigned PredNo = getPatternPredicate(Pred);
532     if (PredNo > 255)
533       OS << "OPC_CheckPatternPredicateTwoByte, TARGET_VAL(" << PredNo << "),";
534     else if (PredNo < 8)
535       OS << "OPC_CheckPatternPredicate" << PredNo << ',';
536     else
537       OS << "OPC_CheckPatternPredicate, " << PredNo << ',';
538     if (!OmitComments)
539       OS << " // " << Pred;
540     OS << '\n';
541     return 2 + (PredNo > 255) - (PredNo < 8);
542   }
543   case Matcher::CheckPredicate: {
544     TreePredicateFn Pred = cast<CheckPredicateMatcher>(N)->getPredicate();
545     unsigned OperandBytes = 0;
546     unsigned PredNo = getNodePredicate(Pred);
547 
548     if (Pred.usesOperands()) {
549       unsigned NumOps = cast<CheckPredicateMatcher>(N)->getNumOperands();
550       OS << "OPC_CheckPredicateWithOperands, " << NumOps << "/*#Ops*/, ";
551       for (unsigned i = 0; i < NumOps; ++i)
552         OS << cast<CheckPredicateMatcher>(N)->getOperandNo(i) << ", ";
553       OperandBytes = 1 + NumOps;
554     } else {
555       if (PredNo < 8) {
556         OperandBytes = -1;
557         OS << "OPC_CheckPredicate" << PredNo << ", ";
558       } else
559         OS << "OPC_CheckPredicate, ";
560     }
561 
562     if (PredNo >= 8 || Pred.usesOperands())
563       OS << PredNo << ',';
564     if (!OmitComments)
565       OS << " // " << Pred.getFnName();
566     OS << '\n';
567     return 2 + OperandBytes;
568   }
569 
570   case Matcher::CheckOpcode:
571     OS << "OPC_CheckOpcode, TARGET_VAL("
572        << cast<CheckOpcodeMatcher>(N)->getOpcode().getEnumName() << "),\n";
573     return 3;
574 
575   case Matcher::SwitchOpcode:
576   case Matcher::SwitchType: {
577     unsigned StartIdx = CurrentIdx;
578 
579     unsigned NumCases;
580     if (const SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N)) {
581       OS << "OPC_SwitchOpcode ";
582       NumCases = SOM->getNumCases();
583     } else {
584       OS << "OPC_SwitchType ";
585       NumCases = cast<SwitchTypeMatcher>(N)->getNumCases();
586     }
587 
588     if (!OmitComments)
589       OS << "/*" << NumCases << " cases */";
590     OS << ", ";
591     ++CurrentIdx;
592 
593     // For each case we emit the size, then the opcode, then the matcher.
594     for (unsigned i = 0, e = NumCases; i != e; ++i) {
595       const Matcher *Child;
596       unsigned IdxSize;
597       if (const SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N)) {
598         Child = SOM->getCaseMatcher(i);
599         IdxSize = 2; // size of opcode in table is 2 bytes.
600       } else {
601         Child = cast<SwitchTypeMatcher>(N)->getCaseMatcher(i);
602         IdxSize = 1; // size of type in table is 1 byte.
603       }
604 
605       if (i != 0) {
606         if (!OmitComments)
607           OS << "/*" << format_decimal(CurrentIdx, IndexWidth) << "*/";
608         OS.indent(Indent);
609         if (!OmitComments)
610           OS << (isa<SwitchOpcodeMatcher>(N) ? "/*SwitchOpcode*/ "
611                                              : "/*SwitchType*/ ");
612       }
613 
614       unsigned ChildSize = Child->getSize();
615       CurrentIdx += EmitVBRValue(ChildSize, OS) + IdxSize;
616       if (const SwitchOpcodeMatcher *SOM = dyn_cast<SwitchOpcodeMatcher>(N))
617         OS << "TARGET_VAL(" << SOM->getCaseOpcode(i).getEnumName() << "),";
618       else
619         OS << getEnumName(cast<SwitchTypeMatcher>(N)->getCaseType(i)) << ',';
620       if (!OmitComments)
621         OS << "// ->" << CurrentIdx + ChildSize;
622       OS << '\n';
623 
624       ChildSize = EmitMatcherList(Child, Indent + 1, CurrentIdx, OS);
625       assert(ChildSize == Child->getSize() &&
626              "Emitted child size does not match calculated size");
627       CurrentIdx += ChildSize;
628     }
629 
630     // Emit the final zero to terminate the switch.
631     if (!OmitComments)
632       OS << "/*" << format_decimal(CurrentIdx, IndexWidth) << "*/";
633     OS.indent(Indent) << "0,";
634     if (!OmitComments)
635       OS << (isa<SwitchOpcodeMatcher>(N) ? " // EndSwitchOpcode"
636                                          : " // EndSwitchType");
637 
638     OS << '\n';
639     return CurrentIdx - StartIdx + 1;
640   }
641 
642   case Matcher::CheckType:
643     if (cast<CheckTypeMatcher>(N)->getResNo() == 0) {
644       MVT::SimpleValueType VT = cast<CheckTypeMatcher>(N)->getType();
645       switch (VT) {
646       case MVT::i32:
647       case MVT::i64:
648         OS << "OPC_CheckTypeI" << MVT(VT).getSizeInBits() << ",\n";
649         return 1;
650       default:
651         OS << "OPC_CheckType, " << getEnumName(VT) << ",\n";
652         return 2;
653       }
654     }
655     OS << "OPC_CheckTypeRes, " << cast<CheckTypeMatcher>(N)->getResNo() << ", "
656        << getEnumName(cast<CheckTypeMatcher>(N)->getType()) << ",\n";
657     return 3;
658 
659   case Matcher::CheckChildType: {
660     MVT::SimpleValueType VT = cast<CheckChildTypeMatcher>(N)->getType();
661     switch (VT) {
662     case MVT::i32:
663     case MVT::i64:
664       OS << "OPC_CheckChild" << cast<CheckChildTypeMatcher>(N)->getChildNo()
665          << "TypeI" << MVT(VT).getSizeInBits() << ",\n";
666       return 1;
667     default:
668       OS << "OPC_CheckChild" << cast<CheckChildTypeMatcher>(N)->getChildNo()
669          << "Type, " << getEnumName(VT) << ",\n";
670       return 2;
671     }
672   }
673 
674   case Matcher::CheckInteger: {
675     OS << "OPC_CheckInteger, ";
676     unsigned Bytes =
677         1 + EmitSignedVBRValue(cast<CheckIntegerMatcher>(N)->getValue(), OS);
678     OS << '\n';
679     return Bytes;
680   }
681   case Matcher::CheckChildInteger: {
682     OS << "OPC_CheckChild" << cast<CheckChildIntegerMatcher>(N)->getChildNo()
683        << "Integer, ";
684     unsigned Bytes = 1 + EmitSignedVBRValue(
685                              cast<CheckChildIntegerMatcher>(N)->getValue(), OS);
686     OS << '\n';
687     return Bytes;
688   }
689   case Matcher::CheckCondCode:
690     OS << "OPC_CheckCondCode, ISD::"
691        << cast<CheckCondCodeMatcher>(N)->getCondCodeName() << ",\n";
692     return 2;
693 
694   case Matcher::CheckChild2CondCode:
695     OS << "OPC_CheckChild2CondCode, ISD::"
696        << cast<CheckChild2CondCodeMatcher>(N)->getCondCodeName() << ",\n";
697     return 2;
698 
699   case Matcher::CheckValueType:
700     OS << "OPC_CheckValueType, "
701        << getEnumName(cast<CheckValueTypeMatcher>(N)->getVT()) << ",\n";
702     return 2;
703 
704   case Matcher::CheckComplexPat: {
705     const CheckComplexPatMatcher *CCPM = cast<CheckComplexPatMatcher>(N);
706     const ComplexPattern &Pattern = CCPM->getPattern();
707     unsigned PatternNo = getComplexPat(Pattern);
708     if (PatternNo < 8)
709       OS << "OPC_CheckComplexPat" << PatternNo << ", /*#*/"
710          << CCPM->getMatchNumber() << ',';
711     else
712       OS << "OPC_CheckComplexPat, /*CP*/" << PatternNo << ", /*#*/"
713          << CCPM->getMatchNumber() << ',';
714 
715     if (!OmitComments) {
716       OS << " // " << Pattern.getSelectFunc();
717       OS << ":$" << CCPM->getName();
718       for (unsigned i = 0, e = Pattern.getNumOperands(); i != e; ++i)
719         OS << " #" << CCPM->getFirstResult() + i;
720 
721       if (Pattern.hasProperty(SDNPHasChain))
722         OS << " + chain result";
723     }
724     OS << '\n';
725     return PatternNo < 8 ? 2 : 3;
726   }
727 
728   case Matcher::CheckAndImm: {
729     OS << "OPC_CheckAndImm, ";
730     unsigned Bytes =
731         1 + EmitVBRValue(cast<CheckAndImmMatcher>(N)->getValue(), OS);
732     OS << '\n';
733     return Bytes;
734   }
735 
736   case Matcher::CheckOrImm: {
737     OS << "OPC_CheckOrImm, ";
738     unsigned Bytes =
739         1 + EmitVBRValue(cast<CheckOrImmMatcher>(N)->getValue(), OS);
740     OS << '\n';
741     return Bytes;
742   }
743 
744   case Matcher::CheckFoldableChainNode:
745     OS << "OPC_CheckFoldableChainNode,\n";
746     return 1;
747 
748   case Matcher::CheckImmAllOnesV:
749     OS << "OPC_CheckImmAllOnesV,\n";
750     return 1;
751 
752   case Matcher::CheckImmAllZerosV:
753     OS << "OPC_CheckImmAllZerosV,\n";
754     return 1;
755 
756   case Matcher::EmitInteger: {
757     int64_t Val = cast<EmitIntegerMatcher>(N)->getValue();
758     MVT::SimpleValueType VT = cast<EmitIntegerMatcher>(N)->getVT();
759     unsigned OpBytes;
760     switch (VT) {
761     case MVT::i8:
762     case MVT::i16:
763     case MVT::i32:
764     case MVT::i64:
765       OpBytes = 1;
766       OS << "OPC_EmitInteger" << MVT(VT).getSizeInBits() << ", ";
767       break;
768     default:
769       OpBytes = 2;
770       OS << "OPC_EmitInteger, " << getEnumName(VT) << ", ";
771       break;
772     }
773     unsigned Bytes = OpBytes + EmitSignedVBRValue(Val, OS);
774     OS << '\n';
775     return Bytes;
776   }
777   case Matcher::EmitStringInteger: {
778     const std::string &Val = cast<EmitStringIntegerMatcher>(N)->getValue();
779     MVT::SimpleValueType VT = cast<EmitStringIntegerMatcher>(N)->getVT();
780     // These should always fit into 7 bits.
781     unsigned OpBytes;
782     switch (VT) {
783     case MVT::i32:
784       OpBytes = 1;
785       OS << "OPC_EmitStringInteger" << MVT(VT).getSizeInBits() << ", ";
786       break;
787     default:
788       OpBytes = 2;
789       OS << "OPC_EmitStringInteger, " << getEnumName(VT) << ", ";
790       break;
791     }
792     OS << Val << ",\n";
793     return OpBytes + 1;
794   }
795 
796   case Matcher::EmitRegister: {
797     const EmitRegisterMatcher *Matcher = cast<EmitRegisterMatcher>(N);
798     const CodeGenRegister *Reg = Matcher->getReg();
799     MVT::SimpleValueType VT = Matcher->getVT();
800     // If the enum value of the register is larger than one byte can handle,
801     // use EmitRegister2.
802     if (Reg && Reg->EnumValue > 255) {
803       OS << "OPC_EmitRegister2, " << getEnumName(VT) << ", ";
804       OS << "TARGET_VAL(" << getQualifiedName(Reg->TheDef) << "),\n";
805       return 4;
806     }
807     unsigned OpBytes;
808     switch (VT) {
809     case MVT::i32:
810     case MVT::i64:
811       OpBytes = 1;
812       OS << "OPC_EmitRegisterI" << MVT(VT).getSizeInBits() << ", ";
813       break;
814     default:
815       OpBytes = 2;
816       OS << "OPC_EmitRegister, " << getEnumName(VT) << ", ";
817       break;
818     }
819     if (Reg) {
820       OS << getQualifiedName(Reg->TheDef) << ",\n";
821     } else {
822       OS << "0 ";
823       if (!OmitComments)
824         OS << "/*zero_reg*/";
825       OS << ",\n";
826     }
827     return OpBytes + 1;
828   }
829 
830   case Matcher::EmitConvertToTarget: {
831     unsigned Slot = cast<EmitConvertToTargetMatcher>(N)->getSlot();
832     if (Slot < 8) {
833       OS << "OPC_EmitConvertToTarget" << Slot << ",\n";
834       return 1;
835     }
836     OS << "OPC_EmitConvertToTarget, " << Slot << ",\n";
837     return 2;
838   }
839 
840   case Matcher::EmitMergeInputChains: {
841     const EmitMergeInputChainsMatcher *MN =
842         cast<EmitMergeInputChainsMatcher>(N);
843 
844     // Handle the specialized forms OPC_EmitMergeInputChains1_0, 1_1, and 1_2.
845     if (MN->getNumNodes() == 1 && MN->getNode(0) < 3) {
846       OS << "OPC_EmitMergeInputChains1_" << MN->getNode(0) << ",\n";
847       return 1;
848     }
849 
850     OS << "OPC_EmitMergeInputChains, " << MN->getNumNodes() << ", ";
851     for (unsigned i = 0, e = MN->getNumNodes(); i != e; ++i)
852       OS << MN->getNode(i) << ", ";
853     OS << '\n';
854     return 2 + MN->getNumNodes();
855   }
856   case Matcher::EmitCopyToReg: {
857     const auto *C2RMatcher = cast<EmitCopyToRegMatcher>(N);
858     int Bytes = 3;
859     const CodeGenRegister *Reg = C2RMatcher->getDestPhysReg();
860     unsigned Slot = C2RMatcher->getSrcSlot();
861     if (Reg->EnumValue > 255) {
862       assert(isUInt<16>(Reg->EnumValue) && "not handled");
863       OS << "OPC_EmitCopyToRegTwoByte, " << Slot << ", "
864          << "TARGET_VAL(" << getQualifiedName(Reg->TheDef) << "),\n";
865       ++Bytes;
866     } else {
867       if (Slot < 8) {
868         OS << "OPC_EmitCopyToReg" << Slot << ", "
869            << getQualifiedName(Reg->TheDef) << ",\n";
870         --Bytes;
871       } else
872         OS << "OPC_EmitCopyToReg, " << Slot << ", "
873            << getQualifiedName(Reg->TheDef) << ",\n";
874     }
875 
876     return Bytes;
877   }
878   case Matcher::EmitNodeXForm: {
879     const EmitNodeXFormMatcher *XF = cast<EmitNodeXFormMatcher>(N);
880     OS << "OPC_EmitNodeXForm, " << getNodeXFormID(XF->getNodeXForm()) << ", "
881        << XF->getSlot() << ',';
882     if (!OmitComments)
883       OS << " // " << XF->getNodeXForm()->getName();
884     OS << '\n';
885     return 3;
886   }
887 
888   case Matcher::EmitNode:
889   case Matcher::MorphNodeTo: {
890     auto NumCoveredBytes = 0;
891     if (InstrumentCoverage) {
892       if (const MorphNodeToMatcher *SNT = dyn_cast<MorphNodeToMatcher>(N)) {
893         NumCoveredBytes = 3;
894         OS << "OPC_Coverage, ";
895         std::string src =
896             GetPatFromTreePatternNode(SNT->getPattern().getSrcPattern());
897         std::string dst =
898             GetPatFromTreePatternNode(SNT->getPattern().getDstPattern());
899         Record *PatRecord = SNT->getPattern().getSrcRecord();
900         std::string include_src = getIncludePath(PatRecord);
901         unsigned Offset =
902             getPatternIdxFromTable(src + " -> " + dst, std::move(include_src));
903         OS << "TARGET_VAL(" << Offset << "),\n";
904         OS.indent(FullIndexWidth + Indent);
905       }
906     }
907     const EmitNodeMatcherCommon *EN = cast<EmitNodeMatcherCommon>(N);
908     bool IsEmitNode = isa<EmitNodeMatcher>(EN);
909     OS << (IsEmitNode ? "OPC_EmitNode" : "OPC_MorphNodeTo");
910     bool CompressVTs = EN->getNumVTs() < 3;
911     bool CompressNodeInfo = false;
912     if (CompressVTs) {
913       OS << EN->getNumVTs();
914       if (!EN->hasChain() && !EN->hasInGlue() && !EN->hasOutGlue() &&
915           !EN->hasMemRefs() && EN->getNumFixedArityOperands() == -1) {
916         CompressNodeInfo = true;
917         OS << "None";
918       } else if (EN->hasChain() && !EN->hasInGlue() && !EN->hasOutGlue() &&
919                  !EN->hasMemRefs() && EN->getNumFixedArityOperands() == -1) {
920         CompressNodeInfo = true;
921         OS << "Chain";
922       } else if (!IsEmitNode && !EN->hasChain() && EN->hasInGlue() &&
923                  !EN->hasOutGlue() && !EN->hasMemRefs() &&
924                  EN->getNumFixedArityOperands() == -1) {
925         CompressNodeInfo = true;
926         OS << "GlueInput";
927       } else if (!IsEmitNode && !EN->hasChain() && !EN->hasInGlue() &&
928                  EN->hasOutGlue() && !EN->hasMemRefs() &&
929                  EN->getNumFixedArityOperands() == -1) {
930         CompressNodeInfo = true;
931         OS << "GlueOutput";
932       }
933     }
934 
935     const CodeGenInstruction &CGI = EN->getInstruction();
936     OS << ", TARGET_VAL(" << CGI.Namespace << "::" << CGI.TheDef->getName()
937        << ")";
938 
939     if (!CompressNodeInfo) {
940       OS << ", 0";
941       if (EN->hasChain())
942         OS << "|OPFL_Chain";
943       if (EN->hasInGlue())
944         OS << "|OPFL_GlueInput";
945       if (EN->hasOutGlue())
946         OS << "|OPFL_GlueOutput";
947       if (EN->hasMemRefs())
948         OS << "|OPFL_MemRefs";
949       if (EN->getNumFixedArityOperands() != -1)
950         OS << "|OPFL_Variadic" << EN->getNumFixedArityOperands();
951     }
952     OS << ",\n";
953 
954     OS.indent(FullIndexWidth + Indent + 4);
955     if (!CompressVTs) {
956       OS << EN->getNumVTs();
957       if (!OmitComments)
958         OS << "/*#VTs*/";
959       OS << ", ";
960     }
961     for (unsigned i = 0, e = EN->getNumVTs(); i != e; ++i)
962       OS << getEnumName(EN->getVT(i)) << ", ";
963 
964     OS << EN->getNumOperands();
965     if (!OmitComments)
966       OS << "/*#Ops*/";
967     OS << ", ";
968     unsigned NumOperandBytes = 0;
969     for (unsigned i = 0, e = EN->getNumOperands(); i != e; ++i)
970       NumOperandBytes += EmitVBRValue(EN->getOperand(i), OS);
971 
972     if (!OmitComments) {
973       // Print the result #'s for EmitNode.
974       if (const EmitNodeMatcher *E = dyn_cast<EmitNodeMatcher>(EN)) {
975         if (unsigned NumResults = EN->getNumVTs()) {
976           OS << " // Results =";
977           unsigned First = E->getFirstResultSlot();
978           for (unsigned i = 0; i != NumResults; ++i)
979             OS << " #" << First + i;
980         }
981       }
982       OS << '\n';
983 
984       if (const MorphNodeToMatcher *SNT = dyn_cast<MorphNodeToMatcher>(N)) {
985         OS.indent(FullIndexWidth + Indent)
986             << "// Src: " << SNT->getPattern().getSrcPattern()
987             << " - Complexity = " << SNT->getPattern().getPatternComplexity(CGP)
988             << '\n';
989         OS.indent(FullIndexWidth + Indent)
990             << "// Dst: " << SNT->getPattern().getDstPattern() << '\n';
991       }
992     } else
993       OS << '\n';
994 
995     return 4 + !CompressVTs + !CompressNodeInfo + EN->getNumVTs() +
996            NumOperandBytes + NumCoveredBytes;
997   }
998   case Matcher::CompleteMatch: {
999     const CompleteMatchMatcher *CM = cast<CompleteMatchMatcher>(N);
1000     auto NumCoveredBytes = 0;
1001     if (InstrumentCoverage) {
1002       NumCoveredBytes = 3;
1003       OS << "OPC_Coverage, ";
1004       std::string src =
1005           GetPatFromTreePatternNode(CM->getPattern().getSrcPattern());
1006       std::string dst =
1007           GetPatFromTreePatternNode(CM->getPattern().getDstPattern());
1008       Record *PatRecord = CM->getPattern().getSrcRecord();
1009       std::string include_src = getIncludePath(PatRecord);
1010       unsigned Offset =
1011           getPatternIdxFromTable(src + " -> " + dst, std::move(include_src));
1012       OS << "TARGET_VAL(" << Offset << "),\n";
1013       OS.indent(FullIndexWidth + Indent);
1014     }
1015     OS << "OPC_CompleteMatch, " << CM->getNumResults() << ", ";
1016     unsigned NumResultBytes = 0;
1017     for (unsigned i = 0, e = CM->getNumResults(); i != e; ++i)
1018       NumResultBytes += EmitVBRValue(CM->getResult(i), OS);
1019     OS << '\n';
1020     if (!OmitComments) {
1021       OS.indent(FullIndexWidth + Indent)
1022           << " // Src: " << CM->getPattern().getSrcPattern()
1023           << " - Complexity = " << CM->getPattern().getPatternComplexity(CGP)
1024           << '\n';
1025       OS.indent(FullIndexWidth + Indent)
1026           << " // Dst: " << CM->getPattern().getDstPattern();
1027     }
1028     OS << '\n';
1029     return 2 + NumResultBytes + NumCoveredBytes;
1030   }
1031   }
1032   llvm_unreachable("Unreachable");
1033 }
1034 
1035 /// This function traverses the matcher tree and emits all the nodes.
1036 /// The nodes have already been sized.
EmitMatcherList(const Matcher * N,const unsigned Indent,unsigned CurrentIdx,raw_ostream & OS)1037 unsigned MatcherTableEmitter::EmitMatcherList(const Matcher *N,
1038                                               const unsigned Indent,
1039                                               unsigned CurrentIdx,
1040                                               raw_ostream &OS) {
1041   unsigned Size = 0;
1042   while (N) {
1043     if (!OmitComments)
1044       OS << "/*" << format_decimal(CurrentIdx, IndexWidth) << "*/";
1045     unsigned MatcherSize = EmitMatcher(N, Indent, CurrentIdx, OS);
1046     Size += MatcherSize;
1047     CurrentIdx += MatcherSize;
1048 
1049     // If there are other nodes in this list, iterate to them, otherwise we're
1050     // done.
1051     N = N->getNext();
1052   }
1053   return Size;
1054 }
1055 
EmitNodePredicatesFunction(const std::vector<TreePattern * > & Preds,StringRef Decl,raw_ostream & OS)1056 void MatcherTableEmitter::EmitNodePredicatesFunction(
1057     const std::vector<TreePattern *> &Preds, StringRef Decl, raw_ostream &OS) {
1058   if (Preds.empty())
1059     return;
1060 
1061   BeginEmitFunction(OS, "bool", Decl, true /*AddOverride*/);
1062   OS << "{\n";
1063   OS << "  switch (PredNo) {\n";
1064   OS << "  default: llvm_unreachable(\"Invalid predicate in table?\");\n";
1065   for (unsigned i = 0, e = Preds.size(); i != e; ++i) {
1066     // Emit the predicate code corresponding to this pattern.
1067     TreePredicateFn PredFn(Preds[i]);
1068     assert(!PredFn.isAlwaysTrue() && "No code in this predicate");
1069     std::string PredFnCodeStr = PredFn.getCodeToRunOnSDNode();
1070 
1071     OS << "  case " << i << ": {\n";
1072     for (auto *SimilarPred : NodePredicatesByCodeToRun[PredFnCodeStr])
1073       OS << "    // " << TreePredicateFn(SimilarPred).getFnName() << '\n';
1074     OS << PredFnCodeStr << "\n  }\n";
1075   }
1076   OS << "  }\n";
1077   OS << "}\n";
1078   EndEmitFunction(OS);
1079 }
1080 
EmitPredicateFunctions(raw_ostream & OS)1081 void MatcherTableEmitter::EmitPredicateFunctions(raw_ostream &OS) {
1082   // Emit pattern predicates.
1083   if (!PatternPredicates.empty()) {
1084     BeginEmitFunction(OS, "bool",
1085                       "CheckPatternPredicate(unsigned PredNo) const",
1086                       true /*AddOverride*/);
1087     OS << "{\n";
1088     OS << "  switch (PredNo) {\n";
1089     OS << "  default: llvm_unreachable(\"Invalid predicate in table?\");\n";
1090     for (unsigned i = 0, e = PatternPredicates.size(); i != e; ++i)
1091       OS << "  case " << i << ": return " << PatternPredicates[i] << ";\n";
1092     OS << "  }\n";
1093     OS << "}\n";
1094     EndEmitFunction(OS);
1095   }
1096 
1097   // Emit Node predicates.
1098   EmitNodePredicatesFunction(
1099       NodePredicates, "CheckNodePredicate(SDNode *Node, unsigned PredNo) const",
1100       OS);
1101   EmitNodePredicatesFunction(
1102       NodePredicatesWithOperands,
1103       "CheckNodePredicateWithOperands(SDNode *Node, unsigned PredNo, "
1104       "const SmallVectorImpl<SDValue> &Operands) const",
1105       OS);
1106 
1107   // Emit CompletePattern matchers.
1108   // FIXME: This should be const.
1109   if (!ComplexPatterns.empty()) {
1110     BeginEmitFunction(
1111         OS, "bool",
1112         "CheckComplexPattern(SDNode *Root, SDNode *Parent,\n"
1113         "      SDValue N, unsigned PatternNo,\n"
1114         "      SmallVectorImpl<std::pair<SDValue, SDNode *>> &Result)",
1115         true /*AddOverride*/);
1116     OS << "{\n";
1117     OS << "  unsigned NextRes = Result.size();\n";
1118     OS << "  switch (PatternNo) {\n";
1119     OS << "  default: llvm_unreachable(\"Invalid pattern # in table?\");\n";
1120     for (unsigned i = 0, e = ComplexPatterns.size(); i != e; ++i) {
1121       const ComplexPattern &P = *ComplexPatterns[i];
1122       unsigned NumOps = P.getNumOperands();
1123 
1124       if (P.hasProperty(SDNPHasChain))
1125         ++NumOps; // Get the chained node too.
1126 
1127       OS << "  case " << i << ":\n";
1128       if (InstrumentCoverage)
1129         OS << "  {\n";
1130       OS << "    Result.resize(NextRes+" << NumOps << ");\n";
1131       if (InstrumentCoverage)
1132         OS << "    bool Succeeded = " << P.getSelectFunc();
1133       else
1134         OS << "  return " << P.getSelectFunc();
1135 
1136       OS << "(";
1137       // If the complex pattern wants the root of the match, pass it in as the
1138       // first argument.
1139       if (P.hasProperty(SDNPWantRoot))
1140         OS << "Root, ";
1141 
1142       // If the complex pattern wants the parent of the operand being matched,
1143       // pass it in as the next argument.
1144       if (P.hasProperty(SDNPWantParent))
1145         OS << "Parent, ";
1146 
1147       OS << "N";
1148       for (unsigned i = 0; i != NumOps; ++i)
1149         OS << ", Result[NextRes+" << i << "].first";
1150       OS << ");\n";
1151       if (InstrumentCoverage) {
1152         OS << "    if (Succeeded)\n";
1153         OS << "       dbgs() << \"\\nCOMPLEX_PATTERN: " << P.getSelectFunc()
1154            << "\\n\" ;\n";
1155         OS << "    return Succeeded;\n";
1156         OS << "    }\n";
1157       }
1158     }
1159     OS << "  }\n";
1160     OS << "}\n";
1161     EndEmitFunction(OS);
1162   }
1163 
1164   // Emit SDNodeXForm handlers.
1165   // FIXME: This should be const.
1166   if (!NodeXForms.empty()) {
1167     BeginEmitFunction(OS, "SDValue",
1168                       "RunSDNodeXForm(SDValue V, unsigned XFormNo)",
1169                       true /*AddOverride*/);
1170     OS << "{\n";
1171     OS << "  switch (XFormNo) {\n";
1172     OS << "  default: llvm_unreachable(\"Invalid xform # in table?\");\n";
1173 
1174     // FIXME: The node xform could take SDValue's instead of SDNode*'s.
1175     for (unsigned i = 0, e = NodeXForms.size(); i != e; ++i) {
1176       const CodeGenDAGPatterns::NodeXForm &Entry =
1177           CGP.getSDNodeTransform(NodeXForms[i]);
1178 
1179       Record *SDNode = Entry.first;
1180       const std::string &Code = Entry.second;
1181 
1182       OS << "  case " << i << ": {  ";
1183       if (!OmitComments)
1184         OS << "// " << NodeXForms[i]->getName();
1185       OS << '\n';
1186 
1187       std::string ClassName =
1188           std::string(CGP.getSDNodeInfo(SDNode).getSDClassName());
1189       if (ClassName == "SDNode")
1190         OS << "    SDNode *N = V.getNode();\n";
1191       else
1192         OS << "    " << ClassName << " *N = cast<" << ClassName
1193            << ">(V.getNode());\n";
1194       OS << Code << "\n  }\n";
1195     }
1196     OS << "  }\n";
1197     OS << "}\n";
1198     EndEmitFunction(OS);
1199   }
1200 }
1201 
getOpcodeString(Matcher::KindTy Kind)1202 static StringRef getOpcodeString(Matcher::KindTy Kind) {
1203   switch (Kind) {
1204   case Matcher::Scope:
1205     return "OPC_Scope";
1206   case Matcher::RecordNode:
1207     return "OPC_RecordNode";
1208   case Matcher::RecordChild:
1209     return "OPC_RecordChild";
1210   case Matcher::RecordMemRef:
1211     return "OPC_RecordMemRef";
1212   case Matcher::CaptureGlueInput:
1213     return "OPC_CaptureGlueInput";
1214   case Matcher::MoveChild:
1215     return "OPC_MoveChild";
1216   case Matcher::MoveSibling:
1217     return "OPC_MoveSibling";
1218   case Matcher::MoveParent:
1219     return "OPC_MoveParent";
1220   case Matcher::CheckSame:
1221     return "OPC_CheckSame";
1222   case Matcher::CheckChildSame:
1223     return "OPC_CheckChildSame";
1224   case Matcher::CheckPatternPredicate:
1225     return "OPC_CheckPatternPredicate";
1226   case Matcher::CheckPredicate:
1227     return "OPC_CheckPredicate";
1228   case Matcher::CheckOpcode:
1229     return "OPC_CheckOpcode";
1230   case Matcher::SwitchOpcode:
1231     return "OPC_SwitchOpcode";
1232   case Matcher::CheckType:
1233     return "OPC_CheckType";
1234   case Matcher::SwitchType:
1235     return "OPC_SwitchType";
1236   case Matcher::CheckChildType:
1237     return "OPC_CheckChildType";
1238   case Matcher::CheckInteger:
1239     return "OPC_CheckInteger";
1240   case Matcher::CheckChildInteger:
1241     return "OPC_CheckChildInteger";
1242   case Matcher::CheckCondCode:
1243     return "OPC_CheckCondCode";
1244   case Matcher::CheckChild2CondCode:
1245     return "OPC_CheckChild2CondCode";
1246   case Matcher::CheckValueType:
1247     return "OPC_CheckValueType";
1248   case Matcher::CheckComplexPat:
1249     return "OPC_CheckComplexPat";
1250   case Matcher::CheckAndImm:
1251     return "OPC_CheckAndImm";
1252   case Matcher::CheckOrImm:
1253     return "OPC_CheckOrImm";
1254   case Matcher::CheckFoldableChainNode:
1255     return "OPC_CheckFoldableChainNode";
1256   case Matcher::CheckImmAllOnesV:
1257     return "OPC_CheckImmAllOnesV";
1258   case Matcher::CheckImmAllZerosV:
1259     return "OPC_CheckImmAllZerosV";
1260   case Matcher::EmitInteger:
1261     return "OPC_EmitInteger";
1262   case Matcher::EmitStringInteger:
1263     return "OPC_EmitStringInteger";
1264   case Matcher::EmitRegister:
1265     return "OPC_EmitRegister";
1266   case Matcher::EmitConvertToTarget:
1267     return "OPC_EmitConvertToTarget";
1268   case Matcher::EmitMergeInputChains:
1269     return "OPC_EmitMergeInputChains";
1270   case Matcher::EmitCopyToReg:
1271     return "OPC_EmitCopyToReg";
1272   case Matcher::EmitNode:
1273     return "OPC_EmitNode";
1274   case Matcher::MorphNodeTo:
1275     return "OPC_MorphNodeTo";
1276   case Matcher::EmitNodeXForm:
1277     return "OPC_EmitNodeXForm";
1278   case Matcher::CompleteMatch:
1279     return "OPC_CompleteMatch";
1280   }
1281 
1282   llvm_unreachable("Unhandled opcode?");
1283 }
1284 
EmitHistogram(const Matcher * M,raw_ostream & OS)1285 void MatcherTableEmitter::EmitHistogram(const Matcher *M, raw_ostream &OS) {
1286   if (OmitComments)
1287     return;
1288 
1289   OS << "  // Opcode Histogram:\n";
1290   for (unsigned i = 0, e = OpcodeCounts.size(); i != e; ++i) {
1291     OS << "  // #"
1292        << left_justify(getOpcodeString((Matcher::KindTy)i), HistOpcWidth)
1293        << " = " << OpcodeCounts[i] << '\n';
1294   }
1295   OS << '\n';
1296 }
1297 
EmitMatcherTable(Matcher * TheMatcher,const CodeGenDAGPatterns & CGP,raw_ostream & OS)1298 void llvm::EmitMatcherTable(Matcher *TheMatcher, const CodeGenDAGPatterns &CGP,
1299                             raw_ostream &OS) {
1300   OS << "#if defined(GET_DAGISEL_DECL) && defined(GET_DAGISEL_BODY)\n";
1301   OS << "#error GET_DAGISEL_DECL and GET_DAGISEL_BODY cannot be both defined, ";
1302   OS << "undef both for inline definitions\n";
1303   OS << "#endif\n\n";
1304 
1305   // Emit a check for omitted class name.
1306   OS << "#ifdef GET_DAGISEL_BODY\n";
1307   OS << "#define LOCAL_DAGISEL_STRINGIZE(X) LOCAL_DAGISEL_STRINGIZE_(X)\n";
1308   OS << "#define LOCAL_DAGISEL_STRINGIZE_(X) #X\n";
1309   OS << "static_assert(sizeof(LOCAL_DAGISEL_STRINGIZE(GET_DAGISEL_BODY)) > 1,"
1310         "\n";
1311   OS << "   \"GET_DAGISEL_BODY is empty: it should be defined with the class "
1312         "name\");\n";
1313   OS << "#undef LOCAL_DAGISEL_STRINGIZE_\n";
1314   OS << "#undef LOCAL_DAGISEL_STRINGIZE\n";
1315   OS << "#endif\n\n";
1316 
1317   OS << "#if !defined(GET_DAGISEL_DECL) && !defined(GET_DAGISEL_BODY)\n";
1318   OS << "#define DAGISEL_INLINE 1\n";
1319   OS << "#else\n";
1320   OS << "#define DAGISEL_INLINE 0\n";
1321   OS << "#endif\n\n";
1322 
1323   OS << "#if !DAGISEL_INLINE\n";
1324   OS << "#define DAGISEL_CLASS_COLONCOLON GET_DAGISEL_BODY ::\n";
1325   OS << "#else\n";
1326   OS << "#define DAGISEL_CLASS_COLONCOLON\n";
1327   OS << "#endif\n\n";
1328 
1329   BeginEmitFunction(OS, "void", "SelectCode(SDNode *N)", false /*AddOverride*/);
1330   MatcherTableEmitter MatcherEmitter(TheMatcher, CGP);
1331 
1332   // First we size all the children of the three kinds of matchers that have
1333   // them. This is done by sharing the code in EmitMatcher(). but we don't
1334   // want to emit anything, so we turn off comments and use a null stream.
1335   bool SaveOmitComments = OmitComments;
1336   OmitComments = true;
1337   raw_null_ostream NullOS;
1338   unsigned TotalSize = MatcherEmitter.SizeMatcherList(TheMatcher, NullOS);
1339   OmitComments = SaveOmitComments;
1340 
1341   // Now that the matchers are sized, we can emit the code for them to the
1342   // final stream.
1343   OS << "{\n";
1344   OS << "  // Some target values are emitted as 2 bytes, TARGET_VAL handles\n";
1345   OS << "  // this.\n";
1346   OS << "  #define TARGET_VAL(X) X & 255, unsigned(X) >> 8\n";
1347   OS << "  static const unsigned char MatcherTable[] = {\n";
1348   TotalSize = MatcherEmitter.EmitMatcherList(TheMatcher, 1, 0, OS);
1349   OS << "    0\n  }; // Total Array size is " << (TotalSize + 1)
1350      << " bytes\n\n";
1351 
1352   MatcherEmitter.EmitHistogram(TheMatcher, OS);
1353 
1354   OS << "  #undef TARGET_VAL\n";
1355   OS << "  SelectCodeCommon(N, MatcherTable, sizeof(MatcherTable));\n";
1356   OS << "}\n";
1357   EndEmitFunction(OS);
1358 
1359   // Next up, emit the function for node and pattern predicates:
1360   MatcherEmitter.EmitPredicateFunctions(OS);
1361 
1362   if (InstrumentCoverage)
1363     MatcherEmitter.EmitPatternMatchTable(OS);
1364 
1365   // Clean up the preprocessor macros.
1366   OS << "\n";
1367   OS << "#ifdef DAGISEL_INLINE\n";
1368   OS << "#undef DAGISEL_INLINE\n";
1369   OS << "#endif\n";
1370   OS << "#ifdef DAGISEL_CLASS_COLONCOLON\n";
1371   OS << "#undef DAGISEL_CLASS_COLONCOLON\n";
1372   OS << "#endif\n";
1373   OS << "#ifdef GET_DAGISEL_DECL\n";
1374   OS << "#undef GET_DAGISEL_DECL\n";
1375   OS << "#endif\n";
1376   OS << "#ifdef GET_DAGISEL_BODY\n";
1377   OS << "#undef GET_DAGISEL_BODY\n";
1378   OS << "#endif\n";
1379 }
1380