1 //===- ProtocolServerMCP.cpp ----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "ProtocolServerMCP.h" 10 #include "MCPError.h" 11 #include "lldb/Core/PluginManager.h" 12 #include "lldb/Utility/LLDBLog.h" 13 #include "lldb/Utility/Log.h" 14 #include "llvm/ADT/StringExtras.h" 15 #include "llvm/Support/Threading.h" 16 #include <thread> 17 #include <variant> 18 19 using namespace lldb_private; 20 using namespace lldb_private::mcp; 21 using namespace llvm; 22 23 LLDB_PLUGIN_DEFINE(ProtocolServerMCP) 24 25 static constexpr size_t kChunkSize = 1024; 26 27 ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() { 28 AddRequestHandler("initialize", 29 std::bind(&ProtocolServerMCP::InitializeHandler, this, 30 std::placeholders::_1)); 31 32 AddRequestHandler("tools/list", 33 std::bind(&ProtocolServerMCP::ToolsListHandler, this, 34 std::placeholders::_1)); 35 AddRequestHandler("tools/call", 36 std::bind(&ProtocolServerMCP::ToolsCallHandler, this, 37 std::placeholders::_1)); 38 39 AddRequestHandler("resources/list", 40 std::bind(&ProtocolServerMCP::ResourcesListHandler, this, 41 std::placeholders::_1)); 42 AddRequestHandler("resources/read", 43 std::bind(&ProtocolServerMCP::ResourcesReadHandler, this, 44 std::placeholders::_1)); 45 AddNotificationHandler( 46 "notifications/initialized", [](const protocol::Notification &) { 47 LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete"); 48 }); 49 50 AddTool( 51 std::make_unique<CommandTool>("lldb_command", "Run an lldb command.")); 52 53 AddResourceProvider(std::make_unique<DebuggerResourceProvider>()); 54 } 55 56 ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } 57 58 void ProtocolServerMCP::Initialize() { 59 PluginManager::RegisterPlugin(GetPluginNameStatic(), 60 GetPluginDescriptionStatic(), CreateInstance); 61 } 62 63 void ProtocolServerMCP::Terminate() { 64 PluginManager::UnregisterPlugin(CreateInstance); 65 } 66 67 lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() { 68 return std::make_unique<ProtocolServerMCP>(); 69 } 70 71 llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { 72 return "MCP Server."; 73 } 74 75 llvm::Expected<protocol::Response> 76 ProtocolServerMCP::Handle(protocol::Request request) { 77 auto it = m_request_handlers.find(request.method); 78 if (it != m_request_handlers.end()) { 79 llvm::Expected<protocol::Response> response = it->second(request); 80 if (!response) 81 return response; 82 response->id = request.id; 83 return *response; 84 } 85 86 return make_error<MCPError>( 87 llvm::formatv("no handler for request: {0}", request.method).str()); 88 } 89 90 void ProtocolServerMCP::Handle(protocol::Notification notification) { 91 auto it = m_notification_handlers.find(notification.method); 92 if (it != m_notification_handlers.end()) { 93 it->second(notification); 94 return; 95 } 96 97 LLDB_LOG(GetLog(LLDBLog::Host), "MPC notification: {0} ({1})", 98 notification.method, notification.params); 99 } 100 101 void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) { 102 LLDB_LOG(GetLog(LLDBLog::Host), "New MCP client ({0}) connected", 103 m_clients.size() + 1); 104 105 lldb::IOObjectSP io_sp = std::move(socket); 106 auto client_up = std::make_unique<Client>(); 107 client_up->io_sp = io_sp; 108 Client *client = client_up.get(); 109 110 Status status; 111 auto read_handle_up = m_loop.RegisterReadObject( 112 io_sp, 113 [this, client](MainLoopBase &loop) { 114 if (Error error = ReadCallback(*client)) { 115 LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(error), "{0}"); 116 client->read_handle_up.reset(); 117 } 118 }, 119 status); 120 if (status.Fail()) 121 return; 122 123 client_up->read_handle_up = std::move(read_handle_up); 124 m_clients.emplace_back(std::move(client_up)); 125 } 126 127 llvm::Error ProtocolServerMCP::ReadCallback(Client &client) { 128 char chunk[kChunkSize]; 129 size_t bytes_read = sizeof(chunk); 130 if (Status status = client.io_sp->Read(chunk, bytes_read); status.Fail()) 131 return status.takeError(); 132 client.buffer.append(chunk, bytes_read); 133 134 for (std::string::size_type pos; 135 (pos = client.buffer.find('\n')) != std::string::npos;) { 136 llvm::Expected<std::optional<protocol::Message>> message = 137 HandleData(StringRef(client.buffer.data(), pos)); 138 client.buffer = client.buffer.erase(0, pos + 1); 139 if (!message) 140 return message.takeError(); 141 142 if (*message) { 143 std::string Output; 144 llvm::raw_string_ostream OS(Output); 145 OS << llvm::formatv("{0}", toJSON(**message)) << '\n'; 146 size_t num_bytes = Output.size(); 147 return client.io_sp->Write(Output.data(), num_bytes).takeError(); 148 } 149 } 150 151 return llvm::Error::success(); 152 } 153 154 llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { 155 std::lock_guard<std::mutex> guard(m_server_mutex); 156 157 if (m_running) 158 return llvm::createStringError("the MCP server is already running"); 159 160 Status status; 161 m_listener = Socket::Create(connection.protocol, status); 162 if (status.Fail()) 163 return status.takeError(); 164 165 status = m_listener->Listen(connection.name, /*backlog=*/5); 166 if (status.Fail()) 167 return status.takeError(); 168 169 auto handles = 170 m_listener->Accept(m_loop, std::bind(&ProtocolServerMCP::AcceptCallback, 171 this, std::placeholders::_1)); 172 if (llvm::Error error = handles.takeError()) 173 return error; 174 175 m_running = true; 176 m_listen_handlers = std::move(*handles); 177 m_loop_thread = std::thread([=] { 178 llvm::set_thread_name("protocol-server.mcp"); 179 m_loop.Run(); 180 }); 181 182 return llvm::Error::success(); 183 } 184 185 llvm::Error ProtocolServerMCP::Stop() { 186 { 187 std::lock_guard<std::mutex> guard(m_server_mutex); 188 if (!m_running) 189 return createStringError("the MCP sever is not running"); 190 m_running = false; 191 } 192 193 // Stop the main loop. 194 m_loop.AddPendingCallback( 195 [](MainLoopBase &loop) { loop.RequestTermination(); }); 196 197 // Wait for the main loop to exit. 198 if (m_loop_thread.joinable()) 199 m_loop_thread.join(); 200 201 { 202 std::lock_guard<std::mutex> guard(m_server_mutex); 203 m_listener.reset(); 204 m_listen_handlers.clear(); 205 m_clients.clear(); 206 } 207 208 return llvm::Error::success(); 209 } 210 211 llvm::Expected<std::optional<protocol::Message>> 212 ProtocolServerMCP::HandleData(llvm::StringRef data) { 213 auto message = llvm::json::parse<protocol::Message>(/*JSON=*/data); 214 if (!message) 215 return message.takeError(); 216 217 if (const protocol::Request *request = 218 std::get_if<protocol::Request>(&(*message))) { 219 llvm::Expected<protocol::Response> response = Handle(*request); 220 221 // Handle failures by converting them into an Error message. 222 if (!response) { 223 protocol::Error protocol_error; 224 llvm::handleAllErrors( 225 response.takeError(), 226 [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, 227 [&](const llvm::ErrorInfoBase &err) { 228 protocol_error.error.code = MCPError::kInternalError; 229 protocol_error.error.message = err.message(); 230 }); 231 protocol_error.id = request->id; 232 return protocol_error; 233 } 234 235 return *response; 236 } 237 238 if (const protocol::Notification *notification = 239 std::get_if<protocol::Notification>(&(*message))) { 240 Handle(*notification); 241 return std::nullopt; 242 } 243 244 if (std::get_if<protocol::Error>(&(*message))) 245 return llvm::createStringError("unexpected MCP message: error"); 246 247 if (std::get_if<protocol::Response>(&(*message))) 248 return llvm::createStringError("unexpected MCP message: response"); 249 250 llvm_unreachable("all message types handled"); 251 } 252 253 protocol::Capabilities ProtocolServerMCP::GetCapabilities() { 254 protocol::Capabilities capabilities; 255 capabilities.tools.listChanged = true; 256 // FIXME: Support sending notifications when a debugger/target are 257 // added/removed. 258 capabilities.resources.listChanged = false; 259 return capabilities; 260 } 261 262 void ProtocolServerMCP::AddTool(std::unique_ptr<Tool> tool) { 263 std::lock_guard<std::mutex> guard(m_server_mutex); 264 265 if (!tool) 266 return; 267 m_tools[tool->GetName()] = std::move(tool); 268 } 269 270 void ProtocolServerMCP::AddResourceProvider( 271 std::unique_ptr<ResourceProvider> resource_provider) { 272 std::lock_guard<std::mutex> guard(m_server_mutex); 273 274 if (!resource_provider) 275 return; 276 m_resource_providers.push_back(std::move(resource_provider)); 277 } 278 279 void ProtocolServerMCP::AddRequestHandler(llvm::StringRef method, 280 RequestHandler handler) { 281 std::lock_guard<std::mutex> guard(m_server_mutex); 282 m_request_handlers[method] = std::move(handler); 283 } 284 285 void ProtocolServerMCP::AddNotificationHandler(llvm::StringRef method, 286 NotificationHandler handler) { 287 std::lock_guard<std::mutex> guard(m_server_mutex); 288 m_notification_handlers[method] = std::move(handler); 289 } 290 291 llvm::Expected<protocol::Response> 292 ProtocolServerMCP::InitializeHandler(const protocol::Request &request) { 293 protocol::Response response; 294 response.result.emplace(llvm::json::Object{ 295 {"protocolVersion", protocol::kVersion}, 296 {"capabilities", GetCapabilities()}, 297 {"serverInfo", 298 llvm::json::Object{{"name", kName}, {"version", kVersion}}}}); 299 return response; 300 } 301 302 llvm::Expected<protocol::Response> 303 ProtocolServerMCP::ToolsListHandler(const protocol::Request &request) { 304 protocol::Response response; 305 306 llvm::json::Array tools; 307 for (const auto &tool : m_tools) 308 tools.emplace_back(toJSON(tool.second->GetDefinition())); 309 310 response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); 311 312 return response; 313 } 314 315 llvm::Expected<protocol::Response> 316 ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { 317 protocol::Response response; 318 319 if (!request.params) 320 return llvm::createStringError("no tool parameters"); 321 322 const json::Object *param_obj = request.params->getAsObject(); 323 if (!param_obj) 324 return llvm::createStringError("no tool parameters"); 325 326 const json::Value *name = param_obj->get("name"); 327 if (!name) 328 return llvm::createStringError("no tool name"); 329 330 llvm::StringRef tool_name = name->getAsString().value_or(""); 331 if (tool_name.empty()) 332 return llvm::createStringError("no tool name"); 333 334 auto it = m_tools.find(tool_name); 335 if (it == m_tools.end()) 336 return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); 337 338 protocol::ToolArguments tool_args; 339 if (const json::Value *args = param_obj->get("arguments")) 340 tool_args = *args; 341 342 llvm::Expected<protocol::TextResult> text_result = 343 it->second->Call(tool_args); 344 if (!text_result) 345 return text_result.takeError(); 346 347 response.result.emplace(toJSON(*text_result)); 348 349 return response; 350 } 351 352 llvm::Expected<protocol::Response> 353 ProtocolServerMCP::ResourcesListHandler(const protocol::Request &request) { 354 protocol::Response response; 355 356 llvm::json::Array resources; 357 358 std::lock_guard<std::mutex> guard(m_server_mutex); 359 for (std::unique_ptr<ResourceProvider> &resource_provider_up : 360 m_resource_providers) { 361 for (const protocol::Resource &resource : 362 resource_provider_up->GetResources()) 363 resources.push_back(resource); 364 } 365 response.result.emplace( 366 llvm::json::Object{{"resources", std::move(resources)}}); 367 368 return response; 369 } 370 371 llvm::Expected<protocol::Response> 372 ProtocolServerMCP::ResourcesReadHandler(const protocol::Request &request) { 373 protocol::Response response; 374 375 if (!request.params) 376 return llvm::createStringError("no resource parameters"); 377 378 const json::Object *param_obj = request.params->getAsObject(); 379 if (!param_obj) 380 return llvm::createStringError("no resource parameters"); 381 382 const json::Value *uri = param_obj->get("uri"); 383 if (!uri) 384 return llvm::createStringError("no resource uri"); 385 386 llvm::StringRef uri_str = uri->getAsString().value_or(""); 387 if (uri_str.empty()) 388 return llvm::createStringError("no resource uri"); 389 390 std::lock_guard<std::mutex> guard(m_server_mutex); 391 for (std::unique_ptr<ResourceProvider> &resource_provider_up : 392 m_resource_providers) { 393 llvm::Expected<protocol::ResourceResult> result = 394 resource_provider_up->ReadResource(uri_str); 395 if (result.errorIsA<UnsupportedURI>()) { 396 llvm::consumeError(result.takeError()); 397 continue; 398 } 399 if (!result) 400 return result.takeError(); 401 402 protocol::Response response; 403 response.result.emplace(std::move(*result)); 404 return response; 405 } 406 407 return make_error<MCPError>( 408 llvm::formatv("no resource handler for uri: {0}", uri_str).str(), 409 MCPError::kResourceNotFound); 410 } 411