1 #ifndef __SOCKMAP_HELPERS__ 2 #define __SOCKMAP_HELPERS__ 3 4 #include <linux/vm_sockets.h> 5 6 /* include/linux/net.h */ 7 #define SOCK_TYPE_MASK 0xf 8 9 #define IO_TIMEOUT_SEC 30 10 #define MAX_STRERR_LEN 256 11 #define MAX_TEST_NAME 80 12 13 /* workaround for older vm_sockets.h */ 14 #ifndef VMADDR_CID_LOCAL 15 #define VMADDR_CID_LOCAL 1 16 #endif 17 18 #define __always_unused __attribute__((__unused__)) 19 20 /* include/linux/cleanup.h */ 21 #define __get_and_null(p, nullvalue) \ 22 ({ \ 23 __auto_type __ptr = &(p); \ 24 __auto_type __val = *__ptr; \ 25 *__ptr = nullvalue; \ 26 __val; \ 27 }) 28 29 #define take_fd(fd) __get_and_null(fd, -EBADF) 30 31 #define _FAIL(errnum, fmt...) \ 32 ({ \ 33 error_at_line(0, (errnum), __func__, __LINE__, fmt); \ 34 CHECK_FAIL(true); \ 35 }) 36 #define FAIL(fmt...) _FAIL(0, fmt) 37 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt) 38 #define FAIL_LIBBPF(err, msg) \ 39 ({ \ 40 char __buf[MAX_STRERR_LEN]; \ 41 libbpf_strerror((err), __buf, sizeof(__buf)); \ 42 FAIL("%s: %s", (msg), __buf); \ 43 }) 44 45 /* Wrappers that fail the test on error and report it. */ 46 47 #define xaccept_nonblock(fd, addr, len) \ 48 ({ \ 49 int __ret = \ 50 accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC); \ 51 if (__ret == -1) \ 52 FAIL_ERRNO("accept"); \ 53 __ret; \ 54 }) 55 56 #define xbind(fd, addr, len) \ 57 ({ \ 58 int __ret = bind((fd), (addr), (len)); \ 59 if (__ret == -1) \ 60 FAIL_ERRNO("bind"); \ 61 __ret; \ 62 }) 63 64 #define xclose(fd) \ 65 ({ \ 66 int __ret = close((fd)); \ 67 if (__ret == -1) \ 68 FAIL_ERRNO("close"); \ 69 __ret; \ 70 }) 71 72 #define xconnect(fd, addr, len) \ 73 ({ \ 74 int __ret = connect((fd), (addr), (len)); \ 75 if (__ret == -1) \ 76 FAIL_ERRNO("connect"); \ 77 __ret; \ 78 }) 79 80 #define xgetsockname(fd, addr, len) \ 81 ({ \ 82 int __ret = getsockname((fd), (addr), (len)); \ 83 if (__ret == -1) \ 84 FAIL_ERRNO("getsockname"); \ 85 __ret; \ 86 }) 87 88 #define xgetsockopt(fd, level, name, val, len) \ 89 ({ \ 90 int __ret = getsockopt((fd), (level), (name), (val), (len)); \ 91 if (__ret == -1) \ 92 FAIL_ERRNO("getsockopt(" #name ")"); \ 93 __ret; \ 94 }) 95 96 #define xlisten(fd, backlog) \ 97 ({ \ 98 int __ret = listen((fd), (backlog)); \ 99 if (__ret == -1) \ 100 FAIL_ERRNO("listen"); \ 101 __ret; \ 102 }) 103 104 #define xsetsockopt(fd, level, name, val, len) \ 105 ({ \ 106 int __ret = setsockopt((fd), (level), (name), (val), (len)); \ 107 if (__ret == -1) \ 108 FAIL_ERRNO("setsockopt(" #name ")"); \ 109 __ret; \ 110 }) 111 112 #define xsend(fd, buf, len, flags) \ 113 ({ \ 114 ssize_t __ret = send((fd), (buf), (len), (flags)); \ 115 if (__ret == -1) \ 116 FAIL_ERRNO("send"); \ 117 __ret; \ 118 }) 119 120 #define xrecv_nonblock(fd, buf, len, flags) \ 121 ({ \ 122 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags), \ 123 IO_TIMEOUT_SEC); \ 124 if (__ret == -1) \ 125 FAIL_ERRNO("recv"); \ 126 __ret; \ 127 }) 128 129 #define xsocket(family, sotype, flags) \ 130 ({ \ 131 int __ret = socket(family, sotype, flags); \ 132 if (__ret == -1) \ 133 FAIL_ERRNO("socket"); \ 134 __ret; \ 135 }) 136 137 #define xbpf_map_delete_elem(fd, key) \ 138 ({ \ 139 int __ret = bpf_map_delete_elem((fd), (key)); \ 140 if (__ret < 0) \ 141 FAIL_ERRNO("map_delete"); \ 142 __ret; \ 143 }) 144 145 #define xbpf_map_lookup_elem(fd, key, val) \ 146 ({ \ 147 int __ret = bpf_map_lookup_elem((fd), (key), (val)); \ 148 if (__ret < 0) \ 149 FAIL_ERRNO("map_lookup"); \ 150 __ret; \ 151 }) 152 153 #define xbpf_map_update_elem(fd, key, val, flags) \ 154 ({ \ 155 int __ret = bpf_map_update_elem((fd), (key), (val), (flags)); \ 156 if (__ret < 0) \ 157 FAIL_ERRNO("map_update"); \ 158 __ret; \ 159 }) 160 161 #define xbpf_prog_attach(prog, target, type, flags) \ 162 ({ \ 163 int __ret = \ 164 bpf_prog_attach((prog), (target), (type), (flags)); \ 165 if (__ret < 0) \ 166 FAIL_ERRNO("prog_attach(" #type ")"); \ 167 __ret; \ 168 }) 169 170 #define xbpf_prog_detach2(prog, target, type) \ 171 ({ \ 172 int __ret = bpf_prog_detach2((prog), (target), (type)); \ 173 if (__ret < 0) \ 174 FAIL_ERRNO("prog_detach2(" #type ")"); \ 175 __ret; \ 176 }) 177 178 #define xpthread_create(thread, attr, func, arg) \ 179 ({ \ 180 int __ret = pthread_create((thread), (attr), (func), (arg)); \ 181 errno = __ret; \ 182 if (__ret) \ 183 FAIL_ERRNO("pthread_create"); \ 184 __ret; \ 185 }) 186 187 #define xpthread_join(thread, retval) \ 188 ({ \ 189 int __ret = pthread_join((thread), (retval)); \ 190 errno = __ret; \ 191 if (__ret) \ 192 FAIL_ERRNO("pthread_join"); \ 193 __ret; \ 194 }) 195 196 static inline void close_fd(int *fd) 197 { 198 if (*fd >= 0) 199 xclose(*fd); 200 } 201 202 #define __close_fd __attribute__((cleanup(close_fd))) 203 204 static inline int poll_connect(int fd, unsigned int timeout_sec) 205 { 206 struct timeval timeout = { .tv_sec = timeout_sec }; 207 fd_set wfds; 208 int r, eval; 209 socklen_t esize = sizeof(eval); 210 211 FD_ZERO(&wfds); 212 FD_SET(fd, &wfds); 213 214 r = select(fd + 1, NULL, &wfds, NULL, &timeout); 215 if (r == 0) 216 errno = ETIME; 217 if (r != 1) 218 return -1; 219 220 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0) 221 return -1; 222 if (eval != 0) { 223 errno = eval; 224 return -1; 225 } 226 227 return 0; 228 } 229 230 static inline int poll_read(int fd, unsigned int timeout_sec) 231 { 232 struct timeval timeout = { .tv_sec = timeout_sec }; 233 fd_set rfds; 234 int r; 235 236 FD_ZERO(&rfds); 237 FD_SET(fd, &rfds); 238 239 r = select(fd + 1, &rfds, NULL, NULL, &timeout); 240 if (r == 0) 241 errno = ETIME; 242 243 return r == 1 ? 0 : -1; 244 } 245 246 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len, 247 unsigned int timeout_sec) 248 { 249 if (poll_read(fd, timeout_sec)) 250 return -1; 251 252 return accept(fd, addr, len); 253 } 254 255 static inline int recv_timeout(int fd, void *buf, size_t len, int flags, 256 unsigned int timeout_sec) 257 { 258 if (poll_read(fd, timeout_sec)) 259 return -1; 260 261 return recv(fd, buf, len, flags); 262 } 263 264 static inline void init_addr_loopback4(struct sockaddr_storage *ss, 265 socklen_t *len) 266 { 267 struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss)); 268 269 addr4->sin_family = AF_INET; 270 addr4->sin_port = 0; 271 addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK); 272 *len = sizeof(*addr4); 273 } 274 275 static inline void init_addr_loopback6(struct sockaddr_storage *ss, 276 socklen_t *len) 277 { 278 struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss)); 279 280 addr6->sin6_family = AF_INET6; 281 addr6->sin6_port = 0; 282 addr6->sin6_addr = in6addr_loopback; 283 *len = sizeof(*addr6); 284 } 285 286 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss, 287 socklen_t *len) 288 { 289 struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss)); 290 291 addr->svm_family = AF_VSOCK; 292 addr->svm_port = VMADDR_PORT_ANY; 293 addr->svm_cid = VMADDR_CID_LOCAL; 294 *len = sizeof(*addr); 295 } 296 297 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss, 298 socklen_t *len) 299 { 300 switch (family) { 301 case AF_INET: 302 init_addr_loopback4(ss, len); 303 return; 304 case AF_INET6: 305 init_addr_loopback6(ss, len); 306 return; 307 case AF_VSOCK: 308 init_addr_loopback_vsock(ss, len); 309 return; 310 default: 311 FAIL("unsupported address family %d", family); 312 } 313 } 314 315 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss) 316 { 317 return (struct sockaddr *)ss; 318 } 319 320 static inline int add_to_sockmap(int sock_mapfd, int fd1, int fd2) 321 { 322 u64 value; 323 u32 key; 324 int err; 325 326 key = 0; 327 value = fd1; 328 err = xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST); 329 if (err) 330 return err; 331 332 key = 1; 333 value = fd2; 334 return xbpf_map_update_elem(sock_mapfd, &key, &value, BPF_NOEXIST); 335 } 336 337 static inline int enable_reuseport(int s, int progfd) 338 { 339 int err, one = 1; 340 341 err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one)); 342 if (err) 343 return -1; 344 err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd, 345 sizeof(progfd)); 346 if (err) 347 return -1; 348 349 return 0; 350 } 351 352 static inline int socket_loopback_reuseport(int family, int sotype, int progfd) 353 { 354 struct sockaddr_storage addr; 355 socklen_t len = 0; 356 int err, s; 357 358 init_addr_loopback(family, &addr, &len); 359 360 s = xsocket(family, sotype, 0); 361 if (s == -1) 362 return -1; 363 364 if (progfd >= 0) 365 enable_reuseport(s, progfd); 366 367 err = xbind(s, sockaddr(&addr), len); 368 if (err) 369 goto close; 370 371 if (sotype & SOCK_DGRAM) 372 return s; 373 374 err = xlisten(s, SOMAXCONN); 375 if (err) 376 goto close; 377 378 return s; 379 close: 380 xclose(s); 381 return -1; 382 } 383 384 static inline int socket_loopback(int family, int sotype) 385 { 386 return socket_loopback_reuseport(family, sotype, -1); 387 } 388 389 static inline int create_pair(int family, int sotype, int *p0, int *p1) 390 { 391 __close_fd int s, c = -1, p = -1; 392 struct sockaddr_storage addr; 393 socklen_t len = sizeof(addr); 394 int err; 395 396 s = socket_loopback(family, sotype); 397 if (s < 0) 398 return s; 399 400 err = xgetsockname(s, sockaddr(&addr), &len); 401 if (err) 402 return err; 403 404 c = xsocket(family, sotype, 0); 405 if (c < 0) 406 return c; 407 408 err = connect(c, sockaddr(&addr), len); 409 if (err) { 410 if (errno != EINPROGRESS) { 411 FAIL_ERRNO("connect"); 412 return err; 413 } 414 415 err = poll_connect(c, IO_TIMEOUT_SEC); 416 if (err) { 417 FAIL_ERRNO("poll_connect"); 418 return err; 419 } 420 } 421 422 switch (sotype & SOCK_TYPE_MASK) { 423 case SOCK_DGRAM: 424 err = xgetsockname(c, sockaddr(&addr), &len); 425 if (err) 426 return err; 427 428 err = xconnect(s, sockaddr(&addr), len); 429 if (err) 430 return err; 431 432 *p0 = take_fd(s); 433 break; 434 case SOCK_STREAM: 435 case SOCK_SEQPACKET: 436 p = xaccept_nonblock(s, NULL, NULL); 437 if (p < 0) 438 return p; 439 440 *p0 = take_fd(p); 441 break; 442 default: 443 FAIL("Unsupported socket type %#x", sotype); 444 return -EOPNOTSUPP; 445 } 446 447 *p1 = take_fd(c); 448 return 0; 449 } 450 451 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1, 452 int *p0, int *p1) 453 { 454 int err; 455 456 err = create_pair(family, sotype, c0, p0); 457 if (err) 458 return err; 459 460 err = create_pair(family, sotype, c1, p1); 461 if (err) { 462 close(*c0); 463 close(*p0); 464 } 465 466 return err; 467 } 468 469 #endif // __SOCKMAP_HELPERS__ 470