xref: /freebsd/contrib/llvm-project/llvm/utils/TableGen/Common/VarLenCodeEmitterGen.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- VarLenCodeEmitterGen.cpp - CEG for variable-length insts -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The CodeEmitterGen component for variable-length instructions.
10 //
11 // The basic CodeEmitterGen is almost exclusively designed for fixed-
12 // length instructions. A good analogy for its encoding scheme is how printf
13 // works: The (immutable) formatting string represent the fixed values in the
14 // encoded instruction. Placeholders (i.e. %something), on the other hand,
15 // represent encoding for instruction operands.
16 // ```
17 // printf("1101 %src 1001 %dst", <encoded value for operand `src`>,
18 //                               <encoded value for operand `dst`>);
19 // ```
20 // VarLenCodeEmitterGen in this file provides an alternative encoding scheme
21 // that works more like a C++ stream operator:
22 // ```
23 // OS << 0b1101;
24 // if (Cond)
25 //   OS << OperandEncoding0;
26 // OS << 0b1001 << OperandEncoding1;
27 // ```
28 // You are free to concatenate arbitrary types (and sizes) of encoding
29 // fragments on any bit position, bringing more flexibilities on defining
30 // encoding for variable-length instructions.
31 //
32 // In a more specific way, instruction encoding is represented by a DAG type
33 // `Inst` field. Here is an example:
34 // ```
35 // dag Inst = (descend 0b1101, (operand "$src", 4), 0b1001,
36 //                     (operand "$dst", 4));
37 // ```
38 // It represents the following instruction encoding:
39 // ```
40 // MSB                                                     LSB
41 // 1101<encoding for operand src>1001<encoding for operand dst>
42 // ```
43 // For more details about DAG operators in the above snippet, please
44 // refer to \file include/llvm/Target/Target.td.
45 //
46 // VarLenCodeEmitter will convert the above DAG into the same helper function
47 // generated by CodeEmitter, `MCCodeEmitter::getBinaryCodeForInstr` (except
48 // for few details).
49 //
50 //===----------------------------------------------------------------------===//
51 
52 #include "VarLenCodeEmitterGen.h"
53 #include "CodeGenHwModes.h"
54 #include "CodeGenInstruction.h"
55 #include "CodeGenTarget.h"
56 #include "InfoByHwMode.h"
57 #include "llvm/ADT/ArrayRef.h"
58 #include "llvm/ADT/DenseMap.h"
59 #include "llvm/Support/raw_ostream.h"
60 #include "llvm/TableGen/Error.h"
61 #include "llvm/TableGen/Record.h"
62 
63 #include <algorithm>
64 
65 using namespace llvm;
66 
67 namespace {
68 
69 class VarLenCodeEmitterGen {
70   const RecordKeeper &Records;
71 
72   // Representaton of alternative encodings used for HwModes.
73   using AltEncodingTy = int;
74   // Mode identifier when only one encoding is defined.
75   const AltEncodingTy Universal = -1;
76   // The set of alternative instruction encodings with a descriptive
77   // name suffix to improve readability of the generated code.
78   std::map<AltEncodingTy, std::string> Modes;
79 
80   DenseMap<const Record *, DenseMap<AltEncodingTy, VarLenInst>> VarLenInsts;
81 
82   // Emit based values (i.e. fixed bits in the encoded instructions)
83   void emitInstructionBaseValues(
84       raw_ostream &OS,
85       ArrayRef<const CodeGenInstruction *> NumberedInstructions,
86       const CodeGenTarget &Target, AltEncodingTy Mode);
87 
88   std::string getInstructionCases(const Record *R, const CodeGenTarget &Target);
89   std::string getInstructionCaseForEncoding(const Record *R, AltEncodingTy Mode,
90                                             const VarLenInst &VLI,
91                                             const CodeGenTarget &Target,
92                                             int Indent);
93 
94 public:
95   explicit VarLenCodeEmitterGen(const RecordKeeper &R) : Records(R) {}
96 
97   void run(raw_ostream &OS);
98 };
99 } // end anonymous namespace
100 
101 // Get the name of custom encoder or decoder, if there is any.
102 // Returns `{encoder name, decoder name}`.
103 static std::pair<StringRef, StringRef>
104 getCustomCoders(ArrayRef<const Init *> Args) {
105   std::pair<StringRef, StringRef> Result;
106   for (const auto *Arg : Args) {
107     const auto *DI = dyn_cast<DagInit>(Arg);
108     if (!DI)
109       continue;
110     const Init *Op = DI->getOperator();
111     if (!isa<DefInit>(Op))
112       continue;
113     // syntax: `(<encoder | decoder> "function name")`
114     StringRef OpName = cast<DefInit>(Op)->getDef()->getName();
115     if (OpName != "encoder" && OpName != "decoder")
116       continue;
117     if (!DI->getNumArgs() || !isa<StringInit>(DI->getArg(0)))
118       PrintFatalError("expected '" + OpName +
119                       "' directive to be followed by a custom function name.");
120     StringRef FuncName = cast<StringInit>(DI->getArg(0))->getValue();
121     if (OpName == "encoder")
122       Result.first = FuncName;
123     else
124       Result.second = FuncName;
125   }
126   return Result;
127 }
128 
129 VarLenInst::VarLenInst(const DagInit *DI, const RecordVal *TheDef)
130     : TheDef(TheDef), NumBits(0U), HasDynamicSegment(false) {
131   buildRec(DI);
132   for (const auto &S : Segments)
133     NumBits += S.BitWidth;
134 }
135 
136 void VarLenInst::buildRec(const DagInit *DI) {
137   assert(TheDef && "The def record is nullptr ?");
138 
139   std::string Op = DI->getOperator()->getAsString();
140 
141   if (Op == "ascend" || Op == "descend") {
142     bool Reverse = Op == "descend";
143     int i = Reverse ? DI->getNumArgs() - 1 : 0;
144     int e = Reverse ? -1 : DI->getNumArgs();
145     int s = Reverse ? -1 : 1;
146     for (; i != e; i += s) {
147       const Init *Arg = DI->getArg(i);
148       if (const auto *BI = dyn_cast<BitsInit>(Arg)) {
149         if (!BI->isComplete())
150           PrintFatalError(TheDef->getLoc(),
151                           "Expecting complete bits init in `" + Op + "`");
152         Segments.push_back({BI->getNumBits(), BI});
153       } else if (const auto *BI = dyn_cast<BitInit>(Arg)) {
154         if (!BI->isConcrete())
155           PrintFatalError(TheDef->getLoc(),
156                           "Expecting concrete bit init in `" + Op + "`");
157         Segments.push_back({1, BI});
158       } else if (const auto *SubDI = dyn_cast<DagInit>(Arg)) {
159         buildRec(SubDI);
160       } else {
161         PrintFatalError(TheDef->getLoc(), "Unrecognized type of argument in `" +
162                                               Op + "`: " + Arg->getAsString());
163       }
164     }
165   } else if (Op == "operand") {
166     // (operand <operand name>, <# of bits>,
167     //          [(encoder <custom encoder>)][, (decoder <custom decoder>)])
168     if (DI->getNumArgs() < 2)
169       PrintFatalError(TheDef->getLoc(),
170                       "Expecting at least 2 arguments for `operand`");
171     HasDynamicSegment = true;
172     const Init *OperandName = DI->getArg(0), *NumBits = DI->getArg(1);
173     if (!isa<StringInit>(OperandName) || !isa<IntInit>(NumBits))
174       PrintFatalError(TheDef->getLoc(), "Invalid argument types for `operand`");
175 
176     auto NumBitsVal = cast<IntInit>(NumBits)->getValue();
177     if (NumBitsVal <= 0)
178       PrintFatalError(TheDef->getLoc(), "Invalid number of bits for `operand`");
179 
180     auto [CustomEncoder, CustomDecoder] =
181         getCustomCoders(DI->getArgs().slice(2));
182     Segments.push_back({static_cast<unsigned>(NumBitsVal), OperandName,
183                         CustomEncoder, CustomDecoder});
184   } else if (Op == "slice") {
185     // (slice <operand name>, <high / low bit>, <low / high bit>,
186     //        [(encoder <custom encoder>)][, (decoder <custom decoder>)])
187     if (DI->getNumArgs() < 3)
188       PrintFatalError(TheDef->getLoc(),
189                       "Expecting at least 3 arguments for `slice`");
190     HasDynamicSegment = true;
191     const Init *OperandName = DI->getArg(0), *HiBit = DI->getArg(1),
192                *LoBit = DI->getArg(2);
193     if (!isa<StringInit>(OperandName) || !isa<IntInit>(HiBit) ||
194         !isa<IntInit>(LoBit))
195       PrintFatalError(TheDef->getLoc(), "Invalid argument types for `slice`");
196 
197     auto HiBitVal = cast<IntInit>(HiBit)->getValue(),
198          LoBitVal = cast<IntInit>(LoBit)->getValue();
199     if (HiBitVal < 0 || LoBitVal < 0)
200       PrintFatalError(TheDef->getLoc(), "Invalid bit range for `slice`");
201     bool NeedSwap = false;
202     unsigned NumBits = 0U;
203     if (HiBitVal < LoBitVal) {
204       NeedSwap = true;
205       NumBits = static_cast<unsigned>(LoBitVal - HiBitVal + 1);
206     } else {
207       NumBits = static_cast<unsigned>(HiBitVal - LoBitVal + 1);
208     }
209 
210     auto [CustomEncoder, CustomDecoder] =
211         getCustomCoders(DI->getArgs().slice(3));
212 
213     if (NeedSwap) {
214       // Normalization: Hi bit should always be the second argument.
215       SmallVector<std::pair<const Init *, const StringInit *>> NewArgs(
216           DI->getArgAndNames());
217       std::swap(NewArgs[1], NewArgs[2]);
218       Segments.push_back({NumBits, DagInit::get(DI->getOperator(), NewArgs),
219                           CustomEncoder, CustomDecoder});
220     } else {
221       Segments.push_back({NumBits, DI, CustomEncoder, CustomDecoder});
222     }
223   }
224 }
225 
226 void VarLenCodeEmitterGen::run(raw_ostream &OS) {
227   CodeGenTarget Target(Records);
228 
229   auto NumberedInstructions = Target.getInstructions();
230 
231   for (const CodeGenInstruction *CGI : NumberedInstructions) {
232     const Record *R = CGI->TheDef;
233     // Create the corresponding VarLenInst instance.
234     if (R->getValueAsString("Namespace") == "TargetOpcode" ||
235         R->getValueAsBit("isPseudo"))
236       continue;
237 
238     // Setup alternative encodings according to HwModes
239     if (const RecordVal *RV = R->getValue("EncodingInfos")) {
240       if (auto *DI = dyn_cast_or_null<DefInit>(RV->getValue())) {
241         const CodeGenHwModes &HWM = Target.getHwModes();
242         EncodingInfoByHwMode EBM(DI->getDef(), HWM);
243         for (const auto [Mode, EncodingDef] : EBM) {
244           Modes.try_emplace(Mode, "_" + HWM.getMode(Mode).Name.str());
245           const RecordVal *RV = EncodingDef->getValue("Inst");
246           const DagInit *DI = cast<DagInit>(RV->getValue());
247           VarLenInsts[R].try_emplace(Mode, VarLenInst(DI, RV));
248         }
249         continue;
250       }
251     }
252     const RecordVal *RV = R->getValue("Inst");
253     const DagInit *DI = cast<DagInit>(RV->getValue());
254     VarLenInsts[R].try_emplace(Universal, VarLenInst(DI, RV));
255   }
256 
257   if (Modes.empty())
258     Modes.try_emplace(Universal, ""); // Base case, skip suffix.
259 
260   // Emit function declaration
261   OS << "void " << Target.getName()
262      << "MCCodeEmitter::getBinaryCodeForInstr(const MCInst &MI,\n"
263      << "    SmallVectorImpl<MCFixup> &Fixups,\n"
264      << "    APInt &Inst,\n"
265      << "    APInt &Scratch,\n"
266      << "    const MCSubtargetInfo &STI) const {\n";
267 
268   // Emit instruction base values
269   for (const auto &Mode : Modes)
270     emitInstructionBaseValues(OS, NumberedInstructions, Target, Mode.first);
271 
272   if (Modes.size() > 1) {
273     OS << "  unsigned Mode = STI.getHwMode();\n";
274   }
275 
276   for (const auto &Mode : Modes) {
277     // Emit helper function to retrieve base values.
278     OS << "  auto getInstBits" << Mode.second
279        << " = [&](unsigned Opcode) -> APInt {\n"
280        << "    unsigned NumBits = Index" << Mode.second << "[Opcode][0];\n"
281        << "    if (!NumBits)\n"
282        << "      return APInt::getZeroWidth();\n"
283        << "    unsigned Idx = Index" << Mode.second << "[Opcode][1];\n"
284        << "    ArrayRef<uint64_t> Data(&InstBits" << Mode.second << "[Idx], "
285        << "APInt::getNumWords(NumBits));\n"
286        << "    return APInt(NumBits, Data);\n"
287        << "  };\n";
288   }
289 
290   // Map to accumulate all the cases.
291   std::map<std::string, std::vector<std::string>> CaseMap;
292 
293   // Construct all cases statement for each opcode
294   for (const Record *R : Records.getAllDerivedDefinitions("Instruction")) {
295     if (R->getValueAsString("Namespace") == "TargetOpcode" ||
296         R->getValueAsBit("isPseudo"))
297       continue;
298     std::string InstName =
299         (R->getValueAsString("Namespace") + "::" + R->getName()).str();
300     std::string Case = getInstructionCases(R, Target);
301 
302     CaseMap[Case].push_back(std::move(InstName));
303   }
304 
305   // Emit initial function code
306   OS << "  const unsigned opcode = MI.getOpcode();\n"
307      << "  switch (opcode) {\n";
308 
309   // Emit each case statement
310   for (const auto &C : CaseMap) {
311     const std::string &Case = C.first;
312     const auto &InstList = C.second;
313 
314     ListSeparator LS("\n");
315     for (const auto &InstName : InstList)
316       OS << LS << "    case " << InstName << ":";
317 
318     OS << " {\n";
319     OS << Case;
320     OS << "      break;\n"
321        << "    }\n";
322   }
323   // Default case: unhandled opcode
324   OS << "  default:\n"
325      << "    std::string msg;\n"
326      << "    raw_string_ostream Msg(msg);\n"
327      << "    Msg << \"Not supported instr: \" << MI;\n"
328      << "    report_fatal_error(Msg.str().c_str());\n"
329      << "  }\n";
330   OS << "}\n\n";
331 }
332 
333 static void emitInstBits(raw_ostream &IS, raw_ostream &SS, const APInt &Bits,
334                          unsigned &Index) {
335   if (!Bits.getNumWords()) {
336     IS.indent(4) << "{/*NumBits*/0, /*Index*/0},";
337     return;
338   }
339 
340   IS.indent(4) << "{/*NumBits*/" << Bits.getBitWidth() << ", " << "/*Index*/"
341                << Index << "},";
342 
343   SS.indent(4);
344   for (unsigned I = 0; I < Bits.getNumWords(); ++I, ++Index)
345     SS << "UINT64_C(" << utostr(Bits.getRawData()[I]) << "),";
346 }
347 
348 void VarLenCodeEmitterGen::emitInstructionBaseValues(
349     raw_ostream &OS, ArrayRef<const CodeGenInstruction *> NumberedInstructions,
350     const CodeGenTarget &Target, AltEncodingTy Mode) {
351   std::string IndexArray, StorageArray;
352   raw_string_ostream IS(IndexArray), SS(StorageArray);
353 
354   IS << "  static const unsigned Index" << Modes[Mode] << "[][2] = {\n";
355   SS << "  static const uint64_t InstBits" << Modes[Mode] << "[] = {\n";
356 
357   unsigned NumFixedValueWords = 0U;
358   for (const CodeGenInstruction *CGI : NumberedInstructions) {
359     const Record *R = CGI->TheDef;
360 
361     if (R->getValueAsString("Namespace") == "TargetOpcode" ||
362         R->getValueAsBit("isPseudo")) {
363       IS.indent(4) << "{/*NumBits*/0, /*Index*/0},\n";
364       continue;
365     }
366 
367     const auto InstIt = VarLenInsts.find(R);
368     if (InstIt == VarLenInsts.end())
369       PrintFatalError(R, "VarLenInst not found for this record");
370     auto ModeIt = InstIt->second.find(Mode);
371     if (ModeIt == InstIt->second.end())
372       ModeIt = InstIt->second.find(Universal);
373     if (ModeIt == InstIt->second.end()) {
374       IS.indent(4) << "{/*NumBits*/0, /*Index*/0},\t" << "// " << R->getName()
375                    << " no encoding\n";
376       continue;
377     }
378     const VarLenInst &VLI = ModeIt->second;
379     unsigned i = 0U, BitWidth = VLI.size();
380 
381     // Start by filling in fixed values.
382     APInt Value(BitWidth, 0);
383     auto SI = VLI.begin(), SE = VLI.end();
384     // Scan through all the segments that have fixed-bits values.
385     while (i < BitWidth && SI != SE) {
386       unsigned SegmentNumBits = SI->BitWidth;
387       if (const auto *BI = dyn_cast<BitsInit>(SI->Value)) {
388         for (unsigned Idx = 0U; Idx != SegmentNumBits; ++Idx) {
389           auto *B = cast<BitInit>(BI->getBit(Idx));
390           Value.setBitVal(i + Idx, B->getValue());
391         }
392       }
393       if (const auto *BI = dyn_cast<BitInit>(SI->Value))
394         Value.setBitVal(i, BI->getValue());
395 
396       i += SegmentNumBits;
397       ++SI;
398     }
399 
400     emitInstBits(IS, SS, Value, NumFixedValueWords);
401     IS << '\t' << "// " << R->getName() << "\n";
402     if (Value.getNumWords())
403       SS << '\t' << "// " << R->getName() << "\n";
404   }
405   IS.indent(4) << "{/*NumBits*/0, /*Index*/0}\n  };\n";
406   SS.indent(4) << "UINT64_C(0)\n  };\n";
407 
408   OS << IndexArray << StorageArray;
409 }
410 
411 std::string
412 VarLenCodeEmitterGen::getInstructionCases(const Record *R,
413                                           const CodeGenTarget &Target) {
414   auto It = VarLenInsts.find(R);
415   if (It == VarLenInsts.end())
416     PrintFatalError(R, "Parsed encoding record not found");
417   const auto &Map = It->second;
418 
419   // Is this instructions encoding universal (same for all modes)?
420   // Allways true if there is only one mode.
421   if (Map.size() == 1 && Map.begin()->first == Universal) {
422     // Universal, just pick the first mode.
423     AltEncodingTy Mode = Modes.begin()->first;
424     const auto &Encoding = Map.begin()->second;
425     return getInstructionCaseForEncoding(R, Mode, Encoding, Target,
426                                          /*Indent=*/6);
427   }
428 
429   std::string Case;
430   Case += "      switch (Mode) {\n";
431   Case += "      default: llvm_unreachable(\"Unhandled Mode\");\n";
432   for (const auto &Mode : Modes) {
433     Case += "      case " + itostr(Mode.first) + ": {\n";
434     const auto &It = Map.find(Mode.first);
435     if (It == Map.end()) {
436       Case +=
437           "        llvm_unreachable(\"Undefined encoding in this mode\");\n";
438     } else {
439       Case += getInstructionCaseForEncoding(R, It->first, It->second, Target,
440                                             /*Indent=*/8);
441     }
442     Case += "        break;\n";
443     Case += "      }\n";
444   }
445   Case += "      }\n";
446   return Case;
447 }
448 
449 std::string VarLenCodeEmitterGen::getInstructionCaseForEncoding(
450     const Record *R, AltEncodingTy Mode, const VarLenInst &VLI,
451     const CodeGenTarget &Target, int Indent) {
452   CodeGenInstruction &CGI = Target.getInstruction(R);
453 
454   std::string Case;
455   raw_string_ostream SS(Case);
456   // Populate based value.
457   SS.indent(Indent) << "Inst = getInstBits" << Modes[Mode] << "(opcode);\n";
458 
459   // Process each segment in VLI.
460   size_t Offset = 0U;
461   unsigned HighScratchAccess = 0U;
462   for (const auto &ES : VLI) {
463     unsigned NumBits = ES.BitWidth;
464     const Init *Val = ES.Value;
465     // If it's a StringInit or DagInit, it's a reference to an operand
466     // or part of an operand.
467     if (isa<StringInit>(Val) || isa<DagInit>(Val)) {
468       StringRef OperandName;
469       unsigned LoBit = 0U;
470       if (const auto *SV = dyn_cast<StringInit>(Val)) {
471         OperandName = SV->getValue();
472       } else {
473         // Normalized: (slice <operand name>, <high bit>, <low bit>)
474         const auto *DV = cast<DagInit>(Val);
475         OperandName = cast<StringInit>(DV->getArg(0))->getValue();
476         LoBit = static_cast<unsigned>(cast<IntInit>(DV->getArg(2))->getValue());
477       }
478 
479       auto OpIdx = CGI.Operands.ParseOperandName(OperandName);
480       unsigned FlatOpIdx = CGI.Operands.getFlattenedOperandNumber(OpIdx);
481       StringRef CustomEncoder =
482           CGI.Operands[OpIdx.first].EncoderMethodNames[OpIdx.second];
483       if (ES.CustomEncoder.size())
484         CustomEncoder = ES.CustomEncoder;
485 
486       SS.indent(Indent) << "Scratch.clearAllBits();\n";
487       SS.indent(Indent) << "// op: " << OperandName.drop_front(1) << "\n";
488       if (CustomEncoder.empty())
489         SS.indent(Indent) << "getMachineOpValue(MI, MI.getOperand("
490                           << utostr(FlatOpIdx) << ")";
491       else
492         SS.indent(Indent) << CustomEncoder << "(MI, /*OpIdx=*/"
493                           << utostr(FlatOpIdx);
494 
495       SS << ", /*Pos=*/" << utostr(Offset) << ", Scratch, Fixups, STI);\n";
496 
497       SS.indent(Indent) << "Inst.insertBits("
498                         << "Scratch.extractBits(" << utostr(NumBits) << ", "
499                         << utostr(LoBit) << ")"
500                         << ", " << Offset << ");\n";
501 
502       HighScratchAccess = std::max(HighScratchAccess, NumBits + LoBit);
503     }
504     Offset += NumBits;
505   }
506 
507   StringRef PostEmitter = R->getValueAsString("PostEncoderMethod");
508   if (!PostEmitter.empty())
509     SS.indent(Indent) << "Inst = " << PostEmitter << "(MI, Inst, STI);\n";
510 
511   // Resize the scratch buffer if it's to small.
512   std::string ScratchResizeStr;
513   if (VLI.size() && !VLI.isFixedValueOnly()) {
514     raw_string_ostream RS(ScratchResizeStr);
515     RS.indent(Indent) << "if (Scratch.getBitWidth() < " << HighScratchAccess
516                       << ") { Scratch = Scratch.zext(" << HighScratchAccess
517                       << "); }\n";
518   }
519 
520   return ScratchResizeStr + Case;
521 }
522 
523 void llvm::emitVarLenCodeEmitter(const RecordKeeper &R, raw_ostream &OS) {
524   VarLenCodeEmitterGen(R).run(OS);
525 }
526