1 /* 2 * Checksum updating actions 3 * 4 * Copyright (c) 2010 Gregoire Baron <baronchon@n7mm.org> 5 * 6 * This program is free software; you can redistribute it and/or modify it 7 * under the terms of the GNU General Public License as published by the Free 8 * Software Foundation; either version 2 of the License, or (at your option) 9 * any later version. 10 * 11 */ 12 13 #include <linux/types.h> 14 #include <linux/init.h> 15 #include <linux/kernel.h> 16 #include <linux/module.h> 17 #include <linux/spinlock.h> 18 19 #include <linux/netlink.h> 20 #include <net/netlink.h> 21 #include <linux/rtnetlink.h> 22 23 #include <linux/skbuff.h> 24 25 #include <net/ip.h> 26 #include <net/ipv6.h> 27 #include <net/icmp.h> 28 #include <linux/icmpv6.h> 29 #include <linux/igmp.h> 30 #include <net/tcp.h> 31 #include <net/udp.h> 32 #include <net/ip6_checksum.h> 33 34 #include <net/act_api.h> 35 36 #include <linux/tc_act/tc_csum.h> 37 #include <net/tc_act/tc_csum.h> 38 39 #define CSUM_TAB_MASK 15 40 41 static const struct nla_policy csum_policy[TCA_CSUM_MAX + 1] = { 42 [TCA_CSUM_PARMS] = { .len = sizeof(struct tc_csum), }, 43 }; 44 45 static int csum_net_id; 46 47 static int tcf_csum_init(struct net *net, struct nlattr *nla, 48 struct nlattr *est, struct tc_action *a, int ovr, 49 int bind) 50 { 51 struct tc_action_net *tn = net_generic(net, csum_net_id); 52 struct nlattr *tb[TCA_CSUM_MAX + 1]; 53 struct tc_csum *parm; 54 struct tcf_csum *p; 55 int ret = 0, err; 56 57 if (nla == NULL) 58 return -EINVAL; 59 60 err = nla_parse_nested(tb, TCA_CSUM_MAX, nla, csum_policy); 61 if (err < 0) 62 return err; 63 64 if (tb[TCA_CSUM_PARMS] == NULL) 65 return -EINVAL; 66 parm = nla_data(tb[TCA_CSUM_PARMS]); 67 68 if (!tcf_hash_check(tn, parm->index, a, bind)) { 69 ret = tcf_hash_create(tn, parm->index, est, a, 70 sizeof(*p), bind, false); 71 if (ret) 72 return ret; 73 ret = ACT_P_CREATED; 74 } else { 75 if (bind)/* dont override defaults */ 76 return 0; 77 tcf_hash_release(a, bind); 78 if (!ovr) 79 return -EEXIST; 80 } 81 82 p = to_tcf_csum(a); 83 spin_lock_bh(&p->tcf_lock); 84 p->tcf_action = parm->action; 85 p->update_flags = parm->update_flags; 86 spin_unlock_bh(&p->tcf_lock); 87 88 if (ret == ACT_P_CREATED) 89 tcf_hash_insert(tn, a); 90 91 return ret; 92 } 93 94 /** 95 * tcf_csum_skb_nextlayer - Get next layer pointer 96 * @skb: sk_buff to use 97 * @ihl: previous summed headers length 98 * @ipl: complete packet length 99 * @jhl: next header length 100 * 101 * Check the expected next layer availability in the specified sk_buff. 102 * Return the next layer pointer if pass, NULL otherwise. 103 */ 104 static void *tcf_csum_skb_nextlayer(struct sk_buff *skb, 105 unsigned int ihl, unsigned int ipl, 106 unsigned int jhl) 107 { 108 int ntkoff = skb_network_offset(skb); 109 int hl = ihl + jhl; 110 111 if (!pskb_may_pull(skb, ipl + ntkoff) || (ipl < hl) || 112 skb_try_make_writable(skb, hl + ntkoff)) 113 return NULL; 114 else 115 return (void *)(skb_network_header(skb) + ihl); 116 } 117 118 static int tcf_csum_ipv4_icmp(struct sk_buff *skb, 119 unsigned int ihl, unsigned int ipl) 120 { 121 struct icmphdr *icmph; 122 123 icmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmph)); 124 if (icmph == NULL) 125 return 0; 126 127 icmph->checksum = 0; 128 skb->csum = csum_partial(icmph, ipl - ihl, 0); 129 icmph->checksum = csum_fold(skb->csum); 130 131 skb->ip_summed = CHECKSUM_NONE; 132 133 return 1; 134 } 135 136 static int tcf_csum_ipv4_igmp(struct sk_buff *skb, 137 unsigned int ihl, unsigned int ipl) 138 { 139 struct igmphdr *igmph; 140 141 igmph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*igmph)); 142 if (igmph == NULL) 143 return 0; 144 145 igmph->csum = 0; 146 skb->csum = csum_partial(igmph, ipl - ihl, 0); 147 igmph->csum = csum_fold(skb->csum); 148 149 skb->ip_summed = CHECKSUM_NONE; 150 151 return 1; 152 } 153 154 static int tcf_csum_ipv6_icmp(struct sk_buff *skb, 155 unsigned int ihl, unsigned int ipl) 156 { 157 struct icmp6hdr *icmp6h; 158 const struct ipv6hdr *ip6h; 159 160 icmp6h = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*icmp6h)); 161 if (icmp6h == NULL) 162 return 0; 163 164 ip6h = ipv6_hdr(skb); 165 icmp6h->icmp6_cksum = 0; 166 skb->csum = csum_partial(icmp6h, ipl - ihl, 0); 167 icmp6h->icmp6_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, 168 ipl - ihl, IPPROTO_ICMPV6, 169 skb->csum); 170 171 skb->ip_summed = CHECKSUM_NONE; 172 173 return 1; 174 } 175 176 static int tcf_csum_ipv4_tcp(struct sk_buff *skb, 177 unsigned int ihl, unsigned int ipl) 178 { 179 struct tcphdr *tcph; 180 const struct iphdr *iph; 181 182 tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph)); 183 if (tcph == NULL) 184 return 0; 185 186 iph = ip_hdr(skb); 187 tcph->check = 0; 188 skb->csum = csum_partial(tcph, ipl - ihl, 0); 189 tcph->check = tcp_v4_check(ipl - ihl, 190 iph->saddr, iph->daddr, skb->csum); 191 192 skb->ip_summed = CHECKSUM_NONE; 193 194 return 1; 195 } 196 197 static int tcf_csum_ipv6_tcp(struct sk_buff *skb, 198 unsigned int ihl, unsigned int ipl) 199 { 200 struct tcphdr *tcph; 201 const struct ipv6hdr *ip6h; 202 203 tcph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*tcph)); 204 if (tcph == NULL) 205 return 0; 206 207 ip6h = ipv6_hdr(skb); 208 tcph->check = 0; 209 skb->csum = csum_partial(tcph, ipl - ihl, 0); 210 tcph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, 211 ipl - ihl, IPPROTO_TCP, 212 skb->csum); 213 214 skb->ip_summed = CHECKSUM_NONE; 215 216 return 1; 217 } 218 219 static int tcf_csum_ipv4_udp(struct sk_buff *skb, 220 unsigned int ihl, unsigned int ipl, int udplite) 221 { 222 struct udphdr *udph; 223 const struct iphdr *iph; 224 u16 ul; 225 226 /* 227 * Support both UDP and UDPLITE checksum algorithms, Don't use 228 * udph->len to get the real length without any protocol check, 229 * UDPLITE uses udph->len for another thing, 230 * Use iph->tot_len, or just ipl. 231 */ 232 233 udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph)); 234 if (udph == NULL) 235 return 0; 236 237 iph = ip_hdr(skb); 238 ul = ntohs(udph->len); 239 240 if (udplite || udph->check) { 241 242 udph->check = 0; 243 244 if (udplite) { 245 if (ul == 0) 246 skb->csum = csum_partial(udph, ipl - ihl, 0); 247 else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl)) 248 skb->csum = csum_partial(udph, ul, 0); 249 else 250 goto ignore_obscure_skb; 251 } else { 252 if (ul != ipl - ihl) 253 goto ignore_obscure_skb; 254 255 skb->csum = csum_partial(udph, ul, 0); 256 } 257 258 udph->check = csum_tcpudp_magic(iph->saddr, iph->daddr, 259 ul, iph->protocol, 260 skb->csum); 261 262 if (!udph->check) 263 udph->check = CSUM_MANGLED_0; 264 } 265 266 skb->ip_summed = CHECKSUM_NONE; 267 268 ignore_obscure_skb: 269 return 1; 270 } 271 272 static int tcf_csum_ipv6_udp(struct sk_buff *skb, 273 unsigned int ihl, unsigned int ipl, int udplite) 274 { 275 struct udphdr *udph; 276 const struct ipv6hdr *ip6h; 277 u16 ul; 278 279 /* 280 * Support both UDP and UDPLITE checksum algorithms, Don't use 281 * udph->len to get the real length without any protocol check, 282 * UDPLITE uses udph->len for another thing, 283 * Use ip6h->payload_len + sizeof(*ip6h) ... , or just ipl. 284 */ 285 286 udph = tcf_csum_skb_nextlayer(skb, ihl, ipl, sizeof(*udph)); 287 if (udph == NULL) 288 return 0; 289 290 ip6h = ipv6_hdr(skb); 291 ul = ntohs(udph->len); 292 293 udph->check = 0; 294 295 if (udplite) { 296 if (ul == 0) 297 skb->csum = csum_partial(udph, ipl - ihl, 0); 298 299 else if ((ul >= sizeof(*udph)) && (ul <= ipl - ihl)) 300 skb->csum = csum_partial(udph, ul, 0); 301 302 else 303 goto ignore_obscure_skb; 304 } else { 305 if (ul != ipl - ihl) 306 goto ignore_obscure_skb; 307 308 skb->csum = csum_partial(udph, ul, 0); 309 } 310 311 udph->check = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr, ul, 312 udplite ? IPPROTO_UDPLITE : IPPROTO_UDP, 313 skb->csum); 314 315 if (!udph->check) 316 udph->check = CSUM_MANGLED_0; 317 318 skb->ip_summed = CHECKSUM_NONE; 319 320 ignore_obscure_skb: 321 return 1; 322 } 323 324 static int tcf_csum_ipv4(struct sk_buff *skb, u32 update_flags) 325 { 326 const struct iphdr *iph; 327 int ntkoff; 328 329 ntkoff = skb_network_offset(skb); 330 331 if (!pskb_may_pull(skb, sizeof(*iph) + ntkoff)) 332 goto fail; 333 334 iph = ip_hdr(skb); 335 336 switch (iph->frag_off & htons(IP_OFFSET) ? 0 : iph->protocol) { 337 case IPPROTO_ICMP: 338 if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP) 339 if (!tcf_csum_ipv4_icmp(skb, iph->ihl * 4, 340 ntohs(iph->tot_len))) 341 goto fail; 342 break; 343 case IPPROTO_IGMP: 344 if (update_flags & TCA_CSUM_UPDATE_FLAG_IGMP) 345 if (!tcf_csum_ipv4_igmp(skb, iph->ihl * 4, 346 ntohs(iph->tot_len))) 347 goto fail; 348 break; 349 case IPPROTO_TCP: 350 if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP) 351 if (!tcf_csum_ipv4_tcp(skb, iph->ihl * 4, 352 ntohs(iph->tot_len))) 353 goto fail; 354 break; 355 case IPPROTO_UDP: 356 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP) 357 if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4, 358 ntohs(iph->tot_len), 0)) 359 goto fail; 360 break; 361 case IPPROTO_UDPLITE: 362 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE) 363 if (!tcf_csum_ipv4_udp(skb, iph->ihl * 4, 364 ntohs(iph->tot_len), 1)) 365 goto fail; 366 break; 367 } 368 369 if (update_flags & TCA_CSUM_UPDATE_FLAG_IPV4HDR) { 370 if (skb_try_make_writable(skb, sizeof(*iph) + ntkoff)) 371 goto fail; 372 373 ip_send_check(ip_hdr(skb)); 374 } 375 376 return 1; 377 378 fail: 379 return 0; 380 } 381 382 static int tcf_csum_ipv6_hopopts(struct ipv6_opt_hdr *ip6xh, 383 unsigned int ixhl, unsigned int *pl) 384 { 385 int off, len, optlen; 386 unsigned char *xh = (void *)ip6xh; 387 388 off = sizeof(*ip6xh); 389 len = ixhl - off; 390 391 while (len > 1) { 392 switch (xh[off]) { 393 case IPV6_TLV_PAD1: 394 optlen = 1; 395 break; 396 case IPV6_TLV_JUMBO: 397 optlen = xh[off + 1] + 2; 398 if (optlen != 6 || len < 6 || (off & 3) != 2) 399 /* wrong jumbo option length/alignment */ 400 return 0; 401 *pl = ntohl(*(__be32 *)(xh + off + 2)); 402 goto done; 403 default: 404 optlen = xh[off + 1] + 2; 405 if (optlen > len) 406 /* ignore obscure options */ 407 goto done; 408 break; 409 } 410 off += optlen; 411 len -= optlen; 412 } 413 414 done: 415 return 1; 416 } 417 418 static int tcf_csum_ipv6(struct sk_buff *skb, u32 update_flags) 419 { 420 struct ipv6hdr *ip6h; 421 struct ipv6_opt_hdr *ip6xh; 422 unsigned int hl, ixhl; 423 unsigned int pl; 424 int ntkoff; 425 u8 nexthdr; 426 427 ntkoff = skb_network_offset(skb); 428 429 hl = sizeof(*ip6h); 430 431 if (!pskb_may_pull(skb, hl + ntkoff)) 432 goto fail; 433 434 ip6h = ipv6_hdr(skb); 435 436 pl = ntohs(ip6h->payload_len); 437 nexthdr = ip6h->nexthdr; 438 439 do { 440 switch (nexthdr) { 441 case NEXTHDR_FRAGMENT: 442 goto ignore_skb; 443 case NEXTHDR_ROUTING: 444 case NEXTHDR_HOP: 445 case NEXTHDR_DEST: 446 if (!pskb_may_pull(skb, hl + sizeof(*ip6xh) + ntkoff)) 447 goto fail; 448 ip6xh = (void *)(skb_network_header(skb) + hl); 449 ixhl = ipv6_optlen(ip6xh); 450 if (!pskb_may_pull(skb, hl + ixhl + ntkoff)) 451 goto fail; 452 ip6xh = (void *)(skb_network_header(skb) + hl); 453 if ((nexthdr == NEXTHDR_HOP) && 454 !(tcf_csum_ipv6_hopopts(ip6xh, ixhl, &pl))) 455 goto fail; 456 nexthdr = ip6xh->nexthdr; 457 hl += ixhl; 458 break; 459 case IPPROTO_ICMPV6: 460 if (update_flags & TCA_CSUM_UPDATE_FLAG_ICMP) 461 if (!tcf_csum_ipv6_icmp(skb, 462 hl, pl + sizeof(*ip6h))) 463 goto fail; 464 goto done; 465 case IPPROTO_TCP: 466 if (update_flags & TCA_CSUM_UPDATE_FLAG_TCP) 467 if (!tcf_csum_ipv6_tcp(skb, 468 hl, pl + sizeof(*ip6h))) 469 goto fail; 470 goto done; 471 case IPPROTO_UDP: 472 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDP) 473 if (!tcf_csum_ipv6_udp(skb, hl, 474 pl + sizeof(*ip6h), 0)) 475 goto fail; 476 goto done; 477 case IPPROTO_UDPLITE: 478 if (update_flags & TCA_CSUM_UPDATE_FLAG_UDPLITE) 479 if (!tcf_csum_ipv6_udp(skb, hl, 480 pl + sizeof(*ip6h), 1)) 481 goto fail; 482 goto done; 483 default: 484 goto ignore_skb; 485 } 486 } while (pskb_may_pull(skb, hl + 1 + ntkoff)); 487 488 done: 489 ignore_skb: 490 return 1; 491 492 fail: 493 return 0; 494 } 495 496 static int tcf_csum(struct sk_buff *skb, 497 const struct tc_action *a, struct tcf_result *res) 498 { 499 struct tcf_csum *p = a->priv; 500 int action; 501 u32 update_flags; 502 503 spin_lock(&p->tcf_lock); 504 p->tcf_tm.lastuse = jiffies; 505 bstats_update(&p->tcf_bstats, skb); 506 action = p->tcf_action; 507 update_flags = p->update_flags; 508 spin_unlock(&p->tcf_lock); 509 510 if (unlikely(action == TC_ACT_SHOT)) 511 goto drop; 512 513 switch (tc_skb_protocol(skb)) { 514 case cpu_to_be16(ETH_P_IP): 515 if (!tcf_csum_ipv4(skb, update_flags)) 516 goto drop; 517 break; 518 case cpu_to_be16(ETH_P_IPV6): 519 if (!tcf_csum_ipv6(skb, update_flags)) 520 goto drop; 521 break; 522 } 523 524 return action; 525 526 drop: 527 spin_lock(&p->tcf_lock); 528 p->tcf_qstats.drops++; 529 spin_unlock(&p->tcf_lock); 530 return TC_ACT_SHOT; 531 } 532 533 static int tcf_csum_dump(struct sk_buff *skb, 534 struct tc_action *a, int bind, int ref) 535 { 536 unsigned char *b = skb_tail_pointer(skb); 537 struct tcf_csum *p = a->priv; 538 struct tc_csum opt = { 539 .update_flags = p->update_flags, 540 .index = p->tcf_index, 541 .action = p->tcf_action, 542 .refcnt = p->tcf_refcnt - ref, 543 .bindcnt = p->tcf_bindcnt - bind, 544 }; 545 struct tcf_t t; 546 547 if (nla_put(skb, TCA_CSUM_PARMS, sizeof(opt), &opt)) 548 goto nla_put_failure; 549 t.install = jiffies_to_clock_t(jiffies - p->tcf_tm.install); 550 t.lastuse = jiffies_to_clock_t(jiffies - p->tcf_tm.lastuse); 551 t.expires = jiffies_to_clock_t(p->tcf_tm.expires); 552 if (nla_put_64bit(skb, TCA_CSUM_TM, sizeof(t), &t, TCA_CSUM_PAD)) 553 goto nla_put_failure; 554 555 return skb->len; 556 557 nla_put_failure: 558 nlmsg_trim(skb, b); 559 return -1; 560 } 561 562 static int tcf_csum_walker(struct net *net, struct sk_buff *skb, 563 struct netlink_callback *cb, int type, 564 struct tc_action *a) 565 { 566 struct tc_action_net *tn = net_generic(net, csum_net_id); 567 568 return tcf_generic_walker(tn, skb, cb, type, a); 569 } 570 571 static int tcf_csum_search(struct net *net, struct tc_action *a, u32 index) 572 { 573 struct tc_action_net *tn = net_generic(net, csum_net_id); 574 575 return tcf_hash_search(tn, a, index); 576 } 577 578 static struct tc_action_ops act_csum_ops = { 579 .kind = "csum", 580 .type = TCA_ACT_CSUM, 581 .owner = THIS_MODULE, 582 .act = tcf_csum, 583 .dump = tcf_csum_dump, 584 .init = tcf_csum_init, 585 .walk = tcf_csum_walker, 586 .lookup = tcf_csum_search, 587 }; 588 589 static __net_init int csum_init_net(struct net *net) 590 { 591 struct tc_action_net *tn = net_generic(net, csum_net_id); 592 593 return tc_action_net_init(tn, &act_csum_ops, CSUM_TAB_MASK); 594 } 595 596 static void __net_exit csum_exit_net(struct net *net) 597 { 598 struct tc_action_net *tn = net_generic(net, csum_net_id); 599 600 tc_action_net_exit(tn); 601 } 602 603 static struct pernet_operations csum_net_ops = { 604 .init = csum_init_net, 605 .exit = csum_exit_net, 606 .id = &csum_net_id, 607 .size = sizeof(struct tc_action_net), 608 }; 609 610 MODULE_DESCRIPTION("Checksum updating actions"); 611 MODULE_LICENSE("GPL"); 612 613 static int __init csum_init_module(void) 614 { 615 return tcf_register_action(&act_csum_ops, &csum_net_ops); 616 } 617 618 static void __exit csum_cleanup_module(void) 619 { 620 tcf_unregister_action(&act_csum_ops, &csum_net_ops); 621 } 622 623 module_init(csum_init_module); 624 module_exit(csum_cleanup_module); 625