xref: /linux/tools/testing/selftests/bpf/prog_tests/socket_helpers.h (revision c752c21c90b808a059ae8e0070ff7566a65f8577)
1 /* SPDX-License-Identifier: GPL-2.0 */
2 
3 #ifndef __SOCKET_HELPERS__
4 #define __SOCKET_HELPERS__
5 
6 #include <sys/un.h>
7 #include <linux/vm_sockets.h>
8 
9 /* include/linux/net.h */
10 #define SOCK_TYPE_MASK 0xf
11 
12 #define IO_TIMEOUT_SEC 30
13 #define MAX_STRERR_LEN 256
14 
15 /* workaround for older vm_sockets.h */
16 #ifndef VMADDR_CID_LOCAL
17 #define VMADDR_CID_LOCAL 1
18 #endif
19 
20 /* include/linux/compiler_types.h */
21 #if __STDC_VERSION__ < 202311L && !defined(auto)
22 # define auto __auto_type
23 #endif
24 
25 /* include/linux/cleanup.h */
26 #define __get_and_null(p, nullvalue)                                           \
27 	({                                                                     \
28 		auto __ptr = &(p);					       \
29 		auto __val = *__ptr;                                           \
30 		*__ptr = nullvalue;                                            \
31 		__val;                                                         \
32 	})
33 
34 #define take_fd(fd) __get_and_null(fd, -EBADF)
35 
36 /* Wrappers that fail the test on error and report it. */
37 
38 #define _FAIL(errnum, fmt...)                                                  \
39 	({                                                                     \
40 		error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
41 		CHECK_FAIL(true);                                              \
42 	})
43 #define FAIL(fmt...) _FAIL(0, fmt)
44 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
45 #define FAIL_LIBBPF(err, msg)                                                  \
46 	({                                                                     \
47 		char __buf[MAX_STRERR_LEN];                                    \
48 		libbpf_strerror((err), __buf, sizeof(__buf));                  \
49 		FAIL("%s: %s", (msg), __buf);                                  \
50 	})
51 
52 
53 #define xaccept_nonblock(fd, addr, len)                                        \
54 	({                                                                     \
55 		int __ret =                                                    \
56 			accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC);   \
57 		if (__ret == -1)                                               \
58 			FAIL_ERRNO("accept");                                  \
59 		__ret;                                                         \
60 	})
61 
62 #define xbind(fd, addr, len)                                                   \
63 	({                                                                     \
64 		int __ret = bind((fd), (addr), (len));                         \
65 		if (__ret == -1)                                               \
66 			FAIL_ERRNO("bind");                                    \
67 		__ret;                                                         \
68 	})
69 
70 #define xclose(fd)                                                             \
71 	({                                                                     \
72 		int __ret = close((fd));                                       \
73 		if (__ret == -1)                                               \
74 			FAIL_ERRNO("close");                                   \
75 		__ret;                                                         \
76 	})
77 
78 #define xconnect(fd, addr, len)                                                \
79 	({                                                                     \
80 		int __ret = connect((fd), (addr), (len));                      \
81 		if (__ret == -1)                                               \
82 			FAIL_ERRNO("connect");                                 \
83 		__ret;                                                         \
84 	})
85 
86 #define xgetsockname(fd, addr, len)                                            \
87 	({                                                                     \
88 		int __ret = getsockname((fd), (addr), (len));                  \
89 		if (__ret == -1)                                               \
90 			FAIL_ERRNO("getsockname");                             \
91 		__ret;                                                         \
92 	})
93 
94 #define xgetsockopt(fd, level, name, val, len)                                 \
95 	({                                                                     \
96 		int __ret = getsockopt((fd), (level), (name), (val), (len));   \
97 		if (__ret == -1)                                               \
98 			FAIL_ERRNO("getsockopt(" #name ")");                   \
99 		__ret;                                                         \
100 	})
101 
102 #define xlisten(fd, backlog)                                                   \
103 	({                                                                     \
104 		int __ret = listen((fd), (backlog));                           \
105 		if (__ret == -1)                                               \
106 			FAIL_ERRNO("listen");                                  \
107 		__ret;                                                         \
108 	})
109 
110 #define xsetsockopt(fd, level, name, val, len)                                 \
111 	({                                                                     \
112 		int __ret = setsockopt((fd), (level), (name), (val), (len));   \
113 		if (__ret == -1)                                               \
114 			FAIL_ERRNO("setsockopt(" #name ")");                   \
115 		__ret;                                                         \
116 	})
117 
118 #define xsend(fd, buf, len, flags)                                             \
119 	({                                                                     \
120 		ssize_t __ret = send((fd), (buf), (len), (flags));             \
121 		if (__ret == -1)                                               \
122 			FAIL_ERRNO("send");                                    \
123 		__ret;                                                         \
124 	})
125 
126 #define xrecv_nonblock(fd, buf, len, flags)                                    \
127 	({                                                                     \
128 		ssize_t __ret = recv_timeout((fd), (buf), (len), (flags),      \
129 					     IO_TIMEOUT_SEC);                  \
130 		if (__ret == -1)                                               \
131 			FAIL_ERRNO("recv");                                    \
132 		__ret;                                                         \
133 	})
134 
135 #define xsocket(family, sotype, flags)                                         \
136 	({                                                                     \
137 		int __ret = socket(family, sotype, flags);                     \
138 		if (__ret == -1)                                               \
139 			FAIL_ERRNO("socket");                                  \
140 		__ret;                                                         \
141 	})
142 
close_fd(int * fd)143 static inline void close_fd(int *fd)
144 {
145 	if (*fd >= 0)
146 		xclose(*fd);
147 }
148 
149 #define __close_fd __attribute__((cleanup(close_fd)))
150 
sockaddr(struct sockaddr_storage * ss)151 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
152 {
153 	return (struct sockaddr *)ss;
154 }
155 
init_addr_loopback4(struct sockaddr_storage * ss,socklen_t * len)156 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
157 				       socklen_t *len)
158 {
159 	struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
160 
161 	addr4->sin_family = AF_INET;
162 	addr4->sin_port = 0;
163 	addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
164 	*len = sizeof(*addr4);
165 }
166 
init_addr_loopback6(struct sockaddr_storage * ss,socklen_t * len)167 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
168 				       socklen_t *len)
169 {
170 	struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
171 
172 	addr6->sin6_family = AF_INET6;
173 	addr6->sin6_port = 0;
174 	addr6->sin6_addr = in6addr_loopback;
175 	*len = sizeof(*addr6);
176 }
177 
init_addr_loopback_unix(struct sockaddr_storage * ss,socklen_t * len)178 static inline void init_addr_loopback_unix(struct sockaddr_storage *ss,
179 					   socklen_t *len)
180 {
181 	struct sockaddr_un *addr = memset(ss, 0, sizeof(*ss));
182 
183 	addr->sun_family = AF_UNIX;
184 	*len = sizeof(sa_family_t);
185 }
186 
init_addr_loopback_vsock(struct sockaddr_storage * ss,socklen_t * len)187 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
188 					    socklen_t *len)
189 {
190 	struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
191 
192 	addr->svm_family = AF_VSOCK;
193 	addr->svm_port = VMADDR_PORT_ANY;
194 	addr->svm_cid = VMADDR_CID_LOCAL;
195 	*len = sizeof(*addr);
196 }
197 
init_addr_loopback(int family,struct sockaddr_storage * ss,socklen_t * len)198 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
199 				      socklen_t *len)
200 {
201 	switch (family) {
202 	case AF_INET:
203 		init_addr_loopback4(ss, len);
204 		return;
205 	case AF_INET6:
206 		init_addr_loopback6(ss, len);
207 		return;
208 	case AF_UNIX:
209 		init_addr_loopback_unix(ss, len);
210 		return;
211 	case AF_VSOCK:
212 		init_addr_loopback_vsock(ss, len);
213 		return;
214 	default:
215 		FAIL("unsupported address family %d", family);
216 	}
217 }
218 
enable_reuseport(int s,int progfd)219 static inline int enable_reuseport(int s, int progfd)
220 {
221 	int err, one = 1;
222 
223 	err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
224 	if (err)
225 		return -1;
226 	err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
227 			  sizeof(progfd));
228 	if (err)
229 		return -1;
230 
231 	return 0;
232 }
233 
socket_loopback_reuseport(int family,int sotype,int progfd)234 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
235 {
236 	struct sockaddr_storage addr;
237 	socklen_t len = 0;
238 	int err, s;
239 
240 	init_addr_loopback(family, &addr, &len);
241 
242 	s = xsocket(family, sotype, 0);
243 	if (s == -1)
244 		return -1;
245 
246 	if (progfd >= 0)
247 		enable_reuseport(s, progfd);
248 
249 	err = xbind(s, sockaddr(&addr), len);
250 	if (err)
251 		goto close;
252 
253 	if (sotype & SOCK_DGRAM)
254 		return s;
255 
256 	err = xlisten(s, SOMAXCONN);
257 	if (err)
258 		goto close;
259 
260 	return s;
261 close:
262 	xclose(s);
263 	return -1;
264 }
265 
socket_loopback(int family,int sotype)266 static inline int socket_loopback(int family, int sotype)
267 {
268 	return socket_loopback_reuseport(family, sotype, -1);
269 }
270 
poll_connect(int fd,unsigned int timeout_sec)271 static inline int poll_connect(int fd, unsigned int timeout_sec)
272 {
273 	struct timeval timeout = { .tv_sec = timeout_sec };
274 	fd_set wfds;
275 	int r, eval;
276 	socklen_t esize = sizeof(eval);
277 
278 	FD_ZERO(&wfds);
279 	FD_SET(fd, &wfds);
280 
281 	r = select(fd + 1, NULL, &wfds, NULL, &timeout);
282 	if (r == 0)
283 		errno = ETIME;
284 	if (r != 1)
285 		return -1;
286 
287 	if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
288 		return -1;
289 	if (eval != 0) {
290 		errno = eval;
291 		return -1;
292 	}
293 
294 	return 0;
295 }
296 
poll_read(int fd,unsigned int timeout_sec)297 static inline int poll_read(int fd, unsigned int timeout_sec)
298 {
299 	struct timeval timeout = { .tv_sec = timeout_sec };
300 	fd_set rfds;
301 	int r;
302 
303 	FD_ZERO(&rfds);
304 	FD_SET(fd, &rfds);
305 
306 	r = select(fd + 1, &rfds, NULL, NULL, &timeout);
307 	if (r == 0)
308 		errno = ETIME;
309 
310 	return r == 1 ? 0 : -1;
311 }
312 
accept_timeout(int fd,struct sockaddr * addr,socklen_t * len,unsigned int timeout_sec)313 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
314 				 unsigned int timeout_sec)
315 {
316 	if (poll_read(fd, timeout_sec))
317 		return -1;
318 
319 	return accept(fd, addr, len);
320 }
321 
recv_timeout(int fd,void * buf,size_t len,int flags,unsigned int timeout_sec)322 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
323 			       unsigned int timeout_sec)
324 {
325 	if (poll_read(fd, timeout_sec))
326 		return -1;
327 
328 	return recv(fd, buf, len, flags);
329 }
330 
331 
create_pair(int family,int sotype,int * p0,int * p1)332 static inline int create_pair(int family, int sotype, int *p0, int *p1)
333 {
334 	__close_fd int s, c = -1, p = -1;
335 	struct sockaddr_storage addr;
336 	socklen_t len;
337 	int err;
338 
339 	s = socket_loopback(family, sotype);
340 	if (s < 0)
341 		return s;
342 
343 	c = xsocket(family, sotype, 0);
344 	if (c < 0)
345 		return c;
346 
347 	init_addr_loopback(family, &addr, &len);
348 	err = xbind(c, sockaddr(&addr), len);
349 	if (err)
350 		return err;
351 
352 	len = sizeof(addr);
353 	err = xgetsockname(s, sockaddr(&addr), &len);
354 	if (err)
355 		return err;
356 
357 	err = connect(c, sockaddr(&addr), len);
358 	if (err) {
359 		if (errno != EINPROGRESS) {
360 			FAIL_ERRNO("connect");
361 			return err;
362 		}
363 
364 		err = poll_connect(c, IO_TIMEOUT_SEC);
365 		if (err) {
366 			FAIL_ERRNO("poll_connect");
367 			return err;
368 		}
369 	}
370 
371 	switch (sotype & SOCK_TYPE_MASK) {
372 	case SOCK_DGRAM:
373 		err = xgetsockname(c, sockaddr(&addr), &len);
374 		if (err)
375 			return err;
376 
377 		err = xconnect(s, sockaddr(&addr), len);
378 		if (err)
379 			return err;
380 
381 		*p0 = take_fd(s);
382 		break;
383 	case SOCK_STREAM:
384 	case SOCK_SEQPACKET:
385 		p = xaccept_nonblock(s, NULL, NULL);
386 		if (p < 0)
387 			return p;
388 
389 		*p0 = take_fd(p);
390 		break;
391 	default:
392 		FAIL("Unsupported socket type %#x", sotype);
393 		return -EOPNOTSUPP;
394 	}
395 
396 	*p1 = take_fd(c);
397 	return 0;
398 }
399 
create_socket_pairs(int family,int sotype,int * c0,int * c1,int * p0,int * p1)400 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
401 				      int *p0, int *p1)
402 {
403 	int err;
404 
405 	err = create_pair(family, sotype, c0, p0);
406 	if (err)
407 		return err;
408 
409 	err = create_pair(family, sotype, c1, p1);
410 	if (err) {
411 		close(*c0);
412 		close(*p0);
413 	}
414 
415 	return err;
416 }
417 
socket_kind_to_str(int sock_fd)418 static inline const char *socket_kind_to_str(int sock_fd)
419 {
420 	socklen_t opt_len;
421 	int domain, type;
422 
423 	opt_len = sizeof(domain);
424 	if (getsockopt(sock_fd, SOL_SOCKET, SO_DOMAIN, &domain, &opt_len))
425 		FAIL_ERRNO("getsockopt(SO_DOMAIN)");
426 
427 	opt_len = sizeof(type);
428 	if (getsockopt(sock_fd, SOL_SOCKET, SO_TYPE, &type, &opt_len))
429 		FAIL_ERRNO("getsockopt(SO_TYPE)");
430 
431 	switch (domain) {
432 	case AF_INET:
433 		switch (type) {
434 		case SOCK_STREAM:
435 			return "tcp4";
436 		case SOCK_DGRAM:
437 			return "udp4";
438 		}
439 		break;
440 	case AF_INET6:
441 		switch (type) {
442 		case SOCK_STREAM:
443 			return "tcp6";
444 		case SOCK_DGRAM:
445 			return "udp6";
446 		}
447 		break;
448 	case AF_UNIX:
449 		switch (type) {
450 		case SOCK_STREAM:
451 			return "u_str";
452 		case SOCK_DGRAM:
453 			return "u_dgr";
454 		case SOCK_SEQPACKET:
455 			return "u_seq";
456 		}
457 		break;
458 	case AF_VSOCK:
459 		switch (type) {
460 		case SOCK_STREAM:
461 			return "v_str";
462 		case SOCK_DGRAM:
463 			return "v_dgr";
464 		case SOCK_SEQPACKET:
465 			return "v_seq";
466 		}
467 		break;
468 	}
469 
470 	return "???";
471 }
472 
473 #endif // __SOCKET_HELPERS__
474