xref: /freebsd/contrib/ofed/librdmacm/preload.c (revision c66ec88fed842fbaad62c30d510644ceb7bd2d71)
1 /*
2  * Copyright (c) 2011-2012 Intel Corporation.  All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * OpenIB.org BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  *
32  */
33 #define _GNU_SOURCE
34 #include <config.h>
35 
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <sys/uio.h>
39 #include <sys/stat.h>
40 #include <sys/mman.h>
41 #include <stdarg.h>
42 #include <dlfcn.h>
43 #include <netdb.h>
44 #include <unistd.h>
45 #include <fcntl.h>
46 #include <string.h>
47 #include <netinet/tcp.h>
48 #include <unistd.h>
49 #include <semaphore.h>
50 #include <ctype.h>
51 #include <stdlib.h>
52 #include <stdio.h>
53 
54 #include <rdma/rdma_cma.h>
55 #include <rdma/rdma_verbs.h>
56 #include <rdma/rsocket.h>
57 #include "cma.h"
58 #include "indexer.h"
59 
60 struct socket_calls {
61 	int (*socket)(int domain, int type, int protocol);
62 	int (*bind)(int socket, const struct sockaddr *addr, socklen_t addrlen);
63 	int (*listen)(int socket, int backlog);
64 	int (*accept)(int socket, struct sockaddr *addr, socklen_t *addrlen);
65 	int (*connect)(int socket, const struct sockaddr *addr, socklen_t addrlen);
66 	ssize_t (*recv)(int socket, void *buf, size_t len, int flags);
67 	ssize_t (*recvfrom)(int socket, void *buf, size_t len, int flags,
68 			    struct sockaddr *src_addr, socklen_t *addrlen);
69 	ssize_t (*recvmsg)(int socket, struct msghdr *msg, int flags);
70 	ssize_t (*read)(int socket, void *buf, size_t count);
71 	ssize_t (*readv)(int socket, const struct iovec *iov, int iovcnt);
72 	ssize_t (*send)(int socket, const void *buf, size_t len, int flags);
73 	ssize_t (*sendto)(int socket, const void *buf, size_t len, int flags,
74 			  const struct sockaddr *dest_addr, socklen_t addrlen);
75 	ssize_t (*sendmsg)(int socket, const struct msghdr *msg, int flags);
76 	ssize_t (*write)(int socket, const void *buf, size_t count);
77 	ssize_t (*writev)(int socket, const struct iovec *iov, int iovcnt);
78 	int (*poll)(struct pollfd *fds, nfds_t nfds, int timeout);
79 	int (*shutdown)(int socket, int how);
80 	int (*close)(int socket);
81 	int (*getpeername)(int socket, struct sockaddr *addr, socklen_t *addrlen);
82 	int (*getsockname)(int socket, struct sockaddr *addr, socklen_t *addrlen);
83 	int (*setsockopt)(int socket, int level, int optname,
84 			  const void *optval, socklen_t optlen);
85 	int (*getsockopt)(int socket, int level, int optname,
86 			  void *optval, socklen_t *optlen);
87 	int (*fcntl)(int socket, int cmd, ... /* arg */);
88 	int (*dup2)(int oldfd, int newfd);
89 	ssize_t (*sendfile)(int out_fd, int in_fd, off_t *offset, size_t count);
90 	int (*fxstat)(int ver, int fd, struct stat *buf);
91 };
92 
93 static struct socket_calls real;
94 static struct socket_calls rs;
95 
96 static struct index_map idm;
97 static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER;
98 
99 static int sq_size;
100 static int rq_size;
101 static int sq_inline;
102 static int fork_support;
103 
104 enum fd_type {
105 	fd_normal,
106 	fd_rsocket
107 };
108 
109 enum fd_fork_state {
110 	fd_ready,
111 	fd_fork,
112 	fd_fork_listen,
113 	fd_fork_active,
114 	fd_fork_passive
115 };
116 
117 struct fd_info {
118 	enum fd_type type;
119 	enum fd_fork_state state;
120 	int fd;
121 	int dupfd;
122 	_Atomic(int) refcnt;
123 };
124 
125 struct config_entry {
126 	char *name;
127 	int domain;
128 	int type;
129 	int protocol;
130 };
131 
132 static struct config_entry *config;
133 static int config_cnt;
134 
135 static void free_config(void)
136 {
137 	while (config_cnt)
138 		free(config[--config_cnt].name);
139 
140 	free(config);
141 }
142 
143 /*
144  * Config file format:
145  * # Starting '#' indicates comment
146  * # wild card values are supported using '*'
147  * # domain - *, INET, INET6, IB
148  * # type - *, STREAM, DGRAM
149  * # protocol - *, TCP, UDP
150  * program_name domain type protocol
151  */
152 static void scan_config(void)
153 {
154 	struct config_entry *new_config;
155 	FILE *fp;
156 	char line[120], prog[64], dom[16], type[16], proto[16];
157 
158 	fp = fopen(RS_CONF_DIR "/preload_config", "r");
159 	if (!fp)
160 		return;
161 
162 	while (fgets(line, sizeof(line), fp)) {
163 		if (line[0] == '#')
164 			continue;
165 
166 		if (sscanf(line, "%64s%16s%16s%16s", prog, dom, type, proto) != 4)
167 			continue;
168 
169 		new_config = realloc(config, (config_cnt + 1) *
170 					     sizeof(struct config_entry));
171 		if (!new_config)
172 			break;
173 
174 		config = new_config;
175 		memset(&config[config_cnt], 0, sizeof(struct config_entry));
176 
177 		if (!strcasecmp(dom, "INET") ||
178 		    !strcasecmp(dom, "AF_INET") ||
179 		    !strcasecmp(dom, "PF_INET")) {
180 			config[config_cnt].domain = AF_INET;
181 		} else if (!strcasecmp(dom, "INET6") ||
182 			   !strcasecmp(dom, "AF_INET6") ||
183 			   !strcasecmp(dom, "PF_INET6")) {
184 			config[config_cnt].domain = AF_INET6;
185 		} else if (!strcasecmp(dom, "IB") ||
186 			   !strcasecmp(dom, "AF_IB") ||
187 			   !strcasecmp(dom, "PF_IB")) {
188 			config[config_cnt].domain = AF_IB;
189 		} else if (strcmp(dom, "*")) {
190 			continue;
191 		}
192 
193 		if (!strcasecmp(type, "STREAM") ||
194 		    !strcasecmp(type, "SOCK_STREAM")) {
195 			config[config_cnt].type = SOCK_STREAM;
196 		} else if (!strcasecmp(type, "DGRAM") ||
197 			   !strcasecmp(type, "SOCK_DGRAM")) {
198 			config[config_cnt].type = SOCK_DGRAM;
199 		} else if (strcmp(type, "*")) {
200 			continue;
201 		}
202 
203 		if (!strcasecmp(proto, "TCP") ||
204 		    !strcasecmp(proto, "IPPROTO_TCP")) {
205 			config[config_cnt].protocol = IPPROTO_TCP;
206 		} else if (!strcasecmp(proto, "UDP") ||
207 			   !strcasecmp(proto, "IPPROTO_UDP")) {
208 			config[config_cnt].protocol = IPPROTO_UDP;
209 		} else if (strcmp(proto, "*")) {
210 			continue;
211 		}
212 
213 		if (strcmp(prog, "*")) {
214 		    if (!(config[config_cnt].name = strdup(prog)))
215 			    continue;
216 		}
217 
218 		config_cnt++;
219 	}
220 
221 	fclose(fp);
222 	if (config_cnt)
223 		atexit(free_config);
224 }
225 
226 static int intercept_socket(int domain, int type, int protocol)
227 {
228 	int i;
229 
230 	if (!config_cnt)
231 		return 1;
232 
233 	if (!protocol) {
234 		if (type == SOCK_STREAM)
235 			protocol = IPPROTO_TCP;
236 		else if (type == SOCK_DGRAM)
237 			protocol = IPPROTO_UDP;
238 	}
239 
240 	for (i = 0; i < config_cnt; i++) {
241 		if ((!config[i].name ||
242 		     !strncasecmp(config[i].name, program_invocation_short_name,
243 				  strlen(config[i].name))) &&
244 		    (!config[i].domain || config[i].domain == domain) &&
245 		    (!config[i].type || config[i].type == type) &&
246 		    (!config[i].protocol || config[i].protocol == protocol))
247 			return 1;
248 	}
249 
250 	return 0;
251 }
252 
253 static int fd_open(void)
254 {
255 	struct fd_info *fdi;
256 	int ret, index;
257 
258 	fdi = calloc(1, sizeof(*fdi));
259 	if (!fdi)
260 		return ERR(ENOMEM);
261 
262 	index = open("/dev/null", O_RDONLY);
263 	if (index < 0) {
264 		ret = index;
265 		goto err1;
266 	}
267 
268 	fdi->dupfd = -1;
269 	atomic_store(&fdi->refcnt, 1);
270 	pthread_mutex_lock(&mut);
271 	ret = idm_set(&idm, index, fdi);
272 	pthread_mutex_unlock(&mut);
273 	if (ret < 0)
274 		goto err2;
275 
276 	return index;
277 
278 err2:
279 	real.close(index);
280 err1:
281 	free(fdi);
282 	return ret;
283 }
284 
285 static void fd_store(int index, int fd, enum fd_type type, enum fd_fork_state state)
286 {
287 	struct fd_info *fdi;
288 
289 	fdi = idm_at(&idm, index);
290 	fdi->fd = fd;
291 	fdi->type = type;
292 	fdi->state = state;
293 }
294 
295 static inline enum fd_type fd_get(int index, int *fd)
296 {
297 	struct fd_info *fdi;
298 
299 	fdi = idm_lookup(&idm, index);
300 	if (fdi) {
301 		*fd = fdi->fd;
302 		return fdi->type;
303 
304 	} else {
305 		*fd = index;
306 		return fd_normal;
307 	}
308 }
309 
310 static inline int fd_getd(int index)
311 {
312 	struct fd_info *fdi;
313 
314 	fdi = idm_lookup(&idm, index);
315 	return fdi ? fdi->fd : index;
316 }
317 
318 static inline enum fd_fork_state fd_gets(int index)
319 {
320 	struct fd_info *fdi;
321 
322 	fdi = idm_lookup(&idm, index);
323 	return fdi ? fdi->state : fd_ready;
324 }
325 
326 static inline enum fd_type fd_gett(int index)
327 {
328 	struct fd_info *fdi;
329 
330 	fdi = idm_lookup(&idm, index);
331 	return fdi ? fdi->type : fd_normal;
332 }
333 
334 static enum fd_type fd_close(int index, int *fd)
335 {
336 	struct fd_info *fdi;
337 	enum fd_type type;
338 
339 	fdi = idm_lookup(&idm, index);
340 	if (fdi) {
341 		idm_clear(&idm, index);
342 		*fd = fdi->fd;
343 		type = fdi->type;
344 		real.close(index);
345 		free(fdi);
346 	} else {
347 		*fd = index;
348 		type = fd_normal;
349 	}
350 	return type;
351 }
352 
353 static void getenv_options(void)
354 {
355 	char *var;
356 
357 	var = getenv("RS_SQ_SIZE");
358 	if (var)
359 		sq_size = atoi(var);
360 
361 	var = getenv("RS_RQ_SIZE");
362 	if (var)
363 		rq_size = atoi(var);
364 
365 	var = getenv("RS_INLINE");
366 	if (var)
367 		sq_inline = atoi(var);
368 
369 	var = getenv("RDMAV_FORK_SAFE");
370 	if (var)
371 		fork_support = atoi(var);
372 }
373 
374 static void init_preload(void)
375 {
376 	static int init;
377 
378 	/* Quick check without lock */
379 	if (init)
380 		return;
381 
382 	pthread_mutex_lock(&mut);
383 	if (init)
384 		goto out;
385 
386 	real.socket = dlsym(RTLD_NEXT, "socket");
387 	real.bind = dlsym(RTLD_NEXT, "bind");
388 	real.listen = dlsym(RTLD_NEXT, "listen");
389 	real.accept = dlsym(RTLD_NEXT, "accept");
390 	real.connect = dlsym(RTLD_NEXT, "connect");
391 	real.recv = dlsym(RTLD_NEXT, "recv");
392 	real.recvfrom = dlsym(RTLD_NEXT, "recvfrom");
393 	real.recvmsg = dlsym(RTLD_NEXT, "recvmsg");
394 	real.read = dlsym(RTLD_NEXT, "read");
395 	real.readv = dlsym(RTLD_NEXT, "readv");
396 	real.send = dlsym(RTLD_NEXT, "send");
397 	real.sendto = dlsym(RTLD_NEXT, "sendto");
398 	real.sendmsg = dlsym(RTLD_NEXT, "sendmsg");
399 	real.write = dlsym(RTLD_NEXT, "write");
400 	real.writev = dlsym(RTLD_NEXT, "writev");
401 	real.poll = dlsym(RTLD_NEXT, "poll");
402 	real.shutdown = dlsym(RTLD_NEXT, "shutdown");
403 	real.close = dlsym(RTLD_NEXT, "close");
404 	real.getpeername = dlsym(RTLD_NEXT, "getpeername");
405 	real.getsockname = dlsym(RTLD_NEXT, "getsockname");
406 	real.setsockopt = dlsym(RTLD_NEXT, "setsockopt");
407 	real.getsockopt = dlsym(RTLD_NEXT, "getsockopt");
408 	real.fcntl = dlsym(RTLD_NEXT, "fcntl");
409 	real.dup2 = dlsym(RTLD_NEXT, "dup2");
410 	real.sendfile = dlsym(RTLD_NEXT, "sendfile");
411 	real.fxstat = dlsym(RTLD_NEXT, "__fxstat");
412 
413 	rs.socket = dlsym(RTLD_DEFAULT, "rsocket");
414 	rs.bind = dlsym(RTLD_DEFAULT, "rbind");
415 	rs.listen = dlsym(RTLD_DEFAULT, "rlisten");
416 	rs.accept = dlsym(RTLD_DEFAULT, "raccept");
417 	rs.connect = dlsym(RTLD_DEFAULT, "rconnect");
418 	rs.recv = dlsym(RTLD_DEFAULT, "rrecv");
419 	rs.recvfrom = dlsym(RTLD_DEFAULT, "rrecvfrom");
420 	rs.recvmsg = dlsym(RTLD_DEFAULT, "rrecvmsg");
421 	rs.read = dlsym(RTLD_DEFAULT, "rread");
422 	rs.readv = dlsym(RTLD_DEFAULT, "rreadv");
423 	rs.send = dlsym(RTLD_DEFAULT, "rsend");
424 	rs.sendto = dlsym(RTLD_DEFAULT, "rsendto");
425 	rs.sendmsg = dlsym(RTLD_DEFAULT, "rsendmsg");
426 	rs.write = dlsym(RTLD_DEFAULT, "rwrite");
427 	rs.writev = dlsym(RTLD_DEFAULT, "rwritev");
428 	rs.poll = dlsym(RTLD_DEFAULT, "rpoll");
429 	rs.shutdown = dlsym(RTLD_DEFAULT, "rshutdown");
430 	rs.close = dlsym(RTLD_DEFAULT, "rclose");
431 	rs.getpeername = dlsym(RTLD_DEFAULT, "rgetpeername");
432 	rs.getsockname = dlsym(RTLD_DEFAULT, "rgetsockname");
433 	rs.setsockopt = dlsym(RTLD_DEFAULT, "rsetsockopt");
434 	rs.getsockopt = dlsym(RTLD_DEFAULT, "rgetsockopt");
435 	rs.fcntl = dlsym(RTLD_DEFAULT, "rfcntl");
436 
437 	getenv_options();
438 	scan_config();
439 	init = 1;
440 out:
441 	pthread_mutex_unlock(&mut);
442 }
443 
444 /*
445  * We currently only handle copying a few common values.
446  */
447 static int copysockopts(int dfd, int sfd, struct socket_calls *dapi,
448 			struct socket_calls *sapi)
449 {
450 	socklen_t len;
451 	int param, ret;
452 
453 	ret = sapi->fcntl(sfd, F_GETFL);
454 	if (ret > 0)
455 		ret = dapi->fcntl(dfd, F_SETFL, ret);
456 	if (ret)
457 		return ret;
458 
459 	len = sizeof param;
460 	ret = sapi->getsockopt(sfd, SOL_SOCKET, SO_REUSEADDR, &param, &len);
461 	if (param && !ret)
462 		ret = dapi->setsockopt(dfd, SOL_SOCKET, SO_REUSEADDR, &param, len);
463 	if (ret)
464 		return ret;
465 
466 	len = sizeof param;
467 	ret = sapi->getsockopt(sfd, IPPROTO_TCP, TCP_NODELAY, &param, &len);
468 	if (param && !ret)
469 		ret = dapi->setsockopt(dfd, IPPROTO_TCP, TCP_NODELAY, &param, len);
470 	if (ret)
471 		return ret;
472 
473 	return 0;
474 }
475 
476 /*
477  * Convert between an rsocket and a normal socket.
478  */
479 static int transpose_socket(int socket, enum fd_type new_type)
480 {
481 	socklen_t len = 0;
482 	int sfd, dfd, param, ret;
483 	struct socket_calls *sapi, *dapi;
484 
485 	sfd = fd_getd(socket);
486 	if (new_type == fd_rsocket) {
487 		dapi = &rs;
488 		sapi = &real;
489 	} else {
490 		dapi = &real;
491 		sapi = &rs;
492 	}
493 
494 	ret = sapi->getsockname(sfd, NULL, &len);
495 	if (ret)
496 		return ret;
497 
498 	param = (len == sizeof(struct sockaddr_in6)) ? PF_INET6 : PF_INET;
499 	dfd = dapi->socket(param, SOCK_STREAM, 0);
500 	if (dfd < 0)
501 		return dfd;
502 
503 	ret = copysockopts(dfd, sfd, dapi, sapi);
504 	if (ret)
505 		goto err;
506 
507 	fd_store(socket, dfd, new_type, fd_ready);
508 	return dfd;
509 
510 err:
511 	dapi->close(dfd);
512 	return ret;
513 }
514 
515 /*
516  * Use defaults on failure.
517  */
518 static void set_rsocket_options(int rsocket)
519 {
520 	if (sq_size)
521 		rsetsockopt(rsocket, SOL_RDMA, RDMA_SQSIZE, &sq_size, sizeof sq_size);
522 
523 	if (rq_size)
524 		rsetsockopt(rsocket, SOL_RDMA, RDMA_RQSIZE, &rq_size, sizeof rq_size);
525 
526 	if (sq_inline)
527 		rsetsockopt(rsocket, SOL_RDMA, RDMA_INLINE, &sq_inline, sizeof sq_inline);
528 }
529 
530 int socket(int domain, int type, int protocol)
531 {
532 	static __thread int recursive;
533 	int index, ret;
534 
535 	init_preload();
536 
537 	if (recursive || !intercept_socket(domain, type, protocol))
538 		goto real;
539 
540 	index = fd_open();
541 	if (index < 0)
542 		return index;
543 
544 	if (fork_support && (domain == PF_INET || domain == PF_INET6) &&
545 	    (type == SOCK_STREAM) && (!protocol || protocol == IPPROTO_TCP)) {
546 		ret = real.socket(domain, type, protocol);
547 		if (ret < 0)
548 			return ret;
549 		fd_store(index, ret, fd_normal, fd_fork);
550 		return index;
551 	}
552 
553 	recursive = 1;
554 	ret = rsocket(domain, type, protocol);
555 	recursive = 0;
556 	if (ret >= 0) {
557 		fd_store(index, ret, fd_rsocket, fd_ready);
558 		set_rsocket_options(ret);
559 		return index;
560 	}
561 	fd_close(index, &ret);
562 real:
563 	return real.socket(domain, type, protocol);
564 }
565 
566 int bind(int socket, const struct sockaddr *addr, socklen_t addrlen)
567 {
568 	int fd;
569 	return (fd_get(socket, &fd) == fd_rsocket) ?
570 		rbind(fd, addr, addrlen) : real.bind(fd, addr, addrlen);
571 }
572 
573 int listen(int socket, int backlog)
574 {
575 	int fd, ret;
576 	if (fd_get(socket, &fd) == fd_rsocket) {
577 		ret = rlisten(fd, backlog);
578 	} else {
579 		ret = real.listen(fd, backlog);
580 		if (!ret && fd_gets(socket) == fd_fork)
581 			fd_store(socket, fd, fd_normal, fd_fork_listen);
582 	}
583 	return ret;
584 }
585 
586 int accept(int socket, struct sockaddr *addr, socklen_t *addrlen)
587 {
588 	int fd, index, ret;
589 
590 	if (fd_get(socket, &fd) == fd_rsocket) {
591 		index = fd_open();
592 		if (index < 0)
593 			return index;
594 
595 		ret = raccept(fd, addr, addrlen);
596 		if (ret < 0) {
597 			fd_close(index, &fd);
598 			return ret;
599 		}
600 
601 		fd_store(index, ret, fd_rsocket, fd_ready);
602 		return index;
603 	} else if (fd_gets(socket) == fd_fork_listen) {
604 		index = fd_open();
605 		if (index < 0)
606 			return index;
607 
608 		ret = real.accept(fd, addr, addrlen);
609 		if (ret < 0) {
610 			fd_close(index, &fd);
611 			return ret;
612 		}
613 
614 		fd_store(index, ret, fd_normal, fd_fork_passive);
615 		return index;
616 	} else {
617 		return real.accept(fd, addr, addrlen);
618 	}
619 }
620 
621 /*
622  * We can't fork RDMA connections and pass them from the parent to the child
623  * process.  Instead, we need to establish the RDMA connection after calling
624  * fork.  To do this, we delay establishing the RDMA connection until we try
625  * to send/receive on the server side.
626  */
627 static void fork_active(int socket)
628 {
629 	struct sockaddr_storage addr;
630 	int sfd, dfd, ret;
631 	socklen_t len;
632 	uint32_t msg;
633 	long flags;
634 
635 	sfd = fd_getd(socket);
636 
637 	flags = real.fcntl(sfd, F_GETFL);
638 	real.fcntl(sfd, F_SETFL, 0);
639 	ret = real.recv(sfd, &msg, sizeof msg, MSG_PEEK);
640 	real.fcntl(sfd, F_SETFL, flags);
641 	if ((ret != sizeof msg) || msg)
642 		goto err1;
643 
644 	len = sizeof addr;
645 	ret = real.getpeername(sfd, (struct sockaddr *) &addr, &len);
646 	if (ret)
647 		goto err1;
648 
649 	dfd = rsocket(addr.ss_family, SOCK_STREAM, 0);
650 	if (dfd < 0)
651 		goto err1;
652 
653 	ret = rconnect(dfd, (struct sockaddr *) &addr, len);
654 	if (ret)
655 		goto err2;
656 
657 	set_rsocket_options(dfd);
658 	copysockopts(dfd, sfd, &rs, &real);
659 	real.shutdown(sfd, SHUT_RDWR);
660 	real.close(sfd);
661 	fd_store(socket, dfd, fd_rsocket, fd_ready);
662 	return;
663 
664 err2:
665 	rclose(dfd);
666 err1:
667 	fd_store(socket, sfd, fd_normal, fd_ready);
668 }
669 
670 /*
671  * The server will start listening for the new connection, then send a
672  * message to the active side when the listen is ready.  This does leave
673  * fork unsupported in the following case: the server is nonblocking and
674  * calls select/poll waiting to receive data from the client.
675  */
676 static void fork_passive(int socket)
677 {
678 	struct sockaddr_in6 sin6;
679 	sem_t *sem;
680 	int lfd, sfd, dfd, ret, param;
681 	socklen_t len;
682 	uint32_t msg;
683 
684 	sfd = fd_getd(socket);
685 
686 	len = sizeof sin6;
687 	ret = real.getsockname(sfd, (struct sockaddr *) &sin6, &len);
688 	if (ret)
689 		goto out;
690 	sin6.sin6_flowinfo = 0;
691 	sin6.sin6_scope_id = 0;
692 	memset(&sin6.sin6_addr, 0, sizeof sin6.sin6_addr);
693 
694 	sem = sem_open("/rsocket_fork", O_CREAT | O_RDWR,
695 		       S_IRWXU | S_IRWXG, 1);
696 	if (sem == SEM_FAILED) {
697 		ret = -1;
698 		goto out;
699 	}
700 
701 	lfd = rsocket(sin6.sin6_family, SOCK_STREAM, 0);
702 	if (lfd < 0) {
703 		ret = lfd;
704 		goto sclose;
705 	}
706 
707 	param = 1;
708 	rsetsockopt(lfd, SOL_SOCKET, SO_REUSEADDR, &param, sizeof param);
709 
710 	sem_wait(sem);
711 	ret = rbind(lfd, (struct sockaddr *) &sin6, sizeof sin6);
712 	if (ret)
713 		goto lclose;
714 
715 	ret = rlisten(lfd, 1);
716 	if (ret)
717 		goto lclose;
718 
719 	msg = 0;
720 	len = real.write(sfd, &msg, sizeof msg);
721 	if (len != sizeof msg)
722 		goto lclose;
723 
724 	dfd = raccept(lfd, NULL, NULL);
725 	if (dfd < 0) {
726 		ret  = dfd;
727 		goto lclose;
728 	}
729 
730 	set_rsocket_options(dfd);
731 	copysockopts(dfd, sfd, &rs, &real);
732 	real.shutdown(sfd, SHUT_RDWR);
733 	real.close(sfd);
734 	fd_store(socket, dfd, fd_rsocket, fd_ready);
735 
736 lclose:
737 	rclose(lfd);
738 	sem_post(sem);
739 sclose:
740 	sem_close(sem);
741 out:
742 	if (ret)
743 		fd_store(socket, sfd, fd_normal, fd_ready);
744 }
745 
746 static inline enum fd_type fd_fork_get(int index, int *fd)
747 {
748 	struct fd_info *fdi;
749 
750 	fdi = idm_lookup(&idm, index);
751 	if (fdi) {
752 		if (fdi->state == fd_fork_passive)
753 			fork_passive(index);
754 		else if (fdi->state == fd_fork_active)
755 			fork_active(index);
756 		*fd = fdi->fd;
757 		return fdi->type;
758 
759 	} else {
760 		*fd = index;
761 		return fd_normal;
762 	}
763 }
764 
765 int connect(int socket, const struct sockaddr *addr, socklen_t addrlen)
766 {
767 	int fd, ret;
768 
769 	if (fd_get(socket, &fd) == fd_rsocket) {
770 		ret = rconnect(fd, addr, addrlen);
771 		if (!ret || errno == EINPROGRESS)
772 			return ret;
773 
774 		ret = transpose_socket(socket, fd_normal);
775 		if (ret < 0)
776 			return ret;
777 
778 		rclose(fd);
779 		fd = ret;
780 	} else if (fd_gets(socket) == fd_fork) {
781 		fd_store(socket, fd, fd_normal, fd_fork_active);
782 	}
783 
784 	return real.connect(fd, addr, addrlen);
785 }
786 
787 ssize_t recv(int socket, void *buf, size_t len, int flags)
788 {
789 	int fd;
790 	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
791 		rrecv(fd, buf, len, flags) : real.recv(fd, buf, len, flags);
792 }
793 
794 ssize_t recvfrom(int socket, void *buf, size_t len, int flags,
795 		 struct sockaddr *src_addr, socklen_t *addrlen)
796 {
797 	int fd;
798 	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
799 		rrecvfrom(fd, buf, len, flags, src_addr, addrlen) :
800 		real.recvfrom(fd, buf, len, flags, src_addr, addrlen);
801 }
802 
803 ssize_t recvmsg(int socket, struct msghdr *msg, int flags)
804 {
805 	int fd;
806 	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
807 		rrecvmsg(fd, msg, flags) : real.recvmsg(fd, msg, flags);
808 }
809 
810 ssize_t read(int socket, void *buf, size_t count)
811 {
812 	int fd;
813 	init_preload();
814 	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
815 		rread(fd, buf, count) : real.read(fd, buf, count);
816 }
817 
818 ssize_t readv(int socket, const struct iovec *iov, int iovcnt)
819 {
820 	int fd;
821 	init_preload();
822 	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
823 		rreadv(fd, iov, iovcnt) : real.readv(fd, iov, iovcnt);
824 }
825 
826 ssize_t send(int socket, const void *buf, size_t len, int flags)
827 {
828 	int fd;
829 	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
830 		rsend(fd, buf, len, flags) : real.send(fd, buf, len, flags);
831 }
832 
833 ssize_t sendto(int socket, const void *buf, size_t len, int flags,
834 		const struct sockaddr *dest_addr, socklen_t addrlen)
835 {
836 	int fd;
837 	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
838 		rsendto(fd, buf, len, flags, dest_addr, addrlen) :
839 		real.sendto(fd, buf, len, flags, dest_addr, addrlen);
840 }
841 
842 ssize_t sendmsg(int socket, const struct msghdr *msg, int flags)
843 {
844 	int fd;
845 	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
846 		rsendmsg(fd, msg, flags) : real.sendmsg(fd, msg, flags);
847 }
848 
849 ssize_t write(int socket, const void *buf, size_t count)
850 {
851 	int fd;
852 	init_preload();
853 	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
854 		rwrite(fd, buf, count) : real.write(fd, buf, count);
855 }
856 
857 ssize_t writev(int socket, const struct iovec *iov, int iovcnt)
858 {
859 	int fd;
860 	init_preload();
861 	return (fd_fork_get(socket, &fd) == fd_rsocket) ?
862 		rwritev(fd, iov, iovcnt) : real.writev(fd, iov, iovcnt);
863 }
864 
865 static struct pollfd *fds_alloc(nfds_t nfds)
866 {
867 	static __thread struct pollfd *rfds;
868 	static __thread nfds_t rnfds;
869 
870 	if (nfds > rnfds) {
871 		if (rfds)
872 			free(rfds);
873 
874 		rfds = malloc(sizeof(*rfds) * nfds);
875 		rnfds = rfds ? nfds : 0;
876 	}
877 
878 	return rfds;
879 }
880 
881 int poll(struct pollfd *fds, nfds_t nfds, int timeout)
882 {
883 	struct pollfd *rfds;
884 	int i, ret;
885 
886 	init_preload();
887 	for (i = 0; i < nfds; i++) {
888 		if (fd_gett(fds[i].fd) == fd_rsocket)
889 			goto use_rpoll;
890 	}
891 
892 	return real.poll(fds, nfds, timeout);
893 
894 use_rpoll:
895 	rfds = fds_alloc(nfds);
896 	if (!rfds)
897 		return ERR(ENOMEM);
898 
899 	for (i = 0; i < nfds; i++) {
900 		rfds[i].fd = fd_getd(fds[i].fd);
901 		rfds[i].events = fds[i].events;
902 		rfds[i].revents = 0;
903 	}
904 
905 	ret = rpoll(rfds, nfds, timeout);
906 
907 	for (i = 0; i < nfds; i++)
908 		fds[i].revents = rfds[i].revents;
909 
910 	return ret;
911 }
912 
913 static void select_to_rpoll(struct pollfd *fds, int *nfds,
914 			    fd_set *readfds, fd_set *writefds, fd_set *exceptfds)
915 {
916 	int fd, events, i = 0;
917 
918 	for (fd = 0; fd < *nfds; fd++) {
919 		events = (readfds && FD_ISSET(fd, readfds)) ? POLLIN : 0;
920 		if (writefds && FD_ISSET(fd, writefds))
921 			events |= POLLOUT;
922 
923 		if (events || (exceptfds && FD_ISSET(fd, exceptfds))) {
924 			fds[i].fd = fd_getd(fd);
925 			fds[i++].events = events;
926 		}
927 	}
928 
929 	*nfds = i;
930 }
931 
932 static int rpoll_to_select(struct pollfd *fds, int nfds,
933 			   fd_set *readfds, fd_set *writefds, fd_set *exceptfds)
934 {
935 	int fd, rfd, i, cnt = 0;
936 
937 	for (i = 0, fd = 0; i < nfds; fd++) {
938 		rfd = fd_getd(fd);
939 		if (rfd != fds[i].fd)
940 			continue;
941 
942 		if (readfds && (fds[i].revents & POLLIN)) {
943 			FD_SET(fd, readfds);
944 			cnt++;
945 		}
946 
947 		if (writefds && (fds[i].revents & POLLOUT)) {
948 			FD_SET(fd, writefds);
949 			cnt++;
950 		}
951 
952 		if (exceptfds && (fds[i].revents & ~(POLLIN | POLLOUT))) {
953 			FD_SET(fd, exceptfds);
954 			cnt++;
955 		}
956 		i++;
957 	}
958 
959 	return cnt;
960 }
961 
962 static int rs_convert_timeout(struct timeval *timeout)
963 {
964 	return !timeout ? -1 : timeout->tv_sec * 1000 + timeout->tv_usec / 1000;
965 }
966 
967 int select(int nfds, fd_set *readfds, fd_set *writefds,
968 	   fd_set *exceptfds, struct timeval *timeout)
969 {
970 	struct pollfd *fds;
971 	int ret;
972 
973 	fds = fds_alloc(nfds);
974 	if (!fds)
975 		return ERR(ENOMEM);
976 
977 	select_to_rpoll(fds, &nfds, readfds, writefds, exceptfds);
978 	ret = rpoll(fds, nfds, rs_convert_timeout(timeout));
979 
980 	if (readfds)
981 		FD_ZERO(readfds);
982 	if (writefds)
983 		FD_ZERO(writefds);
984 	if (exceptfds)
985 		FD_ZERO(exceptfds);
986 
987 	if (ret > 0)
988 		ret = rpoll_to_select(fds, nfds, readfds, writefds, exceptfds);
989 
990 	return ret;
991 }
992 
993 int shutdown(int socket, int how)
994 {
995 	int fd;
996 	return (fd_get(socket, &fd) == fd_rsocket) ?
997 		rshutdown(fd, how) : real.shutdown(fd, how);
998 }
999 
1000 int close(int socket)
1001 {
1002 	struct fd_info *fdi;
1003 	int ret;
1004 
1005 	init_preload();
1006 	fdi = idm_lookup(&idm, socket);
1007 	if (!fdi)
1008 		return real.close(socket);
1009 
1010 	if (fdi->dupfd != -1) {
1011 		ret = close(fdi->dupfd);
1012 		if (ret)
1013 			return ret;
1014 	}
1015 
1016 	if (atomic_fetch_sub(&fdi->refcnt, 1) != 1)
1017 		return 0;
1018 
1019 	idm_clear(&idm, socket);
1020 	real.close(socket);
1021 	ret = (fdi->type == fd_rsocket) ? rclose(fdi->fd) : real.close(fdi->fd);
1022 	free(fdi);
1023 	return ret;
1024 }
1025 
1026 int getpeername(int socket, struct sockaddr *addr, socklen_t *addrlen)
1027 {
1028 	int fd;
1029 	return (fd_get(socket, &fd) == fd_rsocket) ?
1030 		rgetpeername(fd, addr, addrlen) :
1031 		real.getpeername(fd, addr, addrlen);
1032 }
1033 
1034 int getsockname(int socket, struct sockaddr *addr, socklen_t *addrlen)
1035 {
1036 	int fd;
1037 	init_preload();
1038 	return (fd_get(socket, &fd) == fd_rsocket) ?
1039 		rgetsockname(fd, addr, addrlen) :
1040 		real.getsockname(fd, addr, addrlen);
1041 }
1042 
1043 int setsockopt(int socket, int level, int optname,
1044 		const void *optval, socklen_t optlen)
1045 {
1046 	int fd;
1047 	return (fd_get(socket, &fd) == fd_rsocket) ?
1048 		rsetsockopt(fd, level, optname, optval, optlen) :
1049 		real.setsockopt(fd, level, optname, optval, optlen);
1050 }
1051 
1052 int getsockopt(int socket, int level, int optname,
1053 		void *optval, socklen_t *optlen)
1054 {
1055 	int fd;
1056 	return (fd_get(socket, &fd) == fd_rsocket) ?
1057 		rgetsockopt(fd, level, optname, optval, optlen) :
1058 		real.getsockopt(fd, level, optname, optval, optlen);
1059 }
1060 
1061 int fcntl(int socket, int cmd, ... /* arg */)
1062 {
1063 	va_list args;
1064 	long lparam;
1065 	void *pparam;
1066 	int fd, ret;
1067 
1068 	init_preload();
1069 	va_start(args, cmd);
1070 	switch (cmd) {
1071 	case F_GETFD:
1072 	case F_GETFL:
1073 	case F_GETOWN:
1074 	case F_GETSIG:
1075 	case F_GETLEASE:
1076 		ret = (fd_get(socket, &fd) == fd_rsocket) ?
1077 			rfcntl(fd, cmd) : real.fcntl(fd, cmd);
1078 		break;
1079 	case F_DUPFD:
1080 	/*case F_DUPFD_CLOEXEC:*/
1081 	case F_SETFD:
1082 	case F_SETFL:
1083 	case F_SETOWN:
1084 	case F_SETSIG:
1085 	case F_SETLEASE:
1086 	case F_NOTIFY:
1087 		lparam = va_arg(args, long);
1088 		ret = (fd_get(socket, &fd) == fd_rsocket) ?
1089 			rfcntl(fd, cmd, lparam) : real.fcntl(fd, cmd, lparam);
1090 		break;
1091 	default:
1092 		pparam = va_arg(args, void *);
1093 		ret = (fd_get(socket, &fd) == fd_rsocket) ?
1094 			rfcntl(fd, cmd, pparam) : real.fcntl(fd, cmd, pparam);
1095 		break;
1096 	}
1097 	va_end(args);
1098 	return ret;
1099 }
1100 
1101 /*
1102  * dup2 is not thread safe
1103  */
1104 int dup2(int oldfd, int newfd)
1105 {
1106 	struct fd_info *oldfdi, *newfdi;
1107 	int ret;
1108 
1109 	init_preload();
1110 	oldfdi = idm_lookup(&idm, oldfd);
1111 	if (oldfdi) {
1112 		if (oldfdi->state == fd_fork_passive)
1113 			fork_passive(oldfd);
1114 		else if (oldfdi->state == fd_fork_active)
1115 			fork_active(oldfd);
1116 	}
1117 
1118 	newfdi = idm_lookup(&idm, newfd);
1119 	if (newfdi) {
1120 		 /* newfd cannot have been dup'ed directly */
1121 		if (atomic_load(&newfdi->refcnt) > 1)
1122 			return ERR(EBUSY);
1123 		close(newfd);
1124 	}
1125 
1126 	ret = real.dup2(oldfd, newfd);
1127 	if (!oldfdi || ret != newfd)
1128 		return ret;
1129 
1130 	newfdi = calloc(1, sizeof(*newfdi));
1131 	if (!newfdi) {
1132 		close(newfd);
1133 		return ERR(ENOMEM);
1134 	}
1135 
1136 	pthread_mutex_lock(&mut);
1137 	idm_set(&idm, newfd, newfdi);
1138 	pthread_mutex_unlock(&mut);
1139 
1140 	newfdi->fd = oldfdi->fd;
1141 	newfdi->type = oldfdi->type;
1142 	if (oldfdi->dupfd != -1) {
1143 		newfdi->dupfd = oldfdi->dupfd;
1144 		oldfdi = idm_lookup(&idm, oldfdi->dupfd);
1145 	} else {
1146 		newfdi->dupfd = oldfd;
1147 	}
1148 	atomic_store(&newfdi->refcnt, 1);
1149 	atomic_fetch_add(&oldfdi->refcnt, 1);
1150 	return newfd;
1151 }
1152 
1153 ssize_t sendfile(int out_fd, int in_fd, off_t *offset, size_t count)
1154 {
1155 	void *file_addr;
1156 	int fd;
1157 	size_t ret;
1158 
1159 	if (fd_get(out_fd, &fd) != fd_rsocket)
1160 		return real.sendfile(fd, in_fd, offset, count);
1161 
1162 	file_addr = mmap(NULL, count, PROT_READ, 0, in_fd, offset ? *offset : 0);
1163 	if (file_addr == (void *) -1)
1164 		return -1;
1165 
1166 	ret = rwrite(fd, file_addr, count);
1167 	if ((ret > 0) && offset)
1168 		lseek(in_fd, ret, SEEK_CUR);
1169 	munmap(file_addr, count);
1170 	return ret;
1171 }
1172 
1173 int __fxstat(int ver, int socket, struct stat *buf)
1174 {
1175 	int fd, ret;
1176 
1177 	init_preload();
1178 	if (fd_get(socket, &fd) == fd_rsocket) {
1179 		ret = real.fxstat(ver, socket, buf);
1180 		if (!ret)
1181 			buf->st_mode = (buf->st_mode & ~S_IFMT) | __S_IFSOCK;
1182 	} else {
1183 		ret = real.fxstat(ver, fd, buf);
1184 	}
1185 	return ret;
1186 }
1187