1 //===------------------- RISCVCustomBehaviour.cpp ---------------*-C++ -* -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements methods from the RISCVCustomBehaviour class.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #include "RISCVCustomBehaviour.h"
15 #include "MCTargetDesc/RISCVMCTargetDesc.h"
16 #include "RISCV.h"
17 #include "TargetInfo/RISCVTargetInfo.h"
18 #include "llvm/MC/TargetRegistry.h"
19 #include "llvm/Support/Debug.h"
20
21 #define DEBUG_TYPE "llvm-mca-riscv-custombehaviour"
22
23 namespace llvm {
24 namespace mca {
25
26 const llvm::StringRef RISCVLMULInstrument::DESC_NAME = "RISCV-LMUL";
27
isDataValid(llvm::StringRef Data)28 bool RISCVLMULInstrument::isDataValid(llvm::StringRef Data) {
29 // Return true if not one of the valid LMUL strings
30 return StringSwitch<bool>(Data)
31 .Cases("M1", "M2", "M4", "M8", "MF2", "MF4", "MF8", true)
32 .Default(false);
33 }
34
getLMUL() const35 uint8_t RISCVLMULInstrument::getLMUL() const {
36 // assertion prevents us from needing llvm_unreachable in the StringSwitch
37 // below
38 assert(isDataValid(getData()) &&
39 "Cannot get LMUL because invalid Data value");
40 // These are the LMUL values that are used in RISC-V tablegen
41 return StringSwitch<uint8_t>(getData())
42 .Case("M1", 0b000)
43 .Case("M2", 0b001)
44 .Case("M4", 0b010)
45 .Case("M8", 0b011)
46 .Case("MF2", 0b111)
47 .Case("MF4", 0b110)
48 .Case("MF8", 0b101);
49 }
50
51 const llvm::StringRef RISCVSEWInstrument::DESC_NAME = "RISCV-SEW";
52
isDataValid(llvm::StringRef Data)53 bool RISCVSEWInstrument::isDataValid(llvm::StringRef Data) {
54 // Return true if not one of the valid SEW strings
55 return StringSwitch<bool>(Data)
56 .Cases("E8", "E16", "E32", "E64", true)
57 .Default(false);
58 }
59
getSEW() const60 uint8_t RISCVSEWInstrument::getSEW() const {
61 // assertion prevents us from needing llvm_unreachable in the StringSwitch
62 // below
63 assert(isDataValid(getData()) && "Cannot get SEW because invalid Data value");
64 // These are the LMUL values that are used in RISC-V tablegen
65 return StringSwitch<uint8_t>(getData())
66 .Case("E8", 8)
67 .Case("E16", 16)
68 .Case("E32", 32)
69 .Case("E64", 64);
70 }
71
supportsInstrumentType(llvm::StringRef Type) const72 bool RISCVInstrumentManager::supportsInstrumentType(
73 llvm::StringRef Type) const {
74 return Type == RISCVLMULInstrument::DESC_NAME ||
75 Type == RISCVSEWInstrument::DESC_NAME;
76 }
77
78 UniqueInstrument
createInstrument(llvm::StringRef Desc,llvm::StringRef Data)79 RISCVInstrumentManager::createInstrument(llvm::StringRef Desc,
80 llvm::StringRef Data) {
81 if (Desc == RISCVLMULInstrument::DESC_NAME) {
82 if (!RISCVLMULInstrument::isDataValid(Data)) {
83 LLVM_DEBUG(dbgs() << "RVCB: Bad data for instrument kind " << Desc << ": "
84 << Data << '\n');
85 return nullptr;
86 }
87 return std::make_unique<RISCVLMULInstrument>(Data);
88 }
89
90 if (Desc == RISCVSEWInstrument::DESC_NAME) {
91 if (!RISCVSEWInstrument::isDataValid(Data)) {
92 LLVM_DEBUG(dbgs() << "RVCB: Bad data for instrument kind " << Desc << ": "
93 << Data << '\n');
94 return nullptr;
95 }
96 return std::make_unique<RISCVSEWInstrument>(Data);
97 }
98
99 LLVM_DEBUG(dbgs() << "RVCB: Unknown instrumentation Desc: " << Desc << '\n');
100 return nullptr;
101 }
102
103 SmallVector<UniqueInstrument>
createInstruments(const MCInst & Inst)104 RISCVInstrumentManager::createInstruments(const MCInst &Inst) {
105 if (Inst.getOpcode() == RISCV::VSETVLI ||
106 Inst.getOpcode() == RISCV::VSETIVLI) {
107 LLVM_DEBUG(dbgs() << "RVCB: Found VSETVLI and creating instrument for it: "
108 << Inst << "\n");
109 unsigned VTypeI = Inst.getOperand(2).getImm();
110 RISCVII::VLMUL VLMUL = RISCVVType::getVLMUL(VTypeI);
111
112 StringRef LMUL;
113 switch (VLMUL) {
114 case RISCVII::LMUL_1:
115 LMUL = "M1";
116 break;
117 case RISCVII::LMUL_2:
118 LMUL = "M2";
119 break;
120 case RISCVII::LMUL_4:
121 LMUL = "M4";
122 break;
123 case RISCVII::LMUL_8:
124 LMUL = "M8";
125 break;
126 case RISCVII::LMUL_F2:
127 LMUL = "MF2";
128 break;
129 case RISCVII::LMUL_F4:
130 LMUL = "MF4";
131 break;
132 case RISCVII::LMUL_F8:
133 LMUL = "MF8";
134 break;
135 case RISCVII::LMUL_RESERVED:
136 llvm_unreachable("Cannot create instrument for LMUL_RESERVED");
137 }
138 SmallVector<UniqueInstrument> Instruments;
139 Instruments.emplace_back(
140 createInstrument(RISCVLMULInstrument::DESC_NAME, LMUL));
141
142 unsigned SEW = RISCVVType::getSEW(VTypeI);
143 StringRef SEWStr;
144 switch (SEW) {
145 case 8:
146 SEWStr = "E8";
147 break;
148 case 16:
149 SEWStr = "E16";
150 break;
151 case 32:
152 SEWStr = "E32";
153 break;
154 case 64:
155 SEWStr = "E64";
156 break;
157 default:
158 llvm_unreachable("Cannot create instrument for SEW");
159 }
160 Instruments.emplace_back(
161 createInstrument(RISCVSEWInstrument::DESC_NAME, SEWStr));
162
163 return Instruments;
164 }
165 return SmallVector<UniqueInstrument>();
166 }
167
168 static std::pair<uint8_t, uint8_t>
getEEWAndEMUL(unsigned Opcode,RISCVII::VLMUL LMUL,uint8_t SEW)169 getEEWAndEMUL(unsigned Opcode, RISCVII::VLMUL LMUL, uint8_t SEW) {
170 uint8_t EEW;
171 switch (Opcode) {
172 case RISCV::VLM_V:
173 case RISCV::VSM_V:
174 case RISCV::VLE8_V:
175 case RISCV::VSE8_V:
176 case RISCV::VLSE8_V:
177 case RISCV::VSSE8_V:
178 EEW = 8;
179 break;
180 case RISCV::VLE16_V:
181 case RISCV::VSE16_V:
182 case RISCV::VLSE16_V:
183 case RISCV::VSSE16_V:
184 EEW = 16;
185 break;
186 case RISCV::VLE32_V:
187 case RISCV::VSE32_V:
188 case RISCV::VLSE32_V:
189 case RISCV::VSSE32_V:
190 EEW = 32;
191 break;
192 case RISCV::VLE64_V:
193 case RISCV::VSE64_V:
194 case RISCV::VLSE64_V:
195 case RISCV::VSSE64_V:
196 EEW = 64;
197 break;
198 default:
199 llvm_unreachable("Could not determine EEW from Opcode");
200 }
201
202 auto EMUL = RISCVVType::getSameRatioLMUL(SEW, LMUL, EEW);
203 if (!EEW)
204 llvm_unreachable("Invalid SEW or LMUL for new ratio");
205 return std::make_pair(EEW, *EMUL);
206 }
207
opcodeHasEEWAndEMULInfo(unsigned short Opcode)208 bool opcodeHasEEWAndEMULInfo(unsigned short Opcode) {
209 return Opcode == RISCV::VLM_V || Opcode == RISCV::VSM_V ||
210 Opcode == RISCV::VLE8_V || Opcode == RISCV::VSE8_V ||
211 Opcode == RISCV::VLE16_V || Opcode == RISCV::VSE16_V ||
212 Opcode == RISCV::VLE32_V || Opcode == RISCV::VSE32_V ||
213 Opcode == RISCV::VLE64_V || Opcode == RISCV::VSE64_V ||
214 Opcode == RISCV::VLSE8_V || Opcode == RISCV::VSSE8_V ||
215 Opcode == RISCV::VLSE16_V || Opcode == RISCV::VSSE16_V ||
216 Opcode == RISCV::VLSE32_V || Opcode == RISCV::VSSE32_V ||
217 Opcode == RISCV::VLSE64_V || Opcode == RISCV::VSSE64_V;
218 }
219
getSchedClassID(const MCInstrInfo & MCII,const MCInst & MCI,const llvm::SmallVector<Instrument * > & IVec) const220 unsigned RISCVInstrumentManager::getSchedClassID(
221 const MCInstrInfo &MCII, const MCInst &MCI,
222 const llvm::SmallVector<Instrument *> &IVec) const {
223 unsigned short Opcode = MCI.getOpcode();
224 unsigned SchedClassID = MCII.get(Opcode).getSchedClass();
225
226 // Unpack all possible RISC-V instruments from IVec.
227 RISCVLMULInstrument *LI = nullptr;
228 RISCVSEWInstrument *SI = nullptr;
229 for (auto &I : IVec) {
230 if (I->getDesc() == RISCVLMULInstrument::DESC_NAME)
231 LI = static_cast<RISCVLMULInstrument *>(I);
232 else if (I->getDesc() == RISCVSEWInstrument::DESC_NAME)
233 SI = static_cast<RISCVSEWInstrument *>(I);
234 }
235
236 // Need LMUL or LMUL, SEW in order to override opcode. If no LMUL is provided,
237 // then no option to override.
238 if (!LI) {
239 LLVM_DEBUG(
240 dbgs() << "RVCB: Did not use instrumentation to override Opcode.\n");
241 return SchedClassID;
242 }
243 uint8_t LMUL = LI->getLMUL();
244
245 // getBaseInfo works with (Opcode, LMUL, 0) if no SEW instrument,
246 // or (Opcode, LMUL, SEW) if SEW instrument is active, and depends on LMUL
247 // and SEW, or (Opcode, LMUL, 0) if does not depend on SEW.
248 uint8_t SEW = SI ? SI->getSEW() : 0;
249
250 const RISCVVInversePseudosTable::PseudoInfo *RVV = nullptr;
251 if (opcodeHasEEWAndEMULInfo(Opcode)) {
252 RISCVII::VLMUL VLMUL = static_cast<RISCVII::VLMUL>(LMUL);
253 auto [EEW, EMUL] = getEEWAndEMUL(Opcode, VLMUL, SEW);
254 RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, EMUL, EEW);
255 } else {
256 // Check if it depends on LMUL and SEW
257 RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, LMUL, SEW);
258 // Check if it depends only on LMUL
259 if (!RVV)
260 RVV = RISCVVInversePseudosTable::getBaseInfo(Opcode, LMUL, 0);
261 }
262
263 // Not a RVV instr
264 if (!RVV) {
265 LLVM_DEBUG(
266 dbgs() << "RVCB: Could not find PseudoInstruction for Opcode "
267 << MCII.getName(Opcode)
268 << ", LMUL=" << (LI ? LI->getData() : "Unspecified")
269 << ", SEW=" << (SI ? SI->getData() : "Unspecified")
270 << ". Ignoring instrumentation and using original SchedClassID="
271 << SchedClassID << '\n');
272 return SchedClassID;
273 }
274
275 // Override using pseudo
276 LLVM_DEBUG(dbgs() << "RVCB: Found Pseudo Instruction for Opcode "
277 << MCII.getName(Opcode) << ", LMUL=" << LI->getData()
278 << ", SEW=" << (SI ? SI->getData() : "Unspecified")
279 << ". Overriding original SchedClassID=" << SchedClassID
280 << " with " << MCII.getName(RVV->Pseudo) << '\n');
281 return MCII.get(RVV->Pseudo).getSchedClass();
282 }
283
284 } // namespace mca
285 } // namespace llvm
286
287 using namespace llvm;
288 using namespace mca;
289
290 static InstrumentManager *
createRISCVInstrumentManager(const MCSubtargetInfo & STI,const MCInstrInfo & MCII)291 createRISCVInstrumentManager(const MCSubtargetInfo &STI,
292 const MCInstrInfo &MCII) {
293 return new RISCVInstrumentManager(STI, MCII);
294 }
295
296 /// Extern function to initialize the targets for the RISC-V backend
LLVMInitializeRISCVTargetMCA()297 extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTargetMCA() {
298 TargetRegistry::RegisterInstrumentManager(getTheRISCV32Target(),
299 createRISCVInstrumentManager);
300 TargetRegistry::RegisterInstrumentManager(getTheRISCV64Target(),
301 createRISCVInstrumentManager);
302 }
303