xref: /linux/net/sched/act_sample.c (revision 9f2c9170934eace462499ba0bfe042cc72900173)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * net/sched/act_sample.c - Packet sampling tc action
4  * Copyright (c) 2017 Yotam Gigi <yotamg@mellanox.com>
5  */
6 
7 #include <linux/types.h>
8 #include <linux/kernel.h>
9 #include <linux/string.h>
10 #include <linux/errno.h>
11 #include <linux/skbuff.h>
12 #include <linux/rtnetlink.h>
13 #include <linux/module.h>
14 #include <linux/init.h>
15 #include <linux/gfp.h>
16 #include <net/net_namespace.h>
17 #include <net/netlink.h>
18 #include <net/pkt_sched.h>
19 #include <linux/tc_act/tc_sample.h>
20 #include <net/tc_act/tc_sample.h>
21 #include <net/psample.h>
22 #include <net/pkt_cls.h>
23 #include <net/tc_wrapper.h>
24 
25 #include <linux/if_arp.h>
26 
27 static struct tc_action_ops act_sample_ops;
28 
29 static const struct nla_policy sample_policy[TCA_SAMPLE_MAX + 1] = {
30 	[TCA_SAMPLE_PARMS]		= { .len = sizeof(struct tc_sample) },
31 	[TCA_SAMPLE_RATE]		= { .type = NLA_U32 },
32 	[TCA_SAMPLE_TRUNC_SIZE]		= { .type = NLA_U32 },
33 	[TCA_SAMPLE_PSAMPLE_GROUP]	= { .type = NLA_U32 },
34 };
35 
36 static int tcf_sample_init(struct net *net, struct nlattr *nla,
37 			   struct nlattr *est, struct tc_action **a,
38 			   struct tcf_proto *tp,
39 			   u32 flags, struct netlink_ext_ack *extack)
40 {
41 	struct tc_action_net *tn = net_generic(net, act_sample_ops.net_id);
42 	bool bind = flags & TCA_ACT_FLAGS_BIND;
43 	struct nlattr *tb[TCA_SAMPLE_MAX + 1];
44 	struct psample_group *psample_group;
45 	u32 psample_group_num, rate, index;
46 	struct tcf_chain *goto_ch = NULL;
47 	struct tc_sample *parm;
48 	struct tcf_sample *s;
49 	bool exists = false;
50 	int ret, err;
51 
52 	if (!nla)
53 		return -EINVAL;
54 	ret = nla_parse_nested_deprecated(tb, TCA_SAMPLE_MAX, nla,
55 					  sample_policy, NULL);
56 	if (ret < 0)
57 		return ret;
58 	if (!tb[TCA_SAMPLE_PARMS] || !tb[TCA_SAMPLE_RATE] ||
59 	    !tb[TCA_SAMPLE_PSAMPLE_GROUP])
60 		return -EINVAL;
61 
62 	parm = nla_data(tb[TCA_SAMPLE_PARMS]);
63 	index = parm->index;
64 	err = tcf_idr_check_alloc(tn, &index, a, bind);
65 	if (err < 0)
66 		return err;
67 	exists = err;
68 	if (exists && bind)
69 		return 0;
70 
71 	if (!exists) {
72 		ret = tcf_idr_create(tn, index, est, a,
73 				     &act_sample_ops, bind, true, flags);
74 		if (ret) {
75 			tcf_idr_cleanup(tn, index);
76 			return ret;
77 		}
78 		ret = ACT_P_CREATED;
79 	} else if (!(flags & TCA_ACT_FLAGS_REPLACE)) {
80 		tcf_idr_release(*a, bind);
81 		return -EEXIST;
82 	}
83 	err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack);
84 	if (err < 0)
85 		goto release_idr;
86 
87 	rate = nla_get_u32(tb[TCA_SAMPLE_RATE]);
88 	if (!rate) {
89 		NL_SET_ERR_MSG(extack, "invalid sample rate");
90 		err = -EINVAL;
91 		goto put_chain;
92 	}
93 	psample_group_num = nla_get_u32(tb[TCA_SAMPLE_PSAMPLE_GROUP]);
94 	psample_group = psample_group_get(net, psample_group_num);
95 	if (!psample_group) {
96 		err = -ENOMEM;
97 		goto put_chain;
98 	}
99 
100 	s = to_sample(*a);
101 
102 	spin_lock_bh(&s->tcf_lock);
103 	goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch);
104 	s->rate = rate;
105 	s->psample_group_num = psample_group_num;
106 	psample_group = rcu_replace_pointer(s->psample_group, psample_group,
107 					    lockdep_is_held(&s->tcf_lock));
108 
109 	if (tb[TCA_SAMPLE_TRUNC_SIZE]) {
110 		s->truncate = true;
111 		s->trunc_size = nla_get_u32(tb[TCA_SAMPLE_TRUNC_SIZE]);
112 	}
113 	spin_unlock_bh(&s->tcf_lock);
114 
115 	if (psample_group)
116 		psample_group_put(psample_group);
117 	if (goto_ch)
118 		tcf_chain_put_by_act(goto_ch);
119 
120 	return ret;
121 put_chain:
122 	if (goto_ch)
123 		tcf_chain_put_by_act(goto_ch);
124 release_idr:
125 	tcf_idr_release(*a, bind);
126 	return err;
127 }
128 
129 static void tcf_sample_cleanup(struct tc_action *a)
130 {
131 	struct tcf_sample *s = to_sample(a);
132 	struct psample_group *psample_group;
133 
134 	/* last reference to action, no need to lock */
135 	psample_group = rcu_dereference_protected(s->psample_group, 1);
136 	RCU_INIT_POINTER(s->psample_group, NULL);
137 	if (psample_group)
138 		psample_group_put(psample_group);
139 }
140 
141 static bool tcf_sample_dev_ok_push(struct net_device *dev)
142 {
143 	switch (dev->type) {
144 	case ARPHRD_TUNNEL:
145 	case ARPHRD_TUNNEL6:
146 	case ARPHRD_SIT:
147 	case ARPHRD_IPGRE:
148 	case ARPHRD_IP6GRE:
149 	case ARPHRD_VOID:
150 	case ARPHRD_NONE:
151 		return false;
152 	default:
153 		return true;
154 	}
155 }
156 
157 TC_INDIRECT_SCOPE int tcf_sample_act(struct sk_buff *skb,
158 				     const struct tc_action *a,
159 				     struct tcf_result *res)
160 {
161 	struct tcf_sample *s = to_sample(a);
162 	struct psample_group *psample_group;
163 	struct psample_metadata md = {};
164 	int retval;
165 
166 	tcf_lastuse_update(&s->tcf_tm);
167 	bstats_update(this_cpu_ptr(s->common.cpu_bstats), skb);
168 	retval = READ_ONCE(s->tcf_action);
169 
170 	psample_group = rcu_dereference_bh(s->psample_group);
171 
172 	/* randomly sample packets according to rate */
173 	if (psample_group && (get_random_u32_below(s->rate) == 0)) {
174 		if (!skb_at_tc_ingress(skb)) {
175 			md.in_ifindex = skb->skb_iif;
176 			md.out_ifindex = skb->dev->ifindex;
177 		} else {
178 			md.in_ifindex = skb->dev->ifindex;
179 		}
180 
181 		/* on ingress, the mac header gets popped, so push it back */
182 		if (skb_at_tc_ingress(skb) && tcf_sample_dev_ok_push(skb->dev))
183 			skb_push(skb, skb->mac_len);
184 
185 		md.trunc_size = s->truncate ? s->trunc_size : skb->len;
186 		psample_sample_packet(psample_group, skb, s->rate, &md);
187 
188 		if (skb_at_tc_ingress(skb) && tcf_sample_dev_ok_push(skb->dev))
189 			skb_pull(skb, skb->mac_len);
190 	}
191 
192 	return retval;
193 }
194 
195 static void tcf_sample_stats_update(struct tc_action *a, u64 bytes, u64 packets,
196 				    u64 drops, u64 lastuse, bool hw)
197 {
198 	struct tcf_sample *s = to_sample(a);
199 	struct tcf_t *tm = &s->tcf_tm;
200 
201 	tcf_action_update_stats(a, bytes, packets, drops, hw);
202 	tm->lastuse = max_t(u64, tm->lastuse, lastuse);
203 }
204 
205 static int tcf_sample_dump(struct sk_buff *skb, struct tc_action *a,
206 			   int bind, int ref)
207 {
208 	unsigned char *b = skb_tail_pointer(skb);
209 	struct tcf_sample *s = to_sample(a);
210 	struct tc_sample opt = {
211 		.index      = s->tcf_index,
212 		.refcnt     = refcount_read(&s->tcf_refcnt) - ref,
213 		.bindcnt    = atomic_read(&s->tcf_bindcnt) - bind,
214 	};
215 	struct tcf_t t;
216 
217 	spin_lock_bh(&s->tcf_lock);
218 	opt.action = s->tcf_action;
219 	if (nla_put(skb, TCA_SAMPLE_PARMS, sizeof(opt), &opt))
220 		goto nla_put_failure;
221 
222 	tcf_tm_dump(&t, &s->tcf_tm);
223 	if (nla_put_64bit(skb, TCA_SAMPLE_TM, sizeof(t), &t, TCA_SAMPLE_PAD))
224 		goto nla_put_failure;
225 
226 	if (nla_put_u32(skb, TCA_SAMPLE_RATE, s->rate))
227 		goto nla_put_failure;
228 
229 	if (s->truncate)
230 		if (nla_put_u32(skb, TCA_SAMPLE_TRUNC_SIZE, s->trunc_size))
231 			goto nla_put_failure;
232 
233 	if (nla_put_u32(skb, TCA_SAMPLE_PSAMPLE_GROUP, s->psample_group_num))
234 		goto nla_put_failure;
235 	spin_unlock_bh(&s->tcf_lock);
236 
237 	return skb->len;
238 
239 nla_put_failure:
240 	spin_unlock_bh(&s->tcf_lock);
241 	nlmsg_trim(skb, b);
242 	return -1;
243 }
244 
245 static void tcf_psample_group_put(void *priv)
246 {
247 	struct psample_group *group = priv;
248 
249 	psample_group_put(group);
250 }
251 
252 static struct psample_group *
253 tcf_sample_get_group(const struct tc_action *a,
254 		     tc_action_priv_destructor *destructor)
255 {
256 	struct tcf_sample *s = to_sample(a);
257 	struct psample_group *group;
258 
259 	group = rcu_dereference_protected(s->psample_group,
260 					  lockdep_is_held(&s->tcf_lock));
261 	if (group) {
262 		psample_group_take(group);
263 		*destructor = tcf_psample_group_put;
264 	}
265 
266 	return group;
267 }
268 
269 static void tcf_offload_sample_get_group(struct flow_action_entry *entry,
270 					 const struct tc_action *act)
271 {
272 	entry->sample.psample_group =
273 		act->ops->get_psample_group(act, &entry->destructor);
274 	entry->destructor_priv = entry->sample.psample_group;
275 }
276 
277 static int tcf_sample_offload_act_setup(struct tc_action *act, void *entry_data,
278 					u32 *index_inc, bool bind,
279 					struct netlink_ext_ack *extack)
280 {
281 	if (bind) {
282 		struct flow_action_entry *entry = entry_data;
283 
284 		entry->id = FLOW_ACTION_SAMPLE;
285 		entry->sample.trunc_size = tcf_sample_trunc_size(act);
286 		entry->sample.truncate = tcf_sample_truncate(act);
287 		entry->sample.rate = tcf_sample_rate(act);
288 		tcf_offload_sample_get_group(entry, act);
289 		*index_inc = 1;
290 	} else {
291 		struct flow_offload_action *fl_action = entry_data;
292 
293 		fl_action->id = FLOW_ACTION_SAMPLE;
294 	}
295 
296 	return 0;
297 }
298 
299 static struct tc_action_ops act_sample_ops = {
300 	.kind	  = "sample",
301 	.id	  = TCA_ID_SAMPLE,
302 	.owner	  = THIS_MODULE,
303 	.act	  = tcf_sample_act,
304 	.stats_update = tcf_sample_stats_update,
305 	.dump	  = tcf_sample_dump,
306 	.init	  = tcf_sample_init,
307 	.cleanup  = tcf_sample_cleanup,
308 	.get_psample_group = tcf_sample_get_group,
309 	.offload_act_setup    = tcf_sample_offload_act_setup,
310 	.size	  = sizeof(struct tcf_sample),
311 };
312 
313 static __net_init int sample_init_net(struct net *net)
314 {
315 	struct tc_action_net *tn = net_generic(net, act_sample_ops.net_id);
316 
317 	return tc_action_net_init(net, tn, &act_sample_ops);
318 }
319 
320 static void __net_exit sample_exit_net(struct list_head *net_list)
321 {
322 	tc_action_net_exit(net_list, act_sample_ops.net_id);
323 }
324 
325 static struct pernet_operations sample_net_ops = {
326 	.init = sample_init_net,
327 	.exit_batch = sample_exit_net,
328 	.id   = &act_sample_ops.net_id,
329 	.size = sizeof(struct tc_action_net),
330 };
331 
332 static int __init sample_init_module(void)
333 {
334 	return tcf_register_action(&act_sample_ops, &sample_net_ops);
335 }
336 
337 static void __exit sample_cleanup_module(void)
338 {
339 	tcf_unregister_action(&act_sample_ops, &sample_net_ops);
340 }
341 
342 module_init(sample_init_module);
343 module_exit(sample_cleanup_module);
344 
345 MODULE_AUTHOR("Yotam Gigi <yotam.gi@gmail.com>");
346 MODULE_DESCRIPTION("Packet sampling action");
347 MODULE_LICENSE("GPL v2");
348