xref: /linux/net/ipv4/ipmr_base.c (revision 77de28cd7cf172e782319a144bf64e693794d78b)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Linux multicast routing support
3  * Common logic shared by IPv4 [ipmr] and IPv6 [ip6mr] implementation
4  */
5 
6 #include <linux/rhashtable.h>
7 #include <linux/mroute_base.h>
8 
9 /* Sets everything common except 'dev', since that is done under locking */
10 void vif_device_init(struct vif_device *v,
11 		     struct net_device *dev,
12 		     unsigned long rate_limit,
13 		     unsigned char threshold,
14 		     unsigned short flags,
15 		     unsigned short get_iflink_mask)
16 {
17 	RCU_INIT_POINTER(v->dev, NULL);
18 	v->bytes_in = 0;
19 	v->bytes_out = 0;
20 	v->pkt_in = 0;
21 	v->pkt_out = 0;
22 	v->rate_limit = rate_limit;
23 	v->flags = flags;
24 	v->threshold = threshold;
25 	if (v->flags & get_iflink_mask)
26 		v->link = dev_get_iflink(dev);
27 	else
28 		v->link = dev->ifindex;
29 }
30 EXPORT_SYMBOL(vif_device_init);
31 
32 struct mr_table *
33 mr_table_alloc(struct net *net, u32 id,
34 	       struct mr_table_ops *ops,
35 	       void (*expire_func)(struct timer_list *t),
36 	       void (*table_set)(struct mr_table *mrt,
37 				 struct net *net))
38 {
39 	struct mr_table *mrt;
40 	int err;
41 
42 	mrt = kzalloc_obj(*mrt);
43 	if (!mrt)
44 		return ERR_PTR(-ENOMEM);
45 	mrt->id = id;
46 	write_pnet(&mrt->net, net);
47 
48 	mrt->ops = *ops;
49 	err = rhltable_init(&mrt->mfc_hash, mrt->ops.rht_params);
50 	if (err) {
51 		kfree(mrt);
52 		return ERR_PTR(err);
53 	}
54 	INIT_LIST_HEAD(&mrt->mfc_cache_list);
55 	INIT_LIST_HEAD(&mrt->mfc_unres_queue);
56 
57 	timer_setup(&mrt->ipmr_expire_timer, expire_func, 0);
58 
59 	mrt->mroute_reg_vif_num = -1;
60 	table_set(mrt, net);
61 	return mrt;
62 }
63 EXPORT_SYMBOL(mr_table_alloc);
64 
65 void *mr_mfc_find_parent(struct mr_table *mrt, void *hasharg, int parent)
66 {
67 	struct rhlist_head *tmp, *list;
68 	struct mr_mfc *c;
69 
70 	list = rhltable_lookup(&mrt->mfc_hash, hasharg, *mrt->ops.rht_params);
71 	rhl_for_each_entry_rcu(c, tmp, list, mnode)
72 		if (parent == -1 || parent == c->mfc_parent)
73 			return c;
74 
75 	return NULL;
76 }
77 EXPORT_SYMBOL(mr_mfc_find_parent);
78 
79 void *mr_mfc_find_any_parent(struct mr_table *mrt, int vifi)
80 {
81 	struct rhlist_head *tmp, *list;
82 	struct mr_mfc *c;
83 
84 	list = rhltable_lookup(&mrt->mfc_hash, mrt->ops.cmparg_any,
85 			       *mrt->ops.rht_params);
86 	rhl_for_each_entry_rcu(c, tmp, list, mnode)
87 		if (c->mfc_un.res.ttls[vifi] < 255)
88 			return c;
89 
90 	return NULL;
91 }
92 EXPORT_SYMBOL(mr_mfc_find_any_parent);
93 
94 void *mr_mfc_find_any(struct mr_table *mrt, int vifi, void *hasharg)
95 {
96 	struct rhlist_head *tmp, *list;
97 	struct mr_mfc *c, *proxy;
98 
99 	list = rhltable_lookup(&mrt->mfc_hash, hasharg, *mrt->ops.rht_params);
100 	rhl_for_each_entry_rcu(c, tmp, list, mnode) {
101 		if (c->mfc_un.res.ttls[vifi] < 255)
102 			return c;
103 
104 		/* It's ok if the vifi is part of the static tree */
105 		proxy = mr_mfc_find_any_parent(mrt, c->mfc_parent);
106 		if (proxy && proxy->mfc_un.res.ttls[vifi] < 255)
107 			return c;
108 	}
109 
110 	return mr_mfc_find_any_parent(mrt, vifi);
111 }
112 EXPORT_SYMBOL(mr_mfc_find_any);
113 
114 #ifdef CONFIG_PROC_FS
115 void *mr_vif_seq_idx(struct net *net, struct mr_vif_iter *iter, loff_t pos)
116 {
117 	struct mr_table *mrt = iter->mrt;
118 
119 	for (iter->ct = 0; iter->ct < mrt->maxvif; ++iter->ct) {
120 		if (!VIF_EXISTS(mrt, iter->ct))
121 			continue;
122 		if (pos-- == 0)
123 			return &mrt->vif_table[iter->ct];
124 	}
125 	return NULL;
126 }
127 EXPORT_SYMBOL(mr_vif_seq_idx);
128 
129 void *mr_vif_seq_next(struct seq_file *seq, void *v, loff_t *pos)
130 {
131 	struct mr_vif_iter *iter = seq->private;
132 	struct net *net = seq_file_net(seq);
133 	struct mr_table *mrt = iter->mrt;
134 
135 	++*pos;
136 	if (v == SEQ_START_TOKEN)
137 		return mr_vif_seq_idx(net, iter, 0);
138 
139 	while (++iter->ct < mrt->maxvif) {
140 		if (!VIF_EXISTS(mrt, iter->ct))
141 			continue;
142 		return &mrt->vif_table[iter->ct];
143 	}
144 	return NULL;
145 }
146 EXPORT_SYMBOL(mr_vif_seq_next);
147 
148 void *mr_mfc_seq_idx(struct net *net,
149 		     struct mr_mfc_iter *it, loff_t pos)
150 {
151 	struct mr_table *mrt = it->mrt;
152 	struct mr_mfc *mfc;
153 
154 	rcu_read_lock();
155 	it->cache = &mrt->mfc_cache_list;
156 	list_for_each_entry_rcu(mfc, &mrt->mfc_cache_list, list)
157 		if (pos-- == 0)
158 			return mfc;
159 	rcu_read_unlock();
160 
161 	spin_lock_bh(it->lock);
162 	it->cache = &mrt->mfc_unres_queue;
163 	list_for_each_entry(mfc, it->cache, list)
164 		if (pos-- == 0)
165 			return mfc;
166 	spin_unlock_bh(it->lock);
167 
168 	it->cache = NULL;
169 	return NULL;
170 }
171 EXPORT_SYMBOL(mr_mfc_seq_idx);
172 
173 void *mr_mfc_seq_next(struct seq_file *seq, void *v,
174 		      loff_t *pos)
175 {
176 	struct mr_mfc_iter *it = seq->private;
177 	struct net *net = seq_file_net(seq);
178 	struct mr_table *mrt = it->mrt;
179 	struct mr_mfc *c = v;
180 
181 	++*pos;
182 
183 	if (v == SEQ_START_TOKEN)
184 		return mr_mfc_seq_idx(net, seq->private, 0);
185 
186 	if (c->list.next != it->cache)
187 		return list_entry(c->list.next, struct mr_mfc, list);
188 
189 	if (it->cache == &mrt->mfc_unres_queue)
190 		goto end_of_list;
191 
192 	/* exhausted cache_array, show unresolved */
193 	rcu_read_unlock();
194 	it->cache = &mrt->mfc_unres_queue;
195 
196 	spin_lock_bh(it->lock);
197 	if (!list_empty(it->cache))
198 		return list_first_entry(it->cache, struct mr_mfc, list);
199 
200 end_of_list:
201 	spin_unlock_bh(it->lock);
202 	it->cache = NULL;
203 
204 	return NULL;
205 }
206 EXPORT_SYMBOL(mr_mfc_seq_next);
207 #endif
208 
209 int mr_fill_mroute(struct mr_table *mrt, struct sk_buff *skb,
210 		   struct mr_mfc *c, struct rtmsg *rtm)
211 {
212 	struct net_device *vif_dev;
213 	struct rta_mfc_stats mfcs;
214 	struct nlattr *mp_attr;
215 	struct rtnexthop *nhp;
216 	unsigned long lastuse;
217 	int ct;
218 
219 	/* If cache is unresolved, don't try to parse IIF and OIF */
220 	if (c->mfc_parent >= MAXVIFS) {
221 		rtm->rtm_flags |= RTNH_F_UNRESOLVED;
222 		return -ENOENT;
223 	}
224 
225 	rcu_read_lock();
226 	vif_dev = rcu_dereference(mrt->vif_table[c->mfc_parent].dev);
227 	if (vif_dev && nla_put_u32(skb, RTA_IIF, READ_ONCE(vif_dev->ifindex)) < 0) {
228 		rcu_read_unlock();
229 		return -EMSGSIZE;
230 	}
231 	rcu_read_unlock();
232 
233 	if (c->mfc_flags & MFC_OFFLOAD)
234 		rtm->rtm_flags |= RTNH_F_OFFLOAD;
235 
236 	mp_attr = nla_nest_start_noflag(skb, RTA_MULTIPATH);
237 	if (!mp_attr)
238 		return -EMSGSIZE;
239 
240 	rcu_read_lock();
241 	for (ct = c->mfc_un.res.minvif; ct < c->mfc_un.res.maxvif; ct++) {
242 		struct vif_device *vif = &mrt->vif_table[ct];
243 
244 		vif_dev = rcu_dereference(vif->dev);
245 		if (vif_dev && c->mfc_un.res.ttls[ct] < 255) {
246 
247 			nhp = nla_reserve_nohdr(skb, sizeof(*nhp));
248 			if (!nhp) {
249 				rcu_read_unlock();
250 				nla_nest_cancel(skb, mp_attr);
251 				return -EMSGSIZE;
252 			}
253 
254 			nhp->rtnh_flags = 0;
255 			nhp->rtnh_hops = c->mfc_un.res.ttls[ct];
256 			nhp->rtnh_ifindex = READ_ONCE(vif_dev->ifindex);
257 			nhp->rtnh_len = sizeof(*nhp);
258 		}
259 	}
260 	rcu_read_unlock();
261 
262 	nla_nest_end(skb, mp_attr);
263 
264 	lastuse = READ_ONCE(c->mfc_un.res.lastuse);
265 	lastuse = time_after_eq(jiffies, lastuse) ? jiffies - lastuse : 0;
266 
267 	mfcs.mfcs_packets = atomic_long_read(&c->mfc_un.res.pkt);
268 	mfcs.mfcs_bytes = atomic_long_read(&c->mfc_un.res.bytes);
269 	mfcs.mfcs_wrong_if = atomic_long_read(&c->mfc_un.res.wrong_if);
270 	if (nla_put_64bit(skb, RTA_MFC_STATS, sizeof(mfcs), &mfcs, RTA_PAD) ||
271 	    nla_put_u64_64bit(skb, RTA_EXPIRES, jiffies_to_clock_t(lastuse),
272 			      RTA_PAD))
273 		return -EMSGSIZE;
274 
275 	rtm->rtm_type = RTN_MULTICAST;
276 	return 1;
277 }
278 EXPORT_SYMBOL(mr_fill_mroute);
279 
280 static bool mr_mfc_uses_dev(const struct mr_table *mrt,
281 			    const struct mr_mfc *c,
282 			    const struct net_device *dev)
283 {
284 	int ct;
285 
286 	for (ct = c->mfc_un.res.minvif; ct < c->mfc_un.res.maxvif; ct++) {
287 		const struct net_device *vif_dev;
288 		const struct vif_device *vif;
289 
290 		vif = &mrt->vif_table[ct];
291 		vif_dev = rcu_access_pointer(vif->dev);
292 		if (vif_dev && c->mfc_un.res.ttls[ct] < 255 &&
293 		    vif_dev == dev)
294 			return true;
295 	}
296 	return false;
297 }
298 
299 int mr_table_dump(struct mr_table *mrt, struct sk_buff *skb,
300 		  struct netlink_callback *cb,
301 		  int (*fill)(struct mr_table *mrt, struct sk_buff *skb,
302 			      u32 portid, u32 seq, struct mr_mfc *c,
303 			      int cmd, int flags),
304 		  spinlock_t *lock, struct fib_dump_filter *filter)
305 {
306 	unsigned int e = 0, s_e = cb->args[1];
307 	unsigned int flags = NLM_F_MULTI;
308 	struct mr_mfc *mfc;
309 	int err;
310 
311 	if (filter->filter_set)
312 		flags |= NLM_F_DUMP_FILTERED;
313 
314 	list_for_each_entry_rcu(mfc, &mrt->mfc_cache_list, list,
315 				lockdep_rtnl_is_held()) {
316 		if (e < s_e)
317 			goto next_entry;
318 		if (filter->dev &&
319 		    !mr_mfc_uses_dev(mrt, mfc, filter->dev))
320 			goto next_entry;
321 
322 		err = fill(mrt, skb, NETLINK_CB(cb->skb).portid,
323 			   cb->nlh->nlmsg_seq, mfc, RTM_NEWROUTE, flags);
324 		if (err < 0)
325 			goto out;
326 next_entry:
327 		e++;
328 	}
329 
330 	spin_lock_bh(lock);
331 	list_for_each_entry(mfc, &mrt->mfc_unres_queue, list) {
332 		if (e < s_e)
333 			goto next_entry2;
334 
335 		err = fill(mrt, skb, NETLINK_CB(cb->skb).portid,
336 			   cb->nlh->nlmsg_seq, mfc, RTM_NEWROUTE, flags);
337 		if (err < 0) {
338 			spin_unlock_bh(lock);
339 			goto out;
340 		}
341 next_entry2:
342 		e++;
343 	}
344 	spin_unlock_bh(lock);
345 	err = 0;
346 out:
347 	cb->args[1] = e;
348 	return err;
349 }
350 EXPORT_SYMBOL(mr_table_dump);
351 
352 int mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb,
353 		     struct mr_table *(*iter)(struct net *net,
354 					      struct mr_table *mrt),
355 		     int (*fill)(struct mr_table *mrt,
356 				 struct sk_buff *skb,
357 				 u32 portid, u32 seq, struct mr_mfc *c,
358 				 int cmd, int flags),
359 		     spinlock_t *lock, struct fib_dump_filter *filter)
360 {
361 	unsigned int t = 0, s_t = cb->args[0];
362 	struct net *net = sock_net(skb->sk);
363 	struct mr_table *mrt;
364 	int err;
365 
366 	/* multicast does not track protocol or have route type other
367 	 * than RTN_MULTICAST
368 	 */
369 	if (filter->filter_set) {
370 		if (filter->protocol || filter->flags ||
371 		    (filter->rt_type && filter->rt_type != RTN_MULTICAST))
372 			return skb->len;
373 	}
374 
375 	rcu_read_lock();
376 	for (mrt = iter(net, NULL); mrt; mrt = iter(net, mrt)) {
377 		if (t < s_t)
378 			goto next_table;
379 
380 		err = mr_table_dump(mrt, skb, cb, fill, lock, filter);
381 		if (err < 0)
382 			break;
383 		cb->args[1] = 0;
384 next_table:
385 		t++;
386 	}
387 	rcu_read_unlock();
388 
389 	cb->args[0] = t;
390 
391 	return skb->len;
392 }
393 EXPORT_SYMBOL(mr_rtm_dumproute);
394 
395 int mr_dump(struct net *net, struct notifier_block *nb, unsigned short family,
396 	    int (*rules_dump)(struct net *net,
397 			      struct notifier_block *nb,
398 			      struct netlink_ext_ack *extack),
399 	    struct mr_table *(*mr_iter)(struct net *net,
400 					struct mr_table *mrt),
401 	    struct netlink_ext_ack *extack)
402 {
403 	struct mr_table *mrt;
404 	int err;
405 
406 	err = rules_dump(net, nb, extack);
407 	if (err)
408 		return err;
409 
410 	for (mrt = mr_iter(net, NULL); mrt; mrt = mr_iter(net, mrt)) {
411 		struct vif_device *v = &mrt->vif_table[0];
412 		struct net_device *vif_dev;
413 		struct mr_mfc *mfc;
414 		int vifi;
415 
416 		/* Notifiy on table VIF entries */
417 		rcu_read_lock();
418 		for (vifi = 0; vifi < mrt->maxvif; vifi++, v++) {
419 			vif_dev = rcu_dereference(v->dev);
420 			if (!vif_dev)
421 				continue;
422 
423 			err = mr_call_vif_notifier(nb, family,
424 						   FIB_EVENT_VIF_ADD, v,
425 						   vif_dev, vifi,
426 						   mrt->id, extack);
427 			if (err)
428 				break;
429 		}
430 		rcu_read_unlock();
431 
432 		if (err)
433 			return err;
434 
435 		/* Notify on table MFC entries */
436 		list_for_each_entry_rcu(mfc, &mrt->mfc_cache_list, list) {
437 			err = mr_call_mfc_notifier(nb, family,
438 						   FIB_EVENT_ENTRY_ADD,
439 						   mfc, mrt->id, extack);
440 			if (err)
441 				return err;
442 		}
443 	}
444 
445 	return 0;
446 }
447 EXPORT_SYMBOL(mr_dump);
448