1 //==- WebAssemblyAsmTypeCheck.cpp - Assembler for WebAssembly -*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file is part of the WebAssembly Assembler.
11 ///
12 /// It contains code to translate a parsed .s file into MCInsts.
13 ///
14 //===----------------------------------------------------------------------===//
15
16 #include "AsmParser/WebAssemblyAsmTypeCheck.h"
17 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
18 #include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
19 #include "MCTargetDesc/WebAssemblyTargetStreamer.h"
20 #include "TargetInfo/WebAssemblyTargetInfo.h"
21 #include "WebAssembly.h"
22 #include "llvm/MC/MCContext.h"
23 #include "llvm/MC/MCExpr.h"
24 #include "llvm/MC/MCInst.h"
25 #include "llvm/MC/MCInstrInfo.h"
26 #include "llvm/MC/MCParser/MCParsedAsmOperand.h"
27 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
28 #include "llvm/MC/MCSectionWasm.h"
29 #include "llvm/MC/MCStreamer.h"
30 #include "llvm/MC/MCSubtargetInfo.h"
31 #include "llvm/MC/MCSymbol.h"
32 #include "llvm/MC/MCSymbolWasm.h"
33 #include "llvm/MC/TargetRegistry.h"
34 #include "llvm/Support/Compiler.h"
35 #include "llvm/Support/SourceMgr.h"
36
37 using namespace llvm;
38
39 #define DEBUG_TYPE "wasm-asm-parser"
40
41 extern StringRef GetMnemonic(unsigned Opc);
42
43 namespace llvm {
44
WebAssemblyAsmTypeCheck(MCAsmParser & Parser,const MCInstrInfo & MII,bool is64)45 WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser,
46 const MCInstrInfo &MII,
47 bool is64)
48 : Parser(Parser), MII(MII), is64(is64) {}
49
funcDecl(const wasm::WasmSignature & Sig)50 void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) {
51 LocalTypes.assign(Sig.Params.begin(), Sig.Params.end());
52 ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end());
53 BrStack.emplace_back(Sig.Returns.begin(), Sig.Returns.end());
54 }
55
localDecl(const SmallVectorImpl<wasm::ValType> & Locals)56 void WebAssemblyAsmTypeCheck::localDecl(
57 const SmallVectorImpl<wasm::ValType> &Locals) {
58 LocalTypes.insert(LocalTypes.end(), Locals.begin(), Locals.end());
59 }
60
dumpTypeStack(Twine Msg)61 void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
62 LLVM_DEBUG({
63 std::string s;
64 for (auto VT : Stack) {
65 s += WebAssembly::typeToString(VT);
66 s += " ";
67 }
68 dbgs() << Msg << s << '\n';
69 });
70 }
71
typeError(SMLoc ErrorLoc,const Twine & Msg)72 bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
73 // Once you get one type error in a function, it will likely trigger more
74 // which are mostly not helpful.
75 if (TypeErrorThisFunction)
76 return true;
77 // If we're currently in unreachable code, we suppress errors completely.
78 if (Unreachable)
79 return false;
80 TypeErrorThisFunction = true;
81 dumpTypeStack("current stack: ");
82 return Parser.Error(ErrorLoc, Msg);
83 }
84
popType(SMLoc ErrorLoc,std::optional<wasm::ValType> EVT)85 bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc,
86 std::optional<wasm::ValType> EVT) {
87 if (Stack.empty()) {
88 return typeError(ErrorLoc,
89 EVT ? StringRef("empty stack while popping ") +
90 WebAssembly::typeToString(*EVT)
91 : StringRef("empty stack while popping value"));
92 }
93 auto PVT = Stack.pop_back_val();
94 if (EVT && *EVT != PVT) {
95 return typeError(ErrorLoc,
96 StringRef("popped ") + WebAssembly::typeToString(PVT) +
97 ", expected " + WebAssembly::typeToString(*EVT));
98 }
99 return false;
100 }
101
popRefType(SMLoc ErrorLoc)102 bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
103 if (Stack.empty()) {
104 return typeError(ErrorLoc, StringRef("empty stack while popping reftype"));
105 }
106 auto PVT = Stack.pop_back_val();
107 if (!WebAssembly::isRefType(PVT)) {
108 return typeError(ErrorLoc, StringRef("popped ") +
109 WebAssembly::typeToString(PVT) +
110 ", expected reftype");
111 }
112 return false;
113 }
114
getLocal(SMLoc ErrorLoc,const MCInst & Inst,wasm::ValType & Type)115 bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst,
116 wasm::ValType &Type) {
117 auto Local = static_cast<size_t>(Inst.getOperand(0).getImm());
118 if (Local >= LocalTypes.size())
119 return typeError(ErrorLoc, StringRef("no local type specified for index ") +
120 std::to_string(Local));
121 Type = LocalTypes[Local];
122 return false;
123 }
124
125 static std::optional<std::string>
checkStackTop(const SmallVectorImpl<wasm::ValType> & ExpectedStackTop,const SmallVectorImpl<wasm::ValType> & Got)126 checkStackTop(const SmallVectorImpl<wasm::ValType> &ExpectedStackTop,
127 const SmallVectorImpl<wasm::ValType> &Got) {
128 for (size_t I = 0; I < ExpectedStackTop.size(); I++) {
129 auto EVT = ExpectedStackTop[I];
130 auto PVT = Got[Got.size() - ExpectedStackTop.size() + I];
131 if (PVT != EVT)
132 return std::string{"got "} + WebAssembly::typeToString(PVT) +
133 ", expected " + WebAssembly::typeToString(EVT);
134 }
135 return std::nullopt;
136 }
137
checkBr(SMLoc ErrorLoc,size_t Level)138 bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) {
139 if (Level >= BrStack.size())
140 return typeError(ErrorLoc,
141 StringRef("br: invalid depth ") + std::to_string(Level));
142 const SmallVector<wasm::ValType, 4> &Expected =
143 BrStack[BrStack.size() - Level - 1];
144 if (Expected.size() > Stack.size())
145 return typeError(ErrorLoc, "br: insufficient values on the type stack");
146 auto IsStackTopInvalid = checkStackTop(Expected, Stack);
147 if (IsStackTopInvalid)
148 return typeError(ErrorLoc, "br " + IsStackTopInvalid.value());
149 return false;
150 }
151
checkEnd(SMLoc ErrorLoc,bool PopVals)152 bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) {
153 if (!PopVals)
154 BrStack.pop_back();
155 if (LastSig.Returns.size() > Stack.size())
156 return typeError(ErrorLoc, "end: insufficient values on the type stack");
157
158 if (PopVals) {
159 for (auto VT : llvm::reverse(LastSig.Returns)) {
160 if (popType(ErrorLoc, VT))
161 return true;
162 }
163 return false;
164 }
165
166 auto IsStackTopInvalid = checkStackTop(LastSig.Returns, Stack);
167 if (IsStackTopInvalid)
168 return typeError(ErrorLoc, "end " + IsStackTopInvalid.value());
169 return false;
170 }
171
checkSig(SMLoc ErrorLoc,const wasm::WasmSignature & Sig)172 bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
173 const wasm::WasmSignature &Sig) {
174 for (auto VT : llvm::reverse(Sig.Params))
175 if (popType(ErrorLoc, VT))
176 return true;
177 Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end());
178 return false;
179 }
180
getSymRef(SMLoc ErrorLoc,const MCInst & Inst,const MCSymbolRefExpr * & SymRef)181 bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
182 const MCSymbolRefExpr *&SymRef) {
183 auto Op = Inst.getOperand(0);
184 if (!Op.isExpr())
185 return typeError(ErrorLoc, StringRef("expected expression operand"));
186 SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
187 if (!SymRef)
188 return typeError(ErrorLoc, StringRef("expected symbol operand"));
189 return false;
190 }
191
getGlobal(SMLoc ErrorLoc,const MCInst & Inst,wasm::ValType & Type)192 bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
193 wasm::ValType &Type) {
194 const MCSymbolRefExpr *SymRef;
195 if (getSymRef(ErrorLoc, Inst, SymRef))
196 return true;
197 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
198 switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
199 case wasm::WASM_SYMBOL_TYPE_GLOBAL:
200 Type = static_cast<wasm::ValType>(WasmSym->getGlobalType().Type);
201 break;
202 case wasm::WASM_SYMBOL_TYPE_FUNCTION:
203 case wasm::WASM_SYMBOL_TYPE_DATA:
204 switch (SymRef->getKind()) {
205 case MCSymbolRefExpr::VK_GOT:
206 case MCSymbolRefExpr::VK_WASM_GOT_TLS:
207 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
208 return false;
209 default:
210 break;
211 }
212 [[fallthrough]];
213 default:
214 return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
215 " missing .globaltype");
216 }
217 return false;
218 }
219
getTable(SMLoc ErrorLoc,const MCInst & Inst,wasm::ValType & Type)220 bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
221 wasm::ValType &Type) {
222 const MCSymbolRefExpr *SymRef;
223 if (getSymRef(ErrorLoc, Inst, SymRef))
224 return true;
225 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
226 if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
227 wasm::WASM_SYMBOL_TYPE_TABLE)
228 return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
229 " missing .tabletype");
230 Type = static_cast<wasm::ValType>(WasmSym->getTableType().ElemType);
231 return false;
232 }
233
endOfFunction(SMLoc ErrorLoc)234 bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
235 // Check the return types.
236 for (auto RVT : llvm::reverse(ReturnTypes)) {
237 if (popType(ErrorLoc, RVT))
238 return true;
239 }
240 if (!Stack.empty()) {
241 return typeError(ErrorLoc, std::to_string(Stack.size()) +
242 " superfluous return values");
243 }
244 Unreachable = true;
245 return false;
246 }
247
typeCheck(SMLoc ErrorLoc,const MCInst & Inst,OperandVector & Operands)248 bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
249 OperandVector &Operands) {
250 auto Opc = Inst.getOpcode();
251 auto Name = GetMnemonic(Opc);
252 dumpTypeStack("typechecking " + Name + ": ");
253 wasm::ValType Type;
254 if (Name == "local.get") {
255 if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
256 return true;
257 Stack.push_back(Type);
258 } else if (Name == "local.set") {
259 if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
260 return true;
261 if (popType(ErrorLoc, Type))
262 return true;
263 } else if (Name == "local.tee") {
264 if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
265 return true;
266 if (popType(ErrorLoc, Type))
267 return true;
268 Stack.push_back(Type);
269 } else if (Name == "global.get") {
270 if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
271 return true;
272 Stack.push_back(Type);
273 } else if (Name == "global.set") {
274 if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
275 return true;
276 if (popType(ErrorLoc, Type))
277 return true;
278 } else if (Name == "table.get") {
279 if (getTable(Operands[1]->getStartLoc(), Inst, Type))
280 return true;
281 if (popType(ErrorLoc, wasm::ValType::I32))
282 return true;
283 Stack.push_back(Type);
284 } else if (Name == "table.set") {
285 if (getTable(Operands[1]->getStartLoc(), Inst, Type))
286 return true;
287 if (popType(ErrorLoc, Type))
288 return true;
289 if (popType(ErrorLoc, wasm::ValType::I32))
290 return true;
291 } else if (Name == "table.size") {
292 if (getTable(Operands[1]->getStartLoc(), Inst, Type))
293 return true;
294 Stack.push_back(wasm::ValType::I32);
295 } else if (Name == "table.grow") {
296 if (getTable(Operands[1]->getStartLoc(), Inst, Type))
297 return true;
298 if (popType(ErrorLoc, wasm::ValType::I32))
299 return true;
300 if (popType(ErrorLoc, Type))
301 return true;
302 Stack.push_back(wasm::ValType::I32);
303 } else if (Name == "table.fill") {
304 if (getTable(Operands[1]->getStartLoc(), Inst, Type))
305 return true;
306 if (popType(ErrorLoc, wasm::ValType::I32))
307 return true;
308 if (popType(ErrorLoc, Type))
309 return true;
310 if (popType(ErrorLoc, wasm::ValType::I32))
311 return true;
312 } else if (Name == "memory.fill") {
313 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
314 if (popType(ErrorLoc, Type))
315 return true;
316 if (popType(ErrorLoc, wasm::ValType::I32))
317 return true;
318 if (popType(ErrorLoc, Type))
319 return true;
320 } else if (Name == "memory.copy") {
321 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
322 if (popType(ErrorLoc, Type))
323 return true;
324 if (popType(ErrorLoc, Type))
325 return true;
326 if (popType(ErrorLoc, Type))
327 return true;
328 } else if (Name == "memory.init") {
329 Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
330 if (popType(ErrorLoc, wasm::ValType::I32))
331 return true;
332 if (popType(ErrorLoc, wasm::ValType::I32))
333 return true;
334 if (popType(ErrorLoc, Type))
335 return true;
336 } else if (Name == "drop") {
337 if (popType(ErrorLoc, {}))
338 return true;
339 } else if (Name == "try" || Name == "block" || Name == "loop" ||
340 Name == "if") {
341 if (Name == "if" && popType(ErrorLoc, wasm::ValType::I32))
342 return true;
343 if (Name == "loop")
344 BrStack.emplace_back(LastSig.Params.begin(), LastSig.Params.end());
345 else
346 BrStack.emplace_back(LastSig.Returns.begin(), LastSig.Returns.end());
347 } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
348 Name == "else" || Name == "end_try" || Name == "catch" ||
349 Name == "catch_all" || Name == "delegate") {
350 if (checkEnd(ErrorLoc,
351 Name == "else" || Name == "catch" || Name == "catch_all"))
352 return true;
353 Unreachable = false;
354 if (Name == "catch") {
355 const MCSymbolRefExpr *SymRef;
356 if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
357 return true;
358 const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
359 const auto *Sig = WasmSym->getSignature();
360 if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG)
361 return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
362 WasmSym->getName() +
363 " missing .tagtype");
364 // catch instruction pushes values whose types are specified in the tag's
365 // "params" part
366 Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
367 }
368 } else if (Name == "br") {
369 const MCOperand &Operand = Inst.getOperand(0);
370 if (!Operand.isImm())
371 return false;
372 if (checkBr(ErrorLoc, static_cast<size_t>(Operand.getImm())))
373 return true;
374 } else if (Name == "return") {
375 if (endOfFunction(ErrorLoc))
376 return true;
377 } else if (Name == "call_indirect" || Name == "return_call_indirect") {
378 // Function value.
379 if (popType(ErrorLoc, wasm::ValType::I32))
380 return true;
381 if (checkSig(ErrorLoc, LastSig))
382 return true;
383 if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
384 return true;
385 } else if (Name == "call" || Name == "return_call") {
386 const MCSymbolRefExpr *SymRef;
387 if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
388 return true;
389 auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
390 auto Sig = WasmSym->getSignature();
391 if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION)
392 return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
393 WasmSym->getName() +
394 " missing .functype");
395 if (checkSig(ErrorLoc, *Sig))
396 return true;
397 if (Name == "return_call" && endOfFunction(ErrorLoc))
398 return true;
399 } else if (Name == "unreachable") {
400 Unreachable = true;
401 } else if (Name == "ref.is_null") {
402 if (popRefType(ErrorLoc))
403 return true;
404 Stack.push_back(wasm::ValType::I32);
405 } else {
406 // The current instruction is a stack instruction which doesn't have
407 // explicit operands that indicate push/pop types, so we get those from
408 // the register version of the same instruction.
409 auto RegOpc = WebAssembly::getRegisterOpcode(Opc);
410 assert(RegOpc != -1 && "Failed to get register version of MC instruction");
411 const auto &II = MII.get(RegOpc);
412 // First pop all the uses off the stack and check them.
413 for (unsigned I = II.getNumOperands(); I > II.getNumDefs(); I--) {
414 const auto &Op = II.operands()[I - 1];
415 if (Op.OperandType == MCOI::OPERAND_REGISTER) {
416 auto VT = WebAssembly::regClassToValType(Op.RegClass);
417 if (popType(ErrorLoc, VT))
418 return true;
419 }
420 }
421 // Now push all the defs onto the stack.
422 for (unsigned I = 0; I < II.getNumDefs(); I++) {
423 const auto &Op = II.operands()[I];
424 assert(Op.OperandType == MCOI::OPERAND_REGISTER && "Register expected");
425 auto VT = WebAssembly::regClassToValType(Op.RegClass);
426 Stack.push_back(VT);
427 }
428 }
429 return false;
430 }
431
432 } // end namespace llvm
433