1 /* 2 * Checksum updating actions 3 * 4 * Copyright (c) 2010 Gregoire Baron <baronchon@n7mm.org> 5 * 6 * This program is free software; you can redistribute it and/or modify it 7 * under the terms of the GNU General Public License as published by the Free 8 * Software Foundation; either version 2 of the License, or (at your option) 9 * any later version. 10 * 11 */ 12 13 #include <linux/types.h> 14 #include <linux/init.h> 15 #include <linux/kernel.h> 16 #include <linux/module.h> 17 #include <linux/spinlock.h> 18 19 #include <linux/netlink.h> 20 #include <net/netlink.h> 21 #include <linux/rtnetlink.h> 22 23 #include <linux/skbuff.h> 24 25 #include <net/ip.h> 26 #include <net/ipv6.h> 27 #include <net/icmp.h> 28 #include <linux/icmpv6.h> 29 #include <linux/igmp.h> 30 #include <net/tcp.h> 31 #include <net/udp.h> 32 #include <net/ip6_checksum.h> 33 #include <net/sctp/checksum.h> 34 35 #include <net/act_api.h> 36 37 #include <linux/tc_act/tc_csum.h> 38 #include <net/tc_act/tc_csum.h> 39 40 static const struct nla_policy csum_policy[TCA_CSUM_MAX + 1] = { 41 [TCA_CSUM_PARMS] = { .len = sizeof(struct tc_csum), }, 42 }; 43 44 static unsigned int csum_net_id; 45 static struct tc_action_ops act_csum_ops; 46 47 static int tcf_csum_init(struct net *net, struct nlattr *nla, 48 struct nlattr *est, struct tc_action **a, int ovr, 49 int bind, struct netlink_ext_ack *extack) 50 { 51 struct tc_action_net *tn = net_generic(net, csum_net_id); 52 struct tcf_csum_params *params_old, *params_new; 53 struct nlattr *tb[TCA_CSUM_MAX + 1]; 54 struct tc_csum *parm; 55 struct tcf_csum *p; 56 int ret = 0, err; 57 58 if (nla == NULL) 59 return -EINVAL; 60 61 err = nla_parse_nested(tb, TCA_CSUM_MAX, nla, csum_policy, NULL); 62 if (err < 0) 63 return err; 64 65 if (tb[TCA_CSUM_PARMS] == NULL) 66 return -EINVAL; 67 parm = nla_data(tb[TCA_CSUM_PARMS]); 68 69 if (!tcf_idr_check(tn, parm->index, a, bind)) { 70 ret = tcf_idr_create(tn, parm->index, est, a, 71 &act_csum_ops, bind, true); 72 if (ret) 73 return ret; 74 ret = ACT_P_CREATED; 75 } else { 76 if (bind)/* dont override defaults */ 77 return 0; 78 tcf_idr_release(*a, bind); 79 if (!ovr) 80 return -EEXIST; 81 } 82 83 p = to_tcf_csum(*a); 84 ASSERT_RTNL(); 85 86 params_new = kzalloc(sizeof(*params_new), GFP_KERNEL); 87 if (unlikely(!params_new)) { 88 if (ret == ACT_P_CREATED) 89 tcf_idr_release(*a, bind); 90 return -ENOMEM; 91 } 92 params_old = rtnl_dereference(p->params); 93 94 p->tcf_action = parm->action; 95 params_new->update_flags = parm->update_flags; 96 rcu_assign_pointer(p->params, params_new); 97 if (params_old) 98 kfree_rcu(params_old, rcu); 99 100 if (ret == ACT_P_CREATED) 101 tcf_idr_insert(tn, *a); 102 103 return ret; 104 } 105 106 /** 107 * tcf_csum_skb_nextlayer - Get next layer pointer 108 * @skb: sk_buff to use 109 * @ihl: previous summed headers length 110 * @ipl: complete packet length 111 * @jhl: next header length 112 * 113 * Check the expected next layer availability in the specified sk_buff. 114 * Return the next layer pointer if pass, NULL otherwise. 115 */ 116 static void *tcf_csum_skb_nextlayer(struct sk_buff *skb, 117 unsigned int ihl, unsigned int ipl, 118 unsigned int jhl) 119 { 120 int ntkoff = skb_network_offset(skb); 121 int hl = ihl + jhl; 122 123 if (!pskb_may_pull(skb, ipl + ntkoff) || (ipl < hl) || 124 skb_try_make_writable(skb, hl + ntkoff)) 125 return NULL; 126 else 127 return (void *)(skb_network_header(skb) + ihl); 128 } 129 130 static int tcf_csum_ipv4_icmp(struct sk_buff *skb, unsigned int ihl, 131 unsigned int ipl) 132 { 133 struct icmphdr *icmph; 134 135 icmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmph)); 136 if (icmph == NULL) 137 return 0; 138 139 icmph->checksum = 0; 140 skb->csum = csum_partial(icmph, ipl - ihl, 0); 141 icmph->checksum = csum_fold(skb->csum); 142 143 skb->ip_summed = CHECKSUM_NONE; 144 145 return 1; 146 } 147 148 static int tcf_csum_ipv4_igmp(struct sk_buff *skb, 149 unsigned int ihl, unsigned int ipl) 150 { 151 struct igmphdr *igmph; 152 153 igmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*igmph)); 154 if (igmph == NULL) 155 return 0; 156 157 igmph->csum = 0; 158 skb->csum = csum_partial(igmph, ipl - ihl, 0); 159 igmph->csum = csum_fold(skb->csum); 160 161 skb->ip_summed = CHECKSUM_NONE; 162 163 return 1; 164 } 165 166 static int tcf_csum_ipv6_icmp(struct sk_buff *skb, unsigned int ihl, 167 unsigned int ipl) 168 { 169 struct icmp6hdr *icmp6h; 170 const struct ipv6hdr *ip6h; 171 172 icmp6h = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmp6h)); 173 if (icmp6h == NULL) 174 return 0; 175 176 ip6h = ipv6_hdr(skb); 177 icmp6h->icmp6_cksum = 0; 178 skb->csum = csum_partial(icmp6h, ipl - ihl, 0); 179 icmp6h->icmp6_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, 180 ipl - ihl, IPPROTO_ICMPV6, 181 skb->csum); 182 183 skb->ip_summed = CHECKSUM_NONE; 184 185 return 1; 186 } 187 188 static int tcf_csum_ipv4_tcp(struct sk_buff *skb, unsigned int ihl, 189 unsigned int ipl) 190 { 191 struct tcphdr *tcph; 192 const struct iphdr *iph; 193 194 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV4) 195 return 1; 196 197 tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph)); 198 if (tcph == NULL) 199 return 0; 200 201 iph = ip_hdr(skb); 202 tcph->check = 0; 203 skb->csum = csum_partial(tcph, ipl - ihl, 0); 204 tcph->check = tcp_v4_check(ipl - ihl, 205 iph->saddr, iph->daddr, skb->csum); 206 207 skb->ip_summed = CHECKSUM_NONE; 208 209 return 1; 210 } 211 212 static int tcf_csum_ipv6_tcp(struct sk_buff *skb, unsigned int ihl, 213 unsigned int ipl) 214 { 215 struct tcphdr *tcph; 216 const struct ipv6hdr *ip6h; 217 218 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_TCPV6) 219 return 1; 220 221 tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph)); 222 if (tcph == NULL) 223 return 0; 224 225 ip6h = ipv6_hdr(skb); 226 tcph->check = 0; 227 skb->csum = csum_partial(tcph, ipl - ihl, 0); 228 tcph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, 229 ipl - ihl, IPPROTO_TCP, 230 skb->csum); 231 232 skb->ip_summed = CHECKSUM_NONE; 233 234 return 1; 235 } 236 237 static int tcf_csum_ipv4_udp(struct sk_buff *skb, unsigned int ihl, 238 unsigned int ipl, int udplite) 239 { 240 struct udphdr *udph; 241 const struct iphdr *iph; 242 u16 ul; 243 244 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_UDP) 245 return 1; 246 247 /* 248 * Support both UDP and UDPLITE checksum algorithms, Don't use 249 * udph->len to get the real length without any protocol check, 250 * UDPLITE uses udph->len for another thing, 251 * Use iph->tot_len, or just ipl. 252 */ 253 254 udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph)); 255 if (udph == NULL) 256 return 0; 257 258 iph = ip_hdr(skb); 259 ul = ntohs(udph->len); 260 261 if (udplite || udph->check) { 262 263 udph->check = 0; 264 265 if (udplite) { 266 if (ul == 0) 267 skb->csum = csum_partial(udph, ipl - ihl, 0); 268 else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl)) 269 skb->csum = csum_partial(udph, ul, 0); 270 else 271 goto ignore_obscure_skb; 272 } else { 273 if (ul != ipl - ihl) 274 goto ignore_obscure_skb; 275 276 skb->csum = csum_partial(udph, ul, 0); 277 } 278 279 udph->check = csum_tcpudp_magic(iph->saddr, iph->daddr, 280 ul, iph->protocol, 281 skb->csum); 282 283 if (!udph->check) 284 udph->check = CSUM_MANGLED_0; 285 } 286 287 skb->ip_summed = CHECKSUM_NONE; 288 289 ignore_obscure_skb: 290 return 1; 291 } 292 293 static int tcf_csum_ipv6_udp(struct sk_buff *skb, unsigned int ihl, 294 unsigned int ipl, int udplite) 295 { 296 struct udphdr *udph; 297 const struct ipv6hdr *ip6h; 298 u16 ul; 299 300 if (skb_is_gso(skb) && skb_shinfo(skb)->gso_type & SKB_GSO_UDP) 301 return 1; 302 303 /* 304 * Support both UDP and UDPLITE checksum algorithms, Don't use 305 * udph->len to get the real length without any protocol check, 306 * UDPLITE uses udph->len for another thing, 307 * Use ip6h->payload_len + sizeof(*ip6h) ... , or just ipl. 308 */ 309 310 udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph)); 311 if (udph == NULL) 312 return 0; 313 314 ip6h = ipv6_hdr(skb); 315 ul = ntohs(udph->len); 316 317 udph->check = 0; 318 319 if (udplite) { 320 if (ul == 0) 321 skb->csum = csum_partial(udph, ipl - ihl, 0); 322 323 else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl)) 324 skb->csum = csum_partial(udph, ul, 0); 325 326 else 327 goto ignore_obscure_skb; 328 } else { 329 if (ul != ipl - ihl) 330 goto ignore_obscure_skb; 331 332 skb->csum = csum_partial(udph, ul, 0); 333 } 334 335 udph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, ul, 336 udplite ? IPPROTO_UDPLITE : IPPROTO_UDP, 337 skb->csum); 338 339 if (!udph->check) 340 udph->check = CSUM_MANGLED_0; 341 342 skb->ip_summed = CHECKSUM_NONE; 343 344 ignore_obscure_skb: 345 return 1; 346 } 347 348 static int tcf_csum_sctp(struct sk_buff *skb, unsigned int ihl, 349 unsigned int ipl) 350 { 351 struct sctphdr *sctph; 352 353 if (skb_is_gso(skb) && skb_is_gso_sctp(skb)) 354 return 1; 355 356 sctph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*sctph)); 357 if (!sctph) 358 return 0; 359 360 sctph->checksum = sctp_compute_cksum(skb, 361 skb_network_offset(skb) + ihl); 362 skb->ip_summed = CHECKSUM_NONE; 363 skb->csum_not_inet = 0; 364 365 return 1; 366 } 367 368 static int tcf_csum_ipv4(struct sk_buff *skb, u32 update_flags) 369 { 370 const struct iphdr *iph; 371 int ntkoff; 372 373 ntkoff = skb_network_offset(skb); 374 375 if (!pskb_may_pull(skb, sizeof(*iph) + ntkoff)) 376 goto fail; 377 378 iph = ip_hdr(skb); 379 380 switch (iph->frag_off & htons(IP_OFFSET) ? 0 : iph->protocol) { 381 case IPPROTO_ICMP: 382 if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP) 383 if (!tcf_csum_ipv4_icmp(skb, iph->ihl * 4, 384 ntohs(iph->tot_len))) 385 goto fail; 386 break; 387 case IPPROTO_IGMP: 388 if (update_flags & TCA_CSUM_UPDATE_FLAG_IGMP) 389 if (!tcf_csum_ipv4_igmp(skb, iph->ihl * 4, 390 ntohs(iph->tot_len))) 391 goto fail; 392 break; 393 case IPPROTO_TCP: 394 if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP) 395 if (!tcf_csum_ipv4_tcp(skb, iph->ihl * 4, 396 ntohs(iph->tot_len))) 397 goto fail; 398 break; 399 case IPPROTO_UDP: 400 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP) 401 if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4, 402 ntohs(iph->tot_len), 0)) 403 goto fail; 404 break; 405 case IPPROTO_UDPLITE: 406 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE) 407 if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4, 408 ntohs(iph->tot_len), 1)) 409 goto fail; 410 break; 411 case IPPROTO_SCTP: 412 if ((update_flags & TCA_CSUM_UPDATE_FLAG_SCTP) && 413 !tcf_csum_sctp(skb, iph->ihl * 4, ntohs(iph->tot_len))) 414 goto fail; 415 break; 416 } 417 418 if (update_flags & TCA_CSUM_UPDATE_FLAG_IPV4HDR) { 419 if (skb_try_make_writable(skb, sizeof(*iph) + ntkoff)) 420 goto fail; 421 422 ip_send_check(ip_hdr(skb)); 423 } 424 425 return 1; 426 427 fail: 428 return 0; 429 } 430 431 static int tcf_csum_ipv6_hopopts(struct ipv6_opt_hdr *ip6xh, unsigned int ixhl, 432 unsigned int *pl) 433 { 434 int off, len, optlen; 435 unsigned char *xh = (void *)ip6xh; 436 437 off = sizeof(*ip6xh); 438 len = ixhl - off; 439 440 while (len > 1) { 441 switch (xh[off]) { 442 case IPV6_TLV_PAD1: 443 optlen = 1; 444 break; 445 case IPV6_TLV_JUMBO: 446 optlen = xh[off + 1] + 2; 447 if (optlen != 6 || len < 6 || (off & 3) != 2) 448 /* wrong jumbo option length/alignment */ 449 return 0; 450 *pl = ntohl(*(__be32 *)(xh + off + 2)); 451 goto done; 452 default: 453 optlen = xh[off + 1] + 2; 454 if (optlen > len) 455 /* ignore obscure options */ 456 goto done; 457 break; 458 } 459 off += optlen; 460 len -= optlen; 461 } 462 463 done: 464 return 1; 465 } 466 467 static int tcf_csum_ipv6(struct sk_buff *skb, u32 update_flags) 468 { 469 struct ipv6hdr *ip6h; 470 struct ipv6_opt_hdr *ip6xh; 471 unsigned int hl, ixhl; 472 unsigned int pl; 473 int ntkoff; 474 u8 nexthdr; 475 476 ntkoff = skb_network_offset(skb); 477 478 hl = sizeof(*ip6h); 479 480 if (!pskb_may_pull(skb, hl + ntkoff)) 481 goto fail; 482 483 ip6h = ipv6_hdr(skb); 484 485 pl = ntohs(ip6h->payload_len); 486 nexthdr = ip6h->nexthdr; 487 488 do { 489 switch (nexthdr) { 490 case NEXTHDR_FRAGMENT: 491 goto ignore_skb; 492 case NEXTHDR_ROUTING: 493 case NEXTHDR_HOP: 494 case NEXTHDR_DEST: 495 if (!pskb_may_pull(skb, hl + sizeof(*ip6xh) + ntkoff)) 496 goto fail; 497 ip6xh = (void *)(skb_network_header(skb) + hl); 498 ixhl = ipv6_optlen(ip6xh); 499 if (!pskb_may_pull(skb, hl + ixhl + ntkoff)) 500 goto fail; 501 ip6xh = (void *)(skb_network_header(skb) + hl); 502 if ((nexthdr == NEXTHDR_HOP) && 503 !(tcf_csum_ipv6_hopopts(ip6xh, ixhl, &pl))) 504 goto fail; 505 nexthdr = ip6xh->nexthdr; 506 hl += ixhl; 507 break; 508 case IPPROTO_ICMPV6: 509 if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP) 510 if (!tcf_csum_ipv6_icmp(skb, 511 hl, pl + sizeof(*ip6h))) 512 goto fail; 513 goto done; 514 case IPPROTO_TCP: 515 if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP) 516 if (!tcf_csum_ipv6_tcp(skb, 517 hl, pl + sizeof(*ip6h))) 518 goto fail; 519 goto done; 520 case IPPROTO_UDP: 521 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP) 522 if (!tcf_csum_ipv6_udp(skb, hl, 523 pl + sizeof(*ip6h), 0)) 524 goto fail; 525 goto done; 526 case IPPROTO_UDPLITE: 527 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE) 528 if (!tcf_csum_ipv6_udp(skb, hl, 529 pl + sizeof(*ip6h), 1)) 530 goto fail; 531 goto done; 532 case IPPROTO_SCTP: 533 if ((update_flags & TCA_CSUM_UPDATE_FLAG_SCTP) && 534 !tcf_csum_sctp(skb, hl, pl + sizeof(*ip6h))) 535 goto fail; 536 goto done; 537 default: 538 goto ignore_skb; 539 } 540 } while (pskb_may_pull(skb, hl + 1 + ntkoff)); 541 542 done: 543 ignore_skb: 544 return 1; 545 546 fail: 547 return 0; 548 } 549 550 static int tcf_csum(struct sk_buff *skb, const struct tc_action *a, 551 struct tcf_result *res) 552 { 553 struct tcf_csum *p = to_tcf_csum(a); 554 struct tcf_csum_params *params; 555 u32 update_flags; 556 int action; 557 558 rcu_read_lock(); 559 params = rcu_dereference(p->params); 560 561 tcf_lastuse_update(&p->tcf_tm); 562 bstats_cpu_update(this_cpu_ptr(p->common.cpu_bstats), skb); 563 564 action = READ_ONCE(p->tcf_action); 565 if (unlikely(action == TC_ACT_SHOT)) 566 goto drop_stats; 567 568 update_flags = params->update_flags; 569 switch (tc_skb_protocol(skb)) { 570 case cpu_to_be16(ETH_P_IP): 571 if (!tcf_csum_ipv4(skb, update_flags)) 572 goto drop; 573 break; 574 case cpu_to_be16(ETH_P_IPV6): 575 if (!tcf_csum_ipv6(skb, update_flags)) 576 goto drop; 577 break; 578 } 579 580 unlock: 581 rcu_read_unlock(); 582 return action; 583 584 drop: 585 action = TC_ACT_SHOT; 586 587 drop_stats: 588 qstats_drop_inc(this_cpu_ptr(p->common.cpu_qstats)); 589 goto unlock; 590 } 591 592 static int tcf_csum_dump(struct sk_buff *skb, struct tc_action *a, int bind, 593 int ref) 594 { 595 unsigned char *b = skb_tail_pointer(skb); 596 struct tcf_csum *p = to_tcf_csum(a); 597 struct tcf_csum_params *params; 598 struct tc_csum opt = { 599 .index = p->tcf_index, 600 .refcnt = p->tcf_refcnt - ref, 601 .bindcnt = p->tcf_bindcnt - bind, 602 .action = p->tcf_action, 603 }; 604 struct tcf_t t; 605 606 params = rtnl_dereference(p->params); 607 opt.update_flags = params->update_flags; 608 609 if (nla_put(skb, TCA_CSUM_PARMS, sizeof(opt), &opt)) 610 goto nla_put_failure; 611 612 tcf_tm_dump(&t, &p->tcf_tm); 613 if (nla_put_64bit(skb, TCA_CSUM_TM, sizeof(t), &t, TCA_CSUM_PAD)) 614 goto nla_put_failure; 615 616 return skb->len; 617 618 nla_put_failure: 619 nlmsg_trim(skb, b); 620 return -1; 621 } 622 623 static void tcf_csum_cleanup(struct tc_action *a) 624 { 625 struct tcf_csum *p = to_tcf_csum(a); 626 struct tcf_csum_params *params; 627 628 params = rcu_dereference_protected(p->params, 1); 629 if (params) 630 kfree_rcu(params, rcu); 631 } 632 633 static int tcf_csum_walker(struct net *net, struct sk_buff *skb, 634 struct netlink_callback *cb, int type, 635 const struct tc_action_ops *ops, 636 struct netlink_ext_ack *extack) 637 { 638 struct tc_action_net *tn = net_generic(net, csum_net_id); 639 640 return tcf_generic_walker(tn, skb, cb, type, ops, extack); 641 } 642 643 static int tcf_csum_search(struct net *net, struct tc_action **a, u32 index, 644 struct netlink_ext_ack *extack) 645 { 646 struct tc_action_net *tn = net_generic(net, csum_net_id); 647 648 return tcf_idr_search(tn, a, index); 649 } 650 651 static size_t tcf_csum_get_fill_size(const struct tc_action *act) 652 { 653 return nla_total_size(sizeof(struct tc_csum)); 654 } 655 656 static struct tc_action_ops act_csum_ops = { 657 .kind = "csum", 658 .type = TCA_ACT_CSUM, 659 .owner = THIS_MODULE, 660 .act = tcf_csum, 661 .dump = tcf_csum_dump, 662 .init = tcf_csum_init, 663 .cleanup = tcf_csum_cleanup, 664 .walk = tcf_csum_walker, 665 .lookup = tcf_csum_search, 666 .get_fill_size = tcf_csum_get_fill_size, 667 .size = sizeof(struct tcf_csum), 668 }; 669 670 static __net_init int csum_init_net(struct net *net) 671 { 672 struct tc_action_net *tn = net_generic(net, csum_net_id); 673 674 return tc_action_net_init(tn, &act_csum_ops); 675 } 676 677 static void __net_exit csum_exit_net(struct list_head *net_list) 678 { 679 tc_action_net_exit(net_list, csum_net_id); 680 } 681 682 static struct pernet_operations csum_net_ops = { 683 .init = csum_init_net, 684 .exit_batch = csum_exit_net, 685 .id = &csum_net_id, 686 .size = sizeof(struct tc_action_net), 687 }; 688 689 MODULE_DESCRIPTION("Checksum updating actions"); 690 MODULE_LICENSE("GPL"); 691 692 static int __init csum_init_module(void) 693 { 694 return tcf_register_action(&act_csum_ops, &csum_net_ops); 695 } 696 697 static void __exit csum_cleanup_module(void) 698 { 699 tcf_unregister_action(&act_csum_ops, &csum_net_ops); 700 } 701 702 module_init(csum_init_module); 703 module_exit(csum_cleanup_module); 704