xref: /linux/tools/testing/selftests/bpf/network_helpers.c (revision a6a6a98094116b60e5523a571d9443c53325f5b1)
1 // SPDX-License-Identifier: GPL-2.0-only
2 #define _GNU_SOURCE
3 
4 #include <errno.h>
5 #include <stdbool.h>
6 #include <stdio.h>
7 #include <string.h>
8 #include <unistd.h>
9 #include <sched.h>
10 
11 #include <arpa/inet.h>
12 #include <sys/mount.h>
13 #include <sys/stat.h>
14 #include <sys/un.h>
15 
16 #include <linux/err.h>
17 #include <linux/in.h>
18 #include <linux/in6.h>
19 #include <linux/limits.h>
20 
21 #include "bpf_util.h"
22 #include "network_helpers.h"
23 #include "test_progs.h"
24 
25 #ifndef IPPROTO_MPTCP
26 #define IPPROTO_MPTCP 262
27 #endif
28 
29 #define clean_errno() (errno == 0 ? "None" : strerror(errno))
30 #define log_err(MSG, ...) ({						\
31 			int __save = errno;				\
32 			fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \
33 				__FILE__, __LINE__, clean_errno(),	\
34 				##__VA_ARGS__);				\
35 			errno = __save;					\
36 })
37 
38 struct ipv4_packet pkt_v4 = {
39 	.eth.h_proto = __bpf_constant_htons(ETH_P_IP),
40 	.iph.ihl = 5,
41 	.iph.protocol = IPPROTO_TCP,
42 	.iph.tot_len = __bpf_constant_htons(MAGIC_BYTES),
43 	.tcp.urg_ptr = 123,
44 	.tcp.doff = 5,
45 };
46 
47 struct ipv6_packet pkt_v6 = {
48 	.eth.h_proto = __bpf_constant_htons(ETH_P_IPV6),
49 	.iph.nexthdr = IPPROTO_TCP,
50 	.iph.payload_len = __bpf_constant_htons(MAGIC_BYTES),
51 	.tcp.urg_ptr = 123,
52 	.tcp.doff = 5,
53 };
54 
55 static const struct network_helper_opts default_opts;
56 
57 int settimeo(int fd, int timeout_ms)
58 {
59 	struct timeval timeout = { .tv_sec = 3 };
60 
61 	if (timeout_ms > 0) {
62 		timeout.tv_sec = timeout_ms / 1000;
63 		timeout.tv_usec = (timeout_ms % 1000) * 1000;
64 	}
65 
66 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
67 		       sizeof(timeout))) {
68 		log_err("Failed to set SO_RCVTIMEO");
69 		return -1;
70 	}
71 
72 	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout,
73 		       sizeof(timeout))) {
74 		log_err("Failed to set SO_SNDTIMEO");
75 		return -1;
76 	}
77 
78 	return 0;
79 }
80 
81 #define save_errno_close(fd) ({ int __save = errno; close(fd); errno = __save; })
82 
83 static int __start_server(int type, const struct sockaddr *addr, socklen_t addrlen,
84 			  const struct network_helper_opts *opts)
85 {
86 	int fd;
87 
88 	fd = socket(addr->sa_family, type, opts->proto);
89 	if (fd < 0) {
90 		log_err("Failed to create server socket");
91 		return -1;
92 	}
93 
94 	if (settimeo(fd, opts->timeout_ms))
95 		goto error_close;
96 
97 	if (opts->post_socket_cb &&
98 	    opts->post_socket_cb(fd, opts->cb_opts)) {
99 		log_err("Failed to call post_socket_cb");
100 		goto error_close;
101 	}
102 
103 	if (bind(fd, addr, addrlen) < 0) {
104 		log_err("Failed to bind socket");
105 		goto error_close;
106 	}
107 
108 	if (type == SOCK_STREAM) {
109 		if (listen(fd, 1) < 0) {
110 			log_err("Failed to listed on socket");
111 			goto error_close;
112 		}
113 	}
114 
115 	return fd;
116 
117 error_close:
118 	save_errno_close(fd);
119 	return -1;
120 }
121 
122 int start_server_str(int family, int type, const char *addr_str, __u16 port,
123 		     const struct network_helper_opts *opts)
124 {
125 	struct sockaddr_storage addr;
126 	socklen_t addrlen;
127 
128 	if (!opts)
129 		opts = &default_opts;
130 
131 	if (make_sockaddr(family, addr_str, port, &addr, &addrlen))
132 		return -1;
133 
134 	return __start_server(type, (struct sockaddr *)&addr, addrlen, opts);
135 }
136 
137 int start_server(int family, int type, const char *addr_str, __u16 port,
138 		 int timeout_ms)
139 {
140 	struct network_helper_opts opts = {
141 		.timeout_ms	= timeout_ms,
142 	};
143 
144 	return start_server_str(family, type, addr_str, port, &opts);
145 }
146 
147 static int reuseport_cb(int fd, void *opts)
148 {
149 	int on = 1;
150 
151 	return setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &on, sizeof(on));
152 }
153 
154 int *start_reuseport_server(int family, int type, const char *addr_str,
155 			    __u16 port, int timeout_ms, unsigned int nr_listens)
156 {
157 	struct network_helper_opts opts = {
158 		.timeout_ms = timeout_ms,
159 		.post_socket_cb = reuseport_cb,
160 	};
161 	struct sockaddr_storage addr;
162 	unsigned int nr_fds = 0;
163 	socklen_t addrlen;
164 	int *fds;
165 
166 	if (!nr_listens)
167 		return NULL;
168 
169 	if (make_sockaddr(family, addr_str, port, &addr, &addrlen))
170 		return NULL;
171 
172 	fds = malloc(sizeof(*fds) * nr_listens);
173 	if (!fds)
174 		return NULL;
175 
176 	fds[0] = __start_server(type, (struct sockaddr *)&addr, addrlen, &opts);
177 	if (fds[0] == -1)
178 		goto close_fds;
179 	nr_fds = 1;
180 
181 	if (getsockname(fds[0], (struct sockaddr *)&addr, &addrlen))
182 		goto close_fds;
183 
184 	for (; nr_fds < nr_listens; nr_fds++) {
185 		fds[nr_fds] = __start_server(type, (struct sockaddr *)&addr, addrlen, &opts);
186 		if (fds[nr_fds] == -1)
187 			goto close_fds;
188 	}
189 
190 	return fds;
191 
192 close_fds:
193 	free_fds(fds, nr_fds);
194 	return NULL;
195 }
196 
197 int start_server_addr(int type, const struct sockaddr_storage *addr, socklen_t len,
198 		      const struct network_helper_opts *opts)
199 {
200 	if (!opts)
201 		opts = &default_opts;
202 
203 	return __start_server(type, (struct sockaddr *)addr, len, opts);
204 }
205 
206 void free_fds(int *fds, unsigned int nr_close_fds)
207 {
208 	if (fds) {
209 		while (nr_close_fds)
210 			close(fds[--nr_close_fds]);
211 		free(fds);
212 	}
213 }
214 
215 int fastopen_connect(int server_fd, const char *data, unsigned int data_len,
216 		     int timeout_ms)
217 {
218 	struct sockaddr_storage addr;
219 	socklen_t addrlen = sizeof(addr);
220 	struct sockaddr_in *addr_in;
221 	int fd, ret;
222 
223 	if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) {
224 		log_err("Failed to get server addr");
225 		return -1;
226 	}
227 
228 	addr_in = (struct sockaddr_in *)&addr;
229 	fd = socket(addr_in->sin_family, SOCK_STREAM, 0);
230 	if (fd < 0) {
231 		log_err("Failed to create client socket");
232 		return -1;
233 	}
234 
235 	if (settimeo(fd, timeout_ms))
236 		goto error_close;
237 
238 	ret = sendto(fd, data, data_len, MSG_FASTOPEN, (struct sockaddr *)&addr,
239 		     addrlen);
240 	if (ret != data_len) {
241 		log_err("sendto(data, %u) != %d\n", data_len, ret);
242 		goto error_close;
243 	}
244 
245 	return fd;
246 
247 error_close:
248 	save_errno_close(fd);
249 	return -1;
250 }
251 
252 static int connect_fd_to_addr(int fd,
253 			      const struct sockaddr_storage *addr,
254 			      socklen_t addrlen, const bool must_fail)
255 {
256 	int ret;
257 
258 	errno = 0;
259 	ret = connect(fd, (const struct sockaddr *)addr, addrlen);
260 	if (must_fail) {
261 		if (!ret) {
262 			log_err("Unexpected success to connect to server");
263 			return -1;
264 		}
265 		if (errno != EPERM) {
266 			log_err("Unexpected error from connect to server");
267 			return -1;
268 		}
269 	} else {
270 		if (ret) {
271 			log_err("Failed to connect to server");
272 			return -1;
273 		}
274 	}
275 
276 	return 0;
277 }
278 
279 int connect_to_addr(int type, const struct sockaddr_storage *addr, socklen_t addrlen,
280 		    const struct network_helper_opts *opts)
281 {
282 	int fd;
283 
284 	if (!opts)
285 		opts = &default_opts;
286 
287 	fd = socket(addr->ss_family, type, opts->proto);
288 	if (fd < 0) {
289 		log_err("Failed to create client socket");
290 		return -1;
291 	}
292 
293 	if (settimeo(fd, opts->timeout_ms))
294 		goto error_close;
295 
296 	if (connect_fd_to_addr(fd, addr, addrlen, opts->must_fail))
297 		goto error_close;
298 
299 	return fd;
300 
301 error_close:
302 	save_errno_close(fd);
303 	return -1;
304 }
305 
306 int connect_to_fd_opts(int server_fd, const struct network_helper_opts *opts)
307 {
308 	struct sockaddr_storage addr;
309 	struct sockaddr_in *addr_in;
310 	socklen_t addrlen, optlen;
311 	int fd, type, protocol;
312 
313 	if (!opts)
314 		opts = &default_opts;
315 
316 	optlen = sizeof(type);
317 
318 	if (opts->type) {
319 		type = opts->type;
320 	} else {
321 		if (getsockopt(server_fd, SOL_SOCKET, SO_TYPE, &type, &optlen)) {
322 			log_err("getsockopt(SOL_TYPE)");
323 			return -1;
324 		}
325 	}
326 
327 	if (opts->proto) {
328 		protocol = opts->proto;
329 	} else {
330 		if (getsockopt(server_fd, SOL_SOCKET, SO_PROTOCOL, &protocol, &optlen)) {
331 			log_err("getsockopt(SOL_PROTOCOL)");
332 			return -1;
333 		}
334 	}
335 
336 	addrlen = sizeof(addr);
337 	if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) {
338 		log_err("Failed to get server addr");
339 		return -1;
340 	}
341 
342 	addr_in = (struct sockaddr_in *)&addr;
343 	fd = socket(addr_in->sin_family, type, protocol);
344 	if (fd < 0) {
345 		log_err("Failed to create client socket");
346 		return -1;
347 	}
348 
349 	if (settimeo(fd, opts->timeout_ms))
350 		goto error_close;
351 
352 	if (opts->post_socket_cb &&
353 	    opts->post_socket_cb(fd, opts->cb_opts))
354 		goto error_close;
355 
356 	if (!opts->noconnect)
357 		if (connect_fd_to_addr(fd, &addr, addrlen, opts->must_fail))
358 			goto error_close;
359 
360 	return fd;
361 
362 error_close:
363 	save_errno_close(fd);
364 	return -1;
365 }
366 
367 int connect_to_fd(int server_fd, int timeout_ms)
368 {
369 	struct network_helper_opts opts = {
370 		.timeout_ms = timeout_ms,
371 	};
372 
373 	return connect_to_fd_opts(server_fd, &opts);
374 }
375 
376 int connect_fd_to_fd(int client_fd, int server_fd, int timeout_ms)
377 {
378 	struct sockaddr_storage addr;
379 	socklen_t len = sizeof(addr);
380 
381 	if (settimeo(client_fd, timeout_ms))
382 		return -1;
383 
384 	if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
385 		log_err("Failed to get server addr");
386 		return -1;
387 	}
388 
389 	if (connect_fd_to_addr(client_fd, &addr, len, false))
390 		return -1;
391 
392 	return 0;
393 }
394 
395 int make_sockaddr(int family, const char *addr_str, __u16 port,
396 		  struct sockaddr_storage *addr, socklen_t *len)
397 {
398 	if (family == AF_INET) {
399 		struct sockaddr_in *sin = (void *)addr;
400 
401 		memset(addr, 0, sizeof(*sin));
402 		sin->sin_family = AF_INET;
403 		sin->sin_port = htons(port);
404 		if (addr_str &&
405 		    inet_pton(AF_INET, addr_str, &sin->sin_addr) != 1) {
406 			log_err("inet_pton(AF_INET, %s)", addr_str);
407 			return -1;
408 		}
409 		if (len)
410 			*len = sizeof(*sin);
411 		return 0;
412 	} else if (family == AF_INET6) {
413 		struct sockaddr_in6 *sin6 = (void *)addr;
414 
415 		memset(addr, 0, sizeof(*sin6));
416 		sin6->sin6_family = AF_INET6;
417 		sin6->sin6_port = htons(port);
418 		if (addr_str &&
419 		    inet_pton(AF_INET6, addr_str, &sin6->sin6_addr) != 1) {
420 			log_err("inet_pton(AF_INET6, %s)", addr_str);
421 			return -1;
422 		}
423 		if (len)
424 			*len = sizeof(*sin6);
425 		return 0;
426 	} else if (family == AF_UNIX) {
427 		/* Note that we always use abstract unix sockets to avoid having
428 		 * to clean up leftover files.
429 		 */
430 		struct sockaddr_un *sun = (void *)addr;
431 
432 		memset(addr, 0, sizeof(*sun));
433 		sun->sun_family = family;
434 		sun->sun_path[0] = 0;
435 		strcpy(sun->sun_path + 1, addr_str);
436 		if (len)
437 			*len = offsetof(struct sockaddr_un, sun_path) + 1 + strlen(addr_str);
438 		return 0;
439 	}
440 	return -1;
441 }
442 
443 char *ping_command(int family)
444 {
445 	if (family == AF_INET6) {
446 		/* On some systems 'ping' doesn't support IPv6, so use ping6 if it is present. */
447 		if (!system("which ping6 >/dev/null 2>&1"))
448 			return "ping6";
449 		else
450 			return "ping -6";
451 	}
452 	return "ping";
453 }
454 
455 struct nstoken {
456 	int orig_netns_fd;
457 };
458 
459 struct nstoken *open_netns(const char *name)
460 {
461 	int nsfd;
462 	char nspath[PATH_MAX];
463 	int err;
464 	struct nstoken *token;
465 
466 	token = calloc(1, sizeof(struct nstoken));
467 	if (!token) {
468 		log_err("Failed to malloc token");
469 		return NULL;
470 	}
471 
472 	token->orig_netns_fd = open("/proc/self/ns/net", O_RDONLY);
473 	if (token->orig_netns_fd == -1) {
474 		log_err("Failed to open(/proc/self/ns/net)");
475 		goto fail;
476 	}
477 
478 	snprintf(nspath, sizeof(nspath), "%s/%s", "/var/run/netns", name);
479 	nsfd = open(nspath, O_RDONLY | O_CLOEXEC);
480 	if (nsfd == -1) {
481 		log_err("Failed to open(%s)", nspath);
482 		goto fail;
483 	}
484 
485 	err = setns(nsfd, CLONE_NEWNET);
486 	close(nsfd);
487 	if (err) {
488 		log_err("Failed to setns(nsfd)");
489 		goto fail;
490 	}
491 
492 	return token;
493 fail:
494 	if (token->orig_netns_fd != -1)
495 		close(token->orig_netns_fd);
496 	free(token);
497 	return NULL;
498 }
499 
500 void close_netns(struct nstoken *token)
501 {
502 	if (!token)
503 		return;
504 
505 	if (setns(token->orig_netns_fd, CLONE_NEWNET))
506 		log_err("Failed to setns(orig_netns_fd)");
507 	close(token->orig_netns_fd);
508 	free(token);
509 }
510 
511 int get_socket_local_port(int sock_fd)
512 {
513 	struct sockaddr_storage addr;
514 	socklen_t addrlen = sizeof(addr);
515 	int err;
516 
517 	err = getsockname(sock_fd, (struct sockaddr *)&addr, &addrlen);
518 	if (err < 0)
519 		return err;
520 
521 	if (addr.ss_family == AF_INET) {
522 		struct sockaddr_in *sin = (struct sockaddr_in *)&addr;
523 
524 		return sin->sin_port;
525 	} else if (addr.ss_family == AF_INET6) {
526 		struct sockaddr_in6 *sin = (struct sockaddr_in6 *)&addr;
527 
528 		return sin->sin6_port;
529 	}
530 
531 	return -1;
532 }
533 
534 int get_hw_ring_size(char *ifname, struct ethtool_ringparam *ring_param)
535 {
536 	struct ifreq ifr = {0};
537 	int sockfd, err;
538 
539 	sockfd = socket(AF_INET, SOCK_DGRAM, 0);
540 	if (sockfd < 0)
541 		return -errno;
542 
543 	memcpy(ifr.ifr_name, ifname, sizeof(ifr.ifr_name));
544 
545 	ring_param->cmd = ETHTOOL_GRINGPARAM;
546 	ifr.ifr_data = (char *)ring_param;
547 
548 	if (ioctl(sockfd, SIOCETHTOOL, &ifr) < 0) {
549 		err = errno;
550 		close(sockfd);
551 		return -err;
552 	}
553 
554 	close(sockfd);
555 	return 0;
556 }
557 
558 int set_hw_ring_size(char *ifname, struct ethtool_ringparam *ring_param)
559 {
560 	struct ifreq ifr = {0};
561 	int sockfd, err;
562 
563 	sockfd = socket(AF_INET, SOCK_DGRAM, 0);
564 	if (sockfd < 0)
565 		return -errno;
566 
567 	memcpy(ifr.ifr_name, ifname, sizeof(ifr.ifr_name));
568 
569 	ring_param->cmd = ETHTOOL_SRINGPARAM;
570 	ifr.ifr_data = (char *)ring_param;
571 
572 	if (ioctl(sockfd, SIOCETHTOOL, &ifr) < 0) {
573 		err = errno;
574 		close(sockfd);
575 		return -err;
576 	}
577 
578 	close(sockfd);
579 	return 0;
580 }
581 
582 struct send_recv_arg {
583 	int		fd;
584 	uint32_t	bytes;
585 	int		stop;
586 };
587 
588 static void *send_recv_server(void *arg)
589 {
590 	struct send_recv_arg *a = (struct send_recv_arg *)arg;
591 	ssize_t nr_sent = 0, bytes = 0;
592 	char batch[1500];
593 	int err = 0, fd;
594 
595 	fd = accept(a->fd, NULL, NULL);
596 	while (fd == -1) {
597 		if (errno == EINTR)
598 			continue;
599 		err = -errno;
600 		goto done;
601 	}
602 
603 	if (settimeo(fd, 0)) {
604 		err = -errno;
605 		goto done;
606 	}
607 
608 	while (bytes < a->bytes && !READ_ONCE(a->stop)) {
609 		nr_sent = send(fd, &batch,
610 			       MIN(a->bytes - bytes, sizeof(batch)), 0);
611 		if (nr_sent == -1 && errno == EINTR)
612 			continue;
613 		if (nr_sent == -1) {
614 			err = -errno;
615 			break;
616 		}
617 		bytes += nr_sent;
618 	}
619 
620 	if (bytes != a->bytes) {
621 		log_err("send %zd expected %u", bytes, a->bytes);
622 		if (!err)
623 			err = bytes > a->bytes ? -E2BIG : -EINTR;
624 	}
625 
626 done:
627 	if (fd >= 0)
628 		close(fd);
629 	if (err) {
630 		WRITE_ONCE(a->stop, 1);
631 		return ERR_PTR(err);
632 	}
633 	return NULL;
634 }
635 
636 int send_recv_data(int lfd, int fd, uint32_t total_bytes)
637 {
638 	ssize_t nr_recv = 0, bytes = 0;
639 	struct send_recv_arg arg = {
640 		.fd	= lfd,
641 		.bytes	= total_bytes,
642 		.stop	= 0,
643 	};
644 	pthread_t srv_thread;
645 	void *thread_ret;
646 	char batch[1500];
647 	int err = 0;
648 
649 	err = pthread_create(&srv_thread, NULL, send_recv_server, (void *)&arg);
650 	if (err) {
651 		log_err("Failed to pthread_create");
652 		return err;
653 	}
654 
655 	/* recv total_bytes */
656 	while (bytes < total_bytes && !READ_ONCE(arg.stop)) {
657 		nr_recv = recv(fd, &batch,
658 			       MIN(total_bytes - bytes, sizeof(batch)), 0);
659 		if (nr_recv == -1 && errno == EINTR)
660 			continue;
661 		if (nr_recv == -1) {
662 			err = -errno;
663 			break;
664 		}
665 		bytes += nr_recv;
666 	}
667 
668 	if (bytes != total_bytes) {
669 		log_err("recv %zd expected %u", bytes, total_bytes);
670 		if (!err)
671 			err = bytes > total_bytes ? -E2BIG : -EINTR;
672 	}
673 
674 	WRITE_ONCE(arg.stop, 1);
675 	pthread_join(srv_thread, &thread_ret);
676 	if (IS_ERR(thread_ret)) {
677 		log_err("Failed in thread_ret %ld", PTR_ERR(thread_ret));
678 		err = err ? : PTR_ERR(thread_ret);
679 	}
680 
681 	return err;
682 }
683