1 // SPDX-License-Identifier: GPL-2.0 2 3 #include <stddef.h> 4 #include <linux/bpf.h> 5 #include <linux/in.h> 6 #include <linux/if_ether.h> 7 #include <linux/ip.h> 8 #include <linux/ipv6.h> 9 #include <linux/udp.h> 10 #include <bpf/bpf_endian.h> 11 #include <bpf/bpf_helpers.h> 12 13 #define MAX_ADJST_OFFSET 256 14 #define MAX_PAYLOAD_LEN 5000 15 #define MAX_HDR_LEN 64 16 17 extern int bpf_xdp_pull_data(struct xdp_md *xdp, __u32 len) __ksym __weak; 18 19 enum { 20 XDP_MODE = 0, 21 XDP_PORT = 1, 22 XDP_ADJST_OFFSET = 2, 23 XDP_ADJST_TAG = 3, 24 } xdp_map_setup_keys; 25 26 enum { 27 XDP_MODE_PASS = 0, 28 XDP_MODE_DROP = 1, 29 XDP_MODE_TX = 2, 30 XDP_MODE_TAIL_ADJST = 3, 31 XDP_MODE_HEAD_ADJST = 4, 32 } xdp_map_modes; 33 34 enum { 35 STATS_RX = 0, 36 STATS_PASS = 1, 37 STATS_DROP = 2, 38 STATS_TX = 3, 39 STATS_ABORT = 4, 40 } xdp_stats; 41 42 struct { 43 __uint(type, BPF_MAP_TYPE_ARRAY); 44 __uint(max_entries, 5); 45 __type(key, __u32); 46 __type(value, __s32); 47 } map_xdp_setup SEC(".maps"); 48 49 struct { 50 __uint(type, BPF_MAP_TYPE_ARRAY); 51 __uint(max_entries, 5); 52 __type(key, __u32); 53 __type(value, __u64); 54 } map_xdp_stats SEC(".maps"); 55 56 static __u32 min(__u32 a, __u32 b) 57 { 58 return a < b ? a : b; 59 } 60 61 static void record_stats(struct xdp_md *ctx, __u32 stat_type) 62 { 63 __u64 *count; 64 65 count = bpf_map_lookup_elem(&map_xdp_stats, &stat_type); 66 67 if (count) 68 __sync_fetch_and_add(count, 1); 69 } 70 71 static struct udphdr *filter_udphdr(struct xdp_md *ctx, __u16 port) 72 { 73 struct udphdr *udph = NULL; 74 void *data, *data_end; 75 struct ethhdr *eth; 76 int err; 77 78 err = bpf_xdp_pull_data(ctx, sizeof(*eth)); 79 if (err) 80 return NULL; 81 82 data_end = (void *)(long)ctx->data_end; 83 data = eth = (void *)(long)ctx->data; 84 85 if (data + sizeof(*eth) > data_end) 86 return NULL; 87 88 if (eth->h_proto == bpf_htons(ETH_P_IP)) { 89 struct iphdr *iph; 90 91 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*iph) + 92 sizeof(*udph)); 93 if (err) 94 return NULL; 95 96 data_end = (void *)(long)ctx->data_end; 97 data = (void *)(long)ctx->data; 98 99 iph = data + sizeof(*eth); 100 101 if (iph + 1 > (struct iphdr *)data_end || 102 iph->protocol != IPPROTO_UDP) 103 return NULL; 104 105 udph = data + sizeof(*iph) + sizeof(*eth); 106 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) { 107 struct ipv6hdr *ipv6h; 108 109 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*ipv6h) + 110 sizeof(*udph)); 111 if (err) 112 return NULL; 113 114 data_end = (void *)(long)ctx->data_end; 115 data = (void *)(long)ctx->data; 116 117 ipv6h = data + sizeof(*eth); 118 119 if (ipv6h + 1 > (struct ipv6hdr *)data_end || 120 ipv6h->nexthdr != IPPROTO_UDP) 121 return NULL; 122 123 udph = data + sizeof(*ipv6h) + sizeof(*eth); 124 } else { 125 return NULL; 126 } 127 128 if (udph + 1 > (struct udphdr *)data_end) 129 return NULL; 130 131 if (udph->dest != bpf_htons(port)) 132 return NULL; 133 134 record_stats(ctx, STATS_RX); 135 136 return udph; 137 } 138 139 static int xdp_mode_pass(struct xdp_md *ctx, __u16 port) 140 { 141 struct udphdr *udph = NULL; 142 143 udph = filter_udphdr(ctx, port); 144 if (!udph) 145 return XDP_PASS; 146 147 record_stats(ctx, STATS_PASS); 148 149 return XDP_PASS; 150 } 151 152 static int xdp_mode_drop_handler(struct xdp_md *ctx, __u16 port) 153 { 154 struct udphdr *udph = NULL; 155 156 udph = filter_udphdr(ctx, port); 157 if (!udph) 158 return XDP_PASS; 159 160 record_stats(ctx, STATS_DROP); 161 162 return XDP_DROP; 163 } 164 165 static void swap_machdr(void *data) 166 { 167 struct ethhdr *eth = data; 168 __u8 tmp_mac[ETH_ALEN]; 169 170 __builtin_memcpy(tmp_mac, eth->h_source, ETH_ALEN); 171 __builtin_memcpy(eth->h_source, eth->h_dest, ETH_ALEN); 172 __builtin_memcpy(eth->h_dest, tmp_mac, ETH_ALEN); 173 } 174 175 static int xdp_mode_tx_handler(struct xdp_md *ctx, __u16 port) 176 { 177 struct udphdr *udph = NULL; 178 void *data, *data_end; 179 struct ethhdr *eth; 180 int err; 181 182 err = bpf_xdp_pull_data(ctx, sizeof(*eth)); 183 if (err) 184 return XDP_PASS; 185 186 data_end = (void *)(long)ctx->data_end; 187 data = eth = (void *)(long)ctx->data; 188 189 if (data + sizeof(*eth) > data_end) 190 return XDP_PASS; 191 192 if (eth->h_proto == bpf_htons(ETH_P_IP)) { 193 struct iphdr *iph; 194 __be32 tmp_ip; 195 196 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*iph) + 197 sizeof(*udph)); 198 if (err) 199 return XDP_PASS; 200 201 data_end = (void *)(long)ctx->data_end; 202 data = (void *)(long)ctx->data; 203 204 iph = data + sizeof(*eth); 205 206 if (iph + 1 > (struct iphdr *)data_end || 207 iph->protocol != IPPROTO_UDP) 208 return XDP_PASS; 209 210 udph = data + sizeof(*iph) + sizeof(*eth); 211 212 if (udph + 1 > (struct udphdr *)data_end) 213 return XDP_PASS; 214 if (udph->dest != bpf_htons(port)) 215 return XDP_PASS; 216 217 record_stats(ctx, STATS_RX); 218 eth = data; 219 swap_machdr((void *)eth); 220 221 tmp_ip = iph->saddr; 222 iph->saddr = iph->daddr; 223 iph->daddr = tmp_ip; 224 225 record_stats(ctx, STATS_TX); 226 227 return XDP_TX; 228 229 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) { 230 struct in6_addr tmp_ipv6; 231 struct ipv6hdr *ipv6h; 232 233 err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*ipv6h) + 234 sizeof(*udph)); 235 if (err) 236 return XDP_PASS; 237 238 data_end = (void *)(long)ctx->data_end; 239 data = (void *)(long)ctx->data; 240 241 ipv6h = data + sizeof(*eth); 242 243 if (ipv6h + 1 > (struct ipv6hdr *)data_end || 244 ipv6h->nexthdr != IPPROTO_UDP) 245 return XDP_PASS; 246 247 udph = data + sizeof(*ipv6h) + sizeof(*eth); 248 249 if (udph + 1 > (struct udphdr *)data_end) 250 return XDP_PASS; 251 if (udph->dest != bpf_htons(port)) 252 return XDP_PASS; 253 254 record_stats(ctx, STATS_RX); 255 eth = data; 256 swap_machdr((void *)eth); 257 258 __builtin_memcpy(&tmp_ipv6, &ipv6h->saddr, sizeof(tmp_ipv6)); 259 __builtin_memcpy(&ipv6h->saddr, &ipv6h->daddr, 260 sizeof(tmp_ipv6)); 261 __builtin_memcpy(&ipv6h->daddr, &tmp_ipv6, sizeof(tmp_ipv6)); 262 263 record_stats(ctx, STATS_TX); 264 265 return XDP_TX; 266 } 267 268 return XDP_PASS; 269 } 270 271 static void *update_pkt(struct xdp_md *ctx, __s16 offset, __u32 *udp_csum) 272 { 273 void *data_end = (void *)(long)ctx->data_end; 274 void *data = (void *)(long)ctx->data; 275 struct udphdr *udph = NULL; 276 struct ethhdr *eth = data; 277 __u32 len, len_new; 278 279 if (data + sizeof(*eth) > data_end) 280 return NULL; 281 282 if (eth->h_proto == bpf_htons(ETH_P_IP)) { 283 struct iphdr *iph = data + sizeof(*eth); 284 __u16 total_len; 285 286 if (iph + 1 > (struct iphdr *)data_end) 287 return NULL; 288 289 iph->tot_len = bpf_htons(bpf_ntohs(iph->tot_len) + offset); 290 291 udph = (void *)eth + sizeof(*iph) + sizeof(*eth); 292 if (!udph || udph + 1 > (struct udphdr *)data_end) 293 return NULL; 294 295 len_new = bpf_htons(bpf_ntohs(udph->len) + offset); 296 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) { 297 struct ipv6hdr *ipv6h = data + sizeof(*eth); 298 __u16 payload_len; 299 300 if (ipv6h + 1 > (struct ipv6hdr *)data_end) 301 return NULL; 302 303 udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth); 304 if (!udph || udph + 1 > (struct udphdr *)data_end) 305 return NULL; 306 307 *udp_csum = ~((__u32)udph->check); 308 309 len = ipv6h->payload_len; 310 len_new = bpf_htons(bpf_ntohs(len) + offset); 311 ipv6h->payload_len = len_new; 312 313 *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new, 314 sizeof(len_new), *udp_csum); 315 316 len = udph->len; 317 len_new = bpf_htons(bpf_ntohs(udph->len) + offset); 318 *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new, 319 sizeof(len_new), *udp_csum); 320 } else { 321 return NULL; 322 } 323 324 udph->len = len_new; 325 326 return udph; 327 } 328 329 static __u16 csum_fold_helper(__u32 csum) 330 { 331 return ~((csum & 0xffff) + (csum >> 16)) ? : 0xffff; 332 } 333 334 static int xdp_adjst_tail_shrnk_data(struct xdp_md *ctx, __u16 offset, 335 __u32 hdr_len) 336 { 337 char tmp_buff[MAX_ADJST_OFFSET]; 338 __u32 buff_pos, udp_csum = 0; 339 struct udphdr *udph = NULL; 340 __u32 buff_len; 341 342 udph = update_pkt(ctx, 0 - offset, &udp_csum); 343 if (!udph) 344 return -1; 345 346 buff_len = bpf_xdp_get_buff_len(ctx); 347 348 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 349 offset & 0xff; 350 if (offset == 0) 351 return -1; 352 353 /* Make sure we have enough data to avoid eating the header */ 354 if (buff_len - offset < hdr_len) 355 return -1; 356 357 buff_pos = buff_len - offset; 358 if (bpf_xdp_load_bytes(ctx, buff_pos, tmp_buff, offset) < 0) 359 return -1; 360 361 udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum); 362 udph->check = (__u16)csum_fold_helper(udp_csum); 363 364 if (bpf_xdp_adjust_tail(ctx, 0 - offset) < 0) 365 return -1; 366 367 return 0; 368 } 369 370 static int xdp_adjst_tail_grow_data(struct xdp_md *ctx, __u16 offset) 371 { 372 char tmp_buff[MAX_ADJST_OFFSET]; 373 __u32 buff_pos, udp_csum = 0; 374 __u32 buff_len, hdr_len, key; 375 struct udphdr *udph; 376 __s32 *val; 377 __u8 tag; 378 379 /* Proceed to update the packet headers before attempting to adjuste 380 * the tail. Once the tail is adjusted we lose access to the offset 381 * amount of data at the end of the packet which is crucial to update 382 * the checksum. 383 * Since any failure beyond this would abort the packet, we should 384 * not worry about passing a packet up the stack with wrong headers 385 */ 386 udph = update_pkt(ctx, offset, &udp_csum); 387 if (!udph) 388 return -1; 389 390 key = XDP_ADJST_TAG; 391 val = bpf_map_lookup_elem(&map_xdp_setup, &key); 392 if (!val) 393 return -1; 394 395 tag = (__u8)(*val); 396 397 for (int i = 0; i < MAX_ADJST_OFFSET; i++) 398 __builtin_memcpy(&tmp_buff[i], &tag, 1); 399 400 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 401 offset & 0xff; 402 if (offset == 0) 403 return -1; 404 405 udp_csum = bpf_csum_diff(0, 0, (__be32 *)tmp_buff, offset, udp_csum); 406 udph->check = (__u16)csum_fold_helper(udp_csum); 407 408 buff_len = bpf_xdp_get_buff_len(ctx); 409 410 if (bpf_xdp_adjust_tail(ctx, offset) < 0) { 411 bpf_printk("Failed to adjust tail\n"); 412 return -1; 413 } 414 415 if (bpf_xdp_store_bytes(ctx, buff_len, tmp_buff, offset) < 0) 416 return -1; 417 418 return 0; 419 } 420 421 static int xdp_adjst_tail(struct xdp_md *ctx, __u16 port) 422 { 423 struct udphdr *udph = NULL; 424 __s32 *adjust_offset, *val; 425 __u32 key, hdr_len; 426 void *offset_ptr; 427 __u8 tag; 428 int ret; 429 430 udph = filter_udphdr(ctx, port); 431 if (!udph) 432 return XDP_PASS; 433 434 hdr_len = (void *)udph - (void *)(long)ctx->data + 435 sizeof(struct udphdr); 436 key = XDP_ADJST_OFFSET; 437 adjust_offset = bpf_map_lookup_elem(&map_xdp_setup, &key); 438 if (!adjust_offset) 439 return XDP_PASS; 440 441 if (*adjust_offset < 0) 442 ret = xdp_adjst_tail_shrnk_data(ctx, 443 (__u16)(0 - *adjust_offset), 444 hdr_len); 445 else 446 ret = xdp_adjst_tail_grow_data(ctx, (__u16)(*adjust_offset)); 447 if (ret) 448 goto abort_pkt; 449 450 record_stats(ctx, STATS_PASS); 451 return XDP_PASS; 452 453 abort_pkt: 454 record_stats(ctx, STATS_ABORT); 455 return XDP_ABORTED; 456 } 457 458 static int xdp_adjst_head_shrnk_data(struct xdp_md *ctx, __u64 hdr_len, 459 __u32 offset) 460 { 461 char tmp_buff[MAX_ADJST_OFFSET]; 462 struct udphdr *udph; 463 void *offset_ptr; 464 __u32 udp_csum = 0; 465 466 /* Update the length information in the IP and UDP headers before 467 * adjusting the headroom. This simplifies accessing the relevant 468 * fields in the IP and UDP headers for fragmented packets. Any 469 * failure beyond this point will result in the packet being aborted, 470 * so we don't need to worry about incorrect length information for 471 * passed packets. 472 */ 473 udph = update_pkt(ctx, (__s16)(0 - offset), &udp_csum); 474 if (!udph) 475 return -1; 476 477 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 478 offset & 0xff; 479 if (offset == 0) 480 return -1; 481 482 if (bpf_xdp_load_bytes(ctx, hdr_len, tmp_buff, offset) < 0) 483 return -1; 484 485 udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum); 486 487 udph->check = (__u16)csum_fold_helper(udp_csum); 488 489 if (bpf_xdp_load_bytes(ctx, 0, tmp_buff, MAX_ADJST_OFFSET) < 0) 490 return -1; 491 492 if (bpf_xdp_adjust_head(ctx, offset) < 0) 493 return -1; 494 495 if (offset > MAX_ADJST_OFFSET) 496 return -1; 497 498 if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0) 499 return -1; 500 501 /* Added here to handle clang complain about negative value */ 502 hdr_len = hdr_len & 0xff; 503 504 if (hdr_len == 0) 505 return -1; 506 507 if (bpf_xdp_store_bytes(ctx, 0, tmp_buff, hdr_len) < 0) 508 return -1; 509 510 return 0; 511 } 512 513 static int xdp_adjst_head_grow_data(struct xdp_md *ctx, __u64 hdr_len, 514 __u32 offset) 515 { 516 char hdr_buff[MAX_HDR_LEN]; 517 char data_buff[MAX_ADJST_OFFSET]; 518 void *offset_ptr; 519 __s32 *val; 520 __u32 key; 521 __u8 tag; 522 __u32 udp_csum = 0; 523 struct udphdr *udph; 524 525 udph = update_pkt(ctx, (__s16)(offset), &udp_csum); 526 if (!udph) 527 return -1; 528 529 key = XDP_ADJST_TAG; 530 val = bpf_map_lookup_elem(&map_xdp_setup, &key); 531 if (!val) 532 return -1; 533 534 tag = (__u8)(*val); 535 for (int i = 0; i < MAX_ADJST_OFFSET; i++) 536 __builtin_memcpy(&data_buff[i], &tag, 1); 537 538 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 539 offset & 0xff; 540 if (offset == 0) 541 return -1; 542 543 udp_csum = bpf_csum_diff(0, 0, (__be32 *)data_buff, offset, udp_csum); 544 udph->check = (__u16)csum_fold_helper(udp_csum); 545 546 if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0) 547 return -1; 548 549 /* Added here to handle clang complain about negative value */ 550 hdr_len = hdr_len & 0xff; 551 552 if (hdr_len == 0) 553 return -1; 554 555 if (bpf_xdp_load_bytes(ctx, 0, hdr_buff, hdr_len) < 0) 556 return -1; 557 558 if (offset > MAX_ADJST_OFFSET) 559 return -1; 560 561 if (bpf_xdp_adjust_head(ctx, 0 - offset) < 0) 562 return -1; 563 564 if (bpf_xdp_store_bytes(ctx, 0, hdr_buff, hdr_len) < 0) 565 return -1; 566 567 if (bpf_xdp_store_bytes(ctx, hdr_len, data_buff, offset) < 0) 568 return -1; 569 570 return 0; 571 } 572 573 static int xdp_head_adjst(struct xdp_md *ctx, __u16 port) 574 { 575 struct udphdr *udph_ptr = NULL; 576 __u32 key, size, hdr_len; 577 __s32 *val; 578 int res; 579 580 /* Filter packets based on UDP port */ 581 udph_ptr = filter_udphdr(ctx, port); 582 if (!udph_ptr) 583 return XDP_PASS; 584 585 hdr_len = (void *)udph_ptr - (void *)(long)ctx->data + 586 sizeof(struct udphdr); 587 588 key = XDP_ADJST_OFFSET; 589 val = bpf_map_lookup_elem(&map_xdp_setup, &key); 590 if (!val) 591 return XDP_PASS; 592 593 switch (*val) { 594 case -16: 595 case 16: 596 size = 16; 597 break; 598 case -32: 599 case 32: 600 size = 32; 601 break; 602 case -64: 603 case 64: 604 size = 64; 605 break; 606 case -128: 607 case 128: 608 size = 128; 609 break; 610 case -256: 611 case 256: 612 size = 256; 613 break; 614 default: 615 bpf_printk("Invalid adjustment offset: %d\n", *val); 616 goto abort; 617 } 618 619 if (*val < 0) 620 res = xdp_adjst_head_grow_data(ctx, hdr_len, size); 621 else 622 res = xdp_adjst_head_shrnk_data(ctx, hdr_len, size); 623 624 if (res) 625 goto abort; 626 627 record_stats(ctx, STATS_PASS); 628 return XDP_PASS; 629 630 abort: 631 record_stats(ctx, STATS_ABORT); 632 return XDP_ABORTED; 633 } 634 635 static int xdp_prog_common(struct xdp_md *ctx) 636 { 637 __u32 key, *port; 638 __s32 *mode; 639 640 key = XDP_MODE; 641 mode = bpf_map_lookup_elem(&map_xdp_setup, &key); 642 if (!mode) 643 return XDP_PASS; 644 645 key = XDP_PORT; 646 port = bpf_map_lookup_elem(&map_xdp_setup, &key); 647 if (!port) 648 return XDP_PASS; 649 650 switch (*mode) { 651 case XDP_MODE_PASS: 652 return xdp_mode_pass(ctx, (__u16)(*port)); 653 case XDP_MODE_DROP: 654 return xdp_mode_drop_handler(ctx, (__u16)(*port)); 655 case XDP_MODE_TX: 656 return xdp_mode_tx_handler(ctx, (__u16)(*port)); 657 case XDP_MODE_TAIL_ADJST: 658 return xdp_adjst_tail(ctx, (__u16)(*port)); 659 case XDP_MODE_HEAD_ADJST: 660 return xdp_head_adjst(ctx, (__u16)(*port)); 661 } 662 663 /* Default action is to simple pass */ 664 return XDP_PASS; 665 } 666 667 SEC("xdp") 668 int xdp_prog(struct xdp_md *ctx) 669 { 670 return xdp_prog_common(ctx); 671 } 672 673 SEC("xdp.frags") 674 int xdp_prog_frags(struct xdp_md *ctx) 675 { 676 return xdp_prog_common(ctx); 677 } 678 679 char _license[] SEC("license") = "GPL"; 680