1 /*
2 * Copyright (c) 2011-2012 Intel Corporation. All rights reserved.
3 *
4 * This software is available to you under a choice of one of two
5 * licenses. You may choose to be licensed under the terms of the GNU
6 * General Public License (GPL) Version 2, available from the file
7 * COPYING in the main directory of this source tree, or the
8 * OpenIB.org BSD license below:
9 *
10 * Redistribution and use in source and binary forms, with or
11 * without modification, are permitted provided that the following
12 * conditions are met:
13 *
14 * - Redistributions of source code must retain the above
15 * copyright notice, this list of conditions and the following
16 * disclaimer.
17 *
18 * - Redistributions in binary form must reproduce the above
19 * copyright notice, this list of conditions and the following
20 * disclaimer in the documentation and/or other materials
21 * provided with the distribution.
22 *
23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30 * SOFTWARE.
31 *
32 */
33 #define _GNU_SOURCE
34 #include <config.h>
35
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <sys/uio.h>
39 #include <sys/stat.h>
40 #include <sys/mman.h>
41 #include <stdarg.h>
42 #include <dlfcn.h>
43 #include <netdb.h>
44 #include <unistd.h>
45 #include <fcntl.h>
46 #include <string.h>
47 #include <netinet/tcp.h>
48 #include <unistd.h>
49 #include <semaphore.h>
50 #include <ctype.h>
51 #include <stdlib.h>
52 #include <stdio.h>
53
54 #include <rdma/rdma_cma.h>
55 #include <rdma/rdma_verbs.h>
56 #include <rdma/rsocket.h>
57 #include "cma.h"
58 #include "indexer.h"
59
60 struct socket_calls {
61 int (*socket)(int domain, int type, int protocol);
62 int (*bind)(int socket, const struct sockaddr *addr, socklen_t addrlen);
63 int (*listen)(int socket, int backlog);
64 int (*accept)(int socket, struct sockaddr *addr, socklen_t *addrlen);
65 int (*connect)(int socket, const struct sockaddr *addr, socklen_t addrlen);
66 ssize_t (*recv)(int socket, void *buf, size_t len, int flags);
67 ssize_t (*recvfrom)(int socket, void *buf, size_t len, int flags,
68 struct sockaddr *src_addr, socklen_t *addrlen);
69 ssize_t (*recvmsg)(int socket, struct msghdr *msg, int flags);
70 ssize_t (*read)(int socket, void *buf, size_t count);
71 ssize_t (*readv)(int socket, const struct iovec *iov, int iovcnt);
72 ssize_t (*send)(int socket, const void *buf, size_t len, int flags);
73 ssize_t (*sendto)(int socket, const void *buf, size_t len, int flags,
74 const struct sockaddr *dest_addr, socklen_t addrlen);
75 ssize_t (*sendmsg)(int socket, const struct msghdr *msg, int flags);
76 ssize_t (*write)(int socket, const void *buf, size_t count);
77 ssize_t (*writev)(int socket, const struct iovec *iov, int iovcnt);
78 int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout);
79 int (*shutdown)(int socket, int how);
80 int (*close)(int socket);
81 int (*getpeername)(int socket, struct sockaddr *addr, socklen_t *addrlen);
82 int (*getsockname)(int socket, struct sockaddr *addr, socklen_t *addrlen);
83 int (*setsockopt)(int socket, int level, int optname,
84 const void *optval, socklen_t optlen);
85 int (*getsockopt)(int socket, int level, int optname,
86 void *optval, socklen_t *optlen);
87 int (*fcntl)(int socket, int cmd, ... /* arg */);
88 int (*dup2)(int oldfd, int newfd);
89 ssize_t (*sendfile)(int out_fd, int in_fd, off_t *offset, size_t count);
90 int (*fxstat)(int ver, int fd, struct stat *buf);
91 };
92
93 static struct socket_calls real;
94 static struct socket_calls rs;
95
96 static struct index_map idm;
97 static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER;
98
99 static int sq_size;
100 static int rq_size;
101 static int sq_inline;
102 static int fork_support;
103
104 enum fd_type {
105 fd_normal,
106 fd_rsocket
107 };
108
109 enum fd_fork_state {
110 fd_ready,
111 fd_fork,
112 fd_fork_listen,
113 fd_fork_active,
114 fd_fork_passive
115 };
116
117 struct fd_info {
118 enum fd_type type;
119 enum fd_fork_state state;
120 int fd;
121 int dupfd;
122 _Atomic(int) refcnt;
123 };
124
125 struct config_entry {
126 char *name;
127 int domain;
128 int type;
129 int protocol;
130 };
131
132 static struct config_entry *config;
133 static int config_cnt;
134
free_config(void)135 static void free_config(void)
136 {
137 while (config_cnt)
138 free(config[--config_cnt].name);
139
140 free(config);
141 }
142
143 /*
144 * Config file format:
145 * # Starting '#' indicates comment
146 * # wild card values are supported using '*'
147 * # domain - *, INET, INET6, IB
148 * # type - *, STREAM, DGRAM
149 * # protocol - *, TCP, UDP
150 * program_name domain type protocol
151 */
scan_config(void)152 static void scan_config(void)
153 {
154 struct config_entry *new_config;
155 FILE *fp;
156 char line[120], prog[64], dom[16], type[16], proto[16];
157
158 fp = fopen(RS_CONF_DIR "/preload_config", "r");
159 if (!fp)
160 return;
161
162 while (fgets(line, sizeof(line), fp)) {
163 if (line[0] == '#')
164 continue;
165
166 if (sscanf(line, "%64s%16s%16s%16s", prog, dom, type, proto) != 4)
167 continue;
168
169 new_config = realloc(config, (config_cnt + 1) *
170 sizeof(struct config_entry));
171 if (!new_config)
172 break;
173
174 config = new_config;
175 memset(&config[config_cnt], 0, sizeof(struct config_entry));
176
177 if (!strcasecmp(dom, "INET") ||
178 !strcasecmp(dom, "AF_INET") ||
179 !strcasecmp(dom, "PF_INET")) {
180 config[config_cnt].domain = AF_INET;
181 } else if (!strcasecmp(dom, "INET6") ||
182 !strcasecmp(dom, "AF_INET6") ||
183 !strcasecmp(dom, "PF_INET6")) {
184 config[config_cnt].domain = AF_INET6;
185 } else if (!strcasecmp(dom, "IB") ||
186 !strcasecmp(dom, "AF_IB") ||
187 !strcasecmp(dom, "PF_IB")) {
188 config[config_cnt].domain = AF_IB;
189 } else if (strcmp(dom, "*")) {
190 continue;
191 }
192
193 if (!strcasecmp(type, "STREAM") ||
194 !strcasecmp(type, "SOCK_STREAM")) {
195 config[config_cnt].type = SOCK_STREAM;
196 } else if (!strcasecmp(type, "DGRAM") ||
197 !strcasecmp(type, "SOCK_DGRAM")) {
198 config[config_cnt].type = SOCK_DGRAM;
199 } else if (strcmp(type, "*")) {
200 continue;
201 }
202
203 if (!strcasecmp(proto, "TCP") ||
204 !strcasecmp(proto, "IPPROTO_TCP")) {
205 config[config_cnt].protocol = IPPROTO_TCP;
206 } else if (!strcasecmp(proto, "UDP") ||
207 !strcasecmp(proto, "IPPROTO_UDP")) {
208 config[config_cnt].protocol = IPPROTO_UDP;
209 } else if (strcmp(proto, "*")) {
210 continue;
211 }
212
213 if (strcmp(prog, "*")) {
214 if (!(config[config_cnt].name = strdup(prog)))
215 continue;
216 }
217
218 config_cnt++;
219 }
220
221 fclose(fp);
222 if (config_cnt)
223 atexit(free_config);
224 }
225
intercept_socket(int domain,int type,int protocol)226 static int intercept_socket(int domain, int type, int protocol)
227 {
228 int i;
229
230 if (!config_cnt)
231 return 1;
232
233 if (!protocol) {
234 if (type == SOCK_STREAM)
235 protocol = IPPROTO_TCP;
236 else if (type == SOCK_DGRAM)
237 protocol = IPPROTO_UDP;
238 }
239
240 for (i = 0; i < config_cnt; i++) {
241 if ((!config[i].name ||
242 !strncasecmp(config[i].name, program_invocation_short_name,
243 strlen(config[i].name))) &&
244 (!config[i].domain || config[i].domain == domain) &&
245 (!config[i].type || config[i].type == type) &&
246 (!config[i].protocol || config[i].protocol == protocol))
247 return 1;
248 }
249
250 return 0;
251 }
252
fd_open(void)253 static int fd_open(void)
254 {
255 struct fd_info *fdi;
256 int ret, index;
257
258 fdi = calloc(1, sizeof(*fdi));
259 if (!fdi)
260 return ERR(ENOMEM);
261
262 index = open("/dev/null", O_RDONLY);
263 if (index < 0) {
264 ret = index;
265 goto err1;
266 }
267
268 fdi->dupfd = -1;
269 atomic_store(&fdi->refcnt, 1);
270 pthread_mutex_lock(&mut);
271 ret = idm_set(&idm, index, fdi);
272 pthread_mutex_unlock(&mut);
273 if (ret < 0)
274 goto err2;
275
276 return index;
277
278 err2:
279 real.close(index);
280 err1:
281 free(fdi);
282 return ret;
283 }
284
fd_store(int index,int fd,enum fd_type type,enum fd_fork_state state)285 static void fd_store(int index, int fd, enum fd_type type, enum fd_fork_state state)
286 {
287 struct fd_info *fdi;
288
289 fdi = idm_at(&idm, index);
290 fdi->fd = fd;
291 fdi->type = type;
292 fdi->state = state;
293 }
294
fd_get(int index,int * fd)295 static inline enum fd_type fd_get(int index, int *fd)
296 {
297 struct fd_info *fdi;
298
299 fdi = idm_lookup(&idm, index);
300 if (fdi) {
301 *fd = fdi->fd;
302 return fdi->type;
303
304 } else {
305 *fd = index;
306 return fd_normal;
307 }
308 }
309
fd_getd(int index)310 static inline int fd_getd(int index)
311 {
312 struct fd_info *fdi;
313
314 fdi = idm_lookup(&idm, index);
315 return fdi ? fdi->fd : index;
316 }
317
fd_gets(int index)318 static inline enum fd_fork_state fd_gets(int index)
319 {
320 struct fd_info *fdi;
321
322 fdi = idm_lookup(&idm, index);
323 return fdi ? fdi->state : fd_ready;
324 }
325
fd_gett(int index)326 static inline enum fd_type fd_gett(int index)
327 {
328 struct fd_info *fdi;
329
330 fdi = idm_lookup(&idm, index);
331 return fdi ? fdi->type : fd_normal;
332 }
333
fd_close(int index,int * fd)334 static enum fd_type fd_close(int index, int *fd)
335 {
336 struct fd_info *fdi;
337 enum fd_type type;
338
339 fdi = idm_lookup(&idm, index);
340 if (fdi) {
341 idm_clear(&idm, index);
342 *fd = fdi->fd;
343 type = fdi->type;
344 real.close(index);
345 free(fdi);
346 } else {
347 *fd = index;
348 type = fd_normal;
349 }
350 return type;
351 }
352
getenv_options(void)353 static void getenv_options(void)
354 {
355 char *var;
356
357 var = getenv("RS_SQ_SIZE");
358 if (var)
359 sq_size = atoi(var);
360
361 var = getenv("RS_RQ_SIZE");
362 if (var)
363 rq_size = atoi(var);
364
365 var = getenv("RS_INLINE");
366 if (var)
367 sq_inline = atoi(var);
368
369 var = getenv("RDMAV_FORK_SAFE");
370 if (var)
371 fork_support = atoi(var);
372 }
373
init_preload(void)374 static void init_preload(void)
375 {
376 static int init;
377
378 /* Quick check without lock */
379 if (init)
380 return;
381
382 pthread_mutex_lock(&mut);
383 if (init)
384 goto out;
385
386 real.socket = dlsym(RTLD_NEXT, "socket");
387 real.bind = dlsym(RTLD_NEXT, "bind");
388 real.listen = dlsym(RTLD_NEXT, "listen");
389 real.accept = dlsym(RTLD_NEXT, "accept");
390 real.connect = dlsym(RTLD_NEXT, "connect");
391 real.recv = dlsym(RTLD_NEXT, "recv");
392 real.recvfrom = dlsym(RTLD_NEXT, "recvfrom");
393 real.recvmsg = dlsym(RTLD_NEXT, "recvmsg");
394 real.read = dlsym(RTLD_NEXT, "read");
395 real.readv = dlsym(RTLD_NEXT, "readv");
396 real.send = dlsym(RTLD_NEXT, "send");
397 real.sendto = dlsym(RTLD_NEXT, "sendto");
398 real.sendmsg = dlsym(RTLD_NEXT, "sendmsg");
399 real.write = dlsym(RTLD_NEXT, "write");
400 real.writev = dlsym(RTLD_NEXT, "writev");
401 real.poll = dlsym(RTLD_NEXT, "poll");
402 real.shutdown = dlsym(RTLD_NEXT, "shutdown");
403 real.close = dlsym(RTLD_NEXT, "close");
404 real.getpeername = dlsym(RTLD_NEXT, "getpeername");
405 real.getsockname = dlsym(RTLD_NEXT, "getsockname");
406 real.setsockopt = dlsym(RTLD_NEXT, "setsockopt");
407 real.getsockopt = dlsym(RTLD_NEXT, "getsockopt");
408 real.fcntl = dlsym(RTLD_NEXT, "fcntl");
409 real.dup2 = dlsym(RTLD_NEXT, "dup2");
410 real.sendfile = dlsym(RTLD_NEXT, "sendfile");
411 real.fxstat = dlsym(RTLD_NEXT, "__fxstat");
412
413 rs.socket = dlsym(RTLD_DEFAULT, "rsocket");
414 rs.bind = dlsym(RTLD_DEFAULT, "rbind");
415 rs.listen = dlsym(RTLD_DEFAULT, "rlisten");
416 rs.accept = dlsym(RTLD_DEFAULT, "raccept");
417 rs.connect = dlsym(RTLD_DEFAULT, "rconnect");
418 rs.recv = dlsym(RTLD_DEFAULT, "rrecv");
419 rs.recvfrom = dlsym(RTLD_DEFAULT, "rrecvfrom");
420 rs.recvmsg = dlsym(RTLD_DEFAULT, "rrecvmsg");
421 rs.read = dlsym(RTLD_DEFAULT, "rread");
422 rs.readv = dlsym(RTLD_DEFAULT, "rreadv");
423 rs.send = dlsym(RTLD_DEFAULT, "rsend");
424 rs.sendto = dlsym(RTLD_DEFAULT, "rsendto");
425 rs.sendmsg = dlsym(RTLD_DEFAULT, "rsendmsg");
426 rs.write = dlsym(RTLD_DEFAULT, "rwrite");
427 rs.writev = dlsym(RTLD_DEFAULT, "rwritev");
428 rs.poll = dlsym(RTLD_DEFAULT, "rpoll");
429 rs.shutdown = dlsym(RTLD_DEFAULT, "rshutdown");
430 rs.close = dlsym(RTLD_DEFAULT, "rclose");
431 rs.getpeername = dlsym(RTLD_DEFAULT, "rgetpeername");
432 rs.getsockname = dlsym(RTLD_DEFAULT, "rgetsockname");
433 rs.setsockopt = dlsym(RTLD_DEFAULT, "rsetsockopt");
434 rs.getsockopt = dlsym(RTLD_DEFAULT, "rgetsockopt");
435 rs.fcntl = dlsym(RTLD_DEFAULT, "rfcntl");
436
437 getenv_options();
438 scan_config();
439 init = 1;
440 out:
441 pthread_mutex_unlock(&mut);
442 }
443
444 /*
445 * We currently only handle copying a few common values.
446 */
copysockopts(int dfd,int sfd,struct socket_calls * dapi,struct socket_calls * sapi)447 static int copysockopts(int dfd, int sfd, struct socket_calls *dapi,
448 struct socket_calls *sapi)
449 {
450 socklen_t len;
451 int param, ret;
452
453 ret = sapi->fcntl(sfd, F_GETFL);
454 if (ret > 0)
455 ret = dapi->fcntl(dfd, F_SETFL, ret);
456 if (ret)
457 return ret;
458
459 len = sizeof param;
460 ret = sapi->getsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, ¶m, &len);
461 if (param && !ret)
462 ret = dapi->setsockopt(dfd, SOL_SOCKET, SO_REUSEADDR, ¶m, len);
463 if (ret)
464 return ret;
465
466 len = sizeof param;
467 ret = sapi->getsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, ¶m, &len);
468 if (param && !ret)
469 ret = dapi->setsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, ¶m, len);
470 if (ret)
471 return ret;
472
473 return 0;
474 }
475
476 /*
477 * Convert between an rsocket and a normal socket.
478 */
transpose_socket(int socket,enum fd_type new_type)479 static int transpose_socket(int socket, enum fd_type new_type)
480 {
481 socklen_t len = 0;
482 int sfd, dfd, param, ret;
483 struct socket_calls *sapi, *dapi;
484
485 sfd = fd_getd(socket);
486 if (new_type == fd_rsocket) {
487 dapi = &rs;
488 sapi = ℜ
489 } else {
490 dapi = ℜ
491 sapi = &rs;
492 }
493
494 ret = sapi->getsockname(sfd, NULL, &len);
495 if (ret)
496 return ret;
497
498 param = (len == sizeof(struct sockaddr_in6)) ? PF_INET6 : PF_INET;
499 dfd = dapi->socket(param, SOCK_STREAM, 0);
500 if (dfd < 0)
501 return dfd;
502
503 ret = copysockopts(dfd, sfd, dapi, sapi);
504 if (ret)
505 goto err;
506
507 fd_store(socket, dfd, new_type, fd_ready);
508 return dfd;
509
510 err:
511 dapi->close(dfd);
512 return ret;
513 }
514
515 /*
516 * Use defaults on failure.
517 */
set_rsocket_options(int rsocket)518 static void set_rsocket_options(int rsocket)
519 {
520 if (sq_size)
521 rsetsockopt(rsocket, SOL_RDMA, RDMA_SQSIZE, &sq_size, sizeof sq_size);
522
523 if (rq_size)
524 rsetsockopt(rsocket, SOL_RDMA, RDMA_RQSIZE, &rq_size, sizeof rq_size);
525
526 if (sq_inline)
527 rsetsockopt(rsocket, SOL_RDMA, RDMA_INLINE, &sq_inline, sizeof sq_inline);
528 }
529
socket(int domain,int type,int protocol)530 int socket(int domain, int type, int protocol)
531 {
532 static __thread int recursive;
533 int index, ret;
534
535 init_preload();
536
537 if (recursive || !intercept_socket(domain, type, protocol))
538 goto real;
539
540 index = fd_open();
541 if (index < 0)
542 return index;
543
544 if (fork_support && (domain == PF_INET || domain == PF_INET6) &&
545 (type == SOCK_STREAM) && (!protocol || protocol == IPPROTO_TCP)) {
546 ret = real.socket(domain, type, protocol);
547 if (ret < 0)
548 return ret;
549 fd_store(index, ret, fd_normal, fd_fork);
550 return index;
551 }
552
553 recursive = 1;
554 ret = rsocket(domain, type, protocol);
555 recursive = 0;
556 if (ret >= 0) {
557 fd_store(index, ret, fd_rsocket, fd_ready);
558 set_rsocket_options(ret);
559 return index;
560 }
561 fd_close(index, &ret);
562 real:
563 return real.socket(domain, type, protocol);
564 }
565
bind(int socket,const struct sockaddr * addr,socklen_t addrlen)566 int bind(int socket, const struct sockaddr *addr, socklen_t addrlen)
567 {
568 int fd;
569 return (fd_get(socket, &fd) == fd_rsocket) ?
570 rbind(fd, addr, addrlen) : real.bind(fd, addr, addrlen);
571 }
572
listen(int socket,int backlog)573 int listen(int socket, int backlog)
574 {
575 int fd, ret;
576 if (fd_get(socket, &fd) == fd_rsocket) {
577 ret = rlisten(fd, backlog);
578 } else {
579 ret = real.listen(fd, backlog);
580 if (!ret && fd_gets(socket) == fd_fork)
581 fd_store(socket, fd, fd_normal, fd_fork_listen);
582 }
583 return ret;
584 }
585
accept(int socket,struct sockaddr * addr,socklen_t * addrlen)586 int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
587 {
588 int fd, index, ret;
589
590 if (fd_get(socket, &fd) == fd_rsocket) {
591 index = fd_open();
592 if (index < 0)
593 return index;
594
595 ret = raccept(fd, addr, addrlen);
596 if (ret < 0) {
597 fd_close(index, &fd);
598 return ret;
599 }
600
601 fd_store(index, ret, fd_rsocket, fd_ready);
602 return index;
603 } else if (fd_gets(socket) == fd_fork_listen) {
604 index = fd_open();
605 if (index < 0)
606 return index;
607
608 ret = real.accept(fd, addr, addrlen);
609 if (ret < 0) {
610 fd_close(index, &fd);
611 return ret;
612 }
613
614 fd_store(index, ret, fd_normal, fd_fork_passive);
615 return index;
616 } else {
617 return real.accept(fd, addr, addrlen);
618 }
619 }
620
621 /*
622 * We can't fork RDMA connections and pass them from the parent to the child
623 * process. Instead, we need to establish the RDMA connection after calling
624 * fork. To do this, we delay establishing the RDMA connection until we try
625 * to send/receive on the server side.
626 */
fork_active(int socket)627 static void fork_active(int socket)
628 {
629 struct sockaddr_storage addr;
630 int sfd, dfd, ret;
631 socklen_t len;
632 uint32_t msg;
633 long flags;
634
635 sfd = fd_getd(socket);
636
637 flags = real.fcntl(sfd, F_GETFL);
638 real.fcntl(sfd, F_SETFL, 0);
639 ret = real.recv(sfd, &msg, sizeof msg, MSG_PEEK);
640 real.fcntl(sfd, F_SETFL, flags);
641 if ((ret != sizeof msg) || msg)
642 goto err1;
643
644 len = sizeof addr;
645 ret = real.getpeername(sfd, (struct sockaddr *) &addr, &len);
646 if (ret)
647 goto err1;
648
649 dfd = rsocket(addr.ss_family, SOCK_STREAM, 0);
650 if (dfd < 0)
651 goto err1;
652
653 ret = rconnect(dfd, (struct sockaddr *) &addr, len);
654 if (ret)
655 goto err2;
656
657 set_rsocket_options(dfd);
658 copysockopts(dfd, sfd, &rs, &real);
659 real.shutdown(sfd, SHUT_RDWR);
660 real.close(sfd);
661 fd_store(socket, dfd, fd_rsocket, fd_ready);
662 return;
663
664 err2:
665 rclose(dfd);
666 err1:
667 fd_store(socket, sfd, fd_normal, fd_ready);
668 }
669
670 /*
671 * The server will start listening for the new connection, then send a
672 * message to the active side when the listen is ready. This does leave
673 * fork unsupported in the following case: the server is nonblocking and
674 * calls select/poll waiting to receive data from the client.
675 */
fork_passive(int socket)676 static void fork_passive(int socket)
677 {
678 struct sockaddr_in6 sin6;
679 sem_t *sem;
680 int lfd, sfd, dfd, ret, param;
681 socklen_t len;
682 uint32_t msg;
683
684 sfd = fd_getd(socket);
685
686 len = sizeof sin6;
687 ret = real.getsockname(sfd, (struct sockaddr *) &sin6, &len);
688 if (ret)
689 goto out;
690 sin6.sin6_flowinfo = 0;
691 sin6.sin6_scope_id = 0;
692 memset(&sin6.sin6_addr, 0, sizeof sin6.sin6_addr);
693
694 sem = sem_open("/rsocket_fork", O_CREAT | O_RDWR,
695 S_IRWXU | S_IRWXG, 1);
696 if (sem == SEM_FAILED) {
697 ret = -1;
698 goto out;
699 }
700
701 lfd = rsocket(sin6.sin6_family, SOCK_STREAM, 0);
702 if (lfd < 0) {
703 ret = lfd;
704 goto sclose;
705 }
706
707 param = 1;
708 rsetsockopt(lfd, SOL_SOCKET, SO_REUSEADDR, ¶m, sizeof param);
709
710 sem_wait(sem);
711 ret = rbind(lfd, (struct sockaddr *) &sin6, sizeof sin6);
712 if (ret)
713 goto lclose;
714
715 ret = rlisten(lfd, 1);
716 if (ret)
717 goto lclose;
718
719 msg = 0;
720 len = real.write(sfd, &msg, sizeof msg);
721 if (len != sizeof msg)
722 goto lclose;
723
724 dfd = raccept(lfd, NULL, NULL);
725 if (dfd < 0) {
726 ret = dfd;
727 goto lclose;
728 }
729
730 set_rsocket_options(dfd);
731 copysockopts(dfd, sfd, &rs, &real);
732 real.shutdown(sfd, SHUT_RDWR);
733 real.close(sfd);
734 fd_store(socket, dfd, fd_rsocket, fd_ready);
735
736 lclose:
737 rclose(lfd);
738 sem_post(sem);
739 sclose:
740 sem_close(sem);
741 out:
742 if (ret)
743 fd_store(socket, sfd, fd_normal, fd_ready);
744 }
745
fd_fork_get(int index,int * fd)746 static inline enum fd_type fd_fork_get(int index, int *fd)
747 {
748 struct fd_info *fdi;
749
750 fdi = idm_lookup(&idm, index);
751 if (fdi) {
752 if (fdi->state == fd_fork_passive)
753 fork_passive(index);
754 else if (fdi->state == fd_fork_active)
755 fork_active(index);
756 *fd = fdi->fd;
757 return fdi->type;
758
759 } else {
760 *fd = index;
761 return fd_normal;
762 }
763 }
764
connect(int socket,const struct sockaddr * addr,socklen_t addrlen)765 int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
766 {
767 int fd, ret;
768
769 if (fd_get(socket, &fd) == fd_rsocket) {
770 ret = rconnect(fd, addr, addrlen);
771 if (!ret || errno == EINPROGRESS)
772 return ret;
773
774 ret = transpose_socket(socket, fd_normal);
775 if (ret < 0)
776 return ret;
777
778 rclose(fd);
779 fd = ret;
780 } else if (fd_gets(socket) == fd_fork) {
781 fd_store(socket, fd, fd_normal, fd_fork_active);
782 }
783
784 return real.connect(fd, addr, addrlen);
785 }
786
recv(int socket,void * buf,size_t len,int flags)787 ssize_t recv(int socket, void *buf, size_t len, int flags)
788 {
789 int fd;
790 return (fd_fork_get(socket, &fd) == fd_rsocket) ?
791 rrecv(fd, buf, len, flags) : real.recv(fd, buf, len, flags);
792 }
793
recvfrom(int socket,void * buf,size_t len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)794 ssize_t recvfrom(int socket, void *buf, size_t len, int flags,
795 struct sockaddr *src_addr, socklen_t *addrlen)
796 {
797 int fd;
798 return (fd_fork_get(socket, &fd) == fd_rsocket) ?
799 rrecvfrom(fd, buf, len, flags, src_addr, addrlen) :
800 real.recvfrom(fd, buf, len, flags, src_addr, addrlen);
801 }
802
recvmsg(int socket,struct msghdr * msg,int flags)803 ssize_t recvmsg(int socket, struct msghdr *msg, int flags)
804 {
805 int fd;
806 return (fd_fork_get(socket, &fd) == fd_rsocket) ?
807 rrecvmsg(fd, msg, flags) : real.recvmsg(fd, msg, flags);
808 }
809
read(int socket,void * buf,size_t count)810 ssize_t read(int socket, void *buf, size_t count)
811 {
812 int fd;
813 init_preload();
814 return (fd_fork_get(socket, &fd) == fd_rsocket) ?
815 rread(fd, buf, count) : real.read(fd, buf, count);
816 }
817
readv(int socket,const struct iovec * iov,int iovcnt)818 ssize_t readv(int socket, const struct iovec *iov, int iovcnt)
819 {
820 int fd;
821 init_preload();
822 return (fd_fork_get(socket, &fd) == fd_rsocket) ?
823 rreadv(fd, iov, iovcnt) : real.readv(fd, iov, iovcnt);
824 }
825
send(int socket,const void * buf,size_t len,int flags)826 ssize_t send(int socket, const void *buf, size_t len, int flags)
827 {
828 int fd;
829 return (fd_fork_get(socket, &fd) == fd_rsocket) ?
830 rsend(fd, buf, len, flags) : real.send(fd, buf, len, flags);
831 }
832
sendto(int socket,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)833 ssize_t sendto(int socket, const void *buf, size_t len, int flags,
834 const struct sockaddr *dest_addr, socklen_t addrlen)
835 {
836 int fd;
837 return (fd_fork_get(socket, &fd) == fd_rsocket) ?
838 rsendto(fd, buf, len, flags, dest_addr, addrlen) :
839 real.sendto(fd, buf, len, flags, dest_addr, addrlen);
840 }
841
sendmsg(int socket,const struct msghdr * msg,int flags)842 ssize_t sendmsg(int socket, const struct msghdr *msg, int flags)
843 {
844 int fd;
845 return (fd_fork_get(socket, &fd) == fd_rsocket) ?
846 rsendmsg(fd, msg, flags) : real.sendmsg(fd, msg, flags);
847 }
848
write(int socket,const void * buf,size_t count)849 ssize_t write(int socket, const void *buf, size_t count)
850 {
851 int fd;
852 init_preload();
853 return (fd_fork_get(socket, &fd) == fd_rsocket) ?
854 rwrite(fd, buf, count) : real.write(fd, buf, count);
855 }
856
writev(int socket,const struct iovec * iov,int iovcnt)857 ssize_t writev(int socket, const struct iovec *iov, int iovcnt)
858 {
859 int fd;
860 init_preload();
861 return (fd_fork_get(socket, &fd) == fd_rsocket) ?
862 rwritev(fd, iov, iovcnt) : real.writev(fd, iov, iovcnt);
863 }
864
fds_alloc(nfds_t nfds)865 static struct pollfd *fds_alloc(nfds_t nfds)
866 {
867 static __thread struct pollfd *rfds;
868 static __thread nfds_t rnfds;
869
870 if (nfds > rnfds) {
871 if (rfds)
872 free(rfds);
873
874 rfds = malloc(sizeof(*rfds) * nfds);
875 rnfds = rfds ? nfds : 0;
876 }
877
878 return rfds;
879 }
880
poll(struct pollfd * fds,nfds_t nfds,int timeout)881 int poll(struct pollfd *fds, nfds_t nfds, int timeout)
882 {
883 struct pollfd *rfds;
884 int i, ret;
885
886 init_preload();
887 for (i = 0; i < nfds; i++) {
888 if (fd_gett(fds[i].fd) == fd_rsocket)
889 goto use_rpoll;
890 }
891
892 return real.poll(fds, nfds, timeout);
893
894 use_rpoll:
895 rfds = fds_alloc(nfds);
896 if (!rfds)
897 return ERR(ENOMEM);
898
899 for (i = 0; i < nfds; i++) {
900 rfds[i].fd = fd_getd(fds[i].fd);
901 rfds[i].events = fds[i].events;
902 rfds[i].revents = 0;
903 }
904
905 ret = rpoll(rfds, nfds, timeout);
906
907 for (i = 0; i < nfds; i++)
908 fds[i].revents = rfds[i].revents;
909
910 return ret;
911 }
912
select_to_rpoll(struct pollfd * fds,int * nfds,fd_set * readfds,fd_set * writefds,fd_set * exceptfds)913 static void select_to_rpoll(struct pollfd *fds, int *nfds,
914 fd_set *readfds, fd_set *writefds, fd_set *exceptfds)
915 {
916 int fd, events, i = 0;
917
918 for (fd = 0; fd < *nfds; fd++) {
919 events = (readfds && FD_ISSET(fd, readfds)) ? POLLIN : 0;
920 if (writefds && FD_ISSET(fd, writefds))
921 events |= POLLOUT;
922
923 if (events || (exceptfds && FD_ISSET(fd, exceptfds))) {
924 fds[i].fd = fd_getd(fd);
925 fds[i++].events = events;
926 }
927 }
928
929 *nfds = i;
930 }
931
rpoll_to_select(struct pollfd * fds,int nfds,fd_set * readfds,fd_set * writefds,fd_set * exceptfds)932 static int rpoll_to_select(struct pollfd *fds, int nfds,
933 fd_set *readfds, fd_set *writefds, fd_set *exceptfds)
934 {
935 int fd, rfd, i, cnt = 0;
936
937 for (i = 0, fd = 0; i < nfds; fd++) {
938 rfd = fd_getd(fd);
939 if (rfd != fds[i].fd)
940 continue;
941
942 if (readfds && (fds[i].revents & POLLIN)) {
943 FD_SET(fd, readfds);
944 cnt++;
945 }
946
947 if (writefds && (fds[i].revents & POLLOUT)) {
948 FD_SET(fd, writefds);
949 cnt++;
950 }
951
952 if (exceptfds && (fds[i].revents & ~(POLLIN | POLLOUT))) {
953 FD_SET(fd, exceptfds);
954 cnt++;
955 }
956 i++;
957 }
958
959 return cnt;
960 }
961
rs_convert_timeout(struct timeval * timeout)962 static int rs_convert_timeout(struct timeval *timeout)
963 {
964 return !timeout ? -1 : timeout->tv_sec * 1000 + timeout->tv_usec / 1000;
965 }
966
select(int nfds,fd_set * readfds,fd_set * writefds,fd_set * exceptfds,struct timeval * timeout)967 int select(int nfds, fd_set *readfds, fd_set *writefds,
968 fd_set *exceptfds, struct timeval *timeout)
969 {
970 struct pollfd *fds;
971 int ret;
972
973 fds = fds_alloc(nfds);
974 if (!fds)
975 return ERR(ENOMEM);
976
977 select_to_rpoll(fds, &nfds, readfds, writefds, exceptfds);
978 ret = rpoll(fds, nfds, rs_convert_timeout(timeout));
979
980 if (readfds)
981 FD_ZERO(readfds);
982 if (writefds)
983 FD_ZERO(writefds);
984 if (exceptfds)
985 FD_ZERO(exceptfds);
986
987 if (ret > 0)
988 ret = rpoll_to_select(fds, nfds, readfds, writefds, exceptfds);
989
990 return ret;
991 }
992
shutdown(int socket,int how)993 int shutdown(int socket, int how)
994 {
995 int fd;
996 return (fd_get(socket, &fd) == fd_rsocket) ?
997 rshutdown(fd, how) : real.shutdown(fd, how);
998 }
999
close(int socket)1000 int close(int socket)
1001 {
1002 struct fd_info *fdi;
1003 int ret;
1004
1005 init_preload();
1006 fdi = idm_lookup(&idm, socket);
1007 if (!fdi)
1008 return real.close(socket);
1009
1010 if (fdi->dupfd != -1) {
1011 ret = close(fdi->dupfd);
1012 if (ret)
1013 return ret;
1014 }
1015
1016 if (atomic_fetch_sub(&fdi->refcnt, 1) != 1)
1017 return 0;
1018
1019 idm_clear(&idm, socket);
1020 real.close(socket);
1021 ret = (fdi->type == fd_rsocket) ? rclose(fdi->fd) : real.close(fdi->fd);
1022 free(fdi);
1023 return ret;
1024 }
1025
getpeername(int socket,struct sockaddr * addr,socklen_t * addrlen)1026 int getpeername(int socket, struct sockaddr *addr, socklen_t *addrlen)
1027 {
1028 int fd;
1029 return (fd_get(socket, &fd) == fd_rsocket) ?
1030 rgetpeername(fd, addr, addrlen) :
1031 real.getpeername(fd, addr, addrlen);
1032 }
1033
getsockname(int socket,struct sockaddr * addr,socklen_t * addrlen)1034 int getsockname(int socket, struct sockaddr *addr, socklen_t *addrlen)
1035 {
1036 int fd;
1037 init_preload();
1038 return (fd_get(socket, &fd) == fd_rsocket) ?
1039 rgetsockname(fd, addr, addrlen) :
1040 real.getsockname(fd, addr, addrlen);
1041 }
1042
setsockopt(int socket,int level,int optname,const void * optval,socklen_t optlen)1043 int setsockopt(int socket, int level, int optname,
1044 const void *optval, socklen_t optlen)
1045 {
1046 int fd;
1047 return (fd_get(socket, &fd) == fd_rsocket) ?
1048 rsetsockopt(fd, level, optname, optval, optlen) :
1049 real.setsockopt(fd, level, optname, optval, optlen);
1050 }
1051
getsockopt(int socket,int level,int optname,void * optval,socklen_t * optlen)1052 int getsockopt(int socket, int level, int optname,
1053 void *optval, socklen_t *optlen)
1054 {
1055 int fd;
1056 return (fd_get(socket, &fd) == fd_rsocket) ?
1057 rgetsockopt(fd, level, optname, optval, optlen) :
1058 real.getsockopt(fd, level, optname, optval, optlen);
1059 }
1060
fcntl(int socket,int cmd,...)1061 int fcntl(int socket, int cmd, ... /* arg */)
1062 {
1063 va_list args;
1064 long lparam;
1065 void *pparam;
1066 int fd, ret;
1067
1068 init_preload();
1069 va_start(args, cmd);
1070 switch (cmd) {
1071 case F_GETFD:
1072 case F_GETFL:
1073 case F_GETOWN:
1074 case F_GETSIG:
1075 case F_GETLEASE:
1076 ret = (fd_get(socket, &fd) == fd_rsocket) ?
1077 rfcntl(fd, cmd) : real.fcntl(fd, cmd);
1078 break;
1079 case F_DUPFD:
1080 /*case F_DUPFD_CLOEXEC:*/
1081 case F_SETFD:
1082 case F_SETFL:
1083 case F_SETOWN:
1084 case F_SETSIG:
1085 case F_SETLEASE:
1086 case F_NOTIFY:
1087 lparam = va_arg(args, long);
1088 ret = (fd_get(socket, &fd) == fd_rsocket) ?
1089 rfcntl(fd, cmd, lparam) : real.fcntl(fd, cmd, lparam);
1090 break;
1091 default:
1092 pparam = va_arg(args, void *);
1093 ret = (fd_get(socket, &fd) == fd_rsocket) ?
1094 rfcntl(fd, cmd, pparam) : real.fcntl(fd, cmd, pparam);
1095 break;
1096 }
1097 va_end(args);
1098 return ret;
1099 }
1100
1101 /*
1102 * dup2 is not thread safe
1103 */
dup2(int oldfd,int newfd)1104 int dup2(int oldfd, int newfd)
1105 {
1106 struct fd_info *oldfdi, *newfdi;
1107 int ret;
1108
1109 init_preload();
1110 oldfdi = idm_lookup(&idm, oldfd);
1111 if (oldfdi) {
1112 if (oldfdi->state == fd_fork_passive)
1113 fork_passive(oldfd);
1114 else if (oldfdi->state == fd_fork_active)
1115 fork_active(oldfd);
1116 }
1117
1118 newfdi = idm_lookup(&idm, newfd);
1119 if (newfdi) {
1120 /* newfd cannot have been dup'ed directly */
1121 if (atomic_load(&newfdi->refcnt) > 1)
1122 return ERR(EBUSY);
1123 close(newfd);
1124 }
1125
1126 ret = real.dup2(oldfd, newfd);
1127 if (!oldfdi || ret != newfd)
1128 return ret;
1129
1130 newfdi = calloc(1, sizeof(*newfdi));
1131 if (!newfdi) {
1132 close(newfd);
1133 return ERR(ENOMEM);
1134 }
1135
1136 pthread_mutex_lock(&mut);
1137 idm_set(&idm, newfd, newfdi);
1138 pthread_mutex_unlock(&mut);
1139
1140 newfdi->fd = oldfdi->fd;
1141 newfdi->type = oldfdi->type;
1142 if (oldfdi->dupfd != -1) {
1143 newfdi->dupfd = oldfdi->dupfd;
1144 oldfdi = idm_lookup(&idm, oldfdi->dupfd);
1145 } else {
1146 newfdi->dupfd = oldfd;
1147 }
1148 atomic_store(&newfdi->refcnt, 1);
1149 atomic_fetch_add(&oldfdi->refcnt, 1);
1150 return newfd;
1151 }
1152
sendfile(int out_fd,int in_fd,off_t * offset,size_t count)1153 ssize_t sendfile(int out_fd, int in_fd, off_t *offset, size_t count)
1154 {
1155 void *file_addr;
1156 int fd;
1157 size_t ret;
1158
1159 if (fd_get(out_fd, &fd) != fd_rsocket)
1160 return real.sendfile(fd, in_fd, offset, count);
1161
1162 file_addr = mmap(NULL, count, PROT_READ, 0, in_fd, offset ? *offset : 0);
1163 if (file_addr == (void *) -1)
1164 return -1;
1165
1166 ret = rwrite(fd, file_addr, count);
1167 if ((ret > 0) && offset)
1168 lseek(in_fd, ret, SEEK_CUR);
1169 munmap(file_addr, count);
1170 return ret;
1171 }
1172
__fxstat(int ver,int socket,struct stat * buf)1173 int __fxstat(int ver, int socket, struct stat *buf)
1174 {
1175 int fd, ret;
1176
1177 init_preload();
1178 if (fd_get(socket, &fd) == fd_rsocket) {
1179 ret = real.fxstat(ver, socket, buf);
1180 if (!ret)
1181 buf->st_mode = (buf->st_mode & ~S_IFMT) | __S_IFSOCK;
1182 } else {
1183 ret = real.fxstat(ver, fd, buf);
1184 }
1185 return ret;
1186 }
1187