xref: /linux/fs/smb/client/smbdirect.c (revision 4a93d1ee2d0206970b6eb13fbffe07938cd95948)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2017, Microsoft Corporation.
4  *
5  *   Author(s): Long Li <longli@microsoft.com>
6  */
7 #include <linux/module.h>
8 #include <linux/highmem.h>
9 #include <linux/folio_queue.h>
10 #define __SMBDIRECT_SOCKET_DISCONNECT(__sc) smbd_disconnect_rdma_connection(__sc)
11 #include "../common/smbdirect/smbdirect_pdu.h"
12 #include "smbdirect.h"
13 #include "cifs_debug.h"
14 #include "cifsproto.h"
15 #include "smb2proto.h"
16 
17 const struct smbdirect_socket_parameters *smbd_get_parameters(struct smbd_connection *conn)
18 {
19 	struct smbdirect_socket *sc = &conn->socket;
20 
21 	return &sc->parameters;
22 }
23 
24 static struct smbdirect_recv_io *get_receive_buffer(
25 		struct smbdirect_socket *sc);
26 static void put_receive_buffer(
27 		struct smbdirect_socket *sc,
28 		struct smbdirect_recv_io *response);
29 static int allocate_receive_buffers(struct smbdirect_socket *sc, int num_buf);
30 static void destroy_receive_buffers(struct smbdirect_socket *sc);
31 
32 static void enqueue_reassembly(
33 		struct smbdirect_socket *sc,
34 		struct smbdirect_recv_io *response, int data_length);
35 static struct smbdirect_recv_io *_get_first_reassembly(
36 		struct smbdirect_socket *sc);
37 
38 static int smbd_post_send(struct smbdirect_socket *sc,
39 			  struct smbdirect_send_batch *batch,
40 			  struct smbdirect_send_io *request);
41 
42 static int smbd_post_recv(
43 		struct smbdirect_socket *sc,
44 		struct smbdirect_recv_io *response);
45 
46 static int smbd_post_send_empty(struct smbdirect_socket *sc);
47 
48 static void destroy_mr_list(struct smbdirect_socket *sc);
49 static int allocate_mr_list(struct smbdirect_socket *sc);
50 
51 struct smb_extract_to_rdma {
52 	struct ib_sge		*sge;
53 	unsigned int		nr_sge;
54 	unsigned int		max_sge;
55 	struct ib_device	*device;
56 	u32			local_dma_lkey;
57 	enum dma_data_direction	direction;
58 };
59 static ssize_t smb_extract_iter_to_rdma(struct iov_iter *iter, size_t len,
60 					struct smb_extract_to_rdma *rdma);
61 
62 /* Port numbers for SMBD transport */
63 #define SMB_PORT	445
64 #define SMBD_PORT	5445
65 
66 /* Address lookup and resolve timeout in ms */
67 #define RDMA_RESOLVE_TIMEOUT	5000
68 
69 /* SMBD negotiation timeout in seconds */
70 #define SMBD_NEGOTIATE_TIMEOUT	120
71 
72 /* The timeout to wait for a keepalive message from peer in seconds */
73 #define KEEPALIVE_RECV_TIMEOUT 5
74 
75 /* SMBD minimum receive size and fragmented sized defined in [MS-SMBD] */
76 #define SMBD_MIN_RECEIVE_SIZE		128
77 #define SMBD_MIN_FRAGMENTED_SIZE	131072
78 
79 /*
80  * Default maximum number of RDMA read/write outstanding on this connection
81  * This value is possibly decreased during QP creation on hardware limit
82  */
83 #define SMBD_CM_RESPONDER_RESOURCES	32
84 
85 /* Maximum number of retries on data transfer operations */
86 #define SMBD_CM_RETRY			6
87 /* No need to retry on Receiver Not Ready since SMBD manages credits */
88 #define SMBD_CM_RNR_RETRY		0
89 
90 /*
91  * User configurable initial values per SMBD transport connection
92  * as defined in [MS-SMBD] 3.1.1.1
93  * Those may change after a SMBD negotiation
94  */
95 /* The local peer's maximum number of credits to grant to the peer */
96 int smbd_receive_credit_max = 255;
97 
98 /* The remote peer's credit request of local peer */
99 int smbd_send_credit_target = 255;
100 
101 /* The maximum single message size can be sent to remote peer */
102 int smbd_max_send_size = 1364;
103 
104 /*
105  * The maximum fragmented upper-layer payload receive size supported
106  *
107  * Assume max_payload_per_credit is
108  * smbd_max_receive_size - 24 = 1340
109  *
110  * The maximum number would be
111  * smbd_receive_credit_max * max_payload_per_credit
112  *
113  *                       1340 * 255 = 341700 (0x536C4)
114  *
115  * The minimum value from the spec is 131072 (0x20000)
116  *
117  * For now we use the logic we used in ksmbd before:
118  *                 (1364 * 255) / 2 = 173910 (0x2A756)
119  */
120 int smbd_max_fragmented_recv_size = (1364 * 255) / 2;
121 
122 /*  The maximum single-message size which can be received */
123 int smbd_max_receive_size = 1364;
124 
125 /* The timeout to initiate send of a keepalive message on idle */
126 int smbd_keep_alive_interval = 120;
127 
128 /*
129  * User configurable initial values for RDMA transport
130  * The actual values used may be lower and are limited to hardware capabilities
131  */
132 /* Default maximum number of pages in a single RDMA write/read */
133 int smbd_max_frmr_depth = 2048;
134 
135 /* If payload is less than this byte, use RDMA send/recv not read/write */
136 int rdma_readwrite_threshold = 4096;
137 
138 /* Transport logging functions
139  * Logging are defined as classes. They can be OR'ed to define the actual
140  * logging level via module parameter smbd_logging_class
141  * e.g. cifs.smbd_logging_class=0xa0 will log all log_rdma_recv() and
142  * log_rdma_event()
143  */
144 #define LOG_OUTGOING			0x1
145 #define LOG_INCOMING			0x2
146 #define LOG_READ			0x4
147 #define LOG_WRITE			0x8
148 #define LOG_RDMA_SEND			0x10
149 #define LOG_RDMA_RECV			0x20
150 #define LOG_KEEP_ALIVE			0x40
151 #define LOG_RDMA_EVENT			0x80
152 #define LOG_RDMA_MR			0x100
153 static unsigned int smbd_logging_class;
154 module_param(smbd_logging_class, uint, 0644);
155 MODULE_PARM_DESC(smbd_logging_class,
156 	"Logging class for SMBD transport 0x0 to 0x100");
157 
158 #define ERR		0x0
159 #define INFO		0x1
160 static unsigned int smbd_logging_level = ERR;
161 module_param(smbd_logging_level, uint, 0644);
162 MODULE_PARM_DESC(smbd_logging_level,
163 	"Logging level for SMBD transport, 0 (default): error, 1: info");
164 
165 #define log_rdma(level, class, fmt, args...)				\
166 do {									\
167 	if (level <= smbd_logging_level || class & smbd_logging_class)	\
168 		cifs_dbg(VFS, "%s:%d " fmt, __func__, __LINE__, ##args);\
169 } while (0)
170 
171 #define log_outgoing(level, fmt, args...) \
172 		log_rdma(level, LOG_OUTGOING, fmt, ##args)
173 #define log_incoming(level, fmt, args...) \
174 		log_rdma(level, LOG_INCOMING, fmt, ##args)
175 #define log_read(level, fmt, args...)	log_rdma(level, LOG_READ, fmt, ##args)
176 #define log_write(level, fmt, args...)	log_rdma(level, LOG_WRITE, fmt, ##args)
177 #define log_rdma_send(level, fmt, args...) \
178 		log_rdma(level, LOG_RDMA_SEND, fmt, ##args)
179 #define log_rdma_recv(level, fmt, args...) \
180 		log_rdma(level, LOG_RDMA_RECV, fmt, ##args)
181 #define log_keep_alive(level, fmt, args...) \
182 		log_rdma(level, LOG_KEEP_ALIVE, fmt, ##args)
183 #define log_rdma_event(level, fmt, args...) \
184 		log_rdma(level, LOG_RDMA_EVENT, fmt, ##args)
185 #define log_rdma_mr(level, fmt, args...) \
186 		log_rdma(level, LOG_RDMA_MR, fmt, ##args)
187 
188 static void smbd_disconnect_wake_up_all(struct smbdirect_socket *sc)
189 {
190 	/*
191 	 * Wake up all waiters in all wait queues
192 	 * in order to notice the broken connection.
193 	 */
194 	wake_up_all(&sc->status_wait);
195 	wake_up_all(&sc->send_io.lcredits.wait_queue);
196 	wake_up_all(&sc->send_io.credits.wait_queue);
197 	wake_up_all(&sc->send_io.pending.dec_wait_queue);
198 	wake_up_all(&sc->send_io.pending.zero_wait_queue);
199 	wake_up_all(&sc->recv_io.reassembly.wait_queue);
200 	wake_up_all(&sc->mr_io.ready.wait_queue);
201 	wake_up_all(&sc->mr_io.cleanup.wait_queue);
202 }
203 
204 static void smbd_disconnect_rdma_work(struct work_struct *work)
205 {
206 	struct smbdirect_socket *sc =
207 		container_of(work, struct smbdirect_socket, disconnect_work);
208 
209 	if (sc->first_error == 0)
210 		sc->first_error = -ECONNABORTED;
211 
212 	/*
213 	 * make sure this and other work is not queued again
214 	 * but here we don't block and avoid
215 	 * disable[_delayed]_work_sync()
216 	 */
217 	disable_work(&sc->disconnect_work);
218 	disable_work(&sc->recv_io.posted.refill_work);
219 	disable_work(&sc->mr_io.recovery_work);
220 	disable_work(&sc->idle.immediate_work);
221 	disable_delayed_work(&sc->idle.timer_work);
222 
223 	switch (sc->status) {
224 	case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED:
225 	case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING:
226 	case SMBDIRECT_SOCKET_NEGOTIATE_FAILED:
227 	case SMBDIRECT_SOCKET_CONNECTED:
228 	case SMBDIRECT_SOCKET_ERROR:
229 		sc->status = SMBDIRECT_SOCKET_DISCONNECTING;
230 		rdma_disconnect(sc->rdma.cm_id);
231 		break;
232 
233 	case SMBDIRECT_SOCKET_CREATED:
234 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED:
235 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING:
236 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED:
237 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED:
238 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING:
239 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED:
240 	case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED:
241 	case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING:
242 	case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED:
243 		/*
244 		 * rdma_connect() never reached
245 		 * RDMA_CM_EVENT_ESTABLISHED
246 		 */
247 		sc->status = SMBDIRECT_SOCKET_DISCONNECTED;
248 		break;
249 
250 	case SMBDIRECT_SOCKET_DISCONNECTING:
251 	case SMBDIRECT_SOCKET_DISCONNECTED:
252 	case SMBDIRECT_SOCKET_DESTROYED:
253 		break;
254 	}
255 
256 	/*
257 	 * Wake up all waiters in all wait queues
258 	 * in order to notice the broken connection.
259 	 */
260 	smbd_disconnect_wake_up_all(sc);
261 }
262 
263 static void smbd_disconnect_rdma_connection(struct smbdirect_socket *sc)
264 {
265 	if (sc->first_error == 0)
266 		sc->first_error = -ECONNABORTED;
267 
268 	/*
269 	 * make sure other work (than disconnect_work) is
270 	 * not queued again but here we don't block and avoid
271 	 * disable[_delayed]_work_sync()
272 	 */
273 	disable_work(&sc->recv_io.posted.refill_work);
274 	disable_work(&sc->mr_io.recovery_work);
275 	disable_work(&sc->idle.immediate_work);
276 	disable_delayed_work(&sc->idle.timer_work);
277 
278 	switch (sc->status) {
279 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED:
280 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED:
281 	case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED:
282 	case SMBDIRECT_SOCKET_NEGOTIATE_FAILED:
283 	case SMBDIRECT_SOCKET_ERROR:
284 	case SMBDIRECT_SOCKET_DISCONNECTING:
285 	case SMBDIRECT_SOCKET_DISCONNECTED:
286 	case SMBDIRECT_SOCKET_DESTROYED:
287 		/*
288 		 * Keep the current error status
289 		 */
290 		break;
291 
292 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED:
293 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING:
294 		sc->status = SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED;
295 		break;
296 
297 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED:
298 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING:
299 		sc->status = SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED;
300 		break;
301 
302 	case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED:
303 	case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING:
304 		sc->status = SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED;
305 		break;
306 
307 	case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED:
308 	case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING:
309 		sc->status = SMBDIRECT_SOCKET_NEGOTIATE_FAILED;
310 		break;
311 
312 	case SMBDIRECT_SOCKET_CREATED:
313 		sc->status = SMBDIRECT_SOCKET_DISCONNECTED;
314 		break;
315 
316 	case SMBDIRECT_SOCKET_CONNECTED:
317 		sc->status = SMBDIRECT_SOCKET_ERROR;
318 		break;
319 	}
320 
321 	/*
322 	 * Wake up all waiters in all wait queues
323 	 * in order to notice the broken connection.
324 	 */
325 	smbd_disconnect_wake_up_all(sc);
326 
327 	queue_work(sc->workqueue, &sc->disconnect_work);
328 }
329 
330 /* Upcall from RDMA CM */
331 static int smbd_conn_upcall(
332 		struct rdma_cm_id *id, struct rdma_cm_event *event)
333 {
334 	struct smbdirect_socket *sc = id->context;
335 	struct smbdirect_socket_parameters *sp = &sc->parameters;
336 	const char *event_name = rdma_event_msg(event->event);
337 	u8 peer_initiator_depth;
338 	u8 peer_responder_resources;
339 
340 	log_rdma_event(INFO, "event=%s status=%d\n",
341 		event_name, event->status);
342 
343 	switch (event->event) {
344 	case RDMA_CM_EVENT_ADDR_RESOLVED:
345 		if (SMBDIRECT_CHECK_STATUS_DISCONNECT(sc, SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING))
346 			break;
347 		sc->status = SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED;
348 		wake_up(&sc->status_wait);
349 		break;
350 
351 	case RDMA_CM_EVENT_ROUTE_RESOLVED:
352 		if (SMBDIRECT_CHECK_STATUS_DISCONNECT(sc, SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING))
353 			break;
354 		sc->status = SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED;
355 		wake_up(&sc->status_wait);
356 		break;
357 
358 	case RDMA_CM_EVENT_ADDR_ERROR:
359 		log_rdma_event(ERR, "connecting failed event=%s\n", event_name);
360 		sc->status = SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED;
361 		smbd_disconnect_rdma_work(&sc->disconnect_work);
362 		break;
363 
364 	case RDMA_CM_EVENT_ROUTE_ERROR:
365 		log_rdma_event(ERR, "connecting failed event=%s\n", event_name);
366 		sc->status = SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED;
367 		smbd_disconnect_rdma_work(&sc->disconnect_work);
368 		break;
369 
370 	case RDMA_CM_EVENT_ESTABLISHED:
371 		log_rdma_event(INFO, "connected event=%s\n", event_name);
372 
373 		/*
374 		 * Here we work around an inconsistency between
375 		 * iWarp and other devices (at least rxe and irdma using RoCEv2)
376 		 */
377 		if (rdma_protocol_iwarp(id->device, id->port_num)) {
378 			/*
379 			 * iWarp devices report the peer's values
380 			 * with the perspective of the peer here.
381 			 * Tested with siw and irdma (in iwarp mode)
382 			 * We need to change to our perspective here,
383 			 * so we need to switch the values.
384 			 */
385 			peer_initiator_depth = event->param.conn.responder_resources;
386 			peer_responder_resources = event->param.conn.initiator_depth;
387 		} else {
388 			/*
389 			 * Non iWarp devices report the peer's values
390 			 * already changed to our perspective here.
391 			 * Tested with rxe and irdma (in roce mode).
392 			 */
393 			peer_initiator_depth = event->param.conn.initiator_depth;
394 			peer_responder_resources = event->param.conn.responder_resources;
395 		}
396 		if (rdma_protocol_iwarp(id->device, id->port_num) &&
397 		    event->param.conn.private_data_len == 8) {
398 			/*
399 			 * Legacy clients with only iWarp MPA v1 support
400 			 * need a private blob in order to negotiate
401 			 * the IRD/ORD values.
402 			 */
403 			const __be32 *ird_ord_hdr = event->param.conn.private_data;
404 			u32 ird32 = be32_to_cpu(ird_ord_hdr[0]);
405 			u32 ord32 = be32_to_cpu(ird_ord_hdr[1]);
406 
407 			/*
408 			 * cifs.ko sends the legacy IRD/ORD negotiation
409 			 * event if iWarp MPA v2 was used.
410 			 *
411 			 * Here we check that the values match and only
412 			 * mark the client as legacy if they don't match.
413 			 */
414 			if ((u32)event->param.conn.initiator_depth != ird32 ||
415 			    (u32)event->param.conn.responder_resources != ord32) {
416 				/*
417 				 * There are broken clients (old cifs.ko)
418 				 * using little endian and also
419 				 * struct rdma_conn_param only uses u8
420 				 * for initiator_depth and responder_resources,
421 				 * so we truncate the value to U8_MAX.
422 				 *
423 				 * smb_direct_accept_client() will then
424 				 * do the real negotiation in order to
425 				 * select the minimum between client and
426 				 * server.
427 				 */
428 				ird32 = min_t(u32, ird32, U8_MAX);
429 				ord32 = min_t(u32, ord32, U8_MAX);
430 
431 				sc->rdma.legacy_iwarp = true;
432 				peer_initiator_depth = (u8)ird32;
433 				peer_responder_resources = (u8)ord32;
434 			}
435 		}
436 
437 		/*
438 		 * negotiate the value by using the minimum
439 		 * between client and server if the client provided
440 		 * non 0 values.
441 		 */
442 		if (peer_initiator_depth != 0)
443 			sp->initiator_depth =
444 					min_t(u8, sp->initiator_depth,
445 					      peer_initiator_depth);
446 		if (peer_responder_resources != 0)
447 			sp->responder_resources =
448 					min_t(u8, sp->responder_resources,
449 					      peer_responder_resources);
450 
451 		if (SMBDIRECT_CHECK_STATUS_DISCONNECT(sc, SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING))
452 			break;
453 		sc->status = SMBDIRECT_SOCKET_NEGOTIATE_NEEDED;
454 		wake_up(&sc->status_wait);
455 		break;
456 
457 	case RDMA_CM_EVENT_CONNECT_ERROR:
458 	case RDMA_CM_EVENT_UNREACHABLE:
459 	case RDMA_CM_EVENT_REJECTED:
460 		log_rdma_event(ERR, "connecting failed event=%s\n", event_name);
461 		sc->status = SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED;
462 		smbd_disconnect_rdma_work(&sc->disconnect_work);
463 		break;
464 
465 	case RDMA_CM_EVENT_DEVICE_REMOVAL:
466 	case RDMA_CM_EVENT_DISCONNECTED:
467 		/* This happens when we fail the negotiation */
468 		if (sc->status == SMBDIRECT_SOCKET_NEGOTIATE_FAILED) {
469 			log_rdma_event(ERR, "event=%s during negotiation\n", event_name);
470 		}
471 
472 		sc->status = SMBDIRECT_SOCKET_DISCONNECTED;
473 		smbd_disconnect_rdma_work(&sc->disconnect_work);
474 		break;
475 
476 	default:
477 		log_rdma_event(ERR, "unexpected event=%s status=%d\n",
478 			       event_name, event->status);
479 		break;
480 	}
481 
482 	return 0;
483 }
484 
485 /* Upcall from RDMA QP */
486 static void
487 smbd_qp_async_error_upcall(struct ib_event *event, void *context)
488 {
489 	struct smbdirect_socket *sc = context;
490 
491 	log_rdma_event(ERR, "%s on device %s socket %p\n",
492 		ib_event_msg(event->event), event->device->name, sc);
493 
494 	switch (event->event) {
495 	case IB_EVENT_CQ_ERR:
496 	case IB_EVENT_QP_FATAL:
497 		smbd_disconnect_rdma_connection(sc);
498 		break;
499 
500 	default:
501 		break;
502 	}
503 }
504 
505 static inline void *smbdirect_send_io_payload(struct smbdirect_send_io *request)
506 {
507 	return (void *)request->packet;
508 }
509 
510 static inline void *smbdirect_recv_io_payload(struct smbdirect_recv_io *response)
511 {
512 	return (void *)response->packet;
513 }
514 
515 static struct smbdirect_send_io *smbd_alloc_send_io(struct smbdirect_socket *sc)
516 {
517 	struct smbdirect_send_io *msg;
518 
519 	msg = mempool_alloc(sc->send_io.mem.pool, GFP_KERNEL);
520 	if (!msg)
521 		return ERR_PTR(-ENOMEM);
522 	msg->socket = sc;
523 	INIT_LIST_HEAD(&msg->sibling_list);
524 	msg->num_sge = 0;
525 
526 	return msg;
527 }
528 
529 static void smbd_free_send_io(struct smbdirect_send_io *msg)
530 {
531 	struct smbdirect_socket *sc = msg->socket;
532 	size_t i;
533 
534 	/*
535 	 * The list needs to be empty!
536 	 * The caller should take care of it.
537 	 */
538 	WARN_ON_ONCE(!list_empty(&msg->sibling_list));
539 
540 	/*
541 	 * Note we call ib_dma_unmap_page(), even if some sges are mapped using
542 	 * ib_dma_map_single().
543 	 *
544 	 * The difference between _single() and _page() only matters for the
545 	 * ib_dma_map_*() case.
546 	 *
547 	 * For the ib_dma_unmap_*() case it does not matter as both take the
548 	 * dma_addr_t and dma_unmap_single_attrs() is just an alias to
549 	 * dma_unmap_page_attrs().
550 	 */
551 	for (i = 0; i < msg->num_sge; i++)
552 		ib_dma_unmap_page(sc->ib.dev,
553 				  msg->sge[i].addr,
554 				  msg->sge[i].length,
555 				  DMA_TO_DEVICE);
556 
557 	mempool_free(msg, sc->send_io.mem.pool);
558 }
559 
560 /* Called when a RDMA send is done */
561 static void send_done(struct ib_cq *cq, struct ib_wc *wc)
562 {
563 	struct smbdirect_send_io *request =
564 		container_of(wc->wr_cqe, struct smbdirect_send_io, cqe);
565 	struct smbdirect_socket *sc = request->socket;
566 	struct smbdirect_send_io *sibling, *next;
567 	int lcredits = 0;
568 
569 	log_rdma_send(INFO, "smbdirect_send_io 0x%p completed wc->status=%s\n",
570 		request, ib_wc_status_msg(wc->status));
571 
572 	if (unlikely(!(request->wr.send_flags & IB_SEND_SIGNALED))) {
573 		/*
574 		 * This happens when smbdirect_send_io is a sibling
575 		 * before the final message, it is signaled on
576 		 * error anyway, so we need to skip
577 		 * smbdirect_connection_free_send_io here,
578 		 * otherwise is will destroy the memory
579 		 * of the siblings too, which will cause
580 		 * use after free problems for the others
581 		 * triggered from ib_drain_qp().
582 		 */
583 		if (wc->status != IB_WC_SUCCESS)
584 			goto skip_free;
585 
586 		/*
587 		 * This should not happen!
588 		 * But we better just close the
589 		 * connection...
590 		 */
591 		log_rdma_send(ERR,
592 			"unexpected send completion wc->status=%s (%d) wc->opcode=%d\n",
593 			ib_wc_status_msg(wc->status), wc->status, wc->opcode);
594 		smbd_disconnect_rdma_connection(sc);
595 		return;
596 	}
597 
598 	/*
599 	 * Free possible siblings and then the main send_io
600 	 */
601 	list_for_each_entry_safe(sibling, next, &request->sibling_list, sibling_list) {
602 		list_del_init(&sibling->sibling_list);
603 		smbd_free_send_io(sibling);
604 		lcredits += 1;
605 	}
606 	/* Note this frees wc->wr_cqe, but not wc */
607 	smbd_free_send_io(request);
608 	lcredits += 1;
609 
610 	if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_SEND) {
611 skip_free:
612 		if (wc->status != IB_WC_WR_FLUSH_ERR)
613 			log_rdma_send(ERR, "wc->status=%s wc->opcode=%d\n",
614 				ib_wc_status_msg(wc->status), wc->opcode);
615 		smbd_disconnect_rdma_connection(sc);
616 		return;
617 	}
618 
619 	atomic_add(lcredits, &sc->send_io.lcredits.count);
620 	wake_up(&sc->send_io.lcredits.wait_queue);
621 
622 	if (atomic_dec_and_test(&sc->send_io.pending.count))
623 		wake_up(&sc->send_io.pending.zero_wait_queue);
624 
625 	wake_up(&sc->send_io.pending.dec_wait_queue);
626 }
627 
628 static void dump_smbdirect_negotiate_resp(struct smbdirect_negotiate_resp *resp)
629 {
630 	log_rdma_event(INFO, "resp message min_version %u max_version %u negotiated_version %u credits_requested %u credits_granted %u status %u max_readwrite_size %u preferred_send_size %u max_receive_size %u max_fragmented_size %u\n",
631 		       resp->min_version, resp->max_version,
632 		       resp->negotiated_version, resp->credits_requested,
633 		       resp->credits_granted, resp->status,
634 		       resp->max_readwrite_size, resp->preferred_send_size,
635 		       resp->max_receive_size, resp->max_fragmented_size);
636 }
637 
638 /*
639  * Process a negotiation response message, according to [MS-SMBD]3.1.5.7
640  * response, packet_length: the negotiation response message
641  * return value: true if negotiation is a success, false if failed
642  */
643 static bool process_negotiation_response(
644 		struct smbdirect_recv_io *response, int packet_length)
645 {
646 	struct smbdirect_socket *sc = response->socket;
647 	struct smbdirect_socket_parameters *sp = &sc->parameters;
648 	struct smbdirect_negotiate_resp *packet = smbdirect_recv_io_payload(response);
649 
650 	if (packet_length < sizeof(struct smbdirect_negotiate_resp)) {
651 		log_rdma_event(ERR,
652 			"error: packet_length=%d\n", packet_length);
653 		return false;
654 	}
655 
656 	if (le16_to_cpu(packet->negotiated_version) != SMBDIRECT_V1) {
657 		log_rdma_event(ERR, "error: negotiated_version=%x\n",
658 			le16_to_cpu(packet->negotiated_version));
659 		return false;
660 	}
661 
662 	if (packet->credits_requested == 0) {
663 		log_rdma_event(ERR, "error: credits_requested==0\n");
664 		return false;
665 	}
666 	sc->recv_io.credits.target = le16_to_cpu(packet->credits_requested);
667 	sc->recv_io.credits.target = min_t(u16, sc->recv_io.credits.target, sp->recv_credit_max);
668 
669 	if (packet->credits_granted == 0) {
670 		log_rdma_event(ERR, "error: credits_granted==0\n");
671 		return false;
672 	}
673 	atomic_set(&sc->send_io.lcredits.count, sp->send_credit_target);
674 	atomic_set(&sc->send_io.credits.count, le16_to_cpu(packet->credits_granted));
675 
676 	if (le32_to_cpu(packet->preferred_send_size) > sp->max_recv_size) {
677 		log_rdma_event(ERR, "error: preferred_send_size=%d\n",
678 			le32_to_cpu(packet->preferred_send_size));
679 		return false;
680 	}
681 	sp->max_recv_size = le32_to_cpu(packet->preferred_send_size);
682 
683 	if (le32_to_cpu(packet->max_receive_size) < SMBD_MIN_RECEIVE_SIZE) {
684 		log_rdma_event(ERR, "error: max_receive_size=%d\n",
685 			le32_to_cpu(packet->max_receive_size));
686 		return false;
687 	}
688 	sp->max_send_size = min_t(u32, sp->max_send_size,
689 				  le32_to_cpu(packet->max_receive_size));
690 
691 	if (le32_to_cpu(packet->max_fragmented_size) <
692 			SMBD_MIN_FRAGMENTED_SIZE) {
693 		log_rdma_event(ERR, "error: max_fragmented_size=%d\n",
694 			le32_to_cpu(packet->max_fragmented_size));
695 		return false;
696 	}
697 	sp->max_fragmented_send_size =
698 		le32_to_cpu(packet->max_fragmented_size);
699 
700 
701 	sp->max_read_write_size = min_t(u32,
702 			le32_to_cpu(packet->max_readwrite_size),
703 			sp->max_frmr_depth * PAGE_SIZE);
704 	sp->max_frmr_depth = sp->max_read_write_size / PAGE_SIZE;
705 
706 	atomic_set(&sc->send_io.bcredits.count, 1);
707 	sc->recv_io.expected = SMBDIRECT_EXPECT_DATA_TRANSFER;
708 	return true;
709 }
710 
711 static void smbd_post_send_credits(struct work_struct *work)
712 {
713 	int rc;
714 	struct smbdirect_recv_io *response;
715 	struct smbdirect_socket *sc =
716 		container_of(work, struct smbdirect_socket, recv_io.posted.refill_work);
717 	int posted = 0;
718 
719 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
720 		return;
721 	}
722 
723 	if (sc->recv_io.credits.target >
724 		atomic_read(&sc->recv_io.credits.count)) {
725 		while (true) {
726 			response = get_receive_buffer(sc);
727 			if (!response)
728 				break;
729 
730 			response->first_segment = false;
731 			rc = smbd_post_recv(sc, response);
732 			if (rc) {
733 				log_rdma_recv(ERR,
734 					"post_recv failed rc=%d\n", rc);
735 				put_receive_buffer(sc, response);
736 				break;
737 			}
738 
739 			atomic_inc(&sc->recv_io.posted.count);
740 			posted += 1;
741 		}
742 	}
743 
744 	atomic_add(posted, &sc->recv_io.credits.available);
745 
746 	/*
747 	 * If the last send credit is waiting for credits
748 	 * it can grant we need to wake it up
749 	 */
750 	if (posted &&
751 	    atomic_read(&sc->send_io.bcredits.count) == 0 &&
752 	    atomic_read(&sc->send_io.credits.count) == 0)
753 		wake_up(&sc->send_io.credits.wait_queue);
754 
755 	/* Promptly send an immediate packet as defined in [MS-SMBD] 3.1.1.1 */
756 	if (atomic_read(&sc->recv_io.credits.count) <
757 		sc->recv_io.credits.target - 1) {
758 		log_keep_alive(INFO, "schedule send of an empty message\n");
759 		queue_work(sc->workqueue, &sc->idle.immediate_work);
760 	}
761 }
762 
763 /* Called from softirq, when recv is done */
764 static void recv_done(struct ib_cq *cq, struct ib_wc *wc)
765 {
766 	struct smbdirect_data_transfer *data_transfer;
767 	struct smbdirect_recv_io *response =
768 		container_of(wc->wr_cqe, struct smbdirect_recv_io, cqe);
769 	struct smbdirect_socket *sc = response->socket;
770 	struct smbdirect_socket_parameters *sp = &sc->parameters;
771 	int current_recv_credits;
772 	u16 old_recv_credit_target;
773 	u32 data_offset = 0;
774 	u32 data_length = 0;
775 	u32 remaining_data_length = 0;
776 	bool negotiate_done = false;
777 
778 	log_rdma_recv(INFO,
779 		      "response=0x%p type=%d wc status=%s wc opcode %d byte_len=%d pkey_index=%u\n",
780 		      response, sc->recv_io.expected,
781 		      ib_wc_status_msg(wc->status), wc->opcode,
782 		      wc->byte_len, wc->pkey_index);
783 
784 	if (wc->status != IB_WC_SUCCESS || wc->opcode != IB_WC_RECV) {
785 		if (wc->status != IB_WC_WR_FLUSH_ERR)
786 			log_rdma_recv(ERR, "wc->status=%s opcode=%d\n",
787 				ib_wc_status_msg(wc->status), wc->opcode);
788 		goto error;
789 	}
790 
791 	ib_dma_sync_single_for_cpu(
792 		wc->qp->device,
793 		response->sge.addr,
794 		response->sge.length,
795 		DMA_FROM_DEVICE);
796 
797 	/*
798 	 * Reset timer to the keepalive interval in
799 	 * order to trigger our next keepalive message.
800 	 */
801 	sc->idle.keepalive = SMBDIRECT_KEEPALIVE_NONE;
802 	mod_delayed_work(sc->workqueue, &sc->idle.timer_work,
803 			 msecs_to_jiffies(sp->keepalive_interval_msec));
804 
805 	switch (sc->recv_io.expected) {
806 	/* SMBD negotiation response */
807 	case SMBDIRECT_EXPECT_NEGOTIATE_REP:
808 		dump_smbdirect_negotiate_resp(smbdirect_recv_io_payload(response));
809 		sc->recv_io.reassembly.full_packet_received = true;
810 		negotiate_done =
811 			process_negotiation_response(response, wc->byte_len);
812 		put_receive_buffer(sc, response);
813 		if (SMBDIRECT_CHECK_STATUS_WARN(sc, SMBDIRECT_SOCKET_NEGOTIATE_RUNNING))
814 			negotiate_done = false;
815 		if (!negotiate_done) {
816 			sc->status = SMBDIRECT_SOCKET_NEGOTIATE_FAILED;
817 			smbd_disconnect_rdma_connection(sc);
818 		} else {
819 			sc->status = SMBDIRECT_SOCKET_CONNECTED;
820 			wake_up(&sc->status_wait);
821 		}
822 
823 		return;
824 
825 	/* SMBD data transfer packet */
826 	case SMBDIRECT_EXPECT_DATA_TRANSFER:
827 		data_transfer = smbdirect_recv_io_payload(response);
828 
829 		if (wc->byte_len <
830 		    offsetof(struct smbdirect_data_transfer, padding))
831 			goto error;
832 
833 		remaining_data_length = le32_to_cpu(data_transfer->remaining_data_length);
834 		data_offset = le32_to_cpu(data_transfer->data_offset);
835 		data_length = le32_to_cpu(data_transfer->data_length);
836 		if (wc->byte_len < data_offset ||
837 		    (u64)wc->byte_len < (u64)data_offset + data_length)
838 			goto error;
839 
840 		if (remaining_data_length > sp->max_fragmented_recv_size ||
841 		    data_length > sp->max_fragmented_recv_size ||
842 		    (u64)remaining_data_length + (u64)data_length > (u64)sp->max_fragmented_recv_size)
843 			goto error;
844 
845 		if (data_length) {
846 			if (sc->recv_io.reassembly.full_packet_received)
847 				response->first_segment = true;
848 
849 			if (le32_to_cpu(data_transfer->remaining_data_length))
850 				sc->recv_io.reassembly.full_packet_received = false;
851 			else
852 				sc->recv_io.reassembly.full_packet_received = true;
853 		}
854 
855 		atomic_dec(&sc->recv_io.posted.count);
856 		current_recv_credits = atomic_dec_return(&sc->recv_io.credits.count);
857 
858 		old_recv_credit_target = sc->recv_io.credits.target;
859 		sc->recv_io.credits.target =
860 			le16_to_cpu(data_transfer->credits_requested);
861 		sc->recv_io.credits.target =
862 			min_t(u16, sc->recv_io.credits.target, sp->recv_credit_max);
863 		sc->recv_io.credits.target =
864 			max_t(u16, sc->recv_io.credits.target, 1);
865 		if (le16_to_cpu(data_transfer->credits_granted)) {
866 			atomic_add(le16_to_cpu(data_transfer->credits_granted),
867 				&sc->send_io.credits.count);
868 			/*
869 			 * We have new send credits granted from remote peer
870 			 * If any sender is waiting for credits, unblock it
871 			 */
872 			wake_up(&sc->send_io.credits.wait_queue);
873 		}
874 
875 		log_incoming(INFO, "data flags %d data_offset %d data_length %d remaining_data_length %d\n",
876 			     le16_to_cpu(data_transfer->flags),
877 			     le32_to_cpu(data_transfer->data_offset),
878 			     le32_to_cpu(data_transfer->data_length),
879 			     le32_to_cpu(data_transfer->remaining_data_length));
880 
881 		/* Send an immediate response right away if requested */
882 		if (le16_to_cpu(data_transfer->flags) &
883 				SMBDIRECT_FLAG_RESPONSE_REQUESTED) {
884 			log_keep_alive(INFO, "schedule send of immediate response\n");
885 			queue_work(sc->workqueue, &sc->idle.immediate_work);
886 		}
887 
888 		/*
889 		 * If this is a packet with data playload place the data in
890 		 * reassembly queue and wake up the reading thread
891 		 */
892 		if (data_length) {
893 			if (current_recv_credits <= (sc->recv_io.credits.target / 4) ||
894 			    sc->recv_io.credits.target > old_recv_credit_target)
895 				queue_work(sc->workqueue, &sc->recv_io.posted.refill_work);
896 
897 			enqueue_reassembly(sc, response, data_length);
898 			wake_up(&sc->recv_io.reassembly.wait_queue);
899 		} else
900 			put_receive_buffer(sc, response);
901 
902 		return;
903 
904 	case SMBDIRECT_EXPECT_NEGOTIATE_REQ:
905 		/* Only server... */
906 		break;
907 	}
908 
909 	/*
910 	 * This is an internal error!
911 	 */
912 	log_rdma_recv(ERR, "unexpected response type=%d\n", sc->recv_io.expected);
913 	WARN_ON_ONCE(sc->recv_io.expected != SMBDIRECT_EXPECT_DATA_TRANSFER);
914 error:
915 	put_receive_buffer(sc, response);
916 	smbd_disconnect_rdma_connection(sc);
917 }
918 
919 static struct rdma_cm_id *smbd_create_id(
920 		struct smbdirect_socket *sc,
921 		struct sockaddr *dstaddr, int port)
922 {
923 	struct smbdirect_socket_parameters *sp = &sc->parameters;
924 	struct rdma_cm_id *id;
925 	u8 node_type = RDMA_NODE_UNSPECIFIED;
926 	int rc;
927 	__be16 *sport;
928 
929 	id = rdma_create_id(&init_net, smbd_conn_upcall, sc,
930 		RDMA_PS_TCP, IB_QPT_RC);
931 	if (IS_ERR(id)) {
932 		rc = PTR_ERR(id);
933 		log_rdma_event(ERR, "rdma_create_id() failed %i\n", rc);
934 		return id;
935 	}
936 
937 	switch (port) {
938 	case SMBD_PORT:
939 		/*
940 		 * only allow iWarp devices
941 		 * for port 5445.
942 		 */
943 		node_type = RDMA_NODE_RNIC;
944 		break;
945 	case SMB_PORT:
946 		/*
947 		 * only allow InfiniBand, RoCEv1 or RoCEv2
948 		 * devices for port 445.
949 		 *
950 		 * (Basically don't allow iWarp devices)
951 		 */
952 		node_type = RDMA_NODE_IB_CA;
953 		break;
954 	}
955 	rc = rdma_restrict_node_type(id, node_type);
956 	if (rc) {
957 		log_rdma_event(ERR, "rdma_restrict_node_type(%u) failed %i\n",
958 			       node_type, rc);
959 		goto out;
960 	}
961 
962 	if (dstaddr->sa_family == AF_INET6)
963 		sport = &((struct sockaddr_in6 *)dstaddr)->sin6_port;
964 	else
965 		sport = &((struct sockaddr_in *)dstaddr)->sin_port;
966 
967 	*sport = htons(port);
968 
969 	WARN_ON_ONCE(sc->status != SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED);
970 	sc->status = SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING;
971 	rc = rdma_resolve_addr(id, NULL, (struct sockaddr *)dstaddr,
972 		sp->resolve_addr_timeout_msec);
973 	if (rc) {
974 		log_rdma_event(ERR, "rdma_resolve_addr() failed %i\n", rc);
975 		goto out;
976 	}
977 	rc = wait_event_interruptible_timeout(
978 		sc->status_wait,
979 		sc->status != SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING,
980 		msecs_to_jiffies(sp->resolve_addr_timeout_msec));
981 	/* e.g. if interrupted returns -ERESTARTSYS */
982 	if (rc < 0) {
983 		log_rdma_event(ERR, "rdma_resolve_addr timeout rc: %i\n", rc);
984 		goto out;
985 	}
986 	if (sc->status == SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING) {
987 		rc = -ETIMEDOUT;
988 		log_rdma_event(ERR, "rdma_resolve_addr() completed %i\n", rc);
989 		goto out;
990 	}
991 	if (sc->status != SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED) {
992 		rc = -EHOSTUNREACH;
993 		log_rdma_event(ERR, "rdma_resolve_addr() completed %i\n", rc);
994 		goto out;
995 	}
996 
997 	WARN_ON_ONCE(sc->status != SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED);
998 	sc->status = SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING;
999 	rc = rdma_resolve_route(id, sp->resolve_route_timeout_msec);
1000 	if (rc) {
1001 		log_rdma_event(ERR, "rdma_resolve_route() failed %i\n", rc);
1002 		goto out;
1003 	}
1004 	rc = wait_event_interruptible_timeout(
1005 		sc->status_wait,
1006 		sc->status != SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING,
1007 		msecs_to_jiffies(sp->resolve_route_timeout_msec));
1008 	/* e.g. if interrupted returns -ERESTARTSYS */
1009 	if (rc < 0)  {
1010 		log_rdma_event(ERR, "rdma_resolve_addr timeout rc: %i\n", rc);
1011 		goto out;
1012 	}
1013 	if (sc->status == SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING) {
1014 		rc = -ETIMEDOUT;
1015 		log_rdma_event(ERR, "rdma_resolve_route() completed %i\n", rc);
1016 		goto out;
1017 	}
1018 	if (sc->status != SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED) {
1019 		rc = -ENETUNREACH;
1020 		log_rdma_event(ERR, "rdma_resolve_route() completed %i\n", rc);
1021 		goto out;
1022 	}
1023 
1024 	return id;
1025 
1026 out:
1027 	rdma_destroy_id(id);
1028 	return ERR_PTR(rc);
1029 }
1030 
1031 /*
1032  * Test if FRWR (Fast Registration Work Requests) is supported on the device
1033  * This implementation requires FRWR on RDMA read/write
1034  * return value: true if it is supported
1035  */
1036 static bool frwr_is_supported(struct ib_device_attr *attrs)
1037 {
1038 	if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS))
1039 		return false;
1040 	if (attrs->max_fast_reg_page_list_len == 0)
1041 		return false;
1042 	return true;
1043 }
1044 
1045 static int smbd_ia_open(
1046 		struct smbdirect_socket *sc,
1047 		struct sockaddr *dstaddr, int port)
1048 {
1049 	struct smbdirect_socket_parameters *sp = &sc->parameters;
1050 	int rc;
1051 
1052 	WARN_ON_ONCE(sc->status != SMBDIRECT_SOCKET_CREATED);
1053 	sc->status = SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED;
1054 
1055 	sc->rdma.cm_id = smbd_create_id(sc, dstaddr, port);
1056 	if (IS_ERR(sc->rdma.cm_id)) {
1057 		rc = PTR_ERR(sc->rdma.cm_id);
1058 		goto out1;
1059 	}
1060 	sc->ib.dev = sc->rdma.cm_id->device;
1061 
1062 	if (!frwr_is_supported(&sc->ib.dev->attrs)) {
1063 		log_rdma_event(ERR, "Fast Registration Work Requests (FRWR) is not supported\n");
1064 		log_rdma_event(ERR, "Device capability flags = %llx max_fast_reg_page_list_len = %u\n",
1065 			       sc->ib.dev->attrs.device_cap_flags,
1066 			       sc->ib.dev->attrs.max_fast_reg_page_list_len);
1067 		rc = -EPROTONOSUPPORT;
1068 		goto out2;
1069 	}
1070 	sp->max_frmr_depth = min_t(u32,
1071 		sp->max_frmr_depth,
1072 		sc->ib.dev->attrs.max_fast_reg_page_list_len);
1073 	sc->mr_io.type = IB_MR_TYPE_MEM_REG;
1074 	if (sc->ib.dev->attrs.kernel_cap_flags & IBK_SG_GAPS_REG)
1075 		sc->mr_io.type = IB_MR_TYPE_SG_GAPS;
1076 
1077 	return 0;
1078 
1079 out2:
1080 	rdma_destroy_id(sc->rdma.cm_id);
1081 	sc->rdma.cm_id = NULL;
1082 
1083 out1:
1084 	return rc;
1085 }
1086 
1087 /*
1088  * Send a negotiation request message to the peer
1089  * The negotiation procedure is in [MS-SMBD] 3.1.5.2 and 3.1.5.3
1090  * After negotiation, the transport is connected and ready for
1091  * carrying upper layer SMB payload
1092  */
1093 static int smbd_post_send_negotiate_req(struct smbdirect_socket *sc)
1094 {
1095 	struct smbdirect_socket_parameters *sp = &sc->parameters;
1096 	int rc;
1097 	struct smbdirect_send_io *request;
1098 	struct smbdirect_negotiate_req *packet;
1099 
1100 	request = smbd_alloc_send_io(sc);
1101 	if (IS_ERR(request))
1102 		return PTR_ERR(request);
1103 
1104 	packet = smbdirect_send_io_payload(request);
1105 	packet->min_version = cpu_to_le16(SMBDIRECT_V1);
1106 	packet->max_version = cpu_to_le16(SMBDIRECT_V1);
1107 	packet->reserved = 0;
1108 	packet->credits_requested = cpu_to_le16(sp->send_credit_target);
1109 	packet->preferred_send_size = cpu_to_le32(sp->max_send_size);
1110 	packet->max_receive_size = cpu_to_le32(sp->max_recv_size);
1111 	packet->max_fragmented_size =
1112 		cpu_to_le32(sp->max_fragmented_recv_size);
1113 
1114 	request->sge[0].addr = ib_dma_map_single(
1115 				sc->ib.dev, (void *)packet,
1116 				sizeof(*packet), DMA_TO_DEVICE);
1117 	if (ib_dma_mapping_error(sc->ib.dev, request->sge[0].addr)) {
1118 		rc = -EIO;
1119 		goto dma_mapping_failed;
1120 	}
1121 	request->num_sge = 1;
1122 
1123 	request->sge[0].length = sizeof(*packet);
1124 	request->sge[0].lkey = sc->ib.pd->local_dma_lkey;
1125 
1126 	rc = smbd_post_send(sc, NULL, request);
1127 	if (!rc)
1128 		return 0;
1129 
1130 	if (rc == -EAGAIN)
1131 		rc = -EIO;
1132 
1133 dma_mapping_failed:
1134 	smbd_free_send_io(request);
1135 	return rc;
1136 }
1137 
1138 /*
1139  * Extend the credits to remote peer
1140  * This implements [MS-SMBD] 3.1.5.9
1141  * The idea is that we should extend credits to remote peer as quickly as
1142  * it's allowed, to maintain data flow. We allocate as much receive
1143  * buffer as possible, and extend the receive credits to remote peer
1144  * return value: the new credtis being granted.
1145  */
1146 static int manage_credits_prior_sending(struct smbdirect_socket *sc)
1147 {
1148 	int missing;
1149 	int available;
1150 	int new_credits;
1151 
1152 	if (atomic_read(&sc->recv_io.credits.count) >= sc->recv_io.credits.target)
1153 		return 0;
1154 
1155 	missing = (int)sc->recv_io.credits.target - atomic_read(&sc->recv_io.credits.count);
1156 	available = atomic_xchg(&sc->recv_io.credits.available, 0);
1157 	new_credits = (u16)min3(U16_MAX, missing, available);
1158 	if (new_credits <= 0) {
1159 		/*
1160 		 * If credits are available, but not granted
1161 		 * we need to re-add them again.
1162 		 */
1163 		if (available)
1164 			atomic_add(available, &sc->recv_io.credits.available);
1165 		return 0;
1166 	}
1167 
1168 	if (new_credits < available) {
1169 		/*
1170 		 * Readd the remaining available again.
1171 		 */
1172 		available -= new_credits;
1173 		atomic_add(available, &sc->recv_io.credits.available);
1174 	}
1175 
1176 	/*
1177 	 * Remember we granted the credits
1178 	 */
1179 	atomic_add(new_credits, &sc->recv_io.credits.count);
1180 	return new_credits;
1181 }
1182 
1183 /*
1184  * Check if we need to send a KEEP_ALIVE message
1185  * The idle connection timer triggers a KEEP_ALIVE message when expires
1186  * SMBDIRECT_FLAG_RESPONSE_REQUESTED is set in the message flag to have peer send
1187  * back a response.
1188  * return value:
1189  * 1 if SMBDIRECT_FLAG_RESPONSE_REQUESTED needs to be set
1190  * 0: otherwise
1191  */
1192 static int manage_keep_alive_before_sending(struct smbdirect_socket *sc)
1193 {
1194 	struct smbdirect_socket_parameters *sp = &sc->parameters;
1195 
1196 	if (sc->idle.keepalive == SMBDIRECT_KEEPALIVE_PENDING) {
1197 		sc->idle.keepalive = SMBDIRECT_KEEPALIVE_SENT;
1198 		/*
1199 		 * Now use the keepalive timeout (instead of keepalive interval)
1200 		 * in order to wait for a response
1201 		 */
1202 		mod_delayed_work(sc->workqueue, &sc->idle.timer_work,
1203 				 msecs_to_jiffies(sp->keepalive_timeout_msec));
1204 		return 1;
1205 	}
1206 	return 0;
1207 }
1208 
1209 static int smbd_ib_post_send(struct smbdirect_socket *sc,
1210 			     struct ib_send_wr *wr)
1211 {
1212 	int ret;
1213 
1214 	atomic_inc(&sc->send_io.pending.count);
1215 	ret = ib_post_send(sc->ib.qp, wr, NULL);
1216 	if (ret) {
1217 		pr_err("failed to post send: %d\n", ret);
1218 		smbd_disconnect_rdma_connection(sc);
1219 		ret = -EAGAIN;
1220 	}
1221 	return ret;
1222 }
1223 
1224 /* Post the send request */
1225 static int smbd_post_send(struct smbdirect_socket *sc,
1226 			  struct smbdirect_send_batch *batch,
1227 			  struct smbdirect_send_io *request)
1228 {
1229 	int i;
1230 
1231 	for (i = 0; i < request->num_sge; i++) {
1232 		log_rdma_send(INFO,
1233 			"rdma_request sge[%d] addr=0x%llx length=%u\n",
1234 			i, request->sge[i].addr, request->sge[i].length);
1235 		ib_dma_sync_single_for_device(
1236 			sc->ib.dev,
1237 			request->sge[i].addr,
1238 			request->sge[i].length,
1239 			DMA_TO_DEVICE);
1240 	}
1241 
1242 	request->cqe.done = send_done;
1243 	request->wr.next = NULL;
1244 	request->wr.sg_list = request->sge;
1245 	request->wr.num_sge = request->num_sge;
1246 	request->wr.opcode = IB_WR_SEND;
1247 
1248 	if (batch) {
1249 		request->wr.wr_cqe = NULL;
1250 		request->wr.send_flags = 0;
1251 		if (!list_empty(&batch->msg_list)) {
1252 			struct smbdirect_send_io *last;
1253 
1254 			last = list_last_entry(&batch->msg_list,
1255 					       struct smbdirect_send_io,
1256 					       sibling_list);
1257 			last->wr.next = &request->wr;
1258 		}
1259 		list_add_tail(&request->sibling_list, &batch->msg_list);
1260 		batch->wr_cnt++;
1261 		return 0;
1262 	}
1263 
1264 	request->wr.wr_cqe = &request->cqe;
1265 	request->wr.send_flags = IB_SEND_SIGNALED;
1266 	return smbd_ib_post_send(sc, &request->wr);
1267 }
1268 
1269 static void smbd_send_batch_init(struct smbdirect_send_batch *batch,
1270 				 bool need_invalidate_rkey,
1271 				 unsigned int remote_key)
1272 {
1273 	INIT_LIST_HEAD(&batch->msg_list);
1274 	batch->wr_cnt = 0;
1275 	batch->need_invalidate_rkey = need_invalidate_rkey;
1276 	batch->remote_key = remote_key;
1277 	batch->credit = 0;
1278 }
1279 
1280 static int smbd_send_batch_flush(struct smbdirect_socket *sc,
1281 				 struct smbdirect_send_batch *batch,
1282 				 bool is_last)
1283 {
1284 	struct smbdirect_send_io *first, *last;
1285 	int ret = 0;
1286 
1287 	if (list_empty(&batch->msg_list))
1288 		goto release_credit;
1289 
1290 	first = list_first_entry(&batch->msg_list,
1291 				 struct smbdirect_send_io,
1292 				 sibling_list);
1293 	last = list_last_entry(&batch->msg_list,
1294 			       struct smbdirect_send_io,
1295 			       sibling_list);
1296 
1297 	if (batch->need_invalidate_rkey) {
1298 		first->wr.opcode = IB_WR_SEND_WITH_INV;
1299 		first->wr.ex.invalidate_rkey = batch->remote_key;
1300 		batch->need_invalidate_rkey = false;
1301 		batch->remote_key = 0;
1302 	}
1303 
1304 	last->wr.send_flags = IB_SEND_SIGNALED;
1305 	last->wr.wr_cqe = &last->cqe;
1306 
1307 	/*
1308 	 * Remove last from batch->msg_list
1309 	 * and splice the rest of batch->msg_list
1310 	 * to last->sibling_list.
1311 	 *
1312 	 * batch->msg_list is a valid empty list
1313 	 * at the end.
1314 	 */
1315 	list_del_init(&last->sibling_list);
1316 	list_splice_tail_init(&batch->msg_list, &last->sibling_list);
1317 	batch->wr_cnt = 0;
1318 
1319 	ret = smbd_ib_post_send(sc, &first->wr);
1320 	if (ret) {
1321 		struct smbdirect_send_io *sibling, *next;
1322 
1323 		list_for_each_entry_safe(sibling, next, &last->sibling_list, sibling_list) {
1324 			list_del_init(&sibling->sibling_list);
1325 			smbd_free_send_io(sibling);
1326 		}
1327 		smbd_free_send_io(last);
1328 	}
1329 
1330 release_credit:
1331 	if (is_last && !ret && batch->credit) {
1332 		atomic_add(batch->credit, &sc->send_io.bcredits.count);
1333 		batch->credit = 0;
1334 		wake_up(&sc->send_io.bcredits.wait_queue);
1335 	}
1336 
1337 	return ret;
1338 }
1339 
1340 static int wait_for_credits(struct smbdirect_socket *sc,
1341 			    wait_queue_head_t *waitq, atomic_t *total_credits,
1342 			    int needed)
1343 {
1344 	int ret;
1345 
1346 	do {
1347 		if (atomic_sub_return(needed, total_credits) >= 0)
1348 			return 0;
1349 
1350 		atomic_add(needed, total_credits);
1351 		ret = wait_event_interruptible(*waitq,
1352 					       atomic_read(total_credits) >= needed ||
1353 					       sc->status != SMBDIRECT_SOCKET_CONNECTED);
1354 
1355 		if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
1356 			return -ENOTCONN;
1357 		else if (ret < 0)
1358 			return ret;
1359 	} while (true);
1360 }
1361 
1362 static int wait_for_send_bcredit(struct smbdirect_socket *sc,
1363 				 struct smbdirect_send_batch *batch)
1364 {
1365 	int ret;
1366 
1367 	if (batch->credit)
1368 		return 0;
1369 
1370 	ret = wait_for_credits(sc,
1371 			       &sc->send_io.bcredits.wait_queue,
1372 			       &sc->send_io.bcredits.count,
1373 			       1);
1374 	if (ret)
1375 		return ret;
1376 
1377 	batch->credit = 1;
1378 	return 0;
1379 }
1380 
1381 static int wait_for_send_lcredit(struct smbdirect_socket *sc,
1382 				 struct smbdirect_send_batch *batch)
1383 {
1384 	if (batch && (atomic_read(&sc->send_io.lcredits.count) <= 1)) {
1385 		int ret;
1386 
1387 		ret = smbd_send_batch_flush(sc, batch, false);
1388 		if (ret)
1389 			return ret;
1390 	}
1391 
1392 	return wait_for_credits(sc,
1393 				&sc->send_io.lcredits.wait_queue,
1394 				&sc->send_io.lcredits.count,
1395 				1);
1396 }
1397 
1398 static int wait_for_send_credits(struct smbdirect_socket *sc,
1399 				 struct smbdirect_send_batch *batch)
1400 {
1401 	if (batch &&
1402 	    (batch->wr_cnt >= 16 || atomic_read(&sc->send_io.credits.count) <= 1)) {
1403 		int ret;
1404 
1405 		ret = smbd_send_batch_flush(sc, batch, false);
1406 		if (ret)
1407 			return ret;
1408 	}
1409 
1410 	return wait_for_credits(sc,
1411 				&sc->send_io.credits.wait_queue,
1412 				&sc->send_io.credits.count,
1413 				1);
1414 }
1415 
1416 static int smbd_post_send_iter(struct smbdirect_socket *sc,
1417 			       struct smbdirect_send_batch *batch,
1418 			       struct iov_iter *iter,
1419 			       int *_remaining_data_length)
1420 {
1421 	struct smbdirect_socket_parameters *sp = &sc->parameters;
1422 	int rc;
1423 	int header_length;
1424 	int data_length;
1425 	struct smbdirect_send_io *request;
1426 	struct smbdirect_data_transfer *packet;
1427 	int new_credits = 0;
1428 	struct smbdirect_send_batch _batch;
1429 
1430 	if (!batch) {
1431 		smbd_send_batch_init(&_batch, false, 0);
1432 		batch = &_batch;
1433 	}
1434 
1435 	rc = wait_for_send_bcredit(sc, batch);
1436 	if (rc) {
1437 		log_outgoing(ERR, "disconnected not sending on wait_bcredit\n");
1438 		rc = -EAGAIN;
1439 		goto err_wait_bcredit;
1440 	}
1441 
1442 	rc = wait_for_send_lcredit(sc, batch);
1443 	if (rc) {
1444 		log_outgoing(ERR, "disconnected not sending on wait_lcredit\n");
1445 		rc = -EAGAIN;
1446 		goto err_wait_lcredit;
1447 	}
1448 
1449 	rc = wait_for_send_credits(sc, batch);
1450 	if (rc) {
1451 		log_outgoing(ERR, "disconnected not sending on wait_credit\n");
1452 		rc = -EAGAIN;
1453 		goto err_wait_credit;
1454 	}
1455 
1456 	new_credits = manage_credits_prior_sending(sc);
1457 	if (new_credits == 0 &&
1458 	    atomic_read(&sc->send_io.credits.count) == 0 &&
1459 	    atomic_read(&sc->recv_io.credits.count) == 0) {
1460 		queue_work(sc->workqueue, &sc->recv_io.posted.refill_work);
1461 		rc = wait_event_interruptible(sc->send_io.credits.wait_queue,
1462 					      atomic_read(&sc->send_io.credits.count) >= 1 ||
1463 					      atomic_read(&sc->recv_io.credits.available) >= 1 ||
1464 					      sc->status != SMBDIRECT_SOCKET_CONNECTED);
1465 		if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
1466 			rc = -ENOTCONN;
1467 		if (rc < 0) {
1468 			log_outgoing(ERR, "disconnected not sending on last credit\n");
1469 			rc = -EAGAIN;
1470 			goto err_wait_credit;
1471 		}
1472 
1473 		new_credits = manage_credits_prior_sending(sc);
1474 	}
1475 
1476 	request = smbd_alloc_send_io(sc);
1477 	if (IS_ERR(request)) {
1478 		rc = PTR_ERR(request);
1479 		goto err_alloc;
1480 	}
1481 
1482 	memset(request->sge, 0, sizeof(request->sge));
1483 
1484 	/* Map the packet to DMA */
1485 	header_length = sizeof(struct smbdirect_data_transfer);
1486 	/* If this is a packet without payload, don't send padding */
1487 	if (!iter)
1488 		header_length = offsetof(struct smbdirect_data_transfer, padding);
1489 
1490 	packet = smbdirect_send_io_payload(request);
1491 	request->sge[0].addr = ib_dma_map_single(sc->ib.dev,
1492 						 (void *)packet,
1493 						 header_length,
1494 						 DMA_TO_DEVICE);
1495 	if (ib_dma_mapping_error(sc->ib.dev, request->sge[0].addr)) {
1496 		rc = -EIO;
1497 		goto err_dma;
1498 	}
1499 
1500 	request->sge[0].length = header_length;
1501 	request->sge[0].lkey = sc->ib.pd->local_dma_lkey;
1502 	request->num_sge = 1;
1503 
1504 	/* Fill in the data payload to find out how much data we can add */
1505 	if (iter) {
1506 		struct smb_extract_to_rdma extract = {
1507 			.nr_sge		= request->num_sge,
1508 			.max_sge	= SMBDIRECT_SEND_IO_MAX_SGE,
1509 			.sge		= request->sge,
1510 			.device		= sc->ib.dev,
1511 			.local_dma_lkey	= sc->ib.pd->local_dma_lkey,
1512 			.direction	= DMA_TO_DEVICE,
1513 		};
1514 		size_t payload_len = umin(*_remaining_data_length,
1515 					  sp->max_send_size - sizeof(*packet));
1516 
1517 		rc = smb_extract_iter_to_rdma(iter, payload_len,
1518 					      &extract);
1519 		if (rc < 0)
1520 			goto err_dma;
1521 		data_length = rc;
1522 		request->num_sge = extract.nr_sge;
1523 		*_remaining_data_length -= data_length;
1524 	} else {
1525 		data_length = 0;
1526 	}
1527 
1528 	/* Fill in the packet header */
1529 	packet->credits_requested = cpu_to_le16(sp->send_credit_target);
1530 	packet->credits_granted = cpu_to_le16(new_credits);
1531 
1532 	packet->flags = 0;
1533 	if (manage_keep_alive_before_sending(sc))
1534 		packet->flags |= cpu_to_le16(SMBDIRECT_FLAG_RESPONSE_REQUESTED);
1535 
1536 	packet->reserved = 0;
1537 	if (!data_length)
1538 		packet->data_offset = 0;
1539 	else
1540 		packet->data_offset = cpu_to_le32(24);
1541 	packet->data_length = cpu_to_le32(data_length);
1542 	packet->remaining_data_length = cpu_to_le32(*_remaining_data_length);
1543 	packet->padding = 0;
1544 
1545 	log_outgoing(INFO, "credits_requested=%d credits_granted=%d data_offset=%d data_length=%d remaining_data_length=%d\n",
1546 		     le16_to_cpu(packet->credits_requested),
1547 		     le16_to_cpu(packet->credits_granted),
1548 		     le32_to_cpu(packet->data_offset),
1549 		     le32_to_cpu(packet->data_length),
1550 		     le32_to_cpu(packet->remaining_data_length));
1551 
1552 	rc = smbd_post_send(sc, batch, request);
1553 	if (!rc) {
1554 		if (batch != &_batch)
1555 			return 0;
1556 
1557 		rc = smbd_send_batch_flush(sc, batch, true);
1558 		if (!rc)
1559 			return 0;
1560 	}
1561 
1562 err_dma:
1563 	smbd_free_send_io(request);
1564 
1565 err_alloc:
1566 	atomic_inc(&sc->send_io.credits.count);
1567 	wake_up(&sc->send_io.credits.wait_queue);
1568 
1569 err_wait_credit:
1570 	atomic_inc(&sc->send_io.lcredits.count);
1571 	wake_up(&sc->send_io.lcredits.wait_queue);
1572 
1573 err_wait_lcredit:
1574 	atomic_add(batch->credit, &sc->send_io.bcredits.count);
1575 	batch->credit = 0;
1576 	wake_up(&sc->send_io.bcredits.wait_queue);
1577 
1578 err_wait_bcredit:
1579 	return rc;
1580 }
1581 
1582 /*
1583  * Send an empty message
1584  * Empty message is used to extend credits to peer to for keep live
1585  * while there is no upper layer payload to send at the time
1586  */
1587 static int smbd_post_send_empty(struct smbdirect_socket *sc)
1588 {
1589 	int remaining_data_length = 0;
1590 
1591 	sc->statistics.send_empty++;
1592 	return smbd_post_send_iter(sc, NULL, NULL, &remaining_data_length);
1593 }
1594 
1595 static int smbd_post_send_full_iter(struct smbdirect_socket *sc,
1596 				    struct smbdirect_send_batch *batch,
1597 				    struct iov_iter *iter,
1598 				    int *_remaining_data_length)
1599 {
1600 	int rc = 0;
1601 
1602 	/*
1603 	 * smbd_post_send_iter() respects the
1604 	 * negotiated max_send_size, so we need to
1605 	 * loop until the full iter is posted
1606 	 */
1607 
1608 	while (iov_iter_count(iter) > 0) {
1609 		rc = smbd_post_send_iter(sc, batch, iter, _remaining_data_length);
1610 		if (rc < 0)
1611 			break;
1612 	}
1613 
1614 	return rc;
1615 }
1616 
1617 /*
1618  * Post a receive request to the transport
1619  * The remote peer can only send data when a receive request is posted
1620  * The interaction is controlled by send/receive credit system
1621  */
1622 static int smbd_post_recv(
1623 		struct smbdirect_socket *sc, struct smbdirect_recv_io *response)
1624 {
1625 	struct smbdirect_socket_parameters *sp = &sc->parameters;
1626 	struct ib_recv_wr recv_wr;
1627 	int rc = -EIO;
1628 
1629 	response->sge.addr = ib_dma_map_single(
1630 				sc->ib.dev, response->packet,
1631 				sp->max_recv_size, DMA_FROM_DEVICE);
1632 	if (ib_dma_mapping_error(sc->ib.dev, response->sge.addr))
1633 		return rc;
1634 
1635 	response->sge.length = sp->max_recv_size;
1636 	response->sge.lkey = sc->ib.pd->local_dma_lkey;
1637 
1638 	response->cqe.done = recv_done;
1639 
1640 	recv_wr.wr_cqe = &response->cqe;
1641 	recv_wr.next = NULL;
1642 	recv_wr.sg_list = &response->sge;
1643 	recv_wr.num_sge = 1;
1644 
1645 	rc = ib_post_recv(sc->ib.qp, &recv_wr, NULL);
1646 	if (rc) {
1647 		ib_dma_unmap_single(sc->ib.dev, response->sge.addr,
1648 				    response->sge.length, DMA_FROM_DEVICE);
1649 		response->sge.length = 0;
1650 		smbd_disconnect_rdma_connection(sc);
1651 		log_rdma_recv(ERR, "ib_post_recv failed rc=%d\n", rc);
1652 	}
1653 
1654 	return rc;
1655 }
1656 
1657 /* Perform SMBD negotiate according to [MS-SMBD] 3.1.5.2 */
1658 static int smbd_negotiate(struct smbdirect_socket *sc)
1659 {
1660 	struct smbdirect_socket_parameters *sp = &sc->parameters;
1661 	int rc;
1662 	struct smbdirect_recv_io *response = get_receive_buffer(sc);
1663 
1664 	WARN_ON_ONCE(sc->status != SMBDIRECT_SOCKET_NEGOTIATE_NEEDED);
1665 	sc->status = SMBDIRECT_SOCKET_NEGOTIATE_RUNNING;
1666 
1667 	sc->recv_io.expected = SMBDIRECT_EXPECT_NEGOTIATE_REP;
1668 	rc = smbd_post_recv(sc, response);
1669 	log_rdma_event(INFO, "smbd_post_recv rc=%d iov.addr=0x%llx iov.length=%u iov.lkey=0x%x\n",
1670 		       rc, response->sge.addr,
1671 		       response->sge.length, response->sge.lkey);
1672 	if (rc) {
1673 		put_receive_buffer(sc, response);
1674 		return rc;
1675 	}
1676 
1677 	rc = smbd_post_send_negotiate_req(sc);
1678 	if (rc)
1679 		return rc;
1680 
1681 	rc = wait_event_interruptible_timeout(
1682 		sc->status_wait,
1683 		sc->status != SMBDIRECT_SOCKET_NEGOTIATE_RUNNING,
1684 		msecs_to_jiffies(sp->negotiate_timeout_msec));
1685 	log_rdma_event(INFO, "wait_event_interruptible_timeout rc=%d\n", rc);
1686 
1687 	if (sc->status == SMBDIRECT_SOCKET_CONNECTED)
1688 		return 0;
1689 
1690 	if (rc == 0)
1691 		rc = -ETIMEDOUT;
1692 	else if (rc == -ERESTARTSYS)
1693 		rc = -EINTR;
1694 	else
1695 		rc = -ENOTCONN;
1696 
1697 	return rc;
1698 }
1699 
1700 /*
1701  * Implement Connection.FragmentReassemblyBuffer defined in [MS-SMBD] 3.1.1.1
1702  * This is a queue for reassembling upper layer payload and present to upper
1703  * layer. All the inncoming payload go to the reassembly queue, regardless of
1704  * if reassembly is required. The uuper layer code reads from the queue for all
1705  * incoming payloads.
1706  * Put a received packet to the reassembly queue
1707  * response: the packet received
1708  * data_length: the size of payload in this packet
1709  */
1710 static void enqueue_reassembly(
1711 	struct smbdirect_socket *sc,
1712 	struct smbdirect_recv_io *response,
1713 	int data_length)
1714 {
1715 	unsigned long flags;
1716 
1717 	spin_lock_irqsave(&sc->recv_io.reassembly.lock, flags);
1718 	list_add_tail(&response->list, &sc->recv_io.reassembly.list);
1719 	sc->recv_io.reassembly.queue_length++;
1720 	/*
1721 	 * Make sure reassembly_data_length is updated after list and
1722 	 * reassembly_queue_length are updated. On the dequeue side
1723 	 * reassembly_data_length is checked without a lock to determine
1724 	 * if reassembly_queue_length and list is up to date
1725 	 */
1726 	virt_wmb();
1727 	sc->recv_io.reassembly.data_length += data_length;
1728 	spin_unlock_irqrestore(&sc->recv_io.reassembly.lock, flags);
1729 	sc->statistics.enqueue_reassembly_queue++;
1730 }
1731 
1732 /*
1733  * Get the first entry at the front of reassembly queue
1734  * Caller is responsible for locking
1735  * return value: the first entry if any, NULL if queue is empty
1736  */
1737 static struct smbdirect_recv_io *_get_first_reassembly(struct smbdirect_socket *sc)
1738 {
1739 	struct smbdirect_recv_io *ret = NULL;
1740 
1741 	if (!list_empty(&sc->recv_io.reassembly.list)) {
1742 		ret = list_first_entry(
1743 			&sc->recv_io.reassembly.list,
1744 			struct smbdirect_recv_io, list);
1745 	}
1746 	return ret;
1747 }
1748 
1749 /*
1750  * Get a receive buffer
1751  * For each remote send, we need to post a receive. The receive buffers are
1752  * pre-allocated in advance.
1753  * return value: the receive buffer, NULL if none is available
1754  */
1755 static struct smbdirect_recv_io *get_receive_buffer(struct smbdirect_socket *sc)
1756 {
1757 	struct smbdirect_recv_io *ret = NULL;
1758 	unsigned long flags;
1759 
1760 	spin_lock_irqsave(&sc->recv_io.free.lock, flags);
1761 	if (!list_empty(&sc->recv_io.free.list)) {
1762 		ret = list_first_entry(
1763 			&sc->recv_io.free.list,
1764 			struct smbdirect_recv_io, list);
1765 		list_del(&ret->list);
1766 		sc->statistics.get_receive_buffer++;
1767 	}
1768 	spin_unlock_irqrestore(&sc->recv_io.free.lock, flags);
1769 
1770 	return ret;
1771 }
1772 
1773 /*
1774  * Return a receive buffer
1775  * Upon returning of a receive buffer, we can post new receive and extend
1776  * more receive credits to remote peer. This is done immediately after a
1777  * receive buffer is returned.
1778  */
1779 static void put_receive_buffer(
1780 	struct smbdirect_socket *sc, struct smbdirect_recv_io *response)
1781 {
1782 	unsigned long flags;
1783 
1784 	if (likely(response->sge.length != 0)) {
1785 		ib_dma_unmap_single(sc->ib.dev,
1786 				    response->sge.addr,
1787 				    response->sge.length,
1788 				    DMA_FROM_DEVICE);
1789 		response->sge.length = 0;
1790 	}
1791 
1792 	spin_lock_irqsave(&sc->recv_io.free.lock, flags);
1793 	list_add_tail(&response->list, &sc->recv_io.free.list);
1794 	sc->statistics.put_receive_buffer++;
1795 	spin_unlock_irqrestore(&sc->recv_io.free.lock, flags);
1796 
1797 	queue_work(sc->workqueue, &sc->recv_io.posted.refill_work);
1798 }
1799 
1800 /* Preallocate all receive buffer on transport establishment */
1801 static int allocate_receive_buffers(struct smbdirect_socket *sc, int num_buf)
1802 {
1803 	struct smbdirect_recv_io *response;
1804 	int i;
1805 
1806 	for (i = 0; i < num_buf; i++) {
1807 		response = mempool_alloc(sc->recv_io.mem.pool, GFP_KERNEL);
1808 		if (!response)
1809 			goto allocate_failed;
1810 
1811 		response->socket = sc;
1812 		response->sge.length = 0;
1813 		list_add_tail(&response->list, &sc->recv_io.free.list);
1814 	}
1815 
1816 	return 0;
1817 
1818 allocate_failed:
1819 	while (!list_empty(&sc->recv_io.free.list)) {
1820 		response = list_first_entry(
1821 				&sc->recv_io.free.list,
1822 				struct smbdirect_recv_io, list);
1823 		list_del(&response->list);
1824 
1825 		mempool_free(response, sc->recv_io.mem.pool);
1826 	}
1827 	return -ENOMEM;
1828 }
1829 
1830 static void destroy_receive_buffers(struct smbdirect_socket *sc)
1831 {
1832 	struct smbdirect_recv_io *response;
1833 
1834 	while ((response = get_receive_buffer(sc)))
1835 		mempool_free(response, sc->recv_io.mem.pool);
1836 }
1837 
1838 static void send_immediate_empty_message(struct work_struct *work)
1839 {
1840 	struct smbdirect_socket *sc =
1841 		container_of(work, struct smbdirect_socket, idle.immediate_work);
1842 
1843 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
1844 		return;
1845 
1846 	log_keep_alive(INFO, "send an empty message\n");
1847 	smbd_post_send_empty(sc);
1848 }
1849 
1850 /* Implement idle connection timer [MS-SMBD] 3.1.6.2 */
1851 static void idle_connection_timer(struct work_struct *work)
1852 {
1853 	struct smbdirect_socket *sc =
1854 		container_of(work, struct smbdirect_socket, idle.timer_work.work);
1855 	struct smbdirect_socket_parameters *sp = &sc->parameters;
1856 
1857 	if (sc->idle.keepalive != SMBDIRECT_KEEPALIVE_NONE) {
1858 		log_keep_alive(ERR,
1859 			"error status sc->idle.keepalive=%d\n",
1860 			sc->idle.keepalive);
1861 		smbd_disconnect_rdma_connection(sc);
1862 		return;
1863 	}
1864 
1865 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
1866 		return;
1867 
1868 	/*
1869 	 * Now use the keepalive timeout (instead of keepalive interval)
1870 	 * in order to wait for a response
1871 	 */
1872 	sc->idle.keepalive = SMBDIRECT_KEEPALIVE_PENDING;
1873 	mod_delayed_work(sc->workqueue, &sc->idle.timer_work,
1874 			 msecs_to_jiffies(sp->keepalive_timeout_msec));
1875 	log_keep_alive(INFO, "schedule send of empty idle message\n");
1876 	queue_work(sc->workqueue, &sc->idle.immediate_work);
1877 }
1878 
1879 /*
1880  * Destroy the transport and related RDMA and memory resources
1881  * Need to go through all the pending counters and make sure on one is using
1882  * the transport while it is destroyed
1883  */
1884 void smbd_destroy(struct TCP_Server_Info *server)
1885 {
1886 	struct smbd_connection *info = server->smbd_conn;
1887 	struct smbdirect_socket *sc;
1888 	struct smbdirect_recv_io *response;
1889 	unsigned long flags;
1890 
1891 	if (!info) {
1892 		log_rdma_event(INFO, "rdma session already destroyed\n");
1893 		return;
1894 	}
1895 	sc = &info->socket;
1896 
1897 	log_rdma_event(INFO, "cancelling and disable disconnect_work\n");
1898 	disable_work_sync(&sc->disconnect_work);
1899 
1900 	log_rdma_event(INFO, "destroying rdma session\n");
1901 	if (sc->status < SMBDIRECT_SOCKET_DISCONNECTING)
1902 		smbd_disconnect_rdma_work(&sc->disconnect_work);
1903 	if (sc->status < SMBDIRECT_SOCKET_DISCONNECTED) {
1904 		log_rdma_event(INFO, "wait for transport being disconnected\n");
1905 		wait_event(sc->status_wait, sc->status == SMBDIRECT_SOCKET_DISCONNECTED);
1906 		log_rdma_event(INFO, "waited for transport being disconnected\n");
1907 	}
1908 
1909 	/*
1910 	 * Wake up all waiters in all wait queues
1911 	 * in order to notice the broken connection.
1912 	 *
1913 	 * Most likely this was already called via
1914 	 * smbd_disconnect_rdma_work(), but call it again...
1915 	 */
1916 	smbd_disconnect_wake_up_all(sc);
1917 
1918 	log_rdma_event(INFO, "cancelling recv_io.posted.refill_work\n");
1919 	disable_work_sync(&sc->recv_io.posted.refill_work);
1920 
1921 	log_rdma_event(INFO, "destroying qp\n");
1922 	ib_drain_qp(sc->ib.qp);
1923 	rdma_destroy_qp(sc->rdma.cm_id);
1924 	sc->ib.qp = NULL;
1925 
1926 	log_rdma_event(INFO, "cancelling idle timer\n");
1927 	disable_delayed_work_sync(&sc->idle.timer_work);
1928 	log_rdma_event(INFO, "cancelling send immediate work\n");
1929 	disable_work_sync(&sc->idle.immediate_work);
1930 
1931 	/* It's not possible for upper layer to get to reassembly */
1932 	log_rdma_event(INFO, "drain the reassembly queue\n");
1933 	do {
1934 		spin_lock_irqsave(&sc->recv_io.reassembly.lock, flags);
1935 		response = _get_first_reassembly(sc);
1936 		if (response) {
1937 			list_del(&response->list);
1938 			spin_unlock_irqrestore(
1939 				&sc->recv_io.reassembly.lock, flags);
1940 			put_receive_buffer(sc, response);
1941 		} else
1942 			spin_unlock_irqrestore(
1943 				&sc->recv_io.reassembly.lock, flags);
1944 	} while (response);
1945 	sc->recv_io.reassembly.data_length = 0;
1946 
1947 	log_rdma_event(INFO, "free receive buffers\n");
1948 	destroy_receive_buffers(sc);
1949 
1950 	log_rdma_event(INFO, "freeing mr list\n");
1951 	destroy_mr_list(sc);
1952 
1953 	ib_free_cq(sc->ib.send_cq);
1954 	ib_free_cq(sc->ib.recv_cq);
1955 	ib_dealloc_pd(sc->ib.pd);
1956 	rdma_destroy_id(sc->rdma.cm_id);
1957 
1958 	/* free mempools */
1959 	mempool_destroy(sc->send_io.mem.pool);
1960 	kmem_cache_destroy(sc->send_io.mem.cache);
1961 
1962 	mempool_destroy(sc->recv_io.mem.pool);
1963 	kmem_cache_destroy(sc->recv_io.mem.cache);
1964 
1965 	sc->status = SMBDIRECT_SOCKET_DESTROYED;
1966 
1967 	destroy_workqueue(sc->workqueue);
1968 	log_rdma_event(INFO,  "rdma session destroyed\n");
1969 	kfree(info);
1970 	server->smbd_conn = NULL;
1971 }
1972 
1973 /*
1974  * Reconnect this SMBD connection, called from upper layer
1975  * return value: 0 on success, or actual error code
1976  */
1977 int smbd_reconnect(struct TCP_Server_Info *server)
1978 {
1979 	log_rdma_event(INFO, "reconnecting rdma session\n");
1980 
1981 	if (!server->smbd_conn) {
1982 		log_rdma_event(INFO, "rdma session already destroyed\n");
1983 		goto create_conn;
1984 	}
1985 
1986 	/*
1987 	 * This is possible if transport is disconnected and we haven't received
1988 	 * notification from RDMA, but upper layer has detected timeout
1989 	 */
1990 	if (server->smbd_conn->socket.status == SMBDIRECT_SOCKET_CONNECTED) {
1991 		log_rdma_event(INFO, "disconnecting transport\n");
1992 		smbd_destroy(server);
1993 	}
1994 
1995 create_conn:
1996 	log_rdma_event(INFO, "creating rdma session\n");
1997 	server->smbd_conn = smbd_get_connection(
1998 		server, (struct sockaddr *) &server->dstaddr);
1999 
2000 	if (server->smbd_conn) {
2001 		cifs_dbg(VFS, "RDMA transport re-established\n");
2002 		trace_smb3_smbd_connect_done(server->hostname, server->conn_id, &server->dstaddr);
2003 		return 0;
2004 	}
2005 	trace_smb3_smbd_connect_err(server->hostname, server->conn_id, &server->dstaddr);
2006 	return -ENOENT;
2007 }
2008 
2009 static void destroy_caches(struct smbdirect_socket *sc)
2010 {
2011 	destroy_receive_buffers(sc);
2012 	mempool_destroy(sc->recv_io.mem.pool);
2013 	kmem_cache_destroy(sc->recv_io.mem.cache);
2014 	mempool_destroy(sc->send_io.mem.pool);
2015 	kmem_cache_destroy(sc->send_io.mem.cache);
2016 }
2017 
2018 #define MAX_NAME_LEN	80
2019 static int allocate_caches(struct smbdirect_socket *sc)
2020 {
2021 	struct smbdirect_socket_parameters *sp = &sc->parameters;
2022 	char name[MAX_NAME_LEN];
2023 	int rc;
2024 
2025 	if (WARN_ON_ONCE(sp->max_recv_size < sizeof(struct smbdirect_data_transfer)))
2026 		return -ENOMEM;
2027 
2028 	scnprintf(name, MAX_NAME_LEN, "smbdirect_send_io_%p", sc);
2029 	sc->send_io.mem.cache =
2030 		kmem_cache_create(
2031 			name,
2032 			sizeof(struct smbdirect_send_io) +
2033 				sizeof(struct smbdirect_data_transfer),
2034 			0, SLAB_HWCACHE_ALIGN, NULL);
2035 	if (!sc->send_io.mem.cache)
2036 		return -ENOMEM;
2037 
2038 	sc->send_io.mem.pool =
2039 		mempool_create(sp->send_credit_target, mempool_alloc_slab,
2040 			mempool_free_slab, sc->send_io.mem.cache);
2041 	if (!sc->send_io.mem.pool)
2042 		goto out1;
2043 
2044 	scnprintf(name, MAX_NAME_LEN, "smbdirect_recv_io_%p", sc);
2045 
2046 	struct kmem_cache_args response_args = {
2047 		.align		= __alignof__(struct smbdirect_recv_io),
2048 		.useroffset	= (offsetof(struct smbdirect_recv_io, packet) +
2049 				   sizeof(struct smbdirect_data_transfer)),
2050 		.usersize	= sp->max_recv_size - sizeof(struct smbdirect_data_transfer),
2051 	};
2052 	sc->recv_io.mem.cache =
2053 		kmem_cache_create(name,
2054 				  sizeof(struct smbdirect_recv_io) + sp->max_recv_size,
2055 				  &response_args, SLAB_HWCACHE_ALIGN);
2056 	if (!sc->recv_io.mem.cache)
2057 		goto out2;
2058 
2059 	sc->recv_io.mem.pool =
2060 		mempool_create(sp->recv_credit_max, mempool_alloc_slab,
2061 		       mempool_free_slab, sc->recv_io.mem.cache);
2062 	if (!sc->recv_io.mem.pool)
2063 		goto out3;
2064 
2065 	rc = allocate_receive_buffers(sc, sp->recv_credit_max);
2066 	if (rc) {
2067 		log_rdma_event(ERR, "failed to allocate receive buffers\n");
2068 		goto out4;
2069 	}
2070 
2071 	return 0;
2072 
2073 out4:
2074 	mempool_destroy(sc->recv_io.mem.pool);
2075 out3:
2076 	kmem_cache_destroy(sc->recv_io.mem.cache);
2077 out2:
2078 	mempool_destroy(sc->send_io.mem.pool);
2079 out1:
2080 	kmem_cache_destroy(sc->send_io.mem.cache);
2081 	return -ENOMEM;
2082 }
2083 
2084 /* Create a SMBD connection, called by upper layer */
2085 static struct smbd_connection *_smbd_get_connection(
2086 	struct TCP_Server_Info *server, struct sockaddr *dstaddr, int port)
2087 {
2088 	int rc;
2089 	struct smbd_connection *info;
2090 	struct smbdirect_socket *sc;
2091 	struct smbdirect_socket_parameters *sp;
2092 	struct rdma_conn_param conn_param;
2093 	struct ib_qp_cap qp_cap;
2094 	struct ib_qp_init_attr qp_attr;
2095 	struct sockaddr_in *addr_in = (struct sockaddr_in *) dstaddr;
2096 	struct ib_port_immutable port_immutable;
2097 	__be32 ird_ord_hdr[2];
2098 	char wq_name[80];
2099 	struct workqueue_struct *workqueue;
2100 
2101 	info = kzalloc(sizeof(struct smbd_connection), GFP_KERNEL);
2102 	if (!info)
2103 		return NULL;
2104 	sc = &info->socket;
2105 	scnprintf(wq_name, ARRAY_SIZE(wq_name), "smbd_%p", sc);
2106 	workqueue = create_workqueue(wq_name);
2107 	if (!workqueue)
2108 		goto create_wq_failed;
2109 	smbdirect_socket_init(sc);
2110 	sc->workqueue = workqueue;
2111 	sp = &sc->parameters;
2112 
2113 	INIT_WORK(&sc->disconnect_work, smbd_disconnect_rdma_work);
2114 
2115 	sp->resolve_addr_timeout_msec = RDMA_RESOLVE_TIMEOUT;
2116 	sp->resolve_route_timeout_msec = RDMA_RESOLVE_TIMEOUT;
2117 	sp->rdma_connect_timeout_msec = RDMA_RESOLVE_TIMEOUT;
2118 	sp->negotiate_timeout_msec = SMBD_NEGOTIATE_TIMEOUT * 1000;
2119 	sp->initiator_depth = 1;
2120 	sp->responder_resources = SMBD_CM_RESPONDER_RESOURCES;
2121 	sp->recv_credit_max = smbd_receive_credit_max;
2122 	sp->send_credit_target = smbd_send_credit_target;
2123 	sp->max_send_size = smbd_max_send_size;
2124 	sp->max_fragmented_recv_size = smbd_max_fragmented_recv_size;
2125 	sp->max_recv_size = smbd_max_receive_size;
2126 	sp->max_frmr_depth = smbd_max_frmr_depth;
2127 	sp->keepalive_interval_msec = smbd_keep_alive_interval * 1000;
2128 	sp->keepalive_timeout_msec = KEEPALIVE_RECV_TIMEOUT * 1000;
2129 
2130 	rc = smbd_ia_open(sc, dstaddr, port);
2131 	if (rc) {
2132 		log_rdma_event(INFO, "smbd_ia_open rc=%d\n", rc);
2133 		goto create_id_failed;
2134 	}
2135 
2136 	if (sp->send_credit_target > sc->ib.dev->attrs.max_cqe ||
2137 	    sp->send_credit_target > sc->ib.dev->attrs.max_qp_wr) {
2138 		log_rdma_event(ERR, "consider lowering send_credit_target = %d. Possible CQE overrun, device reporting max_cqe %d max_qp_wr %d\n",
2139 			       sp->send_credit_target,
2140 			       sc->ib.dev->attrs.max_cqe,
2141 			       sc->ib.dev->attrs.max_qp_wr);
2142 		goto config_failed;
2143 	}
2144 
2145 	if (sp->recv_credit_max > sc->ib.dev->attrs.max_cqe ||
2146 	    sp->recv_credit_max > sc->ib.dev->attrs.max_qp_wr) {
2147 		log_rdma_event(ERR, "consider lowering receive_credit_max = %d. Possible CQE overrun, device reporting max_cqe %d max_qp_wr %d\n",
2148 			       sp->recv_credit_max,
2149 			       sc->ib.dev->attrs.max_cqe,
2150 			       sc->ib.dev->attrs.max_qp_wr);
2151 		goto config_failed;
2152 	}
2153 
2154 	if (sc->ib.dev->attrs.max_send_sge < SMBDIRECT_SEND_IO_MAX_SGE ||
2155 	    sc->ib.dev->attrs.max_recv_sge < SMBDIRECT_RECV_IO_MAX_SGE) {
2156 		log_rdma_event(ERR,
2157 			"device %.*s max_send_sge/max_recv_sge = %d/%d too small\n",
2158 			IB_DEVICE_NAME_MAX,
2159 			sc->ib.dev->name,
2160 			sc->ib.dev->attrs.max_send_sge,
2161 			sc->ib.dev->attrs.max_recv_sge);
2162 		goto config_failed;
2163 	}
2164 
2165 	sp->responder_resources =
2166 		min_t(u8, sp->responder_resources,
2167 		      sc->ib.dev->attrs.max_qp_rd_atom);
2168 	log_rdma_mr(INFO, "responder_resources=%d\n",
2169 		sp->responder_resources);
2170 
2171 	/*
2172 	 * We use allocate sp->responder_resources * 2 MRs
2173 	 * and each MR needs WRs for REG and INV, so
2174 	 * we use '* 4'.
2175 	 *
2176 	 * +1 for ib_drain_qp()
2177 	 */
2178 	memset(&qp_cap, 0, sizeof(qp_cap));
2179 	qp_cap.max_send_wr = sp->send_credit_target + sp->responder_resources * 4 + 1;
2180 	qp_cap.max_recv_wr = sp->recv_credit_max + 1;
2181 	qp_cap.max_send_sge = SMBDIRECT_SEND_IO_MAX_SGE;
2182 	qp_cap.max_recv_sge = SMBDIRECT_RECV_IO_MAX_SGE;
2183 
2184 	sc->ib.pd = ib_alloc_pd(sc->ib.dev, 0);
2185 	if (IS_ERR(sc->ib.pd)) {
2186 		rc = PTR_ERR(sc->ib.pd);
2187 		sc->ib.pd = NULL;
2188 		log_rdma_event(ERR, "ib_alloc_pd() returned %d\n", rc);
2189 		goto alloc_pd_failed;
2190 	}
2191 
2192 	sc->ib.send_cq =
2193 		ib_alloc_cq_any(sc->ib.dev, sc,
2194 				qp_cap.max_send_wr, IB_POLL_SOFTIRQ);
2195 	if (IS_ERR(sc->ib.send_cq)) {
2196 		sc->ib.send_cq = NULL;
2197 		goto alloc_cq_failed;
2198 	}
2199 
2200 	sc->ib.recv_cq =
2201 		ib_alloc_cq_any(sc->ib.dev, sc,
2202 				qp_cap.max_recv_wr, IB_POLL_SOFTIRQ);
2203 	if (IS_ERR(sc->ib.recv_cq)) {
2204 		sc->ib.recv_cq = NULL;
2205 		goto alloc_cq_failed;
2206 	}
2207 
2208 	memset(&qp_attr, 0, sizeof(qp_attr));
2209 	qp_attr.event_handler = smbd_qp_async_error_upcall;
2210 	qp_attr.qp_context = sc;
2211 	qp_attr.cap = qp_cap;
2212 	qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR;
2213 	qp_attr.qp_type = IB_QPT_RC;
2214 	qp_attr.send_cq = sc->ib.send_cq;
2215 	qp_attr.recv_cq = sc->ib.recv_cq;
2216 	qp_attr.port_num = ~0;
2217 
2218 	rc = rdma_create_qp(sc->rdma.cm_id, sc->ib.pd, &qp_attr);
2219 	if (rc) {
2220 		log_rdma_event(ERR, "rdma_create_qp failed %i\n", rc);
2221 		goto create_qp_failed;
2222 	}
2223 	sc->ib.qp = sc->rdma.cm_id->qp;
2224 
2225 	memset(&conn_param, 0, sizeof(conn_param));
2226 	conn_param.initiator_depth = sp->initiator_depth;
2227 	conn_param.responder_resources = sp->responder_resources;
2228 
2229 	/* Need to send IRD/ORD in private data for iWARP */
2230 	sc->ib.dev->ops.get_port_immutable(
2231 		sc->ib.dev, sc->rdma.cm_id->port_num, &port_immutable);
2232 	if (port_immutable.core_cap_flags & RDMA_CORE_PORT_IWARP) {
2233 		ird_ord_hdr[0] = cpu_to_be32(conn_param.responder_resources);
2234 		ird_ord_hdr[1] = cpu_to_be32(conn_param.initiator_depth);
2235 		conn_param.private_data = ird_ord_hdr;
2236 		conn_param.private_data_len = sizeof(ird_ord_hdr);
2237 	} else {
2238 		conn_param.private_data = NULL;
2239 		conn_param.private_data_len = 0;
2240 	}
2241 
2242 	conn_param.retry_count = SMBD_CM_RETRY;
2243 	conn_param.rnr_retry_count = SMBD_CM_RNR_RETRY;
2244 	conn_param.flow_control = 0;
2245 
2246 	log_rdma_event(INFO, "connecting to IP %pI4 port %d\n",
2247 		&addr_in->sin_addr, port);
2248 
2249 	WARN_ON_ONCE(sc->status != SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED);
2250 	sc->status = SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING;
2251 	rc = rdma_connect(sc->rdma.cm_id, &conn_param);
2252 	if (rc) {
2253 		log_rdma_event(ERR, "rdma_connect() failed with %i\n", rc);
2254 		goto rdma_connect_failed;
2255 	}
2256 
2257 	wait_event_interruptible_timeout(
2258 		sc->status_wait,
2259 		sc->status != SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING,
2260 		msecs_to_jiffies(sp->rdma_connect_timeout_msec));
2261 
2262 	if (sc->status != SMBDIRECT_SOCKET_NEGOTIATE_NEEDED) {
2263 		log_rdma_event(ERR, "rdma_connect failed port=%d\n", port);
2264 		goto rdma_connect_failed;
2265 	}
2266 
2267 	log_rdma_event(INFO, "rdma_connect connected\n");
2268 
2269 	rc = allocate_caches(sc);
2270 	if (rc) {
2271 		log_rdma_event(ERR, "cache allocation failed\n");
2272 		goto allocate_cache_failed;
2273 	}
2274 
2275 	INIT_WORK(&sc->idle.immediate_work, send_immediate_empty_message);
2276 	INIT_DELAYED_WORK(&sc->idle.timer_work, idle_connection_timer);
2277 	/*
2278 	 * start with the negotiate timeout and SMBDIRECT_KEEPALIVE_PENDING
2279 	 * so that the timer will cause a disconnect.
2280 	 */
2281 	sc->idle.keepalive = SMBDIRECT_KEEPALIVE_PENDING;
2282 	mod_delayed_work(sc->workqueue, &sc->idle.timer_work,
2283 			 msecs_to_jiffies(sp->negotiate_timeout_msec));
2284 
2285 	INIT_WORK(&sc->recv_io.posted.refill_work, smbd_post_send_credits);
2286 
2287 	rc = smbd_negotiate(sc);
2288 	if (rc) {
2289 		log_rdma_event(ERR, "smbd_negotiate rc=%d\n", rc);
2290 		goto negotiation_failed;
2291 	}
2292 
2293 	rc = allocate_mr_list(sc);
2294 	if (rc) {
2295 		log_rdma_mr(ERR, "memory registration allocation failed\n");
2296 		goto allocate_mr_failed;
2297 	}
2298 
2299 	return info;
2300 
2301 allocate_mr_failed:
2302 	/* At this point, need to a full transport shutdown */
2303 	server->smbd_conn = info;
2304 	smbd_destroy(server);
2305 	return NULL;
2306 
2307 negotiation_failed:
2308 	disable_delayed_work_sync(&sc->idle.timer_work);
2309 	destroy_caches(sc);
2310 	sc->status = SMBDIRECT_SOCKET_NEGOTIATE_FAILED;
2311 	rdma_disconnect(sc->rdma.cm_id);
2312 	wait_event(sc->status_wait,
2313 		sc->status == SMBDIRECT_SOCKET_DISCONNECTED);
2314 
2315 allocate_cache_failed:
2316 rdma_connect_failed:
2317 	rdma_destroy_qp(sc->rdma.cm_id);
2318 
2319 create_qp_failed:
2320 alloc_cq_failed:
2321 	if (sc->ib.send_cq)
2322 		ib_free_cq(sc->ib.send_cq);
2323 	if (sc->ib.recv_cq)
2324 		ib_free_cq(sc->ib.recv_cq);
2325 
2326 	ib_dealloc_pd(sc->ib.pd);
2327 
2328 alloc_pd_failed:
2329 config_failed:
2330 	rdma_destroy_id(sc->rdma.cm_id);
2331 
2332 create_id_failed:
2333 	destroy_workqueue(sc->workqueue);
2334 create_wq_failed:
2335 	kfree(info);
2336 	return NULL;
2337 }
2338 
2339 struct smbd_connection *smbd_get_connection(
2340 	struct TCP_Server_Info *server, struct sockaddr *dstaddr)
2341 {
2342 	struct smbd_connection *ret;
2343 	const struct smbdirect_socket_parameters *sp;
2344 	int port = SMBD_PORT;
2345 
2346 try_again:
2347 	ret = _smbd_get_connection(server, dstaddr, port);
2348 
2349 	/* Try SMB_PORT if SMBD_PORT doesn't work */
2350 	if (!ret && port == SMBD_PORT) {
2351 		port = SMB_PORT;
2352 		goto try_again;
2353 	}
2354 	if (!ret)
2355 		return NULL;
2356 
2357 	sp = &ret->socket.parameters;
2358 
2359 	server->rdma_readwrite_threshold =
2360 		rdma_readwrite_threshold > sp->max_fragmented_send_size ?
2361 		sp->max_fragmented_send_size :
2362 		rdma_readwrite_threshold;
2363 
2364 	return ret;
2365 }
2366 
2367 /*
2368  * Receive data from the transport's receive reassembly queue
2369  * All the incoming data packets are placed in reassembly queue
2370  * iter: the buffer to read data into
2371  * size: the length of data to read
2372  * return value: actual data read
2373  *
2374  * Note: this implementation copies the data from reassembly queue to receive
2375  * buffers used by upper layer. This is not the optimal code path. A better way
2376  * to do it is to not have upper layer allocate its receive buffers but rather
2377  * borrow the buffer from reassembly queue, and return it after data is
2378  * consumed. But this will require more changes to upper layer code, and also
2379  * need to consider packet boundaries while they still being reassembled.
2380  */
2381 int smbd_recv(struct smbd_connection *info, struct msghdr *msg)
2382 {
2383 	struct smbdirect_socket *sc = &info->socket;
2384 	struct smbdirect_recv_io *response;
2385 	struct smbdirect_data_transfer *data_transfer;
2386 	size_t size = iov_iter_count(&msg->msg_iter);
2387 	int to_copy, to_read, data_read, offset;
2388 	u32 data_length, remaining_data_length, data_offset;
2389 	int rc;
2390 
2391 	if (WARN_ON_ONCE(iov_iter_rw(&msg->msg_iter) == WRITE))
2392 		return -EINVAL; /* It's a bug in upper layer to get there */
2393 
2394 again:
2395 	/*
2396 	 * No need to hold the reassembly queue lock all the time as we are
2397 	 * the only one reading from the front of the queue. The transport
2398 	 * may add more entries to the back of the queue at the same time
2399 	 */
2400 	log_read(INFO, "size=%zd sc->recv_io.reassembly.data_length=%d\n", size,
2401 		sc->recv_io.reassembly.data_length);
2402 	if (sc->recv_io.reassembly.data_length >= size) {
2403 		int queue_length;
2404 		int queue_removed = 0;
2405 		unsigned long flags;
2406 
2407 		/*
2408 		 * Need to make sure reassembly_data_length is read before
2409 		 * reading reassembly_queue_length and calling
2410 		 * _get_first_reassembly. This call is lock free
2411 		 * as we never read at the end of the queue which are being
2412 		 * updated in SOFTIRQ as more data is received
2413 		 */
2414 		virt_rmb();
2415 		queue_length = sc->recv_io.reassembly.queue_length;
2416 		data_read = 0;
2417 		to_read = size;
2418 		offset = sc->recv_io.reassembly.first_entry_offset;
2419 		while (data_read < size) {
2420 			response = _get_first_reassembly(sc);
2421 			data_transfer = smbdirect_recv_io_payload(response);
2422 			data_length = le32_to_cpu(data_transfer->data_length);
2423 			remaining_data_length =
2424 				le32_to_cpu(
2425 					data_transfer->remaining_data_length);
2426 			data_offset = le32_to_cpu(data_transfer->data_offset);
2427 
2428 			/*
2429 			 * The upper layer expects RFC1002 length at the
2430 			 * beginning of the payload. Return it to indicate
2431 			 * the total length of the packet. This minimize the
2432 			 * change to upper layer packet processing logic. This
2433 			 * will be eventually remove when an intermediate
2434 			 * transport layer is added
2435 			 */
2436 			if (response->first_segment && size == 4) {
2437 				unsigned int rfc1002_len =
2438 					data_length + remaining_data_length;
2439 				__be32 rfc1002_hdr = cpu_to_be32(rfc1002_len);
2440 				if (copy_to_iter(&rfc1002_hdr, sizeof(rfc1002_hdr),
2441 						 &msg->msg_iter) != sizeof(rfc1002_hdr))
2442 					return -EFAULT;
2443 				data_read = 4;
2444 				response->first_segment = false;
2445 				log_read(INFO, "returning rfc1002 length %d\n",
2446 					rfc1002_len);
2447 				goto read_rfc1002_done;
2448 			}
2449 
2450 			to_copy = min_t(int, data_length - offset, to_read);
2451 			if (copy_to_iter((char *)data_transfer + data_offset + offset,
2452 					 to_copy, &msg->msg_iter) != to_copy)
2453 				return -EFAULT;
2454 
2455 			/* move on to the next buffer? */
2456 			if (to_copy == data_length - offset) {
2457 				queue_length--;
2458 				/*
2459 				 * No need to lock if we are not at the
2460 				 * end of the queue
2461 				 */
2462 				if (queue_length)
2463 					list_del(&response->list);
2464 				else {
2465 					spin_lock_irqsave(
2466 						&sc->recv_io.reassembly.lock, flags);
2467 					list_del(&response->list);
2468 					spin_unlock_irqrestore(
2469 						&sc->recv_io.reassembly.lock, flags);
2470 				}
2471 				queue_removed++;
2472 				sc->statistics.dequeue_reassembly_queue++;
2473 				put_receive_buffer(sc, response);
2474 				offset = 0;
2475 				log_read(INFO, "put_receive_buffer offset=0\n");
2476 			} else
2477 				offset += to_copy;
2478 
2479 			to_read -= to_copy;
2480 			data_read += to_copy;
2481 
2482 			log_read(INFO, "_get_first_reassembly memcpy %d bytes data_transfer_length-offset=%d after that to_read=%d data_read=%d offset=%d\n",
2483 				 to_copy, data_length - offset,
2484 				 to_read, data_read, offset);
2485 		}
2486 
2487 		spin_lock_irqsave(&sc->recv_io.reassembly.lock, flags);
2488 		sc->recv_io.reassembly.data_length -= data_read;
2489 		sc->recv_io.reassembly.queue_length -= queue_removed;
2490 		spin_unlock_irqrestore(&sc->recv_io.reassembly.lock, flags);
2491 
2492 		sc->recv_io.reassembly.first_entry_offset = offset;
2493 		log_read(INFO, "returning to thread data_read=%d reassembly_data_length=%d first_entry_offset=%d\n",
2494 			 data_read, sc->recv_io.reassembly.data_length,
2495 			 sc->recv_io.reassembly.first_entry_offset);
2496 read_rfc1002_done:
2497 		return data_read;
2498 	}
2499 
2500 	log_read(INFO, "wait_event on more data\n");
2501 	rc = wait_event_interruptible(
2502 		sc->recv_io.reassembly.wait_queue,
2503 		sc->recv_io.reassembly.data_length >= size ||
2504 			sc->status != SMBDIRECT_SOCKET_CONNECTED);
2505 	/* Don't return any data if interrupted */
2506 	if (rc)
2507 		return rc;
2508 
2509 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
2510 		log_read(ERR, "disconnected\n");
2511 		return -ECONNABORTED;
2512 	}
2513 
2514 	goto again;
2515 }
2516 
2517 /*
2518  * Send data to transport
2519  * Each rqst is transported as a SMBDirect payload
2520  * rqst: the data to write
2521  * return value: 0 if successfully write, otherwise error code
2522  */
2523 int smbd_send(struct TCP_Server_Info *server,
2524 	int num_rqst, struct smb_rqst *rqst_array)
2525 {
2526 	struct smbd_connection *info = server->smbd_conn;
2527 	struct smbdirect_socket *sc = &info->socket;
2528 	struct smbdirect_socket_parameters *sp = &sc->parameters;
2529 	struct smb_rqst *rqst;
2530 	struct iov_iter iter;
2531 	struct smbdirect_send_batch batch;
2532 	unsigned int remaining_data_length, klen;
2533 	int rc, i, rqst_idx;
2534 	int error = 0;
2535 
2536 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED)
2537 		return -EAGAIN;
2538 
2539 	/*
2540 	 * Add in the page array if there is one. The caller needs to set
2541 	 * rq_tailsz to PAGE_SIZE when the buffer has multiple pages and
2542 	 * ends at page boundary
2543 	 */
2544 	remaining_data_length = 0;
2545 	for (i = 0; i < num_rqst; i++)
2546 		remaining_data_length += smb_rqst_len(server, &rqst_array[i]);
2547 
2548 	if (unlikely(remaining_data_length > sp->max_fragmented_send_size)) {
2549 		/* assertion: payload never exceeds negotiated maximum */
2550 		log_write(ERR, "payload size %d > max size %d\n",
2551 			remaining_data_length, sp->max_fragmented_send_size);
2552 		return -EINVAL;
2553 	}
2554 
2555 	log_write(INFO, "num_rqst=%d total length=%u\n",
2556 			num_rqst, remaining_data_length);
2557 
2558 	rqst_idx = 0;
2559 	smbd_send_batch_init(&batch, false, 0);
2560 	do {
2561 		rqst = &rqst_array[rqst_idx];
2562 
2563 		cifs_dbg(FYI, "Sending smb (RDMA): idx=%d smb_len=%lu\n",
2564 			 rqst_idx, smb_rqst_len(server, rqst));
2565 		for (i = 0; i < rqst->rq_nvec; i++)
2566 			dump_smb(rqst->rq_iov[i].iov_base, rqst->rq_iov[i].iov_len);
2567 
2568 		log_write(INFO, "RDMA-WR[%u] nvec=%d len=%u iter=%zu rqlen=%lu\n",
2569 			  rqst_idx, rqst->rq_nvec, remaining_data_length,
2570 			  iov_iter_count(&rqst->rq_iter), smb_rqst_len(server, rqst));
2571 
2572 		/* Send the metadata pages. */
2573 		klen = 0;
2574 		for (i = 0; i < rqst->rq_nvec; i++)
2575 			klen += rqst->rq_iov[i].iov_len;
2576 		iov_iter_kvec(&iter, ITER_SOURCE, rqst->rq_iov, rqst->rq_nvec, klen);
2577 
2578 		rc = smbd_post_send_full_iter(sc, &batch, &iter, &remaining_data_length);
2579 		if (rc < 0) {
2580 			error = rc;
2581 			break;
2582 		}
2583 
2584 		if (iov_iter_count(&rqst->rq_iter) > 0) {
2585 			/* And then the data pages if there are any */
2586 			rc = smbd_post_send_full_iter(sc, &batch, &rqst->rq_iter,
2587 						      &remaining_data_length);
2588 			if (rc < 0) {
2589 				error = rc;
2590 				break;
2591 			}
2592 		}
2593 
2594 	} while (++rqst_idx < num_rqst);
2595 
2596 	rc = smbd_send_batch_flush(sc, &batch, true);
2597 	if (unlikely(!rc && error))
2598 		rc = error;
2599 
2600 	/*
2601 	 * As an optimization, we don't wait for individual I/O to finish
2602 	 * before sending the next one.
2603 	 * Send them all and wait for pending send count to get to 0
2604 	 * that means all the I/Os have been out and we are good to return
2605 	 */
2606 
2607 	wait_event(sc->send_io.pending.zero_wait_queue,
2608 		atomic_read(&sc->send_io.pending.count) == 0 ||
2609 		sc->status != SMBDIRECT_SOCKET_CONNECTED);
2610 
2611 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED && rc == 0)
2612 		rc = -EAGAIN;
2613 
2614 	return rc;
2615 }
2616 
2617 static void register_mr_done(struct ib_cq *cq, struct ib_wc *wc)
2618 {
2619 	struct smbdirect_mr_io *mr =
2620 		container_of(wc->wr_cqe, struct smbdirect_mr_io, cqe);
2621 	struct smbdirect_socket *sc = mr->socket;
2622 
2623 	if (wc->status) {
2624 		log_rdma_mr(ERR, "status=%d\n", wc->status);
2625 		smbd_disconnect_rdma_connection(sc);
2626 	}
2627 }
2628 
2629 /*
2630  * The work queue function that recovers MRs
2631  * We need to call ib_dereg_mr() and ib_alloc_mr() before this MR can be used
2632  * again. Both calls are slow, so finish them in a workqueue. This will not
2633  * block I/O path.
2634  * There is one workqueue that recovers MRs, there is no need to lock as the
2635  * I/O requests calling smbd_register_mr will never update the links in the
2636  * mr_list.
2637  */
2638 static void smbd_mr_recovery_work(struct work_struct *work)
2639 {
2640 	struct smbdirect_socket *sc =
2641 		container_of(work, struct smbdirect_socket, mr_io.recovery_work);
2642 	struct smbdirect_socket_parameters *sp = &sc->parameters;
2643 	struct smbdirect_mr_io *smbdirect_mr;
2644 	int rc;
2645 
2646 	list_for_each_entry(smbdirect_mr, &sc->mr_io.all.list, list) {
2647 		if (smbdirect_mr->state == SMBDIRECT_MR_ERROR) {
2648 
2649 			/* recover this MR entry */
2650 			rc = ib_dereg_mr(smbdirect_mr->mr);
2651 			if (rc) {
2652 				log_rdma_mr(ERR,
2653 					"ib_dereg_mr failed rc=%x\n",
2654 					rc);
2655 				smbd_disconnect_rdma_connection(sc);
2656 				continue;
2657 			}
2658 
2659 			smbdirect_mr->mr = ib_alloc_mr(
2660 				sc->ib.pd, sc->mr_io.type,
2661 				sp->max_frmr_depth);
2662 			if (IS_ERR(smbdirect_mr->mr)) {
2663 				log_rdma_mr(ERR, "ib_alloc_mr failed mr_type=%x max_frmr_depth=%x\n",
2664 					    sc->mr_io.type,
2665 					    sp->max_frmr_depth);
2666 				smbd_disconnect_rdma_connection(sc);
2667 				continue;
2668 			}
2669 		} else
2670 			/* This MR is being used, don't recover it */
2671 			continue;
2672 
2673 		smbdirect_mr->state = SMBDIRECT_MR_READY;
2674 
2675 		/* smbdirect_mr->state is updated by this function
2676 		 * and is read and updated by I/O issuing CPUs trying
2677 		 * to get a MR, the call to atomic_inc_return
2678 		 * implicates a memory barrier and guarantees this
2679 		 * value is updated before waking up any calls to
2680 		 * get_mr() from the I/O issuing CPUs
2681 		 */
2682 		if (atomic_inc_return(&sc->mr_io.ready.count) == 1)
2683 			wake_up(&sc->mr_io.ready.wait_queue);
2684 	}
2685 }
2686 
2687 static void smbd_mr_disable_locked(struct smbdirect_mr_io *mr)
2688 {
2689 	struct smbdirect_socket *sc = mr->socket;
2690 
2691 	lockdep_assert_held(&mr->mutex);
2692 
2693 	if (mr->state == SMBDIRECT_MR_DISABLED)
2694 		return;
2695 
2696 	if (mr->mr)
2697 		ib_dereg_mr(mr->mr);
2698 	if (mr->sgt.nents)
2699 		ib_dma_unmap_sg(sc->ib.dev, mr->sgt.sgl, mr->sgt.nents, mr->dir);
2700 	kfree(mr->sgt.sgl);
2701 
2702 	mr->mr = NULL;
2703 	mr->sgt.sgl = NULL;
2704 	mr->sgt.nents = 0;
2705 
2706 	mr->state = SMBDIRECT_MR_DISABLED;
2707 }
2708 
2709 static void smbd_mr_free_locked(struct kref *kref)
2710 {
2711 	struct smbdirect_mr_io *mr =
2712 		container_of(kref, struct smbdirect_mr_io, kref);
2713 
2714 	lockdep_assert_held(&mr->mutex);
2715 
2716 	/*
2717 	 * smbd_mr_disable_locked() should already be called!
2718 	 */
2719 	if (WARN_ON_ONCE(mr->state != SMBDIRECT_MR_DISABLED))
2720 		smbd_mr_disable_locked(mr);
2721 
2722 	mutex_unlock(&mr->mutex);
2723 	mutex_destroy(&mr->mutex);
2724 	kfree(mr);
2725 }
2726 
2727 static void destroy_mr_list(struct smbdirect_socket *sc)
2728 {
2729 	struct smbdirect_mr_io *mr, *tmp;
2730 	LIST_HEAD(all_list);
2731 	unsigned long flags;
2732 
2733 	disable_work_sync(&sc->mr_io.recovery_work);
2734 
2735 	spin_lock_irqsave(&sc->mr_io.all.lock, flags);
2736 	list_splice_tail_init(&sc->mr_io.all.list, &all_list);
2737 	spin_unlock_irqrestore(&sc->mr_io.all.lock, flags);
2738 
2739 	list_for_each_entry_safe(mr, tmp, &all_list, list) {
2740 		mutex_lock(&mr->mutex);
2741 
2742 		smbd_mr_disable_locked(mr);
2743 		list_del(&mr->list);
2744 		mr->socket = NULL;
2745 
2746 		/*
2747 		 * No kref_put_mutex() as it's already locked.
2748 		 *
2749 		 * If smbd_mr_free_locked() is called
2750 		 * and the mutex is unlocked and mr is gone,
2751 		 * in that case kref_put() returned 1.
2752 		 *
2753 		 * If kref_put() returned 0 we know that
2754 		 * smbd_mr_free_locked() didn't
2755 		 * run. Not by us nor by anyone else, as we
2756 		 * still hold the mutex, so we need to unlock.
2757 		 *
2758 		 * If the mr is still registered it will
2759 		 * be dangling (detached from the connection
2760 		 * waiting for smbd_deregister_mr() to be
2761 		 * called in order to free the memory.
2762 		 */
2763 		if (!kref_put(&mr->kref, smbd_mr_free_locked))
2764 			mutex_unlock(&mr->mutex);
2765 	}
2766 }
2767 
2768 /*
2769  * Allocate MRs used for RDMA read/write
2770  * The number of MRs will not exceed hardware capability in responder_resources
2771  * All MRs are kept in mr_list. The MR can be recovered after it's used
2772  * Recovery is done in smbd_mr_recovery_work. The content of list entry changes
2773  * as MRs are used and recovered for I/O, but the list links will not change
2774  */
2775 static int allocate_mr_list(struct smbdirect_socket *sc)
2776 {
2777 	struct smbdirect_socket_parameters *sp = &sc->parameters;
2778 	struct smbdirect_mr_io *mr;
2779 	int ret;
2780 	u32 i;
2781 
2782 	if (sp->responder_resources == 0) {
2783 		log_rdma_mr(ERR, "responder_resources negotiated as 0\n");
2784 		return -EINVAL;
2785 	}
2786 
2787 	/* Allocate more MRs (2x) than hardware responder_resources */
2788 	for (i = 0; i < sp->responder_resources * 2; i++) {
2789 		mr = kzalloc(sizeof(*mr), GFP_KERNEL);
2790 		if (!mr) {
2791 			ret = -ENOMEM;
2792 			goto kzalloc_mr_failed;
2793 		}
2794 
2795 		kref_init(&mr->kref);
2796 		mutex_init(&mr->mutex);
2797 
2798 		mr->mr = ib_alloc_mr(sc->ib.pd,
2799 				     sc->mr_io.type,
2800 				     sp->max_frmr_depth);
2801 		if (IS_ERR(mr->mr)) {
2802 			ret = PTR_ERR(mr->mr);
2803 			log_rdma_mr(ERR, "ib_alloc_mr failed mr_type=%x max_frmr_depth=%x\n",
2804 				    sc->mr_io.type, sp->max_frmr_depth);
2805 			goto ib_alloc_mr_failed;
2806 		}
2807 
2808 		mr->sgt.sgl = kcalloc(sp->max_frmr_depth,
2809 				      sizeof(struct scatterlist),
2810 				      GFP_KERNEL);
2811 		if (!mr->sgt.sgl) {
2812 			ret = -ENOMEM;
2813 			log_rdma_mr(ERR, "failed to allocate sgl\n");
2814 			goto kcalloc_sgl_failed;
2815 		}
2816 		mr->state = SMBDIRECT_MR_READY;
2817 		mr->socket = sc;
2818 
2819 		list_add_tail(&mr->list, &sc->mr_io.all.list);
2820 		atomic_inc(&sc->mr_io.ready.count);
2821 	}
2822 
2823 	INIT_WORK(&sc->mr_io.recovery_work, smbd_mr_recovery_work);
2824 
2825 	return 0;
2826 
2827 kcalloc_sgl_failed:
2828 	ib_dereg_mr(mr->mr);
2829 ib_alloc_mr_failed:
2830 	mutex_destroy(&mr->mutex);
2831 	kfree(mr);
2832 kzalloc_mr_failed:
2833 	destroy_mr_list(sc);
2834 	return ret;
2835 }
2836 
2837 /*
2838  * Get a MR from mr_list. This function waits until there is at least one
2839  * MR available in the list. It may access the list while the
2840  * smbd_mr_recovery_work is recovering the MR list. This doesn't need a lock
2841  * as they never modify the same places. However, there may be several CPUs
2842  * issuing I/O trying to get MR at the same time, mr_list_lock is used to
2843  * protect this situation.
2844  */
2845 static struct smbdirect_mr_io *get_mr(struct smbdirect_socket *sc)
2846 {
2847 	struct smbdirect_mr_io *ret;
2848 	unsigned long flags;
2849 	int rc;
2850 again:
2851 	rc = wait_event_interruptible(sc->mr_io.ready.wait_queue,
2852 		atomic_read(&sc->mr_io.ready.count) ||
2853 		sc->status != SMBDIRECT_SOCKET_CONNECTED);
2854 	if (rc) {
2855 		log_rdma_mr(ERR, "wait_event_interruptible rc=%x\n", rc);
2856 		return NULL;
2857 	}
2858 
2859 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
2860 		log_rdma_mr(ERR, "sc->status=%x\n", sc->status);
2861 		return NULL;
2862 	}
2863 
2864 	spin_lock_irqsave(&sc->mr_io.all.lock, flags);
2865 	list_for_each_entry(ret, &sc->mr_io.all.list, list) {
2866 		if (ret->state == SMBDIRECT_MR_READY) {
2867 			ret->state = SMBDIRECT_MR_REGISTERED;
2868 			kref_get(&ret->kref);
2869 			spin_unlock_irqrestore(&sc->mr_io.all.lock, flags);
2870 			atomic_dec(&sc->mr_io.ready.count);
2871 			atomic_inc(&sc->mr_io.used.count);
2872 			return ret;
2873 		}
2874 	}
2875 
2876 	spin_unlock_irqrestore(&sc->mr_io.all.lock, flags);
2877 	/*
2878 	 * It is possible that we could fail to get MR because other processes may
2879 	 * try to acquire a MR at the same time. If this is the case, retry it.
2880 	 */
2881 	goto again;
2882 }
2883 
2884 /*
2885  * Transcribe the pages from an iterator into an MR scatterlist.
2886  */
2887 static int smbd_iter_to_mr(struct iov_iter *iter,
2888 			   struct sg_table *sgt,
2889 			   unsigned int max_sg)
2890 {
2891 	int ret;
2892 
2893 	memset(sgt->sgl, 0, max_sg * sizeof(struct scatterlist));
2894 
2895 	ret = extract_iter_to_sg(iter, iov_iter_count(iter), sgt, max_sg, 0);
2896 	WARN_ON(ret < 0);
2897 	if (sgt->nents > 0)
2898 		sg_mark_end(&sgt->sgl[sgt->nents - 1]);
2899 	return ret;
2900 }
2901 
2902 /*
2903  * Register memory for RDMA read/write
2904  * iter: the buffer to register memory with
2905  * writing: true if this is a RDMA write (SMB read), false for RDMA read
2906  * need_invalidate: true if this MR needs to be locally invalidated after I/O
2907  * return value: the MR registered, NULL if failed.
2908  */
2909 struct smbdirect_mr_io *smbd_register_mr(struct smbd_connection *info,
2910 				 struct iov_iter *iter,
2911 				 bool writing, bool need_invalidate)
2912 {
2913 	struct smbdirect_socket *sc = &info->socket;
2914 	struct smbdirect_socket_parameters *sp = &sc->parameters;
2915 	struct smbdirect_mr_io *mr;
2916 	int rc, num_pages;
2917 	struct ib_reg_wr *reg_wr;
2918 
2919 	num_pages = iov_iter_npages(iter, sp->max_frmr_depth + 1);
2920 	if (num_pages > sp->max_frmr_depth) {
2921 		log_rdma_mr(ERR, "num_pages=%d max_frmr_depth=%d\n",
2922 			num_pages, sp->max_frmr_depth);
2923 		WARN_ON_ONCE(1);
2924 		return NULL;
2925 	}
2926 
2927 	mr = get_mr(sc);
2928 	if (!mr) {
2929 		log_rdma_mr(ERR, "get_mr returning NULL\n");
2930 		return NULL;
2931 	}
2932 
2933 	mutex_lock(&mr->mutex);
2934 
2935 	mr->dir = writing ? DMA_FROM_DEVICE : DMA_TO_DEVICE;
2936 	mr->need_invalidate = need_invalidate;
2937 	mr->sgt.nents = 0;
2938 	mr->sgt.orig_nents = 0;
2939 
2940 	log_rdma_mr(INFO, "num_pages=0x%x count=0x%zx depth=%u\n",
2941 		    num_pages, iov_iter_count(iter), sp->max_frmr_depth);
2942 	smbd_iter_to_mr(iter, &mr->sgt, sp->max_frmr_depth);
2943 
2944 	rc = ib_dma_map_sg(sc->ib.dev, mr->sgt.sgl, mr->sgt.nents, mr->dir);
2945 	if (!rc) {
2946 		log_rdma_mr(ERR, "ib_dma_map_sg num_pages=%x dir=%x rc=%x\n",
2947 			    num_pages, mr->dir, rc);
2948 		goto dma_map_error;
2949 	}
2950 
2951 	rc = ib_map_mr_sg(mr->mr, mr->sgt.sgl, mr->sgt.nents, NULL, PAGE_SIZE);
2952 	if (rc != mr->sgt.nents) {
2953 		log_rdma_mr(ERR,
2954 			    "ib_map_mr_sg failed rc = %d nents = %x\n",
2955 			    rc, mr->sgt.nents);
2956 		goto map_mr_error;
2957 	}
2958 
2959 	ib_update_fast_reg_key(mr->mr, ib_inc_rkey(mr->mr->rkey));
2960 	reg_wr = &mr->wr;
2961 	reg_wr->wr.opcode = IB_WR_REG_MR;
2962 	mr->cqe.done = register_mr_done;
2963 	reg_wr->wr.wr_cqe = &mr->cqe;
2964 	reg_wr->wr.num_sge = 0;
2965 	reg_wr->wr.send_flags = IB_SEND_SIGNALED;
2966 	reg_wr->mr = mr->mr;
2967 	reg_wr->key = mr->mr->rkey;
2968 	reg_wr->access = writing ?
2969 			IB_ACCESS_REMOTE_WRITE | IB_ACCESS_LOCAL_WRITE :
2970 			IB_ACCESS_REMOTE_READ;
2971 
2972 	/*
2973 	 * There is no need for waiting for complemtion on ib_post_send
2974 	 * on IB_WR_REG_MR. Hardware enforces a barrier and order of execution
2975 	 * on the next ib_post_send when we actually send I/O to remote peer
2976 	 */
2977 	rc = ib_post_send(sc->ib.qp, &reg_wr->wr, NULL);
2978 	if (!rc) {
2979 		/*
2980 		 * get_mr() gave us a reference
2981 		 * via kref_get(&mr->kref), we keep that and let
2982 		 * the caller use smbd_deregister_mr()
2983 		 * to remove it again.
2984 		 */
2985 		mutex_unlock(&mr->mutex);
2986 		return mr;
2987 	}
2988 
2989 	log_rdma_mr(ERR, "ib_post_send failed rc=%x reg_wr->key=%x\n",
2990 		rc, reg_wr->key);
2991 
2992 	/* If all failed, attempt to recover this MR by setting it SMBDIRECT_MR_ERROR*/
2993 map_mr_error:
2994 	ib_dma_unmap_sg(sc->ib.dev, mr->sgt.sgl, mr->sgt.nents, mr->dir);
2995 
2996 dma_map_error:
2997 	mr->sgt.nents = 0;
2998 	mr->state = SMBDIRECT_MR_ERROR;
2999 	if (atomic_dec_and_test(&sc->mr_io.used.count))
3000 		wake_up(&sc->mr_io.cleanup.wait_queue);
3001 
3002 	smbd_disconnect_rdma_connection(sc);
3003 
3004 	/*
3005 	 * get_mr() gave us a reference
3006 	 * via kref_get(&mr->kref), we need to remove it again
3007 	 * on error.
3008 	 *
3009 	 * No kref_put_mutex() as it's already locked.
3010 	 *
3011 	 * If smbd_mr_free_locked() is called
3012 	 * and the mutex is unlocked and mr is gone,
3013 	 * in that case kref_put() returned 1.
3014 	 *
3015 	 * If kref_put() returned 0 we know that
3016 	 * smbd_mr_free_locked() didn't
3017 	 * run. Not by us nor by anyone else, as we
3018 	 * still hold the mutex, so we need to unlock.
3019 	 */
3020 	if (!kref_put(&mr->kref, smbd_mr_free_locked))
3021 		mutex_unlock(&mr->mutex);
3022 
3023 	return NULL;
3024 }
3025 
3026 static void local_inv_done(struct ib_cq *cq, struct ib_wc *wc)
3027 {
3028 	struct smbdirect_mr_io *smbdirect_mr;
3029 	struct ib_cqe *cqe;
3030 
3031 	cqe = wc->wr_cqe;
3032 	smbdirect_mr = container_of(cqe, struct smbdirect_mr_io, cqe);
3033 	smbdirect_mr->state = SMBDIRECT_MR_INVALIDATED;
3034 	if (wc->status != IB_WC_SUCCESS) {
3035 		log_rdma_mr(ERR, "invalidate failed status=%x\n", wc->status);
3036 		smbdirect_mr->state = SMBDIRECT_MR_ERROR;
3037 	}
3038 	complete(&smbdirect_mr->invalidate_done);
3039 }
3040 
3041 /*
3042  * Deregister a MR after I/O is done
3043  * This function may wait if remote invalidation is not used
3044  * and we have to locally invalidate the buffer to prevent data is being
3045  * modified by remote peer after upper layer consumes it
3046  */
3047 void smbd_deregister_mr(struct smbdirect_mr_io *mr)
3048 {
3049 	struct smbdirect_socket *sc = mr->socket;
3050 
3051 	mutex_lock(&mr->mutex);
3052 	if (mr->state == SMBDIRECT_MR_DISABLED)
3053 		goto put_kref;
3054 
3055 	if (sc->status != SMBDIRECT_SOCKET_CONNECTED) {
3056 		smbd_mr_disable_locked(mr);
3057 		goto put_kref;
3058 	}
3059 
3060 	if (mr->need_invalidate) {
3061 		struct ib_send_wr *wr = &mr->inv_wr;
3062 		int rc;
3063 
3064 		/* Need to finish local invalidation before returning */
3065 		wr->opcode = IB_WR_LOCAL_INV;
3066 		mr->cqe.done = local_inv_done;
3067 		wr->wr_cqe = &mr->cqe;
3068 		wr->num_sge = 0;
3069 		wr->ex.invalidate_rkey = mr->mr->rkey;
3070 		wr->send_flags = IB_SEND_SIGNALED;
3071 
3072 		init_completion(&mr->invalidate_done);
3073 		rc = ib_post_send(sc->ib.qp, wr, NULL);
3074 		if (rc) {
3075 			log_rdma_mr(ERR, "ib_post_send failed rc=%x\n", rc);
3076 			smbd_mr_disable_locked(mr);
3077 			smbd_disconnect_rdma_connection(sc);
3078 			goto done;
3079 		}
3080 		wait_for_completion(&mr->invalidate_done);
3081 		mr->need_invalidate = false;
3082 	} else
3083 		/*
3084 		 * For remote invalidation, just set it to SMBDIRECT_MR_INVALIDATED
3085 		 * and defer to mr_recovery_work to recover the MR for next use
3086 		 */
3087 		mr->state = SMBDIRECT_MR_INVALIDATED;
3088 
3089 	if (mr->sgt.nents) {
3090 		ib_dma_unmap_sg(sc->ib.dev, mr->sgt.sgl, mr->sgt.nents, mr->dir);
3091 		mr->sgt.nents = 0;
3092 	}
3093 
3094 	if (mr->state == SMBDIRECT_MR_INVALIDATED) {
3095 		mr->state = SMBDIRECT_MR_READY;
3096 		if (atomic_inc_return(&sc->mr_io.ready.count) == 1)
3097 			wake_up(&sc->mr_io.ready.wait_queue);
3098 	} else
3099 		/*
3100 		 * Schedule the work to do MR recovery for future I/Os MR
3101 		 * recovery is slow and don't want it to block current I/O
3102 		 */
3103 		queue_work(sc->workqueue, &sc->mr_io.recovery_work);
3104 
3105 done:
3106 	if (atomic_dec_and_test(&sc->mr_io.used.count))
3107 		wake_up(&sc->mr_io.cleanup.wait_queue);
3108 
3109 put_kref:
3110 	/*
3111 	 * No kref_put_mutex() as it's already locked.
3112 	 *
3113 	 * If smbd_mr_free_locked() is called
3114 	 * and the mutex is unlocked and mr is gone,
3115 	 * in that case kref_put() returned 1.
3116 	 *
3117 	 * If kref_put() returned 0 we know that
3118 	 * smbd_mr_free_locked() didn't
3119 	 * run. Not by us nor by anyone else, as we
3120 	 * still hold the mutex, so we need to unlock
3121 	 * and keep the mr in SMBDIRECT_MR_READY or
3122 	 * SMBDIRECT_MR_ERROR state.
3123 	 */
3124 	if (!kref_put(&mr->kref, smbd_mr_free_locked))
3125 		mutex_unlock(&mr->mutex);
3126 }
3127 
3128 static bool smb_set_sge(struct smb_extract_to_rdma *rdma,
3129 			struct page *lowest_page, size_t off, size_t len)
3130 {
3131 	struct ib_sge *sge = &rdma->sge[rdma->nr_sge];
3132 	u64 addr;
3133 
3134 	addr = ib_dma_map_page(rdma->device, lowest_page,
3135 			       off, len, rdma->direction);
3136 	if (ib_dma_mapping_error(rdma->device, addr))
3137 		return false;
3138 
3139 	sge->addr   = addr;
3140 	sge->length = len;
3141 	sge->lkey   = rdma->local_dma_lkey;
3142 	rdma->nr_sge++;
3143 	return true;
3144 }
3145 
3146 /*
3147  * Extract page fragments from a BVEC-class iterator and add them to an RDMA
3148  * element list.  The pages are not pinned.
3149  */
3150 static ssize_t smb_extract_bvec_to_rdma(struct iov_iter *iter,
3151 					struct smb_extract_to_rdma *rdma,
3152 					ssize_t maxsize)
3153 {
3154 	const struct bio_vec *bv = iter->bvec;
3155 	unsigned long start = iter->iov_offset;
3156 	unsigned int i;
3157 	ssize_t ret = 0;
3158 
3159 	for (i = 0; i < iter->nr_segs; i++) {
3160 		size_t off, len;
3161 
3162 		len = bv[i].bv_len;
3163 		if (start >= len) {
3164 			start -= len;
3165 			continue;
3166 		}
3167 
3168 		len = min_t(size_t, maxsize, len - start);
3169 		off = bv[i].bv_offset + start;
3170 
3171 		if (!smb_set_sge(rdma, bv[i].bv_page, off, len))
3172 			return -EIO;
3173 
3174 		ret += len;
3175 		maxsize -= len;
3176 		if (rdma->nr_sge >= rdma->max_sge || maxsize <= 0)
3177 			break;
3178 		start = 0;
3179 	}
3180 
3181 	if (ret > 0)
3182 		iov_iter_advance(iter, ret);
3183 	return ret;
3184 }
3185 
3186 /*
3187  * Extract fragments from a KVEC-class iterator and add them to an RDMA list.
3188  * This can deal with vmalloc'd buffers as well as kmalloc'd or static buffers.
3189  * The pages are not pinned.
3190  */
3191 static ssize_t smb_extract_kvec_to_rdma(struct iov_iter *iter,
3192 					struct smb_extract_to_rdma *rdma,
3193 					ssize_t maxsize)
3194 {
3195 	const struct kvec *kv = iter->kvec;
3196 	unsigned long start = iter->iov_offset;
3197 	unsigned int i;
3198 	ssize_t ret = 0;
3199 
3200 	for (i = 0; i < iter->nr_segs; i++) {
3201 		struct page *page;
3202 		unsigned long kaddr;
3203 		size_t off, len, seg;
3204 
3205 		len = kv[i].iov_len;
3206 		if (start >= len) {
3207 			start -= len;
3208 			continue;
3209 		}
3210 
3211 		kaddr = (unsigned long)kv[i].iov_base + start;
3212 		off = kaddr & ~PAGE_MASK;
3213 		len = min_t(size_t, maxsize, len - start);
3214 		kaddr &= PAGE_MASK;
3215 
3216 		maxsize -= len;
3217 		do {
3218 			seg = min_t(size_t, len, PAGE_SIZE - off);
3219 
3220 			if (is_vmalloc_or_module_addr((void *)kaddr))
3221 				page = vmalloc_to_page((void *)kaddr);
3222 			else
3223 				page = virt_to_page((void *)kaddr);
3224 
3225 			if (!smb_set_sge(rdma, page, off, seg))
3226 				return -EIO;
3227 
3228 			ret += seg;
3229 			len -= seg;
3230 			kaddr += PAGE_SIZE;
3231 			off = 0;
3232 		} while (len > 0 && rdma->nr_sge < rdma->max_sge);
3233 
3234 		if (rdma->nr_sge >= rdma->max_sge || maxsize <= 0)
3235 			break;
3236 		start = 0;
3237 	}
3238 
3239 	if (ret > 0)
3240 		iov_iter_advance(iter, ret);
3241 	return ret;
3242 }
3243 
3244 /*
3245  * Extract folio fragments from a FOLIOQ-class iterator and add them to an RDMA
3246  * list.  The folios are not pinned.
3247  */
3248 static ssize_t smb_extract_folioq_to_rdma(struct iov_iter *iter,
3249 					  struct smb_extract_to_rdma *rdma,
3250 					  ssize_t maxsize)
3251 {
3252 	const struct folio_queue *folioq = iter->folioq;
3253 	unsigned int slot = iter->folioq_slot;
3254 	ssize_t ret = 0;
3255 	size_t offset = iter->iov_offset;
3256 
3257 	BUG_ON(!folioq);
3258 
3259 	if (slot >= folioq_nr_slots(folioq)) {
3260 		folioq = folioq->next;
3261 		if (WARN_ON_ONCE(!folioq))
3262 			return -EIO;
3263 		slot = 0;
3264 	}
3265 
3266 	do {
3267 		struct folio *folio = folioq_folio(folioq, slot);
3268 		size_t fsize = folioq_folio_size(folioq, slot);
3269 
3270 		if (offset < fsize) {
3271 			size_t part = umin(maxsize, fsize - offset);
3272 
3273 			if (!smb_set_sge(rdma, folio_page(folio, 0), offset, part))
3274 				return -EIO;
3275 
3276 			offset += part;
3277 			ret += part;
3278 			maxsize -= part;
3279 		}
3280 
3281 		if (offset >= fsize) {
3282 			offset = 0;
3283 			slot++;
3284 			if (slot >= folioq_nr_slots(folioq)) {
3285 				if (!folioq->next) {
3286 					WARN_ON_ONCE(ret < iter->count);
3287 					break;
3288 				}
3289 				folioq = folioq->next;
3290 				slot = 0;
3291 			}
3292 		}
3293 	} while (rdma->nr_sge < rdma->max_sge && maxsize > 0);
3294 
3295 	iter->folioq = folioq;
3296 	iter->folioq_slot = slot;
3297 	iter->iov_offset = offset;
3298 	iter->count -= ret;
3299 	return ret;
3300 }
3301 
3302 /*
3303  * Extract page fragments from up to the given amount of the source iterator
3304  * and build up an RDMA list that refers to all of those bits.  The RDMA list
3305  * is appended to, up to the maximum number of elements set in the parameter
3306  * block.
3307  *
3308  * The extracted page fragments are not pinned or ref'd in any way; if an
3309  * IOVEC/UBUF-type iterator is to be used, it should be converted to a
3310  * BVEC-type iterator and the pages pinned, ref'd or otherwise held in some
3311  * way.
3312  */
3313 static ssize_t smb_extract_iter_to_rdma(struct iov_iter *iter, size_t len,
3314 					struct smb_extract_to_rdma *rdma)
3315 {
3316 	ssize_t ret;
3317 	int before = rdma->nr_sge;
3318 
3319 	switch (iov_iter_type(iter)) {
3320 	case ITER_BVEC:
3321 		ret = smb_extract_bvec_to_rdma(iter, rdma, len);
3322 		break;
3323 	case ITER_KVEC:
3324 		ret = smb_extract_kvec_to_rdma(iter, rdma, len);
3325 		break;
3326 	case ITER_FOLIOQ:
3327 		ret = smb_extract_folioq_to_rdma(iter, rdma, len);
3328 		break;
3329 	default:
3330 		WARN_ON_ONCE(1);
3331 		return -EIO;
3332 	}
3333 
3334 	if (ret < 0) {
3335 		while (rdma->nr_sge > before) {
3336 			struct ib_sge *sge = &rdma->sge[rdma->nr_sge--];
3337 
3338 			ib_dma_unmap_single(rdma->device, sge->addr, sge->length,
3339 					    rdma->direction);
3340 			sge->addr = 0;
3341 		}
3342 	}
3343 
3344 	return ret;
3345 }
3346