1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3 * Copyright (C) 2017, Microsoft Corporation.
4 * Copyright (c) 2025, Stefan Metzmacher
5 */
6
7 #include "internal.h"
8
smbdirect_frwr_is_supported(const struct ib_device_attr * attrs)9 bool smbdirect_frwr_is_supported(const struct ib_device_attr *attrs)
10 {
11 /*
12 * Test if FRWR (Fast Registration Work Requests) is supported on the
13 * device This implementation requires FRWR on RDMA read/write return
14 * value: true if it is supported
15 */
16
17 if (!(attrs->device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS))
18 return false;
19 if (attrs->max_fast_reg_page_list_len == 0)
20 return false;
21 return true;
22 }
23 EXPORT_SYMBOL_GPL(smbdirect_frwr_is_supported);
24
25 static void smbdirect_socket_cleanup_work(struct work_struct *work);
26
smbdirect_socket_rdma_event_handler(struct rdma_cm_id * id,struct rdma_cm_event * event)27 static int smbdirect_socket_rdma_event_handler(struct rdma_cm_id *id,
28 struct rdma_cm_event *event)
29 {
30 struct smbdirect_socket *sc = id->context;
31 int ret = -ESTALE;
32
33 /*
34 * This should be replaced before any real work
35 * starts! So it should never be called!
36 */
37
38 if (event->event == RDMA_CM_EVENT_DEVICE_REMOVAL)
39 ret = -ENETDOWN;
40 if (IS_ERR(SMBDIRECT_DEBUG_ERR_PTR(event->status)))
41 ret = event->status;
42 pr_err("%s (first_error=%1pe, expected=%s) => event=%s status=%d => ret=%1pe\n",
43 smbdirect_socket_status_string(sc->status),
44 SMBDIRECT_DEBUG_ERR_PTR(sc->first_error),
45 rdma_event_msg(sc->rdma.expected_event),
46 rdma_event_msg(event->event),
47 event->status,
48 SMBDIRECT_DEBUG_ERR_PTR(ret));
49 WARN_ONCE(1, "%s should not be called!\n", __func__);
50 sc->rdma.cm_id = NULL;
51 return -ESTALE;
52 }
53
smbdirect_socket_init_new(struct net * net,struct smbdirect_socket * sc)54 int smbdirect_socket_init_new(struct net *net, struct smbdirect_socket *sc)
55 {
56 struct rdma_cm_id *id;
57 int ret;
58
59 smbdirect_socket_init(sc);
60
61 id = rdma_create_id(net,
62 smbdirect_socket_rdma_event_handler,
63 sc,
64 RDMA_PS_TCP,
65 IB_QPT_RC);
66 if (IS_ERR(id)) {
67 pr_err("%s: rdma_create_id() failed %1pe\n", __func__, id);
68 return PTR_ERR(id);
69 }
70
71 ret = rdma_set_afonly(id, 1);
72 if (ret) {
73 rdma_destroy_id(id);
74 pr_err("%s: rdma_set_afonly() failed %1pe\n",
75 __func__, SMBDIRECT_DEBUG_ERR_PTR(ret));
76 return ret;
77 }
78
79 sc->rdma.cm_id = id;
80
81 INIT_WORK(&sc->disconnect_work, smbdirect_socket_cleanup_work);
82
83 return 0;
84 }
85
smbdirect_socket_create_kern(struct net * net,struct smbdirect_socket ** _sc)86 int smbdirect_socket_create_kern(struct net *net, struct smbdirect_socket **_sc)
87 {
88 struct smbdirect_socket *sc;
89 int ret;
90
91 ret = -ENOMEM;
92 sc = kzalloc_obj(*sc);
93 if (!sc)
94 goto alloc_failed;
95
96 ret = smbdirect_socket_init_new(net, sc);
97 if (ret)
98 goto init_failed;
99
100 kref_init(&sc->refs.destroy);
101
102 *_sc = sc;
103 return 0;
104
105 init_failed:
106 kfree(sc);
107 alloc_failed:
108 return ret;
109 }
110 EXPORT_SYMBOL_GPL(smbdirect_socket_create_kern);
111
smbdirect_socket_init_accepting(struct rdma_cm_id * id,struct smbdirect_socket * sc)112 int smbdirect_socket_init_accepting(struct rdma_cm_id *id, struct smbdirect_socket *sc)
113 {
114 smbdirect_socket_init(sc);
115
116 sc->rdma.cm_id = id;
117 sc->rdma.cm_id->context = sc;
118 sc->rdma.cm_id->event_handler = smbdirect_socket_rdma_event_handler;
119
120 sc->ib.dev = sc->rdma.cm_id->device;
121
122 INIT_WORK(&sc->disconnect_work, smbdirect_socket_cleanup_work);
123
124 return 0;
125 }
126
smbdirect_socket_create_accepting(struct rdma_cm_id * id,struct smbdirect_socket ** _sc)127 int smbdirect_socket_create_accepting(struct rdma_cm_id *id, struct smbdirect_socket **_sc)
128 {
129 struct smbdirect_socket *sc;
130 int ret;
131
132 ret = -ENOMEM;
133 sc = kzalloc_obj(*sc);
134 if (!sc)
135 goto alloc_failed;
136
137 ret = smbdirect_socket_init_accepting(id, sc);
138 if (ret)
139 goto init_failed;
140
141 kref_init(&sc->refs.destroy);
142
143 *_sc = sc;
144 return 0;
145
146 init_failed:
147 kfree(sc);
148 alloc_failed:
149 return ret;
150 }
151 EXPORT_SYMBOL_GPL(smbdirect_socket_create_accepting);
152
smbdirect_socket_set_initial_parameters(struct smbdirect_socket * sc,const struct smbdirect_socket_parameters * sp)153 int smbdirect_socket_set_initial_parameters(struct smbdirect_socket *sc,
154 const struct smbdirect_socket_parameters *sp)
155 {
156 /*
157 * This is only allowed before connect or accept
158 */
159 WARN_ONCE(sc->status != SMBDIRECT_SOCKET_CREATED,
160 "status=%s first_error=%1pe",
161 smbdirect_socket_status_string(sc->status),
162 SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
163 if (sc->status != SMBDIRECT_SOCKET_CREATED)
164 return -EINVAL;
165
166 if (sp->flags & ~SMBDIRECT_FLAG_PORT_RANGE_MASK)
167 return -EINVAL;
168
169 if (sp->initiator_depth > U8_MAX)
170 return -EINVAL;
171 if (sp->responder_resources > U8_MAX)
172 return -EINVAL;
173
174 if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB &&
175 sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW)
176 return -EINVAL;
177 else if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IB)
178 rdma_restrict_node_type(sc->rdma.cm_id, RDMA_NODE_IB_CA);
179 else if (sp->flags & SMBDIRECT_FLAG_PORT_RANGE_ONLY_IW)
180 rdma_restrict_node_type(sc->rdma.cm_id, RDMA_NODE_RNIC);
181
182 /*
183 * Make a copy of the callers parameters
184 * from here we only work on the copy
185 *
186 * TODO: do we want consistency checking?
187 */
188 sc->parameters = *sp;
189
190 return 0;
191 }
192 EXPORT_SYMBOL_GPL(smbdirect_socket_set_initial_parameters);
193
194 const struct smbdirect_socket_parameters *
smbdirect_socket_get_current_parameters(struct smbdirect_socket * sc)195 smbdirect_socket_get_current_parameters(struct smbdirect_socket *sc)
196 {
197 return &sc->parameters;
198 }
199 EXPORT_SYMBOL_GPL(smbdirect_socket_get_current_parameters);
200
smbdirect_socket_set_kernel_settings(struct smbdirect_socket * sc,enum ib_poll_context poll_ctx,gfp_t gfp_mask)201 int smbdirect_socket_set_kernel_settings(struct smbdirect_socket *sc,
202 enum ib_poll_context poll_ctx,
203 gfp_t gfp_mask)
204 {
205 /*
206 * This is only allowed before connect or accept
207 */
208 WARN_ONCE(sc->status != SMBDIRECT_SOCKET_CREATED,
209 "status=%s first_error=%1pe",
210 smbdirect_socket_status_string(sc->status),
211 SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
212 if (sc->status != SMBDIRECT_SOCKET_CREATED)
213 return -EINVAL;
214
215 sc->ib.poll_ctx = poll_ctx;
216
217 sc->send_io.mem.gfp_mask = gfp_mask;
218 sc->recv_io.mem.gfp_mask = gfp_mask;
219 sc->rw_io.mem.gfp_mask = gfp_mask;
220
221 return 0;
222 }
223 EXPORT_SYMBOL_GPL(smbdirect_socket_set_kernel_settings);
224
smbdirect_socket_set_logging(struct smbdirect_socket * sc,void * private_ptr,bool (* needed)(struct smbdirect_socket * sc,void * private_ptr,unsigned int lvl,unsigned int cls),void (* vaprintf)(struct smbdirect_socket * sc,const char * func,unsigned int line,void * private_ptr,unsigned int lvl,unsigned int cls,struct va_format * vaf))225 void smbdirect_socket_set_logging(struct smbdirect_socket *sc,
226 void *private_ptr,
227 bool (*needed)(struct smbdirect_socket *sc,
228 void *private_ptr,
229 unsigned int lvl,
230 unsigned int cls),
231 void (*vaprintf)(struct smbdirect_socket *sc,
232 const char *func,
233 unsigned int line,
234 void *private_ptr,
235 unsigned int lvl,
236 unsigned int cls,
237 struct va_format *vaf))
238 {
239 sc->logging.private_ptr = private_ptr;
240 sc->logging.needed = needed;
241 sc->logging.vaprintf = vaprintf;
242 }
243 EXPORT_SYMBOL_GPL(smbdirect_socket_set_logging);
244
smbdirect_socket_wake_up_all(struct smbdirect_socket * sc)245 static void smbdirect_socket_wake_up_all(struct smbdirect_socket *sc)
246 {
247 /*
248 * Wake up all waiters in all wait queues
249 * in order to notice the broken connection.
250 */
251 wake_up_all(&sc->status_wait);
252 wake_up_all(&sc->listen.wait_queue);
253 wake_up_all(&sc->send_io.bcredits.wait_queue);
254 wake_up_all(&sc->send_io.lcredits.wait_queue);
255 wake_up_all(&sc->send_io.credits.wait_queue);
256 wake_up_all(&sc->send_io.pending.zero_wait_queue);
257 wake_up_all(&sc->recv_io.reassembly.wait_queue);
258 wake_up_all(&sc->rw_io.credits.wait_queue);
259 wake_up_all(&sc->mr_io.ready.wait_queue);
260 }
261
__smbdirect_socket_schedule_cleanup(struct smbdirect_socket * sc,const char * macro_name,unsigned int lvl,const char * func,unsigned int line,int error,enum smbdirect_socket_status * force_status)262 void __smbdirect_socket_schedule_cleanup(struct smbdirect_socket *sc,
263 const char *macro_name,
264 unsigned int lvl,
265 const char *func,
266 unsigned int line,
267 int error,
268 enum smbdirect_socket_status *force_status)
269 {
270 struct smbdirect_socket *psc, *tsc;
271 unsigned long flags;
272 bool was_first = false;
273
274 if (!sc->first_error) {
275 ___smbdirect_log_generic(sc, func, line,
276 lvl,
277 SMBDIRECT_LOG_RDMA_EVENT,
278 "%s(%1pe%s%s) called from %s in line=%u status=%s\n",
279 macro_name,
280 SMBDIRECT_DEBUG_ERR_PTR(error),
281 force_status ? ", " : "",
282 force_status ? smbdirect_socket_status_string(*force_status) : "",
283 func, line,
284 smbdirect_socket_status_string(sc->status));
285 if (error)
286 sc->first_error = error;
287 else
288 sc->first_error = -ECONNABORTED;
289 was_first = true;
290 }
291
292 /*
293 * make sure other work (than disconnect_work)
294 * is not queued again but here we don't block and avoid
295 * disable[_delayed]_work_sync()
296 */
297 disable_work(&sc->connect.work);
298 disable_work(&sc->recv_io.posted.refill_work);
299 disable_work(&sc->idle.immediate_work);
300 sc->idle.keepalive = SMBDIRECT_KEEPALIVE_NONE;
301 disable_delayed_work(&sc->idle.timer_work);
302
303 /*
304 * In case we were a listener we need to
305 * disconnect all pending and ready sockets
306 *
307 * First we move ready sockets to pending again.
308 */
309 spin_lock_irqsave(&sc->listen.lock, flags);
310 list_splice_init(&sc->listen.ready, &sc->listen.pending);
311 list_for_each_entry_safe(psc, tsc, &sc->listen.pending, accept.list)
312 smbdirect_socket_schedule_cleanup(psc, sc->first_error);
313 spin_unlock_irqrestore(&sc->listen.lock, flags);
314
315 switch (sc->status) {
316 case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED:
317 case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED:
318 case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED:
319 case SMBDIRECT_SOCKET_NEGOTIATE_FAILED:
320 case SMBDIRECT_SOCKET_ERROR:
321 case SMBDIRECT_SOCKET_DISCONNECTING:
322 case SMBDIRECT_SOCKET_DISCONNECTED:
323 case SMBDIRECT_SOCKET_DESTROYED:
324 /*
325 * Keep the current error status
326 */
327 break;
328
329 case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED:
330 case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING:
331 sc->status = SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED;
332 break;
333
334 case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED:
335 case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING:
336 sc->status = SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED;
337 break;
338
339 case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED:
340 case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING:
341 sc->status = SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED;
342 break;
343
344 case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED:
345 case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING:
346 sc->status = SMBDIRECT_SOCKET_NEGOTIATE_FAILED;
347 break;
348
349 case SMBDIRECT_SOCKET_CREATED:
350 case SMBDIRECT_SOCKET_LISTENING:
351 sc->status = SMBDIRECT_SOCKET_DISCONNECTED;
352 break;
353
354 case SMBDIRECT_SOCKET_CONNECTED:
355 sc->status = SMBDIRECT_SOCKET_ERROR;
356 break;
357 }
358
359 if (force_status && (was_first || *force_status > sc->status))
360 sc->status = *force_status;
361
362 /*
363 * Wake up all waiters in all wait queues
364 * in order to notice the broken connection.
365 */
366 smbdirect_socket_wake_up_all(sc);
367
368 queue_work(sc->workqueues.cleanup, &sc->disconnect_work);
369 }
370
smbdirect_socket_cleanup_work(struct work_struct * work)371 static void smbdirect_socket_cleanup_work(struct work_struct *work)
372 {
373 struct smbdirect_socket *sc =
374 container_of(work, struct smbdirect_socket, disconnect_work);
375 struct smbdirect_socket *psc, *tsc;
376 unsigned long flags;
377
378 /*
379 * This should not never be called in an interrupt!
380 */
381 WARN_ON_ONCE(in_interrupt());
382
383 if (!sc->first_error) {
384 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_ERR,
385 "%s called with first_error==0\n",
386 smbdirect_socket_status_string(sc->status));
387
388 sc->first_error = -ECONNABORTED;
389 }
390
391 /*
392 * make sure this and other work is not queued again
393 * but here we don't block and avoid
394 * disable[_delayed]_work_sync()
395 */
396 disable_work(&sc->disconnect_work);
397 disable_work(&sc->connect.work);
398 disable_work(&sc->recv_io.posted.refill_work);
399 disable_work(&sc->idle.immediate_work);
400 sc->idle.keepalive = SMBDIRECT_KEEPALIVE_NONE;
401 disable_delayed_work(&sc->idle.timer_work);
402
403 /*
404 * In case we were a listener we need to
405 * disconnect all pending and ready sockets
406 *
407 * First we move ready sockets to pending again.
408 */
409 spin_lock_irqsave(&sc->listen.lock, flags);
410 list_splice_init(&sc->listen.ready, &sc->listen.pending);
411 list_for_each_entry_safe(psc, tsc, &sc->listen.pending, accept.list)
412 smbdirect_socket_schedule_cleanup(psc, sc->first_error);
413 spin_unlock_irqrestore(&sc->listen.lock, flags);
414
415 switch (sc->status) {
416 case SMBDIRECT_SOCKET_NEGOTIATE_NEEDED:
417 case SMBDIRECT_SOCKET_NEGOTIATE_RUNNING:
418 case SMBDIRECT_SOCKET_NEGOTIATE_FAILED:
419 case SMBDIRECT_SOCKET_CONNECTED:
420 case SMBDIRECT_SOCKET_ERROR:
421 sc->status = SMBDIRECT_SOCKET_DISCONNECTING;
422 /*
423 * Make sure we hold the callback lock
424 * im order to coordinate with the
425 * rdma_event handlers, typically
426 * smbdirect_connection_rdma_event_handler(),
427 * and smbdirect_socket_destroy().
428 *
429 * So that the order of ib_drain_qp()
430 * and rdma_disconnect() is controlled
431 * by the mutex.
432 */
433 rdma_lock_handler(sc->rdma.cm_id);
434 rdma_disconnect(sc->rdma.cm_id);
435 rdma_unlock_handler(sc->rdma.cm_id);
436 break;
437
438 case SMBDIRECT_SOCKET_CREATED:
439 case SMBDIRECT_SOCKET_LISTENING:
440 case SMBDIRECT_SOCKET_RESOLVE_ADDR_NEEDED:
441 case SMBDIRECT_SOCKET_RESOLVE_ADDR_RUNNING:
442 case SMBDIRECT_SOCKET_RESOLVE_ADDR_FAILED:
443 case SMBDIRECT_SOCKET_RESOLVE_ROUTE_NEEDED:
444 case SMBDIRECT_SOCKET_RESOLVE_ROUTE_RUNNING:
445 case SMBDIRECT_SOCKET_RESOLVE_ROUTE_FAILED:
446 case SMBDIRECT_SOCKET_RDMA_CONNECT_NEEDED:
447 case SMBDIRECT_SOCKET_RDMA_CONNECT_RUNNING:
448 case SMBDIRECT_SOCKET_RDMA_CONNECT_FAILED:
449 /*
450 * rdma_{accept,connect}() never reached
451 * RDMA_CM_EVENT_ESTABLISHED
452 */
453 sc->status = SMBDIRECT_SOCKET_DISCONNECTED;
454 break;
455
456 case SMBDIRECT_SOCKET_DISCONNECTING:
457 case SMBDIRECT_SOCKET_DISCONNECTED:
458 case SMBDIRECT_SOCKET_DESTROYED:
459 break;
460 }
461
462 /*
463 * Wake up all waiters in all wait queues
464 * in order to notice the broken connection.
465 */
466 smbdirect_socket_wake_up_all(sc);
467 }
468
smbdirect_socket_destroy(struct smbdirect_socket * sc)469 static void smbdirect_socket_destroy(struct smbdirect_socket *sc)
470 {
471 struct smbdirect_socket *psc, *tsc;
472 size_t psockets;
473 struct smbdirect_recv_io *recv_io;
474 struct smbdirect_recv_io *recv_tmp;
475 LIST_HEAD(all_list);
476 unsigned long flags;
477
478 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
479 "status=%s first_error=%1pe",
480 smbdirect_socket_status_string(sc->status),
481 SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
482
483 /*
484 * This should not never be called in an interrupt!
485 */
486 WARN_ON_ONCE(in_interrupt());
487
488 if (sc->status == SMBDIRECT_SOCKET_DESTROYED)
489 return;
490
491 WARN_ONCE(sc->status != SMBDIRECT_SOCKET_DISCONNECTED,
492 "status=%s first_error=%1pe",
493 smbdirect_socket_status_string(sc->status),
494 SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
495
496 /*
497 * The listener should clear this before we reach this
498 */
499 WARN_ONCE(sc->accept.listener,
500 "status=%s first_error=%1pe",
501 smbdirect_socket_status_string(sc->status),
502 SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
503
504 /*
505 * Wake up all waiters in all wait queues
506 * in order to notice the broken connection.
507 *
508 * Most likely this was already called via
509 * smbdirect_socket_cleanup_work(), but call it again...
510 */
511 smbdirect_socket_wake_up_all(sc);
512
513 disable_work_sync(&sc->disconnect_work);
514 disable_work_sync(&sc->connect.work);
515 disable_work_sync(&sc->recv_io.posted.refill_work);
516 disable_work_sync(&sc->idle.immediate_work);
517 disable_delayed_work_sync(&sc->idle.timer_work);
518
519 if (sc->rdma.cm_id)
520 rdma_lock_handler(sc->rdma.cm_id);
521
522 if (sc->ib.qp) {
523 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
524 "drain qp\n");
525 ib_drain_qp(sc->ib.qp);
526 }
527
528 /*
529 * In case we were a listener we need to
530 * disconnect all pending and ready sockets
531 *
532 * We move ready sockets to pending again.
533 */
534 spin_lock_irqsave(&sc->listen.lock, flags);
535 list_splice_tail_init(&sc->listen.ready, &all_list);
536 list_splice_tail_init(&sc->listen.pending, &all_list);
537 spin_unlock_irqrestore(&sc->listen.lock, flags);
538 psockets = list_count_nodes(&all_list);
539 if (sc->listen.backlog != -1) /* was a listener */
540 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
541 "release %zu pending sockets\n", psockets);
542 list_for_each_entry_safe(psc, tsc, &all_list, accept.list) {
543 list_del_init(&psc->accept.list);
544 psc->accept.listener = NULL;
545 smbdirect_socket_release(psc);
546 }
547 if (sc->listen.backlog != -1) /* was a listener */
548 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
549 "released %zu pending sockets\n", psockets);
550 INIT_LIST_HEAD(&all_list);
551
552 /* It's not possible for upper layer to get to reassembly */
553 if (sc->listen.backlog == -1) /* was not a listener */
554 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
555 "drain the reassembly queue\n");
556 spin_lock_irqsave(&sc->recv_io.reassembly.lock, flags);
557 list_splice_tail_init(&sc->recv_io.reassembly.list, &all_list);
558 spin_unlock_irqrestore(&sc->recv_io.reassembly.lock, flags);
559 list_for_each_entry_safe(recv_io, recv_tmp, &all_list, list)
560 smbdirect_connection_put_recv_io(recv_io);
561 sc->recv_io.reassembly.data_length = 0;
562
563 if (sc->listen.backlog == -1) /* was not a listener */
564 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
565 "freeing mr list\n");
566 smbdirect_connection_destroy_mr_list(sc);
567
568 if (sc->listen.backlog == -1) /* was not a listener */
569 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
570 "destroying qp\n");
571 smbdirect_connection_destroy_qp(sc);
572 if (sc->rdma.cm_id) {
573 rdma_unlock_handler(sc->rdma.cm_id);
574 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
575 "destroying cm_id\n");
576 rdma_destroy_id(sc->rdma.cm_id);
577 sc->rdma.cm_id = NULL;
578 }
579
580 if (sc->listen.backlog == -1) /* was not a listener */
581 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
582 "destroying mem pools\n");
583 smbdirect_connection_destroy_mem_pools(sc);
584
585 sc->status = SMBDIRECT_SOCKET_DESTROYED;
586
587 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
588 "rdma session destroyed\n");
589 }
590
smbdirect_socket_destroy_sync(struct smbdirect_socket * sc)591 void smbdirect_socket_destroy_sync(struct smbdirect_socket *sc)
592 {
593 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
594 "status=%s first_error=%1pe",
595 smbdirect_socket_status_string(sc->status),
596 SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
597
598 /*
599 * This should not never be called in an interrupt!
600 */
601 WARN_ON_ONCE(in_interrupt());
602
603 /*
604 * First we try to disable the work
605 * without disable_work_sync() in a
606 * non blocking way, if it's already
607 * running it will be handles by
608 * disable_work_sync() below.
609 *
610 * Here we just want to make sure queue_work() in
611 * smbdirect_socket_schedule_cleanup_lvl()
612 * is a no-op.
613 */
614 disable_work(&sc->disconnect_work);
615
616 if (!sc->first_error)
617 /*
618 * SMBDIRECT_LOG_INFO is enough here
619 * as this is the typical case where
620 * we terminate the connection ourself.
621 */
622 smbdirect_socket_schedule_cleanup_lvl(sc,
623 SMBDIRECT_LOG_INFO,
624 -ESHUTDOWN);
625
626 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
627 "cancelling and disable disconnect_work\n");
628 disable_work_sync(&sc->disconnect_work);
629
630 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
631 "destroying rdma session\n");
632 if (sc->status < SMBDIRECT_SOCKET_DISCONNECTING)
633 smbdirect_socket_cleanup_work(&sc->disconnect_work);
634 if (sc->status < SMBDIRECT_SOCKET_DISCONNECTED) {
635 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
636 "wait for transport being disconnected\n");
637 wait_event(sc->status_wait, sc->status == SMBDIRECT_SOCKET_DISCONNECTED);
638 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
639 "waited for transport being disconnected\n");
640 }
641
642 /*
643 * Once we reached SMBDIRECT_SOCKET_DISCONNECTED,
644 * we should call smbdirect_socket_destroy()
645 */
646 smbdirect_socket_destroy(sc);
647 smbdirect_log_rdma_event(sc, SMBDIRECT_LOG_INFO,
648 "status=%s first_error=%1pe",
649 smbdirect_socket_status_string(sc->status),
650 SMBDIRECT_DEBUG_ERR_PTR(sc->first_error));
651 }
652
smbdirect_socket_bind(struct smbdirect_socket * sc,struct sockaddr * addr)653 int smbdirect_socket_bind(struct smbdirect_socket *sc, struct sockaddr *addr)
654 {
655 int ret;
656
657 if (sc->status != SMBDIRECT_SOCKET_CREATED)
658 return -EINVAL;
659
660 ret = rdma_bind_addr(sc->rdma.cm_id, addr);
661 if (ret)
662 return ret;
663
664 return 0;
665 }
666 EXPORT_SYMBOL_GPL(smbdirect_socket_bind);
667
smbdirect_socket_shutdown(struct smbdirect_socket * sc)668 void smbdirect_socket_shutdown(struct smbdirect_socket *sc)
669 {
670 smbdirect_socket_schedule_cleanup_lvl(sc, SMBDIRECT_LOG_INFO, -ESHUTDOWN);
671 }
672 EXPORT_SYMBOL_GPL(smbdirect_socket_shutdown);
673
smbdirect_socket_release_disconnect(struct kref * kref)674 static void smbdirect_socket_release_disconnect(struct kref *kref)
675 {
676 struct smbdirect_socket *sc =
677 container_of(kref, struct smbdirect_socket, refs.disconnect);
678
679 /*
680 * For now do a sync disconnect/destroy
681 */
682 smbdirect_socket_destroy_sync(sc);
683 }
684
smbdirect_socket_release_destroy(struct kref * kref)685 static void smbdirect_socket_release_destroy(struct kref *kref)
686 {
687 struct smbdirect_socket *sc =
688 container_of(kref, struct smbdirect_socket, refs.destroy);
689
690 /*
691 * Do a sync disconnect/destroy...
692 * hopefully a no-op, as it should be already
693 * in DESTROYED state, before we free the memory.
694 */
695 smbdirect_socket_destroy_sync(sc);
696 kfree(sc);
697 }
698
smbdirect_socket_release(struct smbdirect_socket * sc)699 void smbdirect_socket_release(struct smbdirect_socket *sc)
700 {
701 /*
702 * We expect only 1 disconnect reference
703 * and if it is already 0, it's a use after free!
704 */
705 WARN_ON_ONCE(kref_read(&sc->refs.disconnect) != 1);
706 WARN_ON(!kref_put(&sc->refs.disconnect, smbdirect_socket_release_disconnect));
707
708 /*
709 * This may not trigger smbdirect_socket_release_destroy(),
710 * if struct smbdirect_socket is embedded in another structure
711 * indicated by REFCOUNT_MAX.
712 */
713 kref_put(&sc->refs.destroy, smbdirect_socket_release_destroy);
714 }
715 EXPORT_SYMBOL_GPL(smbdirect_socket_release);
716
smbdirect_socket_wait_for_credits(struct smbdirect_socket * sc,enum smbdirect_socket_status expected_status,int unexpected_errno,wait_queue_head_t * waitq,atomic_t * total_credits,int needed)717 int smbdirect_socket_wait_for_credits(struct smbdirect_socket *sc,
718 enum smbdirect_socket_status expected_status,
719 int unexpected_errno,
720 wait_queue_head_t *waitq,
721 atomic_t *total_credits,
722 int needed)
723 {
724 int ret;
725
726 if (WARN_ON_ONCE(needed < 0))
727 return -EINVAL;
728
729 do {
730 if (atomic_sub_return(needed, total_credits) >= 0)
731 return 0;
732
733 atomic_add(needed, total_credits);
734 ret = wait_event_interruptible(*waitq,
735 atomic_read(total_credits) >= needed ||
736 sc->status != expected_status);
737
738 if (sc->status != expected_status)
739 return unexpected_errno;
740 else if (ret < 0)
741 return ret;
742 } while (true);
743 }
744