xref: /linux/drivers/net/vxlan/vxlan_mdb.c (revision 02091cbe9cc4f18167208eec1d6de636cc731817)
1 // SPDX-License-Identifier: GPL-2.0-only
2 
3 #include <linux/if_bridge.h>
4 #include <linux/in.h>
5 #include <linux/list.h>
6 #include <linux/netdevice.h>
7 #include <linux/netlink.h>
8 #include <linux/rhashtable.h>
9 #include <linux/rhashtable-types.h>
10 #include <linux/rtnetlink.h>
11 #include <linux/skbuff.h>
12 #include <linux/types.h>
13 #include <net/netlink.h>
14 #include <net/vxlan.h>
15 
16 #include "vxlan_private.h"
17 
18 struct vxlan_mdb_entry_key {
19 	union vxlan_addr src;
20 	union vxlan_addr dst;
21 	__be32 vni;
22 };
23 
24 struct vxlan_mdb_entry {
25 	struct rhash_head rhnode;
26 	struct list_head remotes;
27 	struct vxlan_mdb_entry_key key;
28 	struct hlist_node mdb_node;
29 	struct rcu_head rcu;
30 };
31 
32 #define VXLAN_MDB_REMOTE_F_BLOCKED	BIT(0)
33 
34 struct vxlan_mdb_remote {
35 	struct list_head list;
36 	struct vxlan_rdst __rcu *rd;
37 	u8 flags;
38 	u8 filter_mode;
39 	u8 rt_protocol;
40 	struct hlist_head src_list;
41 	struct rcu_head rcu;
42 };
43 
44 #define VXLAN_SGRP_F_DELETE	BIT(0)
45 
46 struct vxlan_mdb_src_entry {
47 	struct hlist_node node;
48 	union vxlan_addr addr;
49 	u8 flags;
50 };
51 
52 struct vxlan_mdb_dump_ctx {
53 	long reserved;
54 	long entry_idx;
55 	long remote_idx;
56 };
57 
58 struct vxlan_mdb_config_src_entry {
59 	union vxlan_addr addr;
60 	struct list_head node;
61 };
62 
63 struct vxlan_mdb_config {
64 	struct vxlan_dev *vxlan;
65 	struct vxlan_mdb_entry_key group;
66 	struct list_head src_list;
67 	union vxlan_addr remote_ip;
68 	u32 remote_ifindex;
69 	__be32 remote_vni;
70 	__be16 remote_port;
71 	u16 nlflags;
72 	u8 flags;
73 	u8 filter_mode;
74 	u8 rt_protocol;
75 };
76 
77 static const struct rhashtable_params vxlan_mdb_rht_params = {
78 	.head_offset = offsetof(struct vxlan_mdb_entry, rhnode),
79 	.key_offset = offsetof(struct vxlan_mdb_entry, key),
80 	.key_len = sizeof(struct vxlan_mdb_entry_key),
81 	.automatic_shrinking = true,
82 };
83 
84 static int __vxlan_mdb_add(const struct vxlan_mdb_config *cfg,
85 			   struct netlink_ext_ack *extack);
86 static int __vxlan_mdb_del(const struct vxlan_mdb_config *cfg,
87 			   struct netlink_ext_ack *extack);
88 
89 static void vxlan_br_mdb_entry_fill(const struct vxlan_dev *vxlan,
90 				    const struct vxlan_mdb_entry *mdb_entry,
91 				    const struct vxlan_mdb_remote *remote,
92 				    struct br_mdb_entry *e)
93 {
94 	const union vxlan_addr *dst = &mdb_entry->key.dst;
95 
96 	memset(e, 0, sizeof(*e));
97 	e->ifindex = vxlan->dev->ifindex;
98 	e->state = MDB_PERMANENT;
99 
100 	if (remote->flags & VXLAN_MDB_REMOTE_F_BLOCKED)
101 		e->flags |= MDB_FLAGS_BLOCKED;
102 
103 	switch (dst->sa.sa_family) {
104 	case AF_INET:
105 		e->addr.u.ip4 = dst->sin.sin_addr.s_addr;
106 		e->addr.proto = htons(ETH_P_IP);
107 		break;
108 #if IS_ENABLED(CONFIG_IPV6)
109 	case AF_INET6:
110 		e->addr.u.ip6 = dst->sin6.sin6_addr;
111 		e->addr.proto = htons(ETH_P_IPV6);
112 		break;
113 #endif
114 	}
115 }
116 
117 static int vxlan_mdb_entry_info_fill_srcs(struct sk_buff *skb,
118 					  const struct vxlan_mdb_remote *remote)
119 {
120 	struct vxlan_mdb_src_entry *ent;
121 	struct nlattr *nest;
122 
123 	if (hlist_empty(&remote->src_list))
124 		return 0;
125 
126 	nest = nla_nest_start(skb, MDBA_MDB_EATTR_SRC_LIST);
127 	if (!nest)
128 		return -EMSGSIZE;
129 
130 	hlist_for_each_entry(ent, &remote->src_list, node) {
131 		struct nlattr *nest_ent;
132 
133 		nest_ent = nla_nest_start(skb, MDBA_MDB_SRCLIST_ENTRY);
134 		if (!nest_ent)
135 			goto out_cancel_err;
136 
137 		if (vxlan_nla_put_addr(skb, MDBA_MDB_SRCATTR_ADDRESS,
138 				       &ent->addr) ||
139 		    nla_put_u32(skb, MDBA_MDB_SRCATTR_TIMER, 0))
140 			goto out_cancel_err;
141 
142 		nla_nest_end(skb, nest_ent);
143 	}
144 
145 	nla_nest_end(skb, nest);
146 
147 	return 0;
148 
149 out_cancel_err:
150 	nla_nest_cancel(skb, nest);
151 	return -EMSGSIZE;
152 }
153 
154 static int vxlan_mdb_entry_info_fill(const struct vxlan_dev *vxlan,
155 				     struct sk_buff *skb,
156 				     const struct vxlan_mdb_entry *mdb_entry,
157 				     const struct vxlan_mdb_remote *remote)
158 {
159 	struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
160 	struct br_mdb_entry e;
161 	struct nlattr *nest;
162 
163 	nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY_INFO);
164 	if (!nest)
165 		return -EMSGSIZE;
166 
167 	vxlan_br_mdb_entry_fill(vxlan, mdb_entry, remote, &e);
168 
169 	if (nla_put_nohdr(skb, sizeof(e), &e) ||
170 	    nla_put_u32(skb, MDBA_MDB_EATTR_TIMER, 0))
171 		goto nest_err;
172 
173 	if (!vxlan_addr_any(&mdb_entry->key.src) &&
174 	    vxlan_nla_put_addr(skb, MDBA_MDB_EATTR_SOURCE, &mdb_entry->key.src))
175 		goto nest_err;
176 
177 	if (nla_put_u8(skb, MDBA_MDB_EATTR_RTPROT, remote->rt_protocol) ||
178 	    nla_put_u8(skb, MDBA_MDB_EATTR_GROUP_MODE, remote->filter_mode) ||
179 	    vxlan_mdb_entry_info_fill_srcs(skb, remote) ||
180 	    vxlan_nla_put_addr(skb, MDBA_MDB_EATTR_DST, &rd->remote_ip))
181 		goto nest_err;
182 
183 	if (rd->remote_port && rd->remote_port != vxlan->cfg.dst_port &&
184 	    nla_put_u16(skb, MDBA_MDB_EATTR_DST_PORT,
185 			be16_to_cpu(rd->remote_port)))
186 		goto nest_err;
187 
188 	if (rd->remote_vni != vxlan->default_dst.remote_vni &&
189 	    nla_put_u32(skb, MDBA_MDB_EATTR_VNI, be32_to_cpu(rd->remote_vni)))
190 		goto nest_err;
191 
192 	if (rd->remote_ifindex &&
193 	    nla_put_u32(skb, MDBA_MDB_EATTR_IFINDEX, rd->remote_ifindex))
194 		goto nest_err;
195 
196 	if ((vxlan->cfg.flags & VXLAN_F_COLLECT_METADATA) &&
197 	    mdb_entry->key.vni && nla_put_u32(skb, MDBA_MDB_EATTR_SRC_VNI,
198 					      be32_to_cpu(mdb_entry->key.vni)))
199 		goto nest_err;
200 
201 	nla_nest_end(skb, nest);
202 
203 	return 0;
204 
205 nest_err:
206 	nla_nest_cancel(skb, nest);
207 	return -EMSGSIZE;
208 }
209 
210 static int vxlan_mdb_entry_fill(const struct vxlan_dev *vxlan,
211 				struct sk_buff *skb,
212 				struct vxlan_mdb_dump_ctx *ctx,
213 				const struct vxlan_mdb_entry *mdb_entry)
214 {
215 	int remote_idx = 0, s_remote_idx = ctx->remote_idx;
216 	struct vxlan_mdb_remote *remote;
217 	struct nlattr *nest;
218 	int err = 0;
219 
220 	nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY);
221 	if (!nest)
222 		return -EMSGSIZE;
223 
224 	list_for_each_entry(remote, &mdb_entry->remotes, list) {
225 		if (remote_idx < s_remote_idx)
226 			goto skip;
227 
228 		err = vxlan_mdb_entry_info_fill(vxlan, skb, mdb_entry, remote);
229 		if (err)
230 			break;
231 skip:
232 		remote_idx++;
233 	}
234 
235 	ctx->remote_idx = err ? remote_idx : 0;
236 	nla_nest_end(skb, nest);
237 	return err;
238 }
239 
240 static int vxlan_mdb_fill(const struct vxlan_dev *vxlan, struct sk_buff *skb,
241 			  struct vxlan_mdb_dump_ctx *ctx)
242 {
243 	int entry_idx = 0, s_entry_idx = ctx->entry_idx;
244 	struct vxlan_mdb_entry *mdb_entry;
245 	struct nlattr *nest;
246 	int err = 0;
247 
248 	nest = nla_nest_start_noflag(skb, MDBA_MDB);
249 	if (!nest)
250 		return -EMSGSIZE;
251 
252 	hlist_for_each_entry(mdb_entry, &vxlan->mdb_list, mdb_node) {
253 		if (entry_idx < s_entry_idx)
254 			goto skip;
255 
256 		err = vxlan_mdb_entry_fill(vxlan, skb, ctx, mdb_entry);
257 		if (err)
258 			break;
259 skip:
260 		entry_idx++;
261 	}
262 
263 	ctx->entry_idx = err ? entry_idx : 0;
264 	nla_nest_end(skb, nest);
265 	return err;
266 }
267 
268 int vxlan_mdb_dump(struct net_device *dev, struct sk_buff *skb,
269 		   struct netlink_callback *cb)
270 {
271 	struct vxlan_mdb_dump_ctx *ctx = (void *)cb->ctx;
272 	struct vxlan_dev *vxlan = netdev_priv(dev);
273 	struct br_port_msg *bpm;
274 	struct nlmsghdr *nlh;
275 	int err;
276 
277 	ASSERT_RTNL();
278 
279 	NL_ASSERT_DUMP_CTX_FITS(struct vxlan_mdb_dump_ctx);
280 
281 	nlh = nlmsg_put(skb, NETLINK_CB(cb->skb).portid,
282 			cb->nlh->nlmsg_seq, RTM_NEWMDB, sizeof(*bpm),
283 			NLM_F_MULTI);
284 	if (!nlh)
285 		return -EMSGSIZE;
286 
287 	bpm = nlmsg_data(nlh);
288 	memset(bpm, 0, sizeof(*bpm));
289 	bpm->family = AF_BRIDGE;
290 	bpm->ifindex = dev->ifindex;
291 
292 	err = vxlan_mdb_fill(vxlan, skb, ctx);
293 
294 	nlmsg_end(skb, nlh);
295 
296 	cb->seq = vxlan->mdb_seq;
297 	nl_dump_check_consistent(cb, nlh);
298 
299 	return err;
300 }
301 
302 static const struct nla_policy
303 vxlan_mdbe_src_list_entry_pol[MDBE_SRCATTR_MAX + 1] = {
304 	[MDBE_SRCATTR_ADDRESS] = NLA_POLICY_RANGE(NLA_BINARY,
305 						  sizeof(struct in_addr),
306 						  sizeof(struct in6_addr)),
307 };
308 
309 static const struct nla_policy
310 vxlan_mdbe_src_list_pol[MDBE_SRC_LIST_MAX + 1] = {
311 	[MDBE_SRC_LIST_ENTRY] = NLA_POLICY_NESTED(vxlan_mdbe_src_list_entry_pol),
312 };
313 
314 static struct netlink_range_validation vni_range = {
315 	.max = VXLAN_N_VID - 1,
316 };
317 
318 static const struct nla_policy vxlan_mdbe_attrs_pol[MDBE_ATTR_MAX + 1] = {
319 	[MDBE_ATTR_SOURCE] = NLA_POLICY_RANGE(NLA_BINARY,
320 					      sizeof(struct in_addr),
321 					      sizeof(struct in6_addr)),
322 	[MDBE_ATTR_GROUP_MODE] = NLA_POLICY_RANGE(NLA_U8, MCAST_EXCLUDE,
323 						  MCAST_INCLUDE),
324 	[MDBE_ATTR_SRC_LIST] = NLA_POLICY_NESTED(vxlan_mdbe_src_list_pol),
325 	[MDBE_ATTR_RTPROT] = NLA_POLICY_MIN(NLA_U8, RTPROT_STATIC),
326 	[MDBE_ATTR_DST] = NLA_POLICY_RANGE(NLA_BINARY,
327 					   sizeof(struct in_addr),
328 					   sizeof(struct in6_addr)),
329 	[MDBE_ATTR_DST_PORT] = { .type = NLA_U16 },
330 	[MDBE_ATTR_VNI] = NLA_POLICY_FULL_RANGE(NLA_U32, &vni_range),
331 	[MDBE_ATTR_IFINDEX] = NLA_POLICY_MIN(NLA_S32, 1),
332 	[MDBE_ATTR_SRC_VNI] = NLA_POLICY_FULL_RANGE(NLA_U32, &vni_range),
333 };
334 
335 static bool vxlan_mdb_is_valid_source(const struct nlattr *attr, __be16 proto,
336 				      struct netlink_ext_ack *extack)
337 {
338 	switch (proto) {
339 	case htons(ETH_P_IP):
340 		if (nla_len(attr) != sizeof(struct in_addr)) {
341 			NL_SET_ERR_MSG_MOD(extack, "IPv4 invalid source address length");
342 			return false;
343 		}
344 		if (ipv4_is_multicast(nla_get_in_addr(attr))) {
345 			NL_SET_ERR_MSG_MOD(extack, "IPv4 multicast source address is not allowed");
346 			return false;
347 		}
348 		break;
349 #if IS_ENABLED(CONFIG_IPV6)
350 	case htons(ETH_P_IPV6): {
351 		struct in6_addr src;
352 
353 		if (nla_len(attr) != sizeof(struct in6_addr)) {
354 			NL_SET_ERR_MSG_MOD(extack, "IPv6 invalid source address length");
355 			return false;
356 		}
357 		src = nla_get_in6_addr(attr);
358 		if (ipv6_addr_is_multicast(&src)) {
359 			NL_SET_ERR_MSG_MOD(extack, "IPv6 multicast source address is not allowed");
360 			return false;
361 		}
362 		break;
363 	}
364 #endif
365 	default:
366 		NL_SET_ERR_MSG_MOD(extack, "Invalid protocol used with source address");
367 		return false;
368 	}
369 
370 	return true;
371 }
372 
373 static void vxlan_mdb_config_group_set(struct vxlan_mdb_config *cfg,
374 				       const struct br_mdb_entry *entry,
375 				       const struct nlattr *source_attr)
376 {
377 	struct vxlan_mdb_entry_key *group = &cfg->group;
378 
379 	switch (entry->addr.proto) {
380 	case htons(ETH_P_IP):
381 		group->dst.sa.sa_family = AF_INET;
382 		group->dst.sin.sin_addr.s_addr = entry->addr.u.ip4;
383 		break;
384 #if IS_ENABLED(CONFIG_IPV6)
385 	case htons(ETH_P_IPV6):
386 		group->dst.sa.sa_family = AF_INET6;
387 		group->dst.sin6.sin6_addr = entry->addr.u.ip6;
388 		break;
389 #endif
390 	}
391 
392 	if (source_attr)
393 		vxlan_nla_get_addr(&group->src, source_attr);
394 }
395 
396 static bool vxlan_mdb_is_star_g(const struct vxlan_mdb_entry_key *group)
397 {
398 	return !vxlan_addr_any(&group->dst) && vxlan_addr_any(&group->src);
399 }
400 
401 static bool vxlan_mdb_is_sg(const struct vxlan_mdb_entry_key *group)
402 {
403 	return !vxlan_addr_any(&group->dst) && !vxlan_addr_any(&group->src);
404 }
405 
406 static int vxlan_mdb_config_src_entry_init(struct vxlan_mdb_config *cfg,
407 					   __be16 proto,
408 					   const struct nlattr *src_entry,
409 					   struct netlink_ext_ack *extack)
410 {
411 	struct nlattr *tb[MDBE_SRCATTR_MAX + 1];
412 	struct vxlan_mdb_config_src_entry *src;
413 	int err;
414 
415 	err = nla_parse_nested(tb, MDBE_SRCATTR_MAX, src_entry,
416 			       vxlan_mdbe_src_list_entry_pol, extack);
417 	if (err)
418 		return err;
419 
420 	if (NL_REQ_ATTR_CHECK(extack, src_entry, tb, MDBE_SRCATTR_ADDRESS))
421 		return -EINVAL;
422 
423 	if (!vxlan_mdb_is_valid_source(tb[MDBE_SRCATTR_ADDRESS], proto,
424 				       extack))
425 		return -EINVAL;
426 
427 	src = kzalloc(sizeof(*src), GFP_KERNEL);
428 	if (!src)
429 		return -ENOMEM;
430 
431 	err = vxlan_nla_get_addr(&src->addr, tb[MDBE_SRCATTR_ADDRESS]);
432 	if (err)
433 		goto err_free_src;
434 
435 	list_add_tail(&src->node, &cfg->src_list);
436 
437 	return 0;
438 
439 err_free_src:
440 	kfree(src);
441 	return err;
442 }
443 
444 static void
445 vxlan_mdb_config_src_entry_fini(struct vxlan_mdb_config_src_entry *src)
446 {
447 	list_del(&src->node);
448 	kfree(src);
449 }
450 
451 static int vxlan_mdb_config_src_list_init(struct vxlan_mdb_config *cfg,
452 					  __be16 proto,
453 					  const struct nlattr *src_list,
454 					  struct netlink_ext_ack *extack)
455 {
456 	struct vxlan_mdb_config_src_entry *src, *tmp;
457 	struct nlattr *src_entry;
458 	int rem, err;
459 
460 	nla_for_each_nested(src_entry, src_list, rem) {
461 		err = vxlan_mdb_config_src_entry_init(cfg, proto, src_entry,
462 						      extack);
463 		if (err)
464 			goto err_src_entry_init;
465 	}
466 
467 	return 0;
468 
469 err_src_entry_init:
470 	list_for_each_entry_safe_reverse(src, tmp, &cfg->src_list, node)
471 		vxlan_mdb_config_src_entry_fini(src);
472 	return err;
473 }
474 
475 static void vxlan_mdb_config_src_list_fini(struct vxlan_mdb_config *cfg)
476 {
477 	struct vxlan_mdb_config_src_entry *src, *tmp;
478 
479 	list_for_each_entry_safe_reverse(src, tmp, &cfg->src_list, node)
480 		vxlan_mdb_config_src_entry_fini(src);
481 }
482 
483 static int vxlan_mdb_config_attrs_init(struct vxlan_mdb_config *cfg,
484 				       const struct br_mdb_entry *entry,
485 				       const struct nlattr *set_attrs,
486 				       struct netlink_ext_ack *extack)
487 {
488 	struct nlattr *mdbe_attrs[MDBE_ATTR_MAX + 1];
489 	int err;
490 
491 	err = nla_parse_nested(mdbe_attrs, MDBE_ATTR_MAX, set_attrs,
492 			       vxlan_mdbe_attrs_pol, extack);
493 	if (err)
494 		return err;
495 
496 	if (NL_REQ_ATTR_CHECK(extack, set_attrs, mdbe_attrs, MDBE_ATTR_DST)) {
497 		NL_SET_ERR_MSG_MOD(extack, "Missing remote destination IP address");
498 		return -EINVAL;
499 	}
500 
501 	if (mdbe_attrs[MDBE_ATTR_SOURCE] &&
502 	    !vxlan_mdb_is_valid_source(mdbe_attrs[MDBE_ATTR_SOURCE],
503 				       entry->addr.proto, extack))
504 		return -EINVAL;
505 
506 	vxlan_mdb_config_group_set(cfg, entry, mdbe_attrs[MDBE_ATTR_SOURCE]);
507 
508 	/* rtnetlink code only validates that IPv4 group address is
509 	 * multicast.
510 	 */
511 	if (!vxlan_addr_is_multicast(&cfg->group.dst) &&
512 	    !vxlan_addr_any(&cfg->group.dst)) {
513 		NL_SET_ERR_MSG_MOD(extack, "Group address is not multicast");
514 		return -EINVAL;
515 	}
516 
517 	if (vxlan_addr_any(&cfg->group.dst) &&
518 	    mdbe_attrs[MDBE_ATTR_SOURCE]) {
519 		NL_SET_ERR_MSG_MOD(extack, "Source cannot be specified for the all-zeros entry");
520 		return -EINVAL;
521 	}
522 
523 	if (vxlan_mdb_is_sg(&cfg->group))
524 		cfg->filter_mode = MCAST_INCLUDE;
525 
526 	if (mdbe_attrs[MDBE_ATTR_GROUP_MODE]) {
527 		if (!vxlan_mdb_is_star_g(&cfg->group)) {
528 			NL_SET_ERR_MSG_MOD(extack, "Filter mode can only be set for (*, G) entries");
529 			return -EINVAL;
530 		}
531 		cfg->filter_mode = nla_get_u8(mdbe_attrs[MDBE_ATTR_GROUP_MODE]);
532 	}
533 
534 	if (mdbe_attrs[MDBE_ATTR_SRC_LIST]) {
535 		if (!vxlan_mdb_is_star_g(&cfg->group)) {
536 			NL_SET_ERR_MSG_MOD(extack, "Source list can only be set for (*, G) entries");
537 			return -EINVAL;
538 		}
539 		if (!mdbe_attrs[MDBE_ATTR_GROUP_MODE]) {
540 			NL_SET_ERR_MSG_MOD(extack, "Source list cannot be set without filter mode");
541 			return -EINVAL;
542 		}
543 		err = vxlan_mdb_config_src_list_init(cfg, entry->addr.proto,
544 						     mdbe_attrs[MDBE_ATTR_SRC_LIST],
545 						     extack);
546 		if (err)
547 			return err;
548 	}
549 
550 	if (vxlan_mdb_is_star_g(&cfg->group) && list_empty(&cfg->src_list) &&
551 	    cfg->filter_mode == MCAST_INCLUDE) {
552 		NL_SET_ERR_MSG_MOD(extack, "Cannot add (*, G) INCLUDE with an empty source list");
553 		return -EINVAL;
554 	}
555 
556 	if (mdbe_attrs[MDBE_ATTR_RTPROT])
557 		cfg->rt_protocol = nla_get_u8(mdbe_attrs[MDBE_ATTR_RTPROT]);
558 
559 	err = vxlan_nla_get_addr(&cfg->remote_ip, mdbe_attrs[MDBE_ATTR_DST]);
560 	if (err) {
561 		NL_SET_ERR_MSG_MOD(extack, "Invalid remote destination address");
562 		goto err_src_list_fini;
563 	}
564 
565 	if (mdbe_attrs[MDBE_ATTR_DST_PORT])
566 		cfg->remote_port =
567 			cpu_to_be16(nla_get_u16(mdbe_attrs[MDBE_ATTR_DST_PORT]));
568 
569 	if (mdbe_attrs[MDBE_ATTR_VNI])
570 		cfg->remote_vni =
571 			cpu_to_be32(nla_get_u32(mdbe_attrs[MDBE_ATTR_VNI]));
572 
573 	if (mdbe_attrs[MDBE_ATTR_IFINDEX]) {
574 		cfg->remote_ifindex =
575 			nla_get_s32(mdbe_attrs[MDBE_ATTR_IFINDEX]);
576 		if (!__dev_get_by_index(cfg->vxlan->net, cfg->remote_ifindex)) {
577 			NL_SET_ERR_MSG_MOD(extack, "Outgoing interface not found");
578 			err = -EINVAL;
579 			goto err_src_list_fini;
580 		}
581 	}
582 
583 	if (mdbe_attrs[MDBE_ATTR_SRC_VNI])
584 		cfg->group.vni =
585 			cpu_to_be32(nla_get_u32(mdbe_attrs[MDBE_ATTR_SRC_VNI]));
586 
587 	return 0;
588 
589 err_src_list_fini:
590 	vxlan_mdb_config_src_list_fini(cfg);
591 	return err;
592 }
593 
594 static int vxlan_mdb_config_init(struct vxlan_mdb_config *cfg,
595 				 struct net_device *dev, struct nlattr *tb[],
596 				 u16 nlmsg_flags,
597 				 struct netlink_ext_ack *extack)
598 {
599 	struct br_mdb_entry *entry = nla_data(tb[MDBA_SET_ENTRY]);
600 	struct vxlan_dev *vxlan = netdev_priv(dev);
601 
602 	memset(cfg, 0, sizeof(*cfg));
603 	cfg->vxlan = vxlan;
604 	cfg->group.vni = vxlan->default_dst.remote_vni;
605 	INIT_LIST_HEAD(&cfg->src_list);
606 	cfg->nlflags = nlmsg_flags;
607 	cfg->filter_mode = MCAST_EXCLUDE;
608 	cfg->rt_protocol = RTPROT_STATIC;
609 	cfg->remote_vni = vxlan->default_dst.remote_vni;
610 	cfg->remote_port = vxlan->cfg.dst_port;
611 
612 	if (entry->ifindex != dev->ifindex) {
613 		NL_SET_ERR_MSG_MOD(extack, "Port net device must be the VXLAN net device");
614 		return -EINVAL;
615 	}
616 
617 	/* State is not part of the entry key and can be ignored on deletion
618 	 * requests.
619 	 */
620 	if ((nlmsg_flags & (NLM_F_CREATE | NLM_F_REPLACE)) &&
621 	    entry->state != MDB_PERMANENT) {
622 		NL_SET_ERR_MSG_MOD(extack, "MDB entry must be permanent");
623 		return -EINVAL;
624 	}
625 
626 	if (entry->flags) {
627 		NL_SET_ERR_MSG_MOD(extack, "Invalid MDB entry flags");
628 		return -EINVAL;
629 	}
630 
631 	if (entry->vid) {
632 		NL_SET_ERR_MSG_MOD(extack, "VID must not be specified");
633 		return -EINVAL;
634 	}
635 
636 	if (entry->addr.proto != htons(ETH_P_IP) &&
637 	    entry->addr.proto != htons(ETH_P_IPV6)) {
638 		NL_SET_ERR_MSG_MOD(extack, "Group address must be an IPv4 / IPv6 address");
639 		return -EINVAL;
640 	}
641 
642 	if (NL_REQ_ATTR_CHECK(extack, NULL, tb, MDBA_SET_ENTRY_ATTRS)) {
643 		NL_SET_ERR_MSG_MOD(extack, "Missing MDBA_SET_ENTRY_ATTRS attribute");
644 		return -EINVAL;
645 	}
646 
647 	return vxlan_mdb_config_attrs_init(cfg, entry, tb[MDBA_SET_ENTRY_ATTRS],
648 					   extack);
649 }
650 
651 static void vxlan_mdb_config_fini(struct vxlan_mdb_config *cfg)
652 {
653 	vxlan_mdb_config_src_list_fini(cfg);
654 }
655 
656 static struct vxlan_mdb_entry *
657 vxlan_mdb_entry_lookup(struct vxlan_dev *vxlan,
658 		       const struct vxlan_mdb_entry_key *group)
659 {
660 	return rhashtable_lookup_fast(&vxlan->mdb_tbl, group,
661 				      vxlan_mdb_rht_params);
662 }
663 
664 static struct vxlan_mdb_remote *
665 vxlan_mdb_remote_lookup(const struct vxlan_mdb_entry *mdb_entry,
666 			const union vxlan_addr *addr)
667 {
668 	struct vxlan_mdb_remote *remote;
669 
670 	list_for_each_entry(remote, &mdb_entry->remotes, list) {
671 		struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
672 
673 		if (vxlan_addr_equal(addr, &rd->remote_ip))
674 			return remote;
675 	}
676 
677 	return NULL;
678 }
679 
680 static void vxlan_mdb_rdst_free(struct rcu_head *head)
681 {
682 	struct vxlan_rdst *rd = container_of(head, struct vxlan_rdst, rcu);
683 
684 	dst_cache_destroy(&rd->dst_cache);
685 	kfree(rd);
686 }
687 
688 static int vxlan_mdb_remote_rdst_init(const struct vxlan_mdb_config *cfg,
689 				      struct vxlan_mdb_remote *remote)
690 {
691 	struct vxlan_rdst *rd;
692 	int err;
693 
694 	rd = kzalloc(sizeof(*rd), GFP_KERNEL);
695 	if (!rd)
696 		return -ENOMEM;
697 
698 	err = dst_cache_init(&rd->dst_cache, GFP_KERNEL);
699 	if (err)
700 		goto err_free_rdst;
701 
702 	rd->remote_ip = cfg->remote_ip;
703 	rd->remote_port = cfg->remote_port;
704 	rd->remote_vni = cfg->remote_vni;
705 	rd->remote_ifindex = cfg->remote_ifindex;
706 	rcu_assign_pointer(remote->rd, rd);
707 
708 	return 0;
709 
710 err_free_rdst:
711 	kfree(rd);
712 	return err;
713 }
714 
715 static void vxlan_mdb_remote_rdst_fini(struct vxlan_rdst *rd)
716 {
717 	call_rcu(&rd->rcu, vxlan_mdb_rdst_free);
718 }
719 
720 static int vxlan_mdb_remote_init(const struct vxlan_mdb_config *cfg,
721 				 struct vxlan_mdb_remote *remote)
722 {
723 	int err;
724 
725 	err = vxlan_mdb_remote_rdst_init(cfg, remote);
726 	if (err)
727 		return err;
728 
729 	remote->flags = cfg->flags;
730 	remote->filter_mode = cfg->filter_mode;
731 	remote->rt_protocol = cfg->rt_protocol;
732 	INIT_HLIST_HEAD(&remote->src_list);
733 
734 	return 0;
735 }
736 
737 static void vxlan_mdb_remote_fini(struct vxlan_dev *vxlan,
738 				  struct vxlan_mdb_remote *remote)
739 {
740 	WARN_ON_ONCE(!hlist_empty(&remote->src_list));
741 	vxlan_mdb_remote_rdst_fini(rtnl_dereference(remote->rd));
742 }
743 
744 static struct vxlan_mdb_src_entry *
745 vxlan_mdb_remote_src_entry_lookup(const struct vxlan_mdb_remote *remote,
746 				  const union vxlan_addr *addr)
747 {
748 	struct vxlan_mdb_src_entry *ent;
749 
750 	hlist_for_each_entry(ent, &remote->src_list, node) {
751 		if (vxlan_addr_equal(&ent->addr, addr))
752 			return ent;
753 	}
754 
755 	return NULL;
756 }
757 
758 static struct vxlan_mdb_src_entry *
759 vxlan_mdb_remote_src_entry_add(struct vxlan_mdb_remote *remote,
760 			       const union vxlan_addr *addr)
761 {
762 	struct vxlan_mdb_src_entry *ent;
763 
764 	ent = kzalloc(sizeof(*ent), GFP_KERNEL);
765 	if (!ent)
766 		return NULL;
767 
768 	ent->addr = *addr;
769 	hlist_add_head(&ent->node, &remote->src_list);
770 
771 	return ent;
772 }
773 
774 static void
775 vxlan_mdb_remote_src_entry_del(struct vxlan_mdb_src_entry *ent)
776 {
777 	hlist_del(&ent->node);
778 	kfree(ent);
779 }
780 
781 static int
782 vxlan_mdb_remote_src_fwd_add(const struct vxlan_mdb_config *cfg,
783 			     const union vxlan_addr *addr,
784 			     struct netlink_ext_ack *extack)
785 {
786 	struct vxlan_mdb_config sg_cfg;
787 
788 	memset(&sg_cfg, 0, sizeof(sg_cfg));
789 	sg_cfg.vxlan = cfg->vxlan;
790 	sg_cfg.group.src = *addr;
791 	sg_cfg.group.dst = cfg->group.dst;
792 	sg_cfg.group.vni = cfg->group.vni;
793 	INIT_LIST_HEAD(&sg_cfg.src_list);
794 	sg_cfg.remote_ip = cfg->remote_ip;
795 	sg_cfg.remote_ifindex = cfg->remote_ifindex;
796 	sg_cfg.remote_vni = cfg->remote_vni;
797 	sg_cfg.remote_port = cfg->remote_port;
798 	sg_cfg.nlflags = cfg->nlflags;
799 	sg_cfg.filter_mode = MCAST_INCLUDE;
800 	if (cfg->filter_mode == MCAST_EXCLUDE)
801 		sg_cfg.flags = VXLAN_MDB_REMOTE_F_BLOCKED;
802 	sg_cfg.rt_protocol = cfg->rt_protocol;
803 
804 	return __vxlan_mdb_add(&sg_cfg, extack);
805 }
806 
807 static void
808 vxlan_mdb_remote_src_fwd_del(struct vxlan_dev *vxlan,
809 			     const struct vxlan_mdb_entry_key *group,
810 			     const struct vxlan_mdb_remote *remote,
811 			     const union vxlan_addr *addr)
812 {
813 	struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
814 	struct vxlan_mdb_config sg_cfg;
815 
816 	memset(&sg_cfg, 0, sizeof(sg_cfg));
817 	sg_cfg.vxlan = vxlan;
818 	sg_cfg.group.src = *addr;
819 	sg_cfg.group.dst = group->dst;
820 	sg_cfg.group.vni = group->vni;
821 	INIT_LIST_HEAD(&sg_cfg.src_list);
822 	sg_cfg.remote_ip = rd->remote_ip;
823 
824 	__vxlan_mdb_del(&sg_cfg, NULL);
825 }
826 
827 static int
828 vxlan_mdb_remote_src_add(const struct vxlan_mdb_config *cfg,
829 			 struct vxlan_mdb_remote *remote,
830 			 const struct vxlan_mdb_config_src_entry *src,
831 			 struct netlink_ext_ack *extack)
832 {
833 	struct vxlan_mdb_src_entry *ent;
834 	int err;
835 
836 	ent = vxlan_mdb_remote_src_entry_lookup(remote, &src->addr);
837 	if (!ent) {
838 		ent = vxlan_mdb_remote_src_entry_add(remote, &src->addr);
839 		if (!ent)
840 			return -ENOMEM;
841 	} else if (!(cfg->nlflags & NLM_F_REPLACE)) {
842 		NL_SET_ERR_MSG_MOD(extack, "Source entry already exists");
843 		return -EEXIST;
844 	}
845 
846 	err = vxlan_mdb_remote_src_fwd_add(cfg, &ent->addr, extack);
847 	if (err)
848 		goto err_src_del;
849 
850 	/* Clear flags in case source entry was marked for deletion as part of
851 	 * replace flow.
852 	 */
853 	ent->flags = 0;
854 
855 	return 0;
856 
857 err_src_del:
858 	vxlan_mdb_remote_src_entry_del(ent);
859 	return err;
860 }
861 
862 static void vxlan_mdb_remote_src_del(struct vxlan_dev *vxlan,
863 				     const struct vxlan_mdb_entry_key *group,
864 				     const struct vxlan_mdb_remote *remote,
865 				     struct vxlan_mdb_src_entry *ent)
866 {
867 	vxlan_mdb_remote_src_fwd_del(vxlan, group, remote, &ent->addr);
868 	vxlan_mdb_remote_src_entry_del(ent);
869 }
870 
871 static int vxlan_mdb_remote_srcs_add(const struct vxlan_mdb_config *cfg,
872 				     struct vxlan_mdb_remote *remote,
873 				     struct netlink_ext_ack *extack)
874 {
875 	struct vxlan_mdb_config_src_entry *src;
876 	struct vxlan_mdb_src_entry *ent;
877 	struct hlist_node *tmp;
878 	int err;
879 
880 	list_for_each_entry(src, &cfg->src_list, node) {
881 		err = vxlan_mdb_remote_src_add(cfg, remote, src, extack);
882 		if (err)
883 			goto err_src_del;
884 	}
885 
886 	return 0;
887 
888 err_src_del:
889 	hlist_for_each_entry_safe(ent, tmp, &remote->src_list, node)
890 		vxlan_mdb_remote_src_del(cfg->vxlan, &cfg->group, remote, ent);
891 	return err;
892 }
893 
894 static void vxlan_mdb_remote_srcs_del(struct vxlan_dev *vxlan,
895 				      const struct vxlan_mdb_entry_key *group,
896 				      struct vxlan_mdb_remote *remote)
897 {
898 	struct vxlan_mdb_src_entry *ent;
899 	struct hlist_node *tmp;
900 
901 	hlist_for_each_entry_safe(ent, tmp, &remote->src_list, node)
902 		vxlan_mdb_remote_src_del(vxlan, group, remote, ent);
903 }
904 
905 static size_t
906 vxlan_mdb_nlmsg_src_list_size(const struct vxlan_mdb_entry_key *group,
907 			      const struct vxlan_mdb_remote *remote)
908 {
909 	struct vxlan_mdb_src_entry *ent;
910 	size_t nlmsg_size;
911 
912 	if (hlist_empty(&remote->src_list))
913 		return 0;
914 
915 	/* MDBA_MDB_EATTR_SRC_LIST */
916 	nlmsg_size = nla_total_size(0);
917 
918 	hlist_for_each_entry(ent, &remote->src_list, node) {
919 			      /* MDBA_MDB_SRCLIST_ENTRY */
920 		nlmsg_size += nla_total_size(0) +
921 			      /* MDBA_MDB_SRCATTR_ADDRESS */
922 			      nla_total_size(vxlan_addr_size(&group->dst)) +
923 			      /* MDBA_MDB_SRCATTR_TIMER */
924 			      nla_total_size(sizeof(u8));
925 	}
926 
927 	return nlmsg_size;
928 }
929 
930 static size_t vxlan_mdb_nlmsg_size(const struct vxlan_dev *vxlan,
931 				   const struct vxlan_mdb_entry *mdb_entry,
932 				   const struct vxlan_mdb_remote *remote)
933 {
934 	const struct vxlan_mdb_entry_key *group = &mdb_entry->key;
935 	struct vxlan_rdst *rd = rtnl_dereference(remote->rd);
936 	size_t nlmsg_size;
937 
938 	nlmsg_size = NLMSG_ALIGN(sizeof(struct br_port_msg)) +
939 		     /* MDBA_MDB */
940 		     nla_total_size(0) +
941 		     /* MDBA_MDB_ENTRY */
942 		     nla_total_size(0) +
943 		     /* MDBA_MDB_ENTRY_INFO */
944 		     nla_total_size(sizeof(struct br_mdb_entry)) +
945 		     /* MDBA_MDB_EATTR_TIMER */
946 		     nla_total_size(sizeof(u32));
947 	/* MDBA_MDB_EATTR_SOURCE */
948 	if (vxlan_mdb_is_sg(group))
949 		nlmsg_size += nla_total_size(vxlan_addr_size(&group->dst));
950 	/* MDBA_MDB_EATTR_RTPROT */
951 	nlmsg_size += nla_total_size(sizeof(u8));
952 	/* MDBA_MDB_EATTR_SRC_LIST */
953 	nlmsg_size += vxlan_mdb_nlmsg_src_list_size(group, remote);
954 	/* MDBA_MDB_EATTR_GROUP_MODE */
955 	nlmsg_size += nla_total_size(sizeof(u8));
956 	/* MDBA_MDB_EATTR_DST */
957 	nlmsg_size += nla_total_size(vxlan_addr_size(&rd->remote_ip));
958 	/* MDBA_MDB_EATTR_DST_PORT */
959 	if (rd->remote_port && rd->remote_port != vxlan->cfg.dst_port)
960 		nlmsg_size += nla_total_size(sizeof(u16));
961 	/* MDBA_MDB_EATTR_VNI */
962 	if (rd->remote_vni != vxlan->default_dst.remote_vni)
963 		nlmsg_size += nla_total_size(sizeof(u32));
964 	/* MDBA_MDB_EATTR_IFINDEX */
965 	if (rd->remote_ifindex)
966 		nlmsg_size += nla_total_size(sizeof(u32));
967 	/* MDBA_MDB_EATTR_SRC_VNI */
968 	if ((vxlan->cfg.flags & VXLAN_F_COLLECT_METADATA) && group->vni)
969 		nlmsg_size += nla_total_size(sizeof(u32));
970 
971 	return nlmsg_size;
972 }
973 
974 static int vxlan_mdb_nlmsg_fill(const struct vxlan_dev *vxlan,
975 				struct sk_buff *skb,
976 				const struct vxlan_mdb_entry *mdb_entry,
977 				const struct vxlan_mdb_remote *remote,
978 				int type)
979 {
980 	struct nlattr *mdb_nest, *mdb_entry_nest;
981 	struct br_port_msg *bpm;
982 	struct nlmsghdr *nlh;
983 
984 	nlh = nlmsg_put(skb, 0, 0, type, sizeof(*bpm), 0);
985 	if (!nlh)
986 		return -EMSGSIZE;
987 
988 	bpm = nlmsg_data(nlh);
989 	memset(bpm, 0, sizeof(*bpm));
990 	bpm->family  = AF_BRIDGE;
991 	bpm->ifindex = vxlan->dev->ifindex;
992 
993 	mdb_nest = nla_nest_start_noflag(skb, MDBA_MDB);
994 	if (!mdb_nest)
995 		goto cancel;
996 	mdb_entry_nest = nla_nest_start_noflag(skb, MDBA_MDB_ENTRY);
997 	if (!mdb_entry_nest)
998 		goto cancel;
999 
1000 	if (vxlan_mdb_entry_info_fill(vxlan, skb, mdb_entry, remote))
1001 		goto cancel;
1002 
1003 	nla_nest_end(skb, mdb_entry_nest);
1004 	nla_nest_end(skb, mdb_nest);
1005 	nlmsg_end(skb, nlh);
1006 
1007 	return 0;
1008 
1009 cancel:
1010 	nlmsg_cancel(skb, nlh);
1011 	return -EMSGSIZE;
1012 }
1013 
1014 static void vxlan_mdb_remote_notify(const struct vxlan_dev *vxlan,
1015 				    const struct vxlan_mdb_entry *mdb_entry,
1016 				    const struct vxlan_mdb_remote *remote,
1017 				    int type)
1018 {
1019 	struct net *net = dev_net(vxlan->dev);
1020 	struct sk_buff *skb;
1021 	int err = -ENOBUFS;
1022 
1023 	skb = nlmsg_new(vxlan_mdb_nlmsg_size(vxlan, mdb_entry, remote),
1024 			GFP_KERNEL);
1025 	if (!skb)
1026 		goto errout;
1027 
1028 	err = vxlan_mdb_nlmsg_fill(vxlan, skb, mdb_entry, remote, type);
1029 	if (err) {
1030 		kfree_skb(skb);
1031 		goto errout;
1032 	}
1033 
1034 	rtnl_notify(skb, net, 0, RTNLGRP_MDB, NULL, GFP_KERNEL);
1035 	return;
1036 errout:
1037 	rtnl_set_sk_err(net, RTNLGRP_MDB, err);
1038 }
1039 
1040 static int
1041 vxlan_mdb_remote_srcs_replace(const struct vxlan_mdb_config *cfg,
1042 			      const struct vxlan_mdb_entry *mdb_entry,
1043 			      struct vxlan_mdb_remote *remote,
1044 			      struct netlink_ext_ack *extack)
1045 {
1046 	struct vxlan_dev *vxlan = cfg->vxlan;
1047 	struct vxlan_mdb_src_entry *ent;
1048 	struct hlist_node *tmp;
1049 	int err;
1050 
1051 	hlist_for_each_entry(ent, &remote->src_list, node)
1052 		ent->flags |= VXLAN_SGRP_F_DELETE;
1053 
1054 	err = vxlan_mdb_remote_srcs_add(cfg, remote, extack);
1055 	if (err)
1056 		goto err_clear_delete;
1057 
1058 	hlist_for_each_entry_safe(ent, tmp, &remote->src_list, node) {
1059 		if (ent->flags & VXLAN_SGRP_F_DELETE)
1060 			vxlan_mdb_remote_src_del(vxlan, &mdb_entry->key, remote,
1061 						 ent);
1062 	}
1063 
1064 	return 0;
1065 
1066 err_clear_delete:
1067 	hlist_for_each_entry(ent, &remote->src_list, node)
1068 		ent->flags &= ~VXLAN_SGRP_F_DELETE;
1069 	return err;
1070 }
1071 
1072 static int vxlan_mdb_remote_replace(const struct vxlan_mdb_config *cfg,
1073 				    const struct vxlan_mdb_entry *mdb_entry,
1074 				    struct vxlan_mdb_remote *remote,
1075 				    struct netlink_ext_ack *extack)
1076 {
1077 	struct vxlan_rdst *new_rd, *old_rd = rtnl_dereference(remote->rd);
1078 	struct vxlan_dev *vxlan = cfg->vxlan;
1079 	int err;
1080 
1081 	err = vxlan_mdb_remote_rdst_init(cfg, remote);
1082 	if (err)
1083 		return err;
1084 	new_rd = rtnl_dereference(remote->rd);
1085 
1086 	err = vxlan_mdb_remote_srcs_replace(cfg, mdb_entry, remote, extack);
1087 	if (err)
1088 		goto err_rdst_reset;
1089 
1090 	WRITE_ONCE(remote->flags, cfg->flags);
1091 	WRITE_ONCE(remote->filter_mode, cfg->filter_mode);
1092 	remote->rt_protocol = cfg->rt_protocol;
1093 	vxlan_mdb_remote_notify(vxlan, mdb_entry, remote, RTM_NEWMDB);
1094 
1095 	vxlan_mdb_remote_rdst_fini(old_rd);
1096 
1097 	return 0;
1098 
1099 err_rdst_reset:
1100 	rcu_assign_pointer(remote->rd, old_rd);
1101 	vxlan_mdb_remote_rdst_fini(new_rd);
1102 	return err;
1103 }
1104 
1105 static int vxlan_mdb_remote_add(const struct vxlan_mdb_config *cfg,
1106 				struct vxlan_mdb_entry *mdb_entry,
1107 				struct netlink_ext_ack *extack)
1108 {
1109 	struct vxlan_mdb_remote *remote;
1110 	int err;
1111 
1112 	remote = vxlan_mdb_remote_lookup(mdb_entry, &cfg->remote_ip);
1113 	if (remote) {
1114 		if (!(cfg->nlflags & NLM_F_REPLACE)) {
1115 			NL_SET_ERR_MSG_MOD(extack, "Replace not specified and MDB remote entry already exists");
1116 			return -EEXIST;
1117 		}
1118 		return vxlan_mdb_remote_replace(cfg, mdb_entry, remote, extack);
1119 	}
1120 
1121 	if (!(cfg->nlflags & NLM_F_CREATE)) {
1122 		NL_SET_ERR_MSG_MOD(extack, "Create not specified and entry does not exist");
1123 		return -ENOENT;
1124 	}
1125 
1126 	remote = kzalloc(sizeof(*remote), GFP_KERNEL);
1127 	if (!remote)
1128 		return -ENOMEM;
1129 
1130 	err = vxlan_mdb_remote_init(cfg, remote);
1131 	if (err) {
1132 		NL_SET_ERR_MSG_MOD(extack, "Failed to initialize remote MDB entry");
1133 		goto err_free_remote;
1134 	}
1135 
1136 	err = vxlan_mdb_remote_srcs_add(cfg, remote, extack);
1137 	if (err)
1138 		goto err_remote_fini;
1139 
1140 	list_add_rcu(&remote->list, &mdb_entry->remotes);
1141 	vxlan_mdb_remote_notify(cfg->vxlan, mdb_entry, remote, RTM_NEWMDB);
1142 
1143 	return 0;
1144 
1145 err_remote_fini:
1146 	vxlan_mdb_remote_fini(cfg->vxlan, remote);
1147 err_free_remote:
1148 	kfree(remote);
1149 	return err;
1150 }
1151 
1152 static void vxlan_mdb_remote_del(struct vxlan_dev *vxlan,
1153 				 struct vxlan_mdb_entry *mdb_entry,
1154 				 struct vxlan_mdb_remote *remote)
1155 {
1156 	vxlan_mdb_remote_notify(vxlan, mdb_entry, remote, RTM_DELMDB);
1157 	list_del_rcu(&remote->list);
1158 	vxlan_mdb_remote_srcs_del(vxlan, &mdb_entry->key, remote);
1159 	vxlan_mdb_remote_fini(vxlan, remote);
1160 	kfree_rcu(remote, rcu);
1161 }
1162 
1163 static struct vxlan_mdb_entry *
1164 vxlan_mdb_entry_get(struct vxlan_dev *vxlan,
1165 		    const struct vxlan_mdb_entry_key *group)
1166 {
1167 	struct vxlan_mdb_entry *mdb_entry;
1168 	int err;
1169 
1170 	mdb_entry = vxlan_mdb_entry_lookup(vxlan, group);
1171 	if (mdb_entry)
1172 		return mdb_entry;
1173 
1174 	mdb_entry = kzalloc(sizeof(*mdb_entry), GFP_KERNEL);
1175 	if (!mdb_entry)
1176 		return ERR_PTR(-ENOMEM);
1177 
1178 	INIT_LIST_HEAD(&mdb_entry->remotes);
1179 	memcpy(&mdb_entry->key, group, sizeof(mdb_entry->key));
1180 	hlist_add_head(&mdb_entry->mdb_node, &vxlan->mdb_list);
1181 
1182 	err = rhashtable_lookup_insert_fast(&vxlan->mdb_tbl,
1183 					    &mdb_entry->rhnode,
1184 					    vxlan_mdb_rht_params);
1185 	if (err)
1186 		goto err_free_entry;
1187 
1188 	if (hlist_is_singular_node(&mdb_entry->mdb_node, &vxlan->mdb_list))
1189 		vxlan->cfg.flags |= VXLAN_F_MDB;
1190 
1191 	return mdb_entry;
1192 
1193 err_free_entry:
1194 	hlist_del(&mdb_entry->mdb_node);
1195 	kfree(mdb_entry);
1196 	return ERR_PTR(err);
1197 }
1198 
1199 static void vxlan_mdb_entry_put(struct vxlan_dev *vxlan,
1200 				struct vxlan_mdb_entry *mdb_entry)
1201 {
1202 	if (!list_empty(&mdb_entry->remotes))
1203 		return;
1204 
1205 	if (hlist_is_singular_node(&mdb_entry->mdb_node, &vxlan->mdb_list))
1206 		vxlan->cfg.flags &= ~VXLAN_F_MDB;
1207 
1208 	rhashtable_remove_fast(&vxlan->mdb_tbl, &mdb_entry->rhnode,
1209 			       vxlan_mdb_rht_params);
1210 	hlist_del(&mdb_entry->mdb_node);
1211 	kfree_rcu(mdb_entry, rcu);
1212 }
1213 
1214 static int __vxlan_mdb_add(const struct vxlan_mdb_config *cfg,
1215 			   struct netlink_ext_ack *extack)
1216 {
1217 	struct vxlan_dev *vxlan = cfg->vxlan;
1218 	struct vxlan_mdb_entry *mdb_entry;
1219 	int err;
1220 
1221 	mdb_entry = vxlan_mdb_entry_get(vxlan, &cfg->group);
1222 	if (IS_ERR(mdb_entry))
1223 		return PTR_ERR(mdb_entry);
1224 
1225 	err = vxlan_mdb_remote_add(cfg, mdb_entry, extack);
1226 	if (err)
1227 		goto err_entry_put;
1228 
1229 	vxlan->mdb_seq++;
1230 
1231 	return 0;
1232 
1233 err_entry_put:
1234 	vxlan_mdb_entry_put(vxlan, mdb_entry);
1235 	return err;
1236 }
1237 
1238 static int __vxlan_mdb_del(const struct vxlan_mdb_config *cfg,
1239 			   struct netlink_ext_ack *extack)
1240 {
1241 	struct vxlan_dev *vxlan = cfg->vxlan;
1242 	struct vxlan_mdb_entry *mdb_entry;
1243 	struct vxlan_mdb_remote *remote;
1244 
1245 	mdb_entry = vxlan_mdb_entry_lookup(vxlan, &cfg->group);
1246 	if (!mdb_entry) {
1247 		NL_SET_ERR_MSG_MOD(extack, "Did not find MDB entry");
1248 		return -ENOENT;
1249 	}
1250 
1251 	remote = vxlan_mdb_remote_lookup(mdb_entry, &cfg->remote_ip);
1252 	if (!remote) {
1253 		NL_SET_ERR_MSG_MOD(extack, "Did not find MDB remote entry");
1254 		return -ENOENT;
1255 	}
1256 
1257 	vxlan_mdb_remote_del(vxlan, mdb_entry, remote);
1258 	vxlan_mdb_entry_put(vxlan, mdb_entry);
1259 
1260 	vxlan->mdb_seq++;
1261 
1262 	return 0;
1263 }
1264 
1265 int vxlan_mdb_add(struct net_device *dev, struct nlattr *tb[], u16 nlmsg_flags,
1266 		  struct netlink_ext_ack *extack)
1267 {
1268 	struct vxlan_mdb_config cfg;
1269 	int err;
1270 
1271 	ASSERT_RTNL();
1272 
1273 	err = vxlan_mdb_config_init(&cfg, dev, tb, nlmsg_flags, extack);
1274 	if (err)
1275 		return err;
1276 
1277 	err = __vxlan_mdb_add(&cfg, extack);
1278 
1279 	vxlan_mdb_config_fini(&cfg);
1280 	return err;
1281 }
1282 
1283 int vxlan_mdb_del(struct net_device *dev, struct nlattr *tb[],
1284 		  struct netlink_ext_ack *extack)
1285 {
1286 	struct vxlan_mdb_config cfg;
1287 	int err;
1288 
1289 	ASSERT_RTNL();
1290 
1291 	err = vxlan_mdb_config_init(&cfg, dev, tb, 0, extack);
1292 	if (err)
1293 		return err;
1294 
1295 	err = __vxlan_mdb_del(&cfg, extack);
1296 
1297 	vxlan_mdb_config_fini(&cfg);
1298 	return err;
1299 }
1300 
1301 struct vxlan_mdb_entry *vxlan_mdb_entry_skb_get(struct vxlan_dev *vxlan,
1302 						struct sk_buff *skb,
1303 						__be32 src_vni)
1304 {
1305 	struct vxlan_mdb_entry *mdb_entry;
1306 	struct vxlan_mdb_entry_key group;
1307 
1308 	if (!is_multicast_ether_addr(eth_hdr(skb)->h_dest) ||
1309 	    is_broadcast_ether_addr(eth_hdr(skb)->h_dest))
1310 		return NULL;
1311 
1312 	/* When not in collect metadata mode, 'src_vni' is zero, but MDB
1313 	 * entries are stored with the VNI of the VXLAN device.
1314 	 */
1315 	if (!(vxlan->cfg.flags & VXLAN_F_COLLECT_METADATA))
1316 		src_vni = vxlan->default_dst.remote_vni;
1317 
1318 	memset(&group, 0, sizeof(group));
1319 	group.vni = src_vni;
1320 
1321 	switch (skb->protocol) {
1322 	case htons(ETH_P_IP):
1323 		if (!pskb_may_pull(skb, sizeof(struct iphdr)))
1324 			return NULL;
1325 		group.dst.sa.sa_family = AF_INET;
1326 		group.dst.sin.sin_addr.s_addr = ip_hdr(skb)->daddr;
1327 		group.src.sa.sa_family = AF_INET;
1328 		group.src.sin.sin_addr.s_addr = ip_hdr(skb)->saddr;
1329 		break;
1330 #if IS_ENABLED(CONFIG_IPV6)
1331 	case htons(ETH_P_IPV6):
1332 		if (!pskb_may_pull(skb, sizeof(struct ipv6hdr)))
1333 			return NULL;
1334 		group.dst.sa.sa_family = AF_INET6;
1335 		group.dst.sin6.sin6_addr = ipv6_hdr(skb)->daddr;
1336 		group.src.sa.sa_family = AF_INET6;
1337 		group.src.sin6.sin6_addr = ipv6_hdr(skb)->saddr;
1338 		break;
1339 #endif
1340 	default:
1341 		return NULL;
1342 	}
1343 
1344 	mdb_entry = vxlan_mdb_entry_lookup(vxlan, &group);
1345 	if (mdb_entry)
1346 		return mdb_entry;
1347 
1348 	memset(&group.src, 0, sizeof(group.src));
1349 	mdb_entry = vxlan_mdb_entry_lookup(vxlan, &group);
1350 	if (mdb_entry)
1351 		return mdb_entry;
1352 
1353 	/* No (S, G) or (*, G) found. Look up the all-zeros entry, but only if
1354 	 * the destination IP address is not link-local multicast since we want
1355 	 * to transmit such traffic together with broadcast and unknown unicast
1356 	 * traffic.
1357 	 */
1358 	switch (skb->protocol) {
1359 	case htons(ETH_P_IP):
1360 		if (ipv4_is_local_multicast(group.dst.sin.sin_addr.s_addr))
1361 			return NULL;
1362 		group.dst.sin.sin_addr.s_addr = 0;
1363 		break;
1364 #if IS_ENABLED(CONFIG_IPV6)
1365 	case htons(ETH_P_IPV6):
1366 		if (ipv6_addr_type(&group.dst.sin6.sin6_addr) &
1367 		    IPV6_ADDR_LINKLOCAL)
1368 			return NULL;
1369 		memset(&group.dst.sin6.sin6_addr, 0,
1370 		       sizeof(group.dst.sin6.sin6_addr));
1371 		break;
1372 #endif
1373 	default:
1374 		return NULL;
1375 	}
1376 
1377 	return vxlan_mdb_entry_lookup(vxlan, &group);
1378 }
1379 
1380 netdev_tx_t vxlan_mdb_xmit(struct vxlan_dev *vxlan,
1381 			   const struct vxlan_mdb_entry *mdb_entry,
1382 			   struct sk_buff *skb)
1383 {
1384 	struct vxlan_mdb_remote *remote, *fremote = NULL;
1385 	__be32 src_vni = mdb_entry->key.vni;
1386 
1387 	list_for_each_entry_rcu(remote, &mdb_entry->remotes, list) {
1388 		struct sk_buff *skb1;
1389 
1390 		if ((vxlan_mdb_is_star_g(&mdb_entry->key) &&
1391 		     READ_ONCE(remote->filter_mode) == MCAST_INCLUDE) ||
1392 		    (READ_ONCE(remote->flags) & VXLAN_MDB_REMOTE_F_BLOCKED))
1393 			continue;
1394 
1395 		if (!fremote) {
1396 			fremote = remote;
1397 			continue;
1398 		}
1399 
1400 		skb1 = skb_clone(skb, GFP_ATOMIC);
1401 		if (skb1)
1402 			vxlan_xmit_one(skb1, vxlan->dev, src_vni,
1403 				       rcu_dereference(remote->rd), false);
1404 	}
1405 
1406 	if (fremote)
1407 		vxlan_xmit_one(skb, vxlan->dev, src_vni,
1408 			       rcu_dereference(fremote->rd), false);
1409 	else
1410 		kfree_skb(skb);
1411 
1412 	return NETDEV_TX_OK;
1413 }
1414 
1415 static void vxlan_mdb_check_empty(void *ptr, void *arg)
1416 {
1417 	WARN_ON_ONCE(1);
1418 }
1419 
1420 static void vxlan_mdb_remotes_flush(struct vxlan_dev *vxlan,
1421 				    struct vxlan_mdb_entry *mdb_entry)
1422 {
1423 	struct vxlan_mdb_remote *remote, *tmp;
1424 
1425 	list_for_each_entry_safe(remote, tmp, &mdb_entry->remotes, list)
1426 		vxlan_mdb_remote_del(vxlan, mdb_entry, remote);
1427 }
1428 
1429 static void vxlan_mdb_entries_flush(struct vxlan_dev *vxlan)
1430 {
1431 	struct vxlan_mdb_entry *mdb_entry;
1432 	struct hlist_node *tmp;
1433 
1434 	/* The removal of an entry cannot trigger the removal of another entry
1435 	 * since entries are always added to the head of the list.
1436 	 */
1437 	hlist_for_each_entry_safe(mdb_entry, tmp, &vxlan->mdb_list, mdb_node) {
1438 		vxlan_mdb_remotes_flush(vxlan, mdb_entry);
1439 		vxlan_mdb_entry_put(vxlan, mdb_entry);
1440 	}
1441 }
1442 
1443 int vxlan_mdb_init(struct vxlan_dev *vxlan)
1444 {
1445 	int err;
1446 
1447 	err = rhashtable_init(&vxlan->mdb_tbl, &vxlan_mdb_rht_params);
1448 	if (err)
1449 		return err;
1450 
1451 	INIT_HLIST_HEAD(&vxlan->mdb_list);
1452 
1453 	return 0;
1454 }
1455 
1456 void vxlan_mdb_fini(struct vxlan_dev *vxlan)
1457 {
1458 	vxlan_mdb_entries_flush(vxlan);
1459 	WARN_ON_ONCE(vxlan->cfg.flags & VXLAN_F_MDB);
1460 	rhashtable_free_and_destroy(&vxlan->mdb_tbl, vxlan_mdb_check_empty,
1461 				    NULL);
1462 }
1463