1 /* SPDX-License-Identifier: GPL-2.0 */ 2 3 #ifndef __SOCKET_HELPERS__ 4 #define __SOCKET_HELPERS__ 5 6 #include <linux/vm_sockets.h> 7 8 /* include/linux/net.h */ 9 #define SOCK_TYPE_MASK 0xf 10 11 #define IO_TIMEOUT_SEC 30 12 #define MAX_STRERR_LEN 256 13 14 /* workaround for older vm_sockets.h */ 15 #ifndef VMADDR_CID_LOCAL 16 #define VMADDR_CID_LOCAL 1 17 #endif 18 19 /* include/linux/cleanup.h */ 20 #define __get_and_null(p, nullvalue) \ 21 ({ \ 22 __auto_type __ptr = &(p); \ 23 __auto_type __val = *__ptr; \ 24 *__ptr = nullvalue; \ 25 __val; \ 26 }) 27 28 #define take_fd(fd) __get_and_null(fd, -EBADF) 29 30 /* Wrappers that fail the test on error and report it. */ 31 32 #define _FAIL(errnum, fmt...) \ 33 ({ \ 34 error_at_line(0, (errnum), __func__, __LINE__, fmt); \ 35 CHECK_FAIL(true); \ 36 }) 37 #define FAIL(fmt...) _FAIL(0, fmt) 38 #define FAIL_ERRNO(fmt...) _FAIL(errno, fmt) 39 #define FAIL_LIBBPF(err, msg) \ 40 ({ \ 41 char __buf[MAX_STRERR_LEN]; \ 42 libbpf_strerror((err), __buf, sizeof(__buf)); \ 43 FAIL("%s: %s", (msg), __buf); \ 44 }) 45 46 47 #define xaccept_nonblock(fd, addr, len) \ 48 ({ \ 49 int __ret = \ 50 accept_timeout((fd), (addr), (len), IO_TIMEOUT_SEC); \ 51 if (__ret == -1) \ 52 FAIL_ERRNO("accept"); \ 53 __ret; \ 54 }) 55 56 #define xbind(fd, addr, len) \ 57 ({ \ 58 int __ret = bind((fd), (addr), (len)); \ 59 if (__ret == -1) \ 60 FAIL_ERRNO("bind"); \ 61 __ret; \ 62 }) 63 64 #define xclose(fd) \ 65 ({ \ 66 int __ret = close((fd)); \ 67 if (__ret == -1) \ 68 FAIL_ERRNO("close"); \ 69 __ret; \ 70 }) 71 72 #define xconnect(fd, addr, len) \ 73 ({ \ 74 int __ret = connect((fd), (addr), (len)); \ 75 if (__ret == -1) \ 76 FAIL_ERRNO("connect"); \ 77 __ret; \ 78 }) 79 80 #define xgetsockname(fd, addr, len) \ 81 ({ \ 82 int __ret = getsockname((fd), (addr), (len)); \ 83 if (__ret == -1) \ 84 FAIL_ERRNO("getsockname"); \ 85 __ret; \ 86 }) 87 88 #define xgetsockopt(fd, level, name, val, len) \ 89 ({ \ 90 int __ret = getsockopt((fd), (level), (name), (val), (len)); \ 91 if (__ret == -1) \ 92 FAIL_ERRNO("getsockopt(" #name ")"); \ 93 __ret; \ 94 }) 95 96 #define xlisten(fd, backlog) \ 97 ({ \ 98 int __ret = listen((fd), (backlog)); \ 99 if (__ret == -1) \ 100 FAIL_ERRNO("listen"); \ 101 __ret; \ 102 }) 103 104 #define xsetsockopt(fd, level, name, val, len) \ 105 ({ \ 106 int __ret = setsockopt((fd), (level), (name), (val), (len)); \ 107 if (__ret == -1) \ 108 FAIL_ERRNO("setsockopt(" #name ")"); \ 109 __ret; \ 110 }) 111 112 #define xsend(fd, buf, len, flags) \ 113 ({ \ 114 ssize_t __ret = send((fd), (buf), (len), (flags)); \ 115 if (__ret == -1) \ 116 FAIL_ERRNO("send"); \ 117 __ret; \ 118 }) 119 120 #define xrecv_nonblock(fd, buf, len, flags) \ 121 ({ \ 122 ssize_t __ret = recv_timeout((fd), (buf), (len), (flags), \ 123 IO_TIMEOUT_SEC); \ 124 if (__ret == -1) \ 125 FAIL_ERRNO("recv"); \ 126 __ret; \ 127 }) 128 129 #define xsocket(family, sotype, flags) \ 130 ({ \ 131 int __ret = socket(family, sotype, flags); \ 132 if (__ret == -1) \ 133 FAIL_ERRNO("socket"); \ 134 __ret; \ 135 }) 136 137 static inline void close_fd(int *fd) 138 { 139 if (*fd >= 0) 140 xclose(*fd); 141 } 142 143 #define __close_fd __attribute__((cleanup(close_fd))) 144 145 static inline struct sockaddr *sockaddr(struct sockaddr_storage *ss) 146 { 147 return (struct sockaddr *)ss; 148 } 149 150 static inline void init_addr_loopback4(struct sockaddr_storage *ss, 151 socklen_t *len) 152 { 153 struct sockaddr_in *addr4 = memset(ss, 0, sizeof(*ss)); 154 155 addr4->sin_family = AF_INET; 156 addr4->sin_port = 0; 157 addr4->sin_addr.s_addr = htonl(INADDR_LOOPBACK); 158 *len = sizeof(*addr4); 159 } 160 161 static inline void init_addr_loopback6(struct sockaddr_storage *ss, 162 socklen_t *len) 163 { 164 struct sockaddr_in6 *addr6 = memset(ss, 0, sizeof(*ss)); 165 166 addr6->sin6_family = AF_INET6; 167 addr6->sin6_port = 0; 168 addr6->sin6_addr = in6addr_loopback; 169 *len = sizeof(*addr6); 170 } 171 172 static inline void init_addr_loopback_vsock(struct sockaddr_storage *ss, 173 socklen_t *len) 174 { 175 struct sockaddr_vm *addr = memset(ss, 0, sizeof(*ss)); 176 177 addr->svm_family = AF_VSOCK; 178 addr->svm_port = VMADDR_PORT_ANY; 179 addr->svm_cid = VMADDR_CID_LOCAL; 180 *len = sizeof(*addr); 181 } 182 183 static inline void init_addr_loopback(int family, struct sockaddr_storage *ss, 184 socklen_t *len) 185 { 186 switch (family) { 187 case AF_INET: 188 init_addr_loopback4(ss, len); 189 return; 190 case AF_INET6: 191 init_addr_loopback6(ss, len); 192 return; 193 case AF_VSOCK: 194 init_addr_loopback_vsock(ss, len); 195 return; 196 default: 197 FAIL("unsupported address family %d", family); 198 } 199 } 200 201 static inline int enable_reuseport(int s, int progfd) 202 { 203 int err, one = 1; 204 205 err = xsetsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one)); 206 if (err) 207 return -1; 208 err = xsetsockopt(s, SOL_SOCKET, SO_ATTACH_REUSEPORT_EBPF, &progfd, 209 sizeof(progfd)); 210 if (err) 211 return -1; 212 213 return 0; 214 } 215 216 static inline int socket_loopback_reuseport(int family, int sotype, int progfd) 217 { 218 struct sockaddr_storage addr; 219 socklen_t len = 0; 220 int err, s; 221 222 init_addr_loopback(family, &addr, &len); 223 224 s = xsocket(family, sotype, 0); 225 if (s == -1) 226 return -1; 227 228 if (progfd >= 0) 229 enable_reuseport(s, progfd); 230 231 err = xbind(s, sockaddr(&addr), len); 232 if (err) 233 goto close; 234 235 if (sotype & SOCK_DGRAM) 236 return s; 237 238 err = xlisten(s, SOMAXCONN); 239 if (err) 240 goto close; 241 242 return s; 243 close: 244 xclose(s); 245 return -1; 246 } 247 248 static inline int socket_loopback(int family, int sotype) 249 { 250 return socket_loopback_reuseport(family, sotype, -1); 251 } 252 253 static inline int poll_connect(int fd, unsigned int timeout_sec) 254 { 255 struct timeval timeout = { .tv_sec = timeout_sec }; 256 fd_set wfds; 257 int r, eval; 258 socklen_t esize = sizeof(eval); 259 260 FD_ZERO(&wfds); 261 FD_SET(fd, &wfds); 262 263 r = select(fd + 1, NULL, &wfds, NULL, &timeout); 264 if (r == 0) 265 errno = ETIME; 266 if (r != 1) 267 return -1; 268 269 if (getsockopt(fd, SOL_SOCKET, SO_ERROR, &eval, &esize) < 0) 270 return -1; 271 if (eval != 0) { 272 errno = eval; 273 return -1; 274 } 275 276 return 0; 277 } 278 279 static inline int poll_read(int fd, unsigned int timeout_sec) 280 { 281 struct timeval timeout = { .tv_sec = timeout_sec }; 282 fd_set rfds; 283 int r; 284 285 FD_ZERO(&rfds); 286 FD_SET(fd, &rfds); 287 288 r = select(fd + 1, &rfds, NULL, NULL, &timeout); 289 if (r == 0) 290 errno = ETIME; 291 292 return r == 1 ? 0 : -1; 293 } 294 295 static inline int accept_timeout(int fd, struct sockaddr *addr, socklen_t *len, 296 unsigned int timeout_sec) 297 { 298 if (poll_read(fd, timeout_sec)) 299 return -1; 300 301 return accept(fd, addr, len); 302 } 303 304 static inline int recv_timeout(int fd, void *buf, size_t len, int flags, 305 unsigned int timeout_sec) 306 { 307 if (poll_read(fd, timeout_sec)) 308 return -1; 309 310 return recv(fd, buf, len, flags); 311 } 312 313 314 static inline int create_pair(int family, int sotype, int *p0, int *p1) 315 { 316 __close_fd int s, c = -1, p = -1; 317 struct sockaddr_storage addr; 318 socklen_t len = sizeof(addr); 319 int err; 320 321 s = socket_loopback(family, sotype); 322 if (s < 0) 323 return s; 324 325 err = xgetsockname(s, sockaddr(&addr), &len); 326 if (err) 327 return err; 328 329 c = xsocket(family, sotype, 0); 330 if (c < 0) 331 return c; 332 333 err = connect(c, sockaddr(&addr), len); 334 if (err) { 335 if (errno != EINPROGRESS) { 336 FAIL_ERRNO("connect"); 337 return err; 338 } 339 340 err = poll_connect(c, IO_TIMEOUT_SEC); 341 if (err) { 342 FAIL_ERRNO("poll_connect"); 343 return err; 344 } 345 } 346 347 switch (sotype & SOCK_TYPE_MASK) { 348 case SOCK_DGRAM: 349 err = xgetsockname(c, sockaddr(&addr), &len); 350 if (err) 351 return err; 352 353 err = xconnect(s, sockaddr(&addr), len); 354 if (err) 355 return err; 356 357 *p0 = take_fd(s); 358 break; 359 case SOCK_STREAM: 360 case SOCK_SEQPACKET: 361 p = xaccept_nonblock(s, NULL, NULL); 362 if (p < 0) 363 return p; 364 365 *p0 = take_fd(p); 366 break; 367 default: 368 FAIL("Unsupported socket type %#x", sotype); 369 return -EOPNOTSUPP; 370 } 371 372 *p1 = take_fd(c); 373 return 0; 374 } 375 376 static inline int create_socket_pairs(int family, int sotype, int *c0, int *c1, 377 int *p0, int *p1) 378 { 379 int err; 380 381 err = create_pair(family, sotype, c0, p0); 382 if (err) 383 return err; 384 385 err = create_pair(family, sotype, c1, p1); 386 if (err) { 387 close(*c0); 388 close(*p0); 389 } 390 391 return err; 392 } 393 394 #endif // __SOCKET_HELPERS__ 395