xref: /linux/tools/testing/selftests/bpf/prog_tests/socket_helpers.h (revision 9c707ba99f1b638e32724691b18fd1429e23b7f4)
1  /* SPDX-License-Identifier: GPL-2.0 */
2  
3  #ifndef __SOCKET_HELPERS__
4  #define __SOCKET_HELPERS__
5  
6  #include <linux/vm_sockets.h>
7  
8  /* include/linux/net.h */
9  #define SOCK_TYPE_MASK 0xf
10  
11  #define IO_TIMEOUT_SEC 30
12  #define MAX_STRERR_LEN 256
13  
14  /* workaround for older vm_sockets.h */
15  #ifndef VMADDR_CID_LOCAL
16  #define VMADDR_CID_LOCAL 1
17  #endif
18  
19  /* include/linux/cleanup.h */
20  #define __get_and_null(p, nullvalue)                                           \
21  	({                                                                     \
22  		__auto_type __ptr = &(p);                                      \
23  		__auto_type __val = *__ptr;                                    \
24  		*__ptr = nullvalue;                                            \
25  		__val;                                                         \
26  	})
27  
28  #define take_fd(fd) __get_and_null(fd, -EBADF)
29  
30  /* Wrappers that fail the test on error and report it. */
31  
32  #define _FAIL(errnum, fmt...)                                                  \
33  	({                                                                     \
34  		error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
35  		CHECK_FAIL(true);                                              \
36  	})
37  #define FAIL(fmt...) _FAIL(0, fmt)
38  #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
39  #define FAIL_LIBBPF(err, msg)                                                  \
40  	({                                                                     \
41  		char __buf[MAX_STRERR_LEN];                                    \
42  		libbpf_strerror((err), __buf, sizeof(__buf));                  \
43  		FAIL("%s: %s", (msg), __buf);                                  \
44  	})
45  
46  
47  #define xaccept_nonblock(fd, addr, len)                                        \
48  	({                                                                     \
49  		int __ret =                                                    \
50  			accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC);   \
51  		if (__ret == -1)                                               \
52  			FAIL_ERRNO("accept");                                  \
53  		__ret;                                                         \
54  	})
55  
56  #define xbind(fd, addr, len)                                                   \
57  	({                                                                     \
58  		int __ret = bind((fd), (addr), (len));                         \
59  		if (__ret == -1)                                               \
60  			FAIL_ERRNO("bind");                                    \
61  		__ret;                                                         \
62  	})
63  
64  #define xclose(fd)                                                             \
65  	({                                                                     \
66  		int __ret = close((fd));                                       \
67  		if (__ret == -1)                                               \
68  			FAIL_ERRNO("close");                                   \
69  		__ret;                                                         \
70  	})
71  
72  #define xconnect(fd, addr, len)                                                \
73  	({                                                                     \
74  		int __ret = connect((fd), (addr), (len));                      \
75  		if (__ret == -1)                                               \
76  			FAIL_ERRNO("connect");                                 \
77  		__ret;                                                         \
78  	})
79  
80  #define xgetsockname(fd, addr, len)                                            \
81  	({                                                                     \
82  		int __ret = getsockname((fd), (addr), (len));                  \
83  		if (__ret == -1)                                               \
84  			FAIL_ERRNO("getsockname");                             \
85  		__ret;                                                         \
86  	})
87  
88  #define xgetsockopt(fd, level, name, val, len)                                 \
89  	({                                                                     \
90  		int __ret = getsockopt((fd), (level), (name), (val), (len));   \
91  		if (__ret == -1)                                               \
92  			FAIL_ERRNO("getsockopt(" #name ")");                   \
93  		__ret;                                                         \
94  	})
95  
96  #define xlisten(fd, backlog)                                                   \
97  	({                                                                     \
98  		int __ret = listen((fd), (backlog));                           \
99  		if (__ret == -1)                                               \
100  			FAIL_ERRNO("listen");                                  \
101  		__ret;                                                         \
102  	})
103  
104  #define xsetsockopt(fd, level, name, val, len)                                 \
105  	({                                                                     \
106  		int __ret = setsockopt((fd), (level), (name), (val), (len));   \
107  		if (__ret == -1)                                               \
108  			FAIL_ERRNO("setsockopt(" #name ")");                   \
109  		__ret;                                                         \
110  	})
111  
112  #define xsend(fd, buf, len, flags)                                             \
113  	({                                                                     \
114  		ssize_t __ret = send((fd), (buf), (len), (flags));             \
115  		if (__ret == -1)                                               \
116  			FAIL_ERRNO("send");                                    \
117  		__ret;                                                         \
118  	})
119  
120  #define xrecv_nonblock(fd, buf, len, flags)                                    \
121  	({                                                                     \
122  		ssize_t __ret = recv_timeout((fd), (buf), (len), (flags),      \
123  					     IO_TIMEOUT_SEC);                  \
124  		if (__ret == -1)                                               \
125  			FAIL_ERRNO("recv");                                    \
126  		__ret;                                                         \
127  	})
128  
129  #define xsocket(family, sotype, flags)                                         \
130  	({                                                                     \
131  		int __ret = socket(family, sotype, flags);                     \
132  		if (__ret == -1)                                               \
133  			FAIL_ERRNO("socket");                                  \
134  		__ret;                                                         \
135  	})
136  
close_fd(int * fd)137  static inline void close_fd(int *fd)
138  {
139  	if (*fd >= 0)
140  		xclose(*fd);
141  }
142  
143  #define __close_fd __attribute__((cleanup(close_fd)))
144  
sockaddr(struct sockaddr_storage * ss)145  static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
146  {
147  	return (struct sockaddr *)ss;
148  }
149  
init_addr_loopback4(struct sockaddr_storage * ss,socklen_t * len)150  static inline void init_addr_loopback4(struct sockaddr_storage *ss,
151  				       socklen_t *len)
152  {
153  	struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
154  
155  	addr4->sin_family = AF_INET;
156  	addr4->sin_port = 0;
157  	addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
158  	*len = sizeof(*addr4);
159  }
160  
init_addr_loopback6(struct sockaddr_storage * ss,socklen_t * len)161  static inline void init_addr_loopback6(struct sockaddr_storage *ss,
162  				       socklen_t *len)
163  {
164  	struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
165  
166  	addr6->sin6_family = AF_INET6;
167  	addr6->sin6_port = 0;
168  	addr6->sin6_addr = in6addr_loopback;
169  	*len = sizeof(*addr6);
170  }
171  
init_addr_loopback_vsock(struct sockaddr_storage * ss,socklen_t * len)172  static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
173  					    socklen_t *len)
174  {
175  	struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
176  
177  	addr->svm_family = AF_VSOCK;
178  	addr->svm_port = VMADDR_PORT_ANY;
179  	addr->svm_cid = VMADDR_CID_LOCAL;
180  	*len = sizeof(*addr);
181  }
182  
init_addr_loopback(int family,struct sockaddr_storage * ss,socklen_t * len)183  static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
184  				      socklen_t *len)
185  {
186  	switch (family) {
187  	case AF_INET:
188  		init_addr_loopback4(ss, len);
189  		return;
190  	case AF_INET6:
191  		init_addr_loopback6(ss, len);
192  		return;
193  	case AF_VSOCK:
194  		init_addr_loopback_vsock(ss, len);
195  		return;
196  	default:
197  		FAIL("unsupported address family %d", family);
198  	}
199  }
200  
enable_reuseport(int s,int progfd)201  static inline int enable_reuseport(int s, int progfd)
202  {
203  	int err, one = 1;
204  
205  	err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
206  	if (err)
207  		return -1;
208  	err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
209  			  sizeof(progfd));
210  	if (err)
211  		return -1;
212  
213  	return 0;
214  }
215  
socket_loopback_reuseport(int family,int sotype,int progfd)216  static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
217  {
218  	struct sockaddr_storage addr;
219  	socklen_t len = 0;
220  	int err, s;
221  
222  	init_addr_loopback(family, &addr, &len);
223  
224  	s = xsocket(family, sotype, 0);
225  	if (s == -1)
226  		return -1;
227  
228  	if (progfd >= 0)
229  		enable_reuseport(s, progfd);
230  
231  	err = xbind(s, sockaddr(&addr), len);
232  	if (err)
233  		goto close;
234  
235  	if (sotype & SOCK_DGRAM)
236  		return s;
237  
238  	err = xlisten(s, SOMAXCONN);
239  	if (err)
240  		goto close;
241  
242  	return s;
243  close:
244  	xclose(s);
245  	return -1;
246  }
247  
socket_loopback(int family,int sotype)248  static inline int socket_loopback(int family, int sotype)
249  {
250  	return socket_loopback_reuseport(family, sotype, -1);
251  }
252  
poll_connect(int fd,unsigned int timeout_sec)253  static inline int poll_connect(int fd, unsigned int timeout_sec)
254  {
255  	struct timeval timeout = { .tv_sec = timeout_sec };
256  	fd_set wfds;
257  	int r, eval;
258  	socklen_t esize = sizeof(eval);
259  
260  	FD_ZERO(&wfds);
261  	FD_SET(fd, &wfds);
262  
263  	r = select(fd + 1, NULL, &wfds, NULL, &timeout);
264  	if (r == 0)
265  		errno = ETIME;
266  	if (r != 1)
267  		return -1;
268  
269  	if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
270  		return -1;
271  	if (eval != 0) {
272  		errno = eval;
273  		return -1;
274  	}
275  
276  	return 0;
277  }
278  
poll_read(int fd,unsigned int timeout_sec)279  static inline int poll_read(int fd, unsigned int timeout_sec)
280  {
281  	struct timeval timeout = { .tv_sec = timeout_sec };
282  	fd_set rfds;
283  	int r;
284  
285  	FD_ZERO(&rfds);
286  	FD_SET(fd, &rfds);
287  
288  	r = select(fd + 1, &rfds, NULL, NULL, &timeout);
289  	if (r == 0)
290  		errno = ETIME;
291  
292  	return r == 1 ? 0 : -1;
293  }
294  
accept_timeout(int fd,struct sockaddr * addr,socklen_t * len,unsigned int timeout_sec)295  static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
296  				 unsigned int timeout_sec)
297  {
298  	if (poll_read(fd, timeout_sec))
299  		return -1;
300  
301  	return accept(fd, addr, len);
302  }
303  
recv_timeout(int fd,void * buf,size_t len,int flags,unsigned int timeout_sec)304  static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
305  			       unsigned int timeout_sec)
306  {
307  	if (poll_read(fd, timeout_sec))
308  		return -1;
309  
310  	return recv(fd, buf, len, flags);
311  }
312  
313  
create_pair(int family,int sotype,int * p0,int * p1)314  static inline int create_pair(int family, int sotype, int *p0, int *p1)
315  {
316  	__close_fd int s, c = -1, p = -1;
317  	struct sockaddr_storage addr;
318  	socklen_t len = sizeof(addr);
319  	int err;
320  
321  	s = socket_loopback(family, sotype);
322  	if (s < 0)
323  		return s;
324  
325  	err = xgetsockname(s, sockaddr(&addr), &len);
326  	if (err)
327  		return err;
328  
329  	c = xsocket(family, sotype, 0);
330  	if (c < 0)
331  		return c;
332  
333  	err = connect(c, sockaddr(&addr), len);
334  	if (err) {
335  		if (errno != EINPROGRESS) {
336  			FAIL_ERRNO("connect");
337  			return err;
338  		}
339  
340  		err = poll_connect(c, IO_TIMEOUT_SEC);
341  		if (err) {
342  			FAIL_ERRNO("poll_connect");
343  			return err;
344  		}
345  	}
346  
347  	switch (sotype & SOCK_TYPE_MASK) {
348  	case SOCK_DGRAM:
349  		err = xgetsockname(c, sockaddr(&addr), &len);
350  		if (err)
351  			return err;
352  
353  		err = xconnect(s, sockaddr(&addr), len);
354  		if (err)
355  			return err;
356  
357  		*p0 = take_fd(s);
358  		break;
359  	case SOCK_STREAM:
360  	case SOCK_SEQPACKET:
361  		p = xaccept_nonblock(s, NULL, NULL);
362  		if (p < 0)
363  			return p;
364  
365  		*p0 = take_fd(p);
366  		break;
367  	default:
368  		FAIL("Unsupported socket type %#x", sotype);
369  		return -EOPNOTSUPP;
370  	}
371  
372  	*p1 = take_fd(c);
373  	return 0;
374  }
375  
create_socket_pairs(int family,int sotype,int * c0,int * c1,int * p0,int * p1)376  static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
377  				      int *p0, int *p1)
378  {
379  	int err;
380  
381  	err = create_pair(family, sotype, c0, p0);
382  	if (err)
383  		return err;
384  
385  	err = create_pair(family, sotype, c1, p1);
386  	if (err) {
387  		close(*c0);
388  		close(*p0);
389  	}
390  
391  	return err;
392  }
393  
394  #endif // __SOCKET_HELPERS__
395