xref: /linux/drivers/infiniband/core/multicast.c (revision c537b994505099b7197e7d3125b942ecbcc51eb6)
1 /*
2  * Copyright (c) 2006 Intel Corporation.  All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * OpenIB.org BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  */
32 
33 #include <linux/completion.h>
34 #include <linux/dma-mapping.h>
35 #include <linux/err.h>
36 #include <linux/interrupt.h>
37 #include <linux/pci.h>
38 #include <linux/bitops.h>
39 #include <linux/random.h>
40 
41 #include <rdma/ib_cache.h>
42 #include "sa.h"
43 
44 static void mcast_add_one(struct ib_device *device);
45 static void mcast_remove_one(struct ib_device *device);
46 
47 static struct ib_client mcast_client = {
48 	.name   = "ib_multicast",
49 	.add    = mcast_add_one,
50 	.remove = mcast_remove_one
51 };
52 
53 static struct ib_sa_client	sa_client;
54 static struct workqueue_struct	*mcast_wq;
55 static union ib_gid mgid0;
56 
57 struct mcast_device;
58 
59 struct mcast_port {
60 	struct mcast_device	*dev;
61 	spinlock_t		lock;
62 	struct rb_root		table;
63 	atomic_t		refcount;
64 	struct completion	comp;
65 	u8			port_num;
66 };
67 
68 struct mcast_device {
69 	struct ib_device	*device;
70 	struct ib_event_handler	event_handler;
71 	int			start_port;
72 	int			end_port;
73 	struct mcast_port	port[0];
74 };
75 
76 enum mcast_state {
77 	MCAST_IDLE,
78 	MCAST_JOINING,
79 	MCAST_MEMBER,
80 	MCAST_BUSY,
81 	MCAST_ERROR
82 };
83 
84 struct mcast_member;
85 
86 struct mcast_group {
87 	struct ib_sa_mcmember_rec rec;
88 	struct rb_node		node;
89 	struct mcast_port	*port;
90 	spinlock_t		lock;
91 	struct work_struct	work;
92 	struct list_head	pending_list;
93 	struct list_head	active_list;
94 	struct mcast_member	*last_join;
95 	int			members[3];
96 	atomic_t		refcount;
97 	enum mcast_state	state;
98 	struct ib_sa_query	*query;
99 	int			query_id;
100 };
101 
102 struct mcast_member {
103 	struct ib_sa_multicast	multicast;
104 	struct ib_sa_client	*client;
105 	struct mcast_group	*group;
106 	struct list_head	list;
107 	enum mcast_state	state;
108 	atomic_t		refcount;
109 	struct completion	comp;
110 };
111 
112 static void join_handler(int status, struct ib_sa_mcmember_rec *rec,
113 			 void *context);
114 static void leave_handler(int status, struct ib_sa_mcmember_rec *rec,
115 			  void *context);
116 
117 static struct mcast_group *mcast_find(struct mcast_port *port,
118 				      union ib_gid *mgid)
119 {
120 	struct rb_node *node = port->table.rb_node;
121 	struct mcast_group *group;
122 	int ret;
123 
124 	while (node) {
125 		group = rb_entry(node, struct mcast_group, node);
126 		ret = memcmp(mgid->raw, group->rec.mgid.raw, sizeof *mgid);
127 		if (!ret)
128 			return group;
129 
130 		if (ret < 0)
131 			node = node->rb_left;
132 		else
133 			node = node->rb_right;
134 	}
135 	return NULL;
136 }
137 
138 static struct mcast_group *mcast_insert(struct mcast_port *port,
139 					struct mcast_group *group,
140 					int allow_duplicates)
141 {
142 	struct rb_node **link = &port->table.rb_node;
143 	struct rb_node *parent = NULL;
144 	struct mcast_group *cur_group;
145 	int ret;
146 
147 	while (*link) {
148 		parent = *link;
149 		cur_group = rb_entry(parent, struct mcast_group, node);
150 
151 		ret = memcmp(group->rec.mgid.raw, cur_group->rec.mgid.raw,
152 			     sizeof group->rec.mgid);
153 		if (ret < 0)
154 			link = &(*link)->rb_left;
155 		else if (ret > 0)
156 			link = &(*link)->rb_right;
157 		else if (allow_duplicates)
158 			link = &(*link)->rb_left;
159 		else
160 			return cur_group;
161 	}
162 	rb_link_node(&group->node, parent, link);
163 	rb_insert_color(&group->node, &port->table);
164 	return NULL;
165 }
166 
167 static void deref_port(struct mcast_port *port)
168 {
169 	if (atomic_dec_and_test(&port->refcount))
170 		complete(&port->comp);
171 }
172 
173 static void release_group(struct mcast_group *group)
174 {
175 	struct mcast_port *port = group->port;
176 	unsigned long flags;
177 
178 	spin_lock_irqsave(&port->lock, flags);
179 	if (atomic_dec_and_test(&group->refcount)) {
180 		rb_erase(&group->node, &port->table);
181 		spin_unlock_irqrestore(&port->lock, flags);
182 		kfree(group);
183 		deref_port(port);
184 	} else
185 		spin_unlock_irqrestore(&port->lock, flags);
186 }
187 
188 static void deref_member(struct mcast_member *member)
189 {
190 	if (atomic_dec_and_test(&member->refcount))
191 		complete(&member->comp);
192 }
193 
194 static void queue_join(struct mcast_member *member)
195 {
196 	struct mcast_group *group = member->group;
197 	unsigned long flags;
198 
199 	spin_lock_irqsave(&group->lock, flags);
200 	list_add(&member->list, &group->pending_list);
201 	if (group->state == MCAST_IDLE) {
202 		group->state = MCAST_BUSY;
203 		atomic_inc(&group->refcount);
204 		queue_work(mcast_wq, &group->work);
205 	}
206 	spin_unlock_irqrestore(&group->lock, flags);
207 }
208 
209 /*
210  * A multicast group has three types of members: full member, non member, and
211  * send only member.  We need to keep track of the number of members of each
212  * type based on their join state.  Adjust the number of members the belong to
213  * the specified join states.
214  */
215 static void adjust_membership(struct mcast_group *group, u8 join_state, int inc)
216 {
217 	int i;
218 
219 	for (i = 0; i < 3; i++, join_state >>= 1)
220 		if (join_state & 0x1)
221 			group->members[i] += inc;
222 }
223 
224 /*
225  * If a multicast group has zero members left for a particular join state, but
226  * the group is still a member with the SA, we need to leave that join state.
227  * Determine which join states we still belong to, but that do not have any
228  * active members.
229  */
230 static u8 get_leave_state(struct mcast_group *group)
231 {
232 	u8 leave_state = 0;
233 	int i;
234 
235 	for (i = 0; i < 3; i++)
236 		if (!group->members[i])
237 			leave_state |= (0x1 << i);
238 
239 	return leave_state & group->rec.join_state;
240 }
241 
242 static int check_selector(ib_sa_comp_mask comp_mask,
243 			  ib_sa_comp_mask selector_mask,
244 			  ib_sa_comp_mask value_mask,
245 			  u8 selector, u8 src_value, u8 dst_value)
246 {
247 	int err;
248 
249 	if (!(comp_mask & selector_mask) || !(comp_mask & value_mask))
250 		return 0;
251 
252 	switch (selector) {
253 	case IB_SA_GT:
254 		err = (src_value <= dst_value);
255 		break;
256 	case IB_SA_LT:
257 		err = (src_value >= dst_value);
258 		break;
259 	case IB_SA_EQ:
260 		err = (src_value != dst_value);
261 		break;
262 	default:
263 		err = 0;
264 		break;
265 	}
266 
267 	return err;
268 }
269 
270 static int cmp_rec(struct ib_sa_mcmember_rec *src,
271 		   struct ib_sa_mcmember_rec *dst, ib_sa_comp_mask comp_mask)
272 {
273 	/* MGID must already match */
274 
275 	if (comp_mask & IB_SA_MCMEMBER_REC_PORT_GID &&
276 	    memcmp(&src->port_gid, &dst->port_gid, sizeof src->port_gid))
277 		return -EINVAL;
278 	if (comp_mask & IB_SA_MCMEMBER_REC_QKEY && src->qkey != dst->qkey)
279 		return -EINVAL;
280 	if (comp_mask & IB_SA_MCMEMBER_REC_MLID && src->mlid != dst->mlid)
281 		return -EINVAL;
282 	if (check_selector(comp_mask, IB_SA_MCMEMBER_REC_MTU_SELECTOR,
283 			   IB_SA_MCMEMBER_REC_MTU, dst->mtu_selector,
284 			   src->mtu, dst->mtu))
285 		return -EINVAL;
286 	if (comp_mask & IB_SA_MCMEMBER_REC_TRAFFIC_CLASS &&
287 	    src->traffic_class != dst->traffic_class)
288 		return -EINVAL;
289 	if (comp_mask & IB_SA_MCMEMBER_REC_PKEY && src->pkey != dst->pkey)
290 		return -EINVAL;
291 	if (check_selector(comp_mask, IB_SA_MCMEMBER_REC_RATE_SELECTOR,
292 			   IB_SA_MCMEMBER_REC_RATE, dst->rate_selector,
293 			   src->rate, dst->rate))
294 		return -EINVAL;
295 	if (check_selector(comp_mask,
296 			   IB_SA_MCMEMBER_REC_PACKET_LIFE_TIME_SELECTOR,
297 			   IB_SA_MCMEMBER_REC_PACKET_LIFE_TIME,
298 			   dst->packet_life_time_selector,
299 			   src->packet_life_time, dst->packet_life_time))
300 		return -EINVAL;
301 	if (comp_mask & IB_SA_MCMEMBER_REC_SL && src->sl != dst->sl)
302 		return -EINVAL;
303 	if (comp_mask & IB_SA_MCMEMBER_REC_FLOW_LABEL &&
304 	    src->flow_label != dst->flow_label)
305 		return -EINVAL;
306 	if (comp_mask & IB_SA_MCMEMBER_REC_HOP_LIMIT &&
307 	    src->hop_limit != dst->hop_limit)
308 		return -EINVAL;
309 	if (comp_mask & IB_SA_MCMEMBER_REC_SCOPE && src->scope != dst->scope)
310 		return -EINVAL;
311 
312 	/* join_state checked separately, proxy_join ignored */
313 
314 	return 0;
315 }
316 
317 static int send_join(struct mcast_group *group, struct mcast_member *member)
318 {
319 	struct mcast_port *port = group->port;
320 	int ret;
321 
322 	group->last_join = member;
323 	ret = ib_sa_mcmember_rec_query(&sa_client, port->dev->device,
324 				       port->port_num, IB_MGMT_METHOD_SET,
325 				       &member->multicast.rec,
326 				       member->multicast.comp_mask,
327 				       3000, GFP_KERNEL, join_handler, group,
328 				       &group->query);
329 	if (ret >= 0) {
330 		group->query_id = ret;
331 		ret = 0;
332 	}
333 	return ret;
334 }
335 
336 static int send_leave(struct mcast_group *group, u8 leave_state)
337 {
338 	struct mcast_port *port = group->port;
339 	struct ib_sa_mcmember_rec rec;
340 	int ret;
341 
342 	rec = group->rec;
343 	rec.join_state = leave_state;
344 
345 	ret = ib_sa_mcmember_rec_query(&sa_client, port->dev->device,
346 				       port->port_num, IB_SA_METHOD_DELETE, &rec,
347 				       IB_SA_MCMEMBER_REC_MGID     |
348 				       IB_SA_MCMEMBER_REC_PORT_GID |
349 				       IB_SA_MCMEMBER_REC_JOIN_STATE,
350 				       3000, GFP_KERNEL, leave_handler,
351 				       group, &group->query);
352 	if (ret >= 0) {
353 		group->query_id = ret;
354 		ret = 0;
355 	}
356 	return ret;
357 }
358 
359 static void join_group(struct mcast_group *group, struct mcast_member *member,
360 		       u8 join_state)
361 {
362 	member->state = MCAST_MEMBER;
363 	adjust_membership(group, join_state, 1);
364 	group->rec.join_state |= join_state;
365 	member->multicast.rec = group->rec;
366 	member->multicast.rec.join_state = join_state;
367 	list_move(&member->list, &group->active_list);
368 }
369 
370 static int fail_join(struct mcast_group *group, struct mcast_member *member,
371 		     int status)
372 {
373 	spin_lock_irq(&group->lock);
374 	list_del_init(&member->list);
375 	spin_unlock_irq(&group->lock);
376 	return member->multicast.callback(status, &member->multicast);
377 }
378 
379 static void process_group_error(struct mcast_group *group)
380 {
381 	struct mcast_member *member;
382 	int ret;
383 
384 	spin_lock_irq(&group->lock);
385 	while (!list_empty(&group->active_list)) {
386 		member = list_entry(group->active_list.next,
387 				    struct mcast_member, list);
388 		atomic_inc(&member->refcount);
389 		list_del_init(&member->list);
390 		adjust_membership(group, member->multicast.rec.join_state, -1);
391 		member->state = MCAST_ERROR;
392 		spin_unlock_irq(&group->lock);
393 
394 		ret = member->multicast.callback(-ENETRESET,
395 						 &member->multicast);
396 		deref_member(member);
397 		if (ret)
398 			ib_sa_free_multicast(&member->multicast);
399 		spin_lock_irq(&group->lock);
400 	}
401 
402 	group->rec.join_state = 0;
403 	group->state = MCAST_BUSY;
404 	spin_unlock_irq(&group->lock);
405 }
406 
407 static void mcast_work_handler(struct work_struct *work)
408 {
409 	struct mcast_group *group;
410 	struct mcast_member *member;
411 	struct ib_sa_multicast *multicast;
412 	int status, ret;
413 	u8 join_state;
414 
415 	group = container_of(work, typeof(*group), work);
416 retest:
417 	spin_lock_irq(&group->lock);
418 	while (!list_empty(&group->pending_list) ||
419 	       (group->state == MCAST_ERROR)) {
420 
421 		if (group->state == MCAST_ERROR) {
422 			spin_unlock_irq(&group->lock);
423 			process_group_error(group);
424 			goto retest;
425 		}
426 
427 		member = list_entry(group->pending_list.next,
428 				    struct mcast_member, list);
429 		multicast = &member->multicast;
430 		join_state = multicast->rec.join_state;
431 		atomic_inc(&member->refcount);
432 
433 		if (join_state == (group->rec.join_state & join_state)) {
434 			status = cmp_rec(&group->rec, &multicast->rec,
435 					 multicast->comp_mask);
436 			if (!status)
437 				join_group(group, member, join_state);
438 			else
439 				list_del_init(&member->list);
440 			spin_unlock_irq(&group->lock);
441 			ret = multicast->callback(status, multicast);
442 		} else {
443 			spin_unlock_irq(&group->lock);
444 			status = send_join(group, member);
445 			if (!status) {
446 				deref_member(member);
447 				return;
448 			}
449 			ret = fail_join(group, member, status);
450 		}
451 
452 		deref_member(member);
453 		if (ret)
454 			ib_sa_free_multicast(&member->multicast);
455 		spin_lock_irq(&group->lock);
456 	}
457 
458 	join_state = get_leave_state(group);
459 	if (join_state) {
460 		group->rec.join_state &= ~join_state;
461 		spin_unlock_irq(&group->lock);
462 		if (send_leave(group, join_state))
463 			goto retest;
464 	} else {
465 		group->state = MCAST_IDLE;
466 		spin_unlock_irq(&group->lock);
467 		release_group(group);
468 	}
469 }
470 
471 /*
472  * Fail a join request if it is still active - at the head of the pending queue.
473  */
474 static void process_join_error(struct mcast_group *group, int status)
475 {
476 	struct mcast_member *member;
477 	int ret;
478 
479 	spin_lock_irq(&group->lock);
480 	member = list_entry(group->pending_list.next,
481 			    struct mcast_member, list);
482 	if (group->last_join == member) {
483 		atomic_inc(&member->refcount);
484 		list_del_init(&member->list);
485 		spin_unlock_irq(&group->lock);
486 		ret = member->multicast.callback(status, &member->multicast);
487 		deref_member(member);
488 		if (ret)
489 			ib_sa_free_multicast(&member->multicast);
490 	} else
491 		spin_unlock_irq(&group->lock);
492 }
493 
494 static void join_handler(int status, struct ib_sa_mcmember_rec *rec,
495 			 void *context)
496 {
497 	struct mcast_group *group = context;
498 
499 	if (status)
500 		process_join_error(group, status);
501 	else {
502 		spin_lock_irq(&group->port->lock);
503 		group->rec = *rec;
504 		if (!memcmp(&mgid0, &group->rec.mgid, sizeof mgid0)) {
505 			rb_erase(&group->node, &group->port->table);
506 			mcast_insert(group->port, group, 1);
507 		}
508 		spin_unlock_irq(&group->port->lock);
509 	}
510 	mcast_work_handler(&group->work);
511 }
512 
513 static void leave_handler(int status, struct ib_sa_mcmember_rec *rec,
514 			  void *context)
515 {
516 	struct mcast_group *group = context;
517 
518 	mcast_work_handler(&group->work);
519 }
520 
521 static struct mcast_group *acquire_group(struct mcast_port *port,
522 					 union ib_gid *mgid, gfp_t gfp_mask)
523 {
524 	struct mcast_group *group, *cur_group;
525 	unsigned long flags;
526 	int is_mgid0;
527 
528 	is_mgid0 = !memcmp(&mgid0, mgid, sizeof mgid0);
529 	if (!is_mgid0) {
530 		spin_lock_irqsave(&port->lock, flags);
531 		group = mcast_find(port, mgid);
532 		if (group)
533 			goto found;
534 		spin_unlock_irqrestore(&port->lock, flags);
535 	}
536 
537 	group = kzalloc(sizeof *group, gfp_mask);
538 	if (!group)
539 		return NULL;
540 
541 	group->port = port;
542 	group->rec.mgid = *mgid;
543 	INIT_LIST_HEAD(&group->pending_list);
544 	INIT_LIST_HEAD(&group->active_list);
545 	INIT_WORK(&group->work, mcast_work_handler);
546 	spin_lock_init(&group->lock);
547 
548 	spin_lock_irqsave(&port->lock, flags);
549 	cur_group = mcast_insert(port, group, is_mgid0);
550 	if (cur_group) {
551 		kfree(group);
552 		group = cur_group;
553 	} else
554 		atomic_inc(&port->refcount);
555 found:
556 	atomic_inc(&group->refcount);
557 	spin_unlock_irqrestore(&port->lock, flags);
558 	return group;
559 }
560 
561 /*
562  * We serialize all join requests to a single group to make our lives much
563  * easier.  Otherwise, two users could try to join the same group
564  * simultaneously, with different configurations, one could leave while the
565  * join is in progress, etc., which makes locking around error recovery
566  * difficult.
567  */
568 struct ib_sa_multicast *
569 ib_sa_join_multicast(struct ib_sa_client *client,
570 		     struct ib_device *device, u8 port_num,
571 		     struct ib_sa_mcmember_rec *rec,
572 		     ib_sa_comp_mask comp_mask, gfp_t gfp_mask,
573 		     int (*callback)(int status,
574 				     struct ib_sa_multicast *multicast),
575 		     void *context)
576 {
577 	struct mcast_device *dev;
578 	struct mcast_member *member;
579 	struct ib_sa_multicast *multicast;
580 	int ret;
581 
582 	dev = ib_get_client_data(device, &mcast_client);
583 	if (!dev)
584 		return ERR_PTR(-ENODEV);
585 
586 	member = kmalloc(sizeof *member, gfp_mask);
587 	if (!member)
588 		return ERR_PTR(-ENOMEM);
589 
590 	ib_sa_client_get(client);
591 	member->client = client;
592 	member->multicast.rec = *rec;
593 	member->multicast.comp_mask = comp_mask;
594 	member->multicast.callback = callback;
595 	member->multicast.context = context;
596 	init_completion(&member->comp);
597 	atomic_set(&member->refcount, 1);
598 	member->state = MCAST_JOINING;
599 
600 	member->group = acquire_group(&dev->port[port_num - dev->start_port],
601 				      &rec->mgid, gfp_mask);
602 	if (!member->group) {
603 		ret = -ENOMEM;
604 		goto err;
605 	}
606 
607 	/*
608 	 * The user will get the multicast structure in their callback.  They
609 	 * could then free the multicast structure before we can return from
610 	 * this routine.  So we save the pointer to return before queuing
611 	 * any callback.
612 	 */
613 	multicast = &member->multicast;
614 	queue_join(member);
615 	return multicast;
616 
617 err:
618 	ib_sa_client_put(client);
619 	kfree(member);
620 	return ERR_PTR(ret);
621 }
622 EXPORT_SYMBOL(ib_sa_join_multicast);
623 
624 void ib_sa_free_multicast(struct ib_sa_multicast *multicast)
625 {
626 	struct mcast_member *member;
627 	struct mcast_group *group;
628 
629 	member = container_of(multicast, struct mcast_member, multicast);
630 	group = member->group;
631 
632 	spin_lock_irq(&group->lock);
633 	if (member->state == MCAST_MEMBER)
634 		adjust_membership(group, multicast->rec.join_state, -1);
635 
636 	list_del_init(&member->list);
637 
638 	if (group->state == MCAST_IDLE) {
639 		group->state = MCAST_BUSY;
640 		spin_unlock_irq(&group->lock);
641 		/* Continue to hold reference on group until callback */
642 		queue_work(mcast_wq, &group->work);
643 	} else {
644 		spin_unlock_irq(&group->lock);
645 		release_group(group);
646 	}
647 
648 	deref_member(member);
649 	wait_for_completion(&member->comp);
650 	ib_sa_client_put(member->client);
651 	kfree(member);
652 }
653 EXPORT_SYMBOL(ib_sa_free_multicast);
654 
655 int ib_sa_get_mcmember_rec(struct ib_device *device, u8 port_num,
656 			   union ib_gid *mgid, struct ib_sa_mcmember_rec *rec)
657 {
658 	struct mcast_device *dev;
659 	struct mcast_port *port;
660 	struct mcast_group *group;
661 	unsigned long flags;
662 	int ret = 0;
663 
664 	dev = ib_get_client_data(device, &mcast_client);
665 	if (!dev)
666 		return -ENODEV;
667 
668 	port = &dev->port[port_num - dev->start_port];
669 	spin_lock_irqsave(&port->lock, flags);
670 	group = mcast_find(port, mgid);
671 	if (group)
672 		*rec = group->rec;
673 	else
674 		ret = -EADDRNOTAVAIL;
675 	spin_unlock_irqrestore(&port->lock, flags);
676 
677 	return ret;
678 }
679 EXPORT_SYMBOL(ib_sa_get_mcmember_rec);
680 
681 int ib_init_ah_from_mcmember(struct ib_device *device, u8 port_num,
682 			     struct ib_sa_mcmember_rec *rec,
683 			     struct ib_ah_attr *ah_attr)
684 {
685 	int ret;
686 	u16 gid_index;
687 	u8 p;
688 
689 	ret = ib_find_cached_gid(device, &rec->port_gid, &p, &gid_index);
690 	if (ret)
691 		return ret;
692 
693 	memset(ah_attr, 0, sizeof *ah_attr);
694 	ah_attr->dlid = be16_to_cpu(rec->mlid);
695 	ah_attr->sl = rec->sl;
696 	ah_attr->port_num = port_num;
697 	ah_attr->static_rate = rec->rate;
698 
699 	ah_attr->ah_flags = IB_AH_GRH;
700 	ah_attr->grh.dgid = rec->mgid;
701 
702 	ah_attr->grh.sgid_index = (u8) gid_index;
703 	ah_attr->grh.flow_label = be32_to_cpu(rec->flow_label);
704 	ah_attr->grh.hop_limit = rec->hop_limit;
705 	ah_attr->grh.traffic_class = rec->traffic_class;
706 
707 	return 0;
708 }
709 EXPORT_SYMBOL(ib_init_ah_from_mcmember);
710 
711 static void mcast_groups_lost(struct mcast_port *port)
712 {
713 	struct mcast_group *group;
714 	struct rb_node *node;
715 	unsigned long flags;
716 
717 	spin_lock_irqsave(&port->lock, flags);
718 	for (node = rb_first(&port->table); node; node = rb_next(node)) {
719 		group = rb_entry(node, struct mcast_group, node);
720 		spin_lock(&group->lock);
721 		if (group->state == MCAST_IDLE) {
722 			atomic_inc(&group->refcount);
723 			queue_work(mcast_wq, &group->work);
724 		}
725 		group->state = MCAST_ERROR;
726 		spin_unlock(&group->lock);
727 	}
728 	spin_unlock_irqrestore(&port->lock, flags);
729 }
730 
731 static void mcast_event_handler(struct ib_event_handler *handler,
732 				struct ib_event *event)
733 {
734 	struct mcast_device *dev;
735 
736 	dev = container_of(handler, struct mcast_device, event_handler);
737 
738 	switch (event->event) {
739 	case IB_EVENT_PORT_ERR:
740 	case IB_EVENT_LID_CHANGE:
741 	case IB_EVENT_SM_CHANGE:
742 	case IB_EVENT_CLIENT_REREGISTER:
743 		mcast_groups_lost(&dev->port[event->element.port_num -
744 					     dev->start_port]);
745 		break;
746 	default:
747 		break;
748 	}
749 }
750 
751 static void mcast_add_one(struct ib_device *device)
752 {
753 	struct mcast_device *dev;
754 	struct mcast_port *port;
755 	int i;
756 
757 	if (rdma_node_get_transport(device->node_type) != RDMA_TRANSPORT_IB)
758 		return;
759 
760 	dev = kmalloc(sizeof *dev + device->phys_port_cnt * sizeof *port,
761 		      GFP_KERNEL);
762 	if (!dev)
763 		return;
764 
765 	if (device->node_type == RDMA_NODE_IB_SWITCH)
766 		dev->start_port = dev->end_port = 0;
767 	else {
768 		dev->start_port = 1;
769 		dev->end_port = device->phys_port_cnt;
770 	}
771 
772 	for (i = 0; i <= dev->end_port - dev->start_port; i++) {
773 		port = &dev->port[i];
774 		port->dev = dev;
775 		port->port_num = dev->start_port + i;
776 		spin_lock_init(&port->lock);
777 		port->table = RB_ROOT;
778 		init_completion(&port->comp);
779 		atomic_set(&port->refcount, 1);
780 	}
781 
782 	dev->device = device;
783 	ib_set_client_data(device, &mcast_client, dev);
784 
785 	INIT_IB_EVENT_HANDLER(&dev->event_handler, device, mcast_event_handler);
786 	ib_register_event_handler(&dev->event_handler);
787 }
788 
789 static void mcast_remove_one(struct ib_device *device)
790 {
791 	struct mcast_device *dev;
792 	struct mcast_port *port;
793 	int i;
794 
795 	dev = ib_get_client_data(device, &mcast_client);
796 	if (!dev)
797 		return;
798 
799 	ib_unregister_event_handler(&dev->event_handler);
800 	flush_workqueue(mcast_wq);
801 
802 	for (i = 0; i <= dev->end_port - dev->start_port; i++) {
803 		port = &dev->port[i];
804 		deref_port(port);
805 		wait_for_completion(&port->comp);
806 	}
807 
808 	kfree(dev);
809 }
810 
811 int mcast_init(void)
812 {
813 	int ret;
814 
815 	mcast_wq = create_singlethread_workqueue("ib_mcast");
816 	if (!mcast_wq)
817 		return -ENOMEM;
818 
819 	ib_sa_register_client(&sa_client);
820 
821 	ret = ib_register_client(&mcast_client);
822 	if (ret)
823 		goto err;
824 	return 0;
825 
826 err:
827 	ib_sa_unregister_client(&sa_client);
828 	destroy_workqueue(mcast_wq);
829 	return ret;
830 }
831 
832 void mcast_cleanup(void)
833 {
834 	ib_unregister_client(&mcast_client);
835 	ib_sa_unregister_client(&sa_client);
836 	destroy_workqueue(mcast_wq);
837 }
838