xref: /freebsd/contrib/llvm-project/lldb/source/Host/common/TCPSocket.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- TCPSocket.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #if defined(_MSC_VER)
10 #define _WINSOCK_DEPRECATED_NO_WARNINGS
11 #endif
12 
13 #include "lldb/Host/common/TCPSocket.h"
14 
15 #include "lldb/Host/Config.h"
16 #include "lldb/Host/MainLoop.h"
17 #include "lldb/Utility/LLDBLog.h"
18 #include "lldb/Utility/Log.h"
19 
20 #include "llvm/Config/llvm-config.h"
21 #include "llvm/Support/Errno.h"
22 #include "llvm/Support/Error.h"
23 #include "llvm/Support/WindowsError.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 #if LLDB_ENABLE_POSIX
27 #include <arpa/inet.h>
28 #include <netinet/tcp.h>
29 #include <sys/socket.h>
30 #endif
31 
32 #if defined(_WIN32)
33 #include <winsock2.h>
34 #endif
35 
36 using namespace lldb;
37 using namespace lldb_private;
38 
39 static const int kType = SOCK_STREAM;
40 
TCPSocket(bool should_close)41 TCPSocket::TCPSocket(bool should_close) : Socket(ProtocolTcp, should_close) {}
42 
TCPSocket(NativeSocket socket,const TCPSocket & listen_socket)43 TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket)
44     : Socket(ProtocolTcp, listen_socket.m_should_close_fd) {
45   m_socket = socket;
46 }
47 
TCPSocket(NativeSocket socket,bool should_close)48 TCPSocket::TCPSocket(NativeSocket socket, bool should_close)
49     : Socket(ProtocolTcp, should_close) {
50   m_socket = socket;
51 }
52 
~TCPSocket()53 TCPSocket::~TCPSocket() { CloseListenSockets(); }
54 
CreatePair()55 llvm::Expected<TCPSocket::Pair> TCPSocket::CreatePair() {
56   auto listen_socket_up = std::make_unique<TCPSocket>(true);
57   if (Status error = listen_socket_up->Listen("localhost:0", 5); error.Fail())
58     return error.takeError();
59 
60   std::string connect_address =
61       llvm::StringRef(listen_socket_up->GetListeningConnectionURI()[0])
62           .split("://")
63           .second.str();
64 
65   auto connect_socket_up = std::make_unique<TCPSocket>(true);
66   if (Status error = connect_socket_up->Connect(connect_address); error.Fail())
67     return error.takeError();
68 
69   // Connection has already been made above, so a short timeout is sufficient.
70   Socket *accept_socket;
71   if (Status error =
72           listen_socket_up->Accept(std::chrono::seconds(1), accept_socket);
73       error.Fail())
74     return error.takeError();
75 
76   return Pair(
77       std::move(connect_socket_up),
78       std::unique_ptr<TCPSocket>(static_cast<TCPSocket *>(accept_socket)));
79 }
80 
IsValid() const81 bool TCPSocket::IsValid() const {
82   return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
83 }
84 
85 // Return the port number that is being used by the socket.
GetLocalPortNumber() const86 uint16_t TCPSocket::GetLocalPortNumber() const {
87   if (m_socket != kInvalidSocketValue) {
88     SocketAddress sock_addr;
89     socklen_t sock_addr_len = sock_addr.GetMaxLength();
90     if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
91       return sock_addr.GetPort();
92   } else if (!m_listen_sockets.empty()) {
93     SocketAddress sock_addr;
94     socklen_t sock_addr_len = sock_addr.GetMaxLength();
95     if (::getsockname(m_listen_sockets.begin()->first, sock_addr,
96                       &sock_addr_len) == 0)
97       return sock_addr.GetPort();
98   }
99   return 0;
100 }
101 
GetLocalIPAddress() const102 std::string TCPSocket::GetLocalIPAddress() const {
103   // We bound to port zero, so we need to figure out which port we actually
104   // bound to
105   if (m_socket != kInvalidSocketValue) {
106     SocketAddress sock_addr;
107     socklen_t sock_addr_len = sock_addr.GetMaxLength();
108     if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
109       return sock_addr.GetIPAddress();
110   }
111   return "";
112 }
113 
GetRemotePortNumber() const114 uint16_t TCPSocket::GetRemotePortNumber() const {
115   if (m_socket != kInvalidSocketValue) {
116     SocketAddress sock_addr;
117     socklen_t sock_addr_len = sock_addr.GetMaxLength();
118     if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
119       return sock_addr.GetPort();
120   }
121   return 0;
122 }
123 
GetRemoteIPAddress() const124 std::string TCPSocket::GetRemoteIPAddress() const {
125   // We bound to port zero, so we need to figure out which port we actually
126   // bound to
127   if (m_socket != kInvalidSocketValue) {
128     SocketAddress sock_addr;
129     socklen_t sock_addr_len = sock_addr.GetMaxLength();
130     if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
131       return sock_addr.GetIPAddress();
132   }
133   return "";
134 }
135 
GetRemoteConnectionURI() const136 std::string TCPSocket::GetRemoteConnectionURI() const {
137   if (m_socket != kInvalidSocketValue) {
138     return std::string(llvm::formatv(
139         "connect://[{0}]:{1}", GetRemoteIPAddress(), GetRemotePortNumber()));
140   }
141   return "";
142 }
143 
GetListeningConnectionURI() const144 std::vector<std::string> TCPSocket::GetListeningConnectionURI() const {
145   std::vector<std::string> URIs;
146   for (const auto &[fd, addr] : m_listen_sockets)
147     URIs.emplace_back(llvm::formatv("connection://[{0}]:{1}",
148                                     addr.GetIPAddress(), addr.GetPort()));
149   return URIs;
150 }
151 
CreateSocket(int domain)152 Status TCPSocket::CreateSocket(int domain) {
153   Status error;
154   if (IsValid())
155     error = Close();
156   if (error.Fail())
157     return error;
158   m_socket = Socket::CreateSocket(domain, kType, IPPROTO_TCP, error);
159   return error;
160 }
161 
Connect(llvm::StringRef name)162 Status TCPSocket::Connect(llvm::StringRef name) {
163 
164   Log *log = GetLog(LLDBLog::Communication);
165   LLDB_LOG(log, "Connect to host/port {0}", name);
166 
167   Status error;
168   llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
169   if (!host_port)
170     return Status::FromError(host_port.takeError());
171 
172   std::vector<SocketAddress> addresses =
173       SocketAddress::GetAddressInfo(host_port->hostname.c_str(), nullptr,
174                                     AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
175   for (SocketAddress &address : addresses) {
176     error = CreateSocket(address.GetFamily());
177     if (error.Fail())
178       continue;
179 
180     address.SetPort(host_port->port);
181 
182     if (llvm::sys::RetryAfterSignal(-1, ::connect, GetNativeSocket(),
183                                     &address.sockaddr(),
184                                     address.GetLength()) == -1) {
185       Close();
186       continue;
187     }
188 
189     if (SetOptionNoDelay() == -1) {
190       Close();
191       continue;
192     }
193 
194     error.Clear();
195     return error;
196   }
197 
198   error = Status::FromErrorStringWithFormatv(
199       "Failed to connect to {0}:{1}", host_port->hostname, host_port->port);
200   return error;
201 }
202 
Listen(llvm::StringRef name,int backlog)203 Status TCPSocket::Listen(llvm::StringRef name, int backlog) {
204   Log *log = GetLog(LLDBLog::Connection);
205   LLDB_LOG(log, "Listen to {0}", name);
206 
207   Status error;
208   llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
209   if (!host_port)
210     return Status::FromError(host_port.takeError());
211 
212   if (host_port->hostname == "*")
213     host_port->hostname = "0.0.0.0";
214   std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
215       host_port->hostname.c_str(), nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
216   for (SocketAddress &address : addresses) {
217     int fd =
218         Socket::CreateSocket(address.GetFamily(), kType, IPPROTO_TCP, error);
219     if (error.Fail() || fd < 0)
220       continue;
221 
222     // enable local address reuse
223     if (SetOption(fd, SOL_SOCKET, SO_REUSEADDR, 1) == -1) {
224       CloseSocket(fd);
225       continue;
226     }
227 
228     SocketAddress listen_address = address;
229     if(!listen_address.IsLocalhost())
230       listen_address.SetToAnyAddress(address.GetFamily(), host_port->port);
231     else
232       listen_address.SetPort(host_port->port);
233 
234     int err =
235         ::bind(fd, &listen_address.sockaddr(), listen_address.GetLength());
236     if (err != -1)
237       err = ::listen(fd, backlog);
238 
239     if (err == -1) {
240       error = GetLastError();
241       CloseSocket(fd);
242       continue;
243     }
244 
245     if (host_port->port == 0) {
246       socklen_t sa_len = listen_address.GetLength();
247       if (getsockname(fd, &listen_address.sockaddr(), &sa_len) == 0)
248         host_port->port = listen_address.GetPort();
249     }
250     m_listen_sockets[fd] = listen_address;
251   }
252 
253   if (m_listen_sockets.empty()) {
254     assert(error.Fail());
255     return error;
256   }
257   return Status();
258 }
259 
CloseListenSockets()260 void TCPSocket::CloseListenSockets() {
261   for (auto socket : m_listen_sockets)
262     CloseSocket(socket.first);
263   m_listen_sockets.clear();
264 }
265 
266 llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>>
Accept(MainLoopBase & loop,std::function<void (std::unique_ptr<Socket> socket)> sock_cb)267 TCPSocket::Accept(MainLoopBase &loop,
268                   std::function<void(std::unique_ptr<Socket> socket)> sock_cb) {
269   if (m_listen_sockets.size() == 0)
270     return llvm::createStringError("No open listening sockets!");
271 
272   std::vector<MainLoopBase::ReadHandleUP> handles;
273   for (auto socket : m_listen_sockets) {
274     auto fd = socket.first;
275     auto io_sp = std::make_shared<TCPSocket>(fd, false);
276     auto cb = [this, fd, sock_cb](MainLoopBase &loop) {
277       lldb_private::SocketAddress AcceptAddr;
278       socklen_t sa_len = AcceptAddr.GetMaxLength();
279       Status error;
280       NativeSocket sock =
281           AcceptSocket(fd, &AcceptAddr.sockaddr(), &sa_len, error);
282       Log *log = GetLog(LLDBLog::Host);
283       if (error.Fail()) {
284         LLDB_LOG(log, "AcceptSocket({0}): {1}", fd, error);
285         return;
286       }
287 
288       const lldb_private::SocketAddress &AddrIn = m_listen_sockets[fd];
289       if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) {
290         CloseSocket(sock);
291         LLDB_LOG(log, "rejecting incoming connection from {0} (expecting {1})",
292                  AcceptAddr.GetIPAddress(), AddrIn.GetIPAddress());
293         return;
294       }
295       std::unique_ptr<TCPSocket> sock_up(new TCPSocket(sock, *this));
296 
297       // Keep our TCP packets coming without any delays.
298       sock_up->SetOptionNoDelay();
299 
300       sock_cb(std::move(sock_up));
301     };
302     Status error;
303     handles.emplace_back(loop.RegisterReadObject(io_sp, cb, error));
304     if (error.Fail())
305       return error.ToError();
306   }
307 
308   return handles;
309 }
310 
SetOptionNoDelay()311 int TCPSocket::SetOptionNoDelay() {
312   return SetOption(IPPROTO_TCP, TCP_NODELAY, 1);
313 }
314 
SetOptionReuseAddress()315 int TCPSocket::SetOptionReuseAddress() {
316   return SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
317 }
318