xref: /linux/drivers/infiniband/sw/rdmavt/mcast.c (revision 0ad53fe3ae82443c74ff8cfd7bd13377cc1134a3)
1 // SPDX-License-Identifier: GPL-2.0 or BSD-3-Clause
2 /*
3  * Copyright(c) 2016 Intel Corporation.
4  */
5 
6 #include <linux/slab.h>
7 #include <linux/sched.h>
8 #include <linux/rculist.h>
9 #include <rdma/rdma_vt.h>
10 #include <rdma/rdmavt_qp.h>
11 
12 #include "mcast.h"
13 
14 /**
15  * rvt_driver_mcast_init - init resources for multicast
16  * @rdi: rvt dev struct
17  *
18  * This is per device that registers with rdmavt
19  */
20 void rvt_driver_mcast_init(struct rvt_dev_info *rdi)
21 {
22 	/*
23 	 * Anything that needs setup for multicast on a per driver or per rdi
24 	 * basis should be done in here.
25 	 */
26 	spin_lock_init(&rdi->n_mcast_grps_lock);
27 }
28 
29 /**
30  * rvt_mcast_qp_alloc - alloc a struct to link a QP to mcast GID struct
31  * @qp: the QP to link
32  */
33 static struct rvt_mcast_qp *rvt_mcast_qp_alloc(struct rvt_qp *qp)
34 {
35 	struct rvt_mcast_qp *mqp;
36 
37 	mqp = kmalloc(sizeof(*mqp), GFP_KERNEL);
38 	if (!mqp)
39 		goto bail;
40 
41 	mqp->qp = qp;
42 	rvt_get_qp(qp);
43 
44 bail:
45 	return mqp;
46 }
47 
48 static void rvt_mcast_qp_free(struct rvt_mcast_qp *mqp)
49 {
50 	struct rvt_qp *qp = mqp->qp;
51 
52 	/* Notify hfi1_destroy_qp() if it is waiting. */
53 	rvt_put_qp(qp);
54 
55 	kfree(mqp);
56 }
57 
58 /**
59  * rvt_mcast_alloc - allocate the multicast GID structure
60  * @mgid: the multicast GID
61  * @lid: the muilticast LID (host order)
62  *
63  * A list of QPs will be attached to this structure.
64  */
65 static struct rvt_mcast *rvt_mcast_alloc(union ib_gid *mgid, u16 lid)
66 {
67 	struct rvt_mcast *mcast;
68 
69 	mcast = kzalloc(sizeof(*mcast), GFP_KERNEL);
70 	if (!mcast)
71 		goto bail;
72 
73 	mcast->mcast_addr.mgid = *mgid;
74 	mcast->mcast_addr.lid = lid;
75 
76 	INIT_LIST_HEAD(&mcast->qp_list);
77 	init_waitqueue_head(&mcast->wait);
78 	atomic_set(&mcast->refcount, 0);
79 
80 bail:
81 	return mcast;
82 }
83 
84 static void rvt_mcast_free(struct rvt_mcast *mcast)
85 {
86 	struct rvt_mcast_qp *p, *tmp;
87 
88 	list_for_each_entry_safe(p, tmp, &mcast->qp_list, list)
89 		rvt_mcast_qp_free(p);
90 
91 	kfree(mcast);
92 }
93 
94 /**
95  * rvt_mcast_find - search the global table for the given multicast GID/LID
96  * NOTE: It is valid to have 1 MLID with multiple MGIDs.  It is not valid
97  * to have 1 MGID with multiple MLIDs.
98  * @ibp: the IB port structure
99  * @mgid: the multicast GID to search for
100  * @lid: the multicast LID portion of the multicast address (host order)
101  *
102  * The caller is responsible for decrementing the reference count if found.
103  *
104  * Return: NULL if not found.
105  */
106 struct rvt_mcast *rvt_mcast_find(struct rvt_ibport *ibp, union ib_gid *mgid,
107 				 u16 lid)
108 {
109 	struct rb_node *n;
110 	unsigned long flags;
111 	struct rvt_mcast *found = NULL;
112 
113 	spin_lock_irqsave(&ibp->lock, flags);
114 	n = ibp->mcast_tree.rb_node;
115 	while (n) {
116 		int ret;
117 		struct rvt_mcast *mcast;
118 
119 		mcast = rb_entry(n, struct rvt_mcast, rb_node);
120 
121 		ret = memcmp(mgid->raw, mcast->mcast_addr.mgid.raw,
122 			     sizeof(*mgid));
123 		if (ret < 0) {
124 			n = n->rb_left;
125 		} else if (ret > 0) {
126 			n = n->rb_right;
127 		} else {
128 			/* MGID/MLID must match */
129 			if (mcast->mcast_addr.lid == lid) {
130 				atomic_inc(&mcast->refcount);
131 				found = mcast;
132 			}
133 			break;
134 		}
135 	}
136 	spin_unlock_irqrestore(&ibp->lock, flags);
137 	return found;
138 }
139 EXPORT_SYMBOL(rvt_mcast_find);
140 
141 /*
142  * rvt_mcast_add - insert mcast GID into table and attach QP struct
143  * @mcast: the mcast GID table
144  * @mqp: the QP to attach
145  *
146  * Return: zero if both were added.  Return EEXIST if the GID was already in
147  * the table but the QP was added.  Return ESRCH if the QP was already
148  * attached and neither structure was added. Return EINVAL if the MGID was
149  * found, but the MLID did NOT match.
150  */
151 static int rvt_mcast_add(struct rvt_dev_info *rdi, struct rvt_ibport *ibp,
152 			 struct rvt_mcast *mcast, struct rvt_mcast_qp *mqp)
153 {
154 	struct rb_node **n = &ibp->mcast_tree.rb_node;
155 	struct rb_node *pn = NULL;
156 	int ret;
157 
158 	spin_lock_irq(&ibp->lock);
159 
160 	while (*n) {
161 		struct rvt_mcast *tmcast;
162 		struct rvt_mcast_qp *p;
163 
164 		pn = *n;
165 		tmcast = rb_entry(pn, struct rvt_mcast, rb_node);
166 
167 		ret = memcmp(mcast->mcast_addr.mgid.raw,
168 			     tmcast->mcast_addr.mgid.raw,
169 			     sizeof(mcast->mcast_addr.mgid));
170 		if (ret < 0) {
171 			n = &pn->rb_left;
172 			continue;
173 		}
174 		if (ret > 0) {
175 			n = &pn->rb_right;
176 			continue;
177 		}
178 
179 		if (tmcast->mcast_addr.lid != mcast->mcast_addr.lid) {
180 			ret = EINVAL;
181 			goto bail;
182 		}
183 
184 		/* Search the QP list to see if this is already there. */
185 		list_for_each_entry_rcu(p, &tmcast->qp_list, list) {
186 			if (p->qp == mqp->qp) {
187 				ret = ESRCH;
188 				goto bail;
189 			}
190 		}
191 		if (tmcast->n_attached ==
192 		    rdi->dparms.props.max_mcast_qp_attach) {
193 			ret = ENOMEM;
194 			goto bail;
195 		}
196 
197 		tmcast->n_attached++;
198 
199 		list_add_tail_rcu(&mqp->list, &tmcast->qp_list);
200 		ret = EEXIST;
201 		goto bail;
202 	}
203 
204 	spin_lock(&rdi->n_mcast_grps_lock);
205 	if (rdi->n_mcast_grps_allocated == rdi->dparms.props.max_mcast_grp) {
206 		spin_unlock(&rdi->n_mcast_grps_lock);
207 		ret = ENOMEM;
208 		goto bail;
209 	}
210 
211 	rdi->n_mcast_grps_allocated++;
212 	spin_unlock(&rdi->n_mcast_grps_lock);
213 
214 	mcast->n_attached++;
215 
216 	list_add_tail_rcu(&mqp->list, &mcast->qp_list);
217 
218 	atomic_inc(&mcast->refcount);
219 	rb_link_node(&mcast->rb_node, pn, n);
220 	rb_insert_color(&mcast->rb_node, &ibp->mcast_tree);
221 
222 	ret = 0;
223 
224 bail:
225 	spin_unlock_irq(&ibp->lock);
226 
227 	return ret;
228 }
229 
230 /**
231  * rvt_attach_mcast - attach a qp to a multicast group
232  * @ibqp: Infiniband qp
233  * @gid: multicast guid
234  * @lid: multicast lid
235  *
236  * Return: 0 on success
237  */
238 int rvt_attach_mcast(struct ib_qp *ibqp, union ib_gid *gid, u16 lid)
239 {
240 	struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
241 	struct rvt_dev_info *rdi = ib_to_rvt(ibqp->device);
242 	struct rvt_ibport *ibp = rdi->ports[qp->port_num - 1];
243 	struct rvt_mcast *mcast;
244 	struct rvt_mcast_qp *mqp;
245 	int ret = -ENOMEM;
246 
247 	if (ibqp->qp_num <= 1 || qp->state == IB_QPS_RESET)
248 		return -EINVAL;
249 
250 	/*
251 	 * Allocate data structures since its better to do this outside of
252 	 * spin locks and it will most likely be needed.
253 	 */
254 	mcast = rvt_mcast_alloc(gid, lid);
255 	if (!mcast)
256 		return -ENOMEM;
257 
258 	mqp = rvt_mcast_qp_alloc(qp);
259 	if (!mqp)
260 		goto bail_mcast;
261 
262 	switch (rvt_mcast_add(rdi, ibp, mcast, mqp)) {
263 	case ESRCH:
264 		/* Neither was used: OK to attach the same QP twice. */
265 		ret = 0;
266 		goto bail_mqp;
267 	case EEXIST: /* The mcast wasn't used */
268 		ret = 0;
269 		goto bail_mcast;
270 	case ENOMEM:
271 		/* Exceeded the maximum number of mcast groups. */
272 		ret = -ENOMEM;
273 		goto bail_mqp;
274 	case EINVAL:
275 		/* Invalid MGID/MLID pair */
276 		ret = -EINVAL;
277 		goto bail_mqp;
278 	default:
279 		break;
280 	}
281 
282 	return 0;
283 
284 bail_mqp:
285 	rvt_mcast_qp_free(mqp);
286 
287 bail_mcast:
288 	rvt_mcast_free(mcast);
289 
290 	return ret;
291 }
292 
293 /**
294  * rvt_detach_mcast - remove a qp from a multicast group
295  * @ibqp: Infiniband qp
296  * @gid: multicast guid
297  * @lid: multicast lid
298  *
299  * Return: 0 on success
300  */
301 int rvt_detach_mcast(struct ib_qp *ibqp, union ib_gid *gid, u16 lid)
302 {
303 	struct rvt_qp *qp = ibqp_to_rvtqp(ibqp);
304 	struct rvt_dev_info *rdi = ib_to_rvt(ibqp->device);
305 	struct rvt_ibport *ibp = rdi->ports[qp->port_num - 1];
306 	struct rvt_mcast *mcast = NULL;
307 	struct rvt_mcast_qp *p, *tmp, *delp = NULL;
308 	struct rb_node *n;
309 	int last = 0;
310 	int ret = 0;
311 
312 	if (ibqp->qp_num <= 1)
313 		return -EINVAL;
314 
315 	spin_lock_irq(&ibp->lock);
316 
317 	/* Find the GID in the mcast table. */
318 	n = ibp->mcast_tree.rb_node;
319 	while (1) {
320 		if (!n) {
321 			spin_unlock_irq(&ibp->lock);
322 			return -EINVAL;
323 		}
324 
325 		mcast = rb_entry(n, struct rvt_mcast, rb_node);
326 		ret = memcmp(gid->raw, mcast->mcast_addr.mgid.raw,
327 			     sizeof(*gid));
328 		if (ret < 0) {
329 			n = n->rb_left;
330 		} else if (ret > 0) {
331 			n = n->rb_right;
332 		} else {
333 			/* MGID/MLID must match */
334 			if (mcast->mcast_addr.lid != lid) {
335 				spin_unlock_irq(&ibp->lock);
336 				return -EINVAL;
337 			}
338 			break;
339 		}
340 	}
341 
342 	/* Search the QP list. */
343 	list_for_each_entry_safe(p, tmp, &mcast->qp_list, list) {
344 		if (p->qp != qp)
345 			continue;
346 		/*
347 		 * We found it, so remove it, but don't poison the forward
348 		 * link until we are sure there are no list walkers.
349 		 */
350 		list_del_rcu(&p->list);
351 		mcast->n_attached--;
352 		delp = p;
353 
354 		/* If this was the last attached QP, remove the GID too. */
355 		if (list_empty(&mcast->qp_list)) {
356 			rb_erase(&mcast->rb_node, &ibp->mcast_tree);
357 			last = 1;
358 		}
359 		break;
360 	}
361 
362 	spin_unlock_irq(&ibp->lock);
363 	/* QP not attached */
364 	if (!delp)
365 		return -EINVAL;
366 
367 	/*
368 	 * Wait for any list walkers to finish before freeing the
369 	 * list element.
370 	 */
371 	wait_event(mcast->wait, atomic_read(&mcast->refcount) <= 1);
372 	rvt_mcast_qp_free(delp);
373 
374 	if (last) {
375 		atomic_dec(&mcast->refcount);
376 		wait_event(mcast->wait, !atomic_read(&mcast->refcount));
377 		rvt_mcast_free(mcast);
378 		spin_lock_irq(&rdi->n_mcast_grps_lock);
379 		rdi->n_mcast_grps_allocated--;
380 		spin_unlock_irq(&rdi->n_mcast_grps_lock);
381 	}
382 
383 	return 0;
384 }
385 
386 /**
387  * rvt_mcast_tree_empty - determine if any qps are attached to any mcast group
388  * @rdi: rvt dev struct
389  *
390  * Return: in use count
391  */
392 int rvt_mcast_tree_empty(struct rvt_dev_info *rdi)
393 {
394 	int i;
395 	int in_use = 0;
396 
397 	for (i = 0; i < rdi->dparms.nports; i++)
398 		if (rdi->ports[i]->mcast_tree.rb_node)
399 			in_use++;
400 	return in_use;
401 }
402