xref: /freebsd/contrib/ofed/librdmacm/rsocket.c (revision 87181516ef48be852d5e5fee53c6e0dbfc62f21e)
1 /*
2  * Copyright (c) 2008-2014 Intel Corporation.  All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * OpenIB.org BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  *
32  */
33 #define _GNU_SOURCE
34 #include <config.h>
35 
36 #include <sys/types.h>
37 #include <sys/socket.h>
38 #include <sys/time.h>
39 #include <infiniband/endian.h>
40 #include <stdarg.h>
41 #include <netdb.h>
42 #include <unistd.h>
43 #include <fcntl.h>
44 #include <stdio.h>
45 #include <stddef.h>
46 #include <string.h>
47 #include <netinet/tcp.h>
48 #include <sys/epoll.h>
49 #include <search.h>
50 #include <byteswap.h>
51 #include <util/compiler.h>
52 
53 #include <rdma/rdma_cma.h>
54 #include <rdma/rdma_verbs.h>
55 #include <rdma/rsocket.h>
56 #include "cma.h"
57 #include "indexer.h"
58 
59 #define RS_OLAP_START_SIZE 2048
60 #define RS_MAX_TRANSFER 65536
61 #define RS_SNDLOWAT 2048
62 #define RS_QP_MIN_SIZE 16
63 #define RS_QP_MAX_SIZE 0xFFFE
64 #define RS_QP_CTRL_SIZE 4	/* must be power of 2 */
65 #define RS_CONN_RETRIES 6
66 #define RS_SGL_SIZE 2
67 static struct index_map idm;
68 static pthread_mutex_t mut = PTHREAD_MUTEX_INITIALIZER;
69 
70 struct rsocket;
71 
72 enum {
73 	RS_SVC_NOOP,
74 	RS_SVC_ADD_DGRAM,
75 	RS_SVC_REM_DGRAM,
76 	RS_SVC_ADD_KEEPALIVE,
77 	RS_SVC_REM_KEEPALIVE,
78 	RS_SVC_MOD_KEEPALIVE
79 };
80 
81 struct rs_svc_msg {
82 	uint32_t cmd;
83 	uint32_t status;
84 	struct rsocket *rs;
85 };
86 
87 struct rs_svc {
88 	pthread_t id;
89 	int sock[2];
90 	int cnt;
91 	int size;
92 	int context_size;
93 	void *(*run)(void *svc);
94 	struct rsocket **rss;
95 	void *contexts;
96 };
97 
98 static struct pollfd *udp_svc_fds;
99 static void *udp_svc_run(void *arg);
100 static struct rs_svc udp_svc = {
101 	.context_size = sizeof(*udp_svc_fds),
102 	.run = udp_svc_run
103 };
104 static uint32_t *tcp_svc_timeouts;
105 static void *tcp_svc_run(void *arg);
106 static struct rs_svc tcp_svc = {
107 	.context_size = sizeof(*tcp_svc_timeouts),
108 	.run = tcp_svc_run
109 };
110 
111 static uint16_t def_iomap_size = 0;
112 static uint16_t def_inline = 64;
113 static uint16_t def_sqsize = 384;
114 static uint16_t def_rqsize = 384;
115 static uint32_t def_mem = (1 << 17);
116 static uint32_t def_wmem = (1 << 17);
117 static uint32_t polling_time = 10;
118 
119 /*
120  * Immediate data format is determined by the upper bits
121  * bit 31: message type, 0 - data, 1 - control
122  * bit 30: buffers updated, 0 - target, 1 - direct-receive
123  * bit 29: more data, 0 - end of transfer, 1 - more data available
124  *
125  * for data transfers:
126  * bits [28:0]: bytes transferred
127  * for control messages:
128  * SGL, CTRL
129  * bits [28-0]: receive credits granted
130  * IOMAP_SGL
131  * bits [28-16]: reserved, bits [15-0]: index
132  */
133 
134 enum {
135 	RS_OP_DATA,
136 	RS_OP_RSVD_DATA_MORE,
137 	RS_OP_WRITE, /* opcode is not transmitted over the network */
138 	RS_OP_RSVD_DRA_MORE,
139 	RS_OP_SGL,
140 	RS_OP_RSVD,
141 	RS_OP_IOMAP_SGL,
142 	RS_OP_CTRL
143 };
144 #define rs_msg_set(op, data)  ((op << 29) | (uint32_t) (data))
145 #define rs_msg_op(imm_data)   (imm_data >> 29)
146 #define rs_msg_data(imm_data) (imm_data & 0x1FFFFFFF)
147 #define RS_MSG_SIZE	      sizeof(uint32_t)
148 
149 #define RS_WR_ID_FLAG_RECV (((uint64_t) 1) << 63)
150 #define RS_WR_ID_FLAG_MSG_SEND (((uint64_t) 1) << 62) /* See RS_OPT_MSG_SEND */
151 #define rs_send_wr_id(data) ((uint64_t) data)
152 #define rs_recv_wr_id(data) (RS_WR_ID_FLAG_RECV | (uint64_t) data)
153 #define rs_wr_is_recv(wr_id) (wr_id & RS_WR_ID_FLAG_RECV)
154 #define rs_wr_is_msg_send(wr_id) (wr_id & RS_WR_ID_FLAG_MSG_SEND)
155 #define rs_wr_data(wr_id) ((uint32_t) wr_id)
156 
157 enum {
158 	RS_CTRL_DISCONNECT,
159 	RS_CTRL_KEEPALIVE,
160 	RS_CTRL_SHUTDOWN
161 };
162 
163 struct rs_msg {
164 	uint32_t op;
165 	uint32_t data;
166 };
167 
168 struct ds_qp;
169 
170 struct ds_rmsg {
171 	struct ds_qp	*qp;
172 	uint32_t	offset;
173 	uint32_t	length;
174 };
175 
176 struct ds_smsg {
177 	struct ds_smsg	*next;
178 };
179 
180 struct rs_sge {
181 	uint64_t addr;
182 	uint32_t key;
183 	uint32_t length;
184 };
185 
186 struct rs_iomap {
187 	uint64_t offset;
188 	struct rs_sge sge;
189 };
190 
191 struct rs_iomap_mr {
192 	uint64_t offset;
193 	struct ibv_mr *mr;
194 	dlist_entry entry;
195 	_Atomic(int) refcnt;
196 	int index;	/* -1 if mapping is local and not in iomap_list */
197 };
198 
199 #define RS_MAX_CTRL_MSG    (sizeof(struct rs_sge))
200 #define rs_host_is_net()   (__BYTE_ORDER == __BIG_ENDIAN)
201 #define RS_CONN_FLAG_NET   (1 << 0)
202 #define RS_CONN_FLAG_IOMAP (1 << 1)
203 
204 struct rs_conn_data {
205 	uint8_t		  version;
206 	uint8_t		  flags;
207 	__be16		  credits;
208 	uint8_t		  reserved[3];
209 	uint8_t		  target_iomap_size;
210 	struct rs_sge	  target_sgl;
211 	struct rs_sge	  data_buf;
212 };
213 
214 struct rs_conn_private_data {
215 	union {
216 		struct rs_conn_data		conn_data;
217 		struct {
218 			struct ib_connect_hdr	ib_hdr;
219 			struct rs_conn_data	conn_data;
220 		} af_ib;
221 	};
222 };
223 
224 /*
225  * rsocket states are ordered as passive, connecting, connected, disconnected.
226  */
227 enum rs_state {
228 	rs_init,
229 	rs_bound	   =		    0x0001,
230 	rs_listening	   =		    0x0002,
231 	rs_opening	   =		    0x0004,
232 	rs_resolving_addr  = rs_opening |   0x0010,
233 	rs_resolving_route = rs_opening |   0x0020,
234 	rs_connecting      = rs_opening |   0x0040,
235 	rs_accepting       = rs_opening |   0x0080,
236 	rs_connected	   =		    0x0100,
237 	rs_writable 	   =		    0x0200,
238 	rs_readable	   =		    0x0400,
239 	rs_connect_rdwr    = rs_connected | rs_readable | rs_writable,
240 	rs_connect_error   =		    0x0800,
241 	rs_disconnected	   =		    0x1000,
242 	rs_error	   =		    0x2000,
243 };
244 
245 #define RS_OPT_SWAP_SGL   (1 << 0)
246 /*
247  * iWarp does not support RDMA write with immediate data.  For iWarp, we
248  * transfer rsocket messages as inline sends.
249  */
250 #define RS_OPT_MSG_SEND   (1 << 1)
251 #define RS_OPT_SVC_ACTIVE (1 << 2)
252 
253 union socket_addr {
254 	struct sockaddr		sa;
255 	struct sockaddr_in	sin;
256 	struct sockaddr_in6	sin6;
257 };
258 
259 struct ds_header {
260 	uint8_t		  version;
261 	uint8_t		  length;
262 	__be16		  port;
263 	union {
264 		__be32  ipv4;
265 		struct {
266 			__be32 flowinfo;
267 			uint8_t  addr[16];
268 		} ipv6;
269 	} addr;
270 };
271 
272 #define DS_IPV4_HDR_LEN  8
273 #define DS_IPV6_HDR_LEN 24
274 
275 struct ds_dest {
276 	union socket_addr addr;	/* must be first */
277 	struct ds_qp	  *qp;
278 	struct ibv_ah	  *ah;
279 	uint32_t	   qpn;
280 };
281 
282 struct ds_qp {
283 	dlist_entry	  list;
284 	struct rsocket	  *rs;
285 	struct rdma_cm_id *cm_id;
286 	struct ds_header  hdr;
287 	struct ds_dest	  dest;
288 
289 	struct ibv_mr	  *smr;
290 	struct ibv_mr	  *rmr;
291 	uint8_t		  *rbuf;
292 
293 	int		  cq_armed;
294 };
295 
296 struct rsocket {
297 	int		  type;
298 	int		  index;
299 	fastlock_t	  slock;
300 	fastlock_t	  rlock;
301 	fastlock_t	  cq_lock;
302 	fastlock_t	  cq_wait_lock;
303 	fastlock_t	  map_lock; /* acquire slock first if needed */
304 
305 	union {
306 		/* data stream */
307 		struct {
308 			struct rdma_cm_id *cm_id;
309 			uint64_t	  tcp_opts;
310 			unsigned int	  keepalive_time;
311 
312 			unsigned int	  ctrl_seqno;
313 			unsigned int	  ctrl_max_seqno;
314 			uint16_t	  sseq_no;
315 			uint16_t	  sseq_comp;
316 			uint16_t	  rseq_no;
317 			uint16_t	  rseq_comp;
318 
319 			int		  remote_sge;
320 			struct rs_sge	  remote_sgl;
321 			struct rs_sge	  remote_iomap;
322 
323 			struct ibv_mr	  *target_mr;
324 			int		  target_sge;
325 			int		  target_iomap_size;
326 			void		  *target_buffer_list;
327 			volatile struct rs_sge	  *target_sgl;
328 			struct rs_iomap   *target_iomap;
329 
330 			int		  rbuf_msg_index;
331 			int		  rbuf_bytes_avail;
332 			int		  rbuf_free_offset;
333 			int		  rbuf_offset;
334 			struct ibv_mr	  *rmr;
335 			uint8_t		  *rbuf;
336 
337 			int		  sbuf_bytes_avail;
338 			struct ibv_mr	  *smr;
339 			struct ibv_sge	  ssgl[2];
340 		};
341 		/* datagram */
342 		struct {
343 			struct ds_qp	  *qp_list;
344 			void		  *dest_map;
345 			struct ds_dest    *conn_dest;
346 
347 			int		  udp_sock;
348 			int		  epfd;
349 			int		  rqe_avail;
350 			struct ds_smsg	  *smsg_free;
351 		};
352 	};
353 
354 	int		  opts;
355 	int		  fd_flags;
356 	uint64_t	  so_opts;
357 	uint64_t	  ipv6_opts;
358 	void		  *optval;
359 	size_t		  optlen;
360 	int		  state;
361 	int		  cq_armed;
362 	int		  retries;
363 	int		  err;
364 
365 	int		  sqe_avail;
366 	uint32_t	  sbuf_size;
367 	uint16_t	  sq_size;
368 	uint16_t	  sq_inline;
369 
370 	uint32_t	  rbuf_size;
371 	uint16_t	  rq_size;
372 	int		  rmsg_head;
373 	int		  rmsg_tail;
374 	union {
375 		struct rs_msg	  *rmsg;
376 		struct ds_rmsg	  *dmsg;
377 	};
378 
379 	uint8_t		  *sbuf;
380 	struct rs_iomap_mr *remote_iomappings;
381 	dlist_entry	  iomap_list;
382 	dlist_entry	  iomap_queue;
383 	int		  iomap_pending;
384 	int		  unack_cqe;
385 };
386 
387 #define DS_UDP_TAG 0x55555555
388 
389 struct ds_udp_header {
390 	__be32		  tag;
391 	uint8_t		  version;
392 	uint8_t		  op;
393 	uint8_t		  length;
394 	uint8_t		  reserved;
395 	__be32		  qpn;  /* lower 8-bits reserved */
396 	union {
397 		__be32	 ipv4;
398 		uint8_t  ipv6[16];
399 	} addr;
400 };
401 
402 #define DS_UDP_IPV4_HDR_LEN 16
403 #define DS_UDP_IPV6_HDR_LEN 28
404 
405 #define ds_next_qp(qp) container_of((qp)->list.next, struct ds_qp, list)
406 
write_all(int fd,const void * msg,size_t len)407 static void write_all(int fd, const void *msg, size_t len)
408 {
409 	// FIXME: if fd is a socket this really needs to handle EINTR and other conditions.
410 	ssize_t rc = write(fd, msg, len);
411 	assert(rc == len);
412 }
413 
read_all(int fd,void * msg,size_t len)414 static void read_all(int fd, void *msg, size_t len)
415 {
416 	// FIXME: if fd is a socket this really needs to handle EINTR and other conditions.
417 	ssize_t rc = read(fd, msg, len);
418 	assert(rc == len);
419 }
420 
ds_insert_qp(struct rsocket * rs,struct ds_qp * qp)421 static void ds_insert_qp(struct rsocket *rs, struct ds_qp *qp)
422 {
423 	if (!rs->qp_list)
424 		dlist_init(&qp->list);
425 	else
426 		dlist_insert_head(&qp->list, &rs->qp_list->list);
427 	rs->qp_list = qp;
428 }
429 
ds_remove_qp(struct rsocket * rs,struct ds_qp * qp)430 static void ds_remove_qp(struct rsocket *rs, struct ds_qp *qp)
431 {
432 	if (qp->list.next != &qp->list) {
433 		rs->qp_list = ds_next_qp(qp);
434 		dlist_remove(&qp->list);
435 	} else {
436 		rs->qp_list = NULL;
437 	}
438 }
439 
rs_notify_svc(struct rs_svc * svc,struct rsocket * rs,int cmd)440 static int rs_notify_svc(struct rs_svc *svc, struct rsocket *rs, int cmd)
441 {
442 	struct rs_svc_msg msg;
443 	int ret;
444 
445 	pthread_mutex_lock(&mut);
446 	if (!svc->cnt) {
447 		ret = socketpair(AF_UNIX, SOCK_STREAM, 0, svc->sock);
448 		if (ret)
449 			goto unlock;
450 
451 		ret = pthread_create(&svc->id, NULL, svc->run, svc);
452 		if (ret) {
453 			ret = ERR(ret);
454 			goto closepair;
455 		}
456 	}
457 
458 	msg.cmd = cmd;
459 	msg.status = EINVAL;
460 	msg.rs = rs;
461 	write_all(svc->sock[0], &msg, sizeof msg);
462 	read_all(svc->sock[0], &msg, sizeof msg);
463 	ret = rdma_seterrno(msg.status);
464 	if (svc->cnt)
465 		goto unlock;
466 
467 	pthread_join(svc->id, NULL);
468 closepair:
469 	close(svc->sock[0]);
470 	close(svc->sock[1]);
471 unlock:
472 	pthread_mutex_unlock(&mut);
473 	return ret;
474 }
475 
ds_compare_addr(const void * dst1,const void * dst2)476 static int ds_compare_addr(const void *dst1, const void *dst2)
477 {
478 	const struct sockaddr *sa1, *sa2;
479 	size_t len;
480 
481 	sa1 = (const struct sockaddr *) dst1;
482 	sa2 = (const struct sockaddr *) dst2;
483 
484 	len = (sa1->sa_family == AF_INET6 && sa2->sa_family == AF_INET6) ?
485 	      sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in);
486 	return memcmp(dst1, dst2, len);
487 }
488 
rs_value_to_scale(int value,int bits)489 static int rs_value_to_scale(int value, int bits)
490 {
491 	return value <= (1 << (bits - 1)) ?
492 	       value : (1 << (bits - 1)) | (value >> bits);
493 }
494 
rs_scale_to_value(int value,int bits)495 static int rs_scale_to_value(int value, int bits)
496 {
497 	return value <= (1 << (bits - 1)) ?
498 	       value : (value & ~(1 << (bits - 1))) << bits;
499 }
500 
501 /* gcc > ~5 will not allow (void)fscanf to suppress -Wunused-result, but this
502    will do it.  In this case ignoring the result is OK (but horribly
503    unfriendly to user) since the library has a sane default. */
504 #define failable_fscanf(f, fmt, ...)                                           \
505 	{                                                                      \
506 		int rc = fscanf(f, fmt, __VA_ARGS__);                          \
507 		(void) rc;                                                     \
508 	}
509 
rs_configure(void)510 static void rs_configure(void)
511 {
512 	FILE *f;
513 	static int init;
514 
515 	if (init)
516 		return;
517 
518 	pthread_mutex_lock(&mut);
519 	if (init)
520 		goto out;
521 
522 	if (ucma_init())
523 		goto out;
524 	ucma_ib_init();
525 
526 	if ((f = fopen(RS_CONF_DIR "/polling_time", "r"))) {
527 		failable_fscanf(f, "%u", &polling_time);
528 		fclose(f);
529 	}
530 
531 	if ((f = fopen(RS_CONF_DIR "/inline_default", "r"))) {
532 		failable_fscanf(f, "%hu", &def_inline);
533 		fclose(f);
534 	}
535 
536 	if ((f = fopen(RS_CONF_DIR "/sqsize_default", "r"))) {
537 		failable_fscanf(f, "%hu", &def_sqsize);
538 		fclose(f);
539 	}
540 
541 	if ((f = fopen(RS_CONF_DIR "/rqsize_default", "r"))) {
542 		failable_fscanf(f, "%hu", &def_rqsize);
543 		fclose(f);
544 	}
545 
546 	if ((f = fopen(RS_CONF_DIR "/mem_default", "r"))) {
547 		failable_fscanf(f, "%u", &def_mem);
548 		fclose(f);
549 
550 		if (def_mem < 1)
551 			def_mem = 1;
552 	}
553 
554 	if ((f = fopen(RS_CONF_DIR "/wmem_default", "r"))) {
555 		failable_fscanf(f, "%u", &def_wmem);
556 		fclose(f);
557 		if (def_wmem < RS_SNDLOWAT)
558 			def_wmem = RS_SNDLOWAT << 1;
559 	}
560 
561 	if ((f = fopen(RS_CONF_DIR "/iomap_size", "r"))) {
562 		failable_fscanf(f, "%hu", &def_iomap_size);
563 		fclose(f);
564 
565 		/* round to supported values */
566 		def_iomap_size = (uint8_t) rs_value_to_scale(
567 			(uint16_t) rs_scale_to_value(def_iomap_size, 8), 8);
568 	}
569 	init = 1;
570 out:
571 	pthread_mutex_unlock(&mut);
572 }
573 
rs_insert(struct rsocket * rs,int index)574 static int rs_insert(struct rsocket *rs, int index)
575 {
576 	pthread_mutex_lock(&mut);
577 	rs->index = idm_set(&idm, index, rs);
578 	pthread_mutex_unlock(&mut);
579 	return rs->index;
580 }
581 
rs_remove(struct rsocket * rs)582 static void rs_remove(struct rsocket *rs)
583 {
584 	pthread_mutex_lock(&mut);
585 	idm_clear(&idm, rs->index);
586 	pthread_mutex_unlock(&mut);
587 }
588 
589 /* We only inherit from listening sockets */
rs_alloc(struct rsocket * inherited_rs,int type)590 static struct rsocket *rs_alloc(struct rsocket *inherited_rs, int type)
591 {
592 	struct rsocket *rs;
593 
594 	rs = calloc(1, sizeof(*rs));
595 	if (!rs)
596 		return NULL;
597 
598 	rs->type = type;
599 	rs->index = -1;
600 	if (type == SOCK_DGRAM) {
601 		rs->udp_sock = -1;
602 		rs->epfd = -1;
603 	}
604 
605 	if (inherited_rs) {
606 		rs->sbuf_size = inherited_rs->sbuf_size;
607 		rs->rbuf_size = inherited_rs->rbuf_size;
608 		rs->sq_inline = inherited_rs->sq_inline;
609 		rs->sq_size = inherited_rs->sq_size;
610 		rs->rq_size = inherited_rs->rq_size;
611 		if (type == SOCK_STREAM) {
612 			rs->ctrl_max_seqno = inherited_rs->ctrl_max_seqno;
613 			rs->target_iomap_size = inherited_rs->target_iomap_size;
614 		}
615 	} else {
616 		rs->sbuf_size = def_wmem;
617 		rs->rbuf_size = def_mem;
618 		rs->sq_inline = def_inline;
619 		rs->sq_size = def_sqsize;
620 		rs->rq_size = def_rqsize;
621 		if (type == SOCK_STREAM) {
622 			rs->ctrl_max_seqno = RS_QP_CTRL_SIZE;
623 			rs->target_iomap_size = def_iomap_size;
624 		}
625 	}
626 	fastlock_init(&rs->slock);
627 	fastlock_init(&rs->rlock);
628 	fastlock_init(&rs->cq_lock);
629 	fastlock_init(&rs->cq_wait_lock);
630 	fastlock_init(&rs->map_lock);
631 	dlist_init(&rs->iomap_list);
632 	dlist_init(&rs->iomap_queue);
633 	return rs;
634 }
635 
rs_set_nonblocking(struct rsocket * rs,int arg)636 static int rs_set_nonblocking(struct rsocket *rs, int arg)
637 {
638 	struct ds_qp *qp;
639 	int ret = 0;
640 
641 	if (rs->type == SOCK_STREAM) {
642 		if (rs->cm_id->recv_cq_channel)
643 			ret = fcntl(rs->cm_id->recv_cq_channel->fd, F_SETFL, arg);
644 
645 		if (!ret && rs->state < rs_connected)
646 			ret = fcntl(rs->cm_id->channel->fd, F_SETFL, arg);
647 	} else {
648 		ret = fcntl(rs->epfd, F_SETFL, arg);
649 		if (!ret && rs->qp_list) {
650 			qp = rs->qp_list;
651 			do {
652 				ret = fcntl(qp->cm_id->recv_cq_channel->fd,
653 					    F_SETFL, arg);
654 				qp = ds_next_qp(qp);
655 			} while (qp != rs->qp_list && !ret);
656 		}
657 	}
658 
659 	return ret;
660 }
661 
rs_set_qp_size(struct rsocket * rs)662 static void rs_set_qp_size(struct rsocket *rs)
663 {
664 	uint16_t max_size;
665 
666 	max_size = min(ucma_max_qpsize(rs->cm_id), RS_QP_MAX_SIZE);
667 
668 	if (rs->sq_size > max_size)
669 		rs->sq_size = max_size;
670 	else if (rs->sq_size < RS_QP_MIN_SIZE)
671 		rs->sq_size = RS_QP_MIN_SIZE;
672 
673 	if (rs->rq_size > max_size)
674 		rs->rq_size = max_size;
675 	else if (rs->rq_size < RS_QP_MIN_SIZE)
676 		rs->rq_size = RS_QP_MIN_SIZE;
677 }
678 
ds_set_qp_size(struct rsocket * rs)679 static void ds_set_qp_size(struct rsocket *rs)
680 {
681 	uint16_t max_size;
682 
683 	max_size = min(ucma_max_qpsize(NULL), RS_QP_MAX_SIZE);
684 
685 	if (rs->sq_size > max_size)
686 		rs->sq_size = max_size;
687 	if (rs->rq_size > max_size)
688 		rs->rq_size = max_size;
689 
690 	if (rs->rq_size > (rs->rbuf_size / RS_SNDLOWAT))
691 		rs->rq_size = rs->rbuf_size / RS_SNDLOWAT;
692 	else
693 		rs->rbuf_size = rs->rq_size * RS_SNDLOWAT;
694 
695 	if (rs->sq_size > (rs->sbuf_size / RS_SNDLOWAT))
696 		rs->sq_size = rs->sbuf_size / RS_SNDLOWAT;
697 	else
698 		rs->sbuf_size = rs->sq_size * RS_SNDLOWAT;
699 }
700 
rs_init_bufs(struct rsocket * rs)701 static int rs_init_bufs(struct rsocket *rs)
702 {
703 	uint32_t total_rbuf_size, total_sbuf_size;
704 	size_t len;
705 
706 	rs->rmsg = calloc(rs->rq_size + 1, sizeof(*rs->rmsg));
707 	if (!rs->rmsg)
708 		return ERR(ENOMEM);
709 
710 	total_sbuf_size = rs->sbuf_size;
711 	if (rs->sq_inline < RS_MAX_CTRL_MSG)
712 		total_sbuf_size += RS_MAX_CTRL_MSG * RS_QP_CTRL_SIZE;
713 	rs->sbuf = calloc(total_sbuf_size, 1);
714 	if (!rs->sbuf)
715 		return ERR(ENOMEM);
716 
717 	rs->smr = rdma_reg_msgs(rs->cm_id, rs->sbuf, total_sbuf_size);
718 	if (!rs->smr)
719 		return -1;
720 
721 	len = sizeof(*rs->target_sgl) * RS_SGL_SIZE +
722 	      sizeof(*rs->target_iomap) * rs->target_iomap_size;
723 	rs->target_buffer_list = malloc(len);
724 	if (!rs->target_buffer_list)
725 		return ERR(ENOMEM);
726 
727 	rs->target_mr = rdma_reg_write(rs->cm_id, rs->target_buffer_list, len);
728 	if (!rs->target_mr)
729 		return -1;
730 
731 	memset(rs->target_buffer_list, 0, len);
732 	rs->target_sgl = rs->target_buffer_list;
733 	if (rs->target_iomap_size)
734 		rs->target_iomap = (struct rs_iomap *) (rs->target_sgl + RS_SGL_SIZE);
735 
736 	total_rbuf_size = rs->rbuf_size;
737 	if (rs->opts & RS_OPT_MSG_SEND)
738 		total_rbuf_size += rs->rq_size * RS_MSG_SIZE;
739 	rs->rbuf = calloc(total_rbuf_size, 1);
740 	if (!rs->rbuf)
741 		return ERR(ENOMEM);
742 
743 	rs->rmr = rdma_reg_write(rs->cm_id, rs->rbuf, total_rbuf_size);
744 	if (!rs->rmr)
745 		return -1;
746 
747 	rs->ssgl[0].addr = rs->ssgl[1].addr = (uintptr_t) rs->sbuf;
748 	rs->sbuf_bytes_avail = rs->sbuf_size;
749 	rs->ssgl[0].lkey = rs->ssgl[1].lkey = rs->smr->lkey;
750 
751 	rs->rbuf_free_offset = rs->rbuf_size >> 1;
752 	rs->rbuf_bytes_avail = rs->rbuf_size >> 1;
753 	rs->sqe_avail = rs->sq_size - rs->ctrl_max_seqno;
754 	rs->rseq_comp = rs->rq_size >> 1;
755 	return 0;
756 }
757 
ds_init_bufs(struct ds_qp * qp)758 static int ds_init_bufs(struct ds_qp *qp)
759 {
760 	qp->rbuf = calloc(qp->rs->rbuf_size + sizeof(struct ibv_grh), 1);
761 	if (!qp->rbuf)
762 		return ERR(ENOMEM);
763 
764 	qp->smr = rdma_reg_msgs(qp->cm_id, qp->rs->sbuf, qp->rs->sbuf_size);
765 	if (!qp->smr)
766 		return -1;
767 
768 	qp->rmr = rdma_reg_msgs(qp->cm_id, qp->rbuf, qp->rs->rbuf_size +
769 						     sizeof(struct ibv_grh));
770 	if (!qp->rmr)
771 		return -1;
772 
773 	return 0;
774 }
775 
776 /*
777  * If a user is waiting on a datagram rsocket through poll or select, then
778  * we need the first completion to generate an event on the related epoll fd
779  * in order to signal the user.  We arm the CQ on creation for this purpose
780  */
rs_create_cq(struct rsocket * rs,struct rdma_cm_id * cm_id)781 static int rs_create_cq(struct rsocket *rs, struct rdma_cm_id *cm_id)
782 {
783 	cm_id->recv_cq_channel = ibv_create_comp_channel(cm_id->verbs);
784 	if (!cm_id->recv_cq_channel)
785 		return -1;
786 
787 	cm_id->recv_cq = ibv_create_cq(cm_id->verbs, rs->sq_size + rs->rq_size,
788 				       cm_id, cm_id->recv_cq_channel, 0);
789 	if (!cm_id->recv_cq)
790 		goto err1;
791 
792 	if (rs->fd_flags & O_NONBLOCK) {
793 		if (fcntl(cm_id->recv_cq_channel->fd, F_SETFL, O_NONBLOCK))
794 			goto err2;
795 	}
796 
797 	ibv_req_notify_cq(cm_id->recv_cq, 0);
798 	cm_id->send_cq_channel = cm_id->recv_cq_channel;
799 	cm_id->send_cq = cm_id->recv_cq;
800 	return 0;
801 
802 err2:
803 	ibv_destroy_cq(cm_id->recv_cq);
804 	cm_id->recv_cq = NULL;
805 err1:
806 	ibv_destroy_comp_channel(cm_id->recv_cq_channel);
807 	cm_id->recv_cq_channel = NULL;
808 	return -1;
809 }
810 
rs_post_recv(struct rsocket * rs)811 static inline int rs_post_recv(struct rsocket *rs)
812 {
813 	struct ibv_recv_wr wr, *bad;
814 	struct ibv_sge sge;
815 
816 	wr.next = NULL;
817 	if (!(rs->opts & RS_OPT_MSG_SEND)) {
818 		wr.wr_id = rs_recv_wr_id(0);
819 		wr.sg_list = NULL;
820 		wr.num_sge = 0;
821 	} else {
822 		wr.wr_id = rs_recv_wr_id(rs->rbuf_msg_index);
823 		sge.addr = (uintptr_t) rs->rbuf + rs->rbuf_size +
824 			   (rs->rbuf_msg_index * RS_MSG_SIZE);
825 		sge.length = RS_MSG_SIZE;
826 		sge.lkey = rs->rmr->lkey;
827 
828 		wr.sg_list = &sge;
829 		wr.num_sge = 1;
830 		if(++rs->rbuf_msg_index == rs->rq_size)
831 			rs->rbuf_msg_index = 0;
832 	}
833 
834 	return rdma_seterrno(ibv_post_recv(rs->cm_id->qp, &wr, &bad));
835 }
836 
ds_post_recv(struct rsocket * rs,struct ds_qp * qp,uint32_t offset)837 static inline int ds_post_recv(struct rsocket *rs, struct ds_qp *qp, uint32_t offset)
838 {
839 	struct ibv_recv_wr wr, *bad;
840 	struct ibv_sge sge[2];
841 
842 	sge[0].addr = (uintptr_t) qp->rbuf + rs->rbuf_size;
843 	sge[0].length = sizeof(struct ibv_grh);
844 	sge[0].lkey = qp->rmr->lkey;
845 	sge[1].addr = (uintptr_t) qp->rbuf + offset;
846 	sge[1].length = RS_SNDLOWAT;
847 	sge[1].lkey = qp->rmr->lkey;
848 
849 	wr.wr_id = rs_recv_wr_id(offset);
850 	wr.next = NULL;
851 	wr.sg_list = sge;
852 	wr.num_sge = 2;
853 
854 	return rdma_seterrno(ibv_post_recv(qp->cm_id->qp, &wr, &bad));
855 }
856 
rs_create_ep(struct rsocket * rs)857 static int rs_create_ep(struct rsocket *rs)
858 {
859 	struct ibv_qp_init_attr qp_attr;
860 	int i, ret;
861 
862 	rs_set_qp_size(rs);
863 	if (rs->cm_id->verbs->device->transport_type == IBV_TRANSPORT_IWARP)
864 		rs->opts |= RS_OPT_MSG_SEND;
865 	ret = rs_create_cq(rs, rs->cm_id);
866 	if (ret)
867 		return ret;
868 
869 	memset(&qp_attr, 0, sizeof qp_attr);
870 	qp_attr.qp_context = rs;
871 	qp_attr.send_cq = rs->cm_id->send_cq;
872 	qp_attr.recv_cq = rs->cm_id->recv_cq;
873 	qp_attr.qp_type = IBV_QPT_RC;
874 	qp_attr.sq_sig_all = 1;
875 	qp_attr.cap.max_send_wr = rs->sq_size;
876 	qp_attr.cap.max_recv_wr = rs->rq_size;
877 	qp_attr.cap.max_send_sge = 2;
878 	qp_attr.cap.max_recv_sge = 1;
879 	qp_attr.cap.max_inline_data = rs->sq_inline;
880 
881 	ret = rdma_create_qp(rs->cm_id, NULL, &qp_attr);
882 	if (ret)
883 		return ret;
884 
885 	rs->sq_inline = qp_attr.cap.max_inline_data;
886 	if ((rs->opts & RS_OPT_MSG_SEND) && (rs->sq_inline < RS_MSG_SIZE))
887 		return ERR(ENOTSUP);
888 
889 	ret = rs_init_bufs(rs);
890 	if (ret)
891 		return ret;
892 
893 	for (i = 0; i < rs->rq_size; i++) {
894 		ret = rs_post_recv(rs);
895 		if (ret)
896 			return ret;
897 	}
898 	return 0;
899 }
900 
rs_release_iomap_mr(struct rs_iomap_mr * iomr)901 static void rs_release_iomap_mr(struct rs_iomap_mr *iomr)
902 {
903 	if (atomic_fetch_sub(&iomr->refcnt, 1) != 1)
904 		return;
905 
906 	dlist_remove(&iomr->entry);
907 	ibv_dereg_mr(iomr->mr);
908 	if (iomr->index >= 0)
909 		iomr->mr = NULL;
910 	else
911 		free(iomr);
912 }
913 
rs_free_iomappings(struct rsocket * rs)914 static void rs_free_iomappings(struct rsocket *rs)
915 {
916 	struct rs_iomap_mr *iomr;
917 
918 	while (!dlist_empty(&rs->iomap_list)) {
919 		iomr = container_of(rs->iomap_list.next,
920 				    struct rs_iomap_mr, entry);
921 		riounmap(rs->index, iomr->mr->addr, iomr->mr->length);
922 	}
923 	while (!dlist_empty(&rs->iomap_queue)) {
924 		iomr = container_of(rs->iomap_queue.next,
925 				    struct rs_iomap_mr, entry);
926 		riounmap(rs->index, iomr->mr->addr, iomr->mr->length);
927 	}
928 }
929 
ds_free_qp(struct ds_qp * qp)930 static void ds_free_qp(struct ds_qp *qp)
931 {
932 	if (qp->smr)
933 		rdma_dereg_mr(qp->smr);
934 
935 	if (qp->rbuf) {
936 		if (qp->rmr)
937 			rdma_dereg_mr(qp->rmr);
938 		free(qp->rbuf);
939 	}
940 
941 	if (qp->cm_id) {
942 		if (qp->cm_id->qp) {
943 			tdelete(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr);
944 			epoll_ctl(qp->rs->epfd, EPOLL_CTL_DEL,
945 				  qp->cm_id->recv_cq_channel->fd, NULL);
946 			rdma_destroy_qp(qp->cm_id);
947 		}
948 		rdma_destroy_id(qp->cm_id);
949 	}
950 
951 	free(qp);
952 }
953 
ds_free(struct rsocket * rs)954 static void ds_free(struct rsocket *rs)
955 {
956 	struct ds_qp *qp;
957 
958 	if (rs->udp_sock >= 0)
959 		close(rs->udp_sock);
960 
961 	if (rs->index >= 0)
962 		rs_remove(rs);
963 
964 	if (rs->dmsg)
965 		free(rs->dmsg);
966 
967 	while ((qp = rs->qp_list)) {
968 		ds_remove_qp(rs, qp);
969 		ds_free_qp(qp);
970 	}
971 
972 	if (rs->epfd >= 0)
973 		close(rs->epfd);
974 
975 	if (rs->sbuf)
976 		free(rs->sbuf);
977 
978 	tdestroy(rs->dest_map, free);
979 	fastlock_destroy(&rs->map_lock);
980 	fastlock_destroy(&rs->cq_wait_lock);
981 	fastlock_destroy(&rs->cq_lock);
982 	fastlock_destroy(&rs->rlock);
983 	fastlock_destroy(&rs->slock);
984 	free(rs);
985 }
986 
rs_free(struct rsocket * rs)987 static void rs_free(struct rsocket *rs)
988 {
989 	if (rs->type == SOCK_DGRAM) {
990 		ds_free(rs);
991 		return;
992 	}
993 
994 	if (rs->rmsg)
995 		free(rs->rmsg);
996 
997 	if (rs->sbuf) {
998 		if (rs->smr)
999 			rdma_dereg_mr(rs->smr);
1000 		free(rs->sbuf);
1001 	}
1002 
1003 	if (rs->rbuf) {
1004 		if (rs->rmr)
1005 			rdma_dereg_mr(rs->rmr);
1006 		free(rs->rbuf);
1007 	}
1008 
1009 	if (rs->target_buffer_list) {
1010 		if (rs->target_mr)
1011 			rdma_dereg_mr(rs->target_mr);
1012 		free(rs->target_buffer_list);
1013 	}
1014 
1015 	if (rs->cm_id) {
1016 		rs_free_iomappings(rs);
1017 		if (rs->cm_id->qp) {
1018 			ibv_ack_cq_events(rs->cm_id->recv_cq, rs->unack_cqe);
1019 			rdma_destroy_qp(rs->cm_id);
1020 		}
1021 		rdma_destroy_id(rs->cm_id);
1022 	}
1023 
1024 	if (rs->index >= 0)
1025 		rs_remove(rs);
1026 
1027 	fastlock_destroy(&rs->map_lock);
1028 	fastlock_destroy(&rs->cq_wait_lock);
1029 	fastlock_destroy(&rs->cq_lock);
1030 	fastlock_destroy(&rs->rlock);
1031 	fastlock_destroy(&rs->slock);
1032 	free(rs);
1033 }
1034 
rs_conn_data_offset(struct rsocket * rs)1035 static size_t rs_conn_data_offset(struct rsocket *rs)
1036 {
1037 	return (rs->cm_id->route.addr.src_addr.sa_family == AF_IB) ?
1038 		sizeof(struct ib_connect_hdr) : 0;
1039 }
1040 
rs_format_conn_data(struct rsocket * rs,struct rs_conn_data * conn)1041 static void rs_format_conn_data(struct rsocket *rs, struct rs_conn_data *conn)
1042 {
1043 	conn->version = 1;
1044 	conn->flags = RS_CONN_FLAG_IOMAP |
1045 		      (rs_host_is_net() ? RS_CONN_FLAG_NET : 0);
1046 	conn->credits = htobe16(rs->rq_size);
1047 	memset(conn->reserved, 0, sizeof conn->reserved);
1048 	conn->target_iomap_size = (uint8_t) rs_value_to_scale(rs->target_iomap_size, 8);
1049 
1050 	conn->target_sgl.addr = (__force uint64_t)htobe64((uintptr_t) rs->target_sgl);
1051 	conn->target_sgl.length = (__force uint32_t)htobe32(RS_SGL_SIZE);
1052 	conn->target_sgl.key = (__force uint32_t)htobe32(rs->target_mr->rkey);
1053 
1054 	conn->data_buf.addr = (__force uint64_t)htobe64((uintptr_t) rs->rbuf);
1055 	conn->data_buf.length = (__force uint32_t)htobe32(rs->rbuf_size >> 1);
1056 	conn->data_buf.key = (__force uint32_t)htobe32(rs->rmr->rkey);
1057 }
1058 
rs_save_conn_data(struct rsocket * rs,struct rs_conn_data * conn)1059 static void rs_save_conn_data(struct rsocket *rs, struct rs_conn_data *conn)
1060 {
1061 	rs->remote_sgl.addr = be64toh((__force __be64)conn->target_sgl.addr);
1062 	rs->remote_sgl.length = be32toh((__force __be32)conn->target_sgl.length);
1063 	rs->remote_sgl.key = be32toh((__force __be32)conn->target_sgl.key);
1064 	rs->remote_sge = 1;
1065 	if ((rs_host_is_net() && !(conn->flags & RS_CONN_FLAG_NET)) ||
1066 	    (!rs_host_is_net() && (conn->flags & RS_CONN_FLAG_NET)))
1067 		rs->opts = RS_OPT_SWAP_SGL;
1068 
1069 	if (conn->flags & RS_CONN_FLAG_IOMAP) {
1070 		rs->remote_iomap.addr = rs->remote_sgl.addr +
1071 					sizeof(rs->remote_sgl) * rs->remote_sgl.length;
1072 		rs->remote_iomap.length = rs_scale_to_value(conn->target_iomap_size, 8);
1073 		rs->remote_iomap.key = rs->remote_sgl.key;
1074 	}
1075 
1076 	rs->target_sgl[0].addr = be64toh((__force __be64)conn->data_buf.addr);
1077 	rs->target_sgl[0].length = be32toh((__force __be32)conn->data_buf.length);
1078 	rs->target_sgl[0].key = be32toh((__force __be32)conn->data_buf.key);
1079 
1080 	rs->sseq_comp = be16toh(conn->credits);
1081 }
1082 
ds_init(struct rsocket * rs,int domain)1083 static int ds_init(struct rsocket *rs, int domain)
1084 {
1085 	rs->udp_sock = socket(domain, SOCK_DGRAM, 0);
1086 	if (rs->udp_sock < 0)
1087 		return rs->udp_sock;
1088 
1089 	rs->epfd = epoll_create(2);
1090 	if (rs->epfd < 0)
1091 		return rs->epfd;
1092 
1093 	return 0;
1094 }
1095 
ds_init_ep(struct rsocket * rs)1096 static int ds_init_ep(struct rsocket *rs)
1097 {
1098 	struct ds_smsg *msg;
1099 	int i, ret;
1100 
1101 	ds_set_qp_size(rs);
1102 
1103 	rs->sbuf = calloc(rs->sq_size, RS_SNDLOWAT);
1104 	if (!rs->sbuf)
1105 		return ERR(ENOMEM);
1106 
1107 	rs->dmsg = calloc(rs->rq_size + 1, sizeof(*rs->dmsg));
1108 	if (!rs->dmsg)
1109 		return ERR(ENOMEM);
1110 
1111 	rs->sqe_avail = rs->sq_size;
1112 	rs->rqe_avail = rs->rq_size;
1113 
1114 	rs->smsg_free = (struct ds_smsg *) rs->sbuf;
1115 	msg = rs->smsg_free;
1116 	for (i = 0; i < rs->sq_size - 1; i++) {
1117 		msg->next = (void *) msg + RS_SNDLOWAT;
1118 		msg = msg->next;
1119 	}
1120 	msg->next = NULL;
1121 
1122 	ret = rs_notify_svc(&udp_svc, rs, RS_SVC_ADD_DGRAM);
1123 	if (ret)
1124 		return ret;
1125 
1126 	rs->state = rs_readable | rs_writable;
1127 	return 0;
1128 }
1129 
rsocket(int domain,int type,int protocol)1130 int rsocket(int domain, int type, int protocol)
1131 {
1132 	struct rsocket *rs;
1133 	int index, ret;
1134 
1135 	if ((domain != AF_INET && domain != AF_INET6 && domain != AF_IB) ||
1136 	    ((type != SOCK_STREAM) && (type != SOCK_DGRAM)) ||
1137 	    (type == SOCK_STREAM && protocol && protocol != IPPROTO_TCP) ||
1138 	    (type == SOCK_DGRAM && protocol && protocol != IPPROTO_UDP))
1139 		return ERR(ENOTSUP);
1140 
1141 	rs_configure();
1142 	rs = rs_alloc(NULL, type);
1143 	if (!rs)
1144 		return ERR(ENOMEM);
1145 
1146 	if (type == SOCK_STREAM) {
1147 		ret = rdma_create_id(NULL, &rs->cm_id, rs, RDMA_PS_TCP);
1148 		if (ret)
1149 			goto err;
1150 
1151 		rs->cm_id->route.addr.src_addr.sa_family = domain;
1152 		index = rs->cm_id->channel->fd;
1153 	} else {
1154 		ret = ds_init(rs, domain);
1155 		if (ret)
1156 			goto err;
1157 
1158 		index = rs->udp_sock;
1159 	}
1160 
1161 	ret = rs_insert(rs, index);
1162 	if (ret < 0)
1163 		goto err;
1164 
1165 	return rs->index;
1166 
1167 err:
1168 	rs_free(rs);
1169 	return ret;
1170 }
1171 
rbind(int socket,const struct sockaddr * addr,socklen_t addrlen)1172 int rbind(int socket, const struct sockaddr *addr, socklen_t addrlen)
1173 {
1174 	struct rsocket *rs;
1175 	int ret;
1176 
1177 	rs = idm_lookup(&idm, socket);
1178 	if (!rs)
1179 		return ERR(EBADF);
1180 	if (rs->type == SOCK_STREAM) {
1181 		ret = rdma_bind_addr(rs->cm_id, (struct sockaddr *) addr);
1182 		if (!ret)
1183 			rs->state = rs_bound;
1184 	} else {
1185 		if (rs->state == rs_init) {
1186 			ret = ds_init_ep(rs);
1187 			if (ret)
1188 				return ret;
1189 		}
1190 		ret = bind(rs->udp_sock, addr, addrlen);
1191 	}
1192 	return ret;
1193 }
1194 
rlisten(int socket,int backlog)1195 int rlisten(int socket, int backlog)
1196 {
1197 	struct rsocket *rs;
1198 	int ret;
1199 
1200 	rs = idm_lookup(&idm, socket);
1201 	if (!rs)
1202 		return ERR(EBADF);
1203 
1204 	if (rs->state != rs_listening) {
1205 		ret = rdma_listen(rs->cm_id, backlog);
1206 		if (!ret)
1207 			rs->state = rs_listening;
1208 	} else {
1209 		ret = 0;
1210 	}
1211 	return ret;
1212 }
1213 
1214 /*
1215  * Nonblocking is usually not inherited between sockets, but we need to
1216  * inherit it here to establish the connection only.  This is needed to
1217  * prevent rdma_accept from blocking until the remote side finishes
1218  * establishing the connection.  If we were to allow rdma_accept to block,
1219  * then a single thread cannot establish a connection with itself, or
1220  * two threads which try to connect to each other can deadlock trying to
1221  * form a connection.
1222  *
1223  * Data transfers on the new socket remain blocking unless the user
1224  * specifies otherwise through rfcntl.
1225  */
raccept(int socket,struct sockaddr * addr,socklen_t * addrlen)1226 int raccept(int socket, struct sockaddr *addr, socklen_t *addrlen)
1227 {
1228 	struct rsocket *rs, *new_rs;
1229 	struct rdma_conn_param param;
1230 	struct rs_conn_data *creq, cresp;
1231 	int ret;
1232 
1233 	rs = idm_lookup(&idm, socket);
1234 	if (!rs)
1235 		return ERR(EBADF);
1236 	new_rs = rs_alloc(rs, rs->type);
1237 	if (!new_rs)
1238 		return ERR(ENOMEM);
1239 
1240 	ret = rdma_get_request(rs->cm_id, &new_rs->cm_id);
1241 	if (ret)
1242 		goto err;
1243 
1244 	ret = rs_insert(new_rs, new_rs->cm_id->channel->fd);
1245 	if (ret < 0)
1246 		goto err;
1247 
1248 	creq = (struct rs_conn_data *)
1249 	       (new_rs->cm_id->event->param.conn.private_data + rs_conn_data_offset(rs));
1250 	if (creq->version != 1) {
1251 		ret = ERR(ENOTSUP);
1252 		goto err;
1253 	}
1254 
1255 	if (rs->fd_flags & O_NONBLOCK)
1256 		fcntl(new_rs->cm_id->channel->fd, F_SETFL, O_NONBLOCK);
1257 
1258 	ret = rs_create_ep(new_rs);
1259 	if (ret)
1260 		goto err;
1261 
1262 	rs_save_conn_data(new_rs, creq);
1263 	param = new_rs->cm_id->event->param.conn;
1264 	rs_format_conn_data(new_rs, &cresp);
1265 	param.private_data = &cresp;
1266 	param.private_data_len = sizeof cresp;
1267 	ret = rdma_accept(new_rs->cm_id, &param);
1268 	if (!ret)
1269 		new_rs->state = rs_connect_rdwr;
1270 	else if (errno == EAGAIN || errno == EWOULDBLOCK)
1271 		new_rs->state = rs_accepting;
1272 	else
1273 		goto err;
1274 
1275 	if (addr && addrlen)
1276 		rgetpeername(new_rs->index, addr, addrlen);
1277 	return new_rs->index;
1278 
1279 err:
1280 	rs_free(new_rs);
1281 	return ret;
1282 }
1283 
rs_do_connect(struct rsocket * rs)1284 static int rs_do_connect(struct rsocket *rs)
1285 {
1286 	struct rdma_conn_param param;
1287 	struct rs_conn_private_data cdata;
1288 	struct rs_conn_data *creq, *cresp;
1289 	int to, ret;
1290 
1291 	switch (rs->state) {
1292 	case rs_init:
1293 	case rs_bound:
1294 resolve_addr:
1295 		to = 1000 << rs->retries++;
1296 		ret = rdma_resolve_addr(rs->cm_id, NULL,
1297 					&rs->cm_id->route.addr.dst_addr, to);
1298 		if (!ret)
1299 			goto resolve_route;
1300 		if (errno == EAGAIN || errno == EWOULDBLOCK)
1301 			rs->state = rs_resolving_addr;
1302 		break;
1303 	case rs_resolving_addr:
1304 		ret = ucma_complete(rs->cm_id);
1305 		if (ret) {
1306 			if (errno == ETIMEDOUT && rs->retries <= RS_CONN_RETRIES)
1307 				goto resolve_addr;
1308 			break;
1309 		}
1310 
1311 		rs->retries = 0;
1312 resolve_route:
1313 		to = 1000 << rs->retries++;
1314 		if (rs->optval) {
1315 			ret = rdma_set_option(rs->cm_id,  RDMA_OPTION_IB,
1316 					      RDMA_OPTION_IB_PATH, rs->optval,
1317 					      rs->optlen);
1318 			free(rs->optval);
1319 			rs->optval = NULL;
1320 			if (!ret) {
1321 				rs->state = rs_resolving_route;
1322 				goto resolving_route;
1323 			}
1324 		} else {
1325 			ret = rdma_resolve_route(rs->cm_id, to);
1326 			if (!ret)
1327 				goto do_connect;
1328 		}
1329 		if (errno == EAGAIN || errno == EWOULDBLOCK)
1330 			rs->state = rs_resolving_route;
1331 		break;
1332 	case rs_resolving_route:
1333 resolving_route:
1334 		ret = ucma_complete(rs->cm_id);
1335 		if (ret) {
1336 			if (errno == ETIMEDOUT && rs->retries <= RS_CONN_RETRIES)
1337 				goto resolve_route;
1338 			break;
1339 		}
1340 do_connect:
1341 		ret = rs_create_ep(rs);
1342 		if (ret)
1343 			break;
1344 
1345 		memset(&param, 0, sizeof param);
1346 		creq = (void *) &cdata + rs_conn_data_offset(rs);
1347 		rs_format_conn_data(rs, creq);
1348 		param.private_data = (void *) creq - rs_conn_data_offset(rs);
1349 		param.private_data_len = sizeof(*creq) + rs_conn_data_offset(rs);
1350 		param.flow_control = 1;
1351 		param.retry_count = 7;
1352 		param.rnr_retry_count = 7;
1353 		/* work-around: iWarp issues RDMA read during connection */
1354 		if (rs->opts & RS_OPT_MSG_SEND)
1355 			param.initiator_depth = 1;
1356 		rs->retries = 0;
1357 
1358 		ret = rdma_connect(rs->cm_id, &param);
1359 		if (!ret)
1360 			goto connected;
1361 		if (errno == EAGAIN || errno == EWOULDBLOCK)
1362 			rs->state = rs_connecting;
1363 		break;
1364 	case rs_connecting:
1365 		ret = ucma_complete(rs->cm_id);
1366 		if (ret)
1367 			break;
1368 connected:
1369 		cresp = (struct rs_conn_data *) rs->cm_id->event->param.conn.private_data;
1370 		if (cresp->version != 1) {
1371 			ret = ERR(ENOTSUP);
1372 			break;
1373 		}
1374 
1375 		rs_save_conn_data(rs, cresp);
1376 		rs->state = rs_connect_rdwr;
1377 		break;
1378 	case rs_accepting:
1379 		if (!(rs->fd_flags & O_NONBLOCK))
1380 			fcntl(rs->cm_id->channel->fd, F_SETFL, 0);
1381 
1382 		ret = ucma_complete(rs->cm_id);
1383 		if (ret)
1384 			break;
1385 
1386 		rs->state = rs_connect_rdwr;
1387 		break;
1388 	default:
1389 		ret = ERR(EINVAL);
1390 		break;
1391 	}
1392 
1393 	if (ret) {
1394 		if (errno == EAGAIN || errno == EWOULDBLOCK) {
1395 			errno = EINPROGRESS;
1396 		} else {
1397 			rs->state = rs_connect_error;
1398 			rs->err = errno;
1399 		}
1400 	}
1401 	return ret;
1402 }
1403 
rs_any_addr(const union socket_addr * addr)1404 static int rs_any_addr(const union socket_addr *addr)
1405 {
1406 	if (addr->sa.sa_family == AF_INET) {
1407 		return (addr->sin.sin_addr.s_addr == htobe32(INADDR_ANY) ||
1408 			addr->sin.sin_addr.s_addr == htobe32(INADDR_LOOPBACK));
1409 	} else {
1410 		return (!memcmp(&addr->sin6.sin6_addr, &in6addr_any, 16) ||
1411 			!memcmp(&addr->sin6.sin6_addr, &in6addr_loopback, 16));
1412 	}
1413 }
1414 
ds_get_src_addr(struct rsocket * rs,const struct sockaddr * dest_addr,socklen_t dest_len,union socket_addr * src_addr,socklen_t * src_len)1415 static int ds_get_src_addr(struct rsocket *rs,
1416 			   const struct sockaddr *dest_addr, socklen_t dest_len,
1417 			   union socket_addr *src_addr, socklen_t *src_len)
1418 {
1419 	int sock, ret;
1420 	__be16 port;
1421 
1422 	*src_len = sizeof(*src_addr);
1423 	ret = getsockname(rs->udp_sock, &src_addr->sa, src_len);
1424 	if (ret || !rs_any_addr(src_addr))
1425 		return ret;
1426 
1427 	port = src_addr->sin.sin_port;
1428 	sock = socket(dest_addr->sa_family, SOCK_DGRAM, 0);
1429 	if (sock < 0)
1430 		return sock;
1431 
1432 	ret = connect(sock, dest_addr, dest_len);
1433 	if (ret)
1434 		goto out;
1435 
1436 	*src_len = sizeof(*src_addr);
1437 	ret = getsockname(sock, &src_addr->sa, src_len);
1438 	src_addr->sin.sin_port = port;
1439 out:
1440 	close(sock);
1441 	return ret;
1442 }
1443 
ds_format_hdr(struct ds_header * hdr,union socket_addr * addr)1444 static void ds_format_hdr(struct ds_header *hdr, union socket_addr *addr)
1445 {
1446 	if (addr->sa.sa_family == AF_INET) {
1447 		hdr->version = 4;
1448 		hdr->length = DS_IPV4_HDR_LEN;
1449 		hdr->port = addr->sin.sin_port;
1450 		hdr->addr.ipv4 = addr->sin.sin_addr.s_addr;
1451 	} else {
1452 		hdr->version = 6;
1453 		hdr->length = DS_IPV6_HDR_LEN;
1454 		hdr->port = addr->sin6.sin6_port;
1455 		hdr->addr.ipv6.flowinfo= addr->sin6.sin6_flowinfo;
1456 		memcpy(&hdr->addr.ipv6.addr, &addr->sin6.sin6_addr, 16);
1457 	}
1458 }
1459 
ds_add_qp_dest(struct ds_qp * qp,union socket_addr * addr,socklen_t addrlen)1460 static int ds_add_qp_dest(struct ds_qp *qp, union socket_addr *addr,
1461 			  socklen_t addrlen)
1462 {
1463 	struct ibv_port_attr port_attr;
1464 	struct ibv_ah_attr attr;
1465 	int ret;
1466 
1467 	memcpy(&qp->dest.addr, addr, addrlen);
1468 	qp->dest.qp = qp;
1469 	qp->dest.qpn = qp->cm_id->qp->qp_num;
1470 
1471 	ret = ibv_query_port(qp->cm_id->verbs, qp->cm_id->port_num, &port_attr);
1472 	if (ret)
1473 		return ret;
1474 
1475 	memset(&attr, 0, sizeof attr);
1476 	attr.dlid = port_attr.lid;
1477 	attr.port_num = qp->cm_id->port_num;
1478 	qp->dest.ah = ibv_create_ah(qp->cm_id->pd, &attr);
1479 	if (!qp->dest.ah)
1480 		return ERR(ENOMEM);
1481 
1482 	tsearch(&qp->dest.addr, &qp->rs->dest_map, ds_compare_addr);
1483 	return 0;
1484 }
1485 
ds_create_qp(struct rsocket * rs,union socket_addr * src_addr,socklen_t addrlen,struct ds_qp ** new_qp)1486 static int ds_create_qp(struct rsocket *rs, union socket_addr *src_addr,
1487 			socklen_t addrlen, struct ds_qp **new_qp)
1488 {
1489 	struct ds_qp *qp;
1490 	struct ibv_qp_init_attr qp_attr;
1491 	struct epoll_event event;
1492 	int i, ret;
1493 
1494 	qp = calloc(1, sizeof(*qp));
1495 	if (!qp)
1496 		return ERR(ENOMEM);
1497 
1498 	qp->rs = rs;
1499 	ret = rdma_create_id(NULL, &qp->cm_id, qp, RDMA_PS_UDP);
1500 	if (ret)
1501 		goto err;
1502 
1503 	ds_format_hdr(&qp->hdr, src_addr);
1504 	ret = rdma_bind_addr(qp->cm_id, &src_addr->sa);
1505 	if (ret)
1506 		goto err;
1507 
1508 	ret = ds_init_bufs(qp);
1509 	if (ret)
1510 		goto err;
1511 
1512 	ret = rs_create_cq(rs, qp->cm_id);
1513 	if (ret)
1514 		goto err;
1515 
1516 	memset(&qp_attr, 0, sizeof qp_attr);
1517 	qp_attr.qp_context = qp;
1518 	qp_attr.send_cq = qp->cm_id->send_cq;
1519 	qp_attr.recv_cq = qp->cm_id->recv_cq;
1520 	qp_attr.qp_type = IBV_QPT_UD;
1521 	qp_attr.sq_sig_all = 1;
1522 	qp_attr.cap.max_send_wr = rs->sq_size;
1523 	qp_attr.cap.max_recv_wr = rs->rq_size;
1524 	qp_attr.cap.max_send_sge = 1;
1525 	qp_attr.cap.max_recv_sge = 2;
1526 	qp_attr.cap.max_inline_data = rs->sq_inline;
1527 	ret = rdma_create_qp(qp->cm_id, NULL, &qp_attr);
1528 	if (ret)
1529 		goto err;
1530 
1531 	rs->sq_inline = qp_attr.cap.max_inline_data;
1532 	ret = ds_add_qp_dest(qp, src_addr, addrlen);
1533 	if (ret)
1534 		goto err;
1535 
1536 	event.events = EPOLLIN;
1537 	event.data.ptr = qp;
1538 	ret = epoll_ctl(rs->epfd,  EPOLL_CTL_ADD,
1539 			qp->cm_id->recv_cq_channel->fd, &event);
1540 	if (ret)
1541 		goto err;
1542 
1543 	for (i = 0; i < rs->rq_size; i++) {
1544 		ret = ds_post_recv(rs, qp, i * RS_SNDLOWAT);
1545 		if (ret)
1546 			goto err;
1547 	}
1548 
1549 	ds_insert_qp(rs, qp);
1550 	*new_qp = qp;
1551 	return 0;
1552 err:
1553 	ds_free_qp(qp);
1554 	return ret;
1555 }
1556 
ds_get_qp(struct rsocket * rs,union socket_addr * src_addr,socklen_t addrlen,struct ds_qp ** qp)1557 static int ds_get_qp(struct rsocket *rs, union socket_addr *src_addr,
1558 		     socklen_t addrlen, struct ds_qp **qp)
1559 {
1560 	if (rs->qp_list) {
1561 		*qp = rs->qp_list;
1562 		do {
1563 			if (!ds_compare_addr(rdma_get_local_addr((*qp)->cm_id),
1564 					     src_addr))
1565 				return 0;
1566 
1567 			*qp = ds_next_qp(*qp);
1568 		} while (*qp != rs->qp_list);
1569 	}
1570 
1571 	return ds_create_qp(rs, src_addr, addrlen, qp);
1572 }
1573 
ds_get_dest(struct rsocket * rs,const struct sockaddr * addr,socklen_t addrlen,struct ds_dest ** dest)1574 static int ds_get_dest(struct rsocket *rs, const struct sockaddr *addr,
1575 		       socklen_t addrlen, struct ds_dest **dest)
1576 {
1577 	union socket_addr src_addr;
1578 	socklen_t src_len;
1579 	struct ds_qp *qp;
1580 	struct ds_dest **tdest, *new_dest;
1581 	int ret = 0;
1582 
1583 	fastlock_acquire(&rs->map_lock);
1584 	tdest = tfind(addr, &rs->dest_map, ds_compare_addr);
1585 	if (tdest)
1586 		goto found;
1587 
1588 	ret = ds_get_src_addr(rs, addr, addrlen, &src_addr, &src_len);
1589 	if (ret)
1590 		goto out;
1591 
1592 	ret = ds_get_qp(rs, &src_addr, src_len, &qp);
1593 	if (ret)
1594 		goto out;
1595 
1596 	tdest = tfind(addr, &rs->dest_map, ds_compare_addr);
1597 	if (!tdest) {
1598 		new_dest = calloc(1, sizeof(*new_dest));
1599 		if (!new_dest) {
1600 			ret = ERR(ENOMEM);
1601 			goto out;
1602 		}
1603 
1604 		memcpy(&new_dest->addr, addr, addrlen);
1605 		new_dest->qp = qp;
1606 		tdest = tsearch(&new_dest->addr, &rs->dest_map, ds_compare_addr);
1607 	}
1608 
1609 found:
1610 	*dest = *tdest;
1611 out:
1612 	fastlock_release(&rs->map_lock);
1613 	return ret;
1614 }
1615 
rconnect(int socket,const struct sockaddr * addr,socklen_t addrlen)1616 int rconnect(int socket, const struct sockaddr *addr, socklen_t addrlen)
1617 {
1618 	struct rsocket *rs;
1619 	int ret;
1620 
1621 	rs = idm_lookup(&idm, socket);
1622 	if (!rs)
1623 		return ERR(EBADF);
1624 	if (rs->type == SOCK_STREAM) {
1625 		memcpy(&rs->cm_id->route.addr.dst_addr, addr, addrlen);
1626 		ret = rs_do_connect(rs);
1627 	} else {
1628 		if (rs->state == rs_init) {
1629 			ret = ds_init_ep(rs);
1630 			if (ret)
1631 				return ret;
1632 		}
1633 
1634 		fastlock_acquire(&rs->slock);
1635 		ret = connect(rs->udp_sock, addr, addrlen);
1636 		if (!ret)
1637 			ret = ds_get_dest(rs, addr, addrlen, &rs->conn_dest);
1638 		fastlock_release(&rs->slock);
1639 	}
1640 	return ret;
1641 }
1642 
rs_get_ctrl_buf(struct rsocket * rs)1643 static void *rs_get_ctrl_buf(struct rsocket *rs)
1644 {
1645 	return rs->sbuf + rs->sbuf_size +
1646 		RS_MAX_CTRL_MSG * (rs->ctrl_seqno & (RS_QP_CTRL_SIZE - 1));
1647 }
1648 
rs_post_msg(struct rsocket * rs,uint32_t msg)1649 static int rs_post_msg(struct rsocket *rs, uint32_t msg)
1650 {
1651 	struct ibv_send_wr wr, *bad;
1652 	struct ibv_sge sge;
1653 
1654 	wr.wr_id = rs_send_wr_id(msg);
1655 	wr.next = NULL;
1656 	if (!(rs->opts & RS_OPT_MSG_SEND)) {
1657 		wr.sg_list = NULL;
1658 		wr.num_sge = 0;
1659 		wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
1660 		wr.send_flags = 0;
1661 		wr.imm_data = htobe32(msg);
1662 	} else {
1663 		sge.addr = (uintptr_t) &msg;
1664 		sge.lkey = 0;
1665 		sge.length = sizeof msg;
1666 		wr.sg_list = &sge;
1667 		wr.num_sge = 1;
1668 		wr.opcode = IBV_WR_SEND;
1669 		wr.send_flags = IBV_SEND_INLINE;
1670 	}
1671 
1672 	return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad));
1673 }
1674 
rs_post_write(struct rsocket * rs,struct ibv_sge * sgl,int nsge,uint32_t wr_data,int flags,uint64_t addr,uint32_t rkey)1675 static int rs_post_write(struct rsocket *rs,
1676 			 struct ibv_sge *sgl, int nsge,
1677 			 uint32_t wr_data, int flags,
1678 			 uint64_t addr, uint32_t rkey)
1679 {
1680 	struct ibv_send_wr wr, *bad;
1681 
1682 	wr.wr_id = rs_send_wr_id(wr_data);
1683 	wr.next = NULL;
1684 	wr.sg_list = sgl;
1685 	wr.num_sge = nsge;
1686 	wr.opcode = IBV_WR_RDMA_WRITE;
1687 	wr.send_flags = flags;
1688 	wr.wr.rdma.remote_addr = addr;
1689 	wr.wr.rdma.rkey = rkey;
1690 
1691 	return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad));
1692 }
1693 
rs_post_write_msg(struct rsocket * rs,struct ibv_sge * sgl,int nsge,uint32_t msg,int flags,uint64_t addr,uint32_t rkey)1694 static int rs_post_write_msg(struct rsocket *rs,
1695 			 struct ibv_sge *sgl, int nsge,
1696 			 uint32_t msg, int flags,
1697 			 uint64_t addr, uint32_t rkey)
1698 {
1699 	struct ibv_send_wr wr, *bad;
1700 	struct ibv_sge sge;
1701 	int ret;
1702 
1703 	wr.next = NULL;
1704 	if (!(rs->opts & RS_OPT_MSG_SEND)) {
1705 		wr.wr_id = rs_send_wr_id(msg);
1706 		wr.sg_list = sgl;
1707 		wr.num_sge = nsge;
1708 		wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
1709 		wr.send_flags = flags;
1710 		wr.imm_data = htobe32(msg);
1711 		wr.wr.rdma.remote_addr = addr;
1712 		wr.wr.rdma.rkey = rkey;
1713 
1714 		return rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad));
1715 	} else {
1716 		ret = rs_post_write(rs, sgl, nsge, msg, flags, addr, rkey);
1717 		if (!ret) {
1718 			wr.wr_id = rs_send_wr_id(rs_msg_set(rs_msg_op(msg), 0)) |
1719 				   RS_WR_ID_FLAG_MSG_SEND;
1720 			sge.addr = (uintptr_t) &msg;
1721 			sge.lkey = 0;
1722 			sge.length = sizeof msg;
1723 			wr.sg_list = &sge;
1724 			wr.num_sge = 1;
1725 			wr.opcode = IBV_WR_SEND;
1726 			wr.send_flags = IBV_SEND_INLINE;
1727 
1728 			ret = rdma_seterrno(ibv_post_send(rs->cm_id->qp, &wr, &bad));
1729 		}
1730 		return ret;
1731 	}
1732 }
1733 
ds_post_send(struct rsocket * rs,struct ibv_sge * sge,uint32_t wr_data)1734 static int ds_post_send(struct rsocket *rs, struct ibv_sge *sge,
1735 			uint32_t wr_data)
1736 {
1737 	struct ibv_send_wr wr, *bad;
1738 
1739 	wr.wr_id = rs_send_wr_id(wr_data);
1740 	wr.next = NULL;
1741 	wr.sg_list = sge;
1742 	wr.num_sge = 1;
1743 	wr.opcode = IBV_WR_SEND;
1744 	wr.send_flags = (sge->length <= rs->sq_inline) ? IBV_SEND_INLINE : 0;
1745 	wr.wr.ud.ah = rs->conn_dest->ah;
1746 	wr.wr.ud.remote_qpn = rs->conn_dest->qpn;
1747 	wr.wr.ud.remote_qkey = RDMA_UDP_QKEY;
1748 
1749 	return rdma_seterrno(ibv_post_send(rs->conn_dest->qp->cm_id->qp, &wr, &bad));
1750 }
1751 
1752 /*
1753  * Update target SGE before sending data.  Otherwise the remote side may
1754  * update the entry before we do.
1755  */
rs_write_data(struct rsocket * rs,struct ibv_sge * sgl,int nsge,uint32_t length,int flags)1756 static int rs_write_data(struct rsocket *rs,
1757 			 struct ibv_sge *sgl, int nsge,
1758 			 uint32_t length, int flags)
1759 {
1760 	uint64_t addr;
1761 	uint32_t rkey;
1762 
1763 	rs->sseq_no++;
1764 	rs->sqe_avail--;
1765 	if (rs->opts & RS_OPT_MSG_SEND)
1766 		rs->sqe_avail--;
1767 	rs->sbuf_bytes_avail -= length;
1768 
1769 	addr = rs->target_sgl[rs->target_sge].addr;
1770 	rkey = rs->target_sgl[rs->target_sge].key;
1771 
1772 	rs->target_sgl[rs->target_sge].addr += length;
1773 	rs->target_sgl[rs->target_sge].length -= length;
1774 
1775 	if (!rs->target_sgl[rs->target_sge].length) {
1776 		if (++rs->target_sge == RS_SGL_SIZE)
1777 			rs->target_sge = 0;
1778 	}
1779 
1780 	return rs_post_write_msg(rs, sgl, nsge, rs_msg_set(RS_OP_DATA, length),
1781 				 flags, addr, rkey);
1782 }
1783 
rs_write_direct(struct rsocket * rs,struct rs_iomap * iom,uint64_t offset,struct ibv_sge * sgl,int nsge,uint32_t length,int flags)1784 static int rs_write_direct(struct rsocket *rs, struct rs_iomap *iom, uint64_t offset,
1785 			   struct ibv_sge *sgl, int nsge, uint32_t length, int flags)
1786 {
1787 	uint64_t addr;
1788 
1789 	rs->sqe_avail--;
1790 	rs->sbuf_bytes_avail -= length;
1791 
1792 	addr = iom->sge.addr + offset - iom->offset;
1793 	return rs_post_write(rs, sgl, nsge, rs_msg_set(RS_OP_WRITE, length),
1794 			     flags, addr, iom->sge.key);
1795 }
1796 
rs_write_iomap(struct rsocket * rs,struct rs_iomap_mr * iomr,struct ibv_sge * sgl,int nsge,int flags)1797 static int rs_write_iomap(struct rsocket *rs, struct rs_iomap_mr *iomr,
1798 			  struct ibv_sge *sgl, int nsge, int flags)
1799 {
1800 	uint64_t addr;
1801 
1802 	rs->sseq_no++;
1803 	rs->sqe_avail--;
1804 	if (rs->opts & RS_OPT_MSG_SEND)
1805 		rs->sqe_avail--;
1806 	rs->sbuf_bytes_avail -= sizeof(struct rs_iomap);
1807 
1808 	addr = rs->remote_iomap.addr + iomr->index * sizeof(struct rs_iomap);
1809 	return rs_post_write_msg(rs, sgl, nsge, rs_msg_set(RS_OP_IOMAP_SGL, iomr->index),
1810 				 flags, addr, rs->remote_iomap.key);
1811 }
1812 
rs_sbuf_left(struct rsocket * rs)1813 static uint32_t rs_sbuf_left(struct rsocket *rs)
1814 {
1815 	return (uint32_t) (((uint64_t) (uintptr_t) &rs->sbuf[rs->sbuf_size]) -
1816 			   rs->ssgl[0].addr);
1817 }
1818 
rs_send_credits(struct rsocket * rs)1819 static void rs_send_credits(struct rsocket *rs)
1820 {
1821 	struct ibv_sge ibsge;
1822 	struct rs_sge sge, *sge_buf;
1823 	int flags;
1824 
1825 	rs->ctrl_seqno++;
1826 	rs->rseq_comp = rs->rseq_no + (rs->rq_size >> 1);
1827 	if (rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) {
1828 		if (rs->opts & RS_OPT_MSG_SEND)
1829 			rs->ctrl_seqno++;
1830 
1831 		if (!(rs->opts & RS_OPT_SWAP_SGL)) {
1832 			sge.addr = (uintptr_t) &rs->rbuf[rs->rbuf_free_offset];
1833 			sge.key = rs->rmr->rkey;
1834 			sge.length = rs->rbuf_size >> 1;
1835 		} else {
1836 			sge.addr = bswap_64((uintptr_t) &rs->rbuf[rs->rbuf_free_offset]);
1837 			sge.key = bswap_32(rs->rmr->rkey);
1838 			sge.length = bswap_32(rs->rbuf_size >> 1);
1839 		}
1840 
1841 		if (rs->sq_inline < sizeof sge) {
1842 			sge_buf = rs_get_ctrl_buf(rs);
1843 			memcpy(sge_buf, &sge, sizeof sge);
1844 			ibsge.addr = (uintptr_t) sge_buf;
1845 			ibsge.lkey = rs->smr->lkey;
1846 			flags = 0;
1847 		} else {
1848 			ibsge.addr = (uintptr_t) &sge;
1849 			ibsge.lkey = 0;
1850 			flags = IBV_SEND_INLINE;
1851 		}
1852 		ibsge.length = sizeof(sge);
1853 
1854 		rs_post_write_msg(rs, &ibsge, 1,
1855 			rs_msg_set(RS_OP_SGL, rs->rseq_no + rs->rq_size), flags,
1856 			rs->remote_sgl.addr + rs->remote_sge * sizeof(struct rs_sge),
1857 			rs->remote_sgl.key);
1858 
1859 		rs->rbuf_bytes_avail -= rs->rbuf_size >> 1;
1860 		rs->rbuf_free_offset += rs->rbuf_size >> 1;
1861 		if (rs->rbuf_free_offset >= rs->rbuf_size)
1862 			rs->rbuf_free_offset = 0;
1863 		if (++rs->remote_sge == rs->remote_sgl.length)
1864 			rs->remote_sge = 0;
1865 	} else {
1866 		rs_post_msg(rs, rs_msg_set(RS_OP_SGL, rs->rseq_no + rs->rq_size));
1867 	}
1868 }
1869 
rs_ctrl_avail(struct rsocket * rs)1870 static inline int rs_ctrl_avail(struct rsocket *rs)
1871 {
1872 	return rs->ctrl_seqno != rs->ctrl_max_seqno;
1873 }
1874 
1875 /* Protocols that do not support RDMA write with immediate may require 2 msgs */
rs_2ctrl_avail(struct rsocket * rs)1876 static inline int rs_2ctrl_avail(struct rsocket *rs)
1877 {
1878 	return (int)((rs->ctrl_seqno + 1) - rs->ctrl_max_seqno) < 0;
1879 }
1880 
rs_give_credits(struct rsocket * rs)1881 static int rs_give_credits(struct rsocket *rs)
1882 {
1883 	if (!(rs->opts & RS_OPT_MSG_SEND)) {
1884 		return ((rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) ||
1885 			((short) ((short) rs->rseq_no - (short) rs->rseq_comp) >= 0)) &&
1886 		       rs_ctrl_avail(rs) && (rs->state & rs_connected);
1887 	} else {
1888 		return ((rs->rbuf_bytes_avail >= (rs->rbuf_size >> 1)) ||
1889 			((short) ((short) rs->rseq_no - (short) rs->rseq_comp) >= 0)) &&
1890 		       rs_2ctrl_avail(rs) && (rs->state & rs_connected);
1891 	}
1892 }
1893 
rs_update_credits(struct rsocket * rs)1894 static void rs_update_credits(struct rsocket *rs)
1895 {
1896 	if (rs_give_credits(rs))
1897 		rs_send_credits(rs);
1898 }
1899 
rs_poll_cq(struct rsocket * rs)1900 static int rs_poll_cq(struct rsocket *rs)
1901 {
1902 	struct ibv_wc wc;
1903 	uint32_t msg;
1904 	int ret, rcnt = 0;
1905 
1906 	while ((ret = ibv_poll_cq(rs->cm_id->recv_cq, 1, &wc)) > 0) {
1907 		if (rs_wr_is_recv(wc.wr_id)) {
1908 			if (wc.status != IBV_WC_SUCCESS)
1909 				continue;
1910 			rcnt++;
1911 
1912 			if (wc.wc_flags & IBV_WC_WITH_IMM) {
1913 				msg = be32toh(wc.imm_data);
1914 			} else {
1915 				msg = ((uint32_t *) (rs->rbuf + rs->rbuf_size))
1916 					[rs_wr_data(wc.wr_id)];
1917 
1918 			}
1919 			switch (rs_msg_op(msg)) {
1920 			case RS_OP_SGL:
1921 				rs->sseq_comp = (uint16_t) rs_msg_data(msg);
1922 				break;
1923 			case RS_OP_IOMAP_SGL:
1924 				/* The iomap was updated, that's nice to know. */
1925 				break;
1926 			case RS_OP_CTRL:
1927 				if (rs_msg_data(msg) == RS_CTRL_DISCONNECT) {
1928 					rs->state = rs_disconnected;
1929 					return 0;
1930 				} else if (rs_msg_data(msg) == RS_CTRL_SHUTDOWN) {
1931 					if (rs->state & rs_writable) {
1932 						rs->state &= ~rs_readable;
1933 					} else {
1934 						rs->state = rs_disconnected;
1935 						return 0;
1936 					}
1937 				}
1938 				break;
1939 			case RS_OP_WRITE:
1940 				/* We really shouldn't be here. */
1941 				break;
1942 			default:
1943 				rs->rmsg[rs->rmsg_tail].op = rs_msg_op(msg);
1944 				rs->rmsg[rs->rmsg_tail].data = rs_msg_data(msg);
1945 				if (++rs->rmsg_tail == rs->rq_size + 1)
1946 					rs->rmsg_tail = 0;
1947 				break;
1948 			}
1949 		} else {
1950 			switch  (rs_msg_op(rs_wr_data(wc.wr_id))) {
1951 			case RS_OP_SGL:
1952 				rs->ctrl_max_seqno++;
1953 				break;
1954 			case RS_OP_CTRL:
1955 				rs->ctrl_max_seqno++;
1956 				if (rs_msg_data(rs_wr_data(wc.wr_id)) == RS_CTRL_DISCONNECT)
1957 					rs->state = rs_disconnected;
1958 				break;
1959 			case RS_OP_IOMAP_SGL:
1960 				rs->sqe_avail++;
1961 				if (!rs_wr_is_msg_send(wc.wr_id))
1962 					rs->sbuf_bytes_avail += sizeof(struct rs_iomap);
1963 				break;
1964 			default:
1965 				rs->sqe_avail++;
1966 				rs->sbuf_bytes_avail += rs_msg_data(rs_wr_data(wc.wr_id));
1967 				break;
1968 			}
1969 			if (wc.status != IBV_WC_SUCCESS && (rs->state & rs_connected)) {
1970 				rs->state = rs_error;
1971 				rs->err = EIO;
1972 			}
1973 		}
1974 	}
1975 
1976 	if (rs->state & rs_connected) {
1977 		while (!ret && rcnt--)
1978 			ret = rs_post_recv(rs);
1979 
1980 		if (ret) {
1981 			rs->state = rs_error;
1982 			rs->err = errno;
1983 		}
1984 	}
1985 	return ret;
1986 }
1987 
rs_get_cq_event(struct rsocket * rs)1988 static int rs_get_cq_event(struct rsocket *rs)
1989 {
1990 	struct ibv_cq *cq;
1991 	void *context;
1992 	int ret;
1993 
1994 	if (!rs->cq_armed)
1995 		return 0;
1996 
1997 	ret = ibv_get_cq_event(rs->cm_id->recv_cq_channel, &cq, &context);
1998 	if (!ret) {
1999 		if (++rs->unack_cqe >= rs->sq_size + rs->rq_size) {
2000 			ibv_ack_cq_events(rs->cm_id->recv_cq, rs->unack_cqe);
2001 			rs->unack_cqe = 0;
2002 		}
2003 		rs->cq_armed = 0;
2004 	} else if (!(errno == EAGAIN || errno == EINTR)) {
2005 		rs->state = rs_error;
2006 	}
2007 
2008 	return ret;
2009 }
2010 
2011 /*
2012  * Although we serialize rsend and rrecv calls with respect to themselves,
2013  * both calls may run simultaneously and need to poll the CQ for completions.
2014  * We need to serialize access to the CQ, but rsend and rrecv need to
2015  * allow each other to make forward progress.
2016  *
2017  * For example, rsend may need to wait for credits from the remote side,
2018  * which could be stalled until the remote process calls rrecv.  This should
2019  * not block rrecv from receiving data from the remote side however.
2020  *
2021  * We handle this by using two locks.  The cq_lock protects against polling
2022  * the CQ and processing completions.  The cq_wait_lock serializes access to
2023  * waiting on the CQ.
2024  */
rs_process_cq(struct rsocket * rs,int nonblock,int (* test)(struct rsocket * rs))2025 static int rs_process_cq(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs))
2026 {
2027 	int ret;
2028 
2029 	fastlock_acquire(&rs->cq_lock);
2030 	do {
2031 		rs_update_credits(rs);
2032 		ret = rs_poll_cq(rs);
2033 		if (test(rs)) {
2034 			ret = 0;
2035 			break;
2036 		} else if (ret) {
2037 			break;
2038 		} else if (nonblock) {
2039 			ret = ERR(EWOULDBLOCK);
2040 		} else if (!rs->cq_armed) {
2041 			ibv_req_notify_cq(rs->cm_id->recv_cq, 0);
2042 			rs->cq_armed = 1;
2043 		} else {
2044 			rs_update_credits(rs);
2045 			fastlock_acquire(&rs->cq_wait_lock);
2046 			fastlock_release(&rs->cq_lock);
2047 
2048 			ret = rs_get_cq_event(rs);
2049 			fastlock_release(&rs->cq_wait_lock);
2050 			fastlock_acquire(&rs->cq_lock);
2051 		}
2052 	} while (!ret);
2053 
2054 	rs_update_credits(rs);
2055 	fastlock_release(&rs->cq_lock);
2056 	return ret;
2057 }
2058 
rs_get_comp(struct rsocket * rs,int nonblock,int (* test)(struct rsocket * rs))2059 static int rs_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs))
2060 {
2061 	struct timeval s, e;
2062 	uint32_t poll_time = 0;
2063 	int ret;
2064 
2065 	do {
2066 		ret = rs_process_cq(rs, 1, test);
2067 		if (!ret || nonblock || errno != EWOULDBLOCK)
2068 			return ret;
2069 
2070 		if (!poll_time)
2071 			gettimeofday(&s, NULL);
2072 
2073 		gettimeofday(&e, NULL);
2074 		poll_time = (e.tv_sec - s.tv_sec) * 1000000 +
2075 			    (e.tv_usec - s.tv_usec) + 1;
2076 	} while (poll_time <= polling_time);
2077 
2078 	ret = rs_process_cq(rs, 0, test);
2079 	return ret;
2080 }
2081 
ds_valid_recv(struct ds_qp * qp,struct ibv_wc * wc)2082 static int ds_valid_recv(struct ds_qp *qp, struct ibv_wc *wc)
2083 {
2084 	struct ds_header *hdr;
2085 
2086 	hdr = (struct ds_header *) (qp->rbuf + rs_wr_data(wc->wr_id));
2087 	return ((wc->byte_len >= sizeof(struct ibv_grh) + DS_IPV4_HDR_LEN) &&
2088 		((hdr->version == 4 && hdr->length == DS_IPV4_HDR_LEN) ||
2089 		 (hdr->version == 6 && hdr->length == DS_IPV6_HDR_LEN)));
2090 }
2091 
2092 /*
2093  * Poll all CQs associated with a datagram rsocket.  We need to drop any
2094  * received messages that we do not have room to store.  To limit drops,
2095  * we only poll if we have room to store the receive or we need a send
2096  * buffer.  To ensure fairness, we poll the CQs round robin, remembering
2097  * where we left off.
2098  */
ds_poll_cqs(struct rsocket * rs)2099 static void ds_poll_cqs(struct rsocket *rs)
2100 {
2101 	struct ds_qp *qp;
2102 	struct ds_smsg *smsg;
2103 	struct ds_rmsg *rmsg;
2104 	struct ibv_wc wc;
2105 	int ret, cnt;
2106 
2107 	if (!(qp = rs->qp_list))
2108 		return;
2109 
2110 	do {
2111 		cnt = 0;
2112 		do {
2113 			ret = ibv_poll_cq(qp->cm_id->recv_cq, 1, &wc);
2114 			if (ret <= 0) {
2115 				qp = ds_next_qp(qp);
2116 				continue;
2117 			}
2118 
2119 			if (rs_wr_is_recv(wc.wr_id)) {
2120 				if (rs->rqe_avail && wc.status == IBV_WC_SUCCESS &&
2121 				    ds_valid_recv(qp, &wc)) {
2122 					rs->rqe_avail--;
2123 					rmsg = &rs->dmsg[rs->rmsg_tail];
2124 					rmsg->qp = qp;
2125 					rmsg->offset = rs_wr_data(wc.wr_id);
2126 					rmsg->length = wc.byte_len - sizeof(struct ibv_grh);
2127 					if (++rs->rmsg_tail == rs->rq_size + 1)
2128 						rs->rmsg_tail = 0;
2129 				} else {
2130 					ds_post_recv(rs, qp, rs_wr_data(wc.wr_id));
2131 				}
2132 			} else {
2133 				smsg = (struct ds_smsg *) (rs->sbuf + rs_wr_data(wc.wr_id));
2134 				smsg->next = rs->smsg_free;
2135 				rs->smsg_free = smsg;
2136 				rs->sqe_avail++;
2137 			}
2138 
2139 			qp = ds_next_qp(qp);
2140 			if (!rs->rqe_avail && rs->sqe_avail) {
2141 				rs->qp_list = qp;
2142 				return;
2143 			}
2144 			cnt++;
2145 		} while (qp != rs->qp_list);
2146 	} while (cnt);
2147 }
2148 
ds_req_notify_cqs(struct rsocket * rs)2149 static void ds_req_notify_cqs(struct rsocket *rs)
2150 {
2151 	struct ds_qp *qp;
2152 
2153 	if (!(qp = rs->qp_list))
2154 		return;
2155 
2156 	do {
2157 		if (!qp->cq_armed) {
2158 			ibv_req_notify_cq(qp->cm_id->recv_cq, 0);
2159 			qp->cq_armed = 1;
2160 		}
2161 		qp = ds_next_qp(qp);
2162 	} while (qp != rs->qp_list);
2163 }
2164 
ds_get_cq_event(struct rsocket * rs)2165 static int ds_get_cq_event(struct rsocket *rs)
2166 {
2167 	struct epoll_event event;
2168 	struct ds_qp *qp;
2169 	struct ibv_cq *cq;
2170 	void *context;
2171 	int ret;
2172 
2173 	if (!rs->cq_armed)
2174 		return 0;
2175 
2176 	ret = epoll_wait(rs->epfd, &event, 1, -1);
2177 	if (ret <= 0)
2178 		return ret;
2179 
2180 	qp = event.data.ptr;
2181 	ret = ibv_get_cq_event(qp->cm_id->recv_cq_channel, &cq, &context);
2182 	if (!ret) {
2183 		ibv_ack_cq_events(qp->cm_id->recv_cq, 1);
2184 		qp->cq_armed = 0;
2185 		rs->cq_armed = 0;
2186 	}
2187 
2188 	return ret;
2189 }
2190 
ds_process_cqs(struct rsocket * rs,int nonblock,int (* test)(struct rsocket * rs))2191 static int ds_process_cqs(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs))
2192 {
2193 	int ret = 0;
2194 
2195 	fastlock_acquire(&rs->cq_lock);
2196 	do {
2197 		ds_poll_cqs(rs);
2198 		if (test(rs)) {
2199 			ret = 0;
2200 			break;
2201 		} else if (nonblock) {
2202 			ret = ERR(EWOULDBLOCK);
2203 		} else if (!rs->cq_armed) {
2204 			ds_req_notify_cqs(rs);
2205 			rs->cq_armed = 1;
2206 		} else {
2207 			fastlock_acquire(&rs->cq_wait_lock);
2208 			fastlock_release(&rs->cq_lock);
2209 
2210 			ret = ds_get_cq_event(rs);
2211 			fastlock_release(&rs->cq_wait_lock);
2212 			fastlock_acquire(&rs->cq_lock);
2213 		}
2214 	} while (!ret);
2215 
2216 	fastlock_release(&rs->cq_lock);
2217 	return ret;
2218 }
2219 
ds_get_comp(struct rsocket * rs,int nonblock,int (* test)(struct rsocket * rs))2220 static int ds_get_comp(struct rsocket *rs, int nonblock, int (*test)(struct rsocket *rs))
2221 {
2222 	struct timeval s, e;
2223 	uint32_t poll_time = 0;
2224 	int ret;
2225 
2226 	do {
2227 		ret = ds_process_cqs(rs, 1, test);
2228 		if (!ret || nonblock || errno != EWOULDBLOCK)
2229 			return ret;
2230 
2231 		if (!poll_time)
2232 			gettimeofday(&s, NULL);
2233 
2234 		gettimeofday(&e, NULL);
2235 		poll_time = (e.tv_sec - s.tv_sec) * 1000000 +
2236 			    (e.tv_usec - s.tv_usec) + 1;
2237 	} while (poll_time <= polling_time);
2238 
2239 	ret = ds_process_cqs(rs, 0, test);
2240 	return ret;
2241 }
2242 
rs_nonblocking(struct rsocket * rs,int flags)2243 static int rs_nonblocking(struct rsocket *rs, int flags)
2244 {
2245 	return (rs->fd_flags & O_NONBLOCK) || (flags & MSG_DONTWAIT);
2246 }
2247 
rs_is_cq_armed(struct rsocket * rs)2248 static int rs_is_cq_armed(struct rsocket *rs)
2249 {
2250 	return rs->cq_armed;
2251 }
2252 
rs_poll_all(struct rsocket * rs)2253 static int rs_poll_all(struct rsocket *rs)
2254 {
2255 	return 1;
2256 }
2257 
2258 /*
2259  * We use hardware flow control to prevent over running the remote
2260  * receive queue.  However, data transfers still require space in
2261  * the remote rmsg queue, or we risk losing notification that data
2262  * has been transfered.
2263  *
2264  * Be careful with race conditions in the check below.  The target SGL
2265  * may be updated by a remote RDMA write.
2266  */
rs_can_send(struct rsocket * rs)2267 static int rs_can_send(struct rsocket *rs)
2268 {
2269 	if (!(rs->opts & RS_OPT_MSG_SEND)) {
2270 		return rs->sqe_avail && (rs->sbuf_bytes_avail >= RS_SNDLOWAT) &&
2271 		       (rs->sseq_no != rs->sseq_comp) &&
2272 		       (rs->target_sgl[rs->target_sge].length != 0);
2273 	} else {
2274 		return (rs->sqe_avail >= 2) && (rs->sbuf_bytes_avail >= RS_SNDLOWAT) &&
2275 		       (rs->sseq_no != rs->sseq_comp) &&
2276 		       (rs->target_sgl[rs->target_sge].length != 0);
2277 	}
2278 }
2279 
ds_can_send(struct rsocket * rs)2280 static int ds_can_send(struct rsocket *rs)
2281 {
2282 	return rs->sqe_avail;
2283 }
2284 
ds_all_sends_done(struct rsocket * rs)2285 static int ds_all_sends_done(struct rsocket *rs)
2286 {
2287 	return rs->sqe_avail == rs->sq_size;
2288 }
2289 
rs_conn_can_send(struct rsocket * rs)2290 static int rs_conn_can_send(struct rsocket *rs)
2291 {
2292 	return rs_can_send(rs) || !(rs->state & rs_writable);
2293 }
2294 
rs_conn_can_send_ctrl(struct rsocket * rs)2295 static int rs_conn_can_send_ctrl(struct rsocket *rs)
2296 {
2297 	return rs_ctrl_avail(rs) || !(rs->state & rs_connected);
2298 }
2299 
rs_have_rdata(struct rsocket * rs)2300 static int rs_have_rdata(struct rsocket *rs)
2301 {
2302 	return (rs->rmsg_head != rs->rmsg_tail);
2303 }
2304 
rs_conn_have_rdata(struct rsocket * rs)2305 static int rs_conn_have_rdata(struct rsocket *rs)
2306 {
2307 	return rs_have_rdata(rs) || !(rs->state & rs_readable);
2308 }
2309 
rs_conn_all_sends_done(struct rsocket * rs)2310 static int rs_conn_all_sends_done(struct rsocket *rs)
2311 {
2312 	return ((((int) rs->ctrl_max_seqno) - ((int) rs->ctrl_seqno)) +
2313 		rs->sqe_avail == rs->sq_size) ||
2314 	       !(rs->state & rs_connected);
2315 }
2316 
ds_set_src(struct sockaddr * addr,socklen_t * addrlen,struct ds_header * hdr)2317 static void ds_set_src(struct sockaddr *addr, socklen_t *addrlen,
2318 		       struct ds_header *hdr)
2319 {
2320 	union socket_addr sa;
2321 
2322 	memset(&sa, 0, sizeof sa);
2323 	if (hdr->version == 4) {
2324 		if (*addrlen > sizeof(sa.sin))
2325 			*addrlen = sizeof(sa.sin);
2326 
2327 		sa.sin.sin_family = AF_INET;
2328 		sa.sin.sin_port = hdr->port;
2329 		sa.sin.sin_addr.s_addr =  hdr->addr.ipv4;
2330 	} else {
2331 		if (*addrlen > sizeof(sa.sin6))
2332 			*addrlen = sizeof(sa.sin6);
2333 
2334 		sa.sin6.sin6_family = AF_INET6;
2335 		sa.sin6.sin6_port = hdr->port;
2336 		sa.sin6.sin6_flowinfo = hdr->addr.ipv6.flowinfo;
2337 		memcpy(&sa.sin6.sin6_addr, &hdr->addr.ipv6.addr, 16);
2338 	}
2339 	memcpy(addr, &sa, *addrlen);
2340 }
2341 
ds_recvfrom(struct rsocket * rs,void * buf,size_t len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2342 static ssize_t ds_recvfrom(struct rsocket *rs, void *buf, size_t len, int flags,
2343 			   struct sockaddr *src_addr, socklen_t *addrlen)
2344 {
2345 	struct ds_rmsg *rmsg;
2346 	struct ds_header *hdr;
2347 	int ret;
2348 
2349 	if (!(rs->state & rs_readable))
2350 		return ERR(EINVAL);
2351 
2352 	if (!rs_have_rdata(rs)) {
2353 		ret = ds_get_comp(rs, rs_nonblocking(rs, flags),
2354 				  rs_have_rdata);
2355 		if (ret)
2356 			return ret;
2357 	}
2358 
2359 	rmsg = &rs->dmsg[rs->rmsg_head];
2360 	hdr = (struct ds_header *) (rmsg->qp->rbuf + rmsg->offset);
2361 	if (len > rmsg->length - hdr->length)
2362 		len = rmsg->length - hdr->length;
2363 
2364 	memcpy(buf, (void *) hdr + hdr->length, len);
2365 	if (addrlen)
2366 		ds_set_src(src_addr, addrlen, hdr);
2367 
2368 	if (!(flags & MSG_PEEK)) {
2369 		ds_post_recv(rs, rmsg->qp, rmsg->offset);
2370 		if (++rs->rmsg_head == rs->rq_size + 1)
2371 			rs->rmsg_head = 0;
2372 		rs->rqe_avail++;
2373 	}
2374 
2375 	return len;
2376 }
2377 
rs_peek(struct rsocket * rs,void * buf,size_t len)2378 static ssize_t rs_peek(struct rsocket *rs, void *buf, size_t len)
2379 {
2380 	size_t left = len;
2381 	uint32_t end_size, rsize;
2382 	int rmsg_head, rbuf_offset;
2383 
2384 	rmsg_head = rs->rmsg_head;
2385 	rbuf_offset = rs->rbuf_offset;
2386 
2387 	for (; left && (rmsg_head != rs->rmsg_tail); left -= rsize) {
2388 		if (left < rs->rmsg[rmsg_head].data) {
2389 			rsize = left;
2390 		} else {
2391 			rsize = rs->rmsg[rmsg_head].data;
2392 			if (++rmsg_head == rs->rq_size + 1)
2393 				rmsg_head = 0;
2394 		}
2395 
2396 		end_size = rs->rbuf_size - rbuf_offset;
2397 		if (rsize > end_size) {
2398 			memcpy(buf, &rs->rbuf[rbuf_offset], end_size);
2399 			rbuf_offset = 0;
2400 			buf += end_size;
2401 			rsize -= end_size;
2402 			left -= end_size;
2403 		}
2404 		memcpy(buf, &rs->rbuf[rbuf_offset], rsize);
2405 		rbuf_offset += rsize;
2406 		buf += rsize;
2407 	}
2408 
2409 	return len - left;
2410 }
2411 
2412 /*
2413  * Continue to receive any queued data even if the remote side has disconnected.
2414  */
rrecv(int socket,void * buf,size_t len,int flags)2415 ssize_t rrecv(int socket, void *buf, size_t len, int flags)
2416 {
2417 	struct rsocket *rs;
2418 	size_t left = len;
2419 	uint32_t end_size, rsize;
2420 	int ret = 0;
2421 
2422 	rs = idm_at(&idm, socket);
2423 	if (rs->type == SOCK_DGRAM) {
2424 		fastlock_acquire(&rs->rlock);
2425 		ret = ds_recvfrom(rs, buf, len, flags, NULL, NULL);
2426 		fastlock_release(&rs->rlock);
2427 		return ret;
2428 	}
2429 
2430 	if (rs->state & rs_opening) {
2431 		ret = rs_do_connect(rs);
2432 		if (ret) {
2433 			if (errno == EINPROGRESS)
2434 				errno = EAGAIN;
2435 			return ret;
2436 		}
2437 	}
2438 	fastlock_acquire(&rs->rlock);
2439 	do {
2440 		if (!rs_have_rdata(rs)) {
2441 			ret = rs_get_comp(rs, rs_nonblocking(rs, flags),
2442 					  rs_conn_have_rdata);
2443 			if (ret)
2444 				break;
2445 		}
2446 
2447 		if (flags & MSG_PEEK) {
2448 			left = len - rs_peek(rs, buf, left);
2449 			break;
2450 		}
2451 
2452 		for (; left && rs_have_rdata(rs); left -= rsize) {
2453 			if (left < rs->rmsg[rs->rmsg_head].data) {
2454 				rsize = left;
2455 				rs->rmsg[rs->rmsg_head].data -= left;
2456 			} else {
2457 				rs->rseq_no++;
2458 				rsize = rs->rmsg[rs->rmsg_head].data;
2459 				if (++rs->rmsg_head == rs->rq_size + 1)
2460 					rs->rmsg_head = 0;
2461 			}
2462 
2463 			end_size = rs->rbuf_size - rs->rbuf_offset;
2464 			if (rsize > end_size) {
2465 				memcpy(buf, &rs->rbuf[rs->rbuf_offset], end_size);
2466 				rs->rbuf_offset = 0;
2467 				buf += end_size;
2468 				rsize -= end_size;
2469 				left -= end_size;
2470 				rs->rbuf_bytes_avail += end_size;
2471 			}
2472 			memcpy(buf, &rs->rbuf[rs->rbuf_offset], rsize);
2473 			rs->rbuf_offset += rsize;
2474 			buf += rsize;
2475 			rs->rbuf_bytes_avail += rsize;
2476 		}
2477 
2478 	} while (left && (flags & MSG_WAITALL) && (rs->state & rs_readable));
2479 
2480 	fastlock_release(&rs->rlock);
2481 	return (ret && left == len) ? ret : len - left;
2482 }
2483 
rrecvfrom(int socket,void * buf,size_t len,int flags,struct sockaddr * src_addr,socklen_t * addrlen)2484 ssize_t rrecvfrom(int socket, void *buf, size_t len, int flags,
2485 		  struct sockaddr *src_addr, socklen_t *addrlen)
2486 {
2487 	struct rsocket *rs;
2488 	int ret;
2489 
2490 	rs = idm_at(&idm, socket);
2491 	if (rs->type == SOCK_DGRAM) {
2492 		fastlock_acquire(&rs->rlock);
2493 		ret = ds_recvfrom(rs, buf, len, flags, src_addr, addrlen);
2494 		fastlock_release(&rs->rlock);
2495 		return ret;
2496 	}
2497 
2498 	ret = rrecv(socket, buf, len, flags);
2499 	if (ret > 0 && src_addr)
2500 		rgetpeername(socket, src_addr, addrlen);
2501 
2502 	return ret;
2503 }
2504 
2505 /*
2506  * Simple, straightforward implementation for now that only tries to fill
2507  * in the first vector.
2508  */
rrecvv(int socket,const struct iovec * iov,int iovcnt,int flags)2509 static ssize_t rrecvv(int socket, const struct iovec *iov, int iovcnt, int flags)
2510 {
2511 	return rrecv(socket, iov[0].iov_base, iov[0].iov_len, flags);
2512 }
2513 
rrecvmsg(int socket,struct msghdr * msg,int flags)2514 ssize_t rrecvmsg(int socket, struct msghdr *msg, int flags)
2515 {
2516 	if (msg->msg_control && msg->msg_controllen)
2517 		return ERR(ENOTSUP);
2518 
2519 	return rrecvv(socket, msg->msg_iov, (int) msg->msg_iovlen, msg->msg_flags);
2520 }
2521 
rread(int socket,void * buf,size_t count)2522 ssize_t rread(int socket, void *buf, size_t count)
2523 {
2524 	return rrecv(socket, buf, count, 0);
2525 }
2526 
rreadv(int socket,const struct iovec * iov,int iovcnt)2527 ssize_t rreadv(int socket, const struct iovec *iov, int iovcnt)
2528 {
2529 	return rrecvv(socket, iov, iovcnt, 0);
2530 }
2531 
rs_send_iomaps(struct rsocket * rs,int flags)2532 static int rs_send_iomaps(struct rsocket *rs, int flags)
2533 {
2534 	struct rs_iomap_mr *iomr;
2535 	struct ibv_sge sge;
2536 	struct rs_iomap iom;
2537 	int ret;
2538 
2539 	fastlock_acquire(&rs->map_lock);
2540 	while (!dlist_empty(&rs->iomap_queue)) {
2541 		if (!rs_can_send(rs)) {
2542 			ret = rs_get_comp(rs, rs_nonblocking(rs, flags),
2543 					  rs_conn_can_send);
2544 			if (ret)
2545 				break;
2546 			if (!(rs->state & rs_writable)) {
2547 				ret = ERR(ECONNRESET);
2548 				break;
2549 			}
2550 		}
2551 
2552 		iomr = container_of(rs->iomap_queue.next, struct rs_iomap_mr, entry);
2553 		if (!(rs->opts & RS_OPT_SWAP_SGL)) {
2554 			iom.offset = iomr->offset;
2555 			iom.sge.addr = (uintptr_t) iomr->mr->addr;
2556 			iom.sge.length = iomr->mr->length;
2557 			iom.sge.key = iomr->mr->rkey;
2558 		} else {
2559 			iom.offset = bswap_64(iomr->offset);
2560 			iom.sge.addr = bswap_64((uintptr_t) iomr->mr->addr);
2561 			iom.sge.length = bswap_32(iomr->mr->length);
2562 			iom.sge.key = bswap_32(iomr->mr->rkey);
2563 		}
2564 
2565 		if (rs->sq_inline >= sizeof iom) {
2566 			sge.addr = (uintptr_t) &iom;
2567 			sge.length = sizeof iom;
2568 			sge.lkey = 0;
2569 			ret = rs_write_iomap(rs, iomr, &sge, 1, IBV_SEND_INLINE);
2570 		} else if (rs_sbuf_left(rs) >= sizeof iom) {
2571 			memcpy((void *) (uintptr_t) rs->ssgl[0].addr, &iom, sizeof iom);
2572 			rs->ssgl[0].length = sizeof iom;
2573 			ret = rs_write_iomap(rs, iomr, rs->ssgl, 1, 0);
2574 			if (rs_sbuf_left(rs) > sizeof iom)
2575 				rs->ssgl[0].addr += sizeof iom;
2576 			else
2577 				rs->ssgl[0].addr = (uintptr_t) rs->sbuf;
2578 		} else {
2579 			rs->ssgl[0].length = rs_sbuf_left(rs);
2580 			memcpy((void *) (uintptr_t) rs->ssgl[0].addr, &iom,
2581 				rs->ssgl[0].length);
2582 			rs->ssgl[1].length = sizeof iom - rs->ssgl[0].length;
2583 			memcpy(rs->sbuf, ((void *) &iom) + rs->ssgl[0].length,
2584 			       rs->ssgl[1].length);
2585 			ret = rs_write_iomap(rs, iomr, rs->ssgl, 2, 0);
2586 			rs->ssgl[0].addr = (uintptr_t) rs->sbuf + rs->ssgl[1].length;
2587 		}
2588 		dlist_remove(&iomr->entry);
2589 		dlist_insert_tail(&iomr->entry, &rs->iomap_list);
2590 		if (ret)
2591 			break;
2592 	}
2593 
2594 	rs->iomap_pending = !dlist_empty(&rs->iomap_queue);
2595 	fastlock_release(&rs->map_lock);
2596 	return ret;
2597 }
2598 
ds_sendv_udp(struct rsocket * rs,const struct iovec * iov,int iovcnt,int flags,uint8_t op)2599 static ssize_t ds_sendv_udp(struct rsocket *rs, const struct iovec *iov,
2600 			    int iovcnt, int flags, uint8_t op)
2601 {
2602 	struct ds_udp_header hdr;
2603 	struct msghdr msg;
2604 	struct iovec miov[8];
2605 	ssize_t ret;
2606 
2607 	if (iovcnt > 8)
2608 		return ERR(ENOTSUP);
2609 
2610 	hdr.tag = htobe32(DS_UDP_TAG);
2611 	hdr.version = rs->conn_dest->qp->hdr.version;
2612 	hdr.op = op;
2613 	hdr.reserved = 0;
2614 	hdr.qpn = htobe32(rs->conn_dest->qp->cm_id->qp->qp_num & 0xFFFFFF);
2615 	if (rs->conn_dest->qp->hdr.version == 4) {
2616 		hdr.length = DS_UDP_IPV4_HDR_LEN;
2617 		hdr.addr.ipv4 = rs->conn_dest->qp->hdr.addr.ipv4;
2618 	} else {
2619 		hdr.length = DS_UDP_IPV6_HDR_LEN;
2620 		memcpy(hdr.addr.ipv6, &rs->conn_dest->qp->hdr.addr.ipv6, 16);
2621 	}
2622 
2623 	miov[0].iov_base = &hdr;
2624 	miov[0].iov_len = hdr.length;
2625 	if (iov && iovcnt)
2626 		memcpy(&miov[1], iov, sizeof(*iov) * iovcnt);
2627 
2628 	memset(&msg, 0, sizeof msg);
2629 	msg.msg_name = &rs->conn_dest->addr;
2630 	msg.msg_namelen = ucma_addrlen(&rs->conn_dest->addr.sa);
2631 	msg.msg_iov = miov;
2632 	msg.msg_iovlen = iovcnt + 1;
2633 	ret = sendmsg(rs->udp_sock, &msg, flags);
2634 	return ret > 0 ? ret - hdr.length : ret;
2635 }
2636 
ds_send_udp(struct rsocket * rs,const void * buf,size_t len,int flags,uint8_t op)2637 static ssize_t ds_send_udp(struct rsocket *rs, const void *buf, size_t len,
2638 			   int flags, uint8_t op)
2639 {
2640 	struct iovec iov;
2641 	if (buf && len) {
2642 		iov.iov_base = (void *) buf;
2643 		iov.iov_len = len;
2644 		return ds_sendv_udp(rs, &iov, 1, flags, op);
2645 	} else {
2646 		return ds_sendv_udp(rs, NULL, 0, flags, op);
2647 	}
2648 }
2649 
dsend(struct rsocket * rs,const void * buf,size_t len,int flags)2650 static ssize_t dsend(struct rsocket *rs, const void *buf, size_t len, int flags)
2651 {
2652 	struct ds_smsg *msg;
2653 	struct ibv_sge sge;
2654 	uint64_t offset;
2655 	int ret = 0;
2656 
2657 	if (!rs->conn_dest->ah)
2658 		return ds_send_udp(rs, buf, len, flags, RS_OP_DATA);
2659 
2660 	if (!ds_can_send(rs)) {
2661 		ret = ds_get_comp(rs, rs_nonblocking(rs, flags), ds_can_send);
2662 		if (ret)
2663 			return ret;
2664 	}
2665 
2666 	msg = rs->smsg_free;
2667 	rs->smsg_free = msg->next;
2668 	rs->sqe_avail--;
2669 
2670 	memcpy((void *) msg, &rs->conn_dest->qp->hdr, rs->conn_dest->qp->hdr.length);
2671 	memcpy((void *) msg + rs->conn_dest->qp->hdr.length, buf, len);
2672 	sge.addr = (uintptr_t) msg;
2673 	sge.length = rs->conn_dest->qp->hdr.length + len;
2674 	sge.lkey = rs->conn_dest->qp->smr->lkey;
2675 	offset = (uint8_t *) msg - rs->sbuf;
2676 
2677 	ret = ds_post_send(rs, &sge, offset);
2678 	return ret ? ret : len;
2679 }
2680 
2681 /*
2682  * We overlap sending the data, by posting a small work request immediately,
2683  * then increasing the size of the send on each iteration.
2684  */
rsend(int socket,const void * buf,size_t len,int flags)2685 ssize_t rsend(int socket, const void *buf, size_t len, int flags)
2686 {
2687 	struct rsocket *rs;
2688 	struct ibv_sge sge;
2689 	size_t left = len;
2690 	uint32_t xfer_size, olen = RS_OLAP_START_SIZE;
2691 	int ret = 0;
2692 
2693 	rs = idm_at(&idm, socket);
2694 	if (rs->type == SOCK_DGRAM) {
2695 		fastlock_acquire(&rs->slock);
2696 		ret = dsend(rs, buf, len, flags);
2697 		fastlock_release(&rs->slock);
2698 		return ret;
2699 	}
2700 
2701 	if (rs->state & rs_opening) {
2702 		ret = rs_do_connect(rs);
2703 		if (ret) {
2704 			if (errno == EINPROGRESS)
2705 				errno = EAGAIN;
2706 			return ret;
2707 		}
2708 	}
2709 
2710 	fastlock_acquire(&rs->slock);
2711 	if (rs->iomap_pending) {
2712 		ret = rs_send_iomaps(rs, flags);
2713 		if (ret)
2714 			goto out;
2715 	}
2716 	for (; left; left -= xfer_size, buf += xfer_size) {
2717 		if (!rs_can_send(rs)) {
2718 			ret = rs_get_comp(rs, rs_nonblocking(rs, flags),
2719 					  rs_conn_can_send);
2720 			if (ret)
2721 				break;
2722 			if (!(rs->state & rs_writable)) {
2723 				ret = ERR(ECONNRESET);
2724 				break;
2725 			}
2726 		}
2727 
2728 		if (olen < left) {
2729 			xfer_size = olen;
2730 			if (olen < RS_MAX_TRANSFER)
2731 				olen <<= 1;
2732 		} else {
2733 			xfer_size = left;
2734 		}
2735 
2736 		if (xfer_size > rs->sbuf_bytes_avail)
2737 			xfer_size = rs->sbuf_bytes_avail;
2738 		if (xfer_size > rs->target_sgl[rs->target_sge].length)
2739 			xfer_size = rs->target_sgl[rs->target_sge].length;
2740 
2741 		if (xfer_size <= rs->sq_inline) {
2742 			sge.addr = (uintptr_t) buf;
2743 			sge.length = xfer_size;
2744 			sge.lkey = 0;
2745 			ret = rs_write_data(rs, &sge, 1, xfer_size, IBV_SEND_INLINE);
2746 		} else if (xfer_size <= rs_sbuf_left(rs)) {
2747 			memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf, xfer_size);
2748 			rs->ssgl[0].length = xfer_size;
2749 			ret = rs_write_data(rs, rs->ssgl, 1, xfer_size, 0);
2750 			if (xfer_size < rs_sbuf_left(rs))
2751 				rs->ssgl[0].addr += xfer_size;
2752 			else
2753 				rs->ssgl[0].addr = (uintptr_t) rs->sbuf;
2754 		} else {
2755 			rs->ssgl[0].length = rs_sbuf_left(rs);
2756 			memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf,
2757 				rs->ssgl[0].length);
2758 			rs->ssgl[1].length = xfer_size - rs->ssgl[0].length;
2759 			memcpy(rs->sbuf, buf + rs->ssgl[0].length, rs->ssgl[1].length);
2760 			ret = rs_write_data(rs, rs->ssgl, 2, xfer_size, 0);
2761 			rs->ssgl[0].addr = (uintptr_t) rs->sbuf + rs->ssgl[1].length;
2762 		}
2763 		if (ret)
2764 			break;
2765 	}
2766 out:
2767 	fastlock_release(&rs->slock);
2768 
2769 	return (ret && left == len) ? ret : len - left;
2770 }
2771 
rsendto(int socket,const void * buf,size_t len,int flags,const struct sockaddr * dest_addr,socklen_t addrlen)2772 ssize_t rsendto(int socket, const void *buf, size_t len, int flags,
2773 		const struct sockaddr *dest_addr, socklen_t addrlen)
2774 {
2775 	struct rsocket *rs;
2776 	int ret;
2777 
2778 	rs = idm_at(&idm, socket);
2779 	if (rs->type == SOCK_STREAM) {
2780 		if (dest_addr || addrlen)
2781 			return ERR(EISCONN);
2782 
2783 		return rsend(socket, buf, len, flags);
2784 	}
2785 
2786 	if (rs->state == rs_init) {
2787 		ret = ds_init_ep(rs);
2788 		if (ret)
2789 			return ret;
2790 	}
2791 
2792 	fastlock_acquire(&rs->slock);
2793 	if (!rs->conn_dest || ds_compare_addr(dest_addr, &rs->conn_dest->addr)) {
2794 		ret = ds_get_dest(rs, dest_addr, addrlen, &rs->conn_dest);
2795 		if (ret)
2796 			goto out;
2797 	}
2798 
2799 	ret = dsend(rs, buf, len, flags);
2800 out:
2801 	fastlock_release(&rs->slock);
2802 	return ret;
2803 }
2804 
rs_copy_iov(void * dst,const struct iovec ** iov,size_t * offset,size_t len)2805 static void rs_copy_iov(void *dst, const struct iovec **iov, size_t *offset, size_t len)
2806 {
2807 	size_t size;
2808 
2809 	while (len) {
2810 		size = (*iov)->iov_len - *offset;
2811 		if (size > len) {
2812 			memcpy (dst, (*iov)->iov_base + *offset, len);
2813 			*offset += len;
2814 			break;
2815 		}
2816 
2817 		memcpy(dst, (*iov)->iov_base + *offset, size);
2818 		len -= size;
2819 		dst += size;
2820 		(*iov)++;
2821 		*offset = 0;
2822 	}
2823 }
2824 
rsendv(int socket,const struct iovec * iov,int iovcnt,int flags)2825 static ssize_t rsendv(int socket, const struct iovec *iov, int iovcnt, int flags)
2826 {
2827 	struct rsocket *rs;
2828 	const struct iovec *cur_iov;
2829 	size_t left, len, offset = 0;
2830 	uint32_t xfer_size, olen = RS_OLAP_START_SIZE;
2831 	int i, ret = 0;
2832 
2833 	rs = idm_at(&idm, socket);
2834 	if (rs->state & rs_opening) {
2835 		ret = rs_do_connect(rs);
2836 		if (ret) {
2837 			if (errno == EINPROGRESS)
2838 				errno = EAGAIN;
2839 			return ret;
2840 		}
2841 	}
2842 
2843 	cur_iov = iov;
2844 	len = iov[0].iov_len;
2845 	for (i = 1; i < iovcnt; i++)
2846 		len += iov[i].iov_len;
2847 	left = len;
2848 
2849 	fastlock_acquire(&rs->slock);
2850 	if (rs->iomap_pending) {
2851 		ret = rs_send_iomaps(rs, flags);
2852 		if (ret)
2853 			goto out;
2854 	}
2855 	for (; left; left -= xfer_size) {
2856 		if (!rs_can_send(rs)) {
2857 			ret = rs_get_comp(rs, rs_nonblocking(rs, flags),
2858 					  rs_conn_can_send);
2859 			if (ret)
2860 				break;
2861 			if (!(rs->state & rs_writable)) {
2862 				ret = ERR(ECONNRESET);
2863 				break;
2864 			}
2865 		}
2866 
2867 		if (olen < left) {
2868 			xfer_size = olen;
2869 			if (olen < RS_MAX_TRANSFER)
2870 				olen <<= 1;
2871 		} else {
2872 			xfer_size = left;
2873 		}
2874 
2875 		if (xfer_size > rs->sbuf_bytes_avail)
2876 			xfer_size = rs->sbuf_bytes_avail;
2877 		if (xfer_size > rs->target_sgl[rs->target_sge].length)
2878 			xfer_size = rs->target_sgl[rs->target_sge].length;
2879 
2880 		if (xfer_size <= rs_sbuf_left(rs)) {
2881 			rs_copy_iov((void *) (uintptr_t) rs->ssgl[0].addr,
2882 				    &cur_iov, &offset, xfer_size);
2883 			rs->ssgl[0].length = xfer_size;
2884 			ret = rs_write_data(rs, rs->ssgl, 1, xfer_size,
2885 					    xfer_size <= rs->sq_inline ? IBV_SEND_INLINE : 0);
2886 			if (xfer_size < rs_sbuf_left(rs))
2887 				rs->ssgl[0].addr += xfer_size;
2888 			else
2889 				rs->ssgl[0].addr = (uintptr_t) rs->sbuf;
2890 		} else {
2891 			rs->ssgl[0].length = rs_sbuf_left(rs);
2892 			rs_copy_iov((void *) (uintptr_t) rs->ssgl[0].addr, &cur_iov,
2893 				    &offset, rs->ssgl[0].length);
2894 			rs->ssgl[1].length = xfer_size - rs->ssgl[0].length;
2895 			rs_copy_iov(rs->sbuf, &cur_iov, &offset, rs->ssgl[1].length);
2896 			ret = rs_write_data(rs, rs->ssgl, 2, xfer_size,
2897 					    xfer_size <= rs->sq_inline ? IBV_SEND_INLINE : 0);
2898 			rs->ssgl[0].addr = (uintptr_t) rs->sbuf + rs->ssgl[1].length;
2899 		}
2900 		if (ret)
2901 			break;
2902 	}
2903 out:
2904 	fastlock_release(&rs->slock);
2905 
2906 	return (ret && left == len) ? ret : len - left;
2907 }
2908 
rsendmsg(int socket,const struct msghdr * msg,int flags)2909 ssize_t rsendmsg(int socket, const struct msghdr *msg, int flags)
2910 {
2911 	if (msg->msg_control && msg->msg_controllen)
2912 		return ERR(ENOTSUP);
2913 
2914 	return rsendv(socket, msg->msg_iov, (int) msg->msg_iovlen, flags);
2915 }
2916 
rwrite(int socket,const void * buf,size_t count)2917 ssize_t rwrite(int socket, const void *buf, size_t count)
2918 {
2919 	return rsend(socket, buf, count, 0);
2920 }
2921 
rwritev(int socket,const struct iovec * iov,int iovcnt)2922 ssize_t rwritev(int socket, const struct iovec *iov, int iovcnt)
2923 {
2924 	return rsendv(socket, iov, iovcnt, 0);
2925 }
2926 
rs_fds_alloc(nfds_t nfds)2927 static struct pollfd *rs_fds_alloc(nfds_t nfds)
2928 {
2929 	static __thread struct pollfd *rfds;
2930 	static __thread nfds_t rnfds;
2931 
2932 	if (nfds > rnfds) {
2933 		if (rfds)
2934 			free(rfds);
2935 
2936 		rfds = malloc(sizeof(*rfds) * nfds);
2937 		rnfds = rfds ? nfds : 0;
2938 	}
2939 
2940 	return rfds;
2941 }
2942 
rs_poll_rs(struct rsocket * rs,int events,int nonblock,int (* test)(struct rsocket * rs))2943 static int rs_poll_rs(struct rsocket *rs, int events,
2944 		      int nonblock, int (*test)(struct rsocket *rs))
2945 {
2946 	struct pollfd fds;
2947 	short revents;
2948 	int ret;
2949 
2950 check_cq:
2951 	if ((rs->type == SOCK_STREAM) && ((rs->state & rs_connected) ||
2952 	     (rs->state == rs_disconnected) || (rs->state & rs_error))) {
2953 		rs_process_cq(rs, nonblock, test);
2954 
2955 		revents = 0;
2956 		if ((events & POLLIN) && rs_conn_have_rdata(rs))
2957 			revents |= POLLIN;
2958 		if ((events & POLLOUT) && rs_can_send(rs))
2959 			revents |= POLLOUT;
2960 		if (!(rs->state & rs_connected)) {
2961 			if (rs->state == rs_disconnected)
2962 				revents |= POLLHUP;
2963 			else
2964 				revents |= POLLERR;
2965 		}
2966 
2967 		return revents;
2968 	} else if (rs->type == SOCK_DGRAM) {
2969 		ds_process_cqs(rs, nonblock, test);
2970 
2971 		revents = 0;
2972 		if ((events & POLLIN) && rs_have_rdata(rs))
2973 			revents |= POLLIN;
2974 		if ((events & POLLOUT) && ds_can_send(rs))
2975 			revents |= POLLOUT;
2976 
2977 		return revents;
2978 	}
2979 
2980 	if (rs->state == rs_listening) {
2981 		fds.fd = rs->cm_id->channel->fd;
2982 		fds.events = events;
2983 		fds.revents = 0;
2984 		poll(&fds, 1, 0);
2985 		return fds.revents;
2986 	}
2987 
2988 	if (rs->state & rs_opening) {
2989 		ret = rs_do_connect(rs);
2990 		if (ret && (errno == EINPROGRESS)) {
2991 			errno = 0;
2992 		} else {
2993 			goto check_cq;
2994 		}
2995 	}
2996 
2997 	if (rs->state == rs_connect_error) {
2998 		revents = 0;
2999 		if (events & POLLOUT)
3000 			revents |= POLLOUT;
3001 		if (events & POLLIN)
3002 			revents |= POLLIN;
3003 		revents |= POLLERR;
3004 		return revents;
3005 	}
3006 
3007 	return 0;
3008 }
3009 
rs_poll_check(struct pollfd * fds,nfds_t nfds)3010 static int rs_poll_check(struct pollfd *fds, nfds_t nfds)
3011 {
3012 	struct rsocket *rs;
3013 	int i, cnt = 0;
3014 
3015 	for (i = 0; i < nfds; i++) {
3016 		rs = idm_lookup(&idm, fds[i].fd);
3017 		if (rs)
3018 			fds[i].revents = rs_poll_rs(rs, fds[i].events, 1, rs_poll_all);
3019 		else
3020 			poll(&fds[i], 1, 0);
3021 
3022 		if (fds[i].revents)
3023 			cnt++;
3024 	}
3025 	return cnt;
3026 }
3027 
rs_poll_arm(struct pollfd * rfds,struct pollfd * fds,nfds_t nfds)3028 static int rs_poll_arm(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
3029 {
3030 	struct rsocket *rs;
3031 	int i;
3032 
3033 	for (i = 0; i < nfds; i++) {
3034 		rs = idm_lookup(&idm, fds[i].fd);
3035 		if (rs) {
3036 			fds[i].revents = rs_poll_rs(rs, fds[i].events, 0, rs_is_cq_armed);
3037 			if (fds[i].revents)
3038 				return 1;
3039 
3040 			if (rs->type == SOCK_STREAM) {
3041 				if (rs->state >= rs_connected)
3042 					rfds[i].fd = rs->cm_id->recv_cq_channel->fd;
3043 				else
3044 					rfds[i].fd = rs->cm_id->channel->fd;
3045 			} else {
3046 				rfds[i].fd = rs->epfd;
3047 			}
3048 			rfds[i].events = POLLIN;
3049 		} else {
3050 			rfds[i].fd = fds[i].fd;
3051 			rfds[i].events = fds[i].events;
3052 		}
3053 		rfds[i].revents = 0;
3054 	}
3055 	return 0;
3056 }
3057 
rs_poll_events(struct pollfd * rfds,struct pollfd * fds,nfds_t nfds)3058 static int rs_poll_events(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
3059 {
3060 	struct rsocket *rs;
3061 	int i, cnt = 0;
3062 
3063 	for (i = 0; i < nfds; i++) {
3064 		if (!rfds[i].revents)
3065 			continue;
3066 
3067 		rs = idm_lookup(&idm, fds[i].fd);
3068 		if (rs) {
3069 			fastlock_acquire(&rs->cq_wait_lock);
3070 			if (rs->type == SOCK_STREAM)
3071 				rs_get_cq_event(rs);
3072 			else
3073 				ds_get_cq_event(rs);
3074 			fastlock_release(&rs->cq_wait_lock);
3075 			fds[i].revents = rs_poll_rs(rs, fds[i].events, 1, rs_poll_all);
3076 		} else {
3077 			fds[i].revents = rfds[i].revents;
3078 		}
3079 		if (fds[i].revents)
3080 			cnt++;
3081 	}
3082 	return cnt;
3083 }
3084 
3085 /*
3086  * We need to poll *all* fd's that the user specifies at least once.
3087  * Note that we may receive events on an rsocket that may not be reported
3088  * to the user (e.g. connection events or credit updates).  Process those
3089  * events, then return to polling until we find ones of interest.
3090  */
rpoll(struct pollfd * fds,nfds_t nfds,int timeout)3091 int rpoll(struct pollfd *fds, nfds_t nfds, int timeout)
3092 {
3093 	struct timeval s, e;
3094 	struct pollfd *rfds;
3095 	uint32_t poll_time = 0;
3096 	int ret;
3097 
3098 	do {
3099 		ret = rs_poll_check(fds, nfds);
3100 		if (ret || !timeout)
3101 			return ret;
3102 
3103 		if (!poll_time)
3104 			gettimeofday(&s, NULL);
3105 
3106 		gettimeofday(&e, NULL);
3107 		poll_time = (e.tv_sec - s.tv_sec) * 1000000 +
3108 			    (e.tv_usec - s.tv_usec) + 1;
3109 	} while (poll_time <= polling_time);
3110 
3111 	rfds = rs_fds_alloc(nfds);
3112 	if (!rfds)
3113 		return ERR(ENOMEM);
3114 
3115 	do {
3116 		ret = rs_poll_arm(rfds, fds, nfds);
3117 		if (ret)
3118 			break;
3119 
3120 		ret = poll(rfds, nfds, timeout);
3121 		if (ret <= 0)
3122 			break;
3123 
3124 		ret = rs_poll_events(rfds, fds, nfds);
3125 	} while (!ret);
3126 
3127 	return ret;
3128 }
3129 
3130 static struct pollfd *
rs_select_to_poll(int * nfds,fd_set * readfds,fd_set * writefds,fd_set * exceptfds)3131 rs_select_to_poll(int *nfds, fd_set *readfds, fd_set *writefds, fd_set *exceptfds)
3132 {
3133 	struct pollfd *fds;
3134 	int fd, i = 0;
3135 
3136 	fds = calloc(*nfds, sizeof(*fds));
3137 	if (!fds)
3138 		return NULL;
3139 
3140 	for (fd = 0; fd < *nfds; fd++) {
3141 		if (readfds && FD_ISSET(fd, readfds)) {
3142 			fds[i].fd = fd;
3143 			fds[i].events = POLLIN;
3144 		}
3145 
3146 		if (writefds && FD_ISSET(fd, writefds)) {
3147 			fds[i].fd = fd;
3148 			fds[i].events |= POLLOUT;
3149 		}
3150 
3151 		if (exceptfds && FD_ISSET(fd, exceptfds))
3152 			fds[i].fd = fd;
3153 
3154 		if (fds[i].fd)
3155 			i++;
3156 	}
3157 
3158 	*nfds = i;
3159 	return fds;
3160 }
3161 
3162 static int
rs_poll_to_select(int nfds,struct pollfd * fds,fd_set * readfds,fd_set * writefds,fd_set * exceptfds)3163 rs_poll_to_select(int nfds, struct pollfd *fds, fd_set *readfds,
3164 		  fd_set *writefds, fd_set *exceptfds)
3165 {
3166 	int i, cnt = 0;
3167 
3168 	for (i = 0; i < nfds; i++) {
3169 		if (readfds && (fds[i].revents & (POLLIN | POLLHUP))) {
3170 			FD_SET(fds[i].fd, readfds);
3171 			cnt++;
3172 		}
3173 
3174 		if (writefds && (fds[i].revents & POLLOUT)) {
3175 			FD_SET(fds[i].fd, writefds);
3176 			cnt++;
3177 		}
3178 
3179 		if (exceptfds && (fds[i].revents & ~(POLLIN | POLLOUT))) {
3180 			FD_SET(fds[i].fd, exceptfds);
3181 			cnt++;
3182 		}
3183 	}
3184 	return cnt;
3185 }
3186 
rs_convert_timeout(struct timeval * timeout)3187 static int rs_convert_timeout(struct timeval *timeout)
3188 {
3189 	return !timeout ? -1 :
3190 		timeout->tv_sec * 1000 + timeout->tv_usec / 1000;
3191 }
3192 
rselect(int nfds,fd_set * readfds,fd_set * writefds,fd_set * exceptfds,struct timeval * timeout)3193 int rselect(int nfds, fd_set *readfds, fd_set *writefds,
3194 	    fd_set *exceptfds, struct timeval *timeout)
3195 {
3196 	struct pollfd *fds;
3197 	int ret;
3198 
3199 	fds = rs_select_to_poll(&nfds, readfds, writefds, exceptfds);
3200 	if (!fds)
3201 		return ERR(ENOMEM);
3202 
3203 	ret = rpoll(fds, nfds, rs_convert_timeout(timeout));
3204 
3205 	if (readfds)
3206 		FD_ZERO(readfds);
3207 	if (writefds)
3208 		FD_ZERO(writefds);
3209 	if (exceptfds)
3210 		FD_ZERO(exceptfds);
3211 
3212 	if (ret > 0)
3213 		ret = rs_poll_to_select(nfds, fds, readfds, writefds, exceptfds);
3214 
3215 	free(fds);
3216 	return ret;
3217 }
3218 
3219 /*
3220  * For graceful disconnect, notify the remote side that we're
3221  * disconnecting and wait until all outstanding sends complete, provided
3222  * that the remote side has not sent a disconnect message.
3223  */
rshutdown(int socket,int how)3224 int rshutdown(int socket, int how)
3225 {
3226 	struct rsocket *rs;
3227 	int ctrl, ret = 0;
3228 
3229 	rs = idm_lookup(&idm, socket);
3230 	if (!rs)
3231 		return ERR(EBADF);
3232 	if (rs->opts & RS_OPT_SVC_ACTIVE)
3233 		rs_notify_svc(&tcp_svc, rs, RS_SVC_REM_KEEPALIVE);
3234 
3235 	if (rs->fd_flags & O_NONBLOCK)
3236 		rs_set_nonblocking(rs, 0);
3237 
3238 	if (rs->state & rs_connected) {
3239 		if (how == SHUT_RDWR) {
3240 			ctrl = RS_CTRL_DISCONNECT;
3241 			rs->state &= ~(rs_readable | rs_writable);
3242 		} else if (how == SHUT_WR) {
3243 			rs->state &= ~rs_writable;
3244 			ctrl = (rs->state & rs_readable) ?
3245 				RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT;
3246 		} else {
3247 			rs->state &= ~rs_readable;
3248 			if (rs->state & rs_writable)
3249 				goto out;
3250 			ctrl = RS_CTRL_DISCONNECT;
3251 		}
3252 		if (!rs_ctrl_avail(rs)) {
3253 			ret = rs_process_cq(rs, 0, rs_conn_can_send_ctrl);
3254 			if (ret)
3255 				goto out;
3256 		}
3257 
3258 		if ((rs->state & rs_connected) && rs_ctrl_avail(rs)) {
3259 			rs->ctrl_seqno++;
3260 			ret = rs_post_msg(rs, rs_msg_set(RS_OP_CTRL, ctrl));
3261 		}
3262 	}
3263 
3264 	if (rs->state & rs_connected)
3265 		rs_process_cq(rs, 0, rs_conn_all_sends_done);
3266 
3267 out:
3268 	if ((rs->fd_flags & O_NONBLOCK) && (rs->state & rs_connected))
3269 		rs_set_nonblocking(rs, rs->fd_flags);
3270 
3271 	if (rs->state & rs_disconnected) {
3272 		/* Generate event by flushing receives to unblock rpoll */
3273 		ibv_req_notify_cq(rs->cm_id->recv_cq, 0);
3274 		ucma_shutdown(rs->cm_id);
3275 	}
3276 
3277 	return ret;
3278 }
3279 
ds_shutdown(struct rsocket * rs)3280 static void ds_shutdown(struct rsocket *rs)
3281 {
3282 	if (rs->opts & RS_OPT_SVC_ACTIVE)
3283 		rs_notify_svc(&udp_svc, rs, RS_SVC_REM_DGRAM);
3284 
3285 	if (rs->fd_flags & O_NONBLOCK)
3286 		rs_set_nonblocking(rs, 0);
3287 
3288 	rs->state &= ~(rs_readable | rs_writable);
3289 	ds_process_cqs(rs, 0, ds_all_sends_done);
3290 
3291 	if (rs->fd_flags & O_NONBLOCK)
3292 		rs_set_nonblocking(rs, rs->fd_flags);
3293 }
3294 
rclose(int socket)3295 int rclose(int socket)
3296 {
3297 	struct rsocket *rs;
3298 
3299 	rs = idm_lookup(&idm, socket);
3300 	if (!rs)
3301 		return EBADF;
3302 	if (rs->type == SOCK_STREAM) {
3303 		if (rs->state & rs_connected)
3304 			rshutdown(socket, SHUT_RDWR);
3305 		else if (rs->opts & RS_OPT_SVC_ACTIVE)
3306 			rs_notify_svc(&tcp_svc, rs, RS_SVC_REM_KEEPALIVE);
3307 	} else {
3308 		ds_shutdown(rs);
3309 	}
3310 
3311 	rs_free(rs);
3312 	return 0;
3313 }
3314 
rs_copy_addr(struct sockaddr * dst,struct sockaddr * src,socklen_t * len)3315 static void rs_copy_addr(struct sockaddr *dst, struct sockaddr *src, socklen_t *len)
3316 {
3317 	socklen_t size;
3318 
3319 	if (src->sa_family == AF_INET) {
3320 		size = min_t(socklen_t, *len, sizeof(struct sockaddr_in));
3321 		*len = sizeof(struct sockaddr_in);
3322 	} else {
3323 		size = min_t(socklen_t, *len, sizeof(struct sockaddr_in6));
3324 		*len = sizeof(struct sockaddr_in6);
3325 	}
3326 	memcpy(dst, src, size);
3327 }
3328 
rgetpeername(int socket,struct sockaddr * addr,socklen_t * addrlen)3329 int rgetpeername(int socket, struct sockaddr *addr, socklen_t *addrlen)
3330 {
3331 	struct rsocket *rs;
3332 
3333 	rs = idm_lookup(&idm, socket);
3334 	if (!rs)
3335 		return ERR(EBADF);
3336 	if (rs->type == SOCK_STREAM) {
3337 		rs_copy_addr(addr, rdma_get_peer_addr(rs->cm_id), addrlen);
3338 		return 0;
3339 	} else {
3340 		return getpeername(rs->udp_sock, addr, addrlen);
3341 	}
3342 }
3343 
rgetsockname(int socket,struct sockaddr * addr,socklen_t * addrlen)3344 int rgetsockname(int socket, struct sockaddr *addr, socklen_t *addrlen)
3345 {
3346 	struct rsocket *rs;
3347 
3348 	rs = idm_lookup(&idm, socket);
3349 	if (!rs)
3350 		return ERR(EBADF);
3351 	if (rs->type == SOCK_STREAM) {
3352 		rs_copy_addr(addr, rdma_get_local_addr(rs->cm_id), addrlen);
3353 		return 0;
3354 	} else {
3355 		return getsockname(rs->udp_sock, addr, addrlen);
3356 	}
3357 }
3358 
rs_set_keepalive(struct rsocket * rs,int on)3359 static int rs_set_keepalive(struct rsocket *rs, int on)
3360 {
3361 	FILE *f;
3362 	int ret;
3363 
3364 	if ((on && (rs->opts & RS_OPT_SVC_ACTIVE)) ||
3365 	    (!on && !(rs->opts & RS_OPT_SVC_ACTIVE)))
3366 		return 0;
3367 
3368 	if (on) {
3369 		if (!rs->keepalive_time) {
3370 			if ((f = fopen("/proc/sys/net/ipv4/tcp_keepalive_time", "r"))) {
3371 				if (fscanf(f, "%u", &rs->keepalive_time) != 1)
3372 					rs->keepalive_time = 7200;
3373 				fclose(f);
3374 			} else {
3375 				rs->keepalive_time = 7200;
3376 			}
3377 		}
3378 		ret = rs_notify_svc(&tcp_svc, rs, RS_SVC_ADD_KEEPALIVE);
3379 	} else {
3380 		ret = rs_notify_svc(&tcp_svc, rs, RS_SVC_REM_KEEPALIVE);
3381 	}
3382 
3383 	return ret;
3384 }
3385 
rsetsockopt(int socket,int level,int optname,const void * optval,socklen_t optlen)3386 int rsetsockopt(int socket, int level, int optname,
3387 		const void *optval, socklen_t optlen)
3388 {
3389 	struct rsocket *rs;
3390 	int ret, opt_on = 0;
3391 	uint64_t *opts = NULL;
3392 
3393 	ret = ERR(ENOTSUP);
3394 	rs = idm_lookup(&idm, socket);
3395 	if (!rs)
3396 		return ERR(EBADF);
3397 	if (rs->type == SOCK_DGRAM && level != SOL_RDMA) {
3398 		ret = setsockopt(rs->udp_sock, level, optname, optval, optlen);
3399 		if (ret)
3400 			return ret;
3401 	}
3402 
3403 	switch (level) {
3404 	case SOL_SOCKET:
3405 		opts = &rs->so_opts;
3406 		switch (optname) {
3407 		case SO_REUSEADDR:
3408 			if (rs->type == SOCK_STREAM) {
3409 				ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID,
3410 						      RDMA_OPTION_ID_REUSEADDR,
3411 						      (void *) optval, optlen);
3412 				if (ret && ((errno == ENOSYS) || ((rs->state != rs_init) &&
3413 				    rs->cm_id->context &&
3414 				    (rs->cm_id->verbs->device->transport_type == IBV_TRANSPORT_IB))))
3415 					ret = 0;
3416 			}
3417 			opt_on = *(int *) optval;
3418 			break;
3419 		case SO_RCVBUF:
3420 			if ((rs->type == SOCK_STREAM && !rs->rbuf) ||
3421 			    (rs->type == SOCK_DGRAM && !rs->qp_list))
3422 				rs->rbuf_size = (*(uint32_t *) optval) << 1;
3423 			ret = 0;
3424 			break;
3425 		case SO_SNDBUF:
3426 			if (!rs->sbuf)
3427 				rs->sbuf_size = (*(uint32_t *) optval) << 1;
3428 			if (rs->sbuf_size < RS_SNDLOWAT)
3429 				rs->sbuf_size = RS_SNDLOWAT << 1;
3430 			ret = 0;
3431 			break;
3432 		case SO_LINGER:
3433 			/* Invert value so default so_opt = 0 is on */
3434 			opt_on =  !((struct linger *) optval)->l_onoff;
3435 			ret = 0;
3436 			break;
3437 		case SO_KEEPALIVE:
3438 			ret = rs_set_keepalive(rs, *(int *) optval);
3439 			opt_on = rs->opts & RS_OPT_SVC_ACTIVE;
3440 			break;
3441 		case SO_OOBINLINE:
3442 			opt_on = *(int *) optval;
3443 			ret = 0;
3444 			break;
3445 		default:
3446 			break;
3447 		}
3448 		break;
3449 	case IPPROTO_TCP:
3450 		opts = &rs->tcp_opts;
3451 		switch (optname) {
3452 		case TCP_KEEPCNT:
3453 		case TCP_KEEPINTVL:
3454 			ret = 0;   /* N/A - we're using a reliable connection */
3455 			break;
3456 		case TCP_KEEPIDLE:
3457 			if (*(int *) optval <= 0) {
3458 				ret = ERR(EINVAL);
3459 				break;
3460 			}
3461 			rs->keepalive_time = *(int *) optval;
3462 			ret = (rs->opts & RS_OPT_SVC_ACTIVE) ?
3463 			      rs_notify_svc(&tcp_svc, rs, RS_SVC_MOD_KEEPALIVE) : 0;
3464 			break;
3465 		case TCP_NODELAY:
3466 			opt_on = *(int *) optval;
3467 			ret = 0;
3468 			break;
3469 		case TCP_MAXSEG:
3470 			ret = 0;
3471 			break;
3472 		default:
3473 			break;
3474 		}
3475 		break;
3476 	case IPPROTO_IPV6:
3477 		opts = &rs->ipv6_opts;
3478 		switch (optname) {
3479 		case IPV6_V6ONLY:
3480 			if (rs->type == SOCK_STREAM) {
3481 				ret = rdma_set_option(rs->cm_id, RDMA_OPTION_ID,
3482 						      RDMA_OPTION_ID_AFONLY,
3483 						      (void *) optval, optlen);
3484 			}
3485 			opt_on = *(int *) optval;
3486 			break;
3487 		default:
3488 			break;
3489 		}
3490 		break;
3491 	case SOL_RDMA:
3492 		if (rs->state >= rs_opening) {
3493 			ret = ERR(EINVAL);
3494 			break;
3495 		}
3496 
3497 		switch (optname) {
3498 		case RDMA_SQSIZE:
3499 			rs->sq_size = min_t(uint32_t, (*(uint32_t *)optval),
3500 					    RS_QP_MAX_SIZE);
3501 			ret = 0;
3502 			break;
3503 		case RDMA_RQSIZE:
3504 			rs->rq_size = min_t(uint32_t, (*(uint32_t *)optval),
3505 					    RS_QP_MAX_SIZE);
3506 			ret = 0;
3507 			break;
3508 		case RDMA_INLINE:
3509 			rs->sq_inline = min_t(uint32_t, *(uint32_t *)optval,
3510 					      RS_QP_MAX_SIZE);
3511 			ret = 0;
3512 			break;
3513 		case RDMA_IOMAPSIZE:
3514 			rs->target_iomap_size = (uint16_t) rs_scale_to_value(
3515 				(uint8_t) rs_value_to_scale(*(int *) optval, 8), 8);
3516 			ret = 0;
3517 			break;
3518 		case RDMA_ROUTE:
3519 			if ((rs->optval = malloc(optlen))) {
3520 				memcpy(rs->optval, optval, optlen);
3521 				rs->optlen = optlen;
3522 				ret = 0;
3523 			} else {
3524 				ret = ERR(ENOMEM);
3525 			}
3526 			break;
3527 		default:
3528 			break;
3529 		}
3530 		break;
3531 	default:
3532 		break;
3533 	}
3534 
3535 	if (!ret && opts) {
3536 		if (opt_on)
3537 			*opts |= (1 << optname);
3538 		else
3539 			*opts &= ~(1 << optname);
3540 	}
3541 
3542 	return ret;
3543 }
3544 
rs_convert_sa_path(struct ibv_sa_path_rec * sa_path,struct ibv_path_data * path_data)3545 static void rs_convert_sa_path(struct ibv_sa_path_rec *sa_path,
3546 			       struct ibv_path_data *path_data)
3547 {
3548 	uint32_t fl_hop;
3549 
3550 	memset(path_data, 0, sizeof(*path_data));
3551 	path_data->path.dgid = sa_path->dgid;
3552 	path_data->path.sgid = sa_path->sgid;
3553 	path_data->path.dlid = sa_path->dlid;
3554 	path_data->path.slid = sa_path->slid;
3555 	fl_hop = be32toh(sa_path->flow_label) << 8;
3556 	path_data->path.flowlabel_hoplimit = htobe32(fl_hop | sa_path->hop_limit);
3557 	path_data->path.tclass = sa_path->traffic_class;
3558 	path_data->path.reversible_numpath = sa_path->reversible << 7 | 1;
3559 	path_data->path.pkey = sa_path->pkey;
3560 	path_data->path.qosclass_sl = htobe16(sa_path->sl);
3561 	path_data->path.mtu = sa_path->mtu | 2 << 6;	/* exactly */
3562 	path_data->path.rate = sa_path->rate | 2 << 6;
3563 	path_data->path.packetlifetime = sa_path->packet_life_time | 2 << 6;
3564 	path_data->flags= sa_path->preference;
3565 }
3566 
rgetsockopt(int socket,int level,int optname,void * optval,socklen_t * optlen)3567 int rgetsockopt(int socket, int level, int optname,
3568 		void *optval, socklen_t *optlen)
3569 {
3570 	struct rsocket *rs;
3571 	void *opt;
3572 	struct ibv_sa_path_rec *path_rec;
3573 	struct ibv_path_data path_data;
3574 	socklen_t len;
3575 	int ret = 0;
3576 	int num_paths;
3577 
3578 	rs = idm_lookup(&idm, socket);
3579 	if (!rs)
3580 		return ERR(EBADF);
3581 	switch (level) {
3582 	case SOL_SOCKET:
3583 		switch (optname) {
3584 		case SO_REUSEADDR:
3585 		case SO_KEEPALIVE:
3586 		case SO_OOBINLINE:
3587 			*((int *) optval) = !!(rs->so_opts & (1 << optname));
3588 			*optlen = sizeof(int);
3589 			break;
3590 		case SO_RCVBUF:
3591 			*((int *) optval) = rs->rbuf_size;
3592 			*optlen = sizeof(int);
3593 			break;
3594 		case SO_SNDBUF:
3595 			*((int *) optval) = rs->sbuf_size;
3596 			*optlen = sizeof(int);
3597 			break;
3598 		case SO_LINGER:
3599 			/* Value is inverted so default so_opt = 0 is on */
3600 			((struct linger *) optval)->l_onoff =
3601 					!(rs->so_opts & (1 << optname));
3602 			((struct linger *) optval)->l_linger = 0;
3603 			*optlen = sizeof(struct linger);
3604 			break;
3605 		case SO_ERROR:
3606 			*((int *) optval) = rs->err;
3607 			*optlen = sizeof(int);
3608 			rs->err = 0;
3609 			break;
3610 		default:
3611 			ret = ENOTSUP;
3612 			break;
3613 		}
3614 		break;
3615 	case IPPROTO_TCP:
3616 		switch (optname) {
3617 		case TCP_KEEPCNT:
3618 		case TCP_KEEPINTVL:
3619 			*((int *) optval) = 1;   /* N/A */
3620 			break;
3621 		case TCP_KEEPIDLE:
3622 			*((int *) optval) = (int) rs->keepalive_time;
3623 			*optlen = sizeof(int);
3624 			break;
3625 		case TCP_NODELAY:
3626 			*((int *) optval) = !!(rs->tcp_opts & (1 << optname));
3627 			*optlen = sizeof(int);
3628 			break;
3629 		case TCP_MAXSEG:
3630 			*((int *) optval) = (rs->cm_id && rs->cm_id->route.num_paths) ?
3631 					    1 << (7 + rs->cm_id->route.path_rec->mtu) :
3632 					    2048;
3633 			*optlen = sizeof(int);
3634 			break;
3635 		default:
3636 			ret = ENOTSUP;
3637 			break;
3638 		}
3639 		break;
3640 	case IPPROTO_IPV6:
3641 		switch (optname) {
3642 		case IPV6_V6ONLY:
3643 			*((int *) optval) = !!(rs->ipv6_opts & (1 << optname));
3644 			*optlen = sizeof(int);
3645 			break;
3646 		default:
3647 			ret = ENOTSUP;
3648 			break;
3649 		}
3650 		break;
3651 	case SOL_RDMA:
3652 		switch (optname) {
3653 		case RDMA_SQSIZE:
3654 			*((int *) optval) = rs->sq_size;
3655 			*optlen = sizeof(int);
3656 			break;
3657 		case RDMA_RQSIZE:
3658 			*((int *) optval) = rs->rq_size;
3659 			*optlen = sizeof(int);
3660 			break;
3661 		case RDMA_INLINE:
3662 			*((int *) optval) = rs->sq_inline;
3663 			*optlen = sizeof(int);
3664 			break;
3665 		case RDMA_IOMAPSIZE:
3666 			*((int *) optval) = rs->target_iomap_size;
3667 			*optlen = sizeof(int);
3668 			break;
3669 		case RDMA_ROUTE:
3670 			if (rs->optval) {
3671 				if (*optlen < rs->optlen) {
3672 					ret = EINVAL;
3673 				} else {
3674 					memcpy(rs->optval, optval, rs->optlen);
3675 					*optlen = rs->optlen;
3676 				}
3677 			} else {
3678 				if (*optlen < sizeof(path_data)) {
3679 					ret = EINVAL;
3680 				} else {
3681 					len = 0;
3682 					opt = optval;
3683 					path_rec = rs->cm_id->route.path_rec;
3684 					num_paths = 0;
3685 					while (len + sizeof(path_data) <= *optlen &&
3686 					       num_paths < rs->cm_id->route.num_paths) {
3687 						rs_convert_sa_path(path_rec, &path_data);
3688 						memcpy(opt, &path_data, sizeof(path_data));
3689 						len += sizeof(path_data);
3690 						opt += sizeof(path_data);
3691 						path_rec++;
3692 						num_paths++;
3693 					}
3694 					*optlen = len;
3695 					ret = 0;
3696 				}
3697 			}
3698 			break;
3699 		default:
3700 			ret = ENOTSUP;
3701 			break;
3702 		}
3703 		break;
3704 	default:
3705 		ret = ENOTSUP;
3706 		break;
3707 	}
3708 
3709 	return rdma_seterrno(ret);
3710 }
3711 
rfcntl(int socket,int cmd,...)3712 int rfcntl(int socket, int cmd, ... /* arg */ )
3713 {
3714 	struct rsocket *rs;
3715 	va_list args;
3716 	int param;
3717 	int ret = 0;
3718 
3719 	rs = idm_lookup(&idm, socket);
3720 	if (!rs)
3721 		return ERR(EBADF);
3722 	va_start(args, cmd);
3723 	switch (cmd) {
3724 	case F_GETFL:
3725 		ret = rs->fd_flags;
3726 		break;
3727 	case F_SETFL:
3728 		param = va_arg(args, int);
3729 		if ((rs->fd_flags & O_NONBLOCK) != (param & O_NONBLOCK))
3730 			ret = rs_set_nonblocking(rs, param & O_NONBLOCK);
3731 
3732 		if (!ret)
3733 			rs->fd_flags = param;
3734 		break;
3735 	default:
3736 		ret = ERR(ENOTSUP);
3737 		break;
3738 	}
3739 	va_end(args);
3740 	return ret;
3741 }
3742 
rs_get_iomap_mr(struct rsocket * rs)3743 static struct rs_iomap_mr *rs_get_iomap_mr(struct rsocket *rs)
3744 {
3745 	int i;
3746 
3747 	if (!rs->remote_iomappings) {
3748 		rs->remote_iomappings = calloc(rs->remote_iomap.length,
3749 					       sizeof(*rs->remote_iomappings));
3750 		if (!rs->remote_iomappings)
3751 			return NULL;
3752 
3753 		for (i = 0; i < rs->remote_iomap.length; i++)
3754 			rs->remote_iomappings[i].index = i;
3755 	}
3756 
3757 	for (i = 0; i < rs->remote_iomap.length; i++) {
3758 		if (!rs->remote_iomappings[i].mr)
3759 			return &rs->remote_iomappings[i];
3760 	}
3761 	return NULL;
3762 }
3763 
3764 /*
3765  * If an offset is given, we map to it.  If offset is -1, then we map the
3766  * offset to the address of buf.  We do not check for conflicts, which must
3767  * be fixed at some point.
3768  */
riomap(int socket,void * buf,size_t len,int prot,int flags,off_t offset)3769 off_t riomap(int socket, void *buf, size_t len, int prot, int flags, off_t offset)
3770 {
3771 	struct rsocket *rs;
3772 	struct rs_iomap_mr *iomr;
3773 	int access = IBV_ACCESS_LOCAL_WRITE;
3774 
3775 	rs = idm_at(&idm, socket);
3776 	if (!rs->cm_id->pd || (prot & ~(PROT_WRITE | PROT_NONE)))
3777 		return ERR(EINVAL);
3778 
3779 	fastlock_acquire(&rs->map_lock);
3780 	if (prot & PROT_WRITE) {
3781 		iomr = rs_get_iomap_mr(rs);
3782 		access |= IBV_ACCESS_REMOTE_WRITE;
3783 	} else {
3784 		iomr = calloc(1, sizeof(*iomr));
3785 		iomr->index = -1;
3786 	}
3787 	if (!iomr) {
3788 		offset = ERR(ENOMEM);
3789 		goto out;
3790 	}
3791 
3792 	iomr->mr = ibv_reg_mr(rs->cm_id->pd, buf, len, access);
3793 	if (!iomr->mr) {
3794 		if (iomr->index < 0)
3795 			free(iomr);
3796 		offset = -1;
3797 		goto out;
3798 	}
3799 
3800 	if (offset == -1)
3801 		offset = (uintptr_t) buf;
3802 	iomr->offset = offset;
3803 	atomic_store(&iomr->refcnt, 1);
3804 
3805 	if (iomr->index >= 0) {
3806 		dlist_insert_tail(&iomr->entry, &rs->iomap_queue);
3807 		rs->iomap_pending = 1;
3808 	} else {
3809 		dlist_insert_tail(&iomr->entry, &rs->iomap_list);
3810 	}
3811 out:
3812 	fastlock_release(&rs->map_lock);
3813 	return offset;
3814 }
3815 
riounmap(int socket,void * buf,size_t len)3816 int riounmap(int socket, void *buf, size_t len)
3817 {
3818 	struct rsocket *rs;
3819 	struct rs_iomap_mr *iomr;
3820 	dlist_entry *entry;
3821 	int ret = 0;
3822 
3823 	rs = idm_at(&idm, socket);
3824 	fastlock_acquire(&rs->map_lock);
3825 
3826 	for (entry = rs->iomap_list.next; entry != &rs->iomap_list;
3827 	     entry = entry->next) {
3828 		iomr = container_of(entry, struct rs_iomap_mr, entry);
3829 		if (iomr->mr->addr == buf && iomr->mr->length == len) {
3830 			rs_release_iomap_mr(iomr);
3831 			goto out;
3832 		}
3833 	}
3834 
3835 	for (entry = rs->iomap_queue.next; entry != &rs->iomap_queue;
3836 	     entry = entry->next) {
3837 		iomr = container_of(entry, struct rs_iomap_mr, entry);
3838 		if (iomr->mr->addr == buf && iomr->mr->length == len) {
3839 			rs_release_iomap_mr(iomr);
3840 			goto out;
3841 		}
3842 	}
3843 	ret = ERR(EINVAL);
3844 out:
3845 	fastlock_release(&rs->map_lock);
3846 	return ret;
3847 }
3848 
rs_find_iomap(struct rsocket * rs,off_t offset)3849 static struct rs_iomap *rs_find_iomap(struct rsocket *rs, off_t offset)
3850 {
3851 	int i;
3852 
3853 	for (i = 0; i < rs->target_iomap_size; i++) {
3854 		if (offset >= rs->target_iomap[i].offset &&
3855 		    offset < rs->target_iomap[i].offset + rs->target_iomap[i].sge.length)
3856 			return &rs->target_iomap[i];
3857 	}
3858 	return NULL;
3859 }
3860 
riowrite(int socket,const void * buf,size_t count,off_t offset,int flags)3861 size_t riowrite(int socket, const void *buf, size_t count, off_t offset, int flags)
3862 {
3863 	struct rsocket *rs;
3864 	struct rs_iomap *iom = NULL;
3865 	struct ibv_sge sge;
3866 	size_t left = count;
3867 	uint32_t xfer_size, olen = RS_OLAP_START_SIZE;
3868 	int ret = 0;
3869 
3870 	rs = idm_at(&idm, socket);
3871 	fastlock_acquire(&rs->slock);
3872 	if (rs->iomap_pending) {
3873 		ret = rs_send_iomaps(rs, flags);
3874 		if (ret)
3875 			goto out;
3876 	}
3877 	for (; left; left -= xfer_size, buf += xfer_size, offset += xfer_size) {
3878 		if (!iom || offset > iom->offset + iom->sge.length) {
3879 			iom = rs_find_iomap(rs, offset);
3880 			if (!iom)
3881 				break;
3882 		}
3883 
3884 		if (!rs_can_send(rs)) {
3885 			ret = rs_get_comp(rs, rs_nonblocking(rs, flags),
3886 					  rs_conn_can_send);
3887 			if (ret)
3888 				break;
3889 			if (!(rs->state & rs_writable)) {
3890 				ret = ERR(ECONNRESET);
3891 				break;
3892 			}
3893 		}
3894 
3895 		if (olen < left) {
3896 			xfer_size = olen;
3897 			if (olen < RS_MAX_TRANSFER)
3898 				olen <<= 1;
3899 		} else {
3900 			xfer_size = left;
3901 		}
3902 
3903 		if (xfer_size > rs->sbuf_bytes_avail)
3904 			xfer_size = rs->sbuf_bytes_avail;
3905 		if (xfer_size > iom->offset + iom->sge.length - offset)
3906 			xfer_size = iom->offset + iom->sge.length - offset;
3907 
3908 		if (xfer_size <= rs->sq_inline) {
3909 			sge.addr = (uintptr_t) buf;
3910 			sge.length = xfer_size;
3911 			sge.lkey = 0;
3912 			ret = rs_write_direct(rs, iom, offset, &sge, 1,
3913 					      xfer_size, IBV_SEND_INLINE);
3914 		} else if (xfer_size <= rs_sbuf_left(rs)) {
3915 			memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf, xfer_size);
3916 			rs->ssgl[0].length = xfer_size;
3917 			ret = rs_write_direct(rs, iom, offset, rs->ssgl, 1, xfer_size, 0);
3918 			if (xfer_size < rs_sbuf_left(rs))
3919 				rs->ssgl[0].addr += xfer_size;
3920 			else
3921 				rs->ssgl[0].addr = (uintptr_t) rs->sbuf;
3922 		} else {
3923 			rs->ssgl[0].length = rs_sbuf_left(rs);
3924 			memcpy((void *) (uintptr_t) rs->ssgl[0].addr, buf,
3925 				rs->ssgl[0].length);
3926 			rs->ssgl[1].length = xfer_size - rs->ssgl[0].length;
3927 			memcpy(rs->sbuf, buf + rs->ssgl[0].length, rs->ssgl[1].length);
3928 			ret = rs_write_direct(rs, iom, offset, rs->ssgl, 2, xfer_size, 0);
3929 			rs->ssgl[0].addr = (uintptr_t) rs->sbuf + rs->ssgl[1].length;
3930 		}
3931 		if (ret)
3932 			break;
3933 	}
3934 out:
3935 	fastlock_release(&rs->slock);
3936 
3937 	return (ret && left == count) ? ret : count - left;
3938 }
3939 
3940 /****************************************************************************
3941  * Service Processing Threads
3942  ****************************************************************************/
3943 
rs_svc_grow_sets(struct rs_svc * svc,int grow_size)3944 static int rs_svc_grow_sets(struct rs_svc *svc, int grow_size)
3945 {
3946 	struct rsocket **rss;
3947 	void *set, *contexts;
3948 
3949 	set = calloc(svc->size + grow_size, sizeof(*rss) + svc->context_size);
3950 	if (!set)
3951 		return ENOMEM;
3952 
3953 	svc->size += grow_size;
3954 	rss = set;
3955 	contexts = set + sizeof(*rss) * svc->size;
3956 	if (svc->cnt) {
3957 		memcpy(rss, svc->rss, sizeof(*rss) * (svc->cnt + 1));
3958 		memcpy(contexts, svc->contexts, svc->context_size * (svc->cnt + 1));
3959 	}
3960 
3961 	free(svc->rss);
3962 	svc->rss = rss;
3963 	svc->contexts = contexts;
3964 	return 0;
3965 }
3966 
3967 /*
3968  * Index 0 is reserved for the service's communication socket.
3969  */
rs_svc_add_rs(struct rs_svc * svc,struct rsocket * rs)3970 static int rs_svc_add_rs(struct rs_svc *svc, struct rsocket *rs)
3971 {
3972 	int ret;
3973 
3974 	if (svc->cnt >= svc->size - 1) {
3975 		ret = rs_svc_grow_sets(svc, 4);
3976 		if (ret)
3977 			return ret;
3978 	}
3979 
3980 	svc->rss[++svc->cnt] = rs;
3981 	return 0;
3982 }
3983 
rs_svc_index(struct rs_svc * svc,struct rsocket * rs)3984 static int rs_svc_index(struct rs_svc *svc, struct rsocket *rs)
3985 {
3986 	int i;
3987 
3988 	for (i = 1; i <= svc->cnt; i++) {
3989 		if (svc->rss[i] == rs)
3990 			return i;
3991 	}
3992 	return -1;
3993 }
3994 
rs_svc_rm_rs(struct rs_svc * svc,struct rsocket * rs)3995 static int rs_svc_rm_rs(struct rs_svc *svc, struct rsocket *rs)
3996 {
3997 	int i;
3998 
3999 	if ((i = rs_svc_index(svc, rs)) >= 0) {
4000 		svc->rss[i] = svc->rss[svc->cnt];
4001 		memcpy(svc->contexts + i * svc->context_size,
4002 		       svc->contexts + svc->cnt * svc->context_size,
4003 		       svc->context_size);
4004 		svc->cnt--;
4005 		return 0;
4006 	}
4007 	return EBADF;
4008 }
4009 
udp_svc_process_sock(struct rs_svc * svc)4010 static void udp_svc_process_sock(struct rs_svc *svc)
4011 {
4012 	struct rs_svc_msg msg;
4013 
4014 	read_all(svc->sock[1], &msg, sizeof msg);
4015 	switch (msg.cmd) {
4016 	case RS_SVC_ADD_DGRAM:
4017 		msg.status = rs_svc_add_rs(svc, msg.rs);
4018 		if (!msg.status) {
4019 			msg.rs->opts |= RS_OPT_SVC_ACTIVE;
4020 			udp_svc_fds = svc->contexts;
4021 			udp_svc_fds[svc->cnt].fd = msg.rs->udp_sock;
4022 			udp_svc_fds[svc->cnt].events = POLLIN;
4023 			udp_svc_fds[svc->cnt].revents = 0;
4024 		}
4025 		break;
4026 	case RS_SVC_REM_DGRAM:
4027 		msg.status = rs_svc_rm_rs(svc, msg.rs);
4028 		if (!msg.status)
4029 			msg.rs->opts &= ~RS_OPT_SVC_ACTIVE;
4030 		break;
4031 	case RS_SVC_NOOP:
4032 		msg.status = 0;
4033 		break;
4034 	default:
4035 		break;
4036 	}
4037 
4038 	write_all(svc->sock[1], &msg, sizeof msg);
4039 }
4040 
udp_svc_sgid_index(struct ds_dest * dest,union ibv_gid * sgid)4041 static uint8_t udp_svc_sgid_index(struct ds_dest *dest, union ibv_gid *sgid)
4042 {
4043 	union ibv_gid gid;
4044 	int i;
4045 
4046 	for (i = 0; i < 16; i++) {
4047 		ibv_query_gid(dest->qp->cm_id->verbs, dest->qp->cm_id->port_num,
4048 			      i, &gid);
4049 		if (!memcmp(sgid, &gid, sizeof gid))
4050 			return i;
4051 	}
4052 	return 0;
4053 }
4054 
udp_svc_path_bits(struct ds_dest * dest)4055 static uint8_t udp_svc_path_bits(struct ds_dest *dest)
4056 {
4057 	struct ibv_port_attr attr;
4058 
4059 	if (!ibv_query_port(dest->qp->cm_id->verbs, dest->qp->cm_id->port_num, &attr))
4060 		return (uint8_t) ((1 << attr.lmc) - 1);
4061 	return 0x7f;
4062 }
4063 
udp_svc_create_ah(struct rsocket * rs,struct ds_dest * dest,uint32_t qpn)4064 static void udp_svc_create_ah(struct rsocket *rs, struct ds_dest *dest, uint32_t qpn)
4065 {
4066 	union socket_addr saddr;
4067 	struct rdma_cm_id *id;
4068 	struct ibv_ah_attr attr;
4069 	int ret;
4070 
4071 	if (dest->ah) {
4072 		fastlock_acquire(&rs->slock);
4073 		ibv_destroy_ah(dest->ah);
4074 		dest->ah = NULL;
4075 		fastlock_release(&rs->slock);
4076 	}
4077 
4078 	ret = rdma_create_id(NULL, &id, NULL, dest->qp->cm_id->ps);
4079 	if  (ret)
4080 		return;
4081 
4082 	memcpy(&saddr, rdma_get_local_addr(dest->qp->cm_id),
4083 	       ucma_addrlen(rdma_get_local_addr(dest->qp->cm_id)));
4084 	if (saddr.sa.sa_family == AF_INET)
4085 		saddr.sin.sin_port = 0;
4086 	else
4087 		saddr.sin6.sin6_port = 0;
4088 	ret = rdma_resolve_addr(id, &saddr.sa, &dest->addr.sa, 2000);
4089 	if (ret)
4090 		goto out;
4091 
4092 	ret = rdma_resolve_route(id, 2000);
4093 	if (ret)
4094 		goto out;
4095 
4096 	memset(&attr, 0, sizeof attr);
4097 	if (id->route.path_rec->hop_limit > 1) {
4098 		attr.is_global = 1;
4099 		attr.grh.dgid = id->route.path_rec->dgid;
4100 		attr.grh.flow_label = be32toh(id->route.path_rec->flow_label);
4101 		attr.grh.sgid_index = udp_svc_sgid_index(dest, &id->route.path_rec->sgid);
4102 		attr.grh.hop_limit = id->route.path_rec->hop_limit;
4103 		attr.grh.traffic_class = id->route.path_rec->traffic_class;
4104 	}
4105 	attr.dlid = be16toh(id->route.path_rec->dlid);
4106 	attr.sl = id->route.path_rec->sl;
4107 	attr.src_path_bits = be16toh(id->route.path_rec->slid) & udp_svc_path_bits(dest);
4108 	attr.static_rate = id->route.path_rec->rate;
4109 	attr.port_num  = id->port_num;
4110 
4111 	fastlock_acquire(&rs->slock);
4112 	dest->qpn = qpn;
4113 	dest->ah = ibv_create_ah(dest->qp->cm_id->pd, &attr);
4114 	fastlock_release(&rs->slock);
4115 out:
4116 	rdma_destroy_id(id);
4117 }
4118 
udp_svc_valid_udp_hdr(struct ds_udp_header * udp_hdr,union socket_addr * addr)4119 static int udp_svc_valid_udp_hdr(struct ds_udp_header *udp_hdr,
4120 				 union socket_addr *addr)
4121 {
4122 	return (udp_hdr->tag == htobe32(DS_UDP_TAG)) &&
4123 		((udp_hdr->version == 4 && addr->sa.sa_family == AF_INET &&
4124 		  udp_hdr->length == DS_UDP_IPV4_HDR_LEN) ||
4125 		 (udp_hdr->version == 6 && addr->sa.sa_family == AF_INET6 &&
4126 		  udp_hdr->length == DS_UDP_IPV6_HDR_LEN));
4127 }
4128 
udp_svc_forward(struct rsocket * rs,void * buf,size_t len,union socket_addr * src)4129 static void udp_svc_forward(struct rsocket *rs, void *buf, size_t len,
4130 			    union socket_addr *src)
4131 {
4132 	struct ds_header hdr;
4133 	struct ds_smsg *msg;
4134 	struct ibv_sge sge;
4135 	uint64_t offset;
4136 
4137 	if (!ds_can_send(rs)) {
4138 		if (ds_get_comp(rs, 0, ds_can_send))
4139 			return;
4140 	}
4141 
4142 	msg = rs->smsg_free;
4143 	rs->smsg_free = msg->next;
4144 	rs->sqe_avail--;
4145 
4146 	ds_format_hdr(&hdr, src);
4147 	memcpy((void *) msg, &hdr, hdr.length);
4148 	memcpy((void *) msg + hdr.length, buf, len);
4149 	sge.addr = (uintptr_t) msg;
4150 	sge.length = hdr.length + len;
4151 	sge.lkey = rs->conn_dest->qp->smr->lkey;
4152 	offset = (uint8_t *) msg - rs->sbuf;
4153 
4154 	ds_post_send(rs, &sge, offset);
4155 }
4156 
udp_svc_process_rs(struct rsocket * rs)4157 static void udp_svc_process_rs(struct rsocket *rs)
4158 {
4159 	static uint8_t buf[RS_SNDLOWAT];
4160 	struct ds_dest *dest, *cur_dest;
4161 	struct ds_udp_header *udp_hdr;
4162 	union socket_addr addr;
4163 	socklen_t addrlen = sizeof addr;
4164 	int len, ret;
4165 	uint32_t qpn;
4166 
4167 	ret = recvfrom(rs->udp_sock, buf, sizeof buf, 0, &addr.sa, &addrlen);
4168 	if (ret < DS_UDP_IPV4_HDR_LEN)
4169 		return;
4170 
4171 	udp_hdr = (struct ds_udp_header *) buf;
4172 	if (!udp_svc_valid_udp_hdr(udp_hdr, &addr))
4173 		return;
4174 
4175 	len = ret - udp_hdr->length;
4176 	qpn = be32toh(udp_hdr->qpn) & 0xFFFFFF;
4177 
4178 	udp_hdr->tag = (__force __be32)be32toh(udp_hdr->tag);
4179 	udp_hdr->qpn = (__force __be32)qpn;
4180 
4181 	ret = ds_get_dest(rs, &addr.sa, addrlen, &dest);
4182 	if (ret)
4183 		return;
4184 
4185 	if (udp_hdr->op == RS_OP_DATA) {
4186 		fastlock_acquire(&rs->slock);
4187 		cur_dest = rs->conn_dest;
4188 		rs->conn_dest = dest;
4189 		ds_send_udp(rs, NULL, 0, 0, RS_OP_CTRL);
4190 		rs->conn_dest = cur_dest;
4191 		fastlock_release(&rs->slock);
4192 	}
4193 
4194 	if (!dest->ah || (dest->qpn != qpn))
4195 		udp_svc_create_ah(rs, dest, qpn);
4196 
4197 	/* to do: handle when dest local ip address doesn't match udp ip */
4198 	if (udp_hdr->op == RS_OP_DATA) {
4199 		fastlock_acquire(&rs->slock);
4200 		cur_dest = rs->conn_dest;
4201 		rs->conn_dest = &dest->qp->dest;
4202 		udp_svc_forward(rs, buf + udp_hdr->length, len, &addr);
4203 		rs->conn_dest = cur_dest;
4204 		fastlock_release(&rs->slock);
4205 	}
4206 }
4207 
udp_svc_run(void * arg)4208 static void *udp_svc_run(void *arg)
4209 {
4210 	struct rs_svc *svc = arg;
4211 	struct rs_svc_msg msg;
4212 	int i, ret;
4213 
4214 	ret = rs_svc_grow_sets(svc, 4);
4215 	if (ret) {
4216 		msg.status = ret;
4217 		write_all(svc->sock[1], &msg, sizeof msg);
4218 		return (void *) (uintptr_t) ret;
4219 	}
4220 
4221 	udp_svc_fds = svc->contexts;
4222 	udp_svc_fds[0].fd = svc->sock[1];
4223 	udp_svc_fds[0].events = POLLIN;
4224 	do {
4225 		for (i = 0; i <= svc->cnt; i++)
4226 			udp_svc_fds[i].revents = 0;
4227 
4228 		poll(udp_svc_fds, svc->cnt + 1, -1);
4229 		if (udp_svc_fds[0].revents)
4230 			udp_svc_process_sock(svc);
4231 
4232 		for (i = 1; i <= svc->cnt; i++) {
4233 			if (udp_svc_fds[i].revents)
4234 				udp_svc_process_rs(svc->rss[i]);
4235 		}
4236 	} while (svc->cnt >= 1);
4237 
4238 	return NULL;
4239 }
4240 
rs_get_time(void)4241 static uint32_t rs_get_time(void)
4242 {
4243 	struct timeval now;
4244 
4245 	memset(&now, 0, sizeof now);
4246 	gettimeofday(&now, NULL);
4247 	return (uint32_t) now.tv_sec;
4248 }
4249 
tcp_svc_process_sock(struct rs_svc * svc)4250 static void tcp_svc_process_sock(struct rs_svc *svc)
4251 {
4252 	struct rs_svc_msg msg;
4253 	int i;
4254 
4255 	read_all(svc->sock[1], &msg, sizeof msg);
4256 	switch (msg.cmd) {
4257 	case RS_SVC_ADD_KEEPALIVE:
4258 		msg.status = rs_svc_add_rs(svc, msg.rs);
4259 		if (!msg.status) {
4260 			msg.rs->opts |= RS_OPT_SVC_ACTIVE;
4261 			tcp_svc_timeouts = svc->contexts;
4262 			tcp_svc_timeouts[svc->cnt] = rs_get_time() +
4263 						     msg.rs->keepalive_time;
4264 		}
4265 		break;
4266 	case RS_SVC_REM_KEEPALIVE:
4267 		msg.status = rs_svc_rm_rs(svc, msg.rs);
4268 		if (!msg.status)
4269 			msg.rs->opts &= ~RS_OPT_SVC_ACTIVE;
4270 		break;
4271 	case RS_SVC_MOD_KEEPALIVE:
4272 		i = rs_svc_index(svc, msg.rs);
4273 		if (i >= 0) {
4274 			tcp_svc_timeouts[i] = rs_get_time() + msg.rs->keepalive_time;
4275 			msg.status = 0;
4276 		} else {
4277 			msg.status = EBADF;
4278 		}
4279 		break;
4280 	case RS_SVC_NOOP:
4281 		msg.status = 0;
4282 		break;
4283 	default:
4284 		break;
4285 	}
4286 	write_all(svc->sock[1], &msg, sizeof msg);
4287 }
4288 
4289 /*
4290  * Send a 0 byte RDMA write with immediate as keep-alive message.
4291  * This avoids the need for the receive side to do any acknowledgment.
4292  */
tcp_svc_send_keepalive(struct rsocket * rs)4293 static void tcp_svc_send_keepalive(struct rsocket *rs)
4294 {
4295 	fastlock_acquire(&rs->cq_lock);
4296 	if (rs_ctrl_avail(rs) && (rs->state & rs_connected)) {
4297 		rs->ctrl_seqno++;
4298 		rs_post_write(rs, NULL, 0, rs_msg_set(RS_OP_CTRL, RS_CTRL_KEEPALIVE),
4299 			      0, (uintptr_t) NULL, (uintptr_t) NULL);
4300 	}
4301 	fastlock_release(&rs->cq_lock);
4302 }
4303 
tcp_svc_run(void * arg)4304 static void *tcp_svc_run(void *arg)
4305 {
4306 	struct rs_svc *svc = arg;
4307 	struct rs_svc_msg msg;
4308 	struct pollfd fds;
4309 	uint32_t now, next_timeout;
4310 	int i, ret, timeout;
4311 
4312 	ret = rs_svc_grow_sets(svc, 16);
4313 	if (ret) {
4314 		msg.status = ret;
4315 		write_all(svc->sock[1], &msg, sizeof msg);
4316 		return (void *) (uintptr_t) ret;
4317 	}
4318 
4319 	tcp_svc_timeouts = svc->contexts;
4320 	fds.fd = svc->sock[1];
4321 	fds.events = POLLIN;
4322 	timeout = -1;
4323 	do {
4324 		poll(&fds, 1, timeout * 1000);
4325 		if (fds.revents)
4326 			tcp_svc_process_sock(svc);
4327 
4328 		now = rs_get_time();
4329 		next_timeout = ~0;
4330 		for (i = 1; i <= svc->cnt; i++) {
4331 			if (tcp_svc_timeouts[i] <= now) {
4332 				tcp_svc_send_keepalive(svc->rss[i]);
4333 				tcp_svc_timeouts[i] =
4334 					now + svc->rss[i]->keepalive_time;
4335 			}
4336 			if (tcp_svc_timeouts[i] < next_timeout)
4337 				next_timeout = tcp_svc_timeouts[i];
4338 		}
4339 		timeout = (int) (next_timeout - now);
4340 	} while (svc->cnt >= 1);
4341 
4342 	return NULL;
4343 }
4344