xref: /linux/tools/testing/selftests/bpf/network_helpers.c (revision 93a3545d812ae7cfe4426374e00a7d8f64ac02e0)
1 // SPDX-License-Identifier: GPL-2.0-only
2 #include <errno.h>
3 #include <stdbool.h>
4 #include <stdio.h>
5 #include <string.h>
6 #include <unistd.h>
7 
8 #include <arpa/inet.h>
9 
10 #include <linux/err.h>
11 #include <linux/in.h>
12 #include <linux/in6.h>
13 
14 #include "bpf_util.h"
15 #include "network_helpers.h"
16 
17 #define clean_errno() (errno == 0 ? "None" : strerror(errno))
18 #define log_err(MSG, ...) ({						\
19 			int __save = errno;				\
20 			fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \
21 				__FILE__, __LINE__, clean_errno(),	\
22 				##__VA_ARGS__);				\
23 			errno = __save;					\
24 })
25 
26 struct ipv4_packet pkt_v4 = {
27 	.eth.h_proto = __bpf_constant_htons(ETH_P_IP),
28 	.iph.ihl = 5,
29 	.iph.protocol = IPPROTO_TCP,
30 	.iph.tot_len = __bpf_constant_htons(MAGIC_BYTES),
31 	.tcp.urg_ptr = 123,
32 	.tcp.doff = 5,
33 };
34 
35 struct ipv6_packet pkt_v6 = {
36 	.eth.h_proto = __bpf_constant_htons(ETH_P_IPV6),
37 	.iph.nexthdr = IPPROTO_TCP,
38 	.iph.payload_len = __bpf_constant_htons(MAGIC_BYTES),
39 	.tcp.urg_ptr = 123,
40 	.tcp.doff = 5,
41 };
42 
43 static int settimeo(int fd, int timeout_ms)
44 {
45 	struct timeval timeout = { .tv_sec = 3 };
46 
47 	if (timeout_ms > 0) {
48 		timeout.tv_sec = timeout_ms / 1000;
49 		timeout.tv_usec = (timeout_ms % 1000) * 1000;
50 	}
51 
52 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout,
53 		       sizeof(timeout))) {
54 		log_err("Failed to set SO_RCVTIMEO");
55 		return -1;
56 	}
57 
58 	if (setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeout,
59 		       sizeof(timeout))) {
60 		log_err("Failed to set SO_SNDTIMEO");
61 		return -1;
62 	}
63 
64 	return 0;
65 }
66 
67 #define save_errno_close(fd) ({ int __save = errno; close(fd); errno = __save; })
68 
69 int start_server(int family, int type, const char *addr_str, __u16 port,
70 		 int timeout_ms)
71 {
72 	struct sockaddr_storage addr = {};
73 	socklen_t len;
74 	int fd;
75 
76 	if (family == AF_INET) {
77 		struct sockaddr_in *sin = (void *)&addr;
78 
79 		sin->sin_family = AF_INET;
80 		sin->sin_port = htons(port);
81 		if (addr_str &&
82 		    inet_pton(AF_INET, addr_str, &sin->sin_addr) != 1) {
83 			log_err("inet_pton(AF_INET, %s)", addr_str);
84 			return -1;
85 		}
86 		len = sizeof(*sin);
87 	} else {
88 		struct sockaddr_in6 *sin6 = (void *)&addr;
89 
90 		sin6->sin6_family = AF_INET6;
91 		sin6->sin6_port = htons(port);
92 		if (addr_str &&
93 		    inet_pton(AF_INET6, addr_str, &sin6->sin6_addr) != 1) {
94 			log_err("inet_pton(AF_INET6, %s)", addr_str);
95 			return -1;
96 		}
97 		len = sizeof(*sin6);
98 	}
99 
100 	fd = socket(family, type, 0);
101 	if (fd < 0) {
102 		log_err("Failed to create server socket");
103 		return -1;
104 	}
105 
106 	if (settimeo(fd, timeout_ms))
107 		goto error_close;
108 
109 	if (bind(fd, (const struct sockaddr *)&addr, len) < 0) {
110 		log_err("Failed to bind socket");
111 		goto error_close;
112 	}
113 
114 	if (type == SOCK_STREAM) {
115 		if (listen(fd, 1) < 0) {
116 			log_err("Failed to listed on socket");
117 			goto error_close;
118 		}
119 	}
120 
121 	return fd;
122 
123 error_close:
124 	save_errno_close(fd);
125 	return -1;
126 }
127 
128 static int connect_fd_to_addr(int fd,
129 			      const struct sockaddr_storage *addr,
130 			      socklen_t addrlen)
131 {
132 	if (connect(fd, (const struct sockaddr *)addr, addrlen)) {
133 		log_err("Failed to connect to server");
134 		return -1;
135 	}
136 
137 	return 0;
138 }
139 
140 int connect_to_fd(int server_fd, int timeout_ms)
141 {
142 	struct sockaddr_storage addr;
143 	struct sockaddr_in *addr_in;
144 	socklen_t addrlen, optlen;
145 	int fd, type;
146 
147 	optlen = sizeof(type);
148 	if (getsockopt(server_fd, SOL_SOCKET, SO_TYPE, &type, &optlen)) {
149 		log_err("getsockopt(SOL_TYPE)");
150 		return -1;
151 	}
152 
153 	addrlen = sizeof(addr);
154 	if (getsockname(server_fd, (struct sockaddr *)&addr, &addrlen)) {
155 		log_err("Failed to get server addr");
156 		return -1;
157 	}
158 
159 	addr_in = (struct sockaddr_in *)&addr;
160 	fd = socket(addr_in->sin_family, type, 0);
161 	if (fd < 0) {
162 		log_err("Failed to create client socket");
163 		return -1;
164 	}
165 
166 	if (settimeo(fd, timeout_ms))
167 		goto error_close;
168 
169 	if (connect_fd_to_addr(fd, &addr, addrlen))
170 		goto error_close;
171 
172 	return fd;
173 
174 error_close:
175 	save_errno_close(fd);
176 	return -1;
177 }
178 
179 int connect_fd_to_fd(int client_fd, int server_fd, int timeout_ms)
180 {
181 	struct sockaddr_storage addr;
182 	socklen_t len = sizeof(addr);
183 
184 	if (settimeo(client_fd, timeout_ms))
185 		return -1;
186 
187 	if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
188 		log_err("Failed to get server addr");
189 		return -1;
190 	}
191 
192 	if (connect_fd_to_addr(client_fd, &addr, len))
193 		return -1;
194 
195 	return 0;
196 }
197