1 /* SPDX-License-Identifier: GPL-2.0 */
2
3 #ifndef __SOCKET_HELPERS__
4 #define __SOCKET_HELPERS__
5
6 #include <sys/un.h>
7 #include <linux/vm_sockets.h>
8
9 /* include/linux/net.h */
10 #define SOCK_TYPE_MASK 0xf
11
12 #define IO_TIMEOUT_SEC 30
13 #define MAX_STRERR_LEN 256
14
15 /* workaround for older vm_sockets.h */
16 #ifndef VMADDR_CID_LOCAL
17 #define VMADDR_CID_LOCAL 1
18 #endif
19
20 /* include/linux/compiler_types.h */
21 #if __STDC_VERSION__ < 202311L && !defined(auto)
22 # define auto __auto_type
23 #endif
24
25 /* include/linux/cleanup.h */
26 #define __get_and_null(p, nullvalue) \
27 ({ \
28 auto __ptr = &(p); \
29 auto __val = *__ptr; \
30 *__ptr = nullvalue; \
31 __val; \
32 })
33
34 #define take_fd(fd) __get_and_null(fd, -EBADF)
35
36 /* Wrappers that fail the test on error and report it. */
37
38 #define _FAIL(errnum, fmt...) \
39 ({ \
40 error_at_line(0, (errnum), __func__, __LINE__, fmt); \
41 CHECK_FAIL(true); \
42 })
43 #define FAIL(fmt...) _FAIL(0, fmt)
44 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
45 #define FAIL_LIBBPF(err, msg) \
46 ({ \
47 char __buf[MAX_STRERR_LEN]; \
48 libbpf_strerror((err), __buf, sizeof(__buf)); \
49 FAIL("%s: %s", (msg), __buf); \
50 })
51
52
53 #define xaccept_nonblock(fd, addr, len) \
54 ({ \
55 int __ret = \
56 accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC); \
57 if (__ret == -1) \
58 FAIL_ERRNO("accept"); \
59 __ret; \
60 })
61
62 #define xbind(fd, addr, len) \
63 ({ \
64 int __ret = bind((fd), (addr), (len)); \
65 if (__ret == -1) \
66 FAIL_ERRNO("bind"); \
67 __ret; \
68 })
69
70 #define xclose(fd) \
71 ({ \
72 int __ret = close((fd)); \
73 if (__ret == -1) \
74 FAIL_ERRNO("close"); \
75 __ret; \
76 })
77
78 #define xconnect(fd, addr, len) \
79 ({ \
80 int __ret = connect((fd), (addr), (len)); \
81 if (__ret == -1) \
82 FAIL_ERRNO("connect"); \
83 __ret; \
84 })
85
86 #define xgetsockname(fd, addr, len) \
87 ({ \
88 int __ret = getsockname((fd), (addr), (len)); \
89 if (__ret == -1) \
90 FAIL_ERRNO("getsockname"); \
91 __ret; \
92 })
93
94 #define xgetsockopt(fd, level, name, val, len) \
95 ({ \
96 int __ret = getsockopt((fd), (level), (name), (val), (len)); \
97 if (__ret == -1) \
98 FAIL_ERRNO("getsockopt(" #name ")"); \
99 __ret; \
100 })
101
102 #define xlisten(fd, backlog) \
103 ({ \
104 int __ret = listen((fd), (backlog)); \
105 if (__ret == -1) \
106 FAIL_ERRNO("listen"); \
107 __ret; \
108 })
109
110 #define xsetsockopt(fd, level, name, val, len) \
111 ({ \
112 int __ret = setsockopt((fd), (level), (name), (val), (len)); \
113 if (__ret == -1) \
114 FAIL_ERRNO("setsockopt(" #name ")"); \
115 __ret; \
116 })
117
118 #define xsend(fd, buf, len, flags) \
119 ({ \
120 ssize_t __ret = send((fd), (buf), (len), (flags)); \
121 if (__ret == -1) \
122 FAIL_ERRNO("send"); \
123 __ret; \
124 })
125
126 #define xrecv_nonblock(fd, buf, len, flags) \
127 ({ \
128 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags), \
129 IO_TIMEOUT_SEC); \
130 if (__ret == -1) \
131 FAIL_ERRNO("recv"); \
132 __ret; \
133 })
134
135 #define xsocket(family, sotype, flags) \
136 ({ \
137 int __ret = socket(family, sotype, flags); \
138 if (__ret == -1) \
139 FAIL_ERRNO("socket"); \
140 __ret; \
141 })
142
close_fd(int * fd)143 static inline void close_fd(int *fd)
144 {
145 if (*fd >= 0)
146 xclose(*fd);
147 }
148
149 #define __close_fd __attribute__((cleanup(close_fd)))
150
sockaddr(struct sockaddr_storage * ss)151 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
152 {
153 return (struct sockaddr *)ss;
154 }
155
init_addr_loopback4(struct sockaddr_storage * ss,socklen_t * len)156 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
157 socklen_t *len)
158 {
159 struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
160
161 addr4->sin_family = AF_INET;
162 addr4->sin_port = 0;
163 addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
164 *len = sizeof(*addr4);
165 }
166
init_addr_loopback6(struct sockaddr_storage * ss,socklen_t * len)167 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
168 socklen_t *len)
169 {
170 struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
171
172 addr6->sin6_family = AF_INET6;
173 addr6->sin6_port = 0;
174 addr6->sin6_addr = in6addr_loopback;
175 *len = sizeof(*addr6);
176 }
177
init_addr_loopback_unix(struct sockaddr_storage * ss,socklen_t * len)178 static inline void init_addr_loopback_unix(struct sockaddr_storage *ss,
179 socklen_t *len)
180 {
181 struct sockaddr_un *addr = memset(ss, 0, sizeof(*ss));
182
183 addr->sun_family = AF_UNIX;
184 *len = sizeof(sa_family_t);
185 }
186
init_addr_loopback_vsock(struct sockaddr_storage * ss,socklen_t * len)187 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
188 socklen_t *len)
189 {
190 struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
191
192 addr->svm_family = AF_VSOCK;
193 addr->svm_port = VMADDR_PORT_ANY;
194 addr->svm_cid = VMADDR_CID_LOCAL;
195 *len = sizeof(*addr);
196 }
197
init_addr_loopback(int family,struct sockaddr_storage * ss,socklen_t * len)198 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
199 socklen_t *len)
200 {
201 switch (family) {
202 case AF_INET:
203 init_addr_loopback4(ss, len);
204 return;
205 case AF_INET6:
206 init_addr_loopback6(ss, len);
207 return;
208 case AF_UNIX:
209 init_addr_loopback_unix(ss, len);
210 return;
211 case AF_VSOCK:
212 init_addr_loopback_vsock(ss, len);
213 return;
214 default:
215 FAIL("unsupported address family %d", family);
216 }
217 }
218
enable_reuseport(int s,int progfd)219 static inline int enable_reuseport(int s, int progfd)
220 {
221 int err, one = 1;
222
223 err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
224 if (err)
225 return -1;
226 err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
227 sizeof(progfd));
228 if (err)
229 return -1;
230
231 return 0;
232 }
233
socket_loopback_reuseport(int family,int sotype,int progfd)234 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
235 {
236 struct sockaddr_storage addr;
237 socklen_t len = 0;
238 int err, s;
239
240 init_addr_loopback(family, &addr, &len);
241
242 s = xsocket(family, sotype, 0);
243 if (s == -1)
244 return -1;
245
246 if (progfd >= 0)
247 enable_reuseport(s, progfd);
248
249 err = xbind(s, sockaddr(&addr), len);
250 if (err)
251 goto close;
252
253 if (sotype & SOCK_DGRAM)
254 return s;
255
256 err = xlisten(s, SOMAXCONN);
257 if (err)
258 goto close;
259
260 return s;
261 close:
262 xclose(s);
263 return -1;
264 }
265
socket_loopback(int family,int sotype)266 static inline int socket_loopback(int family, int sotype)
267 {
268 return socket_loopback_reuseport(family, sotype, -1);
269 }
270
poll_connect(int fd,unsigned int timeout_sec)271 static inline int poll_connect(int fd, unsigned int timeout_sec)
272 {
273 struct timeval timeout = { .tv_sec = timeout_sec };
274 fd_set wfds;
275 int r, eval;
276 socklen_t esize = sizeof(eval);
277
278 FD_ZERO(&wfds);
279 FD_SET(fd, &wfds);
280
281 r = select(fd + 1, NULL, &wfds, NULL, &timeout);
282 if (r == 0)
283 errno = ETIME;
284 if (r != 1)
285 return -1;
286
287 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
288 return -1;
289 if (eval != 0) {
290 errno = eval;
291 return -1;
292 }
293
294 return 0;
295 }
296
poll_read(int fd,unsigned int timeout_sec)297 static inline int poll_read(int fd, unsigned int timeout_sec)
298 {
299 struct timeval timeout = { .tv_sec = timeout_sec };
300 fd_set rfds;
301 int r;
302
303 FD_ZERO(&rfds);
304 FD_SET(fd, &rfds);
305
306 r = select(fd + 1, &rfds, NULL, NULL, &timeout);
307 if (r == 0)
308 errno = ETIME;
309
310 return r == 1 ? 0 : -1;
311 }
312
accept_timeout(int fd,struct sockaddr * addr,socklen_t * len,unsigned int timeout_sec)313 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
314 unsigned int timeout_sec)
315 {
316 if (poll_read(fd, timeout_sec))
317 return -1;
318
319 return accept(fd, addr, len);
320 }
321
recv_timeout(int fd,void * buf,size_t len,int flags,unsigned int timeout_sec)322 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
323 unsigned int timeout_sec)
324 {
325 if (poll_read(fd, timeout_sec))
326 return -1;
327
328 return recv(fd, buf, len, flags);
329 }
330
331
create_pair(int family,int sotype,int * p0,int * p1)332 static inline int create_pair(int family, int sotype, int *p0, int *p1)
333 {
334 __close_fd int s, c = -1, p = -1;
335 struct sockaddr_storage addr;
336 socklen_t len;
337 int err;
338
339 s = socket_loopback(family, sotype);
340 if (s < 0)
341 return s;
342
343 c = xsocket(family, sotype, 0);
344 if (c < 0)
345 return c;
346
347 init_addr_loopback(family, &addr, &len);
348 err = xbind(c, sockaddr(&addr), len);
349 if (err)
350 return err;
351
352 len = sizeof(addr);
353 err = xgetsockname(s, sockaddr(&addr), &len);
354 if (err)
355 return err;
356
357 err = connect(c, sockaddr(&addr), len);
358 if (err) {
359 if (errno != EINPROGRESS) {
360 FAIL_ERRNO("connect");
361 return err;
362 }
363
364 err = poll_connect(c, IO_TIMEOUT_SEC);
365 if (err) {
366 FAIL_ERRNO("poll_connect");
367 return err;
368 }
369 }
370
371 switch (sotype & SOCK_TYPE_MASK) {
372 case SOCK_DGRAM:
373 err = xgetsockname(c, sockaddr(&addr), &len);
374 if (err)
375 return err;
376
377 err = xconnect(s, sockaddr(&addr), len);
378 if (err)
379 return err;
380
381 *p0 = take_fd(s);
382 break;
383 case SOCK_STREAM:
384 case SOCK_SEQPACKET:
385 p = xaccept_nonblock(s, NULL, NULL);
386 if (p < 0)
387 return p;
388
389 *p0 = take_fd(p);
390 break;
391 default:
392 FAIL("Unsupported socket type %#x", sotype);
393 return -EOPNOTSUPP;
394 }
395
396 *p1 = take_fd(c);
397 return 0;
398 }
399
create_socket_pairs(int family,int sotype,int * c0,int * c1,int * p0,int * p1)400 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
401 int *p0, int *p1)
402 {
403 int err;
404
405 err = create_pair(family, sotype, c0, p0);
406 if (err)
407 return err;
408
409 err = create_pair(family, sotype, c1, p1);
410 if (err) {
411 close(*c0);
412 close(*p0);
413 }
414
415 return err;
416 }
417
socket_kind_to_str(int sock_fd)418 static inline const char *socket_kind_to_str(int sock_fd)
419 {
420 socklen_t opt_len;
421 int domain, type;
422
423 opt_len = sizeof(domain);
424 if (getsockopt(sock_fd, SOL_SOCKET, SO_DOMAIN, &domain, &opt_len))
425 FAIL_ERRNO("getsockopt(SO_DOMAIN)");
426
427 opt_len = sizeof(type);
428 if (getsockopt(sock_fd, SOL_SOCKET, SO_TYPE, &type, &opt_len))
429 FAIL_ERRNO("getsockopt(SO_TYPE)");
430
431 switch (domain) {
432 case AF_INET:
433 switch (type) {
434 case SOCK_STREAM:
435 return "tcp4";
436 case SOCK_DGRAM:
437 return "udp4";
438 }
439 break;
440 case AF_INET6:
441 switch (type) {
442 case SOCK_STREAM:
443 return "tcp6";
444 case SOCK_DGRAM:
445 return "udp6";
446 }
447 break;
448 case AF_UNIX:
449 switch (type) {
450 case SOCK_STREAM:
451 return "u_str";
452 case SOCK_DGRAM:
453 return "u_dgr";
454 case SOCK_SEQPACKET:
455 return "u_seq";
456 }
457 break;
458 case AF_VSOCK:
459 switch (type) {
460 case SOCK_STREAM:
461 return "v_str";
462 case SOCK_DGRAM:
463 return "v_dgr";
464 case SOCK_SEQPACKET:
465 return "v_seq";
466 }
467 break;
468 }
469
470 return "???";
471 }
472
473 #endif // __SOCKET_HELPERS__
474