xref: /freebsd/contrib/llvm-project/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //==- WebAssemblyAsmTypeCheck.cpp - Assembler for WebAssembly -*- C++ -*-==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file is part of the WebAssembly Assembler.
11 ///
12 /// It contains code to translate a parsed .s file into MCInsts.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #include "AsmParser/WebAssemblyAsmTypeCheck.h"
17 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
18 #include "MCTargetDesc/WebAssemblyMCTypeUtilities.h"
19 #include "MCTargetDesc/WebAssemblyTargetStreamer.h"
20 #include "TargetInfo/WebAssemblyTargetInfo.h"
21 #include "WebAssembly.h"
22 #include "llvm/MC/MCContext.h"
23 #include "llvm/MC/MCExpr.h"
24 #include "llvm/MC/MCInst.h"
25 #include "llvm/MC/MCInstrInfo.h"
26 #include "llvm/MC/MCParser/MCParsedAsmOperand.h"
27 #include "llvm/MC/MCParser/MCTargetAsmParser.h"
28 #include "llvm/MC/MCSectionWasm.h"
29 #include "llvm/MC/MCStreamer.h"
30 #include "llvm/MC/MCSubtargetInfo.h"
31 #include "llvm/MC/MCSymbol.h"
32 #include "llvm/MC/MCSymbolWasm.h"
33 #include "llvm/MC/TargetRegistry.h"
34 #include "llvm/Support/Compiler.h"
35 #include "llvm/Support/SourceMgr.h"
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "wasm-asm-parser"
40 
41 extern StringRef GetMnemonic(unsigned Opc);
42 
43 namespace llvm {
44 
WebAssemblyAsmTypeCheck(MCAsmParser & Parser,const MCInstrInfo & MII,bool is64)45 WebAssemblyAsmTypeCheck::WebAssemblyAsmTypeCheck(MCAsmParser &Parser,
46                                                  const MCInstrInfo &MII,
47                                                  bool is64)
48     : Parser(Parser), MII(MII), is64(is64) {}
49 
funcDecl(const wasm::WasmSignature & Sig)50 void WebAssemblyAsmTypeCheck::funcDecl(const wasm::WasmSignature &Sig) {
51   LocalTypes.assign(Sig.Params.begin(), Sig.Params.end());
52   ReturnTypes.assign(Sig.Returns.begin(), Sig.Returns.end());
53   BrStack.emplace_back(Sig.Returns.begin(), Sig.Returns.end());
54 }
55 
localDecl(const SmallVectorImpl<wasm::ValType> & Locals)56 void WebAssemblyAsmTypeCheck::localDecl(
57     const SmallVectorImpl<wasm::ValType> &Locals) {
58   LocalTypes.insert(LocalTypes.end(), Locals.begin(), Locals.end());
59 }
60 
dumpTypeStack(Twine Msg)61 void WebAssemblyAsmTypeCheck::dumpTypeStack(Twine Msg) {
62   LLVM_DEBUG({
63     std::string s;
64     for (auto VT : Stack) {
65       s += WebAssembly::typeToString(VT);
66       s += " ";
67     }
68     dbgs() << Msg << s << '\n';
69   });
70 }
71 
typeError(SMLoc ErrorLoc,const Twine & Msg)72 bool WebAssemblyAsmTypeCheck::typeError(SMLoc ErrorLoc, const Twine &Msg) {
73   // Once you get one type error in a function, it will likely trigger more
74   // which are mostly not helpful.
75   if (TypeErrorThisFunction)
76     return true;
77   // If we're currently in unreachable code, we suppress errors completely.
78   if (Unreachable)
79     return false;
80   TypeErrorThisFunction = true;
81   dumpTypeStack("current stack: ");
82   return Parser.Error(ErrorLoc, Msg);
83 }
84 
popType(SMLoc ErrorLoc,std::optional<wasm::ValType> EVT)85 bool WebAssemblyAsmTypeCheck::popType(SMLoc ErrorLoc,
86                                       std::optional<wasm::ValType> EVT) {
87   if (Stack.empty()) {
88     return typeError(ErrorLoc,
89                      EVT ? StringRef("empty stack while popping ") +
90                                WebAssembly::typeToString(*EVT)
91                          : StringRef("empty stack while popping value"));
92   }
93   auto PVT = Stack.pop_back_val();
94   if (EVT && *EVT != PVT) {
95     return typeError(ErrorLoc,
96                      StringRef("popped ") + WebAssembly::typeToString(PVT) +
97                          ", expected " + WebAssembly::typeToString(*EVT));
98   }
99   return false;
100 }
101 
popRefType(SMLoc ErrorLoc)102 bool WebAssemblyAsmTypeCheck::popRefType(SMLoc ErrorLoc) {
103   if (Stack.empty()) {
104     return typeError(ErrorLoc, StringRef("empty stack while popping reftype"));
105   }
106   auto PVT = Stack.pop_back_val();
107   if (!WebAssembly::isRefType(PVT)) {
108     return typeError(ErrorLoc, StringRef("popped ") +
109                                    WebAssembly::typeToString(PVT) +
110                                    ", expected reftype");
111   }
112   return false;
113 }
114 
getLocal(SMLoc ErrorLoc,const MCInst & Inst,wasm::ValType & Type)115 bool WebAssemblyAsmTypeCheck::getLocal(SMLoc ErrorLoc, const MCInst &Inst,
116                                        wasm::ValType &Type) {
117   auto Local = static_cast<size_t>(Inst.getOperand(0).getImm());
118   if (Local >= LocalTypes.size())
119     return typeError(ErrorLoc, StringRef("no local type specified for index ") +
120                                    std::to_string(Local));
121   Type = LocalTypes[Local];
122   return false;
123 }
124 
125 static std::optional<std::string>
checkStackTop(const SmallVectorImpl<wasm::ValType> & ExpectedStackTop,const SmallVectorImpl<wasm::ValType> & Got)126 checkStackTop(const SmallVectorImpl<wasm::ValType> &ExpectedStackTop,
127               const SmallVectorImpl<wasm::ValType> &Got) {
128   for (size_t I = 0; I < ExpectedStackTop.size(); I++) {
129     auto EVT = ExpectedStackTop[I];
130     auto PVT = Got[Got.size() - ExpectedStackTop.size() + I];
131     if (PVT != EVT)
132       return std::string{"got "} + WebAssembly::typeToString(PVT) +
133              ", expected " + WebAssembly::typeToString(EVT);
134   }
135   return std::nullopt;
136 }
137 
checkBr(SMLoc ErrorLoc,size_t Level)138 bool WebAssemblyAsmTypeCheck::checkBr(SMLoc ErrorLoc, size_t Level) {
139   if (Level >= BrStack.size())
140     return typeError(ErrorLoc,
141                      StringRef("br: invalid depth ") + std::to_string(Level));
142   const SmallVector<wasm::ValType, 4> &Expected =
143       BrStack[BrStack.size() - Level - 1];
144   if (Expected.size() > Stack.size())
145     return typeError(ErrorLoc, "br: insufficient values on the type stack");
146   auto IsStackTopInvalid = checkStackTop(Expected, Stack);
147   if (IsStackTopInvalid)
148     return typeError(ErrorLoc, "br " + IsStackTopInvalid.value());
149   return false;
150 }
151 
checkEnd(SMLoc ErrorLoc,bool PopVals)152 bool WebAssemblyAsmTypeCheck::checkEnd(SMLoc ErrorLoc, bool PopVals) {
153   if (!PopVals)
154     BrStack.pop_back();
155   if (LastSig.Returns.size() > Stack.size())
156     return typeError(ErrorLoc, "end: insufficient values on the type stack");
157 
158   if (PopVals) {
159     for (auto VT : llvm::reverse(LastSig.Returns)) {
160       if (popType(ErrorLoc, VT))
161         return true;
162     }
163     return false;
164   }
165 
166   auto IsStackTopInvalid = checkStackTop(LastSig.Returns, Stack);
167   if (IsStackTopInvalid)
168     return typeError(ErrorLoc, "end " + IsStackTopInvalid.value());
169   return false;
170 }
171 
checkSig(SMLoc ErrorLoc,const wasm::WasmSignature & Sig)172 bool WebAssemblyAsmTypeCheck::checkSig(SMLoc ErrorLoc,
173                                        const wasm::WasmSignature &Sig) {
174   for (auto VT : llvm::reverse(Sig.Params))
175     if (popType(ErrorLoc, VT))
176       return true;
177   Stack.insert(Stack.end(), Sig.Returns.begin(), Sig.Returns.end());
178   return false;
179 }
180 
getSymRef(SMLoc ErrorLoc,const MCInst & Inst,const MCSymbolRefExpr * & SymRef)181 bool WebAssemblyAsmTypeCheck::getSymRef(SMLoc ErrorLoc, const MCInst &Inst,
182                                         const MCSymbolRefExpr *&SymRef) {
183   auto Op = Inst.getOperand(0);
184   if (!Op.isExpr())
185     return typeError(ErrorLoc, StringRef("expected expression operand"));
186   SymRef = dyn_cast<MCSymbolRefExpr>(Op.getExpr());
187   if (!SymRef)
188     return typeError(ErrorLoc, StringRef("expected symbol operand"));
189   return false;
190 }
191 
getGlobal(SMLoc ErrorLoc,const MCInst & Inst,wasm::ValType & Type)192 bool WebAssemblyAsmTypeCheck::getGlobal(SMLoc ErrorLoc, const MCInst &Inst,
193                                         wasm::ValType &Type) {
194   const MCSymbolRefExpr *SymRef;
195   if (getSymRef(ErrorLoc, Inst, SymRef))
196     return true;
197   auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
198   switch (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA)) {
199   case wasm::WASM_SYMBOL_TYPE_GLOBAL:
200     Type = static_cast<wasm::ValType>(WasmSym->getGlobalType().Type);
201     break;
202   case wasm::WASM_SYMBOL_TYPE_FUNCTION:
203   case wasm::WASM_SYMBOL_TYPE_DATA:
204     switch (SymRef->getKind()) {
205     case MCSymbolRefExpr::VK_GOT:
206     case MCSymbolRefExpr::VK_WASM_GOT_TLS:
207       Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
208       return false;
209     default:
210       break;
211     }
212     [[fallthrough]];
213   default:
214     return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
215                                    " missing .globaltype");
216   }
217   return false;
218 }
219 
getTable(SMLoc ErrorLoc,const MCInst & Inst,wasm::ValType & Type)220 bool WebAssemblyAsmTypeCheck::getTable(SMLoc ErrorLoc, const MCInst &Inst,
221                                        wasm::ValType &Type) {
222   const MCSymbolRefExpr *SymRef;
223   if (getSymRef(ErrorLoc, Inst, SymRef))
224     return true;
225   auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
226   if (WasmSym->getType().value_or(wasm::WASM_SYMBOL_TYPE_DATA) !=
227       wasm::WASM_SYMBOL_TYPE_TABLE)
228     return typeError(ErrorLoc, StringRef("symbol ") + WasmSym->getName() +
229                                    " missing .tabletype");
230   Type = static_cast<wasm::ValType>(WasmSym->getTableType().ElemType);
231   return false;
232 }
233 
endOfFunction(SMLoc ErrorLoc)234 bool WebAssemblyAsmTypeCheck::endOfFunction(SMLoc ErrorLoc) {
235   // Check the return types.
236   for (auto RVT : llvm::reverse(ReturnTypes)) {
237     if (popType(ErrorLoc, RVT))
238       return true;
239   }
240   if (!Stack.empty()) {
241     return typeError(ErrorLoc, std::to_string(Stack.size()) +
242                                    " superfluous return values");
243   }
244   Unreachable = true;
245   return false;
246 }
247 
typeCheck(SMLoc ErrorLoc,const MCInst & Inst,OperandVector & Operands)248 bool WebAssemblyAsmTypeCheck::typeCheck(SMLoc ErrorLoc, const MCInst &Inst,
249                                         OperandVector &Operands) {
250   auto Opc = Inst.getOpcode();
251   auto Name = GetMnemonic(Opc);
252   dumpTypeStack("typechecking " + Name + ": ");
253   wasm::ValType Type;
254   if (Name == "local.get") {
255     if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
256       return true;
257     Stack.push_back(Type);
258   } else if (Name == "local.set") {
259     if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
260       return true;
261     if (popType(ErrorLoc, Type))
262       return true;
263   } else if (Name == "local.tee") {
264     if (getLocal(Operands[1]->getStartLoc(), Inst, Type))
265       return true;
266     if (popType(ErrorLoc, Type))
267       return true;
268     Stack.push_back(Type);
269   } else if (Name == "global.get") {
270     if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
271       return true;
272     Stack.push_back(Type);
273   } else if (Name == "global.set") {
274     if (getGlobal(Operands[1]->getStartLoc(), Inst, Type))
275       return true;
276     if (popType(ErrorLoc, Type))
277       return true;
278   } else if (Name == "table.get") {
279     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
280       return true;
281     if (popType(ErrorLoc, wasm::ValType::I32))
282       return true;
283     Stack.push_back(Type);
284   } else if (Name == "table.set") {
285     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
286       return true;
287     if (popType(ErrorLoc, Type))
288       return true;
289     if (popType(ErrorLoc, wasm::ValType::I32))
290       return true;
291   } else if (Name == "table.size") {
292     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
293       return true;
294     Stack.push_back(wasm::ValType::I32);
295   } else if (Name == "table.grow") {
296     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
297       return true;
298     if (popType(ErrorLoc, wasm::ValType::I32))
299       return true;
300     if (popType(ErrorLoc, Type))
301       return true;
302     Stack.push_back(wasm::ValType::I32);
303   } else if (Name == "table.fill") {
304     if (getTable(Operands[1]->getStartLoc(), Inst, Type))
305       return true;
306     if (popType(ErrorLoc, wasm::ValType::I32))
307       return true;
308     if (popType(ErrorLoc, Type))
309       return true;
310     if (popType(ErrorLoc, wasm::ValType::I32))
311       return true;
312   } else if (Name == "memory.fill") {
313     Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
314     if (popType(ErrorLoc, Type))
315       return true;
316     if (popType(ErrorLoc, wasm::ValType::I32))
317       return true;
318     if (popType(ErrorLoc, Type))
319       return true;
320   } else if (Name == "memory.copy") {
321     Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
322     if (popType(ErrorLoc, Type))
323       return true;
324     if (popType(ErrorLoc, Type))
325       return true;
326     if (popType(ErrorLoc, Type))
327       return true;
328   } else if (Name == "memory.init") {
329     Type = is64 ? wasm::ValType::I64 : wasm::ValType::I32;
330     if (popType(ErrorLoc, wasm::ValType::I32))
331       return true;
332     if (popType(ErrorLoc, wasm::ValType::I32))
333       return true;
334     if (popType(ErrorLoc, Type))
335       return true;
336   } else if (Name == "drop") {
337     if (popType(ErrorLoc, {}))
338       return true;
339   } else if (Name == "try" || Name == "block" || Name == "loop" ||
340              Name == "if") {
341     if (Name == "if" && popType(ErrorLoc, wasm::ValType::I32))
342       return true;
343     if (Name == "loop")
344       BrStack.emplace_back(LastSig.Params.begin(), LastSig.Params.end());
345     else
346       BrStack.emplace_back(LastSig.Returns.begin(), LastSig.Returns.end());
347   } else if (Name == "end_block" || Name == "end_loop" || Name == "end_if" ||
348              Name == "else" || Name == "end_try" || Name == "catch" ||
349              Name == "catch_all" || Name == "delegate") {
350     if (checkEnd(ErrorLoc,
351                  Name == "else" || Name == "catch" || Name == "catch_all"))
352       return true;
353     Unreachable = false;
354     if (Name == "catch") {
355       const MCSymbolRefExpr *SymRef;
356       if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
357         return true;
358       const auto *WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
359       const auto *Sig = WasmSym->getSignature();
360       if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_TAG)
361         return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
362                                                          WasmSym->getName() +
363                                                          " missing .tagtype");
364       // catch instruction pushes values whose types are specified in the tag's
365       // "params" part
366       Stack.insert(Stack.end(), Sig->Params.begin(), Sig->Params.end());
367     }
368   } else if (Name == "br") {
369     const MCOperand &Operand = Inst.getOperand(0);
370     if (!Operand.isImm())
371       return false;
372     if (checkBr(ErrorLoc, static_cast<size_t>(Operand.getImm())))
373       return true;
374   } else if (Name == "return") {
375     if (endOfFunction(ErrorLoc))
376       return true;
377   } else if (Name == "call_indirect" || Name == "return_call_indirect") {
378     // Function value.
379     if (popType(ErrorLoc, wasm::ValType::I32))
380       return true;
381     if (checkSig(ErrorLoc, LastSig))
382       return true;
383     if (Name == "return_call_indirect" && endOfFunction(ErrorLoc))
384       return true;
385   } else if (Name == "call" || Name == "return_call") {
386     const MCSymbolRefExpr *SymRef;
387     if (getSymRef(Operands[1]->getStartLoc(), Inst, SymRef))
388       return true;
389     auto WasmSym = cast<MCSymbolWasm>(&SymRef->getSymbol());
390     auto Sig = WasmSym->getSignature();
391     if (!Sig || WasmSym->getType() != wasm::WASM_SYMBOL_TYPE_FUNCTION)
392       return typeError(Operands[1]->getStartLoc(), StringRef("symbol ") +
393                                                        WasmSym->getName() +
394                                                        " missing .functype");
395     if (checkSig(ErrorLoc, *Sig))
396       return true;
397     if (Name == "return_call" && endOfFunction(ErrorLoc))
398       return true;
399   } else if (Name == "unreachable") {
400     Unreachable = true;
401   } else if (Name == "ref.is_null") {
402     if (popRefType(ErrorLoc))
403       return true;
404     Stack.push_back(wasm::ValType::I32);
405   } else {
406     // The current instruction is a stack instruction which doesn't have
407     // explicit operands that indicate push/pop types, so we get those from
408     // the register version of the same instruction.
409     auto RegOpc = WebAssembly::getRegisterOpcode(Opc);
410     assert(RegOpc != -1 && "Failed to get register version of MC instruction");
411     const auto &II = MII.get(RegOpc);
412     // First pop all the uses off the stack and check them.
413     for (unsigned I = II.getNumOperands(); I > II.getNumDefs(); I--) {
414       const auto &Op = II.operands()[I - 1];
415       if (Op.OperandType == MCOI::OPERAND_REGISTER) {
416         auto VT = WebAssembly::regClassToValType(Op.RegClass);
417         if (popType(ErrorLoc, VT))
418           return true;
419       }
420     }
421     // Now push all the defs onto the stack.
422     for (unsigned I = 0; I < II.getNumDefs(); I++) {
423       const auto &Op = II.operands()[I];
424       assert(Op.OperandType == MCOI::OPERAND_REGISTER && "Register expected");
425       auto VT = WebAssembly::regClassToValType(Op.RegClass);
426       Stack.push_back(VT);
427     }
428   }
429   return false;
430 }
431 
432 } // end namespace llvm
433