xref: /linux/net/ipv6/seg6_iptunnel.c (revision 8be4d31cb8aaeea27bde4b7ddb26e28a89062ebf)
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  *  SR-IPv6 implementation
4  *
5  *  Author:
6  *  David Lebrun <david.lebrun@uclouvain.be>
7  */
8 
9 #include <linux/types.h>
10 #include <linux/skbuff.h>
11 #include <linux/net.h>
12 #include <linux/module.h>
13 #include <net/ip.h>
14 #include <net/ip_tunnels.h>
15 #include <net/lwtunnel.h>
16 #include <net/netevent.h>
17 #include <net/netns/generic.h>
18 #include <net/ip6_fib.h>
19 #include <net/route.h>
20 #include <net/seg6.h>
21 #include <linux/seg6.h>
22 #include <linux/seg6_iptunnel.h>
23 #include <net/addrconf.h>
24 #include <net/ip6_route.h>
25 #include <net/dst_cache.h>
26 #ifdef CONFIG_IPV6_SEG6_HMAC
27 #include <net/seg6_hmac.h>
28 #endif
29 #include <linux/netfilter.h>
30 
seg6_lwt_headroom(struct seg6_iptunnel_encap * tuninfo)31 static size_t seg6_lwt_headroom(struct seg6_iptunnel_encap *tuninfo)
32 {
33 	int head = 0;
34 
35 	switch (tuninfo->mode) {
36 	case SEG6_IPTUN_MODE_INLINE:
37 		break;
38 	case SEG6_IPTUN_MODE_ENCAP:
39 	case SEG6_IPTUN_MODE_ENCAP_RED:
40 		head = sizeof(struct ipv6hdr);
41 		break;
42 	case SEG6_IPTUN_MODE_L2ENCAP:
43 	case SEG6_IPTUN_MODE_L2ENCAP_RED:
44 		return 0;
45 	}
46 
47 	return ((tuninfo->srh->hdrlen + 1) << 3) + head;
48 }
49 
50 struct seg6_lwt {
51 	struct dst_cache cache;
52 	struct seg6_iptunnel_encap tuninfo[];
53 };
54 
seg6_lwt_lwtunnel(struct lwtunnel_state * lwt)55 static inline struct seg6_lwt *seg6_lwt_lwtunnel(struct lwtunnel_state *lwt)
56 {
57 	return (struct seg6_lwt *)lwt->data;
58 }
59 
60 static inline struct seg6_iptunnel_encap *
seg6_encap_lwtunnel(struct lwtunnel_state * lwt)61 seg6_encap_lwtunnel(struct lwtunnel_state *lwt)
62 {
63 	return seg6_lwt_lwtunnel(lwt)->tuninfo;
64 }
65 
66 static const struct nla_policy seg6_iptunnel_policy[SEG6_IPTUNNEL_MAX + 1] = {
67 	[SEG6_IPTUNNEL_SRH]	= { .type = NLA_BINARY },
68 };
69 
nla_put_srh(struct sk_buff * skb,int attrtype,struct seg6_iptunnel_encap * tuninfo)70 static int nla_put_srh(struct sk_buff *skb, int attrtype,
71 		       struct seg6_iptunnel_encap *tuninfo)
72 {
73 	struct seg6_iptunnel_encap *data;
74 	struct nlattr *nla;
75 	int len;
76 
77 	len = SEG6_IPTUN_ENCAP_SIZE(tuninfo);
78 
79 	nla = nla_reserve(skb, attrtype, len);
80 	if (!nla)
81 		return -EMSGSIZE;
82 
83 	data = nla_data(nla);
84 	memcpy(data, tuninfo, len);
85 
86 	return 0;
87 }
88 
set_tun_src(struct net * net,struct net_device * dev,struct in6_addr * daddr,struct in6_addr * saddr)89 static void set_tun_src(struct net *net, struct net_device *dev,
90 			struct in6_addr *daddr, struct in6_addr *saddr)
91 {
92 	struct seg6_pernet_data *sdata = seg6_pernet(net);
93 	struct in6_addr *tun_src;
94 
95 	rcu_read_lock();
96 
97 	tun_src = rcu_dereference(sdata->tun_src);
98 
99 	if (!ipv6_addr_any(tun_src)) {
100 		memcpy(saddr, tun_src, sizeof(struct in6_addr));
101 	} else {
102 		ipv6_dev_get_saddr(net, dev, daddr, IPV6_PREFER_SRC_PUBLIC,
103 				   saddr);
104 	}
105 
106 	rcu_read_unlock();
107 }
108 
109 /* Compute flowlabel for outer IPv6 header */
seg6_make_flowlabel(struct net * net,struct sk_buff * skb,struct ipv6hdr * inner_hdr)110 static __be32 seg6_make_flowlabel(struct net *net, struct sk_buff *skb,
111 				  struct ipv6hdr *inner_hdr)
112 {
113 	int do_flowlabel = net->ipv6.sysctl.seg6_flowlabel;
114 	__be32 flowlabel = 0;
115 	u32 hash;
116 
117 	if (do_flowlabel > 0) {
118 		hash = skb_get_hash(skb);
119 		hash = rol32(hash, 16);
120 		flowlabel = (__force __be32)hash & IPV6_FLOWLABEL_MASK;
121 	} else if (!do_flowlabel && skb->protocol == htons(ETH_P_IPV6)) {
122 		flowlabel = ip6_flowlabel(inner_hdr);
123 	}
124 	return flowlabel;
125 }
126 
__seg6_do_srh_encap(struct sk_buff * skb,struct ipv6_sr_hdr * osrh,int proto,struct dst_entry * cache_dst)127 static int __seg6_do_srh_encap(struct sk_buff *skb, struct ipv6_sr_hdr *osrh,
128 			       int proto, struct dst_entry *cache_dst)
129 {
130 	struct dst_entry *dst = skb_dst(skb);
131 	struct net_device *dev = dst_dev(dst);
132 	struct net *net = dev_net(dev);
133 	struct ipv6hdr *hdr, *inner_hdr;
134 	struct ipv6_sr_hdr *isrh;
135 	int hdrlen, tot_len, err;
136 	__be32 flowlabel;
137 
138 	hdrlen = (osrh->hdrlen + 1) << 3;
139 	tot_len = hdrlen + sizeof(*hdr);
140 
141 	err = skb_cow_head(skb, tot_len + dst_dev_overhead(cache_dst, skb));
142 	if (unlikely(err))
143 		return err;
144 
145 	inner_hdr = ipv6_hdr(skb);
146 	flowlabel = seg6_make_flowlabel(net, skb, inner_hdr);
147 
148 	skb_push(skb, tot_len);
149 	skb_reset_network_header(skb);
150 	skb_mac_header_rebuild(skb);
151 	hdr = ipv6_hdr(skb);
152 
153 	/* inherit tc, flowlabel and hlim
154 	 * hlim will be decremented in ip6_forward() afterwards and
155 	 * decapsulation will overwrite inner hlim with outer hlim
156 	 */
157 
158 	if (skb->protocol == htons(ETH_P_IPV6)) {
159 		ip6_flow_hdr(hdr, ip6_tclass(ip6_flowinfo(inner_hdr)),
160 			     flowlabel);
161 		hdr->hop_limit = inner_hdr->hop_limit;
162 	} else {
163 		ip6_flow_hdr(hdr, 0, flowlabel);
164 		hdr->hop_limit = ip6_dst_hoplimit(skb_dst(skb));
165 
166 		memset(IP6CB(skb), 0, sizeof(*IP6CB(skb)));
167 
168 		/* the control block has been erased, so we have to set the
169 		 * iif once again.
170 		 * We read the receiving interface index directly from the
171 		 * skb->skb_iif as it is done in the IPv4 receiving path (i.e.:
172 		 * ip_rcv_core(...)).
173 		 */
174 		IP6CB(skb)->iif = skb->skb_iif;
175 	}
176 
177 	hdr->nexthdr = NEXTHDR_ROUTING;
178 
179 	isrh = (void *)hdr + sizeof(*hdr);
180 	memcpy(isrh, osrh, hdrlen);
181 
182 	isrh->nexthdr = proto;
183 
184 	hdr->daddr = isrh->segments[isrh->first_segment];
185 	set_tun_src(net, dev, &hdr->daddr, &hdr->saddr);
186 
187 #ifdef CONFIG_IPV6_SEG6_HMAC
188 	if (sr_has_hmac(isrh)) {
189 		err = seg6_push_hmac(net, &hdr->saddr, isrh);
190 		if (unlikely(err))
191 			return err;
192 	}
193 #endif
194 
195 	hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
196 
197 	skb_postpush_rcsum(skb, hdr, tot_len);
198 
199 	return 0;
200 }
201 
202 /* encapsulate an IPv6 packet within an outer IPv6 header with a given SRH */
seg6_do_srh_encap(struct sk_buff * skb,struct ipv6_sr_hdr * osrh,int proto)203 int seg6_do_srh_encap(struct sk_buff *skb, struct ipv6_sr_hdr *osrh, int proto)
204 {
205 	return __seg6_do_srh_encap(skb, osrh, proto, NULL);
206 }
207 EXPORT_SYMBOL_GPL(seg6_do_srh_encap);
208 
209 /* encapsulate an IPv6 packet within an outer IPv6 header with reduced SRH */
seg6_do_srh_encap_red(struct sk_buff * skb,struct ipv6_sr_hdr * osrh,int proto,struct dst_entry * cache_dst)210 static int seg6_do_srh_encap_red(struct sk_buff *skb,
211 				 struct ipv6_sr_hdr *osrh, int proto,
212 				 struct dst_entry *cache_dst)
213 {
214 	__u8 first_seg = osrh->first_segment;
215 	struct dst_entry *dst = skb_dst(skb);
216 	struct net_device *dev = dst_dev(dst);
217 	struct net *net = dev_net(dev);
218 	struct ipv6hdr *hdr, *inner_hdr;
219 	int hdrlen = ipv6_optlen(osrh);
220 	int red_tlv_offset, tlv_offset;
221 	struct ipv6_sr_hdr *isrh;
222 	bool skip_srh = false;
223 	__be32 flowlabel;
224 	int tot_len, err;
225 	int red_hdrlen;
226 	int tlvs_len;
227 
228 	if (first_seg > 0) {
229 		red_hdrlen = hdrlen - sizeof(struct in6_addr);
230 	} else {
231 		/* NOTE: if tag/flags and/or other TLVs are introduced in the
232 		 * seg6_iptunnel infrastructure, they should be considered when
233 		 * deciding to skip the SRH.
234 		 */
235 		skip_srh = !sr_has_hmac(osrh);
236 
237 		red_hdrlen = skip_srh ? 0 : hdrlen;
238 	}
239 
240 	tot_len = red_hdrlen + sizeof(struct ipv6hdr);
241 
242 	err = skb_cow_head(skb, tot_len + dst_dev_overhead(cache_dst, skb));
243 	if (unlikely(err))
244 		return err;
245 
246 	inner_hdr = ipv6_hdr(skb);
247 	flowlabel = seg6_make_flowlabel(net, skb, inner_hdr);
248 
249 	skb_push(skb, tot_len);
250 	skb_reset_network_header(skb);
251 	skb_mac_header_rebuild(skb);
252 	hdr = ipv6_hdr(skb);
253 
254 	/* based on seg6_do_srh_encap() */
255 	if (skb->protocol == htons(ETH_P_IPV6)) {
256 		ip6_flow_hdr(hdr, ip6_tclass(ip6_flowinfo(inner_hdr)),
257 			     flowlabel);
258 		hdr->hop_limit = inner_hdr->hop_limit;
259 	} else {
260 		ip6_flow_hdr(hdr, 0, flowlabel);
261 		hdr->hop_limit = ip6_dst_hoplimit(skb_dst(skb));
262 
263 		memset(IP6CB(skb), 0, sizeof(*IP6CB(skb)));
264 		IP6CB(skb)->iif = skb->skb_iif;
265 	}
266 
267 	/* no matter if we have to skip the SRH or not, the first segment
268 	 * always comes in the pushed IPv6 header.
269 	 */
270 	hdr->daddr = osrh->segments[first_seg];
271 
272 	if (skip_srh) {
273 		hdr->nexthdr = proto;
274 
275 		set_tun_src(net, dev, &hdr->daddr, &hdr->saddr);
276 		goto out;
277 	}
278 
279 	/* we cannot skip the SRH, slow path */
280 
281 	hdr->nexthdr = NEXTHDR_ROUTING;
282 	isrh = (void *)hdr + sizeof(struct ipv6hdr);
283 
284 	if (unlikely(!first_seg)) {
285 		/* this is a very rare case; we have only one SID but
286 		 * we cannot skip the SRH since we are carrying some
287 		 * other info.
288 		 */
289 		memcpy(isrh, osrh, hdrlen);
290 		goto srcaddr;
291 	}
292 
293 	tlv_offset = sizeof(*osrh) + (first_seg + 1) * sizeof(struct in6_addr);
294 	red_tlv_offset = tlv_offset - sizeof(struct in6_addr);
295 
296 	memcpy(isrh, osrh, red_tlv_offset);
297 
298 	tlvs_len = hdrlen - tlv_offset;
299 	if (unlikely(tlvs_len > 0)) {
300 		const void *s = (const void *)osrh + tlv_offset;
301 		void *d = (void *)isrh + red_tlv_offset;
302 
303 		memcpy(d, s, tlvs_len);
304 	}
305 
306 	--isrh->first_segment;
307 	isrh->hdrlen -= 2;
308 
309 srcaddr:
310 	isrh->nexthdr = proto;
311 	set_tun_src(net, dev, &hdr->daddr, &hdr->saddr);
312 
313 #ifdef CONFIG_IPV6_SEG6_HMAC
314 	if (unlikely(!skip_srh && sr_has_hmac(isrh))) {
315 		err = seg6_push_hmac(net, &hdr->saddr, isrh);
316 		if (unlikely(err))
317 			return err;
318 	}
319 #endif
320 
321 out:
322 	hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
323 
324 	skb_postpush_rcsum(skb, hdr, tot_len);
325 
326 	return 0;
327 }
328 
__seg6_do_srh_inline(struct sk_buff * skb,struct ipv6_sr_hdr * osrh,struct dst_entry * cache_dst)329 static int __seg6_do_srh_inline(struct sk_buff *skb, struct ipv6_sr_hdr *osrh,
330 				struct dst_entry *cache_dst)
331 {
332 	struct ipv6hdr *hdr, *oldhdr;
333 	struct ipv6_sr_hdr *isrh;
334 	int hdrlen, err;
335 
336 	hdrlen = (osrh->hdrlen + 1) << 3;
337 
338 	err = skb_cow_head(skb, hdrlen + dst_dev_overhead(cache_dst, skb));
339 	if (unlikely(err))
340 		return err;
341 
342 	oldhdr = ipv6_hdr(skb);
343 
344 	skb_pull(skb, sizeof(struct ipv6hdr));
345 	skb_postpull_rcsum(skb, skb_network_header(skb),
346 			   sizeof(struct ipv6hdr));
347 
348 	skb_push(skb, sizeof(struct ipv6hdr) + hdrlen);
349 	skb_reset_network_header(skb);
350 	skb_mac_header_rebuild(skb);
351 
352 	hdr = ipv6_hdr(skb);
353 
354 	memmove(hdr, oldhdr, sizeof(*hdr));
355 
356 	isrh = (void *)hdr + sizeof(*hdr);
357 	memcpy(isrh, osrh, hdrlen);
358 
359 	isrh->nexthdr = hdr->nexthdr;
360 	hdr->nexthdr = NEXTHDR_ROUTING;
361 
362 	isrh->segments[0] = hdr->daddr;
363 	hdr->daddr = isrh->segments[isrh->first_segment];
364 
365 #ifdef CONFIG_IPV6_SEG6_HMAC
366 	if (sr_has_hmac(isrh)) {
367 		struct net *net = skb_dst_dev_net(skb);
368 
369 		err = seg6_push_hmac(net, &hdr->saddr, isrh);
370 		if (unlikely(err))
371 			return err;
372 	}
373 #endif
374 
375 	hdr->payload_len = htons(skb->len - sizeof(struct ipv6hdr));
376 
377 	skb_postpush_rcsum(skb, hdr, sizeof(struct ipv6hdr) + hdrlen);
378 
379 	return 0;
380 }
381 
seg6_do_srh(struct sk_buff * skb,struct dst_entry * cache_dst)382 static int seg6_do_srh(struct sk_buff *skb, struct dst_entry *cache_dst)
383 {
384 	struct dst_entry *dst = skb_dst(skb);
385 	struct seg6_iptunnel_encap *tinfo;
386 	int proto, err = 0;
387 
388 	tinfo = seg6_encap_lwtunnel(dst->lwtstate);
389 
390 	switch (tinfo->mode) {
391 	case SEG6_IPTUN_MODE_INLINE:
392 		if (skb->protocol != htons(ETH_P_IPV6))
393 			return -EINVAL;
394 
395 		err = __seg6_do_srh_inline(skb, tinfo->srh, cache_dst);
396 		if (err)
397 			return err;
398 		break;
399 	case SEG6_IPTUN_MODE_ENCAP:
400 	case SEG6_IPTUN_MODE_ENCAP_RED:
401 		err = iptunnel_handle_offloads(skb, SKB_GSO_IPXIP6);
402 		if (err)
403 			return err;
404 
405 		if (skb->protocol == htons(ETH_P_IPV6))
406 			proto = IPPROTO_IPV6;
407 		else if (skb->protocol == htons(ETH_P_IP))
408 			proto = IPPROTO_IPIP;
409 		else
410 			return -EINVAL;
411 
412 		if (tinfo->mode == SEG6_IPTUN_MODE_ENCAP)
413 			err = __seg6_do_srh_encap(skb, tinfo->srh,
414 						  proto, cache_dst);
415 		else
416 			err = seg6_do_srh_encap_red(skb, tinfo->srh,
417 						    proto, cache_dst);
418 
419 		if (err)
420 			return err;
421 
422 		skb_set_inner_transport_header(skb, skb_transport_offset(skb));
423 		skb_set_inner_protocol(skb, skb->protocol);
424 		skb->protocol = htons(ETH_P_IPV6);
425 		break;
426 	case SEG6_IPTUN_MODE_L2ENCAP:
427 	case SEG6_IPTUN_MODE_L2ENCAP_RED:
428 		if (!skb_mac_header_was_set(skb))
429 			return -EINVAL;
430 
431 		if (pskb_expand_head(skb, skb->mac_len, 0, GFP_ATOMIC) < 0)
432 			return -ENOMEM;
433 
434 		skb_mac_header_rebuild(skb);
435 		skb_push(skb, skb->mac_len);
436 
437 		if (tinfo->mode == SEG6_IPTUN_MODE_L2ENCAP)
438 			err = __seg6_do_srh_encap(skb, tinfo->srh,
439 						  IPPROTO_ETHERNET,
440 						  cache_dst);
441 		else
442 			err = seg6_do_srh_encap_red(skb, tinfo->srh,
443 						    IPPROTO_ETHERNET,
444 						    cache_dst);
445 
446 		if (err)
447 			return err;
448 
449 		skb->protocol = htons(ETH_P_IPV6);
450 		break;
451 	}
452 
453 	skb_set_transport_header(skb, sizeof(struct ipv6hdr));
454 	nf_reset_ct(skb);
455 
456 	return 0;
457 }
458 
459 /* insert an SRH within an IPv6 packet, just after the IPv6 header */
seg6_do_srh_inline(struct sk_buff * skb,struct ipv6_sr_hdr * osrh)460 int seg6_do_srh_inline(struct sk_buff *skb, struct ipv6_sr_hdr *osrh)
461 {
462 	return __seg6_do_srh_inline(skb, osrh, NULL);
463 }
464 EXPORT_SYMBOL_GPL(seg6_do_srh_inline);
465 
seg6_input_finish(struct net * net,struct sock * sk,struct sk_buff * skb)466 static int seg6_input_finish(struct net *net, struct sock *sk,
467 			     struct sk_buff *skb)
468 {
469 	return dst_input(skb);
470 }
471 
seg6_input_core(struct net * net,struct sock * sk,struct sk_buff * skb)472 static int seg6_input_core(struct net *net, struct sock *sk,
473 			   struct sk_buff *skb)
474 {
475 	struct dst_entry *orig_dst = skb_dst(skb);
476 	struct dst_entry *dst = NULL;
477 	struct lwtunnel_state *lwtst;
478 	struct seg6_lwt *slwt;
479 	int err;
480 
481 	/* We cannot dereference "orig_dst" once ip6_route_input() or
482 	 * skb_dst_drop() is called. However, in order to detect a dst loop, we
483 	 * need the address of its lwtstate. So, save the address of lwtstate
484 	 * now and use it later as a comparison.
485 	 */
486 	lwtst = orig_dst->lwtstate;
487 
488 	slwt = seg6_lwt_lwtunnel(lwtst);
489 
490 	local_bh_disable();
491 	dst = dst_cache_get(&slwt->cache);
492 	local_bh_enable();
493 
494 	err = seg6_do_srh(skb, dst);
495 	if (unlikely(err)) {
496 		dst_release(dst);
497 		goto drop;
498 	}
499 
500 	if (!dst) {
501 		ip6_route_input(skb);
502 		dst = skb_dst(skb);
503 
504 		/* cache only if we don't create a dst reference loop */
505 		if (!dst->error && lwtst != dst->lwtstate) {
506 			local_bh_disable();
507 			dst_cache_set_ip6(&slwt->cache, dst,
508 					  &ipv6_hdr(skb)->saddr);
509 			local_bh_enable();
510 		}
511 
512 		err = skb_cow_head(skb, LL_RESERVED_SPACE(dst_dev(dst)));
513 		if (unlikely(err))
514 			goto drop;
515 	} else {
516 		skb_dst_drop(skb);
517 		skb_dst_set(skb, dst);
518 	}
519 
520 	if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
521 		return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT,
522 			       dev_net(skb->dev), NULL, skb, NULL,
523 			       skb_dst_dev(skb), seg6_input_finish);
524 
525 	return seg6_input_finish(dev_net(skb->dev), NULL, skb);
526 drop:
527 	kfree_skb(skb);
528 	return err;
529 }
530 
seg6_input_nf(struct sk_buff * skb)531 static int seg6_input_nf(struct sk_buff *skb)
532 {
533 	struct net_device *dev = skb_dst_dev(skb);
534 	struct net *net = dev_net(skb->dev);
535 
536 	switch (skb->protocol) {
537 	case htons(ETH_P_IP):
538 		return NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING, net, NULL,
539 			       skb, NULL, dev, seg6_input_core);
540 	case htons(ETH_P_IPV6):
541 		return NF_HOOK(NFPROTO_IPV6, NF_INET_POST_ROUTING, net, NULL,
542 			       skb, NULL, dev, seg6_input_core);
543 	}
544 
545 	return -EINVAL;
546 }
547 
seg6_input(struct sk_buff * skb)548 static int seg6_input(struct sk_buff *skb)
549 {
550 	if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
551 		return seg6_input_nf(skb);
552 
553 	return seg6_input_core(dev_net(skb->dev), NULL, skb);
554 }
555 
seg6_output_core(struct net * net,struct sock * sk,struct sk_buff * skb)556 static int seg6_output_core(struct net *net, struct sock *sk,
557 			    struct sk_buff *skb)
558 {
559 	struct dst_entry *orig_dst = skb_dst(skb);
560 	struct dst_entry *dst = NULL;
561 	struct seg6_lwt *slwt;
562 	int err;
563 
564 	slwt = seg6_lwt_lwtunnel(orig_dst->lwtstate);
565 
566 	local_bh_disable();
567 	dst = dst_cache_get(&slwt->cache);
568 	local_bh_enable();
569 
570 	err = seg6_do_srh(skb, dst);
571 	if (unlikely(err))
572 		goto drop;
573 
574 	if (unlikely(!dst)) {
575 		struct ipv6hdr *hdr = ipv6_hdr(skb);
576 		struct flowi6 fl6;
577 
578 		memset(&fl6, 0, sizeof(fl6));
579 		fl6.daddr = hdr->daddr;
580 		fl6.saddr = hdr->saddr;
581 		fl6.flowlabel = ip6_flowinfo(hdr);
582 		fl6.flowi6_mark = skb->mark;
583 		fl6.flowi6_proto = hdr->nexthdr;
584 
585 		dst = ip6_route_output(net, NULL, &fl6);
586 		if (dst->error) {
587 			err = dst->error;
588 			goto drop;
589 		}
590 
591 		/* cache only if we don't create a dst reference loop */
592 		if (orig_dst->lwtstate != dst->lwtstate) {
593 			local_bh_disable();
594 			dst_cache_set_ip6(&slwt->cache, dst, &fl6.saddr);
595 			local_bh_enable();
596 		}
597 
598 		err = skb_cow_head(skb, LL_RESERVED_SPACE(dst_dev(dst)));
599 		if (unlikely(err))
600 			goto drop;
601 	}
602 
603 	skb_dst_drop(skb);
604 	skb_dst_set(skb, dst);
605 
606 	if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
607 		return NF_HOOK(NFPROTO_IPV6, NF_INET_LOCAL_OUT, net, sk, skb,
608 			       NULL, dst_dev(dst), dst_output);
609 
610 	return dst_output(net, sk, skb);
611 drop:
612 	dst_release(dst);
613 	kfree_skb(skb);
614 	return err;
615 }
616 
seg6_output_nf(struct net * net,struct sock * sk,struct sk_buff * skb)617 static int seg6_output_nf(struct net *net, struct sock *sk, struct sk_buff *skb)
618 {
619 	struct net_device *dev = skb_dst_dev(skb);
620 
621 	switch (skb->protocol) {
622 	case htons(ETH_P_IP):
623 		return NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING, net, sk, skb,
624 			       NULL, dev, seg6_output_core);
625 	case htons(ETH_P_IPV6):
626 		return NF_HOOK(NFPROTO_IPV6, NF_INET_POST_ROUTING, net, sk, skb,
627 			       NULL, dev, seg6_output_core);
628 	}
629 
630 	return -EINVAL;
631 }
632 
seg6_output(struct net * net,struct sock * sk,struct sk_buff * skb)633 static int seg6_output(struct net *net, struct sock *sk, struct sk_buff *skb)
634 {
635 	if (static_branch_unlikely(&nf_hooks_lwtunnel_enabled))
636 		return seg6_output_nf(net, sk, skb);
637 
638 	return seg6_output_core(net, sk, skb);
639 }
640 
seg6_build_state(struct net * net,struct nlattr * nla,unsigned int family,const void * cfg,struct lwtunnel_state ** ts,struct netlink_ext_ack * extack)641 static int seg6_build_state(struct net *net, struct nlattr *nla,
642 			    unsigned int family, const void *cfg,
643 			    struct lwtunnel_state **ts,
644 			    struct netlink_ext_ack *extack)
645 {
646 	struct nlattr *tb[SEG6_IPTUNNEL_MAX + 1];
647 	struct seg6_iptunnel_encap *tuninfo;
648 	struct lwtunnel_state *newts;
649 	int tuninfo_len, min_size;
650 	struct seg6_lwt *slwt;
651 	int err;
652 
653 	if (family != AF_INET && family != AF_INET6)
654 		return -EINVAL;
655 
656 	err = nla_parse_nested_deprecated(tb, SEG6_IPTUNNEL_MAX, nla,
657 					  seg6_iptunnel_policy, extack);
658 
659 	if (err < 0)
660 		return err;
661 
662 	if (!tb[SEG6_IPTUNNEL_SRH])
663 		return -EINVAL;
664 
665 	tuninfo = nla_data(tb[SEG6_IPTUNNEL_SRH]);
666 	tuninfo_len = nla_len(tb[SEG6_IPTUNNEL_SRH]);
667 
668 	/* tuninfo must contain at least the iptunnel encap structure,
669 	 * the SRH and one segment
670 	 */
671 	min_size = sizeof(*tuninfo) + sizeof(struct ipv6_sr_hdr) +
672 		   sizeof(struct in6_addr);
673 	if (tuninfo_len < min_size)
674 		return -EINVAL;
675 
676 	switch (tuninfo->mode) {
677 	case SEG6_IPTUN_MODE_INLINE:
678 		if (family != AF_INET6)
679 			return -EINVAL;
680 
681 		break;
682 	case SEG6_IPTUN_MODE_ENCAP:
683 		break;
684 	case SEG6_IPTUN_MODE_L2ENCAP:
685 		break;
686 	case SEG6_IPTUN_MODE_ENCAP_RED:
687 		break;
688 	case SEG6_IPTUN_MODE_L2ENCAP_RED:
689 		break;
690 	default:
691 		return -EINVAL;
692 	}
693 
694 	/* verify that SRH is consistent */
695 	if (!seg6_validate_srh(tuninfo->srh, tuninfo_len - sizeof(*tuninfo), false))
696 		return -EINVAL;
697 
698 	newts = lwtunnel_state_alloc(tuninfo_len + sizeof(*slwt));
699 	if (!newts)
700 		return -ENOMEM;
701 
702 	slwt = seg6_lwt_lwtunnel(newts);
703 
704 	err = dst_cache_init(&slwt->cache, GFP_ATOMIC);
705 	if (err) {
706 		kfree(newts);
707 		return err;
708 	}
709 
710 	memcpy(&slwt->tuninfo, tuninfo, tuninfo_len);
711 
712 	newts->type = LWTUNNEL_ENCAP_SEG6;
713 	newts->flags |= LWTUNNEL_STATE_INPUT_REDIRECT;
714 
715 	if (tuninfo->mode != SEG6_IPTUN_MODE_L2ENCAP)
716 		newts->flags |= LWTUNNEL_STATE_OUTPUT_REDIRECT;
717 
718 	newts->headroom = seg6_lwt_headroom(tuninfo);
719 
720 	*ts = newts;
721 
722 	return 0;
723 }
724 
seg6_destroy_state(struct lwtunnel_state * lwt)725 static void seg6_destroy_state(struct lwtunnel_state *lwt)
726 {
727 	dst_cache_destroy(&seg6_lwt_lwtunnel(lwt)->cache);
728 }
729 
seg6_fill_encap_info(struct sk_buff * skb,struct lwtunnel_state * lwtstate)730 static int seg6_fill_encap_info(struct sk_buff *skb,
731 				struct lwtunnel_state *lwtstate)
732 {
733 	struct seg6_iptunnel_encap *tuninfo = seg6_encap_lwtunnel(lwtstate);
734 
735 	if (nla_put_srh(skb, SEG6_IPTUNNEL_SRH, tuninfo))
736 		return -EMSGSIZE;
737 
738 	return 0;
739 }
740 
seg6_encap_nlsize(struct lwtunnel_state * lwtstate)741 static int seg6_encap_nlsize(struct lwtunnel_state *lwtstate)
742 {
743 	struct seg6_iptunnel_encap *tuninfo = seg6_encap_lwtunnel(lwtstate);
744 
745 	return nla_total_size(SEG6_IPTUN_ENCAP_SIZE(tuninfo));
746 }
747 
seg6_encap_cmp(struct lwtunnel_state * a,struct lwtunnel_state * b)748 static int seg6_encap_cmp(struct lwtunnel_state *a, struct lwtunnel_state *b)
749 {
750 	struct seg6_iptunnel_encap *a_hdr = seg6_encap_lwtunnel(a);
751 	struct seg6_iptunnel_encap *b_hdr = seg6_encap_lwtunnel(b);
752 	int len = SEG6_IPTUN_ENCAP_SIZE(a_hdr);
753 
754 	if (len != SEG6_IPTUN_ENCAP_SIZE(b_hdr))
755 		return 1;
756 
757 	return memcmp(a_hdr, b_hdr, len);
758 }
759 
760 static const struct lwtunnel_encap_ops seg6_iptun_ops = {
761 	.build_state = seg6_build_state,
762 	.destroy_state = seg6_destroy_state,
763 	.output = seg6_output,
764 	.input = seg6_input,
765 	.fill_encap = seg6_fill_encap_info,
766 	.get_encap_size = seg6_encap_nlsize,
767 	.cmp_encap = seg6_encap_cmp,
768 	.owner = THIS_MODULE,
769 };
770 
seg6_iptunnel_init(void)771 int __init seg6_iptunnel_init(void)
772 {
773 	return lwtunnel_encap_add_ops(&seg6_iptun_ops, LWTUNNEL_ENCAP_SEG6);
774 }
775 
seg6_iptunnel_exit(void)776 void seg6_iptunnel_exit(void)
777 {
778 	lwtunnel_encap_del_ops(&seg6_iptun_ops, LWTUNNEL_ENCAP_SEG6);
779 }
780