1 // SPDX-License-Identifier: GPL-2.0 2 3 #include <stddef.h> 4 #include <linux/bpf.h> 5 #include <linux/in.h> 6 #include <linux/if_ether.h> 7 #include <linux/ip.h> 8 #include <linux/ipv6.h> 9 #include <linux/udp.h> 10 #include <bpf/bpf_endian.h> 11 #include <bpf/bpf_helpers.h> 12 13 #define MAX_ADJST_OFFSET 256 14 #define MAX_PAYLOAD_LEN 5000 15 #define MAX_HDR_LEN 64 16 17 extern int bpf_xdp_pull_data(struct xdp_md *xdp, __u32 len) __ksym __weak; 18 19 enum { 20 XDP_MODE = 0, 21 XDP_PORT = 1, 22 XDP_ADJST_OFFSET = 2, 23 XDP_ADJST_TAG = 3, 24 } xdp_map_setup_keys; 25 26 enum { 27 XDP_MODE_PASS = 0, 28 XDP_MODE_DROP = 1, 29 XDP_MODE_TX = 2, 30 XDP_MODE_TAIL_ADJST = 3, 31 XDP_MODE_HEAD_ADJST = 4, 32 } xdp_map_modes; 33 34 enum { 35 STATS_RX = 0, 36 STATS_PASS = 1, 37 STATS_DROP = 2, 38 STATS_TX = 3, 39 STATS_ABORT = 4, 40 } xdp_stats; 41 42 struct { 43 __uint(type, BPF_MAP_TYPE_ARRAY); 44 __uint(max_entries, 5); 45 __type(key, __u32); 46 __type(value, __s32); 47 } map_xdp_setup SEC(".maps"); 48 49 struct { 50 __uint(type, BPF_MAP_TYPE_ARRAY); 51 __uint(max_entries, 5); 52 __type(key, __u32); 53 __type(value, __u64); 54 } map_xdp_stats SEC(".maps"); 55 56 static __u32 min(__u32 a, __u32 b) 57 { 58 return a < b ? a : b; 59 } 60 61 static void record_stats(struct xdp_md *ctx, __u32 stat_type) 62 { 63 __u64 *count; 64 65 count = bpf_map_lookup_elem(&map_xdp_stats, &stat_type); 66 67 if (count) 68 __sync_fetch_and_add(count, 1); 69 } 70 71 static struct udphdr *filter_udphdr(struct xdp_md *ctx, __u16 port) 72 { 73 struct udphdr *udph = NULL; 74 void *data, *data_end; 75 struct ethhdr *eth; 76 int err; 77 78 err = bpf_xdp_pull_data(ctx, sizeof(*eth)); 79 if (err) 80 return NULL; 81 82 data_end = (void *)(long)ctx->data_end; 83 data = eth = (void *)(long)ctx->data; 84 85 if (data + sizeof(*eth) > data_end) 86 return NULL; 87 88 if (eth->h_proto == bpf_htons(ETH_P_IP)) { 89 struct iphdr *iph; 90 91 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*iph) + 92 sizeof(*udph)); 93 if (err) 94 return NULL; 95 96 data_end = (void *)(long)ctx->data_end; 97 data = (void *)(long)ctx->data; 98 99 iph = data + sizeof(*eth); 100 101 if (iph + 1 > (struct iphdr *)data_end || 102 iph->protocol != IPPROTO_UDP) 103 return NULL; 104 105 udph = data + sizeof(*iph) + sizeof(*eth); 106 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) { 107 struct ipv6hdr *ipv6h; 108 109 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*ipv6h) + 110 sizeof(*udph)); 111 if (err) 112 return NULL; 113 114 data_end = (void *)(long)ctx->data_end; 115 data = (void *)(long)ctx->data; 116 117 ipv6h = data + sizeof(*eth); 118 119 if (ipv6h + 1 > (struct ipv6hdr *)data_end || 120 ipv6h->nexthdr != IPPROTO_UDP) 121 return NULL; 122 123 udph = data + sizeof(*ipv6h) + sizeof(*eth); 124 } else { 125 return NULL; 126 } 127 128 if (udph + 1 > (struct udphdr *)data_end) 129 return NULL; 130 131 if (udph->dest != bpf_htons(port)) 132 return NULL; 133 134 record_stats(ctx, STATS_RX); 135 136 return udph; 137 } 138 139 static int xdp_mode_pass(struct xdp_md *ctx, __u16 port) 140 { 141 struct udphdr *udph = NULL; 142 143 udph = filter_udphdr(ctx, port); 144 if (!udph) 145 return XDP_PASS; 146 147 record_stats(ctx, STATS_PASS); 148 149 return XDP_PASS; 150 } 151 152 static int xdp_mode_drop_handler(struct xdp_md *ctx, __u16 port) 153 { 154 struct udphdr *udph = NULL; 155 156 udph = filter_udphdr(ctx, port); 157 if (!udph) 158 return XDP_PASS; 159 160 record_stats(ctx, STATS_DROP); 161 162 return XDP_DROP; 163 } 164 165 static void swap_machdr(void *data) 166 { 167 struct ethhdr *eth = data; 168 __u8 tmp_mac[ETH_ALEN]; 169 170 __builtin_memcpy(tmp_mac, eth->h_source, ETH_ALEN); 171 __builtin_memcpy(eth->h_source, eth->h_dest, ETH_ALEN); 172 __builtin_memcpy(eth->h_dest, tmp_mac, ETH_ALEN); 173 } 174 175 static int xdp_mode_tx_handler(struct xdp_md *ctx, __u16 port) 176 { 177 struct udphdr *udph = NULL; 178 void *data, *data_end; 179 struct ethhdr *eth; 180 int err; 181 182 err = bpf_xdp_pull_data(ctx, sizeof(*eth)); 183 if (err) 184 return XDP_PASS; 185 186 data_end = (void *)(long)ctx->data_end; 187 data = eth = (void *)(long)ctx->data; 188 189 if (data + sizeof(*eth) > data_end) 190 return XDP_PASS; 191 192 if (eth->h_proto == bpf_htons(ETH_P_IP)) { 193 struct iphdr *iph; 194 __be32 tmp_ip; 195 196 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*iph) + 197 sizeof(*udph)); 198 if (err) 199 return XDP_PASS; 200 201 data_end = (void *)(long)ctx->data_end; 202 data = (void *)(long)ctx->data; 203 204 iph = data + sizeof(*eth); 205 206 if (iph + 1 > (struct iphdr *)data_end || 207 iph->protocol != IPPROTO_UDP) 208 return XDP_PASS; 209 210 udph = data + sizeof(*iph) + sizeof(*eth); 211 212 if (udph + 1 > (struct udphdr *)data_end) 213 return XDP_PASS; 214 if (udph->dest != bpf_htons(port)) 215 return XDP_PASS; 216 217 record_stats(ctx, STATS_RX); 218 eth = data; 219 swap_machdr((void *)eth); 220 221 tmp_ip = iph->saddr; 222 iph->saddr = iph->daddr; 223 iph->daddr = tmp_ip; 224 225 record_stats(ctx, STATS_TX); 226 227 return XDP_TX; 228 229 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) { 230 struct in6_addr tmp_ipv6; 231 struct ipv6hdr *ipv6h; 232 233 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*ipv6h) + 234 sizeof(*udph)); 235 if (err) 236 return XDP_PASS; 237 238 data_end = (void *)(long)ctx->data_end; 239 data = (void *)(long)ctx->data; 240 241 ipv6h = data + sizeof(*eth); 242 243 if (ipv6h + 1 > (struct ipv6hdr *)data_end || 244 ipv6h->nexthdr != IPPROTO_UDP) 245 return XDP_PASS; 246 247 udph = data + sizeof(*ipv6h) + sizeof(*eth); 248 249 if (udph + 1 > (struct udphdr *)data_end) 250 return XDP_PASS; 251 if (udph->dest != bpf_htons(port)) 252 return XDP_PASS; 253 254 record_stats(ctx, STATS_RX); 255 eth = data; 256 swap_machdr((void *)eth); 257 258 __builtin_memcpy(&tmp_ipv6, &ipv6h->saddr, sizeof(tmp_ipv6)); 259 __builtin_memcpy(&ipv6h->saddr, &ipv6h->daddr, 260 sizeof(tmp_ipv6)); 261 __builtin_memcpy(&ipv6h->daddr, &tmp_ipv6, sizeof(tmp_ipv6)); 262 263 record_stats(ctx, STATS_TX); 264 265 return XDP_TX; 266 } 267 268 return XDP_PASS; 269 } 270 271 static __always_inline __u16 csum_fold_helper(__u32 csum) 272 { 273 csum = (csum & 0xffff) + (csum >> 16); 274 return ~((csum & 0xffff) + (csum >> 16)); 275 } 276 277 static __always_inline __u16 csum_fold_udp_helper(__u32 csum) 278 { 279 return csum_fold_helper(csum) ? : 0xffff; 280 } 281 282 static void *update_pkt(struct xdp_md *ctx, __s16 offset, __u32 *udp_csum) 283 { 284 void *data_end = (void *)(long)ctx->data_end; 285 void *data = (void *)(long)ctx->data; 286 struct udphdr *udph = NULL; 287 struct ethhdr *eth = data; 288 __u32 len, len_new; 289 290 if (data + sizeof(*eth) > data_end) 291 return NULL; 292 293 if (eth->h_proto == bpf_htons(ETH_P_IP)) { 294 struct iphdr *iph = data + sizeof(*eth); 295 296 if (iph + 1 > (struct iphdr *)data_end) 297 return NULL; 298 299 udph = (void *)eth + sizeof(*iph) + sizeof(*eth); 300 if (!udph || udph + 1 > (struct udphdr *)data_end) 301 return NULL; 302 303 len = iph->tot_len; 304 len_new = bpf_htons(bpf_ntohs(len) + offset); 305 iph->tot_len = len_new; 306 iph->check = csum_fold_helper( 307 bpf_csum_diff(&len, sizeof(len), &len_new, 308 sizeof(len_new), ~((__u32)iph->check))); 309 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) { 310 struct ipv6hdr *ipv6h = data + sizeof(*eth); 311 312 if (ipv6h + 1 > (struct ipv6hdr *)data_end) 313 return NULL; 314 315 udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth); 316 if (!udph || udph + 1 > (struct udphdr *)data_end) 317 return NULL; 318 319 len = ipv6h->payload_len; 320 len_new = bpf_htons(bpf_ntohs(len) + offset); 321 ipv6h->payload_len = len_new; 322 } else { 323 return NULL; 324 } 325 326 len = udph->len; 327 len_new = bpf_htons(bpf_ntohs(len) + offset); 328 329 *udp_csum = ~((__u32)udph->check); 330 *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new, 331 sizeof(len_new), *udp_csum); 332 *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new, 333 sizeof(len_new), *udp_csum); 334 335 udph->len = len_new; 336 337 return udph; 338 } 339 340 static int xdp_adjst_tail_shrnk_data(struct xdp_md *ctx, __u16 offset, 341 unsigned long hdr_len) 342 { 343 char tmp_buff[MAX_ADJST_OFFSET]; 344 __u32 buff_pos, udp_csum = 0; 345 struct udphdr *udph = NULL; 346 __u32 buff_len; 347 348 udph = update_pkt(ctx, 0 - offset, &udp_csum); 349 if (!udph) 350 return -1; 351 352 buff_len = bpf_xdp_get_buff_len(ctx); 353 354 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 355 offset & 0xff; 356 if (offset == 0) 357 return -1; 358 359 /* Make sure we have enough data to avoid eating the header */ 360 if (buff_len - offset < hdr_len) 361 return -1; 362 363 buff_pos = buff_len - offset; 364 if (bpf_xdp_load_bytes(ctx, buff_pos, tmp_buff, offset) < 0) 365 return -1; 366 367 udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum); 368 udph->check = (__u16)csum_fold_udp_helper(udp_csum); 369 370 if (bpf_xdp_adjust_tail(ctx, 0 - offset) < 0) 371 return -1; 372 373 return 0; 374 } 375 376 static int xdp_adjst_tail_grow_data(struct xdp_md *ctx, __u16 offset) 377 { 378 char tmp_buff[MAX_ADJST_OFFSET]; 379 __u32 buff_pos, udp_csum = 0; 380 __u32 buff_len, hdr_len, key; 381 struct udphdr *udph; 382 __s32 *val; 383 __u8 tag; 384 385 /* Proceed to update the packet headers before attempting to adjuste 386 * the tail. Once the tail is adjusted we lose access to the offset 387 * amount of data at the end of the packet which is crucial to update 388 * the checksum. 389 * Since any failure beyond this would abort the packet, we should 390 * not worry about passing a packet up the stack with wrong headers 391 */ 392 udph = update_pkt(ctx, offset, &udp_csum); 393 if (!udph) 394 return -1; 395 396 key = XDP_ADJST_TAG; 397 val = bpf_map_lookup_elem(&map_xdp_setup, &key); 398 if (!val) 399 return -1; 400 401 tag = (__u8)(*val); 402 403 for (int i = 0; i < MAX_ADJST_OFFSET; i++) 404 __builtin_memcpy(&tmp_buff[i], &tag, 1); 405 406 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 407 offset & 0xff; 408 if (offset == 0) 409 return -1; 410 411 udp_csum = bpf_csum_diff(0, 0, (__be32 *)tmp_buff, offset, udp_csum); 412 udph->check = (__u16)csum_fold_udp_helper(udp_csum); 413 414 buff_len = bpf_xdp_get_buff_len(ctx); 415 416 if (bpf_xdp_adjust_tail(ctx, offset) < 0) { 417 bpf_printk("Failed to adjust tail\n"); 418 return -1; 419 } 420 421 if (bpf_xdp_store_bytes(ctx, buff_len, tmp_buff, offset) < 0) 422 return -1; 423 424 return 0; 425 } 426 427 static int xdp_adjst_tail(struct xdp_md *ctx, __u16 port) 428 { 429 struct udphdr *udph = NULL; 430 __s32 *adjust_offset, *val; 431 unsigned long hdr_len; 432 void *offset_ptr; 433 __u32 key; 434 __u8 tag; 435 int ret; 436 437 udph = filter_udphdr(ctx, port); 438 if (!udph) 439 return XDP_PASS; 440 441 hdr_len = (void *)udph - (void *)(long)ctx->data + 442 sizeof(struct udphdr); 443 key = XDP_ADJST_OFFSET; 444 adjust_offset = bpf_map_lookup_elem(&map_xdp_setup, &key); 445 if (!adjust_offset) 446 return XDP_PASS; 447 448 if (*adjust_offset < 0) 449 ret = xdp_adjst_tail_shrnk_data(ctx, 450 (__u16)(0 - *adjust_offset), 451 hdr_len); 452 else 453 ret = xdp_adjst_tail_grow_data(ctx, (__u16)(*adjust_offset)); 454 if (ret) 455 goto abort_pkt; 456 457 record_stats(ctx, STATS_PASS); 458 return XDP_PASS; 459 460 abort_pkt: 461 record_stats(ctx, STATS_ABORT); 462 return XDP_ABORTED; 463 } 464 465 static int xdp_adjst_head_shrnk_data(struct xdp_md *ctx, __u64 hdr_len, 466 __u32 offset) 467 { 468 char tmp_buff[MAX_ADJST_OFFSET]; 469 struct udphdr *udph; 470 void *offset_ptr; 471 __u32 udp_csum = 0; 472 473 /* Update the length information in the IP and UDP headers before 474 * adjusting the headroom. This simplifies accessing the relevant 475 * fields in the IP and UDP headers for fragmented packets. Any 476 * failure beyond this point will result in the packet being aborted, 477 * so we don't need to worry about incorrect length information for 478 * passed packets. 479 */ 480 udph = update_pkt(ctx, (__s16)(0 - offset), &udp_csum); 481 if (!udph) 482 return -1; 483 484 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 485 offset & 0xff; 486 if (offset == 0) 487 return -1; 488 489 if (bpf_xdp_load_bytes(ctx, hdr_len, tmp_buff, offset) < 0) 490 return -1; 491 492 udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum); 493 udph->check = (__u16)csum_fold_udp_helper(udp_csum); 494 495 if (bpf_xdp_load_bytes(ctx, 0, tmp_buff, MAX_ADJST_OFFSET) < 0) 496 return -1; 497 498 if (bpf_xdp_adjust_head(ctx, offset) < 0) 499 return -1; 500 501 if (offset > MAX_ADJST_OFFSET) 502 return -1; 503 504 if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0) 505 return -1; 506 507 /* Added here to handle clang complain about negative value */ 508 hdr_len = hdr_len & 0xff; 509 510 if (hdr_len == 0) 511 return -1; 512 513 if (bpf_xdp_store_bytes(ctx, 0, tmp_buff, hdr_len) < 0) 514 return -1; 515 516 return 0; 517 } 518 519 static int xdp_adjst_head_grow_data(struct xdp_md *ctx, __u64 hdr_len, 520 __u32 offset) 521 { 522 char hdr_buff[MAX_HDR_LEN]; 523 char data_buff[MAX_ADJST_OFFSET]; 524 void *offset_ptr; 525 __s32 *val; 526 __u32 key; 527 __u8 tag; 528 __u32 udp_csum = 0; 529 struct udphdr *udph; 530 531 udph = update_pkt(ctx, (__s16)(offset), &udp_csum); 532 if (!udph) 533 return -1; 534 535 key = XDP_ADJST_TAG; 536 val = bpf_map_lookup_elem(&map_xdp_setup, &key); 537 if (!val) 538 return -1; 539 540 tag = (__u8)(*val); 541 for (int i = 0; i < MAX_ADJST_OFFSET; i++) 542 __builtin_memcpy(&data_buff[i], &tag, 1); 543 544 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 545 offset & 0xff; 546 if (offset == 0) 547 return -1; 548 549 udp_csum = bpf_csum_diff(0, 0, (__be32 *)data_buff, offset, udp_csum); 550 udph->check = (__u16)csum_fold_udp_helper(udp_csum); 551 552 if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0) 553 return -1; 554 555 /* Added here to handle clang complain about negative value */ 556 hdr_len = hdr_len & 0xff; 557 558 if (hdr_len == 0) 559 return -1; 560 561 if (bpf_xdp_load_bytes(ctx, 0, hdr_buff, hdr_len) < 0) 562 return -1; 563 564 if (offset > MAX_ADJST_OFFSET) 565 return -1; 566 567 if (bpf_xdp_adjust_head(ctx, 0 - offset) < 0) 568 return -1; 569 570 if (bpf_xdp_store_bytes(ctx, 0, hdr_buff, hdr_len) < 0) 571 return -1; 572 573 if (bpf_xdp_store_bytes(ctx, hdr_len, data_buff, offset) < 0) 574 return -1; 575 576 return 0; 577 } 578 579 static int xdp_head_adjst(struct xdp_md *ctx, __u16 port) 580 { 581 struct udphdr *udph_ptr = NULL; 582 __u32 key, size, hdr_len; 583 __s32 *val; 584 int res; 585 586 /* Filter packets based on UDP port */ 587 udph_ptr = filter_udphdr(ctx, port); 588 if (!udph_ptr) 589 return XDP_PASS; 590 591 hdr_len = (void *)udph_ptr - (void *)(long)ctx->data + 592 sizeof(struct udphdr); 593 594 key = XDP_ADJST_OFFSET; 595 val = bpf_map_lookup_elem(&map_xdp_setup, &key); 596 if (!val) 597 return XDP_PASS; 598 599 switch (*val) { 600 case -16: 601 case 16: 602 size = 16; 603 break; 604 case -32: 605 case 32: 606 size = 32; 607 break; 608 case -64: 609 case 64: 610 size = 64; 611 break; 612 case -128: 613 case 128: 614 size = 128; 615 break; 616 case -256: 617 case 256: 618 size = 256; 619 break; 620 default: 621 bpf_printk("Invalid adjustment offset: %d\n", *val); 622 goto abort; 623 } 624 625 if (*val < 0) 626 res = xdp_adjst_head_grow_data(ctx, hdr_len, size); 627 else 628 res = xdp_adjst_head_shrnk_data(ctx, hdr_len, size); 629 630 if (res) 631 goto abort; 632 633 record_stats(ctx, STATS_PASS); 634 return XDP_PASS; 635 636 abort: 637 record_stats(ctx, STATS_ABORT); 638 return XDP_ABORTED; 639 } 640 641 static int xdp_prog_common(struct xdp_md *ctx) 642 { 643 __u32 key, *port; 644 __s32 *mode; 645 646 key = XDP_MODE; 647 mode = bpf_map_lookup_elem(&map_xdp_setup, &key); 648 if (!mode) 649 return XDP_PASS; 650 651 key = XDP_PORT; 652 port = bpf_map_lookup_elem(&map_xdp_setup, &key); 653 if (!port) 654 return XDP_PASS; 655 656 switch (*mode) { 657 case XDP_MODE_PASS: 658 return xdp_mode_pass(ctx, (__u16)(*port)); 659 case XDP_MODE_DROP: 660 return xdp_mode_drop_handler(ctx, (__u16)(*port)); 661 case XDP_MODE_TX: 662 return xdp_mode_tx_handler(ctx, (__u16)(*port)); 663 case XDP_MODE_TAIL_ADJST: 664 return xdp_adjst_tail(ctx, (__u16)(*port)); 665 case XDP_MODE_HEAD_ADJST: 666 return xdp_head_adjst(ctx, (__u16)(*port)); 667 } 668 669 /* Default action is to simple pass */ 670 return XDP_PASS; 671 } 672 673 SEC("xdp") 674 int xdp_prog(struct xdp_md *ctx) 675 { 676 return xdp_prog_common(ctx); 677 } 678 679 SEC("xdp.frags") 680 int xdp_prog_frags(struct xdp_md *ctx) 681 { 682 return xdp_prog_common(ctx); 683 } 684 685 char _license[] SEC("license") = "GPL"; 686