xref: /linux/tools/testing/selftests/bpf/prog_tests/socket_helpers.h (revision 643e2e259c2b25a2af0ae4c23c6e16586d9fd19c)
1 /* SPDX-License-Identifier: GPL-2.0 */
2 
3 #ifndef __SOCKET_HELPERS__
4 #define __SOCKET_HELPERS__
5 
6 #include <linux/vm_sockets.h>
7 
8 /* include/linux/net.h */
9 #define SOCK_TYPE_MASK 0xf
10 
11 #define IO_TIMEOUT_SEC 30
12 #define MAX_STRERR_LEN 256
13 
14 /* workaround for older vm_sockets.h */
15 #ifndef VMADDR_CID_LOCAL
16 #define VMADDR_CID_LOCAL 1
17 #endif
18 
19 /* include/linux/cleanup.h */
20 #define __get_and_null(p, nullvalue)                                           \
21 	({                                                                     \
22 		__auto_type __ptr = &(p);                                      \
23 		__auto_type __val = *__ptr;                                    \
24 		*__ptr = nullvalue;                                            \
25 		__val;                                                         \
26 	})
27 
28 #define take_fd(fd) __get_and_null(fd, -EBADF)
29 
30 /* Wrappers that fail the test on error and report it. */
31 
32 #define _FAIL(errnum, fmt...)                                                  \
33 	({                                                                     \
34 		error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
35 		CHECK_FAIL(true);                                              \
36 	})
37 #define FAIL(fmt...) _FAIL(0, fmt)
38 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
39 #define FAIL_LIBBPF(err, msg)                                                  \
40 	({                                                                     \
41 		char __buf[MAX_STRERR_LEN];                                    \
42 		libbpf_strerror((err), __buf, sizeof(__buf));                  \
43 		FAIL("%s: %s", (msg), __buf);                                  \
44 	})
45 
46 
47 #define xaccept_nonblock(fd, addr, len)                                        \
48 	({                                                                     \
49 		int __ret =                                                    \
50 			accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC);   \
51 		if (__ret == -1)                                               \
52 			FAIL_ERRNO("accept");                                  \
53 		__ret;                                                         \
54 	})
55 
56 #define xbind(fd, addr, len)                                                   \
57 	({                                                                     \
58 		int __ret = bind((fd), (addr), (len));                         \
59 		if (__ret == -1)                                               \
60 			FAIL_ERRNO("bind");                                    \
61 		__ret;                                                         \
62 	})
63 
64 #define xclose(fd)                                                             \
65 	({                                                                     \
66 		int __ret = close((fd));                                       \
67 		if (__ret == -1)                                               \
68 			FAIL_ERRNO("close");                                   \
69 		__ret;                                                         \
70 	})
71 
72 #define xconnect(fd, addr, len)                                                \
73 	({                                                                     \
74 		int __ret = connect((fd), (addr), (len));                      \
75 		if (__ret == -1)                                               \
76 			FAIL_ERRNO("connect");                                 \
77 		__ret;                                                         \
78 	})
79 
80 #define xgetsockname(fd, addr, len)                                            \
81 	({                                                                     \
82 		int __ret = getsockname((fd), (addr), (len));                  \
83 		if (__ret == -1)                                               \
84 			FAIL_ERRNO("getsockname");                             \
85 		__ret;                                                         \
86 	})
87 
88 #define xgetsockopt(fd, level, name, val, len)                                 \
89 	({                                                                     \
90 		int __ret = getsockopt((fd), (level), (name), (val), (len));   \
91 		if (__ret == -1)                                               \
92 			FAIL_ERRNO("getsockopt(" #name ")");                   \
93 		__ret;                                                         \
94 	})
95 
96 #define xlisten(fd, backlog)                                                   \
97 	({                                                                     \
98 		int __ret = listen((fd), (backlog));                           \
99 		if (__ret == -1)                                               \
100 			FAIL_ERRNO("listen");                                  \
101 		__ret;                                                         \
102 	})
103 
104 #define xsetsockopt(fd, level, name, val, len)                                 \
105 	({                                                                     \
106 		int __ret = setsockopt((fd), (level), (name), (val), (len));   \
107 		if (__ret == -1)                                               \
108 			FAIL_ERRNO("setsockopt(" #name ")");                   \
109 		__ret;                                                         \
110 	})
111 
112 #define xsend(fd, buf, len, flags)                                             \
113 	({                                                                     \
114 		ssize_t __ret = send((fd), (buf), (len), (flags));             \
115 		if (__ret == -1)                                               \
116 			FAIL_ERRNO("send");                                    \
117 		__ret;                                                         \
118 	})
119 
120 #define xrecv_nonblock(fd, buf, len, flags)                                    \
121 	({                                                                     \
122 		ssize_t __ret = recv_timeout((fd), (buf), (len), (flags),      \
123 					     IO_TIMEOUT_SEC);                  \
124 		if (__ret == -1)                                               \
125 			FAIL_ERRNO("recv");                                    \
126 		__ret;                                                         \
127 	})
128 
129 #define xsocket(family, sotype, flags)                                         \
130 	({                                                                     \
131 		int __ret = socket(family, sotype, flags);                     \
132 		if (__ret == -1)                                               \
133 			FAIL_ERRNO("socket");                                  \
134 		__ret;                                                         \
135 	})
136 
137 static inline void close_fd(int *fd)
138 {
139 	if (*fd >= 0)
140 		xclose(*fd);
141 }
142 
143 #define __close_fd __attribute__((cleanup(close_fd)))
144 
145 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
146 {
147 	return (struct sockaddr *)ss;
148 }
149 
150 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
151 				       socklen_t *len)
152 {
153 	struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
154 
155 	addr4->sin_family = AF_INET;
156 	addr4->sin_port = 0;
157 	addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
158 	*len = sizeof(*addr4);
159 }
160 
161 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
162 				       socklen_t *len)
163 {
164 	struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
165 
166 	addr6->sin6_family = AF_INET6;
167 	addr6->sin6_port = 0;
168 	addr6->sin6_addr = in6addr_loopback;
169 	*len = sizeof(*addr6);
170 }
171 
172 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
173 					    socklen_t *len)
174 {
175 	struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
176 
177 	addr->svm_family = AF_VSOCK;
178 	addr->svm_port = VMADDR_PORT_ANY;
179 	addr->svm_cid = VMADDR_CID_LOCAL;
180 	*len = sizeof(*addr);
181 }
182 
183 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
184 				      socklen_t *len)
185 {
186 	switch (family) {
187 	case AF_INET:
188 		init_addr_loopback4(ss, len);
189 		return;
190 	case AF_INET6:
191 		init_addr_loopback6(ss, len);
192 		return;
193 	case AF_VSOCK:
194 		init_addr_loopback_vsock(ss, len);
195 		return;
196 	default:
197 		FAIL("unsupported address family %d", family);
198 	}
199 }
200 
201 static inline int enable_reuseport(int s, int progfd)
202 {
203 	int err, one = 1;
204 
205 	err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
206 	if (err)
207 		return -1;
208 	err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
209 			  sizeof(progfd));
210 	if (err)
211 		return -1;
212 
213 	return 0;
214 }
215 
216 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
217 {
218 	struct sockaddr_storage addr;
219 	socklen_t len = 0;
220 	int err, s;
221 
222 	init_addr_loopback(family, &addr, &len);
223 
224 	s = xsocket(family, sotype, 0);
225 	if (s == -1)
226 		return -1;
227 
228 	if (progfd >= 0)
229 		enable_reuseport(s, progfd);
230 
231 	err = xbind(s, sockaddr(&addr), len);
232 	if (err)
233 		goto close;
234 
235 	if (sotype & SOCK_DGRAM)
236 		return s;
237 
238 	err = xlisten(s, SOMAXCONN);
239 	if (err)
240 		goto close;
241 
242 	return s;
243 close:
244 	xclose(s);
245 	return -1;
246 }
247 
248 static inline int socket_loopback(int family, int sotype)
249 {
250 	return socket_loopback_reuseport(family, sotype, -1);
251 }
252 
253 static inline int poll_connect(int fd, unsigned int timeout_sec)
254 {
255 	struct timeval timeout = { .tv_sec = timeout_sec };
256 	fd_set wfds;
257 	int r, eval;
258 	socklen_t esize = sizeof(eval);
259 
260 	FD_ZERO(&wfds);
261 	FD_SET(fd, &wfds);
262 
263 	r = select(fd + 1, NULL, &wfds, NULL, &timeout);
264 	if (r == 0)
265 		errno = ETIME;
266 	if (r != 1)
267 		return -1;
268 
269 	if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
270 		return -1;
271 	if (eval != 0) {
272 		errno = eval;
273 		return -1;
274 	}
275 
276 	return 0;
277 }
278 
279 static inline int poll_read(int fd, unsigned int timeout_sec)
280 {
281 	struct timeval timeout = { .tv_sec = timeout_sec };
282 	fd_set rfds;
283 	int r;
284 
285 	FD_ZERO(&rfds);
286 	FD_SET(fd, &rfds);
287 
288 	r = select(fd + 1, &rfds, NULL, NULL, &timeout);
289 	if (r == 0)
290 		errno = ETIME;
291 
292 	return r == 1 ? 0 : -1;
293 }
294 
295 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
296 				 unsigned int timeout_sec)
297 {
298 	if (poll_read(fd, timeout_sec))
299 		return -1;
300 
301 	return accept(fd, addr, len);
302 }
303 
304 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
305 			       unsigned int timeout_sec)
306 {
307 	if (poll_read(fd, timeout_sec))
308 		return -1;
309 
310 	return recv(fd, buf, len, flags);
311 }
312 
313 
314 static inline int create_pair(int family, int sotype, int *p0, int *p1)
315 {
316 	__close_fd int s, c = -1, p = -1;
317 	struct sockaddr_storage addr;
318 	socklen_t len = sizeof(addr);
319 	int err;
320 
321 	s = socket_loopback(family, sotype);
322 	if (s < 0)
323 		return s;
324 
325 	err = xgetsockname(s, sockaddr(&addr), &len);
326 	if (err)
327 		return err;
328 
329 	c = xsocket(family, sotype, 0);
330 	if (c < 0)
331 		return c;
332 
333 	err = connect(c, sockaddr(&addr), len);
334 	if (err) {
335 		if (errno != EINPROGRESS) {
336 			FAIL_ERRNO("connect");
337 			return err;
338 		}
339 
340 		err = poll_connect(c, IO_TIMEOUT_SEC);
341 		if (err) {
342 			FAIL_ERRNO("poll_connect");
343 			return err;
344 		}
345 	}
346 
347 	switch (sotype & SOCK_TYPE_MASK) {
348 	case SOCK_DGRAM:
349 		err = xgetsockname(c, sockaddr(&addr), &len);
350 		if (err)
351 			return err;
352 
353 		err = xconnect(s, sockaddr(&addr), len);
354 		if (err)
355 			return err;
356 
357 		*p0 = take_fd(s);
358 		break;
359 	case SOCK_STREAM:
360 	case SOCK_SEQPACKET:
361 		p = xaccept_nonblock(s, NULL, NULL);
362 		if (p < 0)
363 			return p;
364 
365 		*p0 = take_fd(p);
366 		break;
367 	default:
368 		FAIL("Unsupported socket type %#x", sotype);
369 		return -EOPNOTSUPP;
370 	}
371 
372 	*p1 = take_fd(c);
373 	return 0;
374 }
375 
376 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
377 				      int *p0, int *p1)
378 {
379 	int err;
380 
381 	err = create_pair(family, sotype, c0, p0);
382 	if (err)
383 		return err;
384 
385 	err = create_pair(family, sotype, c1, p1);
386 	if (err) {
387 		close(*c0);
388 		close(*p0);
389 	}
390 
391 	return err;
392 }
393 
394 #endif // __SOCKET_HELPERS__
395