1 //===-- Acceptor.cpp --------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "Acceptor.h" 10 11 #include "llvm/ADT/StringRef.h" 12 #include "llvm/Support/ScopedPrinter.h" 13 14 #include "lldb/Host/ConnectionFileDescriptor.h" 15 #include "lldb/Host/common/TCPSocket.h" 16 #include "lldb/Utility/StreamString.h" 17 #include "lldb/Utility/UriParser.h" 18 #include <optional> 19 20 using namespace lldb; 21 using namespace lldb_private; 22 using namespace lldb_private::lldb_server; 23 using namespace llvm; 24 25 namespace { 26 27 struct SocketScheme { 28 const char *m_scheme; 29 const Socket::SocketProtocol m_protocol; 30 }; 31 32 SocketScheme socket_schemes[] = { 33 {"tcp", Socket::ProtocolTcp}, 34 {"udp", Socket::ProtocolUdp}, 35 {"unix", Socket::ProtocolUnixDomain}, 36 {"unix-abstract", Socket::ProtocolUnixAbstract}, 37 }; 38 39 bool FindProtocolByScheme(const char *scheme, 40 Socket::SocketProtocol &protocol) { 41 for (auto s : socket_schemes) { 42 if (!strcmp(s.m_scheme, scheme)) { 43 protocol = s.m_protocol; 44 return true; 45 } 46 } 47 return false; 48 } 49 50 const char *FindSchemeByProtocol(const Socket::SocketProtocol protocol) { 51 for (auto s : socket_schemes) { 52 if (s.m_protocol == protocol) 53 return s.m_scheme; 54 } 55 return nullptr; 56 } 57 } 58 59 Status Acceptor::Listen(int backlog) { 60 return m_listener_socket_up->Listen(StringRef(m_name), backlog); 61 } 62 63 Status Acceptor::Accept(const bool child_processes_inherit, Connection *&conn) { 64 Socket *conn_socket = nullptr; 65 auto error = m_listener_socket_up->Accept(conn_socket); 66 if (error.Success()) 67 conn = new ConnectionFileDescriptor(conn_socket); 68 69 return error; 70 } 71 72 Socket::SocketProtocol Acceptor::GetSocketProtocol() const { 73 return m_listener_socket_up->GetSocketProtocol(); 74 } 75 76 const char *Acceptor::GetSocketScheme() const { 77 return FindSchemeByProtocol(GetSocketProtocol()); 78 } 79 80 std::string Acceptor::GetLocalSocketId() const { return m_local_socket_id(); } 81 82 std::unique_ptr<Acceptor> Acceptor::Create(StringRef name, 83 const bool child_processes_inherit, 84 Status &error) { 85 error.Clear(); 86 87 Socket::SocketProtocol socket_protocol = Socket::ProtocolUnixDomain; 88 // Try to match socket name as URL - e.g., tcp://localhost:5555 89 if (std::optional<URI> res = URI::Parse(name)) { 90 if (!FindProtocolByScheme(res->scheme.str().c_str(), socket_protocol)) 91 error.SetErrorStringWithFormat("Unknown protocol scheme \"%s\"", 92 res->scheme.str().c_str()); 93 else 94 name = name.drop_front(res->scheme.size() + strlen("://")); 95 } else { 96 // Try to match socket name as $host:port - e.g., localhost:5555 97 if (!llvm::errorToBool(Socket::DecodeHostAndPort(name).takeError())) 98 socket_protocol = Socket::ProtocolTcp; 99 } 100 101 if (error.Fail()) 102 return std::unique_ptr<Acceptor>(); 103 104 std::unique_ptr<Socket> listener_socket_up = 105 Socket::Create(socket_protocol, child_processes_inherit, error); 106 107 LocalSocketIdFunc local_socket_id; 108 if (error.Success()) { 109 if (listener_socket_up->GetSocketProtocol() == Socket::ProtocolTcp) { 110 TCPSocket *tcp_socket = 111 static_cast<TCPSocket *>(listener_socket_up.get()); 112 local_socket_id = [tcp_socket]() { 113 auto local_port = tcp_socket->GetLocalPortNumber(); 114 return (local_port != 0) ? llvm::to_string(local_port) : ""; 115 }; 116 } else { 117 const std::string socket_name = std::string(name); 118 local_socket_id = [socket_name]() { return socket_name; }; 119 } 120 121 return std::unique_ptr<Acceptor>( 122 new Acceptor(std::move(listener_socket_up), name, local_socket_id)); 123 } 124 125 return std::unique_ptr<Acceptor>(); 126 } 127 128 Acceptor::Acceptor(std::unique_ptr<Socket> &&listener_socket, StringRef name, 129 const LocalSocketIdFunc &local_socket_id) 130 : m_listener_socket_up(std::move(listener_socket)), m_name(name.str()), 131 m_local_socket_id(local_socket_id) {} 132