xref: /linux/fs/smb/smbdirect/socket.c (revision 0fc8f6200d2313278fbf4539bbab74677c685531)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *   Copyright (C) 2017, Microsoft Corporation.
4  *   Copyright (c) 2025, Stefan Metzmacher
5  */
6 
7 #include "internal.h"
8 
9 bool smbdirect_frwr_is_supported(const struct ib_device_attr *attrs)
10 {
11 	/*
12 	 * Test if FRWR (Fast Registration Work Requests) is supported on the
13 	 * device This implementation requires FRWR on RDMA read/write return
14 	 * value: true if it is supported
15 	 */
16 
17 	if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS))
18 		return false;
19 	if (attrs->max_fast_reg_page_list_len == 0)
20 		return false;
21 	return true;
22 }
23 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_frwr_is_supported);
24 
25 static void smbdirect_socket_cleanup_work(struct work_struct *work);
26 
27 static int smbdirect_socket_rdma_event_handler(struct rdma_cm_id *id,
28 					       struct rdma_cm_event *event)
29 {
30 	struct smbdirect_socket *sc = id->context;
31 	int ret = -ESTALE;
32 
33 	/*
34 	 * This should be replaced before any real work
35 	 * starts! So it should never be called!
36 	 */
37 
38 	if (event->event == RDMA_CM_EVENT_DEVICE_REMOVAL)
39 		ret = -ENETDOWN;
40 	if (IS_ERR(SMBDIRECT_DEBUG_ERR_PTR(event->status)))
41 		ret = event->status;
42 	pr_err("%s (first_error=%1pe, expected=%s) => event=%s status=%d => ret=%1pe\n",
43 		smbdirect_socket_status_string(sc->status),
44 		SMBDIRECT_DEBUG_ERR_PTR(sc->first_error),
45 		rdma_event_msg(sc->rdma.expected_event),
46 		rdma_event_msg(event->event),
47 		event->status,
48 		SMBDIRECT_DEBUG_ERR_PTR(ret));
49 	WARN_ONCE(1, "%s should not be called!\n", __func__);
50 	sc->rdma.cm_id = NULL;
51 	return -ESTALE;
52 }
53 
54 int smbdirect_socket_init_new(struct net *net, struct smbdirect_socket *sc)
55 {
56 	struct rdma_cm_id *id;
57 	int ret;
58 
59 	smbdirect_socket_init(sc);
60 
61 	id = rdma_create_id(net,
62 			    smbdirect_socket_rdma_event_handler,
63 			    sc,
64 			    RDMA_PS_TCP,
65 			    IB_QPT_RC);
66 	if (IS_ERR(id)) {
67 		pr_err("%s: rdma_create_id() failed %1pe\n", __func__, id);
68 		return PTR_ERR(id);
69 	}
70 
71 	ret = rdma_set_afonly(id, 1);
72 	if (ret) {
73 		rdma_destroy_id(id);
74 		pr_err("%s: rdma_set_afonly() failed %1pe\n",
75 		       __func__, SMBDIRECT_DEBUG_ERR_PTR(ret));
76 		return ret;
77 	}
78 
79 	sc->rdma.cm_id = id;
80 
81 	INIT_WORK(&sc->disconnect_work, smbdirect_socket_cleanup_work);
82 
83 	return 0;
84 }
85 
86 int smbdirect_socket_create_kern(struct net *net, struct smbdirect_socket **_sc)
87 {
88 	struct smbdirect_socket *sc;
89 	int ret;
90 
91 	ret = -ENOMEM;
92 	sc = kzalloc_obj(*sc);
93 	if (!sc)
94 		goto alloc_failed;
95 
96 	ret = smbdirect_socket_init_new(net, sc);
97 	if (ret)
98 		goto init_failed;
99 
100 	kref_init(&sc->refs.destroy);
101 
102 	*_sc = sc;
103 	return 0;
104 
105 init_failed:
106 	kfree(sc);
107 alloc_failed:
108 	return ret;
109 }
110 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_create_kern);
111 
112 int smbdirect_socket_init_accepting(struct rdma_cm_id *id, struct smbdirect_socket *sc)
113 {
114 	smbdirect_socket_init(sc);
115 
116 	sc->rdma.cm_id = id;
117 	sc->rdma.cm_id->context = sc;
118 	sc->rdma.cm_id->event_handler = smbdirect_socket_rdma_event_handler;
119 
120 	sc->ib.dev = sc->rdma.cm_id->device;
121 
122 	INIT_WORK(&sc->disconnect_work, smbdirect_socket_cleanup_work);
123 
124 	return 0;
125 }
126 
127 int smbdirect_socket_create_accepting(struct rdma_cm_id *id, struct smbdirect_socket **_sc)
128 {
129 	struct smbdirect_socket *sc;
130 	int ret;
131 
132 	ret = -ENOMEM;
133 	sc = kzalloc_obj(*sc);
134 	if (!sc)
135 		goto alloc_failed;
136 
137 	ret = smbdirect_socket_init_accepting(id, sc);
138 	if (ret)
139 		goto init_failed;
140 
141 	kref_init(&sc->refs.destroy);
142 
143 	*_sc = sc;
144 	return 0;
145 
146 init_failed:
147 	kfree(sc);
148 alloc_failed:
149 	return ret;
150 }
151 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_create_accepting);
152 
153 int smbdirect_socket_set_initial_parameters(struct smbdirect_socket *sc,
154 					    const struct smbdirect_socket_parameters *sp)
155 {
156 	/*
157 	 * This is only allowed before connect or accept
158 	 */
159 	WARN_ONCE(sc->status != SMBDIRECT_SOCKET_CREATED,
160 		  "status=%s first_error=%1pe",
161 		  smbdirect_socket_status_string(sc->status),
162 		  SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
163 	if (sc->status != SMBDIRECT_SOCKET_CREATED)
164 		return -EINVAL;
165 
166 	if (sp->flags & ~SMBDIRECT_FLAG_PORT_RANGE_MASK)
167 		return -EINVAL;
168 
169 	if (sp->initiator_depth > U8_MAX)
170 		return -EINVAL;
171 	if (sp->responder_resources > U8_MAX)
172 		return -EINVAL;
173 
174 	if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB &&
175 	    sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW)
176 		return -EINVAL;
177 	else if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB)
178 		rdma_restrict_node_type(sc->rdma.cm_id, RDMA_NODE_IB_CA);
179 	else if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW)
180 		rdma_restrict_node_type(sc->rdma.cm_id, RDMA_NODE_RNIC);
181 
182 	/*
183 	 * Make a copy of the callers parameters
184 	 * from here we only work on the copy
185 	 *
186 	 * TODO: do we want consistency checking?
187 	 */
188 	sc->parameters = *sp;
189 
190 	return 0;
191 }
192 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_set_initial_parameters);
193 
194 const struct smbdirect_socket_parameters *
195 smbdirect_socket_get_current_parameters(struct smbdirect_socket *sc)
196 {
197 	return &sc->parameters;
198 }
199 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_get_current_parameters);
200 
201 int smbdirect_socket_set_kernel_settings(struct smbdirect_socket *sc,
202 					 enum ib_poll_context poll_ctx,
203 					 gfp_t gfp_mask)
204 {
205 	/*
206 	 * This is only allowed before connect or accept
207 	 */
208 	WARN_ONCE(sc->status != SMBDIRECT_SOCKET_CREATED,
209 		  "status=%s first_error=%1pe",
210 		  smbdirect_socket_status_string(sc->status),
211 		  SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
212 	if (sc->status != SMBDIRECT_SOCKET_CREATED)
213 		return -EINVAL;
214 
215 	sc->ib.poll_ctx = poll_ctx;
216 
217 	sc->send_io.mem.gfp_mask = gfp_mask;
218 	sc->recv_io.mem.gfp_mask = gfp_mask;
219 	sc->rw_io.mem.gfp_mask = gfp_mask;
220 
221 	return 0;
222 }
223 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_set_kernel_settings);
224 
225 void smbdirect_socket_set_logging(struct smbdirect_socket *sc,
226 				  void *private_ptr,
227 				  bool (*needed)(struct smbdirect_socket *sc,
228 						 void *private_ptr,
229 						 unsigned int lvl,
230 						 unsigned int cls),
231 				  void (*vaprintf)(struct smbdirect_socket *sc,
232 						   const char *func,
233 						   unsigned int line,
234 						   void *private_ptr,
235 						   unsigned int lvl,
236 						   unsigned int cls,
237 						   struct va_format *vaf))
238 {
239 	sc->logging.private_ptr = private_ptr;
240 	sc->logging.needed = needed;
241 	sc->logging.vaprintf = vaprintf;
242 }
243 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_set_logging);
244 
245 static void smbdirect_socket_wake_up_all(struct smbdirect_socket *sc)
246 {
247 	/*
248 	 * Wake up all waiters in all wait queues
249 	 * in order to notice the broken connection.
250 	 */
251 	wake_up_all(&sc->status_wait);
252 	wake_up_all(&sc->listen.wait_queue);
253 	wake_up_all(&sc->send_io.bcredits.wait_queue);
254 	wake_up_all(&sc->send_io.lcredits.wait_queue);
255 	wake_up_all(&sc->send_io.credits.wait_queue);
256 	wake_up_all(&sc->send_io.pending.zero_wait_queue);
257 	wake_up_all(&sc->recv_io.reassembly.wait_queue);
258 	wake_up_all(&sc->rw_io.credits.wait_queue);
259 	wake_up_all(&sc->mr_io.ready.wait_queue);
260 }
261 
262 void __smbdirect_socket_schedule_cleanup(struct smbdirect_socket *sc,
263 					 const char *macro_name,
264 					 unsigned int lvl,
265 					 const char *func,
266 					 unsigned int line,
267 					 int error,
268 					 enum smbdirect_socket_status *force_status)
269 {
270 	struct smbdirect_socket *psc, *tsc;
271 	unsigned long flags;
272 	bool was_first = false;
273 
274 	if (!sc->first_error) {
275 		___smbdirect_log_generic(sc, func, line,
276 			lvl,
277 			SMBDIRECT_LOG_RDMA_EVENT,
278 			"%s(%1pe%s%s) called from %s in line=%u status=%s\n",
279 			macro_name,
280 			SMBDIRECT_DEBUG_ERR_PTR(error),
281 			force_status ? ", " : "",
282 			force_status ? smbdirect_socket_status_string(*force_status) : "",
283 			func, line,
284 			smbdirect_socket_status_string(sc->status));
285 		if (error)
286 			sc->first_error = error;
287 		else
288 			sc->first_error = -ECONNABORTED;
289 		was_first = true;
290 	}
291 
292 	/*
293 	 * make sure other work (than disconnect_work)
294 	 * is not queued again but here we don't block and avoid
295 	 * disable[_delayed]_work_sync()
296 	 */
297 	disable_work(&sc->connect.work);
298 	disable_work(&sc->recv_io.posted.refill_work);
299 	disable_work(&sc->idle.immediate_work);
300 	sc->idle.keepalive = SMBDIRECT_KEEPALIVE_NONE;
301 	disable_delayed_work(&sc->idle.timer_work);
302 
303 	/*
304 	 * In case we were a listener we need to
305 	 * disconnect all pending and ready sockets
306 	 *
307 	 * First we move ready sockets to pending again.
308 	 */
309 	spin_lock_irqsave(&sc->listen.lock, flags);
310 	list_splice_init(&sc->listen.ready, &sc->listen.pending);
311 	list_for_each_entry_safe(psc, tsc, &sc->listen.pending, accept.list)
312 		smbdirect_socket_schedule_cleanup(psc, sc->first_error);
313 	spin_unlock_irqrestore(&sc->listen.lock, flags);
314 
315 	switch (sc->status) {
316 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED:
317 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED:
318 	case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED:
319 	case SMBDIRECT_SOCKET_NEGOTIATE_FAILED:
320 	case SMBDIRECT_SOCKET_ERROR:
321 	case SMBDIRECT_SOCKET_DISCONNECTING:
322 	case SMBDIRECT_SOCKET_DISCONNECTED:
323 	case SMBDIRECT_SOCKET_DESTROYED:
324 		/*
325 		 * Keep the current error status
326 		 */
327 		break;
328 
329 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED:
330 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING:
331 		sc->status = SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED;
332 		break;
333 
334 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED:
335 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING:
336 		sc->status = SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED;
337 		break;
338 
339 	case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED:
340 	case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING:
341 		sc->status = SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED;
342 		break;
343 
344 	case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED:
345 	case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING:
346 		sc->status = SMBDIRECT_SOCKET_NEGOTIATE_FAILED;
347 		break;
348 
349 	case SMBDIRECT_SOCKET_CREATED:
350 	case SMBDIRECT_SOCKET_LISTENING:
351 		sc->status = SMBDIRECT_SOCKET_DISCONNECTED;
352 		break;
353 
354 	case SMBDIRECT_SOCKET_CONNECTED:
355 		sc->status = SMBDIRECT_SOCKET_ERROR;
356 		break;
357 	}
358 
359 	if (force_status && (was_first || *force_status > sc->status))
360 		sc->status = *force_status;
361 
362 	/*
363 	 * Wake up all waiters in all wait queues
364 	 * in order to notice the broken connection.
365 	 */
366 	smbdirect_socket_wake_up_all(sc);
367 
368 	queue_work(sc->workqueues.cleanup, &sc->disconnect_work);
369 }
370 
371 static void smbdirect_socket_cleanup_work(struct work_struct *work)
372 {
373 	struct smbdirect_socket *sc =
374 		container_of(work, struct smbdirect_socket, disconnect_work);
375 	struct smbdirect_socket *psc, *tsc;
376 	unsigned long flags;
377 
378 	/*
379 	 * This should not never be called in an interrupt!
380 	 */
381 	WARN_ON_ONCE(in_interrupt());
382 
383 	if (!sc->first_error) {
384 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
385 			"%s called with first_error==0\n",
386 			smbdirect_socket_status_string(sc->status));
387 
388 		sc->first_error = -ECONNABORTED;
389 	}
390 
391 	/*
392 	 * make sure this and other work is not queued again
393 	 * but here we don't block and avoid
394 	 * disable[_delayed]_work_sync()
395 	 */
396 	disable_work(&sc->disconnect_work);
397 	disable_work(&sc->connect.work);
398 	disable_work(&sc->recv_io.posted.refill_work);
399 	disable_work(&sc->idle.immediate_work);
400 	sc->idle.keepalive = SMBDIRECT_KEEPALIVE_NONE;
401 	disable_delayed_work(&sc->idle.timer_work);
402 
403 	/*
404 	 * In case we were a listener we need to
405 	 * disconnect all pending and ready sockets
406 	 *
407 	 * First we move ready sockets to pending again.
408 	 */
409 	spin_lock_irqsave(&sc->listen.lock, flags);
410 	list_splice_init(&sc->listen.ready, &sc->listen.pending);
411 	list_for_each_entry_safe(psc, tsc, &sc->listen.pending, accept.list)
412 		smbdirect_socket_schedule_cleanup(psc, sc->first_error);
413 	spin_unlock_irqrestore(&sc->listen.lock, flags);
414 
415 	switch (sc->status) {
416 	case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED:
417 	case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING:
418 	case SMBDIRECT_SOCKET_NEGOTIATE_FAILED:
419 	case SMBDIRECT_SOCKET_CONNECTED:
420 	case SMBDIRECT_SOCKET_ERROR:
421 		sc->status = SMBDIRECT_SOCKET_DISCONNECTING;
422 		/*
423 		 * Make sure we hold the callback lock
424 		 * im order to coordinate with the
425 		 * rdma_event handlers, typically
426 		 * smbdirect_connection_rdma_event_handler(),
427 		 * and smbdirect_socket_destroy().
428 		 *
429 		 * So that the order of ib_drain_qp()
430 		 * and rdma_disconnect() is controlled
431 		 * by the mutex.
432 		 */
433 		rdma_lock_handler(sc->rdma.cm_id);
434 		rdma_disconnect(sc->rdma.cm_id);
435 		rdma_unlock_handler(sc->rdma.cm_id);
436 		break;
437 
438 	case SMBDIRECT_SOCKET_CREATED:
439 	case SMBDIRECT_SOCKET_LISTENING:
440 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED:
441 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING:
442 	case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED:
443 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED:
444 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING:
445 	case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED:
446 	case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED:
447 	case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING:
448 	case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED:
449 		/*
450 		 * rdma_{accept,connect}() never reached
451 		 * RDMA_CM_EVENT_ESTABLISHED
452 		 */
453 		sc->status = SMBDIRECT_SOCKET_DISCONNECTED;
454 		break;
455 
456 	case SMBDIRECT_SOCKET_DISCONNECTING:
457 	case SMBDIRECT_SOCKET_DISCONNECTED:
458 	case SMBDIRECT_SOCKET_DESTROYED:
459 		break;
460 	}
461 
462 	/*
463 	 * Wake up all waiters in all wait queues
464 	 * in order to notice the broken connection.
465 	 */
466 	smbdirect_socket_wake_up_all(sc);
467 }
468 
469 static void smbdirect_socket_destroy(struct smbdirect_socket *sc)
470 {
471 	struct smbdirect_socket *psc, *tsc;
472 	size_t psockets;
473 	struct smbdirect_recv_io *recv_io;
474 	struct smbdirect_recv_io *recv_tmp;
475 	LIST_HEAD(all_list);
476 	unsigned long flags;
477 
478 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
479 		"status=%s first_error=%1pe",
480 		smbdirect_socket_status_string(sc->status),
481 		SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
482 
483 	/*
484 	 * This should not never be called in an interrupt!
485 	 */
486 	WARN_ON_ONCE(in_interrupt());
487 
488 	if (sc->status == SMBDIRECT_SOCKET_DESTROYED)
489 		return;
490 
491 	WARN_ONCE(sc->status != SMBDIRECT_SOCKET_DISCONNECTED,
492 		  "status=%s first_error=%1pe",
493 		  smbdirect_socket_status_string(sc->status),
494 		  SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
495 
496 	/*
497 	 * The listener should clear this before we reach this
498 	 */
499 	WARN_ONCE(sc->accept.listener,
500 		  "status=%s first_error=%1pe",
501 		  smbdirect_socket_status_string(sc->status),
502 		  SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
503 
504 	/*
505 	 * Wake up all waiters in all wait queues
506 	 * in order to notice the broken connection.
507 	 *
508 	 * Most likely this was already called via
509 	 * smbdirect_socket_cleanup_work(), but call it again...
510 	 */
511 	smbdirect_socket_wake_up_all(sc);
512 
513 	disable_work_sync(&sc->disconnect_work);
514 	disable_work_sync(&sc->connect.work);
515 	disable_work_sync(&sc->recv_io.posted.refill_work);
516 	disable_work_sync(&sc->idle.immediate_work);
517 	disable_delayed_work_sync(&sc->idle.timer_work);
518 
519 	if (sc->rdma.cm_id)
520 		rdma_lock_handler(sc->rdma.cm_id);
521 
522 	if (sc->ib.qp) {
523 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
524 			"drain qp\n");
525 		ib_drain_qp(sc->ib.qp);
526 	}
527 
528 	/*
529 	 * In case we were a listener we need to
530 	 * disconnect all pending and ready sockets
531 	 *
532 	 * We move ready sockets to pending again.
533 	 */
534 	spin_lock_irqsave(&sc->listen.lock, flags);
535 	list_splice_tail_init(&sc->listen.ready, &all_list);
536 	list_splice_tail_init(&sc->listen.pending, &all_list);
537 	spin_unlock_irqrestore(&sc->listen.lock, flags);
538 	psockets = list_count_nodes(&all_list);
539 	if (sc->listen.backlog != -1) /* was a listener */
540 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
541 			"release %zu pending sockets\n", psockets);
542 	list_for_each_entry_safe(psc, tsc, &all_list, accept.list) {
543 		list_del_init(&psc->accept.list);
544 		psc->accept.listener = NULL;
545 		smbdirect_socket_release(psc);
546 	}
547 	if (sc->listen.backlog != -1) /* was a listener */
548 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
549 			"released %zu pending sockets\n", psockets);
550 	INIT_LIST_HEAD(&all_list);
551 
552 	/* It's not possible for upper layer to get to reassembly */
553 	if (sc->listen.backlog == -1) /* was not a listener */
554 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
555 			"drain the reassembly queue\n");
556 	spin_lock_irqsave(&sc->recv_io.reassembly.lock, flags);
557 	list_splice_tail_init(&sc->recv_io.reassembly.list, &all_list);
558 	spin_unlock_irqrestore(&sc->recv_io.reassembly.lock, flags);
559 	list_for_each_entry_safe(recv_io, recv_tmp, &all_list, list)
560 		smbdirect_connection_put_recv_io(recv_io);
561 	sc->recv_io.reassembly.data_length = 0;
562 
563 	if (sc->listen.backlog == -1) /* was not a listener */
564 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
565 			"freeing mr list\n");
566 	smbdirect_connection_destroy_mr_list(sc);
567 
568 	if (sc->listen.backlog == -1) /* was not a listener */
569 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
570 			"destroying qp\n");
571 	smbdirect_connection_destroy_qp(sc);
572 	if (sc->rdma.cm_id) {
573 		rdma_unlock_handler(sc->rdma.cm_id);
574 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
575 			"destroying cm_id\n");
576 		rdma_destroy_id(sc->rdma.cm_id);
577 		sc->rdma.cm_id = NULL;
578 	}
579 
580 	if (sc->listen.backlog == -1) /* was not a listener */
581 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
582 			"destroying mem pools\n");
583 	smbdirect_connection_destroy_mem_pools(sc);
584 
585 	sc->status = SMBDIRECT_SOCKET_DESTROYED;
586 
587 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
588 		"rdma session destroyed\n");
589 }
590 
591 void smbdirect_socket_destroy_sync(struct smbdirect_socket *sc)
592 {
593 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
594 		"status=%s first_error=%1pe",
595 		smbdirect_socket_status_string(sc->status),
596 		SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
597 
598 	/*
599 	 * This should not never be called in an interrupt!
600 	 */
601 	WARN_ON_ONCE(in_interrupt());
602 
603 	/*
604 	 * First we try to disable the work
605 	 * without disable_work_sync() in a
606 	 * non blocking way, if it's already
607 	 * running it will be handles by
608 	 * disable_work_sync() below.
609 	 *
610 	 * Here we just want to make sure queue_work() in
611 	 * smbdirect_socket_schedule_cleanup_lvl()
612 	 * is a no-op.
613 	 */
614 	disable_work(&sc->disconnect_work);
615 
616 	if (!sc->first_error)
617 		/*
618 		 * SMBDIRECT_LOG_INFO is enough here
619 		 * as this is the typical case where
620 		 * we terminate the connection ourself.
621 		 */
622 		smbdirect_socket_schedule_cleanup_lvl(sc,
623 						      SMBDIRECT_LOG_INFO,
624 						      -ESHUTDOWN);
625 
626 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
627 		"cancelling and disable disconnect_work\n");
628 	disable_work_sync(&sc->disconnect_work);
629 
630 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
631 		"destroying rdma session\n");
632 	if (sc->status < SMBDIRECT_SOCKET_DISCONNECTING)
633 		smbdirect_socket_cleanup_work(&sc->disconnect_work);
634 	if (sc->status < SMBDIRECT_SOCKET_DISCONNECTED) {
635 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
636 			"wait for transport being disconnected\n");
637 		wait_event(sc->status_wait, sc->status == SMBDIRECT_SOCKET_DISCONNECTED);
638 		smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
639 			"waited for transport being disconnected\n");
640 	}
641 
642 	/*
643 	 * Once we reached SMBDIRECT_SOCKET_DISCONNECTED,
644 	 * we should call smbdirect_socket_destroy()
645 	 */
646 	smbdirect_socket_destroy(sc);
647 	smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
648 		"status=%s first_error=%1pe",
649 		smbdirect_socket_status_string(sc->status),
650 		SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
651 }
652 
653 int smbdirect_socket_bind(struct smbdirect_socket *sc, struct sockaddr *addr)
654 {
655 	int ret;
656 
657 	if (sc->status != SMBDIRECT_SOCKET_CREATED)
658 		return -EINVAL;
659 
660 	ret = rdma_bind_addr(sc->rdma.cm_id, addr);
661 	if (ret)
662 		return ret;
663 
664 	return 0;
665 }
666 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_bind);
667 
668 void smbdirect_socket_shutdown(struct smbdirect_socket *sc)
669 {
670 	smbdirect_socket_schedule_cleanup_lvl(sc, SMBDIRECT_LOG_INFO, -ESHUTDOWN);
671 }
672 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_shutdown);
673 
674 static void smbdirect_socket_release_disconnect(struct kref *kref)
675 {
676 	struct smbdirect_socket *sc =
677 		container_of(kref, struct smbdirect_socket, refs.disconnect);
678 
679 	/*
680 	 * For now do a sync disconnect/destroy
681 	 */
682 	smbdirect_socket_destroy_sync(sc);
683 }
684 
685 static void smbdirect_socket_release_destroy(struct kref *kref)
686 {
687 	struct smbdirect_socket *sc =
688 		container_of(kref, struct smbdirect_socket, refs.destroy);
689 
690 	/*
691 	 * Do a sync disconnect/destroy...
692 	 * hopefully a no-op, as it should be already
693 	 * in DESTROYED state, before we free the memory.
694 	 */
695 	smbdirect_socket_destroy_sync(sc);
696 	kfree(sc);
697 }
698 
699 void smbdirect_socket_release(struct smbdirect_socket *sc)
700 {
701 	/*
702 	 * We expect only 1 disconnect reference
703 	 * and if it is already 0, it's a use after free!
704 	 */
705 	WARN_ON_ONCE(kref_read(&sc->refs.disconnect) != 1);
706 	WARN_ON(!kref_put(&sc->refs.disconnect, smbdirect_socket_release_disconnect));
707 
708 	/*
709 	 * This may not trigger smbdirect_socket_release_destroy(),
710 	 * if struct smbdirect_socket is embedded in another structure
711 	 * indicated by REFCOUNT_MAX.
712 	 */
713 	kref_put(&sc->refs.destroy, smbdirect_socket_release_destroy);
714 }
715 __SMBDIRECT_EXPORT_SYMBOL__(smbdirect_socket_release);
716 
717 int smbdirect_socket_wait_for_credits(struct smbdirect_socket *sc,
718 				      enum smbdirect_socket_status expected_status,
719 				      int unexpected_errno,
720 				      wait_queue_head_t *waitq,
721 				      atomic_t *total_credits,
722 				      int needed)
723 {
724 	int ret;
725 
726 	if (WARN_ON_ONCE(needed < 0))
727 		return -EINVAL;
728 
729 	do {
730 		if (atomic_sub_return(needed, total_credits) >= 0)
731 			return 0;
732 
733 		atomic_add(needed, total_credits);
734 		ret = wait_event_interruptible(*waitq,
735 					       atomic_read(total_credits) >= needed ||
736 					       sc->status != expected_status);
737 
738 		if (sc->status != expected_status)
739 			return unexpected_errno;
740 		else if (ret < 0)
741 			return ret;
742 	} while (true);
743 }
744