xref: /freebsd/contrib/llvm-project/lldb/source/Host/common/Socket.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- Socket.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "lldb/Host/Socket.h"
10 
11 #include "lldb/Host/Config.h"
12 #include "lldb/Host/Host.h"
13 #include "lldb/Host/MainLoop.h"
14 #include "lldb/Host/SocketAddress.h"
15 #include "lldb/Host/common/TCPSocket.h"
16 #include "lldb/Host/common/UDPSocket.h"
17 #include "lldb/Utility/LLDBLog.h"
18 #include "lldb/Utility/Log.h"
19 
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Errno.h"
23 #include "llvm/Support/Error.h"
24 #include "llvm/Support/Regex.h"
25 #include "llvm/Support/WindowsError.h"
26 
27 #if LLDB_ENABLE_POSIX
28 #include "lldb/Host/posix/DomainSocket.h"
29 
30 #include <arpa/inet.h>
31 #include <netdb.h>
32 #include <netinet/in.h>
33 #include <netinet/tcp.h>
34 #include <sys/socket.h>
35 #include <sys/un.h>
36 #include <unistd.h>
37 #endif
38 
39 #ifdef __linux__
40 #include "lldb/Host/linux/AbstractSocket.h"
41 #endif
42 
43 using namespace lldb;
44 using namespace lldb_private;
45 
46 #if defined(_WIN32)
47 typedef const char *set_socket_option_arg_type;
48 typedef char *get_socket_option_arg_type;
49 const NativeSocket Socket::kInvalidSocketValue = INVALID_SOCKET;
50 const shared_fd_t SharedSocket::kInvalidFD = LLDB_INVALID_PIPE;
51 #else  // #if defined(_WIN32)
52 typedef const void *set_socket_option_arg_type;
53 typedef void *get_socket_option_arg_type;
54 const NativeSocket Socket::kInvalidSocketValue = -1;
55 const shared_fd_t SharedSocket::kInvalidFD = Socket::kInvalidSocketValue;
56 #endif // #if defined(_WIN32)
57 
IsInterrupted()58 static bool IsInterrupted() {
59 #if defined(_WIN32)
60   return ::WSAGetLastError() == WSAEINTR;
61 #else
62   return errno == EINTR;
63 #endif
64 }
65 
SharedSocket(const Socket * socket,Status & error)66 SharedSocket::SharedSocket(const Socket *socket, Status &error) {
67 #ifdef _WIN32
68   m_socket = socket->GetNativeSocket();
69   m_fd = kInvalidFD;
70 
71   // Create a pipe to transfer WSAPROTOCOL_INFO to the child process.
72   error = m_socket_pipe.CreateNew();
73   if (error.Fail())
74     return;
75 
76   m_fd = m_socket_pipe.GetReadPipe();
77 #else
78   m_fd = socket->GetNativeSocket();
79   error = Status();
80 #endif
81 }
82 
CompleteSending(lldb::pid_t child_pid)83 Status SharedSocket::CompleteSending(lldb::pid_t child_pid) {
84 #ifdef _WIN32
85   // Transfer WSAPROTOCOL_INFO to the child process.
86   m_socket_pipe.CloseReadFileDescriptor();
87 
88   WSAPROTOCOL_INFO protocol_info;
89   if (::WSADuplicateSocket(m_socket, child_pid, &protocol_info) ==
90       SOCKET_ERROR) {
91     int last_error = ::WSAGetLastError();
92     return Status::FromErrorStringWithFormat(
93         "WSADuplicateSocket() failed, error: %d", last_error);
94   }
95 
96   llvm::Expected<size_t> num_bytes = m_socket_pipe.Write(
97       &protocol_info, sizeof(protocol_info), std::chrono::seconds(10));
98   if (!num_bytes)
99     return Status::FromError(num_bytes.takeError());
100   if (*num_bytes != sizeof(protocol_info))
101     return Status::FromErrorStringWithFormatv(
102         "Write(WSAPROTOCOL_INFO) failed: wrote {0}/{1} bytes", *num_bytes,
103         sizeof(protocol_info));
104 #endif
105   return Status();
106 }
107 
GetNativeSocket(shared_fd_t fd,NativeSocket & socket)108 Status SharedSocket::GetNativeSocket(shared_fd_t fd, NativeSocket &socket) {
109 #ifdef _WIN32
110   socket = Socket::kInvalidSocketValue;
111   // Read WSAPROTOCOL_INFO from the parent process and create NativeSocket.
112   WSAPROTOCOL_INFO protocol_info;
113   {
114     Pipe socket_pipe(fd, LLDB_INVALID_PIPE);
115     llvm::Expected<size_t> num_bytes = socket_pipe.Read(
116         &protocol_info, sizeof(protocol_info), std::chrono::seconds(10));
117     if (!num_bytes)
118       return Status::FromError(num_bytes.takeError());
119     if (*num_bytes != sizeof(protocol_info)) {
120       return Status::FromErrorStringWithFormatv(
121           "Read(WSAPROTOCOL_INFO) failed: read {0}/{1} bytes", *num_bytes,
122           sizeof(protocol_info));
123     }
124   }
125   socket = ::WSASocket(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO,
126                        FROM_PROTOCOL_INFO, &protocol_info, 0, 0);
127   if (socket == INVALID_SOCKET) {
128     return Status::FromErrorStringWithFormatv(
129         "WSASocket(FROM_PROTOCOL_INFO) failed: error {0}", ::WSAGetLastError());
130   }
131   return Status();
132 #else
133   socket = fd;
134   return Status();
135 #endif
136 }
137 
138 struct SocketScheme {
139   const char *m_scheme;
140   const Socket::SocketProtocol m_protocol;
141 };
142 
143 static SocketScheme socket_schemes[] = {
144     {"tcp", Socket::ProtocolTcp},
145     {"udp", Socket::ProtocolUdp},
146     {"unix", Socket::ProtocolUnixDomain},
147     {"unix-abstract", Socket::ProtocolUnixAbstract},
148 };
149 
150 const char *
FindSchemeByProtocol(const Socket::SocketProtocol protocol)151 Socket::FindSchemeByProtocol(const Socket::SocketProtocol protocol) {
152   for (auto s : socket_schemes) {
153     if (s.m_protocol == protocol)
154       return s.m_scheme;
155   }
156   return nullptr;
157 }
158 
FindProtocolByScheme(const char * scheme,Socket::SocketProtocol & protocol)159 bool Socket::FindProtocolByScheme(const char *scheme,
160                                   Socket::SocketProtocol &protocol) {
161   for (auto s : socket_schemes) {
162     if (!strcmp(s.m_scheme, scheme)) {
163       protocol = s.m_protocol;
164       return true;
165     }
166   }
167   return false;
168 }
169 
Socket(SocketProtocol protocol,bool should_close)170 Socket::Socket(SocketProtocol protocol, bool should_close)
171     : IOObject(eFDTypeSocket), m_protocol(protocol),
172       m_socket(kInvalidSocketValue), m_should_close_fd(should_close) {}
173 
~Socket()174 Socket::~Socket() { Close(); }
175 
Initialize()176 llvm::Error Socket::Initialize() {
177 #if defined(_WIN32)
178   auto wVersion = WINSOCK_VERSION;
179   WSADATA wsaData;
180   int err = ::WSAStartup(wVersion, &wsaData);
181   if (err == 0) {
182     if (wsaData.wVersion < wVersion) {
183       WSACleanup();
184       return llvm::createStringError("WSASock version is not expected.");
185     }
186   } else {
187     return llvm::errorCodeToError(llvm::mapWindowsError(::WSAGetLastError()));
188   }
189 #endif
190 
191   return llvm::Error::success();
192 }
193 
Terminate()194 void Socket::Terminate() {
195 #if defined(_WIN32)
196   ::WSACleanup();
197 #endif
198 }
199 
Create(const SocketProtocol protocol,Status & error)200 std::unique_ptr<Socket> Socket::Create(const SocketProtocol protocol,
201                                        Status &error) {
202   error.Clear();
203 
204   const bool should_close = true;
205   std::unique_ptr<Socket> socket_up;
206   switch (protocol) {
207   case ProtocolTcp:
208     socket_up = std::make_unique<TCPSocket>(should_close);
209     break;
210   case ProtocolUdp:
211     socket_up = std::make_unique<UDPSocket>(should_close);
212     break;
213   case ProtocolUnixDomain:
214 #if LLDB_ENABLE_POSIX
215     socket_up = std::make_unique<DomainSocket>(should_close);
216 #else
217     error = Status::FromErrorString(
218         "Unix domain sockets are not supported on this platform.");
219 #endif
220     break;
221   case ProtocolUnixAbstract:
222 #ifdef __linux__
223     socket_up = std::make_unique<AbstractSocket>();
224 #else
225     error = Status::FromErrorString(
226         "Abstract domain sockets are not supported on this platform.");
227 #endif
228     break;
229   }
230 
231   if (error.Fail())
232     socket_up.reset();
233 
234   return socket_up;
235 }
236 
237 llvm::Expected<Socket::Pair>
CreatePair(std::optional<SocketProtocol> protocol)238 Socket::CreatePair(std::optional<SocketProtocol> protocol) {
239   constexpr SocketProtocol kBestProtocol =
240       LLDB_ENABLE_POSIX ? ProtocolUnixDomain : ProtocolTcp;
241   switch (protocol.value_or(kBestProtocol)) {
242   case ProtocolTcp:
243     return TCPSocket::CreatePair();
244 #if LLDB_ENABLE_POSIX
245   case ProtocolUnixDomain:
246   case ProtocolUnixAbstract:
247     return DomainSocket::CreatePair();
248 #endif
249   default:
250     return llvm::createStringError("Unsupported protocol");
251   }
252 }
253 
254 llvm::Expected<std::unique_ptr<Socket>>
TcpConnect(llvm::StringRef host_and_port)255 Socket::TcpConnect(llvm::StringRef host_and_port) {
256   Log *log = GetLog(LLDBLog::Connection);
257   LLDB_LOG(log, "host_and_port = {0}", host_and_port);
258 
259   Status error;
260   std::unique_ptr<Socket> connect_socket = Create(ProtocolTcp, error);
261   if (error.Fail())
262     return error.ToError();
263 
264   error = connect_socket->Connect(host_and_port);
265   if (error.Success())
266     return std::move(connect_socket);
267 
268   return error.ToError();
269 }
270 
271 llvm::Expected<std::unique_ptr<TCPSocket>>
TcpListen(llvm::StringRef host_and_port,int backlog)272 Socket::TcpListen(llvm::StringRef host_and_port, int backlog) {
273   Log *log = GetLog(LLDBLog::Connection);
274   LLDB_LOG(log, "host_and_port = {0}", host_and_port);
275 
276   std::unique_ptr<TCPSocket> listen_socket(
277       new TCPSocket(/*should_close=*/true));
278 
279   Status error = listen_socket->Listen(host_and_port, backlog);
280   if (error.Fail())
281     return error.ToError();
282 
283   return std::move(listen_socket);
284 }
285 
286 llvm::Expected<std::unique_ptr<UDPSocket>>
UdpConnect(llvm::StringRef host_and_port)287 Socket::UdpConnect(llvm::StringRef host_and_port) {
288   return UDPSocket::CreateConnected(host_and_port);
289 }
290 
291 llvm::Expected<Socket::HostAndPort>
DecodeHostAndPort(llvm::StringRef host_and_port)292 Socket::DecodeHostAndPort(llvm::StringRef host_and_port) {
293   static llvm::Regex g_regex("([^:]+|\\[[0-9a-fA-F:]+.*\\]):([0-9]+)");
294   HostAndPort ret;
295   llvm::SmallVector<llvm::StringRef, 3> matches;
296   if (g_regex.match(host_and_port, &matches)) {
297     ret.hostname = matches[1].str();
298     // IPv6 addresses are wrapped in [] when specified with ports
299     if (ret.hostname.front() == '[' && ret.hostname.back() == ']')
300       ret.hostname = ret.hostname.substr(1, ret.hostname.size() - 2);
301     if (to_integer(matches[2], ret.port, 10))
302       return ret;
303   } else {
304     // If this was unsuccessful, then check if it's simply an unsigned 16-bit
305     // integer, representing a port with an empty host.
306     if (to_integer(host_and_port, ret.port, 10))
307       return ret;
308   }
309 
310   return llvm::createStringError(llvm::inconvertibleErrorCode(),
311                                  "invalid host:port specification: '%s'",
312                                  host_and_port.str().c_str());
313 }
314 
GetWaitableHandle()315 IOObject::WaitableHandle Socket::GetWaitableHandle() {
316   return (IOObject::WaitableHandle)m_socket;
317 }
318 
Read(void * buf,size_t & num_bytes)319 Status Socket::Read(void *buf, size_t &num_bytes) {
320   Status error;
321   int bytes_received = 0;
322   do {
323     bytes_received = ::recv(m_socket, static_cast<char *>(buf), num_bytes, 0);
324   } while (bytes_received < 0 && IsInterrupted());
325 
326   if (bytes_received < 0) {
327     SetLastError(error);
328     num_bytes = 0;
329   } else
330     num_bytes = bytes_received;
331 
332   Log *log = GetLog(LLDBLog::Communication);
333   if (log) {
334     LLDB_LOGF(log,
335               "%p Socket::Read() (socket = %" PRIu64
336               ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64
337               " (error = %s)",
338               static_cast<void *>(this), static_cast<uint64_t>(m_socket), buf,
339               static_cast<uint64_t>(num_bytes),
340               static_cast<int64_t>(bytes_received), error.AsCString());
341   }
342 
343   return error;
344 }
345 
Write(const void * buf,size_t & num_bytes)346 Status Socket::Write(const void *buf, size_t &num_bytes) {
347   const size_t src_len = num_bytes;
348   Status error;
349   int bytes_sent = 0;
350   do {
351     bytes_sent = Send(buf, num_bytes);
352   } while (bytes_sent < 0 && IsInterrupted());
353 
354   if (bytes_sent < 0) {
355     SetLastError(error);
356     num_bytes = 0;
357   } else
358     num_bytes = bytes_sent;
359 
360   Log *log = GetLog(LLDBLog::Communication);
361   if (log) {
362     LLDB_LOGF(log,
363               "%p Socket::Write() (socket = %" PRIu64
364               ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64
365               " (error = %s)",
366               static_cast<void *>(this), static_cast<uint64_t>(m_socket), buf,
367               static_cast<uint64_t>(src_len), static_cast<int64_t>(bytes_sent),
368               error.AsCString());
369   }
370 
371   return error;
372 }
373 
Close()374 Status Socket::Close() {
375   Status error;
376   if (!IsValid() || !m_should_close_fd)
377     return error;
378 
379   Log *log = GetLog(LLDBLog::Connection);
380   LLDB_LOGF(log, "%p Socket::Close (fd = %" PRIu64 ")",
381             static_cast<void *>(this), static_cast<uint64_t>(m_socket));
382 
383   bool success = CloseSocket(m_socket) == 0;
384   // A reference to a FD was passed in, set it to an invalid value
385   m_socket = kInvalidSocketValue;
386   if (!success) {
387     SetLastError(error);
388   }
389 
390   return error;
391 }
392 
GetOption(NativeSocket sockfd,int level,int option_name,int & option_value)393 int Socket::GetOption(NativeSocket sockfd, int level, int option_name,
394                       int &option_value) {
395   get_socket_option_arg_type option_value_p =
396       reinterpret_cast<get_socket_option_arg_type>(&option_value);
397   socklen_t option_value_size = sizeof(int);
398   return ::getsockopt(sockfd, level, option_name, option_value_p,
399                       &option_value_size);
400 }
401 
SetOption(NativeSocket sockfd,int level,int option_name,int option_value)402 int Socket::SetOption(NativeSocket sockfd, int level, int option_name,
403                       int option_value) {
404   set_socket_option_arg_type option_value_p =
405       reinterpret_cast<set_socket_option_arg_type>(&option_value);
406   return ::setsockopt(sockfd, level, option_name, option_value_p,
407                       sizeof(option_value));
408 }
409 
Send(const void * buf,const size_t num_bytes)410 size_t Socket::Send(const void *buf, const size_t num_bytes) {
411   return ::send(m_socket, static_cast<const char *>(buf), num_bytes, 0);
412 }
413 
SetLastError(Status & error)414 void Socket::SetLastError(Status &error) {
415 #if defined(_WIN32)
416   error = Status(::WSAGetLastError(), lldb::eErrorTypeWin32);
417 #else
418   error = Status::FromErrno();
419 #endif
420 }
421 
GetLastError()422 Status Socket::GetLastError() {
423   std::error_code EC;
424 #ifdef _WIN32
425   EC = llvm::mapWindowsError(WSAGetLastError());
426 #else
427   EC = std::error_code(errno, std::generic_category());
428 #endif
429   return EC;
430 }
431 
CloseSocket(NativeSocket sockfd)432 int Socket::CloseSocket(NativeSocket sockfd) {
433 #ifdef _WIN32
434   return ::closesocket(sockfd);
435 #else
436   return ::close(sockfd);
437 #endif
438 }
439 
CreateSocket(const int domain,const int type,const int protocol,Status & error)440 NativeSocket Socket::CreateSocket(const int domain, const int type,
441                                   const int protocol, Status &error) {
442   error.Clear();
443   auto socket_type = type;
444 #ifdef SOCK_CLOEXEC
445   socket_type |= SOCK_CLOEXEC;
446 #endif
447   auto sock = ::socket(domain, socket_type, protocol);
448   if (sock == kInvalidSocketValue)
449     SetLastError(error);
450 
451   return sock;
452 }
453 
Accept(const Timeout<std::micro> & timeout,Socket * & socket)454 Status Socket::Accept(const Timeout<std::micro> &timeout, Socket *&socket) {
455   socket = nullptr;
456   MainLoop accept_loop;
457   llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> expected_handles =
458       Accept(accept_loop,
459              [&accept_loop, &socket](std::unique_ptr<Socket> sock) {
460                socket = sock.release();
461                accept_loop.RequestTermination();
462              });
463   if (!expected_handles)
464     return Status::FromError(expected_handles.takeError());
465   if (timeout) {
466     accept_loop.AddCallback(
467         [](MainLoopBase &loop) { loop.RequestTermination(); }, *timeout);
468   }
469   if (Status status = accept_loop.Run(); status.Fail())
470     return status;
471   if (socket)
472     return Status();
473   return Status(std::make_error_code(std::errc::timed_out));
474 }
475 
AcceptSocket(NativeSocket sockfd,struct sockaddr * addr,socklen_t * addrlen,Status & error)476 NativeSocket Socket::AcceptSocket(NativeSocket sockfd, struct sockaddr *addr,
477                                   socklen_t *addrlen, Status &error) {
478   error.Clear();
479 #if defined(SOCK_CLOEXEC) && defined(HAVE_ACCEPT4)
480   int flags = SOCK_CLOEXEC;
481   NativeSocket fd = llvm::sys::RetryAfterSignal(
482       static_cast<NativeSocket>(-1), ::accept4, sockfd, addr, addrlen, flags);
483 #else
484   NativeSocket fd = llvm::sys::RetryAfterSignal(
485       static_cast<NativeSocket>(-1), ::accept, sockfd, addr, addrlen);
486 #endif
487   if (fd == kInvalidSocketValue)
488     SetLastError(error);
489   return fd;
490 }
491 
operator <<(llvm::raw_ostream & OS,const Socket::HostAndPort & HP)492 llvm::raw_ostream &lldb_private::operator<<(llvm::raw_ostream &OS,
493                                             const Socket::HostAndPort &HP) {
494   return OS << '[' << HP.hostname << ']' << ':' << HP.port;
495 }
496 
497 std::optional<Socket::ProtocolModePair>
GetProtocolAndMode(llvm::StringRef scheme)498 Socket::GetProtocolAndMode(llvm::StringRef scheme) {
499   // Keep in sync with ConnectionFileDescriptor::Connect.
500   return llvm::StringSwitch<std::optional<ProtocolModePair>>(scheme)
501       .Case("listen", ProtocolModePair{SocketProtocol::ProtocolTcp,
502                                        SocketMode::ModeAccept})
503       .Cases("accept", "unix-accept",
504              ProtocolModePair{SocketProtocol::ProtocolUnixDomain,
505                               SocketMode::ModeAccept})
506       .Case("unix-abstract-accept",
507             ProtocolModePair{SocketProtocol::ProtocolUnixAbstract,
508                              SocketMode::ModeAccept})
509       .Cases("connect", "tcp-connect",
510              ProtocolModePair{SocketProtocol::ProtocolTcp,
511                               SocketMode::ModeConnect})
512       .Case("udp", ProtocolModePair{SocketProtocol::ProtocolTcp,
513                                     SocketMode::ModeConnect})
514       .Case("unix-connect", ProtocolModePair{SocketProtocol::ProtocolUnixDomain,
515                                              SocketMode::ModeConnect})
516       .Case("unix-abstract-connect",
517             ProtocolModePair{SocketProtocol::ProtocolUnixAbstract,
518                              SocketMode::ModeConnect})
519       .Default(std::nullopt);
520 }
521