1 /* SPDX-License-Identifier: GPL-2.0 */ 2 3 #ifndef __SOCKET_HELPERS__ 4 #define __SOCKET_HELPERS__ 5 6 #include <sys/un.h> 7 #include <linux/vm_sockets.h> 8 9 /* include/linux/net.h */ 10 #define SOCK_TYPE_MASK 0xf 11 12 #define IO_TIMEOUT_SEC 30 13 #define MAX_STRERR_LEN 256 14 15 /* workaround for older vm_sockets.h */ 16 #ifndef VMADDR_CID_LOCAL 17 #define VMADDR_CID_LOCAL 1 18 #endif 19 20 /* include/linux/compiler_types.h */ 21 #if __STDC_VERSION__ < 202311L && !defined(auto) 22 # define auto __auto_type 23 #endif 24 25 /* include/linux/cleanup.h */ 26 #define __get_and_null(p, nullvalue) \ 27 ({ \ 28 auto __ptr = &(p); \ 29 auto __val = *__ptr; \ 30 *__ptr = nullvalue; \ 31 __val; \ 32 }) 33 34 #define take_fd(fd) __get_and_null(fd, -EBADF) 35 36 /* Wrappers that fail the test on error and report it. */ 37 38 #define _FAIL(errnum, fmt...) \ 39 ({ \ 40 error_at_line(0, (errnum), __func__, __LINE__, fmt); \ 41 CHECK_FAIL(true); \ 42 }) 43 #define FAIL(fmt...) _FAIL(0, fmt) 44 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt) 45 #define FAIL_LIBBPF(err, msg) \ 46 ({ \ 47 char __buf[MAX_STRERR_LEN]; \ 48 libbpf_strerror((err), __buf, sizeof(__buf)); \ 49 FAIL("%s: %s", (msg), __buf); \ 50 }) 51 52 53 #define xaccept_nonblock(fd, addr, len) \ 54 ({ \ 55 int __ret = \ 56 accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC); \ 57 if (__ret == -1) \ 58 FAIL_ERRNO("accept"); \ 59 __ret; \ 60 }) 61 62 #define xbind(fd, addr, len) \ 63 ({ \ 64 int __ret = bind((fd), (addr), (len)); \ 65 if (__ret == -1) \ 66 FAIL_ERRNO("bind"); \ 67 __ret; \ 68 }) 69 70 #define xclose(fd) \ 71 ({ \ 72 int __ret = close((fd)); \ 73 if (__ret == -1) \ 74 FAIL_ERRNO("close"); \ 75 __ret; \ 76 }) 77 78 #define xconnect(fd, addr, len) \ 79 ({ \ 80 int __ret = connect((fd), (addr), (len)); \ 81 if (__ret == -1) \ 82 FAIL_ERRNO("connect"); \ 83 __ret; \ 84 }) 85 86 #define xgetsockname(fd, addr, len) \ 87 ({ \ 88 int __ret = getsockname((fd), (addr), (len)); \ 89 if (__ret == -1) \ 90 FAIL_ERRNO("getsockname"); \ 91 __ret; \ 92 }) 93 94 #define xgetsockopt(fd, level, name, val, len) \ 95 ({ \ 96 int __ret = getsockopt((fd), (level), (name), (val), (len)); \ 97 if (__ret == -1) \ 98 FAIL_ERRNO("getsockopt(" #name ")"); \ 99 __ret; \ 100 }) 101 102 #define xlisten(fd, backlog) \ 103 ({ \ 104 int __ret = listen((fd), (backlog)); \ 105 if (__ret == -1) \ 106 FAIL_ERRNO("listen"); \ 107 __ret; \ 108 }) 109 110 #define xsetsockopt(fd, level, name, val, len) \ 111 ({ \ 112 int __ret = setsockopt((fd), (level), (name), (val), (len)); \ 113 if (__ret == -1) \ 114 FAIL_ERRNO("setsockopt(" #name ")"); \ 115 __ret; \ 116 }) 117 118 #define xsend(fd, buf, len, flags) \ 119 ({ \ 120 ssize_t __ret = send((fd), (buf), (len), (flags)); \ 121 if (__ret == -1) \ 122 FAIL_ERRNO("send"); \ 123 __ret; \ 124 }) 125 126 #define xrecv_nonblock(fd, buf, len, flags) \ 127 ({ \ 128 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags), \ 129 IO_TIMEOUT_SEC); \ 130 if (__ret == -1) \ 131 FAIL_ERRNO("recv"); \ 132 __ret; \ 133 }) 134 135 #define xsocket(family, sotype, flags) \ 136 ({ \ 137 int __ret = socket(family, sotype, flags); \ 138 if (__ret == -1) \ 139 FAIL_ERRNO("socket"); \ 140 __ret; \ 141 }) 142 143 static inline void close_fd(int *fd) 144 { 145 if (*fd >= 0) 146 xclose(*fd); 147 } 148 149 #define __close_fd __attribute__((cleanup(close_fd))) 150 151 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss) 152 { 153 return (struct sockaddr *)ss; 154 } 155 156 static inline void init_addr_loopback4(struct sockaddr_storage *ss, 157 socklen_t *len) 158 { 159 struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss)); 160 161 addr4->sin_family = AF_INET; 162 addr4->sin_port = 0; 163 addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK); 164 *len = sizeof(*addr4); 165 } 166 167 static inline void init_addr_loopback6(struct sockaddr_storage *ss, 168 socklen_t *len) 169 { 170 struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss)); 171 172 addr6->sin6_family = AF_INET6; 173 addr6->sin6_port = 0; 174 addr6->sin6_addr = in6addr_loopback; 175 *len = sizeof(*addr6); 176 } 177 178 static inline void init_addr_loopback_unix(struct sockaddr_storage *ss, 179 socklen_t *len) 180 { 181 struct sockaddr_un *addr = memset(ss, 0, sizeof(*ss)); 182 183 addr->sun_family = AF_UNIX; 184 *len = sizeof(sa_family_t); 185 } 186 187 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss, 188 socklen_t *len) 189 { 190 struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss)); 191 192 addr->svm_family = AF_VSOCK; 193 addr->svm_port = VMADDR_PORT_ANY; 194 addr->svm_cid = VMADDR_CID_LOCAL; 195 *len = sizeof(*addr); 196 } 197 198 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss, 199 socklen_t *len) 200 { 201 switch (family) { 202 case AF_INET: 203 init_addr_loopback4(ss, len); 204 return; 205 case AF_INET6: 206 init_addr_loopback6(ss, len); 207 return; 208 case AF_UNIX: 209 init_addr_loopback_unix(ss, len); 210 return; 211 case AF_VSOCK: 212 init_addr_loopback_vsock(ss, len); 213 return; 214 default: 215 FAIL("unsupported address family %d", family); 216 } 217 } 218 219 static inline int enable_reuseport(int s, int progfd) 220 { 221 int err, one = 1; 222 223 err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one)); 224 if (err) 225 return -1; 226 err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd, 227 sizeof(progfd)); 228 if (err) 229 return -1; 230 231 return 0; 232 } 233 234 static inline int socket_loopback_reuseport(int family, int sotype, int progfd) 235 { 236 struct sockaddr_storage addr; 237 socklen_t len = 0; 238 int err, s; 239 240 init_addr_loopback(family, &addr, &len); 241 242 s = xsocket(family, sotype, 0); 243 if (s == -1) 244 return -1; 245 246 if (progfd >= 0) 247 enable_reuseport(s, progfd); 248 249 err = xbind(s, sockaddr(&addr), len); 250 if (err) 251 goto close; 252 253 if (sotype & SOCK_DGRAM) 254 return s; 255 256 err = xlisten(s, SOMAXCONN); 257 if (err) 258 goto close; 259 260 return s; 261 close: 262 xclose(s); 263 return -1; 264 } 265 266 static inline int socket_loopback(int family, int sotype) 267 { 268 return socket_loopback_reuseport(family, sotype, -1); 269 } 270 271 static inline int poll_connect(int fd, unsigned int timeout_sec) 272 { 273 struct timeval timeout = { .tv_sec = timeout_sec }; 274 fd_set wfds; 275 int r, eval; 276 socklen_t esize = sizeof(eval); 277 278 FD_ZERO(&wfds); 279 FD_SET(fd, &wfds); 280 281 r = select(fd + 1, NULL, &wfds, NULL, &timeout); 282 if (r == 0) 283 errno = ETIME; 284 if (r != 1) 285 return -1; 286 287 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0) 288 return -1; 289 if (eval != 0) { 290 errno = eval; 291 return -1; 292 } 293 294 return 0; 295 } 296 297 static inline int poll_read(int fd, unsigned int timeout_sec) 298 { 299 struct timeval timeout = { .tv_sec = timeout_sec }; 300 fd_set rfds; 301 int r; 302 303 FD_ZERO(&rfds); 304 FD_SET(fd, &rfds); 305 306 r = select(fd + 1, &rfds, NULL, NULL, &timeout); 307 if (r == 0) 308 errno = ETIME; 309 310 return r == 1 ? 0 : -1; 311 } 312 313 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len, 314 unsigned int timeout_sec) 315 { 316 if (poll_read(fd, timeout_sec)) 317 return -1; 318 319 return accept(fd, addr, len); 320 } 321 322 static inline int recv_timeout(int fd, void *buf, size_t len, int flags, 323 unsigned int timeout_sec) 324 { 325 if (poll_read(fd, timeout_sec)) 326 return -1; 327 328 return recv(fd, buf, len, flags); 329 } 330 331 332 static inline int create_pair(int family, int sotype, int *p0, int *p1) 333 { 334 __close_fd int s, c = -1, p = -1; 335 struct sockaddr_storage addr; 336 socklen_t len; 337 int err; 338 339 s = socket_loopback(family, sotype); 340 if (s < 0) 341 return s; 342 343 c = xsocket(family, sotype, 0); 344 if (c < 0) 345 return c; 346 347 init_addr_loopback(family, &addr, &len); 348 err = xbind(c, sockaddr(&addr), len); 349 if (err) 350 return err; 351 352 len = sizeof(addr); 353 err = xgetsockname(s, sockaddr(&addr), &len); 354 if (err) 355 return err; 356 357 err = connect(c, sockaddr(&addr), len); 358 if (err) { 359 if (errno != EINPROGRESS) { 360 FAIL_ERRNO("connect"); 361 return err; 362 } 363 364 err = poll_connect(c, IO_TIMEOUT_SEC); 365 if (err) { 366 FAIL_ERRNO("poll_connect"); 367 return err; 368 } 369 } 370 371 switch (sotype & SOCK_TYPE_MASK) { 372 case SOCK_DGRAM: 373 err = xgetsockname(c, sockaddr(&addr), &len); 374 if (err) 375 return err; 376 377 err = xconnect(s, sockaddr(&addr), len); 378 if (err) 379 return err; 380 381 *p0 = take_fd(s); 382 break; 383 case SOCK_STREAM: 384 case SOCK_SEQPACKET: 385 p = xaccept_nonblock(s, NULL, NULL); 386 if (p < 0) 387 return p; 388 389 *p0 = take_fd(p); 390 break; 391 default: 392 FAIL("Unsupported socket type %#x", sotype); 393 return -EOPNOTSUPP; 394 } 395 396 *p1 = take_fd(c); 397 return 0; 398 } 399 400 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1, 401 int *p0, int *p1) 402 { 403 int err; 404 405 err = create_pair(family, sotype, c0, p0); 406 if (err) 407 return err; 408 409 err = create_pair(family, sotype, c1, p1); 410 if (err) { 411 close(*c0); 412 close(*p0); 413 } 414 415 return err; 416 } 417 418 static inline const char *socket_kind_to_str(int sock_fd) 419 { 420 socklen_t opt_len; 421 int domain, type; 422 423 opt_len = sizeof(domain); 424 if (getsockopt(sock_fd, SOL_SOCKET, SO_DOMAIN, &domain, &opt_len)) 425 FAIL_ERRNO("getsockopt(SO_DOMAIN)"); 426 427 opt_len = sizeof(type); 428 if (getsockopt(sock_fd, SOL_SOCKET, SO_TYPE, &type, &opt_len)) 429 FAIL_ERRNO("getsockopt(SO_TYPE)"); 430 431 switch (domain) { 432 case AF_INET: 433 switch (type) { 434 case SOCK_STREAM: 435 return "tcp4"; 436 case SOCK_DGRAM: 437 return "udp4"; 438 } 439 break; 440 case AF_INET6: 441 switch (type) { 442 case SOCK_STREAM: 443 return "tcp6"; 444 case SOCK_DGRAM: 445 return "udp6"; 446 } 447 break; 448 case AF_UNIX: 449 switch (type) { 450 case SOCK_STREAM: 451 return "u_str"; 452 case SOCK_DGRAM: 453 return "u_dgr"; 454 case SOCK_SEQPACKET: 455 return "u_seq"; 456 } 457 break; 458 case AF_VSOCK: 459 switch (type) { 460 case SOCK_STREAM: 461 return "v_str"; 462 case SOCK_DGRAM: 463 return "v_dgr"; 464 case SOCK_SEQPACKET: 465 return "v_seq"; 466 } 467 break; 468 } 469 470 return "???"; 471 } 472 473 #endif // __SOCKET_HELPERS__ 474