xref: /linux/net/ipv4/fou_core.c (revision c48a7c44a1d02516309015b6134c9bb982e17008)
1 // SPDX-License-Identifier: GPL-2.0-only
2 #include <linux/module.h>
3 #include <linux/errno.h>
4 #include <linux/socket.h>
5 #include <linux/skbuff.h>
6 #include <linux/ip.h>
7 #include <linux/icmp.h>
8 #include <linux/udp.h>
9 #include <linux/types.h>
10 #include <linux/kernel.h>
11 #include <net/genetlink.h>
12 #include <net/gro.h>
13 #include <net/gue.h>
14 #include <net/fou.h>
15 #include <net/ip.h>
16 #include <net/protocol.h>
17 #include <net/udp.h>
18 #include <net/udp_tunnel.h>
19 #include <uapi/linux/fou.h>
20 #include <uapi/linux/genetlink.h>
21 
22 #include "fou_nl.h"
23 
24 struct fou {
25 	struct socket *sock;
26 	u8 protocol;
27 	u8 flags;
28 	__be16 port;
29 	u8 family;
30 	u16 type;
31 	struct list_head list;
32 	struct rcu_head rcu;
33 };
34 
35 #define FOU_F_REMCSUM_NOPARTIAL BIT(0)
36 
37 struct fou_cfg {
38 	u16 type;
39 	u8 protocol;
40 	u8 flags;
41 	struct udp_port_cfg udp_config;
42 };
43 
44 static unsigned int fou_net_id;
45 
46 struct fou_net {
47 	struct list_head fou_list;
48 	struct mutex fou_lock;
49 };
50 
51 static inline struct fou *fou_from_sock(struct sock *sk)
52 {
53 	return sk->sk_user_data;
54 }
55 
56 static int fou_recv_pull(struct sk_buff *skb, struct fou *fou, size_t len)
57 {
58 	/* Remove 'len' bytes from the packet (UDP header and
59 	 * FOU header if present).
60 	 */
61 	if (fou->family == AF_INET)
62 		ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
63 	else
64 		ipv6_hdr(skb)->payload_len =
65 		    htons(ntohs(ipv6_hdr(skb)->payload_len) - len);
66 
67 	__skb_pull(skb, len);
68 	skb_postpull_rcsum(skb, udp_hdr(skb), len);
69 	skb_reset_transport_header(skb);
70 	return iptunnel_pull_offloads(skb);
71 }
72 
73 static int fou_udp_recv(struct sock *sk, struct sk_buff *skb)
74 {
75 	struct fou *fou = fou_from_sock(sk);
76 
77 	if (!fou)
78 		return 1;
79 
80 	if (fou_recv_pull(skb, fou, sizeof(struct udphdr)))
81 		goto drop;
82 
83 	return -fou->protocol;
84 
85 drop:
86 	kfree_skb(skb);
87 	return 0;
88 }
89 
90 static struct guehdr *gue_remcsum(struct sk_buff *skb, struct guehdr *guehdr,
91 				  void *data, size_t hdrlen, u8 ipproto,
92 				  bool nopartial)
93 {
94 	__be16 *pd = data;
95 	size_t start = ntohs(pd[0]);
96 	size_t offset = ntohs(pd[1]);
97 	size_t plen = sizeof(struct udphdr) + hdrlen +
98 	    max_t(size_t, offset + sizeof(u16), start);
99 
100 	if (skb->remcsum_offload)
101 		return guehdr;
102 
103 	if (!pskb_may_pull(skb, plen))
104 		return NULL;
105 	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
106 
107 	skb_remcsum_process(skb, (void *)guehdr + hdrlen,
108 			    start, offset, nopartial);
109 
110 	return guehdr;
111 }
112 
113 static int gue_control_message(struct sk_buff *skb, struct guehdr *guehdr)
114 {
115 	/* No support yet */
116 	kfree_skb(skb);
117 	return 0;
118 }
119 
120 static int gue_udp_recv(struct sock *sk, struct sk_buff *skb)
121 {
122 	struct fou *fou = fou_from_sock(sk);
123 	size_t len, optlen, hdrlen;
124 	struct guehdr *guehdr;
125 	void *data;
126 	u16 doffset = 0;
127 	u8 proto_ctype;
128 
129 	if (!fou)
130 		return 1;
131 
132 	len = sizeof(struct udphdr) + sizeof(struct guehdr);
133 	if (!pskb_may_pull(skb, len))
134 		goto drop;
135 
136 	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
137 
138 	switch (guehdr->version) {
139 	case 0: /* Full GUE header present */
140 		break;
141 
142 	case 1: {
143 		/* Direct encapsulation of IPv4 or IPv6 */
144 
145 		int prot;
146 
147 		switch (((struct iphdr *)guehdr)->version) {
148 		case 4:
149 			prot = IPPROTO_IPIP;
150 			break;
151 		case 6:
152 			prot = IPPROTO_IPV6;
153 			break;
154 		default:
155 			goto drop;
156 		}
157 
158 		if (fou_recv_pull(skb, fou, sizeof(struct udphdr)))
159 			goto drop;
160 
161 		return -prot;
162 	}
163 
164 	default: /* Undefined version */
165 		goto drop;
166 	}
167 
168 	optlen = guehdr->hlen << 2;
169 	len += optlen;
170 
171 	if (!pskb_may_pull(skb, len))
172 		goto drop;
173 
174 	/* guehdr may change after pull */
175 	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
176 
177 	if (validate_gue_flags(guehdr, optlen))
178 		goto drop;
179 
180 	hdrlen = sizeof(struct guehdr) + optlen;
181 
182 	if (fou->family == AF_INET)
183 		ip_hdr(skb)->tot_len = htons(ntohs(ip_hdr(skb)->tot_len) - len);
184 	else
185 		ipv6_hdr(skb)->payload_len =
186 		    htons(ntohs(ipv6_hdr(skb)->payload_len) - len);
187 
188 	/* Pull csum through the guehdr now . This can be used if
189 	 * there is a remote checksum offload.
190 	 */
191 	skb_postpull_rcsum(skb, udp_hdr(skb), len);
192 
193 	data = &guehdr[1];
194 
195 	if (guehdr->flags & GUE_FLAG_PRIV) {
196 		__be32 flags = *(__be32 *)(data + doffset);
197 
198 		doffset += GUE_LEN_PRIV;
199 
200 		if (flags & GUE_PFLAG_REMCSUM) {
201 			guehdr = gue_remcsum(skb, guehdr, data + doffset,
202 					     hdrlen, guehdr->proto_ctype,
203 					     !!(fou->flags &
204 						FOU_F_REMCSUM_NOPARTIAL));
205 			if (!guehdr)
206 				goto drop;
207 
208 			data = &guehdr[1];
209 
210 			doffset += GUE_PLEN_REMCSUM;
211 		}
212 	}
213 
214 	if (unlikely(guehdr->control))
215 		return gue_control_message(skb, guehdr);
216 
217 	proto_ctype = guehdr->proto_ctype;
218 	__skb_pull(skb, sizeof(struct udphdr) + hdrlen);
219 	skb_reset_transport_header(skb);
220 
221 	if (iptunnel_pull_offloads(skb))
222 		goto drop;
223 
224 	return -proto_ctype;
225 
226 drop:
227 	kfree_skb(skb);
228 	return 0;
229 }
230 
231 static struct sk_buff *fou_gro_receive(struct sock *sk,
232 				       struct list_head *head,
233 				       struct sk_buff *skb)
234 {
235 	const struct net_offload __rcu **offloads;
236 	u8 proto = fou_from_sock(sk)->protocol;
237 	const struct net_offload *ops;
238 	struct sk_buff *pp = NULL;
239 
240 	/* We can clear the encap_mark for FOU as we are essentially doing
241 	 * one of two possible things.  We are either adding an L4 tunnel
242 	 * header to the outer L3 tunnel header, or we are simply
243 	 * treating the GRE tunnel header as though it is a UDP protocol
244 	 * specific header such as VXLAN or GENEVE.
245 	 */
246 	NAPI_GRO_CB(skb)->encap_mark = 0;
247 
248 	/* Flag this frame as already having an outer encap header */
249 	NAPI_GRO_CB(skb)->is_fou = 1;
250 
251 	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
252 	ops = rcu_dereference(offloads[proto]);
253 	if (!ops || !ops->callbacks.gro_receive)
254 		goto out;
255 
256 	pp = call_gro_receive(ops->callbacks.gro_receive, head, skb);
257 
258 out:
259 	return pp;
260 }
261 
262 static int fou_gro_complete(struct sock *sk, struct sk_buff *skb,
263 			    int nhoff)
264 {
265 	const struct net_offload __rcu **offloads;
266 	u8 proto = fou_from_sock(sk)->protocol;
267 	const struct net_offload *ops;
268 	int err = -ENOSYS;
269 
270 	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
271 	ops = rcu_dereference(offloads[proto]);
272 	if (WARN_ON(!ops || !ops->callbacks.gro_complete))
273 		goto out;
274 
275 	err = ops->callbacks.gro_complete(skb, nhoff);
276 
277 	skb_set_inner_mac_header(skb, nhoff);
278 
279 out:
280 	return err;
281 }
282 
283 static struct guehdr *gue_gro_remcsum(struct sk_buff *skb, unsigned int off,
284 				      struct guehdr *guehdr, void *data,
285 				      size_t hdrlen, struct gro_remcsum *grc,
286 				      bool nopartial)
287 {
288 	__be16 *pd = data;
289 	size_t start = ntohs(pd[0]);
290 	size_t offset = ntohs(pd[1]);
291 
292 	if (skb->remcsum_offload)
293 		return guehdr;
294 
295 	if (!NAPI_GRO_CB(skb)->csum_valid)
296 		return NULL;
297 
298 	guehdr = skb_gro_remcsum_process(skb, (void *)guehdr, off, hdrlen,
299 					 start, offset, grc, nopartial);
300 
301 	skb->remcsum_offload = 1;
302 
303 	return guehdr;
304 }
305 
306 static struct sk_buff *gue_gro_receive(struct sock *sk,
307 				       struct list_head *head,
308 				       struct sk_buff *skb)
309 {
310 	const struct net_offload __rcu **offloads;
311 	const struct net_offload *ops;
312 	struct sk_buff *pp = NULL;
313 	struct sk_buff *p;
314 	struct guehdr *guehdr;
315 	size_t len, optlen, hdrlen, off;
316 	void *data;
317 	u16 doffset = 0;
318 	int flush = 1;
319 	struct fou *fou = fou_from_sock(sk);
320 	struct gro_remcsum grc;
321 	u8 proto;
322 
323 	skb_gro_remcsum_init(&grc);
324 
325 	off = skb_gro_offset(skb);
326 	len = off + sizeof(*guehdr);
327 
328 	guehdr = skb_gro_header(skb, len, off);
329 	if (unlikely(!guehdr))
330 		goto out;
331 
332 	switch (guehdr->version) {
333 	case 0:
334 		break;
335 	case 1:
336 		switch (((struct iphdr *)guehdr)->version) {
337 		case 4:
338 			proto = IPPROTO_IPIP;
339 			break;
340 		case 6:
341 			proto = IPPROTO_IPV6;
342 			break;
343 		default:
344 			goto out;
345 		}
346 		goto next_proto;
347 	default:
348 		goto out;
349 	}
350 
351 	optlen = guehdr->hlen << 2;
352 	len += optlen;
353 
354 	if (skb_gro_header_hard(skb, len)) {
355 		guehdr = skb_gro_header_slow(skb, len, off);
356 		if (unlikely(!guehdr))
357 			goto out;
358 	}
359 
360 	if (unlikely(guehdr->control) || guehdr->version != 0 ||
361 	    validate_gue_flags(guehdr, optlen))
362 		goto out;
363 
364 	hdrlen = sizeof(*guehdr) + optlen;
365 
366 	/* Adjust NAPI_GRO_CB(skb)->csum to account for guehdr,
367 	 * this is needed if there is a remote checkcsum offload.
368 	 */
369 	skb_gro_postpull_rcsum(skb, guehdr, hdrlen);
370 
371 	data = &guehdr[1];
372 
373 	if (guehdr->flags & GUE_FLAG_PRIV) {
374 		__be32 flags = *(__be32 *)(data + doffset);
375 
376 		doffset += GUE_LEN_PRIV;
377 
378 		if (flags & GUE_PFLAG_REMCSUM) {
379 			guehdr = gue_gro_remcsum(skb, off, guehdr,
380 						 data + doffset, hdrlen, &grc,
381 						 !!(fou->flags &
382 						    FOU_F_REMCSUM_NOPARTIAL));
383 
384 			if (!guehdr)
385 				goto out;
386 
387 			data = &guehdr[1];
388 
389 			doffset += GUE_PLEN_REMCSUM;
390 		}
391 	}
392 
393 	skb_gro_pull(skb, hdrlen);
394 
395 	list_for_each_entry(p, head, list) {
396 		const struct guehdr *guehdr2;
397 
398 		if (!NAPI_GRO_CB(p)->same_flow)
399 			continue;
400 
401 		guehdr2 = (struct guehdr *)(p->data + off);
402 
403 		/* Compare base GUE header to be equal (covers
404 		 * hlen, version, proto_ctype, and flags.
405 		 */
406 		if (guehdr->word != guehdr2->word) {
407 			NAPI_GRO_CB(p)->same_flow = 0;
408 			continue;
409 		}
410 
411 		/* Compare optional fields are the same. */
412 		if (guehdr->hlen && memcmp(&guehdr[1], &guehdr2[1],
413 					   guehdr->hlen << 2)) {
414 			NAPI_GRO_CB(p)->same_flow = 0;
415 			continue;
416 		}
417 	}
418 
419 	proto = guehdr->proto_ctype;
420 
421 next_proto:
422 
423 	/* We can clear the encap_mark for GUE as we are essentially doing
424 	 * one of two possible things.  We are either adding an L4 tunnel
425 	 * header to the outer L3 tunnel header, or we are simply
426 	 * treating the GRE tunnel header as though it is a UDP protocol
427 	 * specific header such as VXLAN or GENEVE.
428 	 */
429 	NAPI_GRO_CB(skb)->encap_mark = 0;
430 
431 	/* Flag this frame as already having an outer encap header */
432 	NAPI_GRO_CB(skb)->is_fou = 1;
433 
434 	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
435 	ops = rcu_dereference(offloads[proto]);
436 	if (WARN_ON_ONCE(!ops || !ops->callbacks.gro_receive))
437 		goto out;
438 
439 	pp = call_gro_receive(ops->callbacks.gro_receive, head, skb);
440 	flush = 0;
441 
442 out:
443 	skb_gro_flush_final_remcsum(skb, pp, flush, &grc);
444 
445 	return pp;
446 }
447 
448 static int gue_gro_complete(struct sock *sk, struct sk_buff *skb, int nhoff)
449 {
450 	struct guehdr *guehdr = (struct guehdr *)(skb->data + nhoff);
451 	const struct net_offload __rcu **offloads;
452 	const struct net_offload *ops;
453 	unsigned int guehlen = 0;
454 	u8 proto;
455 	int err = -ENOENT;
456 
457 	switch (guehdr->version) {
458 	case 0:
459 		proto = guehdr->proto_ctype;
460 		guehlen = sizeof(*guehdr) + (guehdr->hlen << 2);
461 		break;
462 	case 1:
463 		switch (((struct iphdr *)guehdr)->version) {
464 		case 4:
465 			proto = IPPROTO_IPIP;
466 			break;
467 		case 6:
468 			proto = IPPROTO_IPV6;
469 			break;
470 		default:
471 			return err;
472 		}
473 		break;
474 	default:
475 		return err;
476 	}
477 
478 	offloads = NAPI_GRO_CB(skb)->is_ipv6 ? inet6_offloads : inet_offloads;
479 	ops = rcu_dereference(offloads[proto]);
480 	if (WARN_ON(!ops || !ops->callbacks.gro_complete))
481 		goto out;
482 
483 	err = ops->callbacks.gro_complete(skb, nhoff + guehlen);
484 
485 	skb_set_inner_mac_header(skb, nhoff + guehlen);
486 
487 out:
488 	return err;
489 }
490 
491 static bool fou_cfg_cmp(struct fou *fou, struct fou_cfg *cfg)
492 {
493 	struct sock *sk = fou->sock->sk;
494 	struct udp_port_cfg *udp_cfg = &cfg->udp_config;
495 
496 	if (fou->family != udp_cfg->family ||
497 	    fou->port != udp_cfg->local_udp_port ||
498 	    sk->sk_dport != udp_cfg->peer_udp_port ||
499 	    sk->sk_bound_dev_if != udp_cfg->bind_ifindex)
500 		return false;
501 
502 	if (fou->family == AF_INET) {
503 		if (sk->sk_rcv_saddr != udp_cfg->local_ip.s_addr ||
504 		    sk->sk_daddr != udp_cfg->peer_ip.s_addr)
505 			return false;
506 		else
507 			return true;
508 #if IS_ENABLED(CONFIG_IPV6)
509 	} else {
510 		if (ipv6_addr_cmp(&sk->sk_v6_rcv_saddr, &udp_cfg->local_ip6) ||
511 		    ipv6_addr_cmp(&sk->sk_v6_daddr, &udp_cfg->peer_ip6))
512 			return false;
513 		else
514 			return true;
515 #endif
516 	}
517 
518 	return false;
519 }
520 
521 static int fou_add_to_port_list(struct net *net, struct fou *fou,
522 				struct fou_cfg *cfg)
523 {
524 	struct fou_net *fn = net_generic(net, fou_net_id);
525 	struct fou *fout;
526 
527 	mutex_lock(&fn->fou_lock);
528 	list_for_each_entry(fout, &fn->fou_list, list) {
529 		if (fou_cfg_cmp(fout, cfg)) {
530 			mutex_unlock(&fn->fou_lock);
531 			return -EALREADY;
532 		}
533 	}
534 
535 	list_add(&fou->list, &fn->fou_list);
536 	mutex_unlock(&fn->fou_lock);
537 
538 	return 0;
539 }
540 
541 static void fou_release(struct fou *fou)
542 {
543 	struct socket *sock = fou->sock;
544 
545 	list_del(&fou->list);
546 	udp_tunnel_sock_release(sock);
547 
548 	kfree_rcu(fou, rcu);
549 }
550 
551 static int fou_create(struct net *net, struct fou_cfg *cfg,
552 		      struct socket **sockp)
553 {
554 	struct socket *sock = NULL;
555 	struct fou *fou = NULL;
556 	struct sock *sk;
557 	struct udp_tunnel_sock_cfg tunnel_cfg;
558 	int err;
559 
560 	/* Open UDP socket */
561 	err = udp_sock_create(net, &cfg->udp_config, &sock);
562 	if (err < 0)
563 		goto error;
564 
565 	/* Allocate FOU port structure */
566 	fou = kzalloc(sizeof(*fou), GFP_KERNEL);
567 	if (!fou) {
568 		err = -ENOMEM;
569 		goto error;
570 	}
571 
572 	sk = sock->sk;
573 
574 	fou->port = cfg->udp_config.local_udp_port;
575 	fou->family = cfg->udp_config.family;
576 	fou->flags = cfg->flags;
577 	fou->type = cfg->type;
578 	fou->sock = sock;
579 
580 	memset(&tunnel_cfg, 0, sizeof(tunnel_cfg));
581 	tunnel_cfg.encap_type = 1;
582 	tunnel_cfg.sk_user_data = fou;
583 	tunnel_cfg.encap_destroy = NULL;
584 
585 	/* Initial for fou type */
586 	switch (cfg->type) {
587 	case FOU_ENCAP_DIRECT:
588 		tunnel_cfg.encap_rcv = fou_udp_recv;
589 		tunnel_cfg.gro_receive = fou_gro_receive;
590 		tunnel_cfg.gro_complete = fou_gro_complete;
591 		fou->protocol = cfg->protocol;
592 		break;
593 	case FOU_ENCAP_GUE:
594 		tunnel_cfg.encap_rcv = gue_udp_recv;
595 		tunnel_cfg.gro_receive = gue_gro_receive;
596 		tunnel_cfg.gro_complete = gue_gro_complete;
597 		break;
598 	default:
599 		err = -EINVAL;
600 		goto error;
601 	}
602 
603 	setup_udp_tunnel_sock(net, sock, &tunnel_cfg);
604 
605 	sk->sk_allocation = GFP_ATOMIC;
606 
607 	err = fou_add_to_port_list(net, fou, cfg);
608 	if (err)
609 		goto error;
610 
611 	if (sockp)
612 		*sockp = sock;
613 
614 	return 0;
615 
616 error:
617 	kfree(fou);
618 	if (sock)
619 		udp_tunnel_sock_release(sock);
620 
621 	return err;
622 }
623 
624 static int fou_destroy(struct net *net, struct fou_cfg *cfg)
625 {
626 	struct fou_net *fn = net_generic(net, fou_net_id);
627 	int err = -EINVAL;
628 	struct fou *fou;
629 
630 	mutex_lock(&fn->fou_lock);
631 	list_for_each_entry(fou, &fn->fou_list, list) {
632 		if (fou_cfg_cmp(fou, cfg)) {
633 			fou_release(fou);
634 			err = 0;
635 			break;
636 		}
637 	}
638 	mutex_unlock(&fn->fou_lock);
639 
640 	return err;
641 }
642 
643 static struct genl_family fou_nl_family;
644 
645 static int parse_nl_config(struct genl_info *info,
646 			   struct fou_cfg *cfg)
647 {
648 	bool has_local = false, has_peer = false;
649 	struct nlattr *attr;
650 	int ifindex;
651 	__be16 port;
652 
653 	memset(cfg, 0, sizeof(*cfg));
654 
655 	cfg->udp_config.family = AF_INET;
656 
657 	if (info->attrs[FOU_ATTR_AF]) {
658 		u8 family = nla_get_u8(info->attrs[FOU_ATTR_AF]);
659 
660 		switch (family) {
661 		case AF_INET:
662 			break;
663 		case AF_INET6:
664 			cfg->udp_config.ipv6_v6only = 1;
665 			break;
666 		default:
667 			return -EAFNOSUPPORT;
668 		}
669 
670 		cfg->udp_config.family = family;
671 	}
672 
673 	if (info->attrs[FOU_ATTR_PORT]) {
674 		port = nla_get_be16(info->attrs[FOU_ATTR_PORT]);
675 		cfg->udp_config.local_udp_port = port;
676 	}
677 
678 	if (info->attrs[FOU_ATTR_IPPROTO])
679 		cfg->protocol = nla_get_u8(info->attrs[FOU_ATTR_IPPROTO]);
680 
681 	if (info->attrs[FOU_ATTR_TYPE])
682 		cfg->type = nla_get_u8(info->attrs[FOU_ATTR_TYPE]);
683 
684 	if (info->attrs[FOU_ATTR_REMCSUM_NOPARTIAL])
685 		cfg->flags |= FOU_F_REMCSUM_NOPARTIAL;
686 
687 	if (cfg->udp_config.family == AF_INET) {
688 		if (info->attrs[FOU_ATTR_LOCAL_V4]) {
689 			attr = info->attrs[FOU_ATTR_LOCAL_V4];
690 			cfg->udp_config.local_ip.s_addr = nla_get_in_addr(attr);
691 			has_local = true;
692 		}
693 
694 		if (info->attrs[FOU_ATTR_PEER_V4]) {
695 			attr = info->attrs[FOU_ATTR_PEER_V4];
696 			cfg->udp_config.peer_ip.s_addr = nla_get_in_addr(attr);
697 			has_peer = true;
698 		}
699 #if IS_ENABLED(CONFIG_IPV6)
700 	} else {
701 		if (info->attrs[FOU_ATTR_LOCAL_V6]) {
702 			attr = info->attrs[FOU_ATTR_LOCAL_V6];
703 			cfg->udp_config.local_ip6 = nla_get_in6_addr(attr);
704 			has_local = true;
705 		}
706 
707 		if (info->attrs[FOU_ATTR_PEER_V6]) {
708 			attr = info->attrs[FOU_ATTR_PEER_V6];
709 			cfg->udp_config.peer_ip6 = nla_get_in6_addr(attr);
710 			has_peer = true;
711 		}
712 #endif
713 	}
714 
715 	if (has_peer) {
716 		if (info->attrs[FOU_ATTR_PEER_PORT]) {
717 			port = nla_get_be16(info->attrs[FOU_ATTR_PEER_PORT]);
718 			cfg->udp_config.peer_udp_port = port;
719 		} else {
720 			return -EINVAL;
721 		}
722 	}
723 
724 	if (info->attrs[FOU_ATTR_IFINDEX]) {
725 		if (!has_local)
726 			return -EINVAL;
727 
728 		ifindex = nla_get_s32(info->attrs[FOU_ATTR_IFINDEX]);
729 
730 		cfg->udp_config.bind_ifindex = ifindex;
731 	}
732 
733 	return 0;
734 }
735 
736 int fou_nl_add_doit(struct sk_buff *skb, struct genl_info *info)
737 {
738 	struct net *net = genl_info_net(info);
739 	struct fou_cfg cfg;
740 	int err;
741 
742 	err = parse_nl_config(info, &cfg);
743 	if (err)
744 		return err;
745 
746 	return fou_create(net, &cfg, NULL);
747 }
748 
749 int fou_nl_del_doit(struct sk_buff *skb, struct genl_info *info)
750 {
751 	struct net *net = genl_info_net(info);
752 	struct fou_cfg cfg;
753 	int err;
754 
755 	err = parse_nl_config(info, &cfg);
756 	if (err)
757 		return err;
758 
759 	return fou_destroy(net, &cfg);
760 }
761 
762 static int fou_fill_info(struct fou *fou, struct sk_buff *msg)
763 {
764 	struct sock *sk = fou->sock->sk;
765 
766 	if (nla_put_u8(msg, FOU_ATTR_AF, fou->sock->sk->sk_family) ||
767 	    nla_put_be16(msg, FOU_ATTR_PORT, fou->port) ||
768 	    nla_put_be16(msg, FOU_ATTR_PEER_PORT, sk->sk_dport) ||
769 	    nla_put_u8(msg, FOU_ATTR_IPPROTO, fou->protocol) ||
770 	    nla_put_u8(msg, FOU_ATTR_TYPE, fou->type) ||
771 	    nla_put_s32(msg, FOU_ATTR_IFINDEX, sk->sk_bound_dev_if))
772 		return -1;
773 
774 	if (fou->flags & FOU_F_REMCSUM_NOPARTIAL)
775 		if (nla_put_flag(msg, FOU_ATTR_REMCSUM_NOPARTIAL))
776 			return -1;
777 
778 	if (fou->sock->sk->sk_family == AF_INET) {
779 		if (nla_put_in_addr(msg, FOU_ATTR_LOCAL_V4, sk->sk_rcv_saddr))
780 			return -1;
781 
782 		if (nla_put_in_addr(msg, FOU_ATTR_PEER_V4, sk->sk_daddr))
783 			return -1;
784 #if IS_ENABLED(CONFIG_IPV6)
785 	} else {
786 		if (nla_put_in6_addr(msg, FOU_ATTR_LOCAL_V6,
787 				     &sk->sk_v6_rcv_saddr))
788 			return -1;
789 
790 		if (nla_put_in6_addr(msg, FOU_ATTR_PEER_V6, &sk->sk_v6_daddr))
791 			return -1;
792 #endif
793 	}
794 
795 	return 0;
796 }
797 
798 static int fou_dump_info(struct fou *fou, u32 portid, u32 seq,
799 			 u32 flags, struct sk_buff *skb, u8 cmd)
800 {
801 	void *hdr;
802 
803 	hdr = genlmsg_put(skb, portid, seq, &fou_nl_family, flags, cmd);
804 	if (!hdr)
805 		return -ENOMEM;
806 
807 	if (fou_fill_info(fou, skb) < 0)
808 		goto nla_put_failure;
809 
810 	genlmsg_end(skb, hdr);
811 	return 0;
812 
813 nla_put_failure:
814 	genlmsg_cancel(skb, hdr);
815 	return -EMSGSIZE;
816 }
817 
818 int fou_nl_get_doit(struct sk_buff *skb, struct genl_info *info)
819 {
820 	struct net *net = genl_info_net(info);
821 	struct fou_net *fn = net_generic(net, fou_net_id);
822 	struct sk_buff *msg;
823 	struct fou_cfg cfg;
824 	struct fou *fout;
825 	__be16 port;
826 	u8 family;
827 	int ret;
828 
829 	ret = parse_nl_config(info, &cfg);
830 	if (ret)
831 		return ret;
832 	port = cfg.udp_config.local_udp_port;
833 	if (port == 0)
834 		return -EINVAL;
835 
836 	family = cfg.udp_config.family;
837 	if (family != AF_INET && family != AF_INET6)
838 		return -EINVAL;
839 
840 	msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
841 	if (!msg)
842 		return -ENOMEM;
843 
844 	ret = -ESRCH;
845 	mutex_lock(&fn->fou_lock);
846 	list_for_each_entry(fout, &fn->fou_list, list) {
847 		if (fou_cfg_cmp(fout, &cfg)) {
848 			ret = fou_dump_info(fout, info->snd_portid,
849 					    info->snd_seq, 0, msg,
850 					    info->genlhdr->cmd);
851 			break;
852 		}
853 	}
854 	mutex_unlock(&fn->fou_lock);
855 	if (ret < 0)
856 		goto out_free;
857 
858 	return genlmsg_reply(msg, info);
859 
860 out_free:
861 	nlmsg_free(msg);
862 	return ret;
863 }
864 
865 int fou_nl_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb)
866 {
867 	struct net *net = sock_net(skb->sk);
868 	struct fou_net *fn = net_generic(net, fou_net_id);
869 	struct fou *fout;
870 	int idx = 0, ret;
871 
872 	mutex_lock(&fn->fou_lock);
873 	list_for_each_entry(fout, &fn->fou_list, list) {
874 		if (idx++ < cb->args[0])
875 			continue;
876 		ret = fou_dump_info(fout, NETLINK_CB(cb->skb).portid,
877 				    cb->nlh->nlmsg_seq, NLM_F_MULTI,
878 				    skb, FOU_CMD_GET);
879 		if (ret)
880 			break;
881 	}
882 	mutex_unlock(&fn->fou_lock);
883 
884 	cb->args[0] = idx;
885 	return skb->len;
886 }
887 
888 static struct genl_family fou_nl_family __ro_after_init = {
889 	.hdrsize	= 0,
890 	.name		= FOU_GENL_NAME,
891 	.version	= FOU_GENL_VERSION,
892 	.maxattr	= FOU_ATTR_MAX,
893 	.policy		= fou_nl_policy,
894 	.netnsok	= true,
895 	.module		= THIS_MODULE,
896 	.small_ops	= fou_nl_ops,
897 	.n_small_ops	= ARRAY_SIZE(fou_nl_ops),
898 	.resv_start_op	= FOU_CMD_GET + 1,
899 };
900 
901 size_t fou_encap_hlen(struct ip_tunnel_encap *e)
902 {
903 	return sizeof(struct udphdr);
904 }
905 EXPORT_SYMBOL(fou_encap_hlen);
906 
907 size_t gue_encap_hlen(struct ip_tunnel_encap *e)
908 {
909 	size_t len;
910 	bool need_priv = false;
911 
912 	len = sizeof(struct udphdr) + sizeof(struct guehdr);
913 
914 	if (e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) {
915 		len += GUE_PLEN_REMCSUM;
916 		need_priv = true;
917 	}
918 
919 	len += need_priv ? GUE_LEN_PRIV : 0;
920 
921 	return len;
922 }
923 EXPORT_SYMBOL(gue_encap_hlen);
924 
925 int __fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
926 		       u8 *protocol, __be16 *sport, int type)
927 {
928 	int err;
929 
930 	err = iptunnel_handle_offloads(skb, type);
931 	if (err)
932 		return err;
933 
934 	*sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
935 						skb, 0, 0, false);
936 
937 	return 0;
938 }
939 EXPORT_SYMBOL(__fou_build_header);
940 
941 int __gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
942 		       u8 *protocol, __be16 *sport, int type)
943 {
944 	struct guehdr *guehdr;
945 	size_t hdrlen, optlen = 0;
946 	void *data;
947 	bool need_priv = false;
948 	int err;
949 
950 	if ((e->flags & TUNNEL_ENCAP_FLAG_REMCSUM) &&
951 	    skb->ip_summed == CHECKSUM_PARTIAL) {
952 		optlen += GUE_PLEN_REMCSUM;
953 		type |= SKB_GSO_TUNNEL_REMCSUM;
954 		need_priv = true;
955 	}
956 
957 	optlen += need_priv ? GUE_LEN_PRIV : 0;
958 
959 	err = iptunnel_handle_offloads(skb, type);
960 	if (err)
961 		return err;
962 
963 	/* Get source port (based on flow hash) before skb_push */
964 	*sport = e->sport ? : udp_flow_src_port(dev_net(skb->dev),
965 						skb, 0, 0, false);
966 
967 	hdrlen = sizeof(struct guehdr) + optlen;
968 
969 	skb_push(skb, hdrlen);
970 
971 	guehdr = (struct guehdr *)skb->data;
972 
973 	guehdr->control = 0;
974 	guehdr->version = 0;
975 	guehdr->hlen = optlen >> 2;
976 	guehdr->flags = 0;
977 	guehdr->proto_ctype = *protocol;
978 
979 	data = &guehdr[1];
980 
981 	if (need_priv) {
982 		__be32 *flags = data;
983 
984 		guehdr->flags |= GUE_FLAG_PRIV;
985 		*flags = 0;
986 		data += GUE_LEN_PRIV;
987 
988 		if (type & SKB_GSO_TUNNEL_REMCSUM) {
989 			u16 csum_start = skb_checksum_start_offset(skb);
990 			__be16 *pd = data;
991 
992 			if (csum_start < hdrlen)
993 				return -EINVAL;
994 
995 			csum_start -= hdrlen;
996 			pd[0] = htons(csum_start);
997 			pd[1] = htons(csum_start + skb->csum_offset);
998 
999 			if (!skb_is_gso(skb)) {
1000 				skb->ip_summed = CHECKSUM_NONE;
1001 				skb->encapsulation = 0;
1002 			}
1003 
1004 			*flags |= GUE_PFLAG_REMCSUM;
1005 			data += GUE_PLEN_REMCSUM;
1006 		}
1007 
1008 	}
1009 
1010 	return 0;
1011 }
1012 EXPORT_SYMBOL(__gue_build_header);
1013 
1014 #ifdef CONFIG_NET_FOU_IP_TUNNELS
1015 
1016 static void fou_build_udp(struct sk_buff *skb, struct ip_tunnel_encap *e,
1017 			  struct flowi4 *fl4, u8 *protocol, __be16 sport)
1018 {
1019 	struct udphdr *uh;
1020 
1021 	skb_push(skb, sizeof(struct udphdr));
1022 	skb_reset_transport_header(skb);
1023 
1024 	uh = udp_hdr(skb);
1025 
1026 	uh->dest = e->dport;
1027 	uh->source = sport;
1028 	uh->len = htons(skb->len);
1029 	udp_set_csum(!(e->flags & TUNNEL_ENCAP_FLAG_CSUM), skb,
1030 		     fl4->saddr, fl4->daddr, skb->len);
1031 
1032 	*protocol = IPPROTO_UDP;
1033 }
1034 
1035 static int fou_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
1036 			    u8 *protocol, struct flowi4 *fl4)
1037 {
1038 	int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM :
1039 						       SKB_GSO_UDP_TUNNEL;
1040 	__be16 sport;
1041 	int err;
1042 
1043 	err = __fou_build_header(skb, e, protocol, &sport, type);
1044 	if (err)
1045 		return err;
1046 
1047 	fou_build_udp(skb, e, fl4, protocol, sport);
1048 
1049 	return 0;
1050 }
1051 
1052 static int gue_build_header(struct sk_buff *skb, struct ip_tunnel_encap *e,
1053 			    u8 *protocol, struct flowi4 *fl4)
1054 {
1055 	int type = e->flags & TUNNEL_ENCAP_FLAG_CSUM ? SKB_GSO_UDP_TUNNEL_CSUM :
1056 						       SKB_GSO_UDP_TUNNEL;
1057 	__be16 sport;
1058 	int err;
1059 
1060 	err = __gue_build_header(skb, e, protocol, &sport, type);
1061 	if (err)
1062 		return err;
1063 
1064 	fou_build_udp(skb, e, fl4, protocol, sport);
1065 
1066 	return 0;
1067 }
1068 
1069 static int gue_err_proto_handler(int proto, struct sk_buff *skb, u32 info)
1070 {
1071 	const struct net_protocol *ipprot = rcu_dereference(inet_protos[proto]);
1072 
1073 	if (ipprot && ipprot->err_handler) {
1074 		if (!ipprot->err_handler(skb, info))
1075 			return 0;
1076 	}
1077 
1078 	return -ENOENT;
1079 }
1080 
1081 static int gue_err(struct sk_buff *skb, u32 info)
1082 {
1083 	int transport_offset = skb_transport_offset(skb);
1084 	struct guehdr *guehdr;
1085 	size_t len, optlen;
1086 	int ret;
1087 
1088 	len = sizeof(struct udphdr) + sizeof(struct guehdr);
1089 	if (!pskb_may_pull(skb, transport_offset + len))
1090 		return -EINVAL;
1091 
1092 	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
1093 
1094 	switch (guehdr->version) {
1095 	case 0: /* Full GUE header present */
1096 		break;
1097 	case 1: {
1098 		/* Direct encapsulation of IPv4 or IPv6 */
1099 		skb_set_transport_header(skb, -(int)sizeof(struct icmphdr));
1100 
1101 		switch (((struct iphdr *)guehdr)->version) {
1102 		case 4:
1103 			ret = gue_err_proto_handler(IPPROTO_IPIP, skb, info);
1104 			goto out;
1105 #if IS_ENABLED(CONFIG_IPV6)
1106 		case 6:
1107 			ret = gue_err_proto_handler(IPPROTO_IPV6, skb, info);
1108 			goto out;
1109 #endif
1110 		default:
1111 			ret = -EOPNOTSUPP;
1112 			goto out;
1113 		}
1114 	}
1115 	default: /* Undefined version */
1116 		return -EOPNOTSUPP;
1117 	}
1118 
1119 	if (guehdr->control)
1120 		return -ENOENT;
1121 
1122 	optlen = guehdr->hlen << 2;
1123 
1124 	if (!pskb_may_pull(skb, transport_offset + len + optlen))
1125 		return -EINVAL;
1126 
1127 	guehdr = (struct guehdr *)&udp_hdr(skb)[1];
1128 	if (validate_gue_flags(guehdr, optlen))
1129 		return -EINVAL;
1130 
1131 	/* Handling exceptions for direct UDP encapsulation in GUE would lead to
1132 	 * recursion. Besides, this kind of encapsulation can't even be
1133 	 * configured currently. Discard this.
1134 	 */
1135 	if (guehdr->proto_ctype == IPPROTO_UDP ||
1136 	    guehdr->proto_ctype == IPPROTO_UDPLITE)
1137 		return -EOPNOTSUPP;
1138 
1139 	skb_set_transport_header(skb, -(int)sizeof(struct icmphdr));
1140 	ret = gue_err_proto_handler(guehdr->proto_ctype, skb, info);
1141 
1142 out:
1143 	skb_set_transport_header(skb, transport_offset);
1144 	return ret;
1145 }
1146 
1147 
1148 static const struct ip_tunnel_encap_ops fou_iptun_ops = {
1149 	.encap_hlen = fou_encap_hlen,
1150 	.build_header = fou_build_header,
1151 	.err_handler = gue_err,
1152 };
1153 
1154 static const struct ip_tunnel_encap_ops gue_iptun_ops = {
1155 	.encap_hlen = gue_encap_hlen,
1156 	.build_header = gue_build_header,
1157 	.err_handler = gue_err,
1158 };
1159 
1160 static int ip_tunnel_encap_add_fou_ops(void)
1161 {
1162 	int ret;
1163 
1164 	ret = ip_tunnel_encap_add_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
1165 	if (ret < 0) {
1166 		pr_err("can't add fou ops\n");
1167 		return ret;
1168 	}
1169 
1170 	ret = ip_tunnel_encap_add_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
1171 	if (ret < 0) {
1172 		pr_err("can't add gue ops\n");
1173 		ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
1174 		return ret;
1175 	}
1176 
1177 	return 0;
1178 }
1179 
1180 static void ip_tunnel_encap_del_fou_ops(void)
1181 {
1182 	ip_tunnel_encap_del_ops(&fou_iptun_ops, TUNNEL_ENCAP_FOU);
1183 	ip_tunnel_encap_del_ops(&gue_iptun_ops, TUNNEL_ENCAP_GUE);
1184 }
1185 
1186 #else
1187 
1188 static int ip_tunnel_encap_add_fou_ops(void)
1189 {
1190 	return 0;
1191 }
1192 
1193 static void ip_tunnel_encap_del_fou_ops(void)
1194 {
1195 }
1196 
1197 #endif
1198 
1199 static __net_init int fou_init_net(struct net *net)
1200 {
1201 	struct fou_net *fn = net_generic(net, fou_net_id);
1202 
1203 	INIT_LIST_HEAD(&fn->fou_list);
1204 	mutex_init(&fn->fou_lock);
1205 	return 0;
1206 }
1207 
1208 static __net_exit void fou_exit_net(struct net *net)
1209 {
1210 	struct fou_net *fn = net_generic(net, fou_net_id);
1211 	struct fou *fou, *next;
1212 
1213 	/* Close all the FOU sockets */
1214 	mutex_lock(&fn->fou_lock);
1215 	list_for_each_entry_safe(fou, next, &fn->fou_list, list)
1216 		fou_release(fou);
1217 	mutex_unlock(&fn->fou_lock);
1218 }
1219 
1220 static struct pernet_operations fou_net_ops = {
1221 	.init = fou_init_net,
1222 	.exit = fou_exit_net,
1223 	.id   = &fou_net_id,
1224 	.size = sizeof(struct fou_net),
1225 };
1226 
1227 static int __init fou_init(void)
1228 {
1229 	int ret;
1230 
1231 	ret = register_pernet_device(&fou_net_ops);
1232 	if (ret)
1233 		goto exit;
1234 
1235 	ret = genl_register_family(&fou_nl_family);
1236 	if (ret < 0)
1237 		goto unregister;
1238 
1239 	ret = register_fou_bpf();
1240 	if (ret < 0)
1241 		goto kfunc_failed;
1242 
1243 	ret = ip_tunnel_encap_add_fou_ops();
1244 	if (ret == 0)
1245 		return 0;
1246 
1247 kfunc_failed:
1248 	genl_unregister_family(&fou_nl_family);
1249 unregister:
1250 	unregister_pernet_device(&fou_net_ops);
1251 exit:
1252 	return ret;
1253 }
1254 
1255 static void __exit fou_fini(void)
1256 {
1257 	ip_tunnel_encap_del_fou_ops();
1258 	genl_unregister_family(&fou_nl_family);
1259 	unregister_pernet_device(&fou_net_ops);
1260 }
1261 
1262 module_init(fou_init);
1263 module_exit(fou_fini);
1264 MODULE_AUTHOR("Tom Herbert <therbert@google.com>");
1265 MODULE_LICENSE("GPL");
1266 MODULE_DESCRIPTION("Foo over UDP");
1267