xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/MCA/RISCVCustomBehaviour.cpp (revision 7d0873ebb83b19ba1e8a89e679470d885efe12e3)
1 //===------------------- RISCVCustomBehaviour.cpp ---------------*-C++ -* -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements methods from the RISCVCustomBehaviour class.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "RISCVCustomBehaviour.h"
15 #include "MCTargetDesc/RISCVMCTargetDesc.h"
16 #include "RISCV.h"
17 #include "TargetInfo/RISCVTargetInfo.h"
18 #include "llvm/MC/TargetRegistry.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "llvm-mca-riscv-custombehaviour"
22 
23 namespace llvm {
24 namespace mca {
25 
26 const llvm::StringRef RISCVLMULInstrument::DESC_NAME = "RISCV-LMUL";
27 
28 bool RISCVLMULInstrument::isDataValid(llvm::StringRef Data) {
29   // Return true if not one of the valid LMUL strings
30   return StringSwitch<bool>(Data)
31       .Cases("M1", "M2", "M4", "M8", "MF2", "MF4", "MF8", true)
32       .Default(false);
33 }
34 
35 uint8_t RISCVLMULInstrument::getLMUL() const {
36   // assertion prevents us from needing llvm_unreachable in the StringSwitch
37   // below
38   assert(isDataValid(getData()) &&
39          "Cannot get LMUL because invalid Data value");
40   // These are the LMUL values that are used in RISC-V tablegen
41   return StringSwitch<uint8_t>(getData())
42       .Case("M1", 0b000)
43       .Case("M2", 0b001)
44       .Case("M4", 0b010)
45       .Case("M8", 0b011)
46       .Case("MF2", 0b111)
47       .Case("MF4", 0b110)
48       .Case("MF8", 0b101);
49 }
50 
51 const llvm::StringRef RISCVSEWInstrument::DESC_NAME = "RISCV-SEW";
52 
53 bool RISCVSEWInstrument::isDataValid(llvm::StringRef Data) {
54   // Return true if not one of the valid SEW strings
55   return StringSwitch<bool>(Data)
56       .Cases("E8", "E16", "E32", "E64", true)
57       .Default(false);
58 }
59 
60 uint8_t RISCVSEWInstrument::getSEW() const {
61   // assertion prevents us from needing llvm_unreachable in the StringSwitch
62   // below
63   assert(isDataValid(getData()) && "Cannot get SEW because invalid Data value");
64   // These are the LMUL values that are used in RISC-V tablegen
65   return StringSwitch<uint8_t>(getData())
66       .Case("E8", 8)
67       .Case("E16", 16)
68       .Case("E32", 32)
69       .Case("E64", 64);
70 }
71 
72 bool RISCVInstrumentManager::supportsInstrumentType(
73     llvm::StringRef Type) const {
74   return Type == RISCVLMULInstrument::DESC_NAME ||
75          Type == RISCVSEWInstrument::DESC_NAME;
76 }
77 
78 UniqueInstrument
79 RISCVInstrumentManager::createInstrument(llvm::StringRef Desc,
80                                          llvm::StringRef Data) {
81   if (Desc == RISCVLMULInstrument::DESC_NAME) {
82     if (!RISCVLMULInstrument::isDataValid(Data)) {
83       LLVM_DEBUG(dbgs() << "RVCB: Bad data for instrument kind " << Desc << ": "
84                         << Data << '\n');
85       return nullptr;
86     }
87     return std::make_unique<RISCVLMULInstrument>(Data);
88   }
89 
90   if (Desc == RISCVSEWInstrument::DESC_NAME) {
91     if (!RISCVSEWInstrument::isDataValid(Data)) {
92       LLVM_DEBUG(dbgs() << "RVCB: Bad data for instrument kind " << Desc << ": "
93                         << Data << '\n');
94       return nullptr;
95     }
96     return std::make_unique<RISCVSEWInstrument>(Data);
97   }
98 
99   LLVM_DEBUG(dbgs() << "RVCB: Unknown instrumentation Desc: " << Desc << '\n');
100   return nullptr;
101 }
102 
103 SmallVector<UniqueInstrument>
104 RISCVInstrumentManager::createInstruments(const MCInst &Inst) {
105   if (Inst.getOpcode() == RISCV::VSETVLI ||
106       Inst.getOpcode() == RISCV::VSETIVLI) {
107     LLVM_DEBUG(dbgs() << "RVCB: Found VSETVLI and creating instrument for it: "
108                       << Inst << "\n");
109     unsigned VTypeI = Inst.getOperand(2).getImm();
110     RISCVII::VLMUL VLMUL = RISCVVType::getVLMUL(VTypeI);
111 
112     StringRef LMUL;
113     switch (VLMUL) {
114     case RISCVII::LMUL_1:
115       LMUL = "M1";
116       break;
117     case RISCVII::LMUL_2:
118       LMUL = "M2";
119       break;
120     case RISCVII::LMUL_4:
121       LMUL = "M4";
122       break;
123     case RISCVII::LMUL_8:
124       LMUL = "M8";
125       break;
126     case RISCVII::LMUL_F2:
127       LMUL = "MF2";
128       break;
129     case RISCVII::LMUL_F4:
130       LMUL = "MF4";
131       break;
132     case RISCVII::LMUL_F8:
133       LMUL = "MF8";
134       break;
135     case RISCVII::LMUL_RESERVED:
136       llvm_unreachable("Cannot create instrument for LMUL_RESERVED");
137     }
138     SmallVector<UniqueInstrument> Instruments;
139     Instruments.emplace_back(
140         createInstrument(RISCVLMULInstrument::DESC_NAME, LMUL));
141 
142     unsigned SEW = RISCVVType::getSEW(VTypeI);
143     StringRef SEWStr;
144     switch (SEW) {
145     case 8:
146       SEWStr = "E8";
147       break;
148     case 16:
149       SEWStr = "E16";
150       break;
151     case 32:
152       SEWStr = "E32";
153       break;
154     case 64:
155       SEWStr = "E64";
156       break;
157     default:
158       llvm_unreachable("Cannot create instrument for SEW");
159     }
160     Instruments.emplace_back(
161         createInstrument(RISCVSEWInstrument::DESC_NAME, SEWStr));
162 
163     return Instruments;
164   }
165   return SmallVector<UniqueInstrument>();
166 }
167 
168 static std::pair<uint8_t, uint8_t>
169 getEEWAndEMUL(unsigned Opcode, RISCVII::VLMUL LMUL, uint8_t SEW) {
170   uint8_t EEW;
171   switch (Opcode) {
172   case RISCV::VLM_V:
173   case RISCV::VSM_V:
174   case RISCV::VLE8_V:
175   case RISCV::VSE8_V:
176   case RISCV::VLSE8_V:
177   case RISCV::VSSE8_V:
178     EEW = 8;
179     break;
180   case RISCV::VLE16_V:
181   case RISCV::VSE16_V:
182   case RISCV::VLSE16_V:
183   case RISCV::VSSE16_V:
184     EEW = 16;
185     break;
186   case RISCV::VLE32_V:
187   case RISCV::VSE32_V:
188   case RISCV::VLSE32_V:
189   case RISCV::VSSE32_V:
190     EEW = 32;
191     break;
192   case RISCV::VLE64_V:
193   case RISCV::VSE64_V:
194   case RISCV::VLSE64_V:
195   case RISCV::VSSE64_V:
196     EEW = 64;
197     break;
198   default:
199     llvm_unreachable("Could not determine EEW from Opcode");
200   }
201 
202   auto EMUL = RISCVVType::getSameRatioLMUL(SEW, LMUL, EEW);
203   if (!EEW)
204     llvm_unreachable("Invalid SEW or LMUL for new ratio");
205   return std::make_pair(EEW, *EMUL);
206 }
207 
208 bool opcodeHasEEWAndEMULInfo(unsigned short Opcode) {
209   return Opcode == RISCV::VLM_V || Opcode == RISCV::VSM_V ||
210          Opcode == RISCV::VLE8_V || Opcode == RISCV::VSE8_V ||
211          Opcode == RISCV::VLE16_V || Opcode == RISCV::VSE16_V ||
212          Opcode == RISCV::VLE32_V || Opcode == RISCV::VSE32_V ||
213          Opcode == RISCV::VLE64_V || Opcode == RISCV::VSE64_V ||
214          Opcode == RISCV::VLSE8_V || Opcode == RISCV::VSSE8_V ||
215          Opcode == RISCV::VLSE16_V || Opcode == RISCV::VSSE16_V ||
216          Opcode == RISCV::VLSE32_V || Opcode == RISCV::VSSE32_V ||
217          Opcode == RISCV::VLSE64_V || Opcode == RISCV::VSSE64_V;
218 }
219 
220 unsigned RISCVInstrumentManager::getSchedClassID(
221     const MCInstrInfo &MCII, const MCInst &MCI,
222     const llvm::SmallVector<Instrument *> &IVec) const {
223   unsigned short Opcode = MCI.getOpcode();
224   unsigned SchedClassID = MCII.get(Opcode).getSchedClass();
225 
226   // Unpack all possible RISC-V instruments from IVec.
227   RISCVLMULInstrument *LI = nullptr;
228   RISCVSEWInstrument *SI = nullptr;
229   for (auto &I : IVec) {
230     if (I->getDesc() == RISCVLMULInstrument::DESC_NAME)
231       LI = static_cast<RISCVLMULInstrument *>(I);
232     else if (I->getDesc() == RISCVSEWInstrument::DESC_NAME)
233       SI = static_cast<RISCVSEWInstrument *>(I);
234   }
235 
236   // Need LMUL or LMUL, SEW in order to override opcode. If no LMUL is provided,
237   // then no option to override.
238   if (!LI) {
239     LLVM_DEBUG(
240         dbgs() << "RVCB: Did not use instrumentation to override Opcode.\n");
241     return SchedClassID;
242   }
243   uint8_t LMUL = LI->getLMUL();
244 
245   // getBaseInfo works with (Opcode, LMUL, 0) if no SEW instrument,
246   // or (Opcode, LMUL, SEW) if SEW instrument is active, and depends on LMUL
247   // and SEW, or (Opcode, LMUL, 0) if does not depend on SEW.
248   uint8_t SEW = SI ? SI->getSEW() : 0;
249 
250   const RISCVVInversePseudosTable::PseudoInfo *RVV = nullptr;
251   if (opcodeHasEEWAndEMULInfo(Opcode)) {
252     RISCVII::VLMUL VLMUL = static_cast<RISCVII::VLMUL>(LMUL);
253     auto [EEW, EMUL] = getEEWAndEMUL(Opcode, VLMUL, SEW);
254     RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, EMUL, EEW);
255   } else {
256     // Check if it depends on LMUL and SEW
257     RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, LMUL, SEW);
258     // Check if it depends only on LMUL
259     if (!RVV)
260       RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, LMUL, 0);
261   }
262 
263   // Not a RVV instr
264   if (!RVV) {
265     LLVM_DEBUG(
266         dbgs() << "RVCB: Could not find PseudoInstruction for Opcode "
267                << MCII.getName(Opcode)
268                << ", LMUL=" << (LI ? LI->getData() : "Unspecified")
269                << ", SEW=" << (SI ? SI->getData() : "Unspecified")
270                << ". Ignoring instrumentation and using original SchedClassID="
271                << SchedClassID << '\n');
272     return SchedClassID;
273   }
274 
275   // Override using pseudo
276   LLVM_DEBUG(dbgs() << "RVCB: Found Pseudo Instruction for Opcode "
277                     << MCII.getName(Opcode) << ", LMUL=" << LI->getData()
278                     << ", SEW=" << (SI ? SI->getData() : "Unspecified")
279                     << ". Overriding original SchedClassID=" << SchedClassID
280                     << " with " << MCII.getName(RVV->Pseudo) << '\n');
281   return MCII.get(RVV->Pseudo).getSchedClass();
282 }
283 
284 } // namespace mca
285 } // namespace llvm
286 
287 using namespace llvm;
288 using namespace mca;
289 
290 static InstrumentManager *
291 createRISCVInstrumentManager(const MCSubtargetInfo &STI,
292                              const MCInstrInfo &MCII) {
293   return new RISCVInstrumentManager(STI, MCII);
294 }
295 
296 /// Extern function to initialize the targets for the RISC-V backend
297 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTargetMCA() {
298   TargetRegistry::RegisterInstrumentManager(getTheRISCV32Target(),
299                                             createRISCVInstrumentManager);
300   TargetRegistry::RegisterInstrumentManager(getTheRISCV64Target(),
301                                             createRISCVInstrumentManager);
302 }
303