xref: /freebsd/contrib/llvm-project/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- ProtocolServerMCP.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ProtocolServerMCP.h"
10 #include "MCPError.h"
11 #include "lldb/Core/PluginManager.h"
12 #include "lldb/Utility/LLDBLog.h"
13 #include "lldb/Utility/Log.h"
14 #include "llvm/ADT/StringExtras.h"
15 #include "llvm/Support/Threading.h"
16 #include <thread>
17 #include <variant>
18 
19 using namespace lldb_private;
20 using namespace lldb_private::mcp;
21 using namespace llvm;
22 
23 LLDB_PLUGIN_DEFINE(ProtocolServerMCP)
24 
25 static constexpr size_t kChunkSize = 1024;
26 
ProtocolServerMCP()27 ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {
28   AddRequestHandler("initialize",
29                     std::bind(&ProtocolServerMCP::InitializeHandler, this,
30                               std::placeholders::_1));
31 
32   AddRequestHandler("tools/list",
33                     std::bind(&ProtocolServerMCP::ToolsListHandler, this,
34                               std::placeholders::_1));
35   AddRequestHandler("tools/call",
36                     std::bind(&ProtocolServerMCP::ToolsCallHandler, this,
37                               std::placeholders::_1));
38 
39   AddRequestHandler("resources/list",
40                     std::bind(&ProtocolServerMCP::ResourcesListHandler, this,
41                               std::placeholders::_1));
42   AddRequestHandler("resources/read",
43                     std::bind(&ProtocolServerMCP::ResourcesReadHandler, this,
44                               std::placeholders::_1));
45   AddNotificationHandler(
46       "notifications/initialized", [](const protocol::Notification &) {
47         LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete");
48       });
49 
50   AddTool(
51       std::make_unique<CommandTool>("lldb_command", "Run an lldb command."));
52 
53   AddResourceProvider(std::make_unique<DebuggerResourceProvider>());
54 }
55 
~ProtocolServerMCP()56 ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); }
57 
Initialize()58 void ProtocolServerMCP::Initialize() {
59   PluginManager::RegisterPlugin(GetPluginNameStatic(),
60                                 GetPluginDescriptionStatic(), CreateInstance);
61 }
62 
Terminate()63 void ProtocolServerMCP::Terminate() {
64   PluginManager::UnregisterPlugin(CreateInstance);
65 }
66 
CreateInstance()67 lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() {
68   return std::make_unique<ProtocolServerMCP>();
69 }
70 
GetPluginDescriptionStatic()71 llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
72   return "MCP Server.";
73 }
74 
75 llvm::Expected<protocol::Response>
Handle(protocol::Request request)76 ProtocolServerMCP::Handle(protocol::Request request) {
77   auto it = m_request_handlers.find(request.method);
78   if (it != m_request_handlers.end()) {
79     llvm::Expected<protocol::Response> response = it->second(request);
80     if (!response)
81       return response;
82     response->id = request.id;
83     return *response;
84   }
85 
86   return make_error<MCPError>(
87       llvm::formatv("no handler for request: {0}", request.method).str());
88 }
89 
Handle(protocol::Notification notification)90 void ProtocolServerMCP::Handle(protocol::Notification notification) {
91   auto it = m_notification_handlers.find(notification.method);
92   if (it != m_notification_handlers.end()) {
93     it->second(notification);
94     return;
95   }
96 
97   LLDB_LOG(GetLog(LLDBLog::Host), "MPC notification: {0} ({1})",
98            notification.method, notification.params);
99 }
100 
AcceptCallback(std::unique_ptr<Socket> socket)101 void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) {
102   LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected",
103            m_clients.size() + 1);
104 
105   lldb::IOObjectSP io_sp = std::move(socket);
106   auto client_up = std::make_unique<Client>();
107   client_up->io_sp = io_sp;
108   Client *client = client_up.get();
109 
110   Status status;
111   auto read_handle_up = m_loop.RegisterReadObject(
112       io_sp,
113       [this, client](MainLoopBase &loop) {
114         if (Error error = ReadCallback(*client)) {
115           LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}");
116           client->read_handle_up.reset();
117         }
118       },
119       status);
120   if (status.Fail())
121     return;
122 
123   client_up->read_handle_up = std::move(read_handle_up);
124   m_clients.emplace_back(std::move(client_up));
125 }
126 
ReadCallback(Client & client)127 llvm::Error ProtocolServerMCP::ReadCallback(Client &client) {
128   char chunk[kChunkSize];
129   size_t bytes_read = sizeof(chunk);
130   if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail())
131     return status.takeError();
132   client.buffer.append(chunk, bytes_read);
133 
134   for (std::string::size_type pos;
135        (pos = client.buffer.find('\n')) != std::string::npos;) {
136     llvm::Expected<std::optional<protocol::Message>> message =
137         HandleData(StringRef(client.buffer.data(), pos));
138     client.buffer = client.buffer.erase(0, pos + 1);
139     if (!message)
140       return message.takeError();
141 
142     if (*message) {
143       std::string Output;
144       llvm::raw_string_ostream OS(Output);
145       OS << llvm::formatv("{0}", toJSON(**message)) << '\n';
146       size_t num_bytes = Output.size();
147       return client.io_sp->Write(Output.data(), num_bytes).takeError();
148     }
149   }
150 
151   return llvm::Error::success();
152 }
153 
Start(ProtocolServer::Connection connection)154 llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
155   std::lock_guard<std::mutex> guard(m_server_mutex);
156 
157   if (m_running)
158     return llvm::createStringError("the MCP server is already running");
159 
160   Status status;
161   m_listener = Socket::Create(connection.protocol, status);
162   if (status.Fail())
163     return status.takeError();
164 
165   status = m_listener->Listen(connection.name, /*backlog=*/5);
166   if (status.Fail())
167     return status.takeError();
168 
169   auto handles =
170       m_listener->Accept(m_loop, std::bind(&ProtocolServerMCP::AcceptCallback,
171                                            this, std::placeholders::_1));
172   if (llvm::Error error = handles.takeError())
173     return error;
174 
175   m_running = true;
176   m_listen_handlers = std::move(*handles);
177   m_loop_thread = std::thread([=] {
178     llvm::set_thread_name("protocol-server.mcp");
179     m_loop.Run();
180   });
181 
182   return llvm::Error::success();
183 }
184 
Stop()185 llvm::Error ProtocolServerMCP::Stop() {
186   {
187     std::lock_guard<std::mutex> guard(m_server_mutex);
188     if (!m_running)
189       return createStringError("the MCP sever is not running");
190     m_running = false;
191   }
192 
193   // Stop the main loop.
194   m_loop.AddPendingCallback(
195       [](MainLoopBase &loop) { loop.RequestTermination(); });
196 
197   // Wait for the main loop to exit.
198   if (m_loop_thread.joinable())
199     m_loop_thread.join();
200 
201   {
202     std::lock_guard<std::mutex> guard(m_server_mutex);
203     m_listener.reset();
204     m_listen_handlers.clear();
205     m_clients.clear();
206   }
207 
208   return llvm::Error::success();
209 }
210 
211 llvm::Expected<std::optional<protocol::Message>>
HandleData(llvm::StringRef data)212 ProtocolServerMCP::HandleData(llvm::StringRef data) {
213   auto message = llvm::json::parse<protocol::Message>(/*JSON=*/data);
214   if (!message)
215     return message.takeError();
216 
217   if (const protocol::Request *request =
218           std::get_if<protocol::Request>(&(*message))) {
219     llvm::Expected<protocol::Response> response = Handle(*request);
220 
221     // Handle failures by converting them into an Error message.
222     if (!response) {
223       protocol::Error protocol_error;
224       llvm::handleAllErrors(
225           response.takeError(),
226           [&](const MCPError &err) { protocol_error = err.toProtcolError(); },
227           [&](const llvm::ErrorInfoBase &err) {
228             protocol_error.error.code = MCPError::kInternalError;
229             protocol_error.error.message = err.message();
230           });
231       protocol_error.id = request->id;
232       return protocol_error;
233     }
234 
235     return *response;
236   }
237 
238   if (const protocol::Notification *notification =
239           std::get_if<protocol::Notification>(&(*message))) {
240     Handle(*notification);
241     return std::nullopt;
242   }
243 
244   if (std::get_if<protocol::Error>(&(*message)))
245     return llvm::createStringError("unexpected MCP message: error");
246 
247   if (std::get_if<protocol::Response>(&(*message)))
248     return llvm::createStringError("unexpected MCP message: response");
249 
250   llvm_unreachable("all message types handled");
251 }
252 
GetCapabilities()253 protocol::Capabilities ProtocolServerMCP::GetCapabilities() {
254   protocol::Capabilities capabilities;
255   capabilities.tools.listChanged = true;
256   // FIXME: Support sending notifications when a debugger/target are
257   // added/removed.
258   capabilities.resources.listChanged = false;
259   return capabilities;
260 }
261 
AddTool(std::unique_ptr<Tool> tool)262 void ProtocolServerMCP::AddTool(std::unique_ptr<Tool> tool) {
263   std::lock_guard<std::mutex> guard(m_server_mutex);
264 
265   if (!tool)
266     return;
267   m_tools[tool->GetName()] = std::move(tool);
268 }
269 
AddResourceProvider(std::unique_ptr<ResourceProvider> resource_provider)270 void ProtocolServerMCP::AddResourceProvider(
271     std::unique_ptr<ResourceProvider> resource_provider) {
272   std::lock_guard<std::mutex> guard(m_server_mutex);
273 
274   if (!resource_provider)
275     return;
276   m_resource_providers.push_back(std::move(resource_provider));
277 }
278 
AddRequestHandler(llvm::StringRef method,RequestHandler handler)279 void ProtocolServerMCP::AddRequestHandler(llvm::StringRef method,
280                                           RequestHandler handler) {
281   std::lock_guard<std::mutex> guard(m_server_mutex);
282   m_request_handlers[method] = std::move(handler);
283 }
284 
AddNotificationHandler(llvm::StringRef method,NotificationHandler handler)285 void ProtocolServerMCP::AddNotificationHandler(llvm::StringRef method,
286                                                NotificationHandler handler) {
287   std::lock_guard<std::mutex> guard(m_server_mutex);
288   m_notification_handlers[method] = std::move(handler);
289 }
290 
291 llvm::Expected<protocol::Response>
InitializeHandler(const protocol::Request & request)292 ProtocolServerMCP::InitializeHandler(const protocol::Request &request) {
293   protocol::Response response;
294   response.result.emplace(llvm::json::Object{
295       {"protocolVersion", protocol::kVersion},
296       {"capabilities", GetCapabilities()},
297       {"serverInfo",
298        llvm::json::Object{{"name", kName}, {"version", kVersion}}}});
299   return response;
300 }
301 
302 llvm::Expected<protocol::Response>
ToolsListHandler(const protocol::Request & request)303 ProtocolServerMCP::ToolsListHandler(const protocol::Request &request) {
304   protocol::Response response;
305 
306   llvm::json::Array tools;
307   for (const auto &tool : m_tools)
308     tools.emplace_back(toJSON(tool.second->GetDefinition()));
309 
310   response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}});
311 
312   return response;
313 }
314 
315 llvm::Expected<protocol::Response>
ToolsCallHandler(const protocol::Request & request)316 ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) {
317   protocol::Response response;
318 
319   if (!request.params)
320     return llvm::createStringError("no tool parameters");
321 
322   const json::Object *param_obj = request.params->getAsObject();
323   if (!param_obj)
324     return llvm::createStringError("no tool parameters");
325 
326   const json::Value *name = param_obj->get("name");
327   if (!name)
328     return llvm::createStringError("no tool name");
329 
330   llvm::StringRef tool_name = name->getAsString().value_or("");
331   if (tool_name.empty())
332     return llvm::createStringError("no tool name");
333 
334   auto it = m_tools.find(tool_name);
335   if (it == m_tools.end())
336     return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name));
337 
338   protocol::ToolArguments tool_args;
339   if (const json::Value *args = param_obj->get("arguments"))
340     tool_args = *args;
341 
342   llvm::Expected<protocol::TextResult> text_result =
343       it->second->Call(tool_args);
344   if (!text_result)
345     return text_result.takeError();
346 
347   response.result.emplace(toJSON(*text_result));
348 
349   return response;
350 }
351 
352 llvm::Expected<protocol::Response>
ResourcesListHandler(const protocol::Request & request)353 ProtocolServerMCP::ResourcesListHandler(const protocol::Request &request) {
354   protocol::Response response;
355 
356   llvm::json::Array resources;
357 
358   std::lock_guard<std::mutex> guard(m_server_mutex);
359   for (std::unique_ptr<ResourceProvider> &resource_provider_up :
360        m_resource_providers) {
361     for (const protocol::Resource &resource :
362          resource_provider_up->GetResources())
363       resources.push_back(resource);
364   }
365   response.result.emplace(
366       llvm::json::Object{{"resources", std::move(resources)}});
367 
368   return response;
369 }
370 
371 llvm::Expected<protocol::Response>
ResourcesReadHandler(const protocol::Request & request)372 ProtocolServerMCP::ResourcesReadHandler(const protocol::Request &request) {
373   protocol::Response response;
374 
375   if (!request.params)
376     return llvm::createStringError("no resource parameters");
377 
378   const json::Object *param_obj = request.params->getAsObject();
379   if (!param_obj)
380     return llvm::createStringError("no resource parameters");
381 
382   const json::Value *uri = param_obj->get("uri");
383   if (!uri)
384     return llvm::createStringError("no resource uri");
385 
386   llvm::StringRef uri_str = uri->getAsString().value_or("");
387   if (uri_str.empty())
388     return llvm::createStringError("no resource uri");
389 
390   std::lock_guard<std::mutex> guard(m_server_mutex);
391   for (std::unique_ptr<ResourceProvider> &resource_provider_up :
392        m_resource_providers) {
393     llvm::Expected<protocol::ResourceResult> result =
394         resource_provider_up->ReadResource(uri_str);
395     if (result.errorIsA<UnsupportedURI>()) {
396       llvm::consumeError(result.takeError());
397       continue;
398     }
399     if (!result)
400       return result.takeError();
401 
402     protocol::Response response;
403     response.result.emplace(std::move(*result));
404     return response;
405   }
406 
407   return make_error<MCPError>(
408       llvm::formatv("no resource handler for uri: {0}", uri_str).str(),
409       MCPError::kResourceNotFound);
410 }
411