xref: /linux/tools/testing/selftests/net/lib/xdp_native.bpf.c (revision 4f38da1f027ea2c9f01bb71daa7a299c191b6940)
1 // SPDX-License-Identifier: GPL-2.0
2 
3 #include <stddef.h>
4 #include <linux/bpf.h>
5 #include <linux/in.h>
6 #include <linux/if_ether.h>
7 #include <linux/ip.h>
8 #include <linux/ipv6.h>
9 #include <linux/udp.h>
10 #include <bpf/bpf_endian.h>
11 #include <bpf/bpf_helpers.h>
12 
13 #define MAX_ADJST_OFFSET 256
14 #define MAX_PAYLOAD_LEN 5000
15 #define MAX_HDR_LEN 64
16 
17 extern int bpf_xdp_pull_data(struct xdp_md *xdp, __u32 len) __ksym __weak;
18 
19 enum {
20 	XDP_MODE = 0,
21 	XDP_PORT = 1,
22 	XDP_ADJST_OFFSET = 2,
23 	XDP_ADJST_TAG = 3,
24 } xdp_map_setup_keys;
25 
26 enum {
27 	XDP_MODE_PASS = 0,
28 	XDP_MODE_DROP = 1,
29 	XDP_MODE_TX = 2,
30 	XDP_MODE_TAIL_ADJST = 3,
31 	XDP_MODE_HEAD_ADJST = 4,
32 } xdp_map_modes;
33 
34 enum {
35 	STATS_RX = 0,
36 	STATS_PASS = 1,
37 	STATS_DROP = 2,
38 	STATS_TX = 3,
39 	STATS_ABORT = 4,
40 } xdp_stats;
41 
42 struct {
43 	__uint(type, BPF_MAP_TYPE_ARRAY);
44 	__uint(max_entries, 5);
45 	__type(key, __u32);
46 	__type(value, __s32);
47 } map_xdp_setup SEC(".maps");
48 
49 struct {
50 	__uint(type, BPF_MAP_TYPE_ARRAY);
51 	__uint(max_entries, 5);
52 	__type(key, __u32);
53 	__type(value, __u64);
54 } map_xdp_stats SEC(".maps");
55 
56 static __u32 min(__u32 a, __u32 b)
57 {
58 	return a < b ? a : b;
59 }
60 
61 static void record_stats(struct xdp_md *ctx, __u32 stat_type)
62 {
63 	__u64 *count;
64 
65 	count = bpf_map_lookup_elem(&map_xdp_stats, &stat_type);
66 
67 	if (count)
68 		__sync_fetch_and_add(count, 1);
69 }
70 
71 static struct udphdr *filter_udphdr(struct xdp_md *ctx, __u16 port)
72 {
73 	struct udphdr *udph = NULL;
74 	void *data, *data_end;
75 	struct ethhdr *eth;
76 	int err;
77 
78 	err = bpf_xdp_pull_data(ctx, sizeof(*eth));
79 	if (err)
80 		return NULL;
81 
82 	data_end = (void *)(long)ctx->data_end;
83 	data = eth = (void *)(long)ctx->data;
84 
85 	if (data + sizeof(*eth) > data_end)
86 		return NULL;
87 
88 	if (eth->h_proto == bpf_htons(ETH_P_IP)) {
89 		struct iphdr *iph;
90 
91 		err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*iph) +
92 					     sizeof(*udph));
93 		if (err)
94 			return NULL;
95 
96 		data_end = (void *)(long)ctx->data_end;
97 		data = (void *)(long)ctx->data;
98 
99 		iph = data + sizeof(*eth);
100 
101 		if (iph + 1 > (struct iphdr *)data_end ||
102 		    iph->protocol != IPPROTO_UDP)
103 			return NULL;
104 
105 		udph = data + sizeof(*iph) + sizeof(*eth);
106 	} else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
107 		struct ipv6hdr *ipv6h;
108 
109 		err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*ipv6h) +
110 					     sizeof(*udph));
111 		if (err)
112 			return NULL;
113 
114 		data_end = (void *)(long)ctx->data_end;
115 		data = (void *)(long)ctx->data;
116 
117 		ipv6h = data + sizeof(*eth);
118 
119 		if (ipv6h + 1 > (struct ipv6hdr *)data_end ||
120 		    ipv6h->nexthdr != IPPROTO_UDP)
121 			return NULL;
122 
123 		udph = data + sizeof(*ipv6h) + sizeof(*eth);
124 	} else {
125 		return NULL;
126 	}
127 
128 	if (udph + 1 > (struct udphdr *)data_end)
129 		return NULL;
130 
131 	if (udph->dest != bpf_htons(port))
132 		return NULL;
133 
134 	record_stats(ctx, STATS_RX);
135 
136 	return udph;
137 }
138 
139 static int xdp_mode_pass(struct xdp_md *ctx, __u16 port)
140 {
141 	struct udphdr *udph = NULL;
142 
143 	udph = filter_udphdr(ctx, port);
144 	if (!udph)
145 		return XDP_PASS;
146 
147 	record_stats(ctx, STATS_PASS);
148 
149 	return XDP_PASS;
150 }
151 
152 static int xdp_mode_drop_handler(struct xdp_md *ctx, __u16 port)
153 {
154 	struct udphdr *udph = NULL;
155 
156 	udph = filter_udphdr(ctx, port);
157 	if (!udph)
158 		return XDP_PASS;
159 
160 	record_stats(ctx, STATS_DROP);
161 
162 	return XDP_DROP;
163 }
164 
165 static void swap_machdr(void *data)
166 {
167 	struct ethhdr *eth = data;
168 	__u8 tmp_mac[ETH_ALEN];
169 
170 	__builtin_memcpy(tmp_mac, eth->h_source, ETH_ALEN);
171 	__builtin_memcpy(eth->h_source, eth->h_dest, ETH_ALEN);
172 	__builtin_memcpy(eth->h_dest, tmp_mac, ETH_ALEN);
173 }
174 
175 static int xdp_mode_tx_handler(struct xdp_md *ctx, __u16 port)
176 {
177 	struct udphdr *udph = NULL;
178 	void *data, *data_end;
179 	struct ethhdr *eth;
180 	int err;
181 
182 	err = bpf_xdp_pull_data(ctx, sizeof(*eth));
183 	if (err)
184 		return XDP_PASS;
185 
186 	data_end = (void *)(long)ctx->data_end;
187 	data = eth = (void *)(long)ctx->data;
188 
189 	if (data + sizeof(*eth) > data_end)
190 		return XDP_PASS;
191 
192 	if (eth->h_proto == bpf_htons(ETH_P_IP)) {
193 		struct iphdr *iph;
194 		__be32 tmp_ip;
195 
196 		err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*iph) +
197 					     sizeof(*udph));
198 		if (err)
199 			return XDP_PASS;
200 
201 		data_end = (void *)(long)ctx->data_end;
202 		data = (void *)(long)ctx->data;
203 
204 		iph = data + sizeof(*eth);
205 
206 		if (iph + 1 > (struct iphdr *)data_end ||
207 		    iph->protocol != IPPROTO_UDP)
208 			return XDP_PASS;
209 
210 		udph = data + sizeof(*iph) + sizeof(*eth);
211 
212 		if (udph + 1 > (struct udphdr *)data_end)
213 			return XDP_PASS;
214 		if (udph->dest != bpf_htons(port))
215 			return XDP_PASS;
216 
217 		record_stats(ctx, STATS_RX);
218 		eth = data;
219 		swap_machdr((void *)eth);
220 
221 		tmp_ip = iph->saddr;
222 		iph->saddr = iph->daddr;
223 		iph->daddr = tmp_ip;
224 
225 		record_stats(ctx, STATS_TX);
226 
227 		return XDP_TX;
228 
229 	} else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
230 		struct in6_addr tmp_ipv6;
231 		struct ipv6hdr *ipv6h;
232 
233 		err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*ipv6h) +
234 					     sizeof(*udph));
235 		if (err)
236 			return XDP_PASS;
237 
238 		data_end = (void *)(long)ctx->data_end;
239 		data = (void *)(long)ctx->data;
240 
241 		ipv6h = data + sizeof(*eth);
242 
243 		if (ipv6h + 1 > (struct ipv6hdr *)data_end ||
244 		    ipv6h->nexthdr != IPPROTO_UDP)
245 			return XDP_PASS;
246 
247 		udph = data + sizeof(*ipv6h) + sizeof(*eth);
248 
249 		if (udph + 1 > (struct udphdr *)data_end)
250 			return XDP_PASS;
251 		if (udph->dest != bpf_htons(port))
252 			return XDP_PASS;
253 
254 		record_stats(ctx, STATS_RX);
255 		eth = data;
256 		swap_machdr((void *)eth);
257 
258 		__builtin_memcpy(&tmp_ipv6, &ipv6h->saddr, sizeof(tmp_ipv6));
259 		__builtin_memcpy(&ipv6h->saddr, &ipv6h->daddr,
260 				 sizeof(tmp_ipv6));
261 		__builtin_memcpy(&ipv6h->daddr, &tmp_ipv6, sizeof(tmp_ipv6));
262 
263 		record_stats(ctx, STATS_TX);
264 
265 		return XDP_TX;
266 	}
267 
268 	return XDP_PASS;
269 }
270 
271 static void *update_pkt(struct xdp_md *ctx, __s16 offset, __u32 *udp_csum)
272 {
273 	void *data_end = (void *)(long)ctx->data_end;
274 	void *data = (void *)(long)ctx->data;
275 	struct udphdr *udph = NULL;
276 	struct ethhdr *eth = data;
277 	__u32 len, len_new;
278 
279 	if (data + sizeof(*eth) > data_end)
280 		return NULL;
281 
282 	if (eth->h_proto == bpf_htons(ETH_P_IP)) {
283 		struct iphdr *iph = data + sizeof(*eth);
284 		__u16 total_len;
285 
286 		if (iph + 1 > (struct iphdr *)data_end)
287 			return NULL;
288 
289 		iph->tot_len = bpf_htons(bpf_ntohs(iph->tot_len) + offset);
290 
291 		udph = (void *)eth + sizeof(*iph) + sizeof(*eth);
292 		if (!udph || udph + 1 > (struct udphdr *)data_end)
293 			return NULL;
294 
295 		len_new = bpf_htons(bpf_ntohs(udph->len) + offset);
296 	} else if (eth->h_proto  == bpf_htons(ETH_P_IPV6)) {
297 		struct ipv6hdr *ipv6h = data + sizeof(*eth);
298 		__u16 payload_len;
299 
300 		if (ipv6h + 1 > (struct ipv6hdr *)data_end)
301 			return NULL;
302 
303 		udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth);
304 		if (!udph || udph + 1 > (struct udphdr *)data_end)
305 			return NULL;
306 
307 		*udp_csum = ~((__u32)udph->check);
308 
309 		len = ipv6h->payload_len;
310 		len_new = bpf_htons(bpf_ntohs(len) + offset);
311 		ipv6h->payload_len = len_new;
312 
313 		*udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new,
314 					  sizeof(len_new), *udp_csum);
315 
316 		len = udph->len;
317 		len_new = bpf_htons(bpf_ntohs(udph->len) + offset);
318 		*udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new,
319 					  sizeof(len_new), *udp_csum);
320 	} else {
321 		return NULL;
322 	}
323 
324 	udph->len = len_new;
325 
326 	return udph;
327 }
328 
329 static __u16 csum_fold_helper(__u32 csum)
330 {
331 	return ~((csum & 0xffff) + (csum >> 16)) ? : 0xffff;
332 }
333 
334 static int xdp_adjst_tail_shrnk_data(struct xdp_md *ctx, __u16 offset,
335 				     __u32 hdr_len)
336 {
337 	char tmp_buff[MAX_ADJST_OFFSET];
338 	__u32 buff_pos, udp_csum = 0;
339 	struct udphdr *udph = NULL;
340 	__u32 buff_len;
341 
342 	udph = update_pkt(ctx, 0 - offset, &udp_csum);
343 	if (!udph)
344 		return -1;
345 
346 	buff_len = bpf_xdp_get_buff_len(ctx);
347 
348 	offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
349 				     offset & 0xff;
350 	if (offset == 0)
351 		return -1;
352 
353 	/* Make sure we have enough data to avoid eating the header */
354 	if (buff_len - offset < hdr_len)
355 		return -1;
356 
357 	buff_pos = buff_len - offset;
358 	if (bpf_xdp_load_bytes(ctx, buff_pos, tmp_buff, offset) < 0)
359 		return -1;
360 
361 	udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum);
362 	udph->check = (__u16)csum_fold_helper(udp_csum);
363 
364 	if (bpf_xdp_adjust_tail(ctx, 0 - offset) < 0)
365 		return -1;
366 
367 	return 0;
368 }
369 
370 static int xdp_adjst_tail_grow_data(struct xdp_md *ctx, __u16 offset)
371 {
372 	char tmp_buff[MAX_ADJST_OFFSET];
373 	__u32 buff_pos, udp_csum = 0;
374 	__u32 buff_len, hdr_len, key;
375 	struct udphdr *udph;
376 	__s32 *val;
377 	__u8 tag;
378 
379 	/* Proceed to update the packet headers before attempting to adjuste
380 	 * the tail. Once the tail is adjusted we lose access to the offset
381 	 * amount of data at the end of the packet which is crucial to update
382 	 * the checksum.
383 	 * Since any failure beyond this would abort the packet, we should
384 	 * not worry about passing a packet up the stack with wrong headers
385 	 */
386 	udph = update_pkt(ctx, offset, &udp_csum);
387 	if (!udph)
388 		return -1;
389 
390 	key = XDP_ADJST_TAG;
391 	val = bpf_map_lookup_elem(&map_xdp_setup, &key);
392 	if (!val)
393 		return -1;
394 
395 	tag = (__u8)(*val);
396 
397 	for (int i = 0; i < MAX_ADJST_OFFSET; i++)
398 		__builtin_memcpy(&tmp_buff[i], &tag, 1);
399 
400 	offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
401 				     offset & 0xff;
402 	if (offset == 0)
403 		return -1;
404 
405 	udp_csum = bpf_csum_diff(0, 0, (__be32 *)tmp_buff, offset, udp_csum);
406 	udph->check = (__u16)csum_fold_helper(udp_csum);
407 
408 	buff_len = bpf_xdp_get_buff_len(ctx);
409 
410 	if (bpf_xdp_adjust_tail(ctx, offset) < 0) {
411 		bpf_printk("Failed to adjust tail\n");
412 		return -1;
413 	}
414 
415 	if (bpf_xdp_store_bytes(ctx, buff_len, tmp_buff, offset) < 0)
416 		return -1;
417 
418 	return 0;
419 }
420 
421 static int xdp_adjst_tail(struct xdp_md *ctx, __u16 port)
422 {
423 	struct udphdr *udph = NULL;
424 	__s32 *adjust_offset, *val;
425 	__u32 key, hdr_len;
426 	void *offset_ptr;
427 	__u8 tag;
428 	int ret;
429 
430 	udph = filter_udphdr(ctx, port);
431 	if (!udph)
432 		return XDP_PASS;
433 
434 	hdr_len = (void *)udph - (void *)(long)ctx->data +
435 		  sizeof(struct udphdr);
436 	key = XDP_ADJST_OFFSET;
437 	adjust_offset = bpf_map_lookup_elem(&map_xdp_setup, &key);
438 	if (!adjust_offset)
439 		return XDP_PASS;
440 
441 	if (*adjust_offset < 0)
442 		ret = xdp_adjst_tail_shrnk_data(ctx,
443 						(__u16)(0 - *adjust_offset),
444 						hdr_len);
445 	else
446 		ret = xdp_adjst_tail_grow_data(ctx, (__u16)(*adjust_offset));
447 	if (ret)
448 		goto abort_pkt;
449 
450 	record_stats(ctx, STATS_PASS);
451 	return XDP_PASS;
452 
453 abort_pkt:
454 	record_stats(ctx, STATS_ABORT);
455 	return XDP_ABORTED;
456 }
457 
458 static int xdp_adjst_head_shrnk_data(struct xdp_md *ctx, __u64 hdr_len,
459 				     __u32 offset)
460 {
461 	char tmp_buff[MAX_ADJST_OFFSET];
462 	struct udphdr *udph;
463 	void *offset_ptr;
464 	__u32 udp_csum = 0;
465 
466 	/* Update the length information in the IP and UDP headers before
467 	 * adjusting the headroom. This simplifies accessing the relevant
468 	 * fields in the IP and UDP headers for fragmented packets. Any
469 	 * failure beyond this point will result in the packet being aborted,
470 	 * so we don't need to worry about incorrect length information for
471 	 * passed packets.
472 	 */
473 	udph = update_pkt(ctx, (__s16)(0 - offset), &udp_csum);
474 	if (!udph)
475 		return -1;
476 
477 	offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
478 				     offset & 0xff;
479 	if (offset == 0)
480 		return -1;
481 
482 	if (bpf_xdp_load_bytes(ctx, hdr_len, tmp_buff, offset) < 0)
483 		return -1;
484 
485 	udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum);
486 
487 	udph->check = (__u16)csum_fold_helper(udp_csum);
488 
489 	if (bpf_xdp_load_bytes(ctx, 0, tmp_buff, MAX_ADJST_OFFSET) < 0)
490 		return -1;
491 
492 	if (bpf_xdp_adjust_head(ctx, offset) < 0)
493 		return -1;
494 
495 	if (offset > MAX_ADJST_OFFSET)
496 		return -1;
497 
498 	if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0)
499 		return -1;
500 
501 	/* Added here to handle clang complain about negative value */
502 	hdr_len = hdr_len & 0xff;
503 
504 	if (hdr_len == 0)
505 		return -1;
506 
507 	if (bpf_xdp_store_bytes(ctx, 0, tmp_buff, hdr_len) < 0)
508 		return -1;
509 
510 	return 0;
511 }
512 
513 static int xdp_adjst_head_grow_data(struct xdp_md *ctx, __u64 hdr_len,
514 				    __u32 offset)
515 {
516 	char hdr_buff[MAX_HDR_LEN];
517 	char data_buff[MAX_ADJST_OFFSET];
518 	void *offset_ptr;
519 	__s32 *val;
520 	__u32 key;
521 	__u8 tag;
522 	__u32 udp_csum = 0;
523 	struct udphdr *udph;
524 
525 	udph = update_pkt(ctx, (__s16)(offset), &udp_csum);
526 	if (!udph)
527 		return -1;
528 
529 	key = XDP_ADJST_TAG;
530 	val = bpf_map_lookup_elem(&map_xdp_setup, &key);
531 	if (!val)
532 		return -1;
533 
534 	tag = (__u8)(*val);
535 	for (int i = 0; i < MAX_ADJST_OFFSET; i++)
536 		__builtin_memcpy(&data_buff[i], &tag, 1);
537 
538 	offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
539 				     offset & 0xff;
540 	if (offset == 0)
541 		return -1;
542 
543 	udp_csum = bpf_csum_diff(0, 0, (__be32 *)data_buff, offset, udp_csum);
544 	udph->check = (__u16)csum_fold_helper(udp_csum);
545 
546 	if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0)
547 		return -1;
548 
549 	/* Added here to handle clang complain about negative value */
550 	hdr_len = hdr_len & 0xff;
551 
552 	if (hdr_len == 0)
553 		return -1;
554 
555 	if (bpf_xdp_load_bytes(ctx, 0, hdr_buff, hdr_len) < 0)
556 		return -1;
557 
558 	if (offset > MAX_ADJST_OFFSET)
559 		return -1;
560 
561 	if (bpf_xdp_adjust_head(ctx, 0 - offset) < 0)
562 		return -1;
563 
564 	if (bpf_xdp_store_bytes(ctx, 0, hdr_buff, hdr_len) < 0)
565 		return -1;
566 
567 	if (bpf_xdp_store_bytes(ctx, hdr_len, data_buff, offset) < 0)
568 		return -1;
569 
570 	return 0;
571 }
572 
573 static int xdp_head_adjst(struct xdp_md *ctx, __u16 port)
574 {
575 	struct udphdr *udph_ptr = NULL;
576 	__u32 key, size, hdr_len;
577 	__s32 *val;
578 	int res;
579 
580 	/* Filter packets based on UDP port */
581 	udph_ptr = filter_udphdr(ctx, port);
582 	if (!udph_ptr)
583 		return XDP_PASS;
584 
585 	hdr_len = (void *)udph_ptr - (void *)(long)ctx->data +
586 		  sizeof(struct udphdr);
587 
588 	key = XDP_ADJST_OFFSET;
589 	val = bpf_map_lookup_elem(&map_xdp_setup, &key);
590 	if (!val)
591 		return XDP_PASS;
592 
593 	switch (*val) {
594 	case -16:
595 	case 16:
596 		size = 16;
597 		break;
598 	case -32:
599 	case 32:
600 		size = 32;
601 		break;
602 	case -64:
603 	case 64:
604 		size = 64;
605 		break;
606 	case -128:
607 	case 128:
608 		size = 128;
609 		break;
610 	case -256:
611 	case 256:
612 		size = 256;
613 		break;
614 	default:
615 		bpf_printk("Invalid adjustment offset: %d\n", *val);
616 		goto abort;
617 	}
618 
619 	if (*val < 0)
620 		res = xdp_adjst_head_grow_data(ctx, hdr_len, size);
621 	else
622 		res = xdp_adjst_head_shrnk_data(ctx, hdr_len, size);
623 
624 	if (res)
625 		goto abort;
626 
627 	record_stats(ctx, STATS_PASS);
628 	return XDP_PASS;
629 
630 abort:
631 	record_stats(ctx, STATS_ABORT);
632 	return XDP_ABORTED;
633 }
634 
635 static int xdp_prog_common(struct xdp_md *ctx)
636 {
637 	__u32 key, *port;
638 	__s32 *mode;
639 
640 	key = XDP_MODE;
641 	mode = bpf_map_lookup_elem(&map_xdp_setup, &key);
642 	if (!mode)
643 		return XDP_PASS;
644 
645 	key = XDP_PORT;
646 	port = bpf_map_lookup_elem(&map_xdp_setup, &key);
647 	if (!port)
648 		return XDP_PASS;
649 
650 	switch (*mode) {
651 	case XDP_MODE_PASS:
652 		return xdp_mode_pass(ctx, (__u16)(*port));
653 	case XDP_MODE_DROP:
654 		return xdp_mode_drop_handler(ctx, (__u16)(*port));
655 	case XDP_MODE_TX:
656 		return xdp_mode_tx_handler(ctx, (__u16)(*port));
657 	case XDP_MODE_TAIL_ADJST:
658 		return xdp_adjst_tail(ctx, (__u16)(*port));
659 	case XDP_MODE_HEAD_ADJST:
660 		return xdp_head_adjst(ctx, (__u16)(*port));
661 	}
662 
663 	/* Default action is to simple pass */
664 	return XDP_PASS;
665 }
666 
667 SEC("xdp")
668 int xdp_prog(struct xdp_md *ctx)
669 {
670 	return xdp_prog_common(ctx);
671 }
672 
673 SEC("xdp.frags")
674 int xdp_prog_frags(struct xdp_md *ctx)
675 {
676 	return xdp_prog_common(ctx);
677 }
678 
679 char _license[] SEC("license") = "GPL";
680