1 /* SPDX-License-Identifier: GPL-2.0 */ 2 3 #ifndef __SOCKET_HELPERS__ 4 #define __SOCKET_HELPERS__ 5 6 #include <sys/un.h> 7 #include <linux/vm_sockets.h> 8 9 /* include/linux/net.h */ 10 #define SOCK_TYPE_MASK 0xf 11 12 #define IO_TIMEOUT_SEC 30 13 #define MAX_STRERR_LEN 256 14 15 /* workaround for older vm_sockets.h */ 16 #ifndef VMADDR_CID_LOCAL 17 #define VMADDR_CID_LOCAL 1 18 #endif 19 20 /* include/linux/cleanup.h */ 21 #define __get_and_null(p, nullvalue) \ 22 ({ \ 23 __auto_type __ptr = &(p); \ 24 __auto_type __val = *__ptr; \ 25 *__ptr = nullvalue; \ 26 __val; \ 27 }) 28 29 #define take_fd(fd) __get_and_null(fd, -EBADF) 30 31 /* Wrappers that fail the test on error and report it. */ 32 33 #define _FAIL(errnum, fmt...) \ 34 ({ \ 35 error_at_line(0, (errnum), __func__, __LINE__, fmt); \ 36 CHECK_FAIL(true); \ 37 }) 38 #define FAIL(fmt...) _FAIL(0, fmt) 39 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt) 40 #define FAIL_LIBBPF(err, msg) \ 41 ({ \ 42 char __buf[MAX_STRERR_LEN]; \ 43 libbpf_strerror((err), __buf, sizeof(__buf)); \ 44 FAIL("%s: %s", (msg), __buf); \ 45 }) 46 47 48 #define xaccept_nonblock(fd, addr, len) \ 49 ({ \ 50 int __ret = \ 51 accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC); \ 52 if (__ret == -1) \ 53 FAIL_ERRNO("accept"); \ 54 __ret; \ 55 }) 56 57 #define xbind(fd, addr, len) \ 58 ({ \ 59 int __ret = bind((fd), (addr), (len)); \ 60 if (__ret == -1) \ 61 FAIL_ERRNO("bind"); \ 62 __ret; \ 63 }) 64 65 #define xclose(fd) \ 66 ({ \ 67 int __ret = close((fd)); \ 68 if (__ret == -1) \ 69 FAIL_ERRNO("close"); \ 70 __ret; \ 71 }) 72 73 #define xconnect(fd, addr, len) \ 74 ({ \ 75 int __ret = connect((fd), (addr), (len)); \ 76 if (__ret == -1) \ 77 FAIL_ERRNO("connect"); \ 78 __ret; \ 79 }) 80 81 #define xgetsockname(fd, addr, len) \ 82 ({ \ 83 int __ret = getsockname((fd), (addr), (len)); \ 84 if (__ret == -1) \ 85 FAIL_ERRNO("getsockname"); \ 86 __ret; \ 87 }) 88 89 #define xgetsockopt(fd, level, name, val, len) \ 90 ({ \ 91 int __ret = getsockopt((fd), (level), (name), (val), (len)); \ 92 if (__ret == -1) \ 93 FAIL_ERRNO("getsockopt(" #name ")"); \ 94 __ret; \ 95 }) 96 97 #define xlisten(fd, backlog) \ 98 ({ \ 99 int __ret = listen((fd), (backlog)); \ 100 if (__ret == -1) \ 101 FAIL_ERRNO("listen"); \ 102 __ret; \ 103 }) 104 105 #define xsetsockopt(fd, level, name, val, len) \ 106 ({ \ 107 int __ret = setsockopt((fd), (level), (name), (val), (len)); \ 108 if (__ret == -1) \ 109 FAIL_ERRNO("setsockopt(" #name ")"); \ 110 __ret; \ 111 }) 112 113 #define xsend(fd, buf, len, flags) \ 114 ({ \ 115 ssize_t __ret = send((fd), (buf), (len), (flags)); \ 116 if (__ret == -1) \ 117 FAIL_ERRNO("send"); \ 118 __ret; \ 119 }) 120 121 #define xrecv_nonblock(fd, buf, len, flags) \ 122 ({ \ 123 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags), \ 124 IO_TIMEOUT_SEC); \ 125 if (__ret == -1) \ 126 FAIL_ERRNO("recv"); \ 127 __ret; \ 128 }) 129 130 #define xsocket(family, sotype, flags) \ 131 ({ \ 132 int __ret = socket(family, sotype, flags); \ 133 if (__ret == -1) \ 134 FAIL_ERRNO("socket"); \ 135 __ret; \ 136 }) 137 138 static inline void close_fd(int *fd) 139 { 140 if (*fd >= 0) 141 xclose(*fd); 142 } 143 144 #define __close_fd __attribute__((cleanup(close_fd))) 145 146 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss) 147 { 148 return (struct sockaddr *)ss; 149 } 150 151 static inline void init_addr_loopback4(struct sockaddr_storage *ss, 152 socklen_t *len) 153 { 154 struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss)); 155 156 addr4->sin_family = AF_INET; 157 addr4->sin_port = 0; 158 addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK); 159 *len = sizeof(*addr4); 160 } 161 162 static inline void init_addr_loopback6(struct sockaddr_storage *ss, 163 socklen_t *len) 164 { 165 struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss)); 166 167 addr6->sin6_family = AF_INET6; 168 addr6->sin6_port = 0; 169 addr6->sin6_addr = in6addr_loopback; 170 *len = sizeof(*addr6); 171 } 172 173 static inline void init_addr_loopback_unix(struct sockaddr_storage *ss, 174 socklen_t *len) 175 { 176 struct sockaddr_un *addr = memset(ss, 0, sizeof(*ss)); 177 178 addr->sun_family = AF_UNIX; 179 *len = sizeof(sa_family_t); 180 } 181 182 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss, 183 socklen_t *len) 184 { 185 struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss)); 186 187 addr->svm_family = AF_VSOCK; 188 addr->svm_port = VMADDR_PORT_ANY; 189 addr->svm_cid = VMADDR_CID_LOCAL; 190 *len = sizeof(*addr); 191 } 192 193 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss, 194 socklen_t *len) 195 { 196 switch (family) { 197 case AF_INET: 198 init_addr_loopback4(ss, len); 199 return; 200 case AF_INET6: 201 init_addr_loopback6(ss, len); 202 return; 203 case AF_UNIX: 204 init_addr_loopback_unix(ss, len); 205 return; 206 case AF_VSOCK: 207 init_addr_loopback_vsock(ss, len); 208 return; 209 default: 210 FAIL("unsupported address family %d", family); 211 } 212 } 213 214 static inline int enable_reuseport(int s, int progfd) 215 { 216 int err, one = 1; 217 218 err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one)); 219 if (err) 220 return -1; 221 err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd, 222 sizeof(progfd)); 223 if (err) 224 return -1; 225 226 return 0; 227 } 228 229 static inline int socket_loopback_reuseport(int family, int sotype, int progfd) 230 { 231 struct sockaddr_storage addr; 232 socklen_t len = 0; 233 int err, s; 234 235 init_addr_loopback(family, &addr, &len); 236 237 s = xsocket(family, sotype, 0); 238 if (s == -1) 239 return -1; 240 241 if (progfd >= 0) 242 enable_reuseport(s, progfd); 243 244 err = xbind(s, sockaddr(&addr), len); 245 if (err) 246 goto close; 247 248 if (sotype & SOCK_DGRAM) 249 return s; 250 251 err = xlisten(s, SOMAXCONN); 252 if (err) 253 goto close; 254 255 return s; 256 close: 257 xclose(s); 258 return -1; 259 } 260 261 static inline int socket_loopback(int family, int sotype) 262 { 263 return socket_loopback_reuseport(family, sotype, -1); 264 } 265 266 static inline int poll_connect(int fd, unsigned int timeout_sec) 267 { 268 struct timeval timeout = { .tv_sec = timeout_sec }; 269 fd_set wfds; 270 int r, eval; 271 socklen_t esize = sizeof(eval); 272 273 FD_ZERO(&wfds); 274 FD_SET(fd, &wfds); 275 276 r = select(fd + 1, NULL, &wfds, NULL, &timeout); 277 if (r == 0) 278 errno = ETIME; 279 if (r != 1) 280 return -1; 281 282 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0) 283 return -1; 284 if (eval != 0) { 285 errno = eval; 286 return -1; 287 } 288 289 return 0; 290 } 291 292 static inline int poll_read(int fd, unsigned int timeout_sec) 293 { 294 struct timeval timeout = { .tv_sec = timeout_sec }; 295 fd_set rfds; 296 int r; 297 298 FD_ZERO(&rfds); 299 FD_SET(fd, &rfds); 300 301 r = select(fd + 1, &rfds, NULL, NULL, &timeout); 302 if (r == 0) 303 errno = ETIME; 304 305 return r == 1 ? 0 : -1; 306 } 307 308 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len, 309 unsigned int timeout_sec) 310 { 311 if (poll_read(fd, timeout_sec)) 312 return -1; 313 314 return accept(fd, addr, len); 315 } 316 317 static inline int recv_timeout(int fd, void *buf, size_t len, int flags, 318 unsigned int timeout_sec) 319 { 320 if (poll_read(fd, timeout_sec)) 321 return -1; 322 323 return recv(fd, buf, len, flags); 324 } 325 326 327 static inline int create_pair(int family, int sotype, int *p0, int *p1) 328 { 329 __close_fd int s, c = -1, p = -1; 330 struct sockaddr_storage addr; 331 socklen_t len; 332 int err; 333 334 s = socket_loopback(family, sotype); 335 if (s < 0) 336 return s; 337 338 c = xsocket(family, sotype, 0); 339 if (c < 0) 340 return c; 341 342 init_addr_loopback(family, &addr, &len); 343 err = xbind(c, sockaddr(&addr), len); 344 if (err) 345 return err; 346 347 len = sizeof(addr); 348 err = xgetsockname(s, sockaddr(&addr), &len); 349 if (err) 350 return err; 351 352 err = connect(c, sockaddr(&addr), len); 353 if (err) { 354 if (errno != EINPROGRESS) { 355 FAIL_ERRNO("connect"); 356 return err; 357 } 358 359 err = poll_connect(c, IO_TIMEOUT_SEC); 360 if (err) { 361 FAIL_ERRNO("poll_connect"); 362 return err; 363 } 364 } 365 366 switch (sotype & SOCK_TYPE_MASK) { 367 case SOCK_DGRAM: 368 err = xgetsockname(c, sockaddr(&addr), &len); 369 if (err) 370 return err; 371 372 err = xconnect(s, sockaddr(&addr), len); 373 if (err) 374 return err; 375 376 *p0 = take_fd(s); 377 break; 378 case SOCK_STREAM: 379 case SOCK_SEQPACKET: 380 p = xaccept_nonblock(s, NULL, NULL); 381 if (p < 0) 382 return p; 383 384 *p0 = take_fd(p); 385 break; 386 default: 387 FAIL("Unsupported socket type %#x", sotype); 388 return -EOPNOTSUPP; 389 } 390 391 *p1 = take_fd(c); 392 return 0; 393 } 394 395 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1, 396 int *p0, int *p1) 397 { 398 int err; 399 400 err = create_pair(family, sotype, c0, p0); 401 if (err) 402 return err; 403 404 err = create_pair(family, sotype, c1, p1); 405 if (err) { 406 close(*c0); 407 close(*p0); 408 } 409 410 return err; 411 } 412 413 static inline const char *socket_kind_to_str(int sock_fd) 414 { 415 socklen_t opt_len; 416 int domain, type; 417 418 opt_len = sizeof(domain); 419 if (getsockopt(sock_fd, SOL_SOCKET, SO_DOMAIN, &domain, &opt_len)) 420 FAIL_ERRNO("getsockopt(SO_DOMAIN)"); 421 422 opt_len = sizeof(type); 423 if (getsockopt(sock_fd, SOL_SOCKET, SO_TYPE, &type, &opt_len)) 424 FAIL_ERRNO("getsockopt(SO_TYPE)"); 425 426 switch (domain) { 427 case AF_INET: 428 switch (type) { 429 case SOCK_STREAM: 430 return "tcp4"; 431 case SOCK_DGRAM: 432 return "udp4"; 433 } 434 break; 435 case AF_INET6: 436 switch (type) { 437 case SOCK_STREAM: 438 return "tcp6"; 439 case SOCK_DGRAM: 440 return "udp6"; 441 } 442 break; 443 case AF_UNIX: 444 switch (type) { 445 case SOCK_STREAM: 446 return "u_str"; 447 case SOCK_DGRAM: 448 return "u_dgr"; 449 case SOCK_SEQPACKET: 450 return "u_seq"; 451 } 452 break; 453 case AF_VSOCK: 454 switch (type) { 455 case SOCK_STREAM: 456 return "v_str"; 457 case SOCK_DGRAM: 458 return "v_dgr"; 459 case SOCK_SEQPACKET: 460 return "v_seq"; 461 } 462 break; 463 } 464 465 return "???"; 466 } 467 468 #endif // __SOCKET_HELPERS__ 469