xref: /linux/tools/testing/selftests/bpf/prog_tests/sockmap_helpers.h (revision b00f7f4f8e936da55f2e6c7fd96391ef54c145fc)
1 #ifndef __SOCKMAP_HELPERS__
2 #define __SOCKMAP_HELPERS__
3 
4 #include <linux/vm_sockets.h>
5 
6 /* include/linux/net.h */
7 #define SOCK_TYPE_MASK 0xf
8 
9 #define IO_TIMEOUT_SEC 30
10 #define MAX_STRERR_LEN 256
11 #define MAX_TEST_NAME 80
12 
13 /* workaround for older vm_sockets.h */
14 #ifndef VMADDR_CID_LOCAL
15 #define VMADDR_CID_LOCAL 1
16 #endif
17 
18 #define __always_unused	__attribute__((__unused__))
19 
20 /* include/linux/cleanup.h */
21 #define __get_and_null(p, nullvalue)                                           \
22 	({                                                                     \
23 		__auto_type __ptr = &(p);                                      \
24 		__auto_type __val = *__ptr;                                    \
25 		*__ptr = nullvalue;                                            \
26 		__val;                                                         \
27 	})
28 
29 #define take_fd(fd) __get_and_null(fd, -EBADF)
30 
31 #define _FAIL(errnum, fmt...)                                                  \
32 	({                                                                     \
33 		error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
34 		CHECK_FAIL(true);                                              \
35 	})
36 #define FAIL(fmt...) _FAIL(0, fmt)
37 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
38 #define FAIL_LIBBPF(err, msg)                                                  \
39 	({                                                                     \
40 		char __buf[MAX_STRERR_LEN];                                    \
41 		libbpf_strerror((err), __buf, sizeof(__buf));                  \
42 		FAIL("%s: %s", (msg), __buf);                                  \
43 	})
44 
45 /* Wrappers that fail the test on error and report it. */
46 
47 #define xaccept_nonblock(fd, addr, len)                                        \
48 	({                                                                     \
49 		int __ret =                                                    \
50 			accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC);   \
51 		if (__ret == -1)                                               \
52 			FAIL_ERRNO("accept");                                  \
53 		__ret;                                                         \
54 	})
55 
56 #define xbind(fd, addr, len)                                                   \
57 	({                                                                     \
58 		int __ret = bind((fd), (addr), (len));                         \
59 		if (__ret == -1)                                               \
60 			FAIL_ERRNO("bind");                                    \
61 		__ret;                                                         \
62 	})
63 
64 #define xclose(fd)                                                             \
65 	({                                                                     \
66 		int __ret = close((fd));                                       \
67 		if (__ret == -1)                                               \
68 			FAIL_ERRNO("close");                                   \
69 		__ret;                                                         \
70 	})
71 
72 #define xconnect(fd, addr, len)                                                \
73 	({                                                                     \
74 		int __ret = connect((fd), (addr), (len));                      \
75 		if (__ret == -1)                                               \
76 			FAIL_ERRNO("connect");                                 \
77 		__ret;                                                         \
78 	})
79 
80 #define xgetsockname(fd, addr, len)                                            \
81 	({                                                                     \
82 		int __ret = getsockname((fd), (addr), (len));                  \
83 		if (__ret == -1)                                               \
84 			FAIL_ERRNO("getsockname");                             \
85 		__ret;                                                         \
86 	})
87 
88 #define xgetsockopt(fd, level, name, val, len)                                 \
89 	({                                                                     \
90 		int __ret = getsockopt((fd), (level), (name), (val), (len));   \
91 		if (__ret == -1)                                               \
92 			FAIL_ERRNO("getsockopt(" #name ")");                   \
93 		__ret;                                                         \
94 	})
95 
96 #define xlisten(fd, backlog)                                                   \
97 	({                                                                     \
98 		int __ret = listen((fd), (backlog));                           \
99 		if (__ret == -1)                                               \
100 			FAIL_ERRNO("listen");                                  \
101 		__ret;                                                         \
102 	})
103 
104 #define xsetsockopt(fd, level, name, val, len)                                 \
105 	({                                                                     \
106 		int __ret = setsockopt((fd), (level), (name), (val), (len));   \
107 		if (__ret == -1)                                               \
108 			FAIL_ERRNO("setsockopt(" #name ")");                   \
109 		__ret;                                                         \
110 	})
111 
112 #define xsend(fd, buf, len, flags)                                             \
113 	({                                                                     \
114 		ssize_t __ret = send((fd), (buf), (len), (flags));             \
115 		if (__ret == -1)                                               \
116 			FAIL_ERRNO("send");                                    \
117 		__ret;                                                         \
118 	})
119 
120 #define xrecv_nonblock(fd, buf, len, flags)                                    \
121 	({                                                                     \
122 		ssize_t __ret = recv_timeout((fd), (buf), (len), (flags),      \
123 					     IO_TIMEOUT_SEC);                  \
124 		if (__ret == -1)                                               \
125 			FAIL_ERRNO("recv");                                    \
126 		__ret;                                                         \
127 	})
128 
129 #define xsocket(family, sotype, flags)                                         \
130 	({                                                                     \
131 		int __ret = socket(family, sotype, flags);                     \
132 		if (__ret == -1)                                               \
133 			FAIL_ERRNO("socket");                                  \
134 		__ret;                                                         \
135 	})
136 
137 #define xbpf_map_delete_elem(fd, key)                                          \
138 	({                                                                     \
139 		int __ret = bpf_map_delete_elem((fd), (key));                  \
140 		if (__ret < 0)                                               \
141 			FAIL_ERRNO("map_delete");                              \
142 		__ret;                                                         \
143 	})
144 
145 #define xbpf_map_lookup_elem(fd, key, val)                                     \
146 	({                                                                     \
147 		int __ret = bpf_map_lookup_elem((fd), (key), (val));           \
148 		if (__ret < 0)                                               \
149 			FAIL_ERRNO("map_lookup");                              \
150 		__ret;                                                         \
151 	})
152 
153 #define xbpf_map_update_elem(fd, key, val, flags)                              \
154 	({                                                                     \
155 		int __ret = bpf_map_update_elem((fd), (key), (val), (flags));  \
156 		if (__ret < 0)                                               \
157 			FAIL_ERRNO("map_update");                              \
158 		__ret;                                                         \
159 	})
160 
161 #define xbpf_prog_attach(prog, target, type, flags)                            \
162 	({                                                                     \
163 		int __ret =                                                    \
164 			bpf_prog_attach((prog), (target), (type), (flags));    \
165 		if (__ret < 0)                                               \
166 			FAIL_ERRNO("prog_attach(" #type ")");                  \
167 		__ret;                                                         \
168 	})
169 
170 #define xbpf_prog_detach2(prog, target, type)                                  \
171 	({                                                                     \
172 		int __ret = bpf_prog_detach2((prog), (target), (type));        \
173 		if (__ret < 0)                                               \
174 			FAIL_ERRNO("prog_detach2(" #type ")");                 \
175 		__ret;                                                         \
176 	})
177 
178 #define xpthread_create(thread, attr, func, arg)                               \
179 	({                                                                     \
180 		int __ret = pthread_create((thread), (attr), (func), (arg));   \
181 		errno = __ret;                                                 \
182 		if (__ret)                                                     \
183 			FAIL_ERRNO("pthread_create");                          \
184 		__ret;                                                         \
185 	})
186 
187 #define xpthread_join(thread, retval)                                          \
188 	({                                                                     \
189 		int __ret = pthread_join((thread), (retval));                  \
190 		errno = __ret;                                                 \
191 		if (__ret)                                                     \
192 			FAIL_ERRNO("pthread_join");                            \
193 		__ret;                                                         \
194 	})
195 
196 static inline void close_fd(int *fd)
197 {
198 	if (*fd >= 0)
199 		xclose(*fd);
200 }
201 
202 #define __close_fd __attribute__((cleanup(close_fd)))
203 
204 static inline int poll_connect(int fd, unsigned int timeout_sec)
205 {
206 	struct timeval timeout = { .tv_sec = timeout_sec };
207 	fd_set wfds;
208 	int r, eval;
209 	socklen_t esize = sizeof(eval);
210 
211 	FD_ZERO(&wfds);
212 	FD_SET(fd, &wfds);
213 
214 	r = select(fd + 1, NULL, &wfds, NULL, &timeout);
215 	if (r == 0)
216 		errno = ETIME;
217 	if (r != 1)
218 		return -1;
219 
220 	if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
221 		return -1;
222 	if (eval != 0) {
223 		errno = eval;
224 		return -1;
225 	}
226 
227 	return 0;
228 }
229 
230 static inline int poll_read(int fd, unsigned int timeout_sec)
231 {
232 	struct timeval timeout = { .tv_sec = timeout_sec };
233 	fd_set rfds;
234 	int r;
235 
236 	FD_ZERO(&rfds);
237 	FD_SET(fd, &rfds);
238 
239 	r = select(fd + 1, &rfds, NULL, NULL, &timeout);
240 	if (r == 0)
241 		errno = ETIME;
242 
243 	return r == 1 ? 0 : -1;
244 }
245 
246 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
247 				 unsigned int timeout_sec)
248 {
249 	if (poll_read(fd, timeout_sec))
250 		return -1;
251 
252 	return accept(fd, addr, len);
253 }
254 
255 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
256 			       unsigned int timeout_sec)
257 {
258 	if (poll_read(fd, timeout_sec))
259 		return -1;
260 
261 	return recv(fd, buf, len, flags);
262 }
263 
264 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
265 				       socklen_t *len)
266 {
267 	struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
268 
269 	addr4->sin_family = AF_INET;
270 	addr4->sin_port = 0;
271 	addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
272 	*len = sizeof(*addr4);
273 }
274 
275 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
276 				       socklen_t *len)
277 {
278 	struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
279 
280 	addr6->sin6_family = AF_INET6;
281 	addr6->sin6_port = 0;
282 	addr6->sin6_addr = in6addr_loopback;
283 	*len = sizeof(*addr6);
284 }
285 
286 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
287 					    socklen_t *len)
288 {
289 	struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
290 
291 	addr->svm_family = AF_VSOCK;
292 	addr->svm_port = VMADDR_PORT_ANY;
293 	addr->svm_cid = VMADDR_CID_LOCAL;
294 	*len = sizeof(*addr);
295 }
296 
297 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
298 				      socklen_t *len)
299 {
300 	switch (family) {
301 	case AF_INET:
302 		init_addr_loopback4(ss, len);
303 		return;
304 	case AF_INET6:
305 		init_addr_loopback6(ss, len);
306 		return;
307 	case AF_VSOCK:
308 		init_addr_loopback_vsock(ss, len);
309 		return;
310 	default:
311 		FAIL("unsupported address family %d", family);
312 	}
313 }
314 
315 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
316 {
317 	return (struct sockaddr *)ss;
318 }
319 
320 static inline int add_to_sockmap(int sock_mapfd, int fd1, int fd2)
321 {
322 	u64 value;
323 	u32 key;
324 	int err;
325 
326 	key = 0;
327 	value = fd1;
328 	err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
329 	if (err)
330 		return err;
331 
332 	key = 1;
333 	value = fd2;
334 	return xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
335 }
336 
337 static inline int enable_reuseport(int s, int progfd)
338 {
339 	int err, one = 1;
340 
341 	err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
342 	if (err)
343 		return -1;
344 	err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
345 			  sizeof(progfd));
346 	if (err)
347 		return -1;
348 
349 	return 0;
350 }
351 
352 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
353 {
354 	struct sockaddr_storage addr;
355 	socklen_t len = 0;
356 	int err, s;
357 
358 	init_addr_loopback(family, &addr, &len);
359 
360 	s = xsocket(family, sotype, 0);
361 	if (s == -1)
362 		return -1;
363 
364 	if (progfd >= 0)
365 		enable_reuseport(s, progfd);
366 
367 	err = xbind(s, sockaddr(&addr), len);
368 	if (err)
369 		goto close;
370 
371 	if (sotype & SOCK_DGRAM)
372 		return s;
373 
374 	err = xlisten(s, SOMAXCONN);
375 	if (err)
376 		goto close;
377 
378 	return s;
379 close:
380 	xclose(s);
381 	return -1;
382 }
383 
384 static inline int socket_loopback(int family, int sotype)
385 {
386 	return socket_loopback_reuseport(family, sotype, -1);
387 }
388 
389 static inline int create_pair(int family, int sotype, int *p0, int *p1)
390 {
391 	__close_fd int s, c = -1, p = -1;
392 	struct sockaddr_storage addr;
393 	socklen_t len = sizeof(addr);
394 	int err;
395 
396 	s = socket_loopback(family, sotype);
397 	if (s < 0)
398 		return s;
399 
400 	err = xgetsockname(s, sockaddr(&addr), &len);
401 	if (err)
402 		return err;
403 
404 	c = xsocket(family, sotype, 0);
405 	if (c < 0)
406 		return c;
407 
408 	err = connect(c, sockaddr(&addr), len);
409 	if (err) {
410 		if (errno != EINPROGRESS) {
411 			FAIL_ERRNO("connect");
412 			return err;
413 		}
414 
415 		err = poll_connect(c, IO_TIMEOUT_SEC);
416 		if (err) {
417 			FAIL_ERRNO("poll_connect");
418 			return err;
419 		}
420 	}
421 
422 	switch (sotype & SOCK_TYPE_MASK) {
423 	case SOCK_DGRAM:
424 		err = xgetsockname(c, sockaddr(&addr), &len);
425 		if (err)
426 			return err;
427 
428 		err = xconnect(s, sockaddr(&addr), len);
429 		if (err)
430 			return err;
431 
432 		*p0 = take_fd(s);
433 		break;
434 	case SOCK_STREAM:
435 	case SOCK_SEQPACKET:
436 		p = xaccept_nonblock(s, NULL, NULL);
437 		if (p < 0)
438 			return p;
439 
440 		*p0 = take_fd(p);
441 		break;
442 	default:
443 		FAIL("Unsupported socket type %#x", sotype);
444 		return -EOPNOTSUPP;
445 	}
446 
447 	*p1 = take_fd(c);
448 	return 0;
449 }
450 
451 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
452 				      int *p0, int *p1)
453 {
454 	int err;
455 
456 	err = create_pair(family, sotype, c0, p0);
457 	if (err)
458 		return err;
459 
460 	err = create_pair(family, sotype, c1, p1);
461 	if (err) {
462 		close(*c0);
463 		close(*p0);
464 	}
465 
466 	return err;
467 }
468 
469 #endif // __SOCKMAP_HELPERS__
470