xref: /linux/net/ipv6/seg6_iptunnel.c (revision f9c52a6ba9780bd27e0bf4c044fd91c13c778b6e)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *  SR-IPv6 implementation
4  *
5  *  Author:
6  *  David Lebrun <david.lebrun@uclouvain.be>
7  */
8 
9 #include <linux/types.h>
10 #include <linux/skbuff.h>
11 #include <linux/net.h>
12 #include <linux/module.h>
13 #include <net/ip.h>
14 #include <net/ip_tunnels.h>
15 #include <net/lwtunnel.h>
16 #include <net/netevent.h>
17 #include <net/netns/generic.h>
18 #include <net/ip6_fib.h>
19 #include <net/route.h>
20 #include <net/seg6.h>
21 #include <linux/seg6.h>
22 #include <linux/seg6_iptunnel.h>
23 #include <net/addrconf.h>
24 #include <net/ip6_route.h>
25 #include <net/dst_cache.h>
26 #ifdef CONFIG_IPV6_SEG6_HMAC
27 #include <net/seg6_hmac.h>
28 #endif
29 #include <linux/netfilter.h>
30 
31 static size_t seg6_lwt_headroom(struct seg6_iptunnel_encap *tuninfo)
32 {
33 	int head = 0;
34 
35 	switch (tuninfo->mode) {
36 	case SEG6_IPTUN_MODE_INLINE:
37 		break;
38 	case SEG6_IPTUN_MODE_ENCAP:
39 	case SEG6_IPTUN_MODE_ENCAP_RED:
40 		head = sizeof(struct ipv6hdr);
41 		break;
42 	case SEG6_IPTUN_MODE_L2ENCAP:
43 	case SEG6_IPTUN_MODE_L2ENCAP_RED:
44 		return 0;
45 	}
46 
47 	return ((tuninfo->srh->hdrlen + 1) << 3) + head;
48 }
49 
50 struct seg6_lwt {
51 	struct dst_cache cache_input;
52 	struct dst_cache cache_output;
53 	struct in6_addr tunsrc;
54 	struct seg6_iptunnel_encap tuninfo[];
55 };
56 
57 static inline struct seg6_lwt *seg6_lwt_lwtunnel(struct lwtunnel_state *lwt)
58 {
59 	return (struct seg6_lwt *)lwt->data;
60 }
61 
62 static inline struct seg6_iptunnel_encap *
63 seg6_encap_lwtunnel(struct lwtunnel_state *lwt)
64 {
65 	return seg6_lwt_lwtunnel(lwt)->tuninfo;
66 }
67 
68 static const struct nla_policy seg6_iptunnel_policy[SEG6_IPTUNNEL_MAX + 1] = {
69 	[SEG6_IPTUNNEL_SRH]	= { .type = NLA_BINARY },
70 	[SEG6_IPTUNNEL_SRC]	= NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
71 };
72 
73 static int nla_put_srh(struct sk_buff *skb, int attrtype,
74 		       struct seg6_iptunnel_encap *tuninfo)
75 {
76 	struct seg6_iptunnel_encap *data;
77 	struct nlattr *nla;
78 	int len;
79 
80 	len = SEG6_IPTUN_ENCAP_SIZE(tuninfo);
81 
82 	nla = nla_reserve(skb, attrtype, len);
83 	if (!nla)
84 		return -EMSGSIZE;
85 
86 	data = nla_data(nla);
87 	memcpy(data, tuninfo, len);
88 
89 	return 0;
90 }
91 
92 static void set_tun_src(struct net *net, struct net_device *dev,
93 			struct in6_addr *daddr, struct in6_addr *saddr,
94 			struct in6_addr *route_tunsrc)
95 {
96 	struct seg6_pernet_data *sdata = seg6_pernet(net);
97 	struct in6_addr *tun_src;
98 
99 	/* Priority order to select tunnel source address:
100 	 *  1. per route source address (if configured)
101 	 *  2. per network namespace source address (if configured)
102 	 *  3. dynamic resolution
103 	 */
104 	if (route_tunsrc && !ipv6_addr_any(route_tunsrc)) {
105 		memcpy(saddr, route_tunsrc, sizeof(struct in6_addr));
106 	} else {
107 		rcu_read_lock();
108 		tun_src = rcu_dereference(sdata->tun_src);
109 
110 		if (!ipv6_addr_any(tun_src)) {
111 			memcpy(saddr, tun_src, sizeof(struct in6_addr));
112 		} else {
113 			ipv6_dev_get_saddr(net, dev, daddr,
114 					   IPV6_PREFER_SRC_PUBLIC, saddr);
115 		}
116 
117 		rcu_read_unlock();
118 	}
119 }
120 
121 /* Compute flowlabel for outer IPv6 header */
122 static __be32 seg6_make_flowlabel(struct net *net, struct sk_buff *skb,
123 				  struct ipv6hdr *inner_hdr)
124 {
125 	int do_flowlabel = net->ipv6.sysctl.seg6_flowlabel;
126 	__be32 flowlabel = 0;
127 	u32 hash;
128 
129 	if (do_flowlabel > 0) {
130 		hash = skb_get_hash(skb);
131 		hash = rol32(hash, 16);
132 		flowlabel = (__force __be32)hash & IPV6_FLOWLABEL_MASK;
133 	} else if (!do_flowlabel && skb->protocol == htons(ETH_P_IPV6)) {
134 		flowlabel = ip6_flowlabel(inner_hdr);
135 	}
136 	return flowlabel;
137 }
138 
139 static int __seg6_do_srh_encap(struct sk_buff *skb, struct ipv6_sr_hdr *osrh,
140 			       int proto, struct dst_entry *cache_dst,
141 			       struct in6_addr *route_tunsrc)
142 {
143 	struct dst_entry *dst = skb_dst(skb);
144 	struct net_device *dev = dst_dev(dst);
145 	struct net *net = dev_net(dev);
146 	struct ipv6hdr *hdr, *inner_hdr;
147 	struct ipv6_sr_hdr *isrh;
148 	int hdrlen, tot_len, err;
149 	__be32 flowlabel;
150 
151 	hdrlen = (osrh->hdrlen + 1) << 3;
152 	tot_len = hdrlen + sizeof(*hdr);
153 
154 	err = skb_cow_head(skb, tot_len + dst_dev_overhead(cache_dst, skb));
155 	if (unlikely(err))
156 		return err;
157 
158 	inner_hdr = ipv6_hdr(skb);
159 	flowlabel = seg6_make_flowlabel(net, skb, inner_hdr);
160 
161 	skb_push(skb, tot_len);
162 	skb_reset_network_header(skb);
163 	skb_mac_header_rebuild(skb);
164 	hdr = ipv6_hdr(skb);
165 
166 	/* inherit tc, flowlabel and hlim
167 	 * hlim will be decremented in ip6_forward() afterwards and
168 	 * decapsulation will overwrite inner hlim with outer hlim
169 	 */
170 
171 	if (skb->protocol == htons(ETH_P_IPV6)) {
172 		ip6_flow_hdr(hdr, ip6_tclass(ip6_flowinfo(inner_hdr)),
173 			     flowlabel);
174 		hdr->hop_limit = inner_hdr->hop_limit;
175 	} else {
176 		ip6_flow_hdr(hdr, 0, flowlabel);
177 		hdr->hop_limit = ip6_dst_hoplimit(skb_dst(skb));
178 
179 		memset(IP6CB(skb), 0, sizeof(*IP6CB(skb)));
180 
181 		/* the control block has been erased, so we have to set the
182 		 * iif once again.
183 		 * We read the receiving interface index directly from the
184 		 * skb->skb_iif as it is done in the IPv4 receiving path (i.e.:
185 		 * ip_rcv_core(...)).
186 		 */
187 		IP6CB(skb)->iif = skb->skb_iif;
188 	}
189 
190 	hdr->nexthdr = NEXTHDR_ROUTING;
191 
192 	isrh = (void *)hdr + sizeof(*hdr);
193 	memcpy(isrh, osrh, hdrlen);
194 
195 	isrh->nexthdr = proto;
196 
197 	hdr->daddr = isrh->segments[isrh->first_segment];
198 	set_tun_src(net, dev, &hdr->daddr, &hdr->saddr, route_tunsrc);
199 
200 #ifdef CONFIG_IPV6_SEG6_HMAC
201 	if (sr_has_hmac(isrh)) {
202 		err = seg6_push_hmac(net, &hdr->saddr, isrh);
203 		if (unlikely(err))
204 			return err;
205 	}
206 #endif
207 
208 	hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
209 
210 	skb_postpush_rcsum(skb, hdr, tot_len);
211 
212 	return 0;
213 }
214 
215 /* encapsulate an IPv6 packet within an outer IPv6 header with a given SRH */
216 int seg6_do_srh_encap(struct sk_buff *skb, struct ipv6_sr_hdr *osrh, int proto)
217 {
218 	return __seg6_do_srh_encap(skb, osrh, proto, NULL, NULL);
219 }
220 EXPORT_SYMBOL_GPL(seg6_do_srh_encap);
221 
222 /* encapsulate an IPv6 packet within an outer IPv6 header with reduced SRH */
223 static int seg6_do_srh_encap_red(struct sk_buff *skb,
224 				 struct ipv6_sr_hdr *osrh, int proto,
225 				 struct dst_entry *cache_dst,
226 				 struct in6_addr *route_tunsrc)
227 {
228 	__u8 first_seg = osrh->first_segment;
229 	struct dst_entry *dst = skb_dst(skb);
230 	struct net_device *dev = dst_dev(dst);
231 	struct net *net = dev_net(dev);
232 	struct ipv6hdr *hdr, *inner_hdr;
233 	int hdrlen = ipv6_optlen(osrh);
234 	int red_tlv_offset, tlv_offset;
235 	struct ipv6_sr_hdr *isrh;
236 	bool skip_srh = false;
237 	__be32 flowlabel;
238 	int tot_len, err;
239 	int red_hdrlen;
240 	int tlvs_len;
241 
242 	if (first_seg > 0) {
243 		red_hdrlen = hdrlen - sizeof(struct in6_addr);
244 	} else {
245 		/* NOTE: if tag/flags and/or other TLVs are introduced in the
246 		 * seg6_iptunnel infrastructure, they should be considered when
247 		 * deciding to skip the SRH.
248 		 */
249 		skip_srh = !sr_has_hmac(osrh);
250 
251 		red_hdrlen = skip_srh ? 0 : hdrlen;
252 	}
253 
254 	tot_len = red_hdrlen + sizeof(struct ipv6hdr);
255 
256 	err = skb_cow_head(skb, tot_len + dst_dev_overhead(cache_dst, skb));
257 	if (unlikely(err))
258 		return err;
259 
260 	inner_hdr = ipv6_hdr(skb);
261 	flowlabel = seg6_make_flowlabel(net, skb, inner_hdr);
262 
263 	skb_push(skb, tot_len);
264 	skb_reset_network_header(skb);
265 	skb_mac_header_rebuild(skb);
266 	hdr = ipv6_hdr(skb);
267 
268 	/* based on seg6_do_srh_encap() */
269 	if (skb->protocol == htons(ETH_P_IPV6)) {
270 		ip6_flow_hdr(hdr, ip6_tclass(ip6_flowinfo(inner_hdr)),
271 			     flowlabel);
272 		hdr->hop_limit = inner_hdr->hop_limit;
273 	} else {
274 		ip6_flow_hdr(hdr, 0, flowlabel);
275 		hdr->hop_limit = ip6_dst_hoplimit(skb_dst(skb));
276 
277 		memset(IP6CB(skb), 0, sizeof(*IP6CB(skb)));
278 		IP6CB(skb)->iif = skb->skb_iif;
279 	}
280 
281 	/* no matter if we have to skip the SRH or not, the first segment
282 	 * always comes in the pushed IPv6 header.
283 	 */
284 	hdr->daddr = osrh->segments[first_seg];
285 
286 	if (skip_srh) {
287 		hdr->nexthdr = proto;
288 
289 		set_tun_src(net, dev, &hdr->daddr, &hdr->saddr, route_tunsrc);
290 		goto out;
291 	}
292 
293 	/* we cannot skip the SRH, slow path */
294 
295 	hdr->nexthdr = NEXTHDR_ROUTING;
296 	isrh = (void *)hdr + sizeof(struct ipv6hdr);
297 
298 	if (unlikely(!first_seg)) {
299 		/* this is a very rare case; we have only one SID but
300 		 * we cannot skip the SRH since we are carrying some
301 		 * other info.
302 		 */
303 		memcpy(isrh, osrh, hdrlen);
304 		goto srcaddr;
305 	}
306 
307 	tlv_offset = sizeof(*osrh) + (first_seg + 1) * sizeof(struct in6_addr);
308 	red_tlv_offset = tlv_offset - sizeof(struct in6_addr);
309 
310 	memcpy(isrh, osrh, red_tlv_offset);
311 
312 	tlvs_len = hdrlen - tlv_offset;
313 	if (unlikely(tlvs_len > 0)) {
314 		const void *s = (const void *)osrh + tlv_offset;
315 		void *d = (void *)isrh + red_tlv_offset;
316 
317 		memcpy(d, s, tlvs_len);
318 	}
319 
320 	--isrh->first_segment;
321 	isrh->hdrlen -= 2;
322 
323 srcaddr:
324 	isrh->nexthdr = proto;
325 	set_tun_src(net, dev, &hdr->daddr, &hdr->saddr, route_tunsrc);
326 
327 #ifdef CONFIG_IPV6_SEG6_HMAC
328 	if (unlikely(!skip_srh && sr_has_hmac(isrh))) {
329 		err = seg6_push_hmac(net, &hdr->saddr, isrh);
330 		if (unlikely(err))
331 			return err;
332 	}
333 #endif
334 
335 out:
336 	hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
337 
338 	skb_postpush_rcsum(skb, hdr, tot_len);
339 
340 	return 0;
341 }
342 
343 static int __seg6_do_srh_inline(struct sk_buff *skb, struct ipv6_sr_hdr *osrh,
344 				struct dst_entry *cache_dst)
345 {
346 	struct ipv6hdr *hdr, *oldhdr;
347 	struct ipv6_sr_hdr *isrh;
348 	int hdrlen, err;
349 
350 	hdrlen = (osrh->hdrlen + 1) << 3;
351 
352 	err = skb_cow_head(skb, hdrlen + dst_dev_overhead(cache_dst, skb));
353 	if (unlikely(err))
354 		return err;
355 
356 	oldhdr = ipv6_hdr(skb);
357 
358 	skb_pull(skb, sizeof(struct ipv6hdr));
359 	skb_postpull_rcsum(skb, skb_network_header(skb),
360 			   sizeof(struct ipv6hdr));
361 
362 	skb_push(skb, sizeof(struct ipv6hdr) + hdrlen);
363 	skb_reset_network_header(skb);
364 	skb_mac_header_rebuild(skb);
365 
366 	hdr = ipv6_hdr(skb);
367 
368 	memmove(hdr, oldhdr, sizeof(*hdr));
369 
370 	isrh = (void *)hdr + sizeof(*hdr);
371 	memcpy(isrh, osrh, hdrlen);
372 
373 	isrh->nexthdr = hdr->nexthdr;
374 	hdr->nexthdr = NEXTHDR_ROUTING;
375 
376 	isrh->segments[0] = hdr->daddr;
377 	hdr->daddr = isrh->segments[isrh->first_segment];
378 
379 #ifdef CONFIG_IPV6_SEG6_HMAC
380 	if (sr_has_hmac(isrh)) {
381 		struct net *net = skb_dst_dev_net(skb);
382 
383 		err = seg6_push_hmac(net, &hdr->saddr, isrh);
384 		if (unlikely(err))
385 			return err;
386 	}
387 #endif
388 
389 	hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
390 
391 	skb_postpush_rcsum(skb, hdr, sizeof(struct ipv6hdr) + hdrlen);
392 
393 	return 0;
394 }
395 
396 static int seg6_do_srh(struct sk_buff *skb, struct dst_entry *cache_dst)
397 {
398 	struct dst_entry *dst = skb_dst(skb);
399 	struct seg6_iptunnel_encap *tinfo;
400 	struct seg6_lwt *slwt;
401 	int proto, err = 0;
402 
403 	slwt = seg6_lwt_lwtunnel(dst->lwtstate);
404 	tinfo = slwt->tuninfo;
405 
406 	switch (tinfo->mode) {
407 	case SEG6_IPTUN_MODE_INLINE:
408 		if (skb->protocol != htons(ETH_P_IPV6))
409 			return -EINVAL;
410 
411 		err = __seg6_do_srh_inline(skb, tinfo->srh, cache_dst);
412 		if (err)
413 			return err;
414 		break;
415 	case SEG6_IPTUN_MODE_ENCAP:
416 	case SEG6_IPTUN_MODE_ENCAP_RED:
417 		err = iptunnel_handle_offloads(skb, SKB_GSO_IPXIP6);
418 		if (err)
419 			return err;
420 
421 		if (skb->protocol == htons(ETH_P_IPV6))
422 			proto = IPPROTO_IPV6;
423 		else if (skb->protocol == htons(ETH_P_IP))
424 			proto = IPPROTO_IPIP;
425 		else
426 			return -EINVAL;
427 
428 		if (tinfo->mode == SEG6_IPTUN_MODE_ENCAP)
429 			err = __seg6_do_srh_encap(skb, tinfo->srh, proto,
430 						  cache_dst, &slwt->tunsrc);
431 		else
432 			err = seg6_do_srh_encap_red(skb, tinfo->srh, proto,
433 						    cache_dst, &slwt->tunsrc);
434 
435 		if (err)
436 			return err;
437 
438 		skb_set_inner_transport_header(skb, skb_transport_offset(skb));
439 		skb_set_inner_protocol(skb, skb->protocol);
440 		skb->protocol = htons(ETH_P_IPV6);
441 		break;
442 	case SEG6_IPTUN_MODE_L2ENCAP:
443 	case SEG6_IPTUN_MODE_L2ENCAP_RED:
444 		if (!skb_mac_header_was_set(skb))
445 			return -EINVAL;
446 
447 		if (pskb_expand_head(skb, skb->mac_len, 0, GFP_ATOMIC) < 0)
448 			return -ENOMEM;
449 
450 		skb_mac_header_rebuild(skb);
451 		skb_push(skb, skb->mac_len);
452 
453 		if (tinfo->mode == SEG6_IPTUN_MODE_L2ENCAP)
454 			err = __seg6_do_srh_encap(skb, tinfo->srh,
455 						  IPPROTO_ETHERNET, cache_dst,
456 						  &slwt->tunsrc);
457 		else
458 			err = seg6_do_srh_encap_red(skb, tinfo->srh,
459 						    IPPROTO_ETHERNET, cache_dst,
460 						    &slwt->tunsrc);
461 
462 		if (err)
463 			return err;
464 
465 		skb->protocol = htons(ETH_P_IPV6);
466 		break;
467 	}
468 
469 	skb_set_transport_header(skb, sizeof(struct ipv6hdr));
470 	nf_reset_ct(skb);
471 
472 	return 0;
473 }
474 
475 /* insert an SRH within an IPv6 packet, just after the IPv6 header */
476 int seg6_do_srh_inline(struct sk_buff *skb, struct ipv6_sr_hdr *osrh)
477 {
478 	return __seg6_do_srh_inline(skb, osrh, NULL);
479 }
480 EXPORT_SYMBOL_GPL(seg6_do_srh_inline);
481 
482 static int seg6_input_finish(struct net *net, struct sock *sk,
483 			     struct sk_buff *skb)
484 {
485 	return dst_input(skb);
486 }
487 
488 static int seg6_input_core(struct net *net, struct sock *sk,
489 			   struct sk_buff *skb)
490 {
491 	struct dst_entry *orig_dst = skb_dst(skb);
492 	struct dst_entry *dst = NULL;
493 	struct lwtunnel_state *lwtst;
494 	struct seg6_lwt *slwt;
495 	int err;
496 
497 	/* We cannot dereference "orig_dst" once ip6_route_input() or
498 	 * skb_dst_drop() is called. However, in order to detect a dst loop, we
499 	 * need the address of its lwtstate. So, save the address of lwtstate
500 	 * now and use it later as a comparison.
501 	 */
502 	lwtst = orig_dst->lwtstate;
503 
504 	slwt = seg6_lwt_lwtunnel(lwtst);
505 
506 	local_bh_disable();
507 	dst = dst_cache_get(&slwt->cache_input);
508 	local_bh_enable();
509 
510 	err = seg6_do_srh(skb, dst);
511 	if (unlikely(err)) {
512 		dst_release(dst);
513 		goto drop;
514 	}
515 
516 	if (!dst) {
517 		ip6_route_input(skb);
518 
519 		/* ip6_route_input() sets a NOREF dst; force a refcount on it
520 		 * before caching or further use.
521 		 */
522 		skb_dst_force(skb);
523 		dst = skb_dst(skb);
524 		if (unlikely(!dst)) {
525 			err = -ENETUNREACH;
526 			goto drop;
527 		}
528 
529 		/* cache only if we don't create a dst reference loop */
530 		if (!dst->error && lwtst != dst->lwtstate) {
531 			local_bh_disable();
532 			dst_cache_set_ip6(&slwt->cache_input, dst,
533 					  &ipv6_hdr(skb)->saddr);
534 			local_bh_enable();
535 		}
536 
537 		err = skb_cow_head(skb, LL_RESERVED_SPACE(dst_dev(dst)));
538 		if (unlikely(err))
539 			goto drop;
540 	} else {
541 		skb_dst_drop(skb);
542 		skb_dst_set(skb, dst);
543 	}
544 
545 	if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
546 		return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT,
547 			       dev_net(skb->dev), NULL, skb, NULL,
548 			       skb_dst_dev(skb), seg6_input_finish);
549 
550 	return seg6_input_finish(dev_net(skb->dev), NULL, skb);
551 drop:
552 	kfree_skb(skb);
553 	return err;
554 }
555 
556 static int seg6_input_nf(struct sk_buff *skb)
557 {
558 	struct net_device *dev = skb_dst_dev(skb);
559 	struct net *net = dev_net(skb->dev);
560 
561 	switch (skb->protocol) {
562 	case htons(ETH_P_IP):
563 		return NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING, net, NULL,
564 			       skb, NULL, dev, seg6_input_core);
565 	case htons(ETH_P_IPV6):
566 		return NF_HOOK(NFPROTO_IPV6, NF_INET_POST_ROUTING, net, NULL,
567 			       skb, NULL, dev, seg6_input_core);
568 	}
569 
570 	return -EINVAL;
571 }
572 
573 static int seg6_input(struct sk_buff *skb)
574 {
575 	if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
576 		return seg6_input_nf(skb);
577 
578 	return seg6_input_core(dev_net(skb->dev), NULL, skb);
579 }
580 
581 static int seg6_output_core(struct net *net, struct sock *sk,
582 			    struct sk_buff *skb)
583 {
584 	struct dst_entry *orig_dst = skb_dst(skb);
585 	struct dst_entry *dst = NULL;
586 	struct seg6_lwt *slwt;
587 	int err;
588 
589 	slwt = seg6_lwt_lwtunnel(orig_dst->lwtstate);
590 
591 	local_bh_disable();
592 	dst = dst_cache_get(&slwt->cache_output);
593 	local_bh_enable();
594 
595 	err = seg6_do_srh(skb, dst);
596 	if (unlikely(err))
597 		goto drop;
598 
599 	if (unlikely(!dst)) {
600 		struct ipv6hdr *hdr = ipv6_hdr(skb);
601 		struct flowi6 fl6;
602 
603 		memset(&fl6, 0, sizeof(fl6));
604 		fl6.daddr = hdr->daddr;
605 		fl6.saddr = hdr->saddr;
606 		fl6.flowlabel = ip6_flowinfo(hdr);
607 		fl6.flowi6_mark = skb->mark;
608 		fl6.flowi6_proto = hdr->nexthdr;
609 
610 		dst = ip6_route_output(net, NULL, &fl6);
611 		if (dst->error) {
612 			err = dst->error;
613 			goto drop;
614 		}
615 
616 		/* cache only if we don't create a dst reference loop */
617 		if (orig_dst->lwtstate != dst->lwtstate) {
618 			local_bh_disable();
619 			dst_cache_set_ip6(&slwt->cache_output, dst, &fl6.saddr);
620 			local_bh_enable();
621 		}
622 
623 		err = skb_cow_head(skb, LL_RESERVED_SPACE(dst_dev(dst)));
624 		if (unlikely(err))
625 			goto drop;
626 	}
627 
628 	skb_dst_drop(skb);
629 	skb_dst_set(skb, dst);
630 
631 	if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
632 		return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, net, sk, skb,
633 			       NULL, dst_dev(dst), dst_output);
634 
635 	return dst_output(net, sk, skb);
636 drop:
637 	dst_release(dst);
638 	kfree_skb(skb);
639 	return err;
640 }
641 
642 static int seg6_output_nf(struct net *net, struct sock *sk, struct sk_buff *skb)
643 {
644 	struct net_device *dev = skb_dst_dev(skb);
645 
646 	switch (skb->protocol) {
647 	case htons(ETH_P_IP):
648 		return NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING, net, sk, skb,
649 			       NULL, dev, seg6_output_core);
650 	case htons(ETH_P_IPV6):
651 		return NF_HOOK(NFPROTO_IPV6, NF_INET_POST_ROUTING, net, sk, skb,
652 			       NULL, dev, seg6_output_core);
653 	}
654 
655 	return -EINVAL;
656 }
657 
658 static int seg6_output(struct net *net, struct sock *sk, struct sk_buff *skb)
659 {
660 	if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
661 		return seg6_output_nf(net, sk, skb);
662 
663 	return seg6_output_core(net, sk, skb);
664 }
665 
666 static int seg6_build_state(struct net *net, struct nlattr *nla,
667 			    unsigned int family, const void *cfg,
668 			    struct lwtunnel_state **ts,
669 			    struct netlink_ext_ack *extack)
670 {
671 	struct nlattr *tb[SEG6_IPTUNNEL_MAX + 1];
672 	struct seg6_iptunnel_encap *tuninfo;
673 	struct lwtunnel_state *newts;
674 	int tuninfo_len, min_size;
675 	struct seg6_lwt *slwt;
676 	int err;
677 
678 	if (family != AF_INET && family != AF_INET6)
679 		return -EINVAL;
680 
681 	err = nla_parse_nested_deprecated(tb, SEG6_IPTUNNEL_MAX, nla,
682 					  seg6_iptunnel_policy, extack);
683 
684 	if (err < 0)
685 		return err;
686 
687 	if (!tb[SEG6_IPTUNNEL_SRH])
688 		return -EINVAL;
689 
690 	tuninfo = nla_data(tb[SEG6_IPTUNNEL_SRH]);
691 	tuninfo_len = nla_len(tb[SEG6_IPTUNNEL_SRH]);
692 
693 	/* tuninfo must contain at least the iptunnel encap structure,
694 	 * the SRH and one segment
695 	 */
696 	min_size = sizeof(*tuninfo) + sizeof(struct ipv6_sr_hdr) +
697 		   sizeof(struct in6_addr);
698 	if (tuninfo_len < min_size)
699 		return -EINVAL;
700 
701 	switch (tuninfo->mode) {
702 	case SEG6_IPTUN_MODE_INLINE:
703 		if (family != AF_INET6)
704 			return -EINVAL;
705 
706 		if (tb[SEG6_IPTUNNEL_SRC]) {
707 			NL_SET_ERR_MSG(extack, "incompatible mode for tunsrc");
708 			return -EINVAL;
709 		}
710 		break;
711 	case SEG6_IPTUN_MODE_ENCAP:
712 		break;
713 	case SEG6_IPTUN_MODE_L2ENCAP:
714 		break;
715 	case SEG6_IPTUN_MODE_ENCAP_RED:
716 		break;
717 	case SEG6_IPTUN_MODE_L2ENCAP_RED:
718 		break;
719 	default:
720 		return -EINVAL;
721 	}
722 
723 	/* verify that SRH is consistent */
724 	if (!seg6_validate_srh(tuninfo->srh, tuninfo_len - sizeof(*tuninfo), false))
725 		return -EINVAL;
726 
727 	newts = lwtunnel_state_alloc(tuninfo_len + sizeof(*slwt));
728 	if (!newts)
729 		return -ENOMEM;
730 
731 	slwt = seg6_lwt_lwtunnel(newts);
732 
733 	err = dst_cache_init(&slwt->cache_input, GFP_ATOMIC);
734 	if (err)
735 		goto err_free_newts;
736 
737 	err = dst_cache_init(&slwt->cache_output, GFP_ATOMIC);
738 	if (err)
739 		goto err_destroy_input;
740 
741 	memcpy(&slwt->tuninfo, tuninfo, tuninfo_len);
742 
743 	if (tb[SEG6_IPTUNNEL_SRC]) {
744 		slwt->tunsrc = nla_get_in6_addr(tb[SEG6_IPTUNNEL_SRC]);
745 
746 		if (ipv6_addr_any(&slwt->tunsrc) ||
747 		    ipv6_addr_is_multicast(&slwt->tunsrc) ||
748 		    ipv6_addr_loopback(&slwt->tunsrc)) {
749 			NL_SET_ERR_MSG(extack, "invalid tunsrc address");
750 			err = -EINVAL;
751 			goto err_destroy_output;
752 		}
753 	}
754 
755 	newts->type = LWTUNNEL_ENCAP_SEG6;
756 	newts->flags |= LWTUNNEL_STATE_INPUT_REDIRECT;
757 
758 	if (tuninfo->mode != SEG6_IPTUN_MODE_L2ENCAP &&
759 	    tuninfo->mode != SEG6_IPTUN_MODE_L2ENCAP_RED)
760 		newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT;
761 
762 	newts->headroom = seg6_lwt_headroom(tuninfo);
763 
764 	*ts = newts;
765 
766 	return 0;
767 
768 err_destroy_output:
769 	dst_cache_destroy(&slwt->cache_output);
770 err_destroy_input:
771 	dst_cache_destroy(&slwt->cache_input);
772 err_free_newts:
773 	kfree(newts);
774 	return err;
775 }
776 
777 static void seg6_destroy_state(struct lwtunnel_state *lwt)
778 {
779 	struct seg6_lwt *slwt = seg6_lwt_lwtunnel(lwt);
780 
781 	dst_cache_destroy(&slwt->cache_input);
782 	dst_cache_destroy(&slwt->cache_output);
783 }
784 
785 static int seg6_fill_encap_info(struct sk_buff *skb,
786 				struct lwtunnel_state *lwtstate)
787 {
788 	struct seg6_iptunnel_encap *tuninfo = seg6_encap_lwtunnel(lwtstate);
789 	struct seg6_lwt *slwt = seg6_lwt_lwtunnel(lwtstate);
790 
791 	if (nla_put_srh(skb, SEG6_IPTUNNEL_SRH, tuninfo))
792 		return -EMSGSIZE;
793 
794 	if (!ipv6_addr_any(&slwt->tunsrc) &&
795 	    nla_put_in6_addr(skb, SEG6_IPTUNNEL_SRC, &slwt->tunsrc))
796 		return -EMSGSIZE;
797 
798 	return 0;
799 }
800 
801 static int seg6_encap_nlsize(struct lwtunnel_state *lwtstate)
802 {
803 	struct seg6_iptunnel_encap *tuninfo = seg6_encap_lwtunnel(lwtstate);
804 	struct seg6_lwt *slwt = seg6_lwt_lwtunnel(lwtstate);
805 	int nlsize;
806 
807 	nlsize = nla_total_size(SEG6_IPTUN_ENCAP_SIZE(tuninfo));
808 
809 	if (!ipv6_addr_any(&slwt->tunsrc))
810 		nlsize += nla_total_size(sizeof(slwt->tunsrc));
811 
812 	return nlsize;
813 }
814 
815 static int seg6_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b)
816 {
817 	struct seg6_iptunnel_encap *a_hdr = seg6_encap_lwtunnel(a);
818 	struct seg6_iptunnel_encap *b_hdr = seg6_encap_lwtunnel(b);
819 	struct seg6_lwt *a_slwt = seg6_lwt_lwtunnel(a);
820 	struct seg6_lwt *b_slwt = seg6_lwt_lwtunnel(b);
821 	int len = SEG6_IPTUN_ENCAP_SIZE(a_hdr);
822 
823 	if (len != SEG6_IPTUN_ENCAP_SIZE(b_hdr))
824 		return 1;
825 
826 	if (!ipv6_addr_equal(&a_slwt->tunsrc, &b_slwt->tunsrc))
827 		return 1;
828 
829 	return memcmp(a_hdr, b_hdr, len);
830 }
831 
832 static const struct lwtunnel_encap_ops seg6_iptun_ops = {
833 	.build_state = seg6_build_state,
834 	.destroy_state = seg6_destroy_state,
835 	.output = seg6_output,
836 	.input = seg6_input,
837 	.fill_encap = seg6_fill_encap_info,
838 	.get_encap_size = seg6_encap_nlsize,
839 	.cmp_encap = seg6_encap_cmp,
840 	.owner = THIS_MODULE,
841 };
842 
843 int __init seg6_iptunnel_init(void)
844 {
845 	return lwtunnel_encap_add_ops(&seg6_iptun_ops, LWTUNNEL_ENCAP_SEG6);
846 }
847 
848 void seg6_iptunnel_exit(void)
849 {
850 	lwtunnel_encap_del_ops(&seg6_iptun_ops, LWTUNNEL_ENCAP_SEG6);
851 }
852