1 // SPDX-License-Identifier: GPL-2.0
2
3 #include <stddef.h>
4 #include <linux/bpf.h>
5 #include <linux/in.h>
6 #include <linux/if_ether.h>
7 #include <linux/ip.h>
8 #include <linux/ipv6.h>
9 #include <linux/udp.h>
10 #include <bpf/bpf_endian.h>
11 #include <bpf/bpf_helpers.h>
12
13 #define MAX_ADJST_OFFSET 256
14 #define MAX_PAYLOAD_LEN 5000
15 #define MAX_HDR_LEN 64
16
17 extern int bpf_xdp_pull_data(struct xdp_md *xdp, __u32 len) __ksym __weak;
18
19 enum {
20 XDP_MODE = 0,
21 XDP_PORT = 1,
22 XDP_ADJST_OFFSET = 2,
23 XDP_ADJST_TAG = 3,
24 } xdp_map_setup_keys;
25
26 enum {
27 XDP_MODE_PASS = 0,
28 XDP_MODE_DROP = 1,
29 XDP_MODE_TX = 2,
30 XDP_MODE_TAIL_ADJST = 3,
31 XDP_MODE_HEAD_ADJST = 4,
32 } xdp_map_modes;
33
34 enum {
35 STATS_RX = 0,
36 STATS_PASS = 1,
37 STATS_DROP = 2,
38 STATS_TX = 3,
39 STATS_ABORT = 4,
40 } xdp_stats;
41
42 struct {
43 __uint(type, BPF_MAP_TYPE_ARRAY);
44 __uint(max_entries, 5);
45 __type(key, __u32);
46 __type(value, __s32);
47 } map_xdp_setup SEC(".maps");
48
49 struct {
50 __uint(type, BPF_MAP_TYPE_ARRAY);
51 __uint(max_entries, 5);
52 __type(key, __u32);
53 __type(value, __u64);
54 } map_xdp_stats SEC(".maps");
55
min(__u32 a,__u32 b)56 static __u32 min(__u32 a, __u32 b)
57 {
58 return a < b ? a : b;
59 }
60
record_stats(struct xdp_md * ctx,__u32 stat_type)61 static void record_stats(struct xdp_md *ctx, __u32 stat_type)
62 {
63 __u64 *count;
64
65 count = bpf_map_lookup_elem(&map_xdp_stats, &stat_type);
66
67 if (count)
68 __sync_fetch_and_add(count, 1);
69 }
70
filter_udphdr(struct xdp_md * ctx,__u16 port)71 static struct udphdr *filter_udphdr(struct xdp_md *ctx, __u16 port)
72 {
73 struct udphdr *udph = NULL;
74 void *data, *data_end;
75 struct ethhdr *eth;
76 int err;
77
78 err = bpf_xdp_pull_data(ctx, sizeof(*eth));
79 if (err)
80 return NULL;
81
82 data_end = (void *)(long)ctx->data_end;
83 data = eth = (void *)(long)ctx->data;
84
85 if (data + sizeof(*eth) > data_end)
86 return NULL;
87
88 if (eth->h_proto == bpf_htons(ETH_P_IP)) {
89 struct iphdr *iph;
90
91 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*iph) +
92 sizeof(*udph));
93 if (err)
94 return NULL;
95
96 data_end = (void *)(long)ctx->data_end;
97 data = (void *)(long)ctx->data;
98
99 iph = data + sizeof(*eth);
100
101 if (iph + 1 > (struct iphdr *)data_end ||
102 iph->protocol != IPPROTO_UDP)
103 return NULL;
104
105 udph = data + sizeof(*iph) + sizeof(*eth);
106 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
107 struct ipv6hdr *ipv6h;
108
109 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*ipv6h) +
110 sizeof(*udph));
111 if (err)
112 return NULL;
113
114 data_end = (void *)(long)ctx->data_end;
115 data = (void *)(long)ctx->data;
116
117 ipv6h = data + sizeof(*eth);
118
119 if (ipv6h + 1 > (struct ipv6hdr *)data_end ||
120 ipv6h->nexthdr != IPPROTO_UDP)
121 return NULL;
122
123 udph = data + sizeof(*ipv6h) + sizeof(*eth);
124 } else {
125 return NULL;
126 }
127
128 if (udph + 1 > (struct udphdr *)data_end)
129 return NULL;
130
131 if (udph->dest != bpf_htons(port))
132 return NULL;
133
134 record_stats(ctx, STATS_RX);
135
136 return udph;
137 }
138
xdp_mode_pass(struct xdp_md * ctx,__u16 port)139 static int xdp_mode_pass(struct xdp_md *ctx, __u16 port)
140 {
141 struct udphdr *udph = NULL;
142
143 udph = filter_udphdr(ctx, port);
144 if (!udph)
145 return XDP_PASS;
146
147 record_stats(ctx, STATS_PASS);
148
149 return XDP_PASS;
150 }
151
xdp_mode_drop_handler(struct xdp_md * ctx,__u16 port)152 static int xdp_mode_drop_handler(struct xdp_md *ctx, __u16 port)
153 {
154 struct udphdr *udph = NULL;
155
156 udph = filter_udphdr(ctx, port);
157 if (!udph)
158 return XDP_PASS;
159
160 record_stats(ctx, STATS_DROP);
161
162 return XDP_DROP;
163 }
164
swap_machdr(void * data)165 static void swap_machdr(void *data)
166 {
167 struct ethhdr *eth = data;
168 __u8 tmp_mac[ETH_ALEN];
169
170 __builtin_memcpy(tmp_mac, eth->h_source, ETH_ALEN);
171 __builtin_memcpy(eth->h_source, eth->h_dest, ETH_ALEN);
172 __builtin_memcpy(eth->h_dest, tmp_mac, ETH_ALEN);
173 }
174
xdp_mode_tx_handler(struct xdp_md * ctx,__u16 port)175 static int xdp_mode_tx_handler(struct xdp_md *ctx, __u16 port)
176 {
177 struct udphdr *udph = NULL;
178 void *data, *data_end;
179 struct ethhdr *eth;
180 int err;
181
182 err = bpf_xdp_pull_data(ctx, sizeof(*eth));
183 if (err)
184 return XDP_PASS;
185
186 data_end = (void *)(long)ctx->data_end;
187 data = eth = (void *)(long)ctx->data;
188
189 if (data + sizeof(*eth) > data_end)
190 return XDP_PASS;
191
192 if (eth->h_proto == bpf_htons(ETH_P_IP)) {
193 struct iphdr *iph;
194 __be32 tmp_ip;
195
196 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*iph) +
197 sizeof(*udph));
198 if (err)
199 return XDP_PASS;
200
201 data_end = (void *)(long)ctx->data_end;
202 data = (void *)(long)ctx->data;
203
204 iph = data + sizeof(*eth);
205
206 if (iph + 1 > (struct iphdr *)data_end ||
207 iph->protocol != IPPROTO_UDP)
208 return XDP_PASS;
209
210 udph = data + sizeof(*iph) + sizeof(*eth);
211
212 if (udph + 1 > (struct udphdr *)data_end)
213 return XDP_PASS;
214 if (udph->dest != bpf_htons(port))
215 return XDP_PASS;
216
217 record_stats(ctx, STATS_RX);
218 eth = data;
219 swap_machdr((void *)eth);
220
221 tmp_ip = iph->saddr;
222 iph->saddr = iph->daddr;
223 iph->daddr = tmp_ip;
224
225 record_stats(ctx, STATS_TX);
226
227 return XDP_TX;
228
229 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
230 struct in6_addr tmp_ipv6;
231 struct ipv6hdr *ipv6h;
232
233 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*ipv6h) +
234 sizeof(*udph));
235 if (err)
236 return XDP_PASS;
237
238 data_end = (void *)(long)ctx->data_end;
239 data = (void *)(long)ctx->data;
240
241 ipv6h = data + sizeof(*eth);
242
243 if (ipv6h + 1 > (struct ipv6hdr *)data_end ||
244 ipv6h->nexthdr != IPPROTO_UDP)
245 return XDP_PASS;
246
247 udph = data + sizeof(*ipv6h) + sizeof(*eth);
248
249 if (udph + 1 > (struct udphdr *)data_end)
250 return XDP_PASS;
251 if (udph->dest != bpf_htons(port))
252 return XDP_PASS;
253
254 record_stats(ctx, STATS_RX);
255 eth = data;
256 swap_machdr((void *)eth);
257
258 __builtin_memcpy(&tmp_ipv6, &ipv6h->saddr, sizeof(tmp_ipv6));
259 __builtin_memcpy(&ipv6h->saddr, &ipv6h->daddr,
260 sizeof(tmp_ipv6));
261 __builtin_memcpy(&ipv6h->daddr, &tmp_ipv6, sizeof(tmp_ipv6));
262
263 record_stats(ctx, STATS_TX);
264
265 return XDP_TX;
266 }
267
268 return XDP_PASS;
269 }
270
csum_fold_helper(__u32 csum)271 static __always_inline __u16 csum_fold_helper(__u32 csum)
272 {
273 csum = (csum & 0xffff) + (csum >> 16);
274 return ~((csum & 0xffff) + (csum >> 16));
275 }
276
csum_fold_udp_helper(__u32 csum)277 static __always_inline __u16 csum_fold_udp_helper(__u32 csum)
278 {
279 return csum_fold_helper(csum) ? : 0xffff;
280 }
281
update_pkt(struct xdp_md * ctx,__s16 offset,__u32 * udp_csum)282 static void *update_pkt(struct xdp_md *ctx, __s16 offset, __u32 *udp_csum)
283 {
284 void *data_end = (void *)(long)ctx->data_end;
285 void *data = (void *)(long)ctx->data;
286 struct udphdr *udph = NULL;
287 struct ethhdr *eth = data;
288 __u32 len, len_new;
289
290 if (data + sizeof(*eth) > data_end)
291 return NULL;
292
293 if (eth->h_proto == bpf_htons(ETH_P_IP)) {
294 struct iphdr *iph = data + sizeof(*eth);
295
296 if (iph + 1 > (struct iphdr *)data_end)
297 return NULL;
298
299 udph = (void *)eth + sizeof(*iph) + sizeof(*eth);
300 if (!udph || udph + 1 > (struct udphdr *)data_end)
301 return NULL;
302
303 len = iph->tot_len;
304 len_new = bpf_htons(bpf_ntohs(len) + offset);
305 iph->tot_len = len_new;
306 iph->check = csum_fold_helper(
307 bpf_csum_diff(&len, sizeof(len), &len_new,
308 sizeof(len_new), ~((__u32)iph->check)));
309 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
310 struct ipv6hdr *ipv6h = data + sizeof(*eth);
311
312 if (ipv6h + 1 > (struct ipv6hdr *)data_end)
313 return NULL;
314
315 udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth);
316 if (!udph || udph + 1 > (struct udphdr *)data_end)
317 return NULL;
318
319 len = ipv6h->payload_len;
320 len_new = bpf_htons(bpf_ntohs(len) + offset);
321 ipv6h->payload_len = len_new;
322 } else {
323 return NULL;
324 }
325
326 len = udph->len;
327 len_new = bpf_htons(bpf_ntohs(len) + offset);
328
329 *udp_csum = ~((__u32)udph->check);
330 *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new,
331 sizeof(len_new), *udp_csum);
332 *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new,
333 sizeof(len_new), *udp_csum);
334
335 udph->len = len_new;
336
337 return udph;
338 }
339
xdp_adjst_tail_shrnk_data(struct xdp_md * ctx,__u16 offset,unsigned long hdr_len)340 static int xdp_adjst_tail_shrnk_data(struct xdp_md *ctx, __u16 offset,
341 unsigned long hdr_len)
342 {
343 char tmp_buff[MAX_ADJST_OFFSET];
344 __u32 buff_pos, udp_csum = 0;
345 struct udphdr *udph = NULL;
346 __u32 buff_len;
347
348 udph = update_pkt(ctx, 0 - offset, &udp_csum);
349 if (!udph)
350 return -1;
351
352 buff_len = bpf_xdp_get_buff_len(ctx);
353
354 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
355 offset & 0xff;
356 if (offset == 0)
357 return -1;
358
359 /* Make sure we have enough data to avoid eating the header */
360 if (buff_len - offset < hdr_len)
361 return -1;
362
363 buff_pos = buff_len - offset;
364 if (bpf_xdp_load_bytes(ctx, buff_pos, tmp_buff, offset) < 0)
365 return -1;
366
367 udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum);
368 udph->check = (__u16)csum_fold_udp_helper(udp_csum);
369
370 if (bpf_xdp_adjust_tail(ctx, 0 - offset) < 0)
371 return -1;
372
373 return 0;
374 }
375
xdp_adjst_tail_grow_data(struct xdp_md * ctx,__u16 offset)376 static int xdp_adjst_tail_grow_data(struct xdp_md *ctx, __u16 offset)
377 {
378 char tmp_buff[MAX_ADJST_OFFSET];
379 __u32 buff_pos, udp_csum = 0;
380 __u32 buff_len, hdr_len, key;
381 struct udphdr *udph;
382 __s32 *val;
383 __u8 tag;
384
385 /* Proceed to update the packet headers before attempting to adjuste
386 * the tail. Once the tail is adjusted we lose access to the offset
387 * amount of data at the end of the packet which is crucial to update
388 * the checksum.
389 * Since any failure beyond this would abort the packet, we should
390 * not worry about passing a packet up the stack with wrong headers
391 */
392 udph = update_pkt(ctx, offset, &udp_csum);
393 if (!udph)
394 return -1;
395
396 key = XDP_ADJST_TAG;
397 val = bpf_map_lookup_elem(&map_xdp_setup, &key);
398 if (!val)
399 return -1;
400
401 tag = (__u8)(*val);
402
403 for (int i = 0; i < MAX_ADJST_OFFSET; i++)
404 __builtin_memcpy(&tmp_buff[i], &tag, 1);
405
406 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
407 offset & 0xff;
408 if (offset == 0)
409 return -1;
410
411 udp_csum = bpf_csum_diff(0, 0, (__be32 *)tmp_buff, offset, udp_csum);
412 udph->check = (__u16)csum_fold_udp_helper(udp_csum);
413
414 buff_len = bpf_xdp_get_buff_len(ctx);
415
416 if (bpf_xdp_adjust_tail(ctx, offset) < 0) {
417 bpf_printk("Failed to adjust tail\n");
418 return -1;
419 }
420
421 if (bpf_xdp_store_bytes(ctx, buff_len, tmp_buff, offset) < 0)
422 return -1;
423
424 return 0;
425 }
426
xdp_adjst_tail(struct xdp_md * ctx,__u16 port)427 static int xdp_adjst_tail(struct xdp_md *ctx, __u16 port)
428 {
429 struct udphdr *udph = NULL;
430 __s32 *adjust_offset, *val;
431 unsigned long hdr_len;
432 void *offset_ptr;
433 __u32 key;
434 __u8 tag;
435 int ret;
436
437 udph = filter_udphdr(ctx, port);
438 if (!udph)
439 return XDP_PASS;
440
441 hdr_len = (void *)udph - (void *)(long)ctx->data +
442 sizeof(struct udphdr);
443 key = XDP_ADJST_OFFSET;
444 adjust_offset = bpf_map_lookup_elem(&map_xdp_setup, &key);
445 if (!adjust_offset)
446 return XDP_PASS;
447
448 if (*adjust_offset < 0)
449 ret = xdp_adjst_tail_shrnk_data(ctx,
450 (__u16)(0 - *adjust_offset),
451 hdr_len);
452 else
453 ret = xdp_adjst_tail_grow_data(ctx, (__u16)(*adjust_offset));
454 if (ret)
455 goto abort_pkt;
456
457 record_stats(ctx, STATS_PASS);
458 return XDP_PASS;
459
460 abort_pkt:
461 record_stats(ctx, STATS_ABORT);
462 return XDP_ABORTED;
463 }
464
xdp_adjst_head_shrnk_data(struct xdp_md * ctx,__u64 hdr_len,__u32 offset)465 static int xdp_adjst_head_shrnk_data(struct xdp_md *ctx, __u64 hdr_len,
466 __u32 offset)
467 {
468 char tmp_buff[MAX_ADJST_OFFSET];
469 struct udphdr *udph;
470 void *offset_ptr;
471 __u32 udp_csum = 0;
472
473 /* Update the length information in the IP and UDP headers before
474 * adjusting the headroom. This simplifies accessing the relevant
475 * fields in the IP and UDP headers for fragmented packets. Any
476 * failure beyond this point will result in the packet being aborted,
477 * so we don't need to worry about incorrect length information for
478 * passed packets.
479 */
480 udph = update_pkt(ctx, (__s16)(0 - offset), &udp_csum);
481 if (!udph)
482 return -1;
483
484 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
485 offset & 0xff;
486 if (offset == 0)
487 return -1;
488
489 if (bpf_xdp_load_bytes(ctx, hdr_len, tmp_buff, offset) < 0)
490 return -1;
491
492 udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum);
493 udph->check = (__u16)csum_fold_udp_helper(udp_csum);
494
495 if (bpf_xdp_load_bytes(ctx, 0, tmp_buff, MAX_ADJST_OFFSET) < 0)
496 return -1;
497
498 if (bpf_xdp_adjust_head(ctx, offset) < 0)
499 return -1;
500
501 if (offset > MAX_ADJST_OFFSET)
502 return -1;
503
504 if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0)
505 return -1;
506
507 /* Added here to handle clang complain about negative value */
508 hdr_len = hdr_len & 0xff;
509
510 if (hdr_len == 0)
511 return -1;
512
513 if (bpf_xdp_store_bytes(ctx, 0, tmp_buff, hdr_len) < 0)
514 return -1;
515
516 return 0;
517 }
518
xdp_adjst_head_grow_data(struct xdp_md * ctx,__u64 hdr_len,__u32 offset)519 static int xdp_adjst_head_grow_data(struct xdp_md *ctx, __u64 hdr_len,
520 __u32 offset)
521 {
522 char hdr_buff[MAX_HDR_LEN];
523 char data_buff[MAX_ADJST_OFFSET];
524 void *offset_ptr;
525 __s32 *val;
526 __u32 key;
527 __u8 tag;
528 __u32 udp_csum = 0;
529 struct udphdr *udph;
530
531 udph = update_pkt(ctx, (__s16)(offset), &udp_csum);
532 if (!udph)
533 return -1;
534
535 key = XDP_ADJST_TAG;
536 val = bpf_map_lookup_elem(&map_xdp_setup, &key);
537 if (!val)
538 return -1;
539
540 tag = (__u8)(*val);
541 for (int i = 0; i < MAX_ADJST_OFFSET; i++)
542 __builtin_memcpy(&data_buff[i], &tag, 1);
543
544 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
545 offset & 0xff;
546 if (offset == 0)
547 return -1;
548
549 udp_csum = bpf_csum_diff(0, 0, (__be32 *)data_buff, offset, udp_csum);
550 udph->check = (__u16)csum_fold_udp_helper(udp_csum);
551
552 if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0)
553 return -1;
554
555 /* Added here to handle clang complain about negative value */
556 hdr_len = hdr_len & 0xff;
557
558 if (hdr_len == 0)
559 return -1;
560
561 if (bpf_xdp_load_bytes(ctx, 0, hdr_buff, hdr_len) < 0)
562 return -1;
563
564 if (offset > MAX_ADJST_OFFSET)
565 return -1;
566
567 if (bpf_xdp_adjust_head(ctx, 0 - offset) < 0)
568 return -1;
569
570 if (bpf_xdp_store_bytes(ctx, 0, hdr_buff, hdr_len) < 0)
571 return -1;
572
573 if (bpf_xdp_store_bytes(ctx, hdr_len, data_buff, offset) < 0)
574 return -1;
575
576 return 0;
577 }
578
xdp_head_adjst(struct xdp_md * ctx,__u16 port)579 static int xdp_head_adjst(struct xdp_md *ctx, __u16 port)
580 {
581 struct udphdr *udph_ptr = NULL;
582 __u32 key, size, hdr_len;
583 __s32 *val;
584 int res;
585
586 /* Filter packets based on UDP port */
587 udph_ptr = filter_udphdr(ctx, port);
588 if (!udph_ptr)
589 return XDP_PASS;
590
591 hdr_len = (void *)udph_ptr - (void *)(long)ctx->data +
592 sizeof(struct udphdr);
593
594 key = XDP_ADJST_OFFSET;
595 val = bpf_map_lookup_elem(&map_xdp_setup, &key);
596 if (!val)
597 return XDP_PASS;
598
599 switch (*val) {
600 case -16:
601 case 16:
602 size = 16;
603 break;
604 case -32:
605 case 32:
606 size = 32;
607 break;
608 case -64:
609 case 64:
610 size = 64;
611 break;
612 case -128:
613 case 128:
614 size = 128;
615 break;
616 case -256:
617 case 256:
618 size = 256;
619 break;
620 default:
621 bpf_printk("Invalid adjustment offset: %d\n", *val);
622 goto abort;
623 }
624
625 if (*val < 0)
626 res = xdp_adjst_head_grow_data(ctx, hdr_len, size);
627 else
628 res = xdp_adjst_head_shrnk_data(ctx, hdr_len, size);
629
630 if (res)
631 goto abort;
632
633 record_stats(ctx, STATS_PASS);
634 return XDP_PASS;
635
636 abort:
637 record_stats(ctx, STATS_ABORT);
638 return XDP_ABORTED;
639 }
640
xdp_prog_common(struct xdp_md * ctx)641 static int xdp_prog_common(struct xdp_md *ctx)
642 {
643 __u32 key, *port;
644 __s32 *mode;
645
646 key = XDP_MODE;
647 mode = bpf_map_lookup_elem(&map_xdp_setup, &key);
648 if (!mode)
649 return XDP_PASS;
650
651 key = XDP_PORT;
652 port = bpf_map_lookup_elem(&map_xdp_setup, &key);
653 if (!port)
654 return XDP_PASS;
655
656 switch (*mode) {
657 case XDP_MODE_PASS:
658 return xdp_mode_pass(ctx, (__u16)(*port));
659 case XDP_MODE_DROP:
660 return xdp_mode_drop_handler(ctx, (__u16)(*port));
661 case XDP_MODE_TX:
662 return xdp_mode_tx_handler(ctx, (__u16)(*port));
663 case XDP_MODE_TAIL_ADJST:
664 return xdp_adjst_tail(ctx, (__u16)(*port));
665 case XDP_MODE_HEAD_ADJST:
666 return xdp_head_adjst(ctx, (__u16)(*port));
667 }
668
669 /* Default action is to simple pass */
670 return XDP_PASS;
671 }
672
673 SEC("xdp")
xdp_prog(struct xdp_md * ctx)674 int xdp_prog(struct xdp_md *ctx)
675 {
676 return xdp_prog_common(ctx);
677 }
678
679 SEC("xdp.frags")
xdp_prog_frags(struct xdp_md * ctx)680 int xdp_prog_frags(struct xdp_md *ctx)
681 {
682 return xdp_prog_common(ctx);
683 }
684
685 char _license[] SEC("license") = "GPL";
686