1 #ifndef __SOCKMAP_HELPERS__
2 #define __SOCKMAP_HELPERS__
3
4 #include <linux/vm_sockets.h>
5
6 /* include/linux/net.h */
7 #define SOCK_TYPE_MASK 0xf
8
9 #define IO_TIMEOUT_SEC 30
10 #define MAX_STRERR_LEN 256
11 #define MAX_TEST_NAME 80
12
13 /* workaround for older vm_sockets.h */
14 #ifndef VMADDR_CID_LOCAL
15 #define VMADDR_CID_LOCAL 1
16 #endif
17
18 #define __always_unused __attribute__((__unused__))
19
20 /* include/linux/cleanup.h */
21 #define __get_and_null(p, nullvalue) \
22 ({ \
23 __auto_type __ptr = &(p); \
24 __auto_type __val = *__ptr; \
25 *__ptr = nullvalue; \
26 __val; \
27 })
28
29 #define take_fd(fd) __get_and_null(fd, -EBADF)
30
31 #define _FAIL(errnum, fmt...) \
32 ({ \
33 error_at_line(0, (errnum), __func__, __LINE__, fmt); \
34 CHECK_FAIL(true); \
35 })
36 #define FAIL(fmt...) _FAIL(0, fmt)
37 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt)
38 #define FAIL_LIBBPF(err, msg) \
39 ({ \
40 char __buf[MAX_STRERR_LEN]; \
41 libbpf_strerror((err), __buf, sizeof(__buf)); \
42 FAIL("%s: %s", (msg), __buf); \
43 })
44
45 /* Wrappers that fail the test on error and report it. */
46
47 #define xaccept_nonblock(fd, addr, len) \
48 ({ \
49 int __ret = \
50 accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC); \
51 if (__ret == -1) \
52 FAIL_ERRNO("accept"); \
53 __ret; \
54 })
55
56 #define xbind(fd, addr, len) \
57 ({ \
58 int __ret = bind((fd), (addr), (len)); \
59 if (__ret == -1) \
60 FAIL_ERRNO("bind"); \
61 __ret; \
62 })
63
64 #define xclose(fd) \
65 ({ \
66 int __ret = close((fd)); \
67 if (__ret == -1) \
68 FAIL_ERRNO("close"); \
69 __ret; \
70 })
71
72 #define xconnect(fd, addr, len) \
73 ({ \
74 int __ret = connect((fd), (addr), (len)); \
75 if (__ret == -1) \
76 FAIL_ERRNO("connect"); \
77 __ret; \
78 })
79
80 #define xgetsockname(fd, addr, len) \
81 ({ \
82 int __ret = getsockname((fd), (addr), (len)); \
83 if (__ret == -1) \
84 FAIL_ERRNO("getsockname"); \
85 __ret; \
86 })
87
88 #define xgetsockopt(fd, level, name, val, len) \
89 ({ \
90 int __ret = getsockopt((fd), (level), (name), (val), (len)); \
91 if (__ret == -1) \
92 FAIL_ERRNO("getsockopt(" #name ")"); \
93 __ret; \
94 })
95
96 #define xlisten(fd, backlog) \
97 ({ \
98 int __ret = listen((fd), (backlog)); \
99 if (__ret == -1) \
100 FAIL_ERRNO("listen"); \
101 __ret; \
102 })
103
104 #define xsetsockopt(fd, level, name, val, len) \
105 ({ \
106 int __ret = setsockopt((fd), (level), (name), (val), (len)); \
107 if (__ret == -1) \
108 FAIL_ERRNO("setsockopt(" #name ")"); \
109 __ret; \
110 })
111
112 #define xsend(fd, buf, len, flags) \
113 ({ \
114 ssize_t __ret = send((fd), (buf), (len), (flags)); \
115 if (__ret == -1) \
116 FAIL_ERRNO("send"); \
117 __ret; \
118 })
119
120 #define xrecv_nonblock(fd, buf, len, flags) \
121 ({ \
122 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags), \
123 IO_TIMEOUT_SEC); \
124 if (__ret == -1) \
125 FAIL_ERRNO("recv"); \
126 __ret; \
127 })
128
129 #define xsocket(family, sotype, flags) \
130 ({ \
131 int __ret = socket(family, sotype, flags); \
132 if (__ret == -1) \
133 FAIL_ERRNO("socket"); \
134 __ret; \
135 })
136
137 #define xbpf_map_delete_elem(fd, key) \
138 ({ \
139 int __ret = bpf_map_delete_elem((fd), (key)); \
140 if (__ret < 0) \
141 FAIL_ERRNO("map_delete"); \
142 __ret; \
143 })
144
145 #define xbpf_map_lookup_elem(fd, key, val) \
146 ({ \
147 int __ret = bpf_map_lookup_elem((fd), (key), (val)); \
148 if (__ret < 0) \
149 FAIL_ERRNO("map_lookup"); \
150 __ret; \
151 })
152
153 #define xbpf_map_update_elem(fd, key, val, flags) \
154 ({ \
155 int __ret = bpf_map_update_elem((fd), (key), (val), (flags)); \
156 if (__ret < 0) \
157 FAIL_ERRNO("map_update"); \
158 __ret; \
159 })
160
161 #define xbpf_prog_attach(prog, target, type, flags) \
162 ({ \
163 int __ret = \
164 bpf_prog_attach((prog), (target), (type), (flags)); \
165 if (__ret < 0) \
166 FAIL_ERRNO("prog_attach(" #type ")"); \
167 __ret; \
168 })
169
170 #define xbpf_prog_detach2(prog, target, type) \
171 ({ \
172 int __ret = bpf_prog_detach2((prog), (target), (type)); \
173 if (__ret < 0) \
174 FAIL_ERRNO("prog_detach2(" #type ")"); \
175 __ret; \
176 })
177
178 #define xpthread_create(thread, attr, func, arg) \
179 ({ \
180 int __ret = pthread_create((thread), (attr), (func), (arg)); \
181 errno = __ret; \
182 if (__ret) \
183 FAIL_ERRNO("pthread_create"); \
184 __ret; \
185 })
186
187 #define xpthread_join(thread, retval) \
188 ({ \
189 int __ret = pthread_join((thread), (retval)); \
190 errno = __ret; \
191 if (__ret) \
192 FAIL_ERRNO("pthread_join"); \
193 __ret; \
194 })
195
close_fd(int * fd)196 static inline void close_fd(int *fd)
197 {
198 if (*fd >= 0)
199 xclose(*fd);
200 }
201
202 #define __close_fd __attribute__((cleanup(close_fd)))
203
poll_connect(int fd,unsigned int timeout_sec)204 static inline int poll_connect(int fd, unsigned int timeout_sec)
205 {
206 struct timeval timeout = { .tv_sec = timeout_sec };
207 fd_set wfds;
208 int r, eval;
209 socklen_t esize = sizeof(eval);
210
211 FD_ZERO(&wfds);
212 FD_SET(fd, &wfds);
213
214 r = select(fd + 1, NULL, &wfds, NULL, &timeout);
215 if (r == 0)
216 errno = ETIME;
217 if (r != 1)
218 return -1;
219
220 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0)
221 return -1;
222 if (eval != 0) {
223 errno = eval;
224 return -1;
225 }
226
227 return 0;
228 }
229
poll_read(int fd,unsigned int timeout_sec)230 static inline int poll_read(int fd, unsigned int timeout_sec)
231 {
232 struct timeval timeout = { .tv_sec = timeout_sec };
233 fd_set rfds;
234 int r;
235
236 FD_ZERO(&rfds);
237 FD_SET(fd, &rfds);
238
239 r = select(fd + 1, &rfds, NULL, NULL, &timeout);
240 if (r == 0)
241 errno = ETIME;
242
243 return r == 1 ? 0 : -1;
244 }
245
accept_timeout(int fd,struct sockaddr * addr,socklen_t * len,unsigned int timeout_sec)246 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len,
247 unsigned int timeout_sec)
248 {
249 if (poll_read(fd, timeout_sec))
250 return -1;
251
252 return accept(fd, addr, len);
253 }
254
recv_timeout(int fd,void * buf,size_t len,int flags,unsigned int timeout_sec)255 static inline int recv_timeout(int fd, void *buf, size_t len, int flags,
256 unsigned int timeout_sec)
257 {
258 if (poll_read(fd, timeout_sec))
259 return -1;
260
261 return recv(fd, buf, len, flags);
262 }
263
init_addr_loopback4(struct sockaddr_storage * ss,socklen_t * len)264 static inline void init_addr_loopback4(struct sockaddr_storage *ss,
265 socklen_t *len)
266 {
267 struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss));
268
269 addr4->sin_family = AF_INET;
270 addr4->sin_port = 0;
271 addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK);
272 *len = sizeof(*addr4);
273 }
274
init_addr_loopback6(struct sockaddr_storage * ss,socklen_t * len)275 static inline void init_addr_loopback6(struct sockaddr_storage *ss,
276 socklen_t *len)
277 {
278 struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss));
279
280 addr6->sin6_family = AF_INET6;
281 addr6->sin6_port = 0;
282 addr6->sin6_addr = in6addr_loopback;
283 *len = sizeof(*addr6);
284 }
285
init_addr_loopback_vsock(struct sockaddr_storage * ss,socklen_t * len)286 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss,
287 socklen_t *len)
288 {
289 struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss));
290
291 addr->svm_family = AF_VSOCK;
292 addr->svm_port = VMADDR_PORT_ANY;
293 addr->svm_cid = VMADDR_CID_LOCAL;
294 *len = sizeof(*addr);
295 }
296
init_addr_loopback(int family,struct sockaddr_storage * ss,socklen_t * len)297 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss,
298 socklen_t *len)
299 {
300 switch (family) {
301 case AF_INET:
302 init_addr_loopback4(ss, len);
303 return;
304 case AF_INET6:
305 init_addr_loopback6(ss, len);
306 return;
307 case AF_VSOCK:
308 init_addr_loopback_vsock(ss, len);
309 return;
310 default:
311 FAIL("unsupported address family %d", family);
312 }
313 }
314
sockaddr(struct sockaddr_storage * ss)315 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss)
316 {
317 return (struct sockaddr *)ss;
318 }
319
add_to_sockmap(int sock_mapfd,int fd1,int fd2)320 static inline int add_to_sockmap(int sock_mapfd, int fd1, int fd2)
321 {
322 u64 value;
323 u32 key;
324 int err;
325
326 key = 0;
327 value = fd1;
328 err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
329 if (err)
330 return err;
331
332 key = 1;
333 value = fd2;
334 return xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST);
335 }
336
enable_reuseport(int s,int progfd)337 static inline int enable_reuseport(int s, int progfd)
338 {
339 int err, one = 1;
340
341 err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
342 if (err)
343 return -1;
344 err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd,
345 sizeof(progfd));
346 if (err)
347 return -1;
348
349 return 0;
350 }
351
socket_loopback_reuseport(int family,int sotype,int progfd)352 static inline int socket_loopback_reuseport(int family, int sotype, int progfd)
353 {
354 struct sockaddr_storage addr;
355 socklen_t len = 0;
356 int err, s;
357
358 init_addr_loopback(family, &addr, &len);
359
360 s = xsocket(family, sotype, 0);
361 if (s == -1)
362 return -1;
363
364 if (progfd >= 0)
365 enable_reuseport(s, progfd);
366
367 err = xbind(s, sockaddr(&addr), len);
368 if (err)
369 goto close;
370
371 if (sotype & SOCK_DGRAM)
372 return s;
373
374 err = xlisten(s, SOMAXCONN);
375 if (err)
376 goto close;
377
378 return s;
379 close:
380 xclose(s);
381 return -1;
382 }
383
socket_loopback(int family,int sotype)384 static inline int socket_loopback(int family, int sotype)
385 {
386 return socket_loopback_reuseport(family, sotype, -1);
387 }
388
create_pair(int family,int sotype,int * p0,int * p1)389 static inline int create_pair(int family, int sotype, int *p0, int *p1)
390 {
391 __close_fd int s, c = -1, p = -1;
392 struct sockaddr_storage addr;
393 socklen_t len = sizeof(addr);
394 int err;
395
396 s = socket_loopback(family, sotype);
397 if (s < 0)
398 return s;
399
400 err = xgetsockname(s, sockaddr(&addr), &len);
401 if (err)
402 return err;
403
404 c = xsocket(family, sotype, 0);
405 if (c < 0)
406 return c;
407
408 err = connect(c, sockaddr(&addr), len);
409 if (err) {
410 if (errno != EINPROGRESS) {
411 FAIL_ERRNO("connect");
412 return err;
413 }
414
415 err = poll_connect(c, IO_TIMEOUT_SEC);
416 if (err) {
417 FAIL_ERRNO("poll_connect");
418 return err;
419 }
420 }
421
422 switch (sotype & SOCK_TYPE_MASK) {
423 case SOCK_DGRAM:
424 err = xgetsockname(c, sockaddr(&addr), &len);
425 if (err)
426 return err;
427
428 err = xconnect(s, sockaddr(&addr), len);
429 if (err)
430 return err;
431
432 *p0 = take_fd(s);
433 break;
434 case SOCK_STREAM:
435 case SOCK_SEQPACKET:
436 p = xaccept_nonblock(s, NULL, NULL);
437 if (p < 0)
438 return p;
439
440 *p0 = take_fd(p);
441 break;
442 default:
443 FAIL("Unsupported socket type %#x", sotype);
444 return -EOPNOTSUPP;
445 }
446
447 *p1 = take_fd(c);
448 return 0;
449 }
450
create_socket_pairs(int family,int sotype,int * c0,int * c1,int * p0,int * p1)451 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1,
452 int *p0, int *p1)
453 {
454 int err;
455
456 err = create_pair(family, sotype, c0, p0);
457 if (err)
458 return err;
459
460 err = create_pair(family, sotype, c1, p1);
461 if (err) {
462 close(*c0);
463 close(*p0);
464 }
465
466 return err;
467 }
468
469 #endif // __SOCKMAP_HELPERS__
470