1 // SPDX-License-Identifier: GPL-2.0 2 3 #include <stddef.h> 4 #include <linux/bpf.h> 5 #include <linux/in.h> 6 #include <linux/if_ether.h> 7 #include <linux/ip.h> 8 #include <linux/ipv6.h> 9 #include <linux/udp.h> 10 #include <bpf/bpf_endian.h> 11 #include <bpf/bpf_helpers.h> 12 13 #define MAX_ADJST_OFFSET 256 14 #define MAX_PAYLOAD_LEN 5000 15 #define MAX_HDR_LEN 64 16 17 enum { 18 XDP_MODE = 0, 19 XDP_PORT = 1, 20 XDP_ADJST_OFFSET = 2, 21 XDP_ADJST_TAG = 3, 22 } xdp_map_setup_keys; 23 24 enum { 25 XDP_MODE_PASS = 0, 26 XDP_MODE_DROP = 1, 27 XDP_MODE_TX = 2, 28 XDP_MODE_TAIL_ADJST = 3, 29 XDP_MODE_HEAD_ADJST = 4, 30 } xdp_map_modes; 31 32 enum { 33 STATS_RX = 0, 34 STATS_PASS = 1, 35 STATS_DROP = 2, 36 STATS_TX = 3, 37 STATS_ABORT = 4, 38 } xdp_stats; 39 40 struct { 41 __uint(type, BPF_MAP_TYPE_ARRAY); 42 __uint(max_entries, 5); 43 __type(key, __u32); 44 __type(value, __s32); 45 } map_xdp_setup SEC(".maps"); 46 47 struct { 48 __uint(type, BPF_MAP_TYPE_ARRAY); 49 __uint(max_entries, 5); 50 __type(key, __u32); 51 __type(value, __u64); 52 } map_xdp_stats SEC(".maps"); 53 54 static __u32 min(__u32 a, __u32 b) 55 { 56 return a < b ? a : b; 57 } 58 59 static void record_stats(struct xdp_md *ctx, __u32 stat_type) 60 { 61 __u64 *count; 62 63 count = bpf_map_lookup_elem(&map_xdp_stats, &stat_type); 64 65 if (count) 66 __sync_fetch_and_add(count, 1); 67 } 68 69 static struct udphdr *filter_udphdr(struct xdp_md *ctx, __u16 port) 70 { 71 void *data_end = (void *)(long)ctx->data_end; 72 void *data = (void *)(long)ctx->data; 73 struct udphdr *udph = NULL; 74 struct ethhdr *eth = data; 75 76 if (data + sizeof(*eth) > data_end) 77 return NULL; 78 79 if (eth->h_proto == bpf_htons(ETH_P_IP)) { 80 struct iphdr *iph = data + sizeof(*eth); 81 82 if (iph + 1 > (struct iphdr *)data_end || 83 iph->protocol != IPPROTO_UDP) 84 return NULL; 85 86 udph = (void *)eth + sizeof(*iph) + sizeof(*eth); 87 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) { 88 struct ipv6hdr *ipv6h = data + sizeof(*eth); 89 90 if (ipv6h + 1 > (struct ipv6hdr *)data_end || 91 ipv6h->nexthdr != IPPROTO_UDP) 92 return NULL; 93 94 udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth); 95 } else { 96 return NULL; 97 } 98 99 if (udph + 1 > (struct udphdr *)data_end) 100 return NULL; 101 102 if (udph->dest != bpf_htons(port)) 103 return NULL; 104 105 record_stats(ctx, STATS_RX); 106 107 return udph; 108 } 109 110 static int xdp_mode_pass(struct xdp_md *ctx, __u16 port) 111 { 112 struct udphdr *udph = NULL; 113 114 udph = filter_udphdr(ctx, port); 115 if (!udph) 116 return XDP_PASS; 117 118 record_stats(ctx, STATS_PASS); 119 120 return XDP_PASS; 121 } 122 123 static int xdp_mode_drop_handler(struct xdp_md *ctx, __u16 port) 124 { 125 struct udphdr *udph = NULL; 126 127 udph = filter_udphdr(ctx, port); 128 if (!udph) 129 return XDP_PASS; 130 131 record_stats(ctx, STATS_DROP); 132 133 return XDP_DROP; 134 } 135 136 static void swap_machdr(void *data) 137 { 138 struct ethhdr *eth = data; 139 __u8 tmp_mac[ETH_ALEN]; 140 141 __builtin_memcpy(tmp_mac, eth->h_source, ETH_ALEN); 142 __builtin_memcpy(eth->h_source, eth->h_dest, ETH_ALEN); 143 __builtin_memcpy(eth->h_dest, tmp_mac, ETH_ALEN); 144 } 145 146 static int xdp_mode_tx_handler(struct xdp_md *ctx, __u16 port) 147 { 148 void *data_end = (void *)(long)ctx->data_end; 149 void *data = (void *)(long)ctx->data; 150 struct udphdr *udph = NULL; 151 struct ethhdr *eth = data; 152 153 if (data + sizeof(*eth) > data_end) 154 return XDP_PASS; 155 156 if (eth->h_proto == bpf_htons(ETH_P_IP)) { 157 struct iphdr *iph = data + sizeof(*eth); 158 __be32 tmp_ip = iph->saddr; 159 160 if (iph + 1 > (struct iphdr *)data_end || 161 iph->protocol != IPPROTO_UDP) 162 return XDP_PASS; 163 164 udph = data + sizeof(*iph) + sizeof(*eth); 165 166 if (udph + 1 > (struct udphdr *)data_end) 167 return XDP_PASS; 168 if (udph->dest != bpf_htons(port)) 169 return XDP_PASS; 170 171 record_stats(ctx, STATS_RX); 172 swap_machdr((void *)eth); 173 174 iph->saddr = iph->daddr; 175 iph->daddr = tmp_ip; 176 177 record_stats(ctx, STATS_TX); 178 179 return XDP_TX; 180 181 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) { 182 struct ipv6hdr *ipv6h = data + sizeof(*eth); 183 struct in6_addr tmp_ipv6; 184 185 if (ipv6h + 1 > (struct ipv6hdr *)data_end || 186 ipv6h->nexthdr != IPPROTO_UDP) 187 return XDP_PASS; 188 189 udph = data + sizeof(*ipv6h) + sizeof(*eth); 190 191 if (udph + 1 > (struct udphdr *)data_end) 192 return XDP_PASS; 193 if (udph->dest != bpf_htons(port)) 194 return XDP_PASS; 195 196 record_stats(ctx, STATS_RX); 197 swap_machdr((void *)eth); 198 199 __builtin_memcpy(&tmp_ipv6, &ipv6h->saddr, sizeof(tmp_ipv6)); 200 __builtin_memcpy(&ipv6h->saddr, &ipv6h->daddr, 201 sizeof(tmp_ipv6)); 202 __builtin_memcpy(&ipv6h->daddr, &tmp_ipv6, sizeof(tmp_ipv6)); 203 204 record_stats(ctx, STATS_TX); 205 206 return XDP_TX; 207 } 208 209 return XDP_PASS; 210 } 211 212 static void *update_pkt(struct xdp_md *ctx, __s16 offset, __u32 *udp_csum) 213 { 214 void *data_end = (void *)(long)ctx->data_end; 215 void *data = (void *)(long)ctx->data; 216 struct udphdr *udph = NULL; 217 struct ethhdr *eth = data; 218 __u32 len, len_new; 219 220 if (data + sizeof(*eth) > data_end) 221 return NULL; 222 223 if (eth->h_proto == bpf_htons(ETH_P_IP)) { 224 struct iphdr *iph = data + sizeof(*eth); 225 __u16 total_len; 226 227 if (iph + 1 > (struct iphdr *)data_end) 228 return NULL; 229 230 iph->tot_len = bpf_htons(bpf_ntohs(iph->tot_len) + offset); 231 232 udph = (void *)eth + sizeof(*iph) + sizeof(*eth); 233 if (!udph || udph + 1 > (struct udphdr *)data_end) 234 return NULL; 235 236 len_new = bpf_htons(bpf_ntohs(udph->len) + offset); 237 } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) { 238 struct ipv6hdr *ipv6h = data + sizeof(*eth); 239 __u16 payload_len; 240 241 if (ipv6h + 1 > (struct ipv6hdr *)data_end) 242 return NULL; 243 244 udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth); 245 if (!udph || udph + 1 > (struct udphdr *)data_end) 246 return NULL; 247 248 *udp_csum = ~((__u32)udph->check); 249 250 len = ipv6h->payload_len; 251 len_new = bpf_htons(bpf_ntohs(len) + offset); 252 ipv6h->payload_len = len_new; 253 254 *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new, 255 sizeof(len_new), *udp_csum); 256 257 len = udph->len; 258 len_new = bpf_htons(bpf_ntohs(udph->len) + offset); 259 *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new, 260 sizeof(len_new), *udp_csum); 261 } else { 262 return NULL; 263 } 264 265 udph->len = len_new; 266 267 return udph; 268 } 269 270 static __u16 csum_fold_helper(__u32 csum) 271 { 272 return ~((csum & 0xffff) + (csum >> 16)) ? : 0xffff; 273 } 274 275 static int xdp_adjst_tail_shrnk_data(struct xdp_md *ctx, __u16 offset, 276 __u32 hdr_len) 277 { 278 char tmp_buff[MAX_ADJST_OFFSET]; 279 __u32 buff_pos, udp_csum = 0; 280 struct udphdr *udph = NULL; 281 __u32 buff_len; 282 283 udph = update_pkt(ctx, 0 - offset, &udp_csum); 284 if (!udph) 285 return -1; 286 287 buff_len = bpf_xdp_get_buff_len(ctx); 288 289 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 290 offset & 0xff; 291 if (offset == 0) 292 return -1; 293 294 /* Make sure we have enough data to avoid eating the header */ 295 if (buff_len - offset < hdr_len) 296 return -1; 297 298 buff_pos = buff_len - offset; 299 if (bpf_xdp_load_bytes(ctx, buff_pos, tmp_buff, offset) < 0) 300 return -1; 301 302 udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum); 303 udph->check = (__u16)csum_fold_helper(udp_csum); 304 305 if (bpf_xdp_adjust_tail(ctx, 0 - offset) < 0) 306 return -1; 307 308 return 0; 309 } 310 311 static int xdp_adjst_tail_grow_data(struct xdp_md *ctx, __u16 offset) 312 { 313 char tmp_buff[MAX_ADJST_OFFSET]; 314 __u32 buff_pos, udp_csum = 0; 315 __u32 buff_len, hdr_len, key; 316 struct udphdr *udph; 317 __s32 *val; 318 __u8 tag; 319 320 /* Proceed to update the packet headers before attempting to adjuste 321 * the tail. Once the tail is adjusted we lose access to the offset 322 * amount of data at the end of the packet which is crucial to update 323 * the checksum. 324 * Since any failure beyond this would abort the packet, we should 325 * not worry about passing a packet up the stack with wrong headers 326 */ 327 udph = update_pkt(ctx, offset, &udp_csum); 328 if (!udph) 329 return -1; 330 331 key = XDP_ADJST_TAG; 332 val = bpf_map_lookup_elem(&map_xdp_setup, &key); 333 if (!val) 334 return -1; 335 336 tag = (__u8)(*val); 337 338 for (int i = 0; i < MAX_ADJST_OFFSET; i++) 339 __builtin_memcpy(&tmp_buff[i], &tag, 1); 340 341 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 342 offset & 0xff; 343 if (offset == 0) 344 return -1; 345 346 udp_csum = bpf_csum_diff(0, 0, (__be32 *)tmp_buff, offset, udp_csum); 347 udph->check = (__u16)csum_fold_helper(udp_csum); 348 349 buff_len = bpf_xdp_get_buff_len(ctx); 350 351 if (bpf_xdp_adjust_tail(ctx, offset) < 0) { 352 bpf_printk("Failed to adjust tail\n"); 353 return -1; 354 } 355 356 if (bpf_xdp_store_bytes(ctx, buff_len, tmp_buff, offset) < 0) 357 return -1; 358 359 return 0; 360 } 361 362 static int xdp_adjst_tail(struct xdp_md *ctx, __u16 port) 363 { 364 void *data = (void *)(long)ctx->data; 365 struct udphdr *udph = NULL; 366 __s32 *adjust_offset, *val; 367 __u32 key, hdr_len; 368 void *offset_ptr; 369 __u8 tag; 370 int ret; 371 372 udph = filter_udphdr(ctx, port); 373 if (!udph) 374 return XDP_PASS; 375 376 hdr_len = (void *)udph - data + sizeof(struct udphdr); 377 key = XDP_ADJST_OFFSET; 378 adjust_offset = bpf_map_lookup_elem(&map_xdp_setup, &key); 379 if (!adjust_offset) 380 return XDP_PASS; 381 382 if (*adjust_offset < 0) 383 ret = xdp_adjst_tail_shrnk_data(ctx, 384 (__u16)(0 - *adjust_offset), 385 hdr_len); 386 else 387 ret = xdp_adjst_tail_grow_data(ctx, (__u16)(*adjust_offset)); 388 if (ret) 389 goto abort_pkt; 390 391 record_stats(ctx, STATS_PASS); 392 return XDP_PASS; 393 394 abort_pkt: 395 record_stats(ctx, STATS_ABORT); 396 return XDP_ABORTED; 397 } 398 399 static int xdp_adjst_head_shrnk_data(struct xdp_md *ctx, __u64 hdr_len, 400 __u32 offset) 401 { 402 char tmp_buff[MAX_ADJST_OFFSET]; 403 struct udphdr *udph; 404 void *offset_ptr; 405 __u32 udp_csum = 0; 406 407 /* Update the length information in the IP and UDP headers before 408 * adjusting the headroom. This simplifies accessing the relevant 409 * fields in the IP and UDP headers for fragmented packets. Any 410 * failure beyond this point will result in the packet being aborted, 411 * so we don't need to worry about incorrect length information for 412 * passed packets. 413 */ 414 udph = update_pkt(ctx, (__s16)(0 - offset), &udp_csum); 415 if (!udph) 416 return -1; 417 418 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 419 offset & 0xff; 420 if (offset == 0) 421 return -1; 422 423 if (bpf_xdp_load_bytes(ctx, hdr_len, tmp_buff, offset) < 0) 424 return -1; 425 426 udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum); 427 428 udph->check = (__u16)csum_fold_helper(udp_csum); 429 430 if (bpf_xdp_load_bytes(ctx, 0, tmp_buff, MAX_ADJST_OFFSET) < 0) 431 return -1; 432 433 if (bpf_xdp_adjust_head(ctx, offset) < 0) 434 return -1; 435 436 if (offset > MAX_ADJST_OFFSET) 437 return -1; 438 439 if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0) 440 return -1; 441 442 /* Added here to handle clang complain about negative value */ 443 hdr_len = hdr_len & 0xff; 444 445 if (hdr_len == 0) 446 return -1; 447 448 if (bpf_xdp_store_bytes(ctx, 0, tmp_buff, hdr_len) < 0) 449 return -1; 450 451 return 0; 452 } 453 454 static int xdp_adjst_head_grow_data(struct xdp_md *ctx, __u64 hdr_len, 455 __u32 offset) 456 { 457 char hdr_buff[MAX_HDR_LEN]; 458 char data_buff[MAX_ADJST_OFFSET]; 459 void *offset_ptr; 460 __s32 *val; 461 __u32 key; 462 __u8 tag; 463 __u32 udp_csum = 0; 464 struct udphdr *udph; 465 466 udph = update_pkt(ctx, (__s16)(offset), &udp_csum); 467 if (!udph) 468 return -1; 469 470 key = XDP_ADJST_TAG; 471 val = bpf_map_lookup_elem(&map_xdp_setup, &key); 472 if (!val) 473 return -1; 474 475 tag = (__u8)(*val); 476 for (int i = 0; i < MAX_ADJST_OFFSET; i++) 477 __builtin_memcpy(&data_buff[i], &tag, 1); 478 479 offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET : 480 offset & 0xff; 481 if (offset == 0) 482 return -1; 483 484 udp_csum = bpf_csum_diff(0, 0, (__be32 *)data_buff, offset, udp_csum); 485 udph->check = (__u16)csum_fold_helper(udp_csum); 486 487 if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0) 488 return -1; 489 490 /* Added here to handle clang complain about negative value */ 491 hdr_len = hdr_len & 0xff; 492 493 if (hdr_len == 0) 494 return -1; 495 496 if (bpf_xdp_load_bytes(ctx, 0, hdr_buff, hdr_len) < 0) 497 return -1; 498 499 if (offset > MAX_ADJST_OFFSET) 500 return -1; 501 502 if (bpf_xdp_adjust_head(ctx, 0 - offset) < 0) 503 return -1; 504 505 if (bpf_xdp_store_bytes(ctx, 0, hdr_buff, hdr_len) < 0) 506 return -1; 507 508 if (bpf_xdp_store_bytes(ctx, hdr_len, data_buff, offset) < 0) 509 return -1; 510 511 return 0; 512 } 513 514 static int xdp_head_adjst(struct xdp_md *ctx, __u16 port) 515 { 516 void *data_end = (void *)(long)ctx->data_end; 517 void *data = (void *)(long)ctx->data; 518 struct udphdr *udph_ptr = NULL; 519 __u32 key, size, hdr_len; 520 __s32 *val; 521 int res; 522 523 /* Filter packets based on UDP port */ 524 udph_ptr = filter_udphdr(ctx, port); 525 if (!udph_ptr) 526 return XDP_PASS; 527 528 hdr_len = (void *)udph_ptr - data + sizeof(struct udphdr); 529 530 key = XDP_ADJST_OFFSET; 531 val = bpf_map_lookup_elem(&map_xdp_setup, &key); 532 if (!val) 533 return XDP_PASS; 534 535 switch (*val) { 536 case -16: 537 case 16: 538 size = 16; 539 break; 540 case -32: 541 case 32: 542 size = 32; 543 break; 544 case -64: 545 case 64: 546 size = 64; 547 break; 548 case -128: 549 case 128: 550 size = 128; 551 break; 552 case -256: 553 case 256: 554 size = 256; 555 break; 556 default: 557 bpf_printk("Invalid adjustment offset: %d\n", *val); 558 goto abort; 559 } 560 561 if (*val < 0) 562 res = xdp_adjst_head_grow_data(ctx, hdr_len, size); 563 else 564 res = xdp_adjst_head_shrnk_data(ctx, hdr_len, size); 565 566 if (res) 567 goto abort; 568 569 record_stats(ctx, STATS_PASS); 570 return XDP_PASS; 571 572 abort: 573 record_stats(ctx, STATS_ABORT); 574 return XDP_ABORTED; 575 } 576 577 static int xdp_prog_common(struct xdp_md *ctx) 578 { 579 __u32 key, *port; 580 __s32 *mode; 581 582 key = XDP_MODE; 583 mode = bpf_map_lookup_elem(&map_xdp_setup, &key); 584 if (!mode) 585 return XDP_PASS; 586 587 key = XDP_PORT; 588 port = bpf_map_lookup_elem(&map_xdp_setup, &key); 589 if (!port) 590 return XDP_PASS; 591 592 switch (*mode) { 593 case XDP_MODE_PASS: 594 return xdp_mode_pass(ctx, (__u16)(*port)); 595 case XDP_MODE_DROP: 596 return xdp_mode_drop_handler(ctx, (__u16)(*port)); 597 case XDP_MODE_TX: 598 return xdp_mode_tx_handler(ctx, (__u16)(*port)); 599 case XDP_MODE_TAIL_ADJST: 600 return xdp_adjst_tail(ctx, (__u16)(*port)); 601 case XDP_MODE_HEAD_ADJST: 602 return xdp_head_adjst(ctx, (__u16)(*port)); 603 } 604 605 /* Default action is to simple pass */ 606 return XDP_PASS; 607 } 608 609 SEC("xdp") 610 int xdp_prog(struct xdp_md *ctx) 611 { 612 return xdp_prog_common(ctx); 613 } 614 615 SEC("xdp.frags") 616 int xdp_prog_frags(struct xdp_md *ctx) 617 { 618 return xdp_prog_common(ctx); 619 } 620 621 char _license[] SEC("license") = "GPL"; 622