1 /* SPDX-License-Identifier: GPL-2.0 */
2
3 #ifndef __SOCKET_HELPERS__
4 #define __SOCKET_HELPERS__
5
6 #include <linux/vm_sockets.h>
7
8 /* include/linux/net.h */
9 #define SOCK_TYPE_MASK 0xf
10
11 #define IO_TIMEOUT_SEC 30
12 #define MAX_STRERR_LEN 256
13
14 /* workaround for older vm_sockets.h */
15 #ifndef VMADDR_CID_LOCAL
16 #define VMADDR_CID_LOCAL 1
17 #endif
18
19 /* include/linux/cleanup.h */
20 #define __get_and_null(p, nullvalue) \
21 ({ \
22 __auto_type __ptr = &(p); \
23 __auto_type __val = *__ptr; \
24 *__ptr = nullvalue; \
25 __val; \
26 })
27
28 #define take_fd(fd) __get_and_null(fd, -EBADF)
29
30 /* Wrappers that fail the test on error and report it. */
31
32 #define _FAIL(errnum, fmt...) \
33 ({ \
34 error_at_line(0, (errnum), __func__, __LINE__, fmt); \
35 CHECK_FAIL(true); \
36 })
37 #define FAIL(fmt...) _FAIL(0, fmt)
38 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
39 #define FAIL_LIBBPF(err, msg) \
40 ({ \
41 char __buf[MAX_STRERR_LEN]; \
42 libbpf_strerror((err), __buf, sizeof(__buf)); \
43 FAIL("%s: %s", (msg), __buf); \
44 })
45
46
47 #define xaccept_nonblock(fd, addr, len) \
48 ({ \
49 int __ret = \
50 accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC); \
51 if (__ret == -1) \
52 FAIL_ERRNO("accept"); \
53 __ret; \
54 })
55
56 #define xbind(fd, addr, len) \
57 ({ \
58 int __ret = bind((fd), (addr), (len)); \
59 if (__ret == -1) \
60 FAIL_ERRNO("bind"); \
61 __ret; \
62 })
63
64 #define xclose(fd) \
65 ({ \
66 int __ret = close((fd)); \
67 if (__ret == -1) \
68 FAIL_ERRNO("close"); \
69 __ret; \
70 })
71
72 #define xconnect(fd, addr, len) \
73 ({ \
74 int __ret = connect((fd), (addr), (len)); \
75 if (__ret == -1) \
76 FAIL_ERRNO("connect"); \
77 __ret; \
78 })
79
80 #define xgetsockname(fd, addr, len) \
81 ({ \
82 int __ret = getsockname((fd), (addr), (len)); \
83 if (__ret == -1) \
84 FAIL_ERRNO("getsockname"); \
85 __ret; \
86 })
87
88 #define xgetsockopt(fd, level, name, val, len) \
89 ({ \
90 int __ret = getsockopt((fd), (level), (name), (val), (len)); \
91 if (__ret == -1) \
92 FAIL_ERRNO("getsockopt(" #name ")"); \
93 __ret; \
94 })
95
96 #define xlisten(fd, backlog) \
97 ({ \
98 int __ret = listen((fd), (backlog)); \
99 if (__ret == -1) \
100 FAIL_ERRNO("listen"); \
101 __ret; \
102 })
103
104 #define xsetsockopt(fd, level, name, val, len) \
105 ({ \
106 int __ret = setsockopt((fd), (level), (name), (val), (len)); \
107 if (__ret == -1) \
108 FAIL_ERRNO("setsockopt(" #name ")"); \
109 __ret; \
110 })
111
112 #define xsend(fd, buf, len, flags) \
113 ({ \
114 ssize_t __ret = send((fd), (buf), (len), (flags)); \
115 if (__ret == -1) \
116 FAIL_ERRNO("send"); \
117 __ret; \
118 })
119
120 #define xrecv_nonblock(fd, buf, len, flags) \
121 ({ \
122 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags), \
123 IO_TIMEOUT_SEC); \
124 if (__ret == -1) \
125 FAIL_ERRNO("recv"); \
126 __ret; \
127 })
128
129 #define xsocket(family, sotype, flags) \
130 ({ \
131 int __ret = socket(family, sotype, flags); \
132 if (__ret == -1) \
133 FAIL_ERRNO("socket"); \
134 __ret; \
135 })
136
close_fd(int * fd)137 static inline void close_fd(int *fd)
138 {
139 if (*fd >= 0)
140 xclose(*fd);
141 }
142
143 #define __close_fd __attribute__((cleanup(close_fd)))
144
sockaddr(struct sockaddr_storage * ss)145 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
146 {
147 return (struct sockaddr *)ss;
148 }
149
init_addr_loopback4(struct sockaddr_storage * ss,socklen_t * len)150 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
151 socklen_t *len)
152 {
153 struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
154
155 addr4->sin_family = AF_INET;
156 addr4->sin_port = 0;
157 addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
158 *len = sizeof(*addr4);
159 }
160
init_addr_loopback6(struct sockaddr_storage * ss,socklen_t * len)161 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
162 socklen_t *len)
163 {
164 struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
165
166 addr6->sin6_family = AF_INET6;
167 addr6->sin6_port = 0;
168 addr6->sin6_addr = in6addr_loopback;
169 *len = sizeof(*addr6);
170 }
171
init_addr_loopback_vsock(struct sockaddr_storage * ss,socklen_t * len)172 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
173 socklen_t *len)
174 {
175 struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
176
177 addr->svm_family = AF_VSOCK;
178 addr->svm_port = VMADDR_PORT_ANY;
179 addr->svm_cid = VMADDR_CID_LOCAL;
180 *len = sizeof(*addr);
181 }
182
init_addr_loopback(int family,struct sockaddr_storage * ss,socklen_t * len)183 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
184 socklen_t *len)
185 {
186 switch (family) {
187 case AF_INET:
188 init_addr_loopback4(ss, len);
189 return;
190 case AF_INET6:
191 init_addr_loopback6(ss, len);
192 return;
193 case AF_VSOCK:
194 init_addr_loopback_vsock(ss, len);
195 return;
196 default:
197 FAIL("unsupported address family %d", family);
198 }
199 }
200
enable_reuseport(int s,int progfd)201 static inline int enable_reuseport(int s, int progfd)
202 {
203 int err, one = 1;
204
205 err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
206 if (err)
207 return -1;
208 err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
209 sizeof(progfd));
210 if (err)
211 return -1;
212
213 return 0;
214 }
215
socket_loopback_reuseport(int family,int sotype,int progfd)216 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
217 {
218 struct sockaddr_storage addr;
219 socklen_t len = 0;
220 int err, s;
221
222 init_addr_loopback(family, &addr, &len);
223
224 s = xsocket(family, sotype, 0);
225 if (s == -1)
226 return -1;
227
228 if (progfd >= 0)
229 enable_reuseport(s, progfd);
230
231 err = xbind(s, sockaddr(&addr), len);
232 if (err)
233 goto close;
234
235 if (sotype & SOCK_DGRAM)
236 return s;
237
238 err = xlisten(s, SOMAXCONN);
239 if (err)
240 goto close;
241
242 return s;
243 close:
244 xclose(s);
245 return -1;
246 }
247
socket_loopback(int family,int sotype)248 static inline int socket_loopback(int family, int sotype)
249 {
250 return socket_loopback_reuseport(family, sotype, -1);
251 }
252
poll_connect(int fd,unsigned int timeout_sec)253 static inline int poll_connect(int fd, unsigned int timeout_sec)
254 {
255 struct timeval timeout = { .tv_sec = timeout_sec };
256 fd_set wfds;
257 int r, eval;
258 socklen_t esize = sizeof(eval);
259
260 FD_ZERO(&wfds);
261 FD_SET(fd, &wfds);
262
263 r = select(fd + 1, NULL, &wfds, NULL, &timeout);
264 if (r == 0)
265 errno = ETIME;
266 if (r != 1)
267 return -1;
268
269 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
270 return -1;
271 if (eval != 0) {
272 errno = eval;
273 return -1;
274 }
275
276 return 0;
277 }
278
poll_read(int fd,unsigned int timeout_sec)279 static inline int poll_read(int fd, unsigned int timeout_sec)
280 {
281 struct timeval timeout = { .tv_sec = timeout_sec };
282 fd_set rfds;
283 int r;
284
285 FD_ZERO(&rfds);
286 FD_SET(fd, &rfds);
287
288 r = select(fd + 1, &rfds, NULL, NULL, &timeout);
289 if (r == 0)
290 errno = ETIME;
291
292 return r == 1 ? 0 : -1;
293 }
294
accept_timeout(int fd,struct sockaddr * addr,socklen_t * len,unsigned int timeout_sec)295 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
296 unsigned int timeout_sec)
297 {
298 if (poll_read(fd, timeout_sec))
299 return -1;
300
301 return accept(fd, addr, len);
302 }
303
recv_timeout(int fd,void * buf,size_t len,int flags,unsigned int timeout_sec)304 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
305 unsigned int timeout_sec)
306 {
307 if (poll_read(fd, timeout_sec))
308 return -1;
309
310 return recv(fd, buf, len, flags);
311 }
312
313
create_pair(int family,int sotype,int * p0,int * p1)314 static inline int create_pair(int family, int sotype, int *p0, int *p1)
315 {
316 __close_fd int s, c = -1, p = -1;
317 struct sockaddr_storage addr;
318 socklen_t len = sizeof(addr);
319 int err;
320
321 s = socket_loopback(family, sotype);
322 if (s < 0)
323 return s;
324
325 err = xgetsockname(s, sockaddr(&addr), &len);
326 if (err)
327 return err;
328
329 c = xsocket(family, sotype, 0);
330 if (c < 0)
331 return c;
332
333 err = connect(c, sockaddr(&addr), len);
334 if (err) {
335 if (errno != EINPROGRESS) {
336 FAIL_ERRNO("connect");
337 return err;
338 }
339
340 err = poll_connect(c, IO_TIMEOUT_SEC);
341 if (err) {
342 FAIL_ERRNO("poll_connect");
343 return err;
344 }
345 }
346
347 switch (sotype & SOCK_TYPE_MASK) {
348 case SOCK_DGRAM:
349 err = xgetsockname(c, sockaddr(&addr), &len);
350 if (err)
351 return err;
352
353 err = xconnect(s, sockaddr(&addr), len);
354 if (err)
355 return err;
356
357 *p0 = take_fd(s);
358 break;
359 case SOCK_STREAM:
360 case SOCK_SEQPACKET:
361 p = xaccept_nonblock(s, NULL, NULL);
362 if (p < 0)
363 return p;
364
365 *p0 = take_fd(p);
366 break;
367 default:
368 FAIL("Unsupported socket type %#x", sotype);
369 return -EOPNOTSUPP;
370 }
371
372 *p1 = take_fd(c);
373 return 0;
374 }
375
create_socket_pairs(int family,int sotype,int * c0,int * c1,int * p0,int * p1)376 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
377 int *p0, int *p1)
378 {
379 int err;
380
381 err = create_pair(family, sotype, c0, p0);
382 if (err)
383 return err;
384
385 err = create_pair(family, sotype, c1, p1);
386 if (err) {
387 close(*c0);
388 close(*p0);
389 }
390
391 return err;
392 }
393
394 #endif // __SOCKET_HELPERS__
395