xref: /linux/tools/testing/selftests/bpf/prog_tests/socket_helpers.h (revision 90b83efa6701656e02c86e7df2cb1765ea602d07)
1 /* SPDX-License-Identifier: GPL-2.0 */
2 
3 #ifndef __SOCKET_HELPERS__
4 #define __SOCKET_HELPERS__
5 
6 #include <sys/un.h>
7 #include <linux/vm_sockets.h>
8 
9 /* include/linux/net.h */
10 #define SOCK_TYPE_MASK 0xf
11 
12 #define IO_TIMEOUT_SEC 30
13 #define MAX_STRERR_LEN 256
14 
15 /* workaround for older vm_sockets.h */
16 #ifndef VMADDR_CID_LOCAL
17 #define VMADDR_CID_LOCAL 1
18 #endif
19 
20 /* include/linux/cleanup.h */
21 #define __get_and_null(p, nullvalue)                                           \
22 	({                                                                     \
23 		__auto_type __ptr = &(p);                                      \
24 		__auto_type __val = *__ptr;                                    \
25 		*__ptr = nullvalue;                                            \
26 		__val;                                                         \
27 	})
28 
29 #define take_fd(fd) __get_and_null(fd, -EBADF)
30 
31 /* Wrappers that fail the test on error and report it. */
32 
33 #define _FAIL(errnum, fmt...)                                                  \
34 	({                                                                     \
35 		error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
36 		CHECK_FAIL(true);                                              \
37 	})
38 #define FAIL(fmt...) _FAIL(0, fmt)
39 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
40 #define FAIL_LIBBPF(err, msg)                                                  \
41 	({                                                                     \
42 		char __buf[MAX_STRERR_LEN];                                    \
43 		libbpf_strerror((err), __buf, sizeof(__buf));                  \
44 		FAIL("%s: %s", (msg), __buf);                                  \
45 	})
46 
47 
48 #define xaccept_nonblock(fd, addr, len)                                        \
49 	({                                                                     \
50 		int __ret =                                                    \
51 			accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC);   \
52 		if (__ret == -1)                                               \
53 			FAIL_ERRNO("accept");                                  \
54 		__ret;                                                         \
55 	})
56 
57 #define xbind(fd, addr, len)                                                   \
58 	({                                                                     \
59 		int __ret = bind((fd), (addr), (len));                         \
60 		if (__ret == -1)                                               \
61 			FAIL_ERRNO("bind");                                    \
62 		__ret;                                                         \
63 	})
64 
65 #define xclose(fd)                                                             \
66 	({                                                                     \
67 		int __ret = close((fd));                                       \
68 		if (__ret == -1)                                               \
69 			FAIL_ERRNO("close");                                   \
70 		__ret;                                                         \
71 	})
72 
73 #define xconnect(fd, addr, len)                                                \
74 	({                                                                     \
75 		int __ret = connect((fd), (addr), (len));                      \
76 		if (__ret == -1)                                               \
77 			FAIL_ERRNO("connect");                                 \
78 		__ret;                                                         \
79 	})
80 
81 #define xgetsockname(fd, addr, len)                                            \
82 	({                                                                     \
83 		int __ret = getsockname((fd), (addr), (len));                  \
84 		if (__ret == -1)                                               \
85 			FAIL_ERRNO("getsockname");                             \
86 		__ret;                                                         \
87 	})
88 
89 #define xgetsockopt(fd, level, name, val, len)                                 \
90 	({                                                                     \
91 		int __ret = getsockopt((fd), (level), (name), (val), (len));   \
92 		if (__ret == -1)                                               \
93 			FAIL_ERRNO("getsockopt(" #name ")");                   \
94 		__ret;                                                         \
95 	})
96 
97 #define xlisten(fd, backlog)                                                   \
98 	({                                                                     \
99 		int __ret = listen((fd), (backlog));                           \
100 		if (__ret == -1)                                               \
101 			FAIL_ERRNO("listen");                                  \
102 		__ret;                                                         \
103 	})
104 
105 #define xsetsockopt(fd, level, name, val, len)                                 \
106 	({                                                                     \
107 		int __ret = setsockopt((fd), (level), (name), (val), (len));   \
108 		if (__ret == -1)                                               \
109 			FAIL_ERRNO("setsockopt(" #name ")");                   \
110 		__ret;                                                         \
111 	})
112 
113 #define xsend(fd, buf, len, flags)                                             \
114 	({                                                                     \
115 		ssize_t __ret = send((fd), (buf), (len), (flags));             \
116 		if (__ret == -1)                                               \
117 			FAIL_ERRNO("send");                                    \
118 		__ret;                                                         \
119 	})
120 
121 #define xrecv_nonblock(fd, buf, len, flags)                                    \
122 	({                                                                     \
123 		ssize_t __ret = recv_timeout((fd), (buf), (len), (flags),      \
124 					     IO_TIMEOUT_SEC);                  \
125 		if (__ret == -1)                                               \
126 			FAIL_ERRNO("recv");                                    \
127 		__ret;                                                         \
128 	})
129 
130 #define xsocket(family, sotype, flags)                                         \
131 	({                                                                     \
132 		int __ret = socket(family, sotype, flags);                     \
133 		if (__ret == -1)                                               \
134 			FAIL_ERRNO("socket");                                  \
135 		__ret;                                                         \
136 	})
137 
138 static inline void close_fd(int *fd)
139 {
140 	if (*fd >= 0)
141 		xclose(*fd);
142 }
143 
144 #define __close_fd __attribute__((cleanup(close_fd)))
145 
146 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
147 {
148 	return (struct sockaddr *)ss;
149 }
150 
151 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
152 				       socklen_t *len)
153 {
154 	struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
155 
156 	addr4->sin_family = AF_INET;
157 	addr4->sin_port = 0;
158 	addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
159 	*len = sizeof(*addr4);
160 }
161 
162 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
163 				       socklen_t *len)
164 {
165 	struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
166 
167 	addr6->sin6_family = AF_INET6;
168 	addr6->sin6_port = 0;
169 	addr6->sin6_addr = in6addr_loopback;
170 	*len = sizeof(*addr6);
171 }
172 
173 static inline void init_addr_loopback_unix(struct sockaddr_storage *ss,
174 					   socklen_t *len)
175 {
176 	struct sockaddr_un *addr = memset(ss, 0, sizeof(*ss));
177 
178 	addr->sun_family = AF_UNIX;
179 	*len = sizeof(sa_family_t);
180 }
181 
182 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
183 					    socklen_t *len)
184 {
185 	struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
186 
187 	addr->svm_family = AF_VSOCK;
188 	addr->svm_port = VMADDR_PORT_ANY;
189 	addr->svm_cid = VMADDR_CID_LOCAL;
190 	*len = sizeof(*addr);
191 }
192 
193 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
194 				      socklen_t *len)
195 {
196 	switch (family) {
197 	case AF_INET:
198 		init_addr_loopback4(ss, len);
199 		return;
200 	case AF_INET6:
201 		init_addr_loopback6(ss, len);
202 		return;
203 	case AF_UNIX:
204 		init_addr_loopback_unix(ss, len);
205 		return;
206 	case AF_VSOCK:
207 		init_addr_loopback_vsock(ss, len);
208 		return;
209 	default:
210 		FAIL("unsupported address family %d", family);
211 	}
212 }
213 
214 static inline int enable_reuseport(int s, int progfd)
215 {
216 	int err, one = 1;
217 
218 	err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
219 	if (err)
220 		return -1;
221 	err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
222 			  sizeof(progfd));
223 	if (err)
224 		return -1;
225 
226 	return 0;
227 }
228 
229 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
230 {
231 	struct sockaddr_storage addr;
232 	socklen_t len = 0;
233 	int err, s;
234 
235 	init_addr_loopback(family, &addr, &len);
236 
237 	s = xsocket(family, sotype, 0);
238 	if (s == -1)
239 		return -1;
240 
241 	if (progfd >= 0)
242 		enable_reuseport(s, progfd);
243 
244 	err = xbind(s, sockaddr(&addr), len);
245 	if (err)
246 		goto close;
247 
248 	if (sotype & SOCK_DGRAM)
249 		return s;
250 
251 	err = xlisten(s, SOMAXCONN);
252 	if (err)
253 		goto close;
254 
255 	return s;
256 close:
257 	xclose(s);
258 	return -1;
259 }
260 
261 static inline int socket_loopback(int family, int sotype)
262 {
263 	return socket_loopback_reuseport(family, sotype, -1);
264 }
265 
266 static inline int poll_connect(int fd, unsigned int timeout_sec)
267 {
268 	struct timeval timeout = { .tv_sec = timeout_sec };
269 	fd_set wfds;
270 	int r, eval;
271 	socklen_t esize = sizeof(eval);
272 
273 	FD_ZERO(&wfds);
274 	FD_SET(fd, &wfds);
275 
276 	r = select(fd + 1, NULL, &wfds, NULL, &timeout);
277 	if (r == 0)
278 		errno = ETIME;
279 	if (r != 1)
280 		return -1;
281 
282 	if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
283 		return -1;
284 	if (eval != 0) {
285 		errno = eval;
286 		return -1;
287 	}
288 
289 	return 0;
290 }
291 
292 static inline int poll_read(int fd, unsigned int timeout_sec)
293 {
294 	struct timeval timeout = { .tv_sec = timeout_sec };
295 	fd_set rfds;
296 	int r;
297 
298 	FD_ZERO(&rfds);
299 	FD_SET(fd, &rfds);
300 
301 	r = select(fd + 1, &rfds, NULL, NULL, &timeout);
302 	if (r == 0)
303 		errno = ETIME;
304 
305 	return r == 1 ? 0 : -1;
306 }
307 
308 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
309 				 unsigned int timeout_sec)
310 {
311 	if (poll_read(fd, timeout_sec))
312 		return -1;
313 
314 	return accept(fd, addr, len);
315 }
316 
317 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
318 			       unsigned int timeout_sec)
319 {
320 	if (poll_read(fd, timeout_sec))
321 		return -1;
322 
323 	return recv(fd, buf, len, flags);
324 }
325 
326 
327 static inline int create_pair(int family, int sotype, int *p0, int *p1)
328 {
329 	__close_fd int s, c = -1, p = -1;
330 	struct sockaddr_storage addr;
331 	socklen_t len;
332 	int err;
333 
334 	s = socket_loopback(family, sotype);
335 	if (s < 0)
336 		return s;
337 
338 	c = xsocket(family, sotype, 0);
339 	if (c < 0)
340 		return c;
341 
342 	init_addr_loopback(family, &addr, &len);
343 	err = xbind(c, sockaddr(&addr), len);
344 	if (err)
345 		return err;
346 
347 	len = sizeof(addr);
348 	err = xgetsockname(s, sockaddr(&addr), &len);
349 	if (err)
350 		return err;
351 
352 	err = connect(c, sockaddr(&addr), len);
353 	if (err) {
354 		if (errno != EINPROGRESS) {
355 			FAIL_ERRNO("connect");
356 			return err;
357 		}
358 
359 		err = poll_connect(c, IO_TIMEOUT_SEC);
360 		if (err) {
361 			FAIL_ERRNO("poll_connect");
362 			return err;
363 		}
364 	}
365 
366 	switch (sotype & SOCK_TYPE_MASK) {
367 	case SOCK_DGRAM:
368 		err = xgetsockname(c, sockaddr(&addr), &len);
369 		if (err)
370 			return err;
371 
372 		err = xconnect(s, sockaddr(&addr), len);
373 		if (err)
374 			return err;
375 
376 		*p0 = take_fd(s);
377 		break;
378 	case SOCK_STREAM:
379 	case SOCK_SEQPACKET:
380 		p = xaccept_nonblock(s, NULL, NULL);
381 		if (p < 0)
382 			return p;
383 
384 		*p0 = take_fd(p);
385 		break;
386 	default:
387 		FAIL("Unsupported socket type %#x", sotype);
388 		return -EOPNOTSUPP;
389 	}
390 
391 	*p1 = take_fd(c);
392 	return 0;
393 }
394 
395 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
396 				      int *p0, int *p1)
397 {
398 	int err;
399 
400 	err = create_pair(family, sotype, c0, p0);
401 	if (err)
402 		return err;
403 
404 	err = create_pair(family, sotype, c1, p1);
405 	if (err) {
406 		close(*c0);
407 		close(*p0);
408 	}
409 
410 	return err;
411 }
412 
413 static inline const char *socket_kind_to_str(int sock_fd)
414 {
415 	socklen_t opt_len;
416 	int domain, type;
417 
418 	opt_len = sizeof(domain);
419 	if (getsockopt(sock_fd, SOL_SOCKET, SO_DOMAIN, &domain, &opt_len))
420 		FAIL_ERRNO("getsockopt(SO_DOMAIN)");
421 
422 	opt_len = sizeof(type);
423 	if (getsockopt(sock_fd, SOL_SOCKET, SO_TYPE, &type, &opt_len))
424 		FAIL_ERRNO("getsockopt(SO_TYPE)");
425 
426 	switch (domain) {
427 	case AF_INET:
428 		switch (type) {
429 		case SOCK_STREAM:
430 			return "tcp4";
431 		case SOCK_DGRAM:
432 			return "udp4";
433 		}
434 		break;
435 	case AF_INET6:
436 		switch (type) {
437 		case SOCK_STREAM:
438 			return "tcp6";
439 		case SOCK_DGRAM:
440 			return "udp6";
441 		}
442 		break;
443 	case AF_UNIX:
444 		switch (type) {
445 		case SOCK_STREAM:
446 			return "u_str";
447 		case SOCK_DGRAM:
448 			return "u_dgr";
449 		case SOCK_SEQPACKET:
450 			return "u_seq";
451 		}
452 		break;
453 	case AF_VSOCK:
454 		switch (type) {
455 		case SOCK_STREAM:
456 			return "v_str";
457 		case SOCK_DGRAM:
458 			return "v_dgr";
459 		case SOCK_SEQPACKET:
460 			return "v_seq";
461 		}
462 		break;
463 	}
464 
465 	return "???";
466 }
467 
468 #endif // __SOCKET_HELPERS__
469