xref: /freebsd/contrib/llvm-project/llvm/lib/Support/raw_socket_stream.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===-- llvm/Support/raw_socket_stream.cpp - Socket streams --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains raw_ostream implementations for streams to communicate
10 // via UNIX sockets
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Support/raw_socket_stream.h"
15 #include "llvm/Config/config.h"
16 #include "llvm/Support/Error.h"
17 #include "llvm/Support/FileSystem.h"
18 
19 #include <atomic>
20 #include <fcntl.h>
21 #include <functional>
22 
23 #ifndef _WIN32
24 #include <poll.h>
25 #include <sys/socket.h>
26 #include <sys/un.h>
27 #else
28 #include "llvm/Support/Windows/WindowsSupport.h"
29 // winsock2.h must be included before afunix.h. Briefly turn off clang-format to
30 // avoid error.
31 // clang-format off
32 #include <winsock2.h>
33 #include <afunix.h>
34 // clang-format on
35 #include <io.h>
36 #endif // _WIN32
37 
38 #if defined(HAVE_UNISTD_H)
39 #include <unistd.h>
40 #endif
41 
42 using namespace llvm;
43 
44 #ifdef _WIN32
45 WSABalancer::WSABalancer() {
46   WSADATA WsaData;
47   ::memset(&WsaData, 0, sizeof(WsaData));
48   if (WSAStartup(MAKEWORD(2, 2), &WsaData) != 0) {
49     llvm::report_fatal_error("WSAStartup failed");
50   }
51 }
52 
53 WSABalancer::~WSABalancer() { WSACleanup(); }
54 #endif // _WIN32
55 
56 static std::error_code getLastSocketErrorCode() {
57 #ifdef _WIN32
58   return std::error_code(::WSAGetLastError(), std::system_category());
59 #else
60   return errnoAsErrorCode();
61 #endif
62 }
63 
64 static sockaddr_un setSocketAddr(StringRef SocketPath) {
65   struct sockaddr_un Addr;
66   memset(&Addr, 0, sizeof(Addr));
67   Addr.sun_family = AF_UNIX;
68   strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
69   return Addr;
70 }
71 
72 static Expected<int> getSocketFD(StringRef SocketPath) {
73 #ifdef _WIN32
74   SOCKET Socket = socket(AF_UNIX, SOCK_STREAM, 0);
75   if (Socket == INVALID_SOCKET) {
76 #else
77   int Socket = socket(AF_UNIX, SOCK_STREAM, 0);
78   if (Socket == -1) {
79 #endif // _WIN32
80     return llvm::make_error<StringError>(getLastSocketErrorCode(),
81                                          "Create socket failed");
82   }
83 
84 #ifdef __CYGWIN__
85   // On Cygwin, UNIX sockets involve a handshake between connect and accept
86   // to enable SO_PEERCRED/getpeereid handling.  This necessitates accept being
87   // called before connect can return, but at least the tests in
88   // llvm/unittests/Support/raw_socket_stream_test do both on the same thread
89   // (first connect and then accept), resulting in a deadlock.  This call turns
90   // off the handshake (and SO_PEERCRED/getpeereid support).
91   setsockopt(Socket, SOL_SOCKET, SO_PEERCRED, NULL, 0);
92 #endif
93   struct sockaddr_un Addr = setSocketAddr(SocketPath);
94   if (::connect(Socket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1)
95     return llvm::make_error<StringError>(getLastSocketErrorCode(),
96                                          "Connect socket failed");
97 
98 #ifdef _WIN32
99   return _open_osfhandle(Socket, 0);
100 #else
101   return Socket;
102 #endif // _WIN32
103 }
104 
105 ListeningSocket::ListeningSocket(int SocketFD, StringRef SocketPath,
106                                  int PipeFD[2])
107     : FD(SocketFD), SocketPath(SocketPath), PipeFD{PipeFD[0], PipeFD[1]} {}
108 
109 ListeningSocket::ListeningSocket(ListeningSocket &&LS)
110     : FD(LS.FD.load()), SocketPath(LS.SocketPath),
111       PipeFD{LS.PipeFD[0], LS.PipeFD[1]} {
112 
113   LS.FD = -1;
114   LS.SocketPath.clear();
115   LS.PipeFD[0] = -1;
116   LS.PipeFD[1] = -1;
117 }
118 
119 Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
120                                                       int MaxBacklog) {
121 
122   // Handle instances where the target socket address already exists and
123   // differentiate between a preexisting file with and without a bound socket
124   //
125   // ::bind will return std::errc:address_in_use if a file at the socket address
126   // already exists (e.g., the file was not properly unlinked due to a crash)
127   // even if another socket has not yet binded to that address
128   if (llvm::sys::fs::exists(SocketPath)) {
129     Expected<int> MaybeFD = getSocketFD(SocketPath);
130     if (!MaybeFD) {
131 
132       // Regardless of the error, notify the caller that a file already exists
133       // at the desired socket address and that there is no bound socket at that
134       // address. The file must be removed before ::bind can use the address
135       consumeError(MaybeFD.takeError());
136       return llvm::make_error<StringError>(
137           std::make_error_code(std::errc::file_exists),
138           "Socket address unavailable");
139     }
140     ::close(std::move(*MaybeFD));
141 
142     // Notify caller that the provided socket address already has a bound socket
143     return llvm::make_error<StringError>(
144         std::make_error_code(std::errc::address_in_use),
145         "Socket address unavailable");
146   }
147 
148 #ifdef _WIN32
149   WSABalancer _;
150   SOCKET Socket = socket(AF_UNIX, SOCK_STREAM, 0);
151   if (Socket == INVALID_SOCKET)
152 #else
153   int Socket = socket(AF_UNIX, SOCK_STREAM, 0);
154   if (Socket == -1)
155 #endif
156     return llvm::make_error<StringError>(getLastSocketErrorCode(),
157                                          "socket create failed");
158 
159 #ifdef __CYGWIN__
160   // On Cygwin, UNIX sockets involve a handshake between connect and accept
161   // to enable SO_PEERCRED/getpeereid handling.  This necessitates accept being
162   // called before connect can return, but at least the tests in
163   // llvm/unittests/Support/raw_socket_stream_test do both on the same thread
164   // (first connect and then accept), resulting in a deadlock.  This call turns
165   // off the handshake (and SO_PEERCRED/getpeereid support).
166   setsockopt(Socket, SOL_SOCKET, SO_PEERCRED, NULL, 0);
167 #endif
168   struct sockaddr_un Addr = setSocketAddr(SocketPath);
169   if (::bind(Socket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
170     // Grab error code from call to ::bind before calling ::close
171     std::error_code EC = getLastSocketErrorCode();
172     ::close(Socket);
173     return llvm::make_error<StringError>(EC, "Bind error");
174   }
175 
176   // Mark socket as passive so incoming connections can be accepted
177   if (::listen(Socket, MaxBacklog) == -1)
178     return llvm::make_error<StringError>(getLastSocketErrorCode(),
179                                          "Listen error");
180 
181   int PipeFD[2];
182 #ifdef _WIN32
183   // Reserve 1 byte for the pipe and use default textmode
184   if (::_pipe(PipeFD, 1, 0) == -1)
185 #else
186   if (::pipe(PipeFD) == -1)
187 #endif // _WIN32
188     return llvm::make_error<StringError>(getLastSocketErrorCode(),
189                                          "pipe failed");
190 
191 #ifdef _WIN32
192   return ListeningSocket{_open_osfhandle(Socket, 0), SocketPath, PipeFD};
193 #else
194   return ListeningSocket{Socket, SocketPath, PipeFD};
195 #endif // _WIN32
196 }
197 
198 // If a file descriptor being monitored by ::poll is closed by another thread,
199 // the result is unspecified. In the case ::poll does not unblock and return,
200 // when ActiveFD is closed, you can provide another file descriptor via CancelFD
201 // that when written to will cause poll to return. Typically CancelFD is the
202 // read end of a unidirectional pipe.
203 //
204 // Timeout should be -1 to block indefinitly
205 //
206 // getActiveFD is a callback to handle ActiveFD's of std::atomic<int> and int
207 static std::error_code
208 manageTimeout(const std::chrono::milliseconds &Timeout,
209               const std::function<int()> &getActiveFD,
210               const std::optional<int> &CancelFD = std::nullopt) {
211   struct pollfd FD[2];
212   FD[0].events = POLLIN;
213 #ifdef _WIN32
214   SOCKET WinServerSock = _get_osfhandle(getActiveFD());
215   FD[0].fd = WinServerSock;
216 #else
217   FD[0].fd = getActiveFD();
218 #endif
219   uint8_t FDCount = 1;
220   if (CancelFD.has_value()) {
221     FD[1].events = POLLIN;
222     FD[1].fd = CancelFD.value();
223     FDCount++;
224   }
225 
226   // Keep track of how much time has passed in case ::poll or WSAPoll are
227   // interupted by a signal and need to be recalled
228   auto Start = std::chrono::steady_clock::now();
229   auto RemainingTimeout = Timeout;
230   int PollStatus = 0;
231   do {
232     // If Timeout is -1 then poll should block and RemainingTimeout does not
233     // need to be recalculated
234     if (PollStatus != 0 && Timeout != std::chrono::milliseconds(-1)) {
235       auto TotalElapsedTime =
236           std::chrono::duration_cast<std::chrono::milliseconds>(
237               std::chrono::steady_clock::now() - Start);
238 
239       if (TotalElapsedTime >= Timeout)
240         return std::make_error_code(std::errc::operation_would_block);
241 
242       RemainingTimeout = Timeout - TotalElapsedTime;
243     }
244 #ifdef _WIN32
245     PollStatus = WSAPoll(FD, FDCount, RemainingTimeout.count());
246   } while (PollStatus == SOCKET_ERROR &&
247            getLastSocketErrorCode() == std::errc::interrupted);
248 #else
249     PollStatus = ::poll(FD, FDCount, RemainingTimeout.count());
250   } while (PollStatus == -1 &&
251            getLastSocketErrorCode() == std::errc::interrupted);
252 #endif
253 
254   // If ActiveFD equals -1 or CancelFD has data to be read then the operation
255   // has been canceled by another thread
256   if (getActiveFD() == -1 || (CancelFD.has_value() && FD[1].revents & POLLIN))
257     return std::make_error_code(std::errc::operation_canceled);
258 #if _WIN32
259   if (PollStatus == SOCKET_ERROR)
260 #else
261   if (PollStatus == -1)
262 #endif
263     return getLastSocketErrorCode();
264   if (PollStatus == 0)
265     return std::make_error_code(std::errc::timed_out);
266   if (FD[0].revents & POLLNVAL)
267     return std::make_error_code(std::errc::bad_file_descriptor);
268   return std::error_code();
269 }
270 
271 Expected<std::unique_ptr<raw_socket_stream>>
272 ListeningSocket::accept(const std::chrono::milliseconds &Timeout) {
273   auto getActiveFD = [this]() -> int { return FD; };
274   std::error_code TimeoutErr = manageTimeout(Timeout, getActiveFD, PipeFD[0]);
275   if (TimeoutErr)
276     return llvm::make_error<StringError>(TimeoutErr, "Timeout error");
277 
278   int AcceptFD;
279 #ifdef _WIN32
280   SOCKET WinAcceptSock = ::accept(_get_osfhandle(FD), NULL, NULL);
281   AcceptFD = _open_osfhandle(WinAcceptSock, 0);
282 #else
283   AcceptFD = ::accept(FD, NULL, NULL);
284 #endif
285 
286   if (AcceptFD == -1)
287     return llvm::make_error<StringError>(getLastSocketErrorCode(),
288                                          "Socket accept failed");
289   return std::make_unique<raw_socket_stream>(AcceptFD);
290 }
291 
292 void ListeningSocket::shutdown() {
293   int ObservedFD = FD.load();
294 
295   if (ObservedFD == -1)
296     return;
297 
298   // If FD equals ObservedFD set FD to -1; If FD doesn't equal ObservedFD then
299   // another thread is responsible for shutdown so return
300   if (!FD.compare_exchange_strong(ObservedFD, -1))
301     return;
302 
303   ::close(ObservedFD);
304   ::unlink(SocketPath.c_str());
305 
306   // Ensure ::poll returns if shutdown is called by a separate thread
307   char Byte = 'A';
308   ssize_t written = ::write(PipeFD[1], &Byte, 1);
309 
310   // Ignore any write() error
311   (void)written;
312 }
313 
314 ListeningSocket::~ListeningSocket() {
315   shutdown();
316 
317   // Close the pipe's FDs in the destructor instead of within
318   // ListeningSocket::shutdown to avoid unnecessary synchronization issues that
319   // would occur as PipeFD's values would have to be changed to -1
320   //
321   // The move constructor sets PipeFD to -1
322   if (PipeFD[0] != -1)
323     ::close(PipeFD[0]);
324   if (PipeFD[1] != -1)
325     ::close(PipeFD[1]);
326 }
327 
328 //===----------------------------------------------------------------------===//
329 //  raw_socket_stream
330 //===----------------------------------------------------------------------===//
331 
332 raw_socket_stream::raw_socket_stream(int SocketFD)
333     : raw_fd_stream(SocketFD, true) {}
334 
335 raw_socket_stream::~raw_socket_stream() {}
336 
337 Expected<std::unique_ptr<raw_socket_stream>>
338 raw_socket_stream::createConnectedUnix(StringRef SocketPath) {
339 #ifdef _WIN32
340   WSABalancer _;
341 #endif // _WIN32
342   Expected<int> FD = getSocketFD(SocketPath);
343   if (!FD)
344     return FD.takeError();
345   return std::make_unique<raw_socket_stream>(*FD);
346 }
347 
348 ssize_t raw_socket_stream::read(char *Ptr, size_t Size,
349                                 const std::chrono::milliseconds &Timeout) {
350   auto getActiveFD = [this]() -> int { return this->get_fd(); };
351   std::error_code Err = manageTimeout(Timeout, getActiveFD);
352   // Mimic raw_fd_stream::read error handling behavior
353   if (Err) {
354     raw_fd_stream::error_detected(Err);
355     return -1;
356   }
357   return raw_fd_stream::read(Ptr, Size);
358 }
359