xref: /linux/net/sched/act_nat.c (revision 0526b56cbc3c489642bd6a5fe4b718dea7ef0ee8)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Stateless NAT actions
4  *
5  * Copyright (c) 2007 Herbert Xu <herbert@gondor.apana.org.au>
6  */
7 
8 #include <linux/errno.h>
9 #include <linux/init.h>
10 #include <linux/kernel.h>
11 #include <linux/module.h>
12 #include <linux/netfilter.h>
13 #include <linux/rtnetlink.h>
14 #include <linux/skbuff.h>
15 #include <linux/slab.h>
16 #include <linux/spinlock.h>
17 #include <linux/string.h>
18 #include <linux/tc_act/tc_nat.h>
19 #include <net/act_api.h>
20 #include <net/pkt_cls.h>
21 #include <net/icmp.h>
22 #include <net/ip.h>
23 #include <net/netlink.h>
24 #include <net/tc_act/tc_nat.h>
25 #include <net/tcp.h>
26 #include <net/udp.h>
27 #include <net/tc_wrapper.h>
28 
29 static struct tc_action_ops act_nat_ops;
30 
31 static const struct nla_policy nat_policy[TCA_NAT_MAX + 1] = {
32 	[TCA_NAT_PARMS]	= { .len = sizeof(struct tc_nat) },
33 };
34 
35 static int tcf_nat_init(struct net *net, struct nlattr *nla, struct nlattr *est,
36 			struct tc_action **a, struct tcf_proto *tp,
37 			u32 flags, struct netlink_ext_ack *extack)
38 {
39 	struct tc_action_net *tn = net_generic(net, act_nat_ops.net_id);
40 	bool bind = flags & TCA_ACT_FLAGS_BIND;
41 	struct tcf_nat_parms *nparm, *oparm;
42 	struct nlattr *tb[TCA_NAT_MAX + 1];
43 	struct tcf_chain *goto_ch = NULL;
44 	struct tc_nat *parm;
45 	int ret = 0, err;
46 	struct tcf_nat *p;
47 	u32 index;
48 
49 	if (nla == NULL)
50 		return -EINVAL;
51 
52 	err = nla_parse_nested_deprecated(tb, TCA_NAT_MAX, nla, nat_policy,
53 					  NULL);
54 	if (err < 0)
55 		return err;
56 
57 	if (tb[TCA_NAT_PARMS] == NULL)
58 		return -EINVAL;
59 	parm = nla_data(tb[TCA_NAT_PARMS]);
60 	index = parm->index;
61 	err = tcf_idr_check_alloc(tn, &index, a, bind);
62 	if (!err) {
63 		ret = tcf_idr_create_from_flags(tn, index, est, a, &act_nat_ops,
64 						bind, flags);
65 		if (ret) {
66 			tcf_idr_cleanup(tn, index);
67 			return ret;
68 		}
69 		ret = ACT_P_CREATED;
70 	} else if (err > 0) {
71 		if (bind)
72 			return 0;
73 		if (!(flags & TCA_ACT_FLAGS_REPLACE)) {
74 			tcf_idr_release(*a, bind);
75 			return -EEXIST;
76 		}
77 	} else {
78 		return err;
79 	}
80 	err = tcf_action_check_ctrlact(parm->action, tp, &goto_ch, extack);
81 	if (err < 0)
82 		goto release_idr;
83 
84 	nparm = kzalloc(sizeof(*nparm), GFP_KERNEL);
85 	if (!nparm) {
86 		err = -ENOMEM;
87 		goto release_idr;
88 	}
89 
90 	nparm->old_addr = parm->old_addr;
91 	nparm->new_addr = parm->new_addr;
92 	nparm->mask = parm->mask;
93 	nparm->flags = parm->flags;
94 
95 	p = to_tcf_nat(*a);
96 
97 	spin_lock_bh(&p->tcf_lock);
98 	goto_ch = tcf_action_set_ctrlact(*a, parm->action, goto_ch);
99 	oparm = rcu_replace_pointer(p->parms, nparm, lockdep_is_held(&p->tcf_lock));
100 	spin_unlock_bh(&p->tcf_lock);
101 
102 	if (goto_ch)
103 		tcf_chain_put_by_act(goto_ch);
104 
105 	if (oparm)
106 		kfree_rcu(oparm, rcu);
107 
108 	return ret;
109 release_idr:
110 	tcf_idr_release(*a, bind);
111 	return err;
112 }
113 
114 TC_INDIRECT_SCOPE int tcf_nat_act(struct sk_buff *skb,
115 				  const struct tc_action *a,
116 				  struct tcf_result *res)
117 {
118 	struct tcf_nat *p = to_tcf_nat(a);
119 	struct tcf_nat_parms *parms;
120 	struct iphdr *iph;
121 	__be32 old_addr;
122 	__be32 new_addr;
123 	__be32 mask;
124 	__be32 addr;
125 	int egress;
126 	int action;
127 	int ihl;
128 	int noff;
129 
130 	tcf_lastuse_update(&p->tcf_tm);
131 	tcf_action_update_bstats(&p->common, skb);
132 
133 	action = READ_ONCE(p->tcf_action);
134 
135 	parms = rcu_dereference_bh(p->parms);
136 	old_addr = parms->old_addr;
137 	new_addr = parms->new_addr;
138 	mask = parms->mask;
139 	egress = parms->flags & TCA_NAT_FLAG_EGRESS;
140 
141 	if (unlikely(action == TC_ACT_SHOT))
142 		goto drop;
143 
144 	noff = skb_network_offset(skb);
145 	if (!pskb_may_pull(skb, sizeof(*iph) + noff))
146 		goto drop;
147 
148 	iph = ip_hdr(skb);
149 
150 	if (egress)
151 		addr = iph->saddr;
152 	else
153 		addr = iph->daddr;
154 
155 	if (!((old_addr ^ addr) & mask)) {
156 		if (skb_try_make_writable(skb, sizeof(*iph) + noff))
157 			goto drop;
158 
159 		new_addr &= mask;
160 		new_addr |= addr & ~mask;
161 
162 		/* Rewrite IP header */
163 		iph = ip_hdr(skb);
164 		if (egress)
165 			iph->saddr = new_addr;
166 		else
167 			iph->daddr = new_addr;
168 
169 		csum_replace4(&iph->check, addr, new_addr);
170 	} else if ((iph->frag_off & htons(IP_OFFSET)) ||
171 		   iph->protocol != IPPROTO_ICMP) {
172 		goto out;
173 	}
174 
175 	ihl = iph->ihl * 4;
176 
177 	/* It would be nice to share code with stateful NAT. */
178 	switch (iph->frag_off & htons(IP_OFFSET) ? 0 : iph->protocol) {
179 	case IPPROTO_TCP:
180 	{
181 		struct tcphdr *tcph;
182 
183 		if (!pskb_may_pull(skb, ihl + sizeof(*tcph) + noff) ||
184 		    skb_try_make_writable(skb, ihl + sizeof(*tcph) + noff))
185 			goto drop;
186 
187 		tcph = (void *)(skb_network_header(skb) + ihl);
188 		inet_proto_csum_replace4(&tcph->check, skb, addr, new_addr,
189 					 true);
190 		break;
191 	}
192 	case IPPROTO_UDP:
193 	{
194 		struct udphdr *udph;
195 
196 		if (!pskb_may_pull(skb, ihl + sizeof(*udph) + noff) ||
197 		    skb_try_make_writable(skb, ihl + sizeof(*udph) + noff))
198 			goto drop;
199 
200 		udph = (void *)(skb_network_header(skb) + ihl);
201 		if (udph->check || skb->ip_summed == CHECKSUM_PARTIAL) {
202 			inet_proto_csum_replace4(&udph->check, skb, addr,
203 						 new_addr, true);
204 			if (!udph->check)
205 				udph->check = CSUM_MANGLED_0;
206 		}
207 		break;
208 	}
209 	case IPPROTO_ICMP:
210 	{
211 		struct icmphdr *icmph;
212 
213 		if (!pskb_may_pull(skb, ihl + sizeof(*icmph) + noff))
214 			goto drop;
215 
216 		icmph = (void *)(skb_network_header(skb) + ihl);
217 
218 		if (!icmp_is_err(icmph->type))
219 			break;
220 
221 		if (!pskb_may_pull(skb, ihl + sizeof(*icmph) + sizeof(*iph) +
222 					noff))
223 			goto drop;
224 
225 		icmph = (void *)(skb_network_header(skb) + ihl);
226 		iph = (void *)(icmph + 1);
227 		if (egress)
228 			addr = iph->daddr;
229 		else
230 			addr = iph->saddr;
231 
232 		if ((old_addr ^ addr) & mask)
233 			break;
234 
235 		if (skb_try_make_writable(skb, ihl + sizeof(*icmph) +
236 					  sizeof(*iph) + noff))
237 			goto drop;
238 
239 		icmph = (void *)(skb_network_header(skb) + ihl);
240 		iph = (void *)(icmph + 1);
241 
242 		new_addr &= mask;
243 		new_addr |= addr & ~mask;
244 
245 		/* XXX Fix up the inner checksums. */
246 		if (egress)
247 			iph->daddr = new_addr;
248 		else
249 			iph->saddr = new_addr;
250 
251 		inet_proto_csum_replace4(&icmph->checksum, skb, addr, new_addr,
252 					 false);
253 		break;
254 	}
255 	default:
256 		break;
257 	}
258 
259 out:
260 	return action;
261 
262 drop:
263 	tcf_action_inc_drop_qstats(&p->common);
264 	return TC_ACT_SHOT;
265 }
266 
267 static int tcf_nat_dump(struct sk_buff *skb, struct tc_action *a,
268 			int bind, int ref)
269 {
270 	unsigned char *b = skb_tail_pointer(skb);
271 	struct tcf_nat *p = to_tcf_nat(a);
272 	struct tc_nat opt = {
273 		.index    = p->tcf_index,
274 		.refcnt   = refcount_read(&p->tcf_refcnt) - ref,
275 		.bindcnt  = atomic_read(&p->tcf_bindcnt) - bind,
276 	};
277 	struct tcf_nat_parms *parms;
278 	struct tcf_t t;
279 
280 	spin_lock_bh(&p->tcf_lock);
281 
282 	opt.action = p->tcf_action;
283 
284 	parms = rcu_dereference_protected(p->parms, lockdep_is_held(&p->tcf_lock));
285 
286 	opt.old_addr = parms->old_addr;
287 	opt.new_addr = parms->new_addr;
288 	opt.mask = parms->mask;
289 	opt.flags = parms->flags;
290 
291 	if (nla_put(skb, TCA_NAT_PARMS, sizeof(opt), &opt))
292 		goto nla_put_failure;
293 
294 	tcf_tm_dump(&t, &p->tcf_tm);
295 	if (nla_put_64bit(skb, TCA_NAT_TM, sizeof(t), &t, TCA_NAT_PAD))
296 		goto nla_put_failure;
297 	spin_unlock_bh(&p->tcf_lock);
298 
299 	return skb->len;
300 
301 nla_put_failure:
302 	spin_unlock_bh(&p->tcf_lock);
303 	nlmsg_trim(skb, b);
304 	return -1;
305 }
306 
307 static void tcf_nat_cleanup(struct tc_action *a)
308 {
309 	struct tcf_nat *p = to_tcf_nat(a);
310 	struct tcf_nat_parms *parms;
311 
312 	parms = rcu_dereference_protected(p->parms, 1);
313 	if (parms)
314 		kfree_rcu(parms, rcu);
315 }
316 
317 static struct tc_action_ops act_nat_ops = {
318 	.kind		=	"nat",
319 	.id		=	TCA_ID_NAT,
320 	.owner		=	THIS_MODULE,
321 	.act		=	tcf_nat_act,
322 	.dump		=	tcf_nat_dump,
323 	.init		=	tcf_nat_init,
324 	.cleanup	=	tcf_nat_cleanup,
325 	.size		=	sizeof(struct tcf_nat),
326 };
327 
328 static __net_init int nat_init_net(struct net *net)
329 {
330 	struct tc_action_net *tn = net_generic(net, act_nat_ops.net_id);
331 
332 	return tc_action_net_init(net, tn, &act_nat_ops);
333 }
334 
335 static void __net_exit nat_exit_net(struct list_head *net_list)
336 {
337 	tc_action_net_exit(net_list, act_nat_ops.net_id);
338 }
339 
340 static struct pernet_operations nat_net_ops = {
341 	.init = nat_init_net,
342 	.exit_batch = nat_exit_net,
343 	.id   = &act_nat_ops.net_id,
344 	.size = sizeof(struct tc_action_net),
345 };
346 
347 MODULE_DESCRIPTION("Stateless NAT actions");
348 MODULE_LICENSE("GPL");
349 
350 static int __init nat_init_module(void)
351 {
352 	return tcf_register_action(&act_nat_ops, &nat_net_ops);
353 }
354 
355 static void __exit nat_cleanup_module(void)
356 {
357 	tcf_unregister_action(&act_nat_ops, &nat_net_ops);
358 }
359 
360 module_init(nat_init_module);
361 module_exit(nat_cleanup_module);
362