1 /* Copyright (c) 2018, Mellanox Technologies All rights reserved. 2 * 3 * This software is available to you under a choice of one of two 4 * licenses. You may choose to be licensed under the terms of the GNU 5 * General Public License (GPL) Version 2, available from the file 6 * COPYING in the main directory of this source tree, or the 7 * OpenIB.org BSD license below: 8 * 9 * Redistribution and use in source and binary forms, with or 10 * without modification, are permitted provided that the following 11 * conditions are met: 12 * 13 * - Redistributions of source code must retain the above 14 * copyright notice, this list of conditions and the following 15 * disclaimer. 16 * 17 * - Redistributions in binary form must reproduce the above 18 * copyright notice, this list of conditions and the following 19 * disclaimer in the documentation and/or other materials 20 * provided with the distribution. 21 * 22 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 23 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 24 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 25 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 26 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 27 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 28 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 * SOFTWARE. 30 */ 31 32 #include <crypto/aead.h> 33 #include <linux/highmem.h> 34 #include <linux/module.h> 35 #include <linux/netdevice.h> 36 #include <net/dst.h> 37 #include <net/inet_connection_sock.h> 38 #include <net/tcp.h> 39 #include <net/tls.h> 40 41 /* device_offload_lock is used to synchronize tls_dev_add 42 * against NETDEV_DOWN notifications. 43 */ 44 static DECLARE_RWSEM(device_offload_lock); 45 46 static void tls_device_gc_task(struct work_struct *work); 47 48 static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task); 49 static LIST_HEAD(tls_device_gc_list); 50 static LIST_HEAD(tls_device_list); 51 static DEFINE_SPINLOCK(tls_device_lock); 52 53 static void tls_device_free_ctx(struct tls_context *ctx) 54 { 55 if (ctx->tx_conf == TLS_HW) 56 kfree(tls_offload_ctx_tx(ctx)); 57 58 if (ctx->rx_conf == TLS_HW) 59 kfree(tls_offload_ctx_rx(ctx)); 60 61 kfree(ctx); 62 } 63 64 static void tls_device_gc_task(struct work_struct *work) 65 { 66 struct tls_context *ctx, *tmp; 67 unsigned long flags; 68 LIST_HEAD(gc_list); 69 70 spin_lock_irqsave(&tls_device_lock, flags); 71 list_splice_init(&tls_device_gc_list, &gc_list); 72 spin_unlock_irqrestore(&tls_device_lock, flags); 73 74 list_for_each_entry_safe(ctx, tmp, &gc_list, list) { 75 struct net_device *netdev = ctx->netdev; 76 77 if (netdev && ctx->tx_conf == TLS_HW) { 78 netdev->tlsdev_ops->tls_dev_del(netdev, ctx, 79 TLS_OFFLOAD_CTX_DIR_TX); 80 dev_put(netdev); 81 ctx->netdev = NULL; 82 } 83 84 list_del(&ctx->list); 85 tls_device_free_ctx(ctx); 86 } 87 } 88 89 static void tls_device_attach(struct tls_context *ctx, struct sock *sk, 90 struct net_device *netdev) 91 { 92 if (sk->sk_destruct != tls_device_sk_destruct) { 93 refcount_set(&ctx->refcount, 1); 94 dev_hold(netdev); 95 ctx->netdev = netdev; 96 spin_lock_irq(&tls_device_lock); 97 list_add_tail(&ctx->list, &tls_device_list); 98 spin_unlock_irq(&tls_device_lock); 99 100 ctx->sk_destruct = sk->sk_destruct; 101 sk->sk_destruct = tls_device_sk_destruct; 102 } 103 } 104 105 static void tls_device_queue_ctx_destruction(struct tls_context *ctx) 106 { 107 unsigned long flags; 108 109 spin_lock_irqsave(&tls_device_lock, flags); 110 list_move_tail(&ctx->list, &tls_device_gc_list); 111 112 /* schedule_work inside the spinlock 113 * to make sure tls_device_down waits for that work. 114 */ 115 schedule_work(&tls_device_gc_work); 116 117 spin_unlock_irqrestore(&tls_device_lock, flags); 118 } 119 120 /* We assume that the socket is already connected */ 121 static struct net_device *get_netdev_for_sock(struct sock *sk) 122 { 123 struct dst_entry *dst = sk_dst_get(sk); 124 struct net_device *netdev = NULL; 125 126 if (likely(dst)) { 127 netdev = dst->dev; 128 dev_hold(netdev); 129 } 130 131 dst_release(dst); 132 133 return netdev; 134 } 135 136 static void destroy_record(struct tls_record_info *record) 137 { 138 int nr_frags = record->num_frags; 139 skb_frag_t *frag; 140 141 while (nr_frags-- > 0) { 142 frag = &record->frags[nr_frags]; 143 __skb_frag_unref(frag); 144 } 145 kfree(record); 146 } 147 148 static void delete_all_records(struct tls_offload_context_tx *offload_ctx) 149 { 150 struct tls_record_info *info, *temp; 151 152 list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) { 153 list_del(&info->list); 154 destroy_record(info); 155 } 156 157 offload_ctx->retransmit_hint = NULL; 158 } 159 160 static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq) 161 { 162 struct tls_context *tls_ctx = tls_get_ctx(sk); 163 struct tls_record_info *info, *temp; 164 struct tls_offload_context_tx *ctx; 165 u64 deleted_records = 0; 166 unsigned long flags; 167 168 if (!tls_ctx) 169 return; 170 171 ctx = tls_offload_ctx_tx(tls_ctx); 172 173 spin_lock_irqsave(&ctx->lock, flags); 174 info = ctx->retransmit_hint; 175 if (info && !before(acked_seq, info->end_seq)) { 176 ctx->retransmit_hint = NULL; 177 list_del(&info->list); 178 destroy_record(info); 179 deleted_records++; 180 } 181 182 list_for_each_entry_safe(info, temp, &ctx->records_list, list) { 183 if (before(acked_seq, info->end_seq)) 184 break; 185 list_del(&info->list); 186 187 destroy_record(info); 188 deleted_records++; 189 } 190 191 ctx->unacked_record_sn += deleted_records; 192 spin_unlock_irqrestore(&ctx->lock, flags); 193 } 194 195 /* At this point, there should be no references on this 196 * socket and no in-flight SKBs associated with this 197 * socket, so it is safe to free all the resources. 198 */ 199 void tls_device_sk_destruct(struct sock *sk) 200 { 201 struct tls_context *tls_ctx = tls_get_ctx(sk); 202 struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); 203 204 tls_ctx->sk_destruct(sk); 205 206 if (tls_ctx->tx_conf == TLS_HW) { 207 if (ctx->open_record) 208 destroy_record(ctx->open_record); 209 delete_all_records(ctx); 210 crypto_free_aead(ctx->aead_send); 211 clean_acked_data_disable(inet_csk(sk)); 212 } 213 214 if (refcount_dec_and_test(&tls_ctx->refcount)) 215 tls_device_queue_ctx_destruction(tls_ctx); 216 } 217 EXPORT_SYMBOL(tls_device_sk_destruct); 218 219 static void tls_append_frag(struct tls_record_info *record, 220 struct page_frag *pfrag, 221 int size) 222 { 223 skb_frag_t *frag; 224 225 frag = &record->frags[record->num_frags - 1]; 226 if (frag->page.p == pfrag->page && 227 frag->page_offset + frag->size == pfrag->offset) { 228 frag->size += size; 229 } else { 230 ++frag; 231 frag->page.p = pfrag->page; 232 frag->page_offset = pfrag->offset; 233 frag->size = size; 234 ++record->num_frags; 235 get_page(pfrag->page); 236 } 237 238 pfrag->offset += size; 239 record->len += size; 240 } 241 242 static int tls_push_record(struct sock *sk, 243 struct tls_context *ctx, 244 struct tls_offload_context_tx *offload_ctx, 245 struct tls_record_info *record, 246 struct page_frag *pfrag, 247 int flags, 248 unsigned char record_type) 249 { 250 struct tcp_sock *tp = tcp_sk(sk); 251 struct page_frag dummy_tag_frag; 252 skb_frag_t *frag; 253 int i; 254 255 /* fill prepend */ 256 frag = &record->frags[0]; 257 tls_fill_prepend(ctx, 258 skb_frag_address(frag), 259 record->len - ctx->tx.prepend_size, 260 record_type, 261 ctx->crypto_send.info.version); 262 263 /* HW doesn't care about the data in the tag, because it fills it. */ 264 dummy_tag_frag.page = skb_frag_page(frag); 265 dummy_tag_frag.offset = 0; 266 267 tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size); 268 record->end_seq = tp->write_seq + record->len; 269 spin_lock_irq(&offload_ctx->lock); 270 list_add_tail(&record->list, &offload_ctx->records_list); 271 spin_unlock_irq(&offload_ctx->lock); 272 offload_ctx->open_record = NULL; 273 set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); 274 tls_advance_record_sn(sk, &ctx->tx, ctx->crypto_send.info.version); 275 276 for (i = 0; i < record->num_frags; i++) { 277 frag = &record->frags[i]; 278 sg_unmark_end(&offload_ctx->sg_tx_data[i]); 279 sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag), 280 frag->size, frag->page_offset); 281 sk_mem_charge(sk, frag->size); 282 get_page(skb_frag_page(frag)); 283 } 284 sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]); 285 286 /* all ready, send */ 287 return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags); 288 } 289 290 static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx, 291 struct page_frag *pfrag, 292 size_t prepend_size) 293 { 294 struct tls_record_info *record; 295 skb_frag_t *frag; 296 297 record = kmalloc(sizeof(*record), GFP_KERNEL); 298 if (!record) 299 return -ENOMEM; 300 301 frag = &record->frags[0]; 302 __skb_frag_set_page(frag, pfrag->page); 303 frag->page_offset = pfrag->offset; 304 skb_frag_size_set(frag, prepend_size); 305 306 get_page(pfrag->page); 307 pfrag->offset += prepend_size; 308 309 record->num_frags = 1; 310 record->len = prepend_size; 311 offload_ctx->open_record = record; 312 return 0; 313 } 314 315 static int tls_do_allocation(struct sock *sk, 316 struct tls_offload_context_tx *offload_ctx, 317 struct page_frag *pfrag, 318 size_t prepend_size) 319 { 320 int ret; 321 322 if (!offload_ctx->open_record) { 323 if (unlikely(!skb_page_frag_refill(prepend_size, pfrag, 324 sk->sk_allocation))) { 325 sk->sk_prot->enter_memory_pressure(sk); 326 sk_stream_moderate_sndbuf(sk); 327 return -ENOMEM; 328 } 329 330 ret = tls_create_new_record(offload_ctx, pfrag, prepend_size); 331 if (ret) 332 return ret; 333 334 if (pfrag->size > pfrag->offset) 335 return 0; 336 } 337 338 if (!sk_page_frag_refill(sk, pfrag)) 339 return -ENOMEM; 340 341 return 0; 342 } 343 344 static int tls_push_data(struct sock *sk, 345 struct iov_iter *msg_iter, 346 size_t size, int flags, 347 unsigned char record_type) 348 { 349 struct tls_context *tls_ctx = tls_get_ctx(sk); 350 struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); 351 int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; 352 int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE); 353 struct tls_record_info *record = ctx->open_record; 354 struct page_frag *pfrag; 355 size_t orig_size = size; 356 u32 max_open_record_len; 357 int copy, rc = 0; 358 bool done = false; 359 long timeo; 360 361 if (flags & 362 ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST)) 363 return -ENOTSUPP; 364 365 if (sk->sk_err) 366 return -sk->sk_err; 367 368 timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); 369 rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo); 370 if (rc < 0) 371 return rc; 372 373 pfrag = sk_page_frag(sk); 374 375 /* TLS_HEADER_SIZE is not counted as part of the TLS record, and 376 * we need to leave room for an authentication tag. 377 */ 378 max_open_record_len = TLS_MAX_PAYLOAD_SIZE + 379 tls_ctx->tx.prepend_size; 380 do { 381 rc = tls_do_allocation(sk, ctx, pfrag, 382 tls_ctx->tx.prepend_size); 383 if (rc) { 384 rc = sk_stream_wait_memory(sk, &timeo); 385 if (!rc) 386 continue; 387 388 record = ctx->open_record; 389 if (!record) 390 break; 391 handle_error: 392 if (record_type != TLS_RECORD_TYPE_DATA) { 393 /* avoid sending partial 394 * record with type != 395 * application_data 396 */ 397 size = orig_size; 398 destroy_record(record); 399 ctx->open_record = NULL; 400 } else if (record->len > tls_ctx->tx.prepend_size) { 401 goto last_record; 402 } 403 404 break; 405 } 406 407 record = ctx->open_record; 408 copy = min_t(size_t, size, (pfrag->size - pfrag->offset)); 409 copy = min_t(size_t, copy, (max_open_record_len - record->len)); 410 411 if (copy_from_iter_nocache(page_address(pfrag->page) + 412 pfrag->offset, 413 copy, msg_iter) != copy) { 414 rc = -EFAULT; 415 goto handle_error; 416 } 417 tls_append_frag(record, pfrag, copy); 418 419 size -= copy; 420 if (!size) { 421 last_record: 422 tls_push_record_flags = flags; 423 if (more) { 424 tls_ctx->pending_open_record_frags = 425 !!record->num_frags; 426 break; 427 } 428 429 done = true; 430 } 431 432 if (done || record->len >= max_open_record_len || 433 (record->num_frags >= MAX_SKB_FRAGS - 1)) { 434 rc = tls_push_record(sk, 435 tls_ctx, 436 ctx, 437 record, 438 pfrag, 439 tls_push_record_flags, 440 record_type); 441 if (rc < 0) 442 break; 443 } 444 } while (!done); 445 446 if (orig_size - size > 0) 447 rc = orig_size - size; 448 449 return rc; 450 } 451 452 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 453 { 454 unsigned char record_type = TLS_RECORD_TYPE_DATA; 455 int rc; 456 457 lock_sock(sk); 458 459 if (unlikely(msg->msg_controllen)) { 460 rc = tls_proccess_cmsg(sk, msg, &record_type); 461 if (rc) 462 goto out; 463 } 464 465 rc = tls_push_data(sk, &msg->msg_iter, size, 466 msg->msg_flags, record_type); 467 468 out: 469 release_sock(sk); 470 return rc; 471 } 472 473 int tls_device_sendpage(struct sock *sk, struct page *page, 474 int offset, size_t size, int flags) 475 { 476 struct iov_iter msg_iter; 477 char *kaddr = kmap(page); 478 struct kvec iov; 479 int rc; 480 481 if (flags & MSG_SENDPAGE_NOTLAST) 482 flags |= MSG_MORE; 483 484 lock_sock(sk); 485 486 if (flags & MSG_OOB) { 487 rc = -ENOTSUPP; 488 goto out; 489 } 490 491 iov.iov_base = kaddr + offset; 492 iov.iov_len = size; 493 iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size); 494 rc = tls_push_data(sk, &msg_iter, size, 495 flags, TLS_RECORD_TYPE_DATA); 496 kunmap(page); 497 498 out: 499 release_sock(sk); 500 return rc; 501 } 502 503 struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, 504 u32 seq, u64 *p_record_sn) 505 { 506 u64 record_sn = context->hint_record_sn; 507 struct tls_record_info *info; 508 509 info = context->retransmit_hint; 510 if (!info || 511 before(seq, info->end_seq - info->len)) { 512 /* if retransmit_hint is irrelevant start 513 * from the beggining of the list 514 */ 515 info = list_first_entry(&context->records_list, 516 struct tls_record_info, list); 517 record_sn = context->unacked_record_sn; 518 } 519 520 list_for_each_entry_from(info, &context->records_list, list) { 521 if (before(seq, info->end_seq)) { 522 if (!context->retransmit_hint || 523 after(info->end_seq, 524 context->retransmit_hint->end_seq)) { 525 context->hint_record_sn = record_sn; 526 context->retransmit_hint = info; 527 } 528 *p_record_sn = record_sn; 529 return info; 530 } 531 record_sn++; 532 } 533 534 return NULL; 535 } 536 EXPORT_SYMBOL(tls_get_record); 537 538 static int tls_device_push_pending_record(struct sock *sk, int flags) 539 { 540 struct iov_iter msg_iter; 541 542 iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0); 543 return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA); 544 } 545 546 void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn) 547 { 548 struct tls_context *tls_ctx = tls_get_ctx(sk); 549 struct net_device *netdev = tls_ctx->netdev; 550 struct tls_offload_context_rx *rx_ctx; 551 u32 is_req_pending; 552 s64 resync_req; 553 u32 req_seq; 554 555 if (tls_ctx->rx_conf != TLS_HW) 556 return; 557 558 rx_ctx = tls_offload_ctx_rx(tls_ctx); 559 resync_req = atomic64_read(&rx_ctx->resync_req); 560 req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1); 561 is_req_pending = resync_req; 562 563 if (unlikely(is_req_pending) && req_seq == seq && 564 atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0)) 565 netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk, 566 seq + TLS_HEADER_SIZE - 1, 567 rcd_sn); 568 } 569 570 static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb) 571 { 572 struct strp_msg *rxm = strp_msg(skb); 573 int err = 0, offset = rxm->offset, copy, nsg; 574 struct sk_buff *skb_iter, *unused; 575 struct scatterlist sg[1]; 576 char *orig_buf, *buf; 577 578 orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + 579 TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation); 580 if (!orig_buf) 581 return -ENOMEM; 582 buf = orig_buf; 583 584 nsg = skb_cow_data(skb, 0, &unused); 585 if (unlikely(nsg < 0)) { 586 err = nsg; 587 goto free_buf; 588 } 589 590 sg_init_table(sg, 1); 591 sg_set_buf(&sg[0], buf, 592 rxm->full_len + TLS_HEADER_SIZE + 593 TLS_CIPHER_AES_GCM_128_IV_SIZE); 594 skb_copy_bits(skb, offset, buf, 595 TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE); 596 597 /* We are interested only in the decrypted data not the auth */ 598 err = decrypt_skb(sk, skb, sg); 599 if (err != -EBADMSG) 600 goto free_buf; 601 else 602 err = 0; 603 604 copy = min_t(int, skb_pagelen(skb) - offset, 605 rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE); 606 607 if (skb->decrypted) 608 skb_store_bits(skb, offset, buf, copy); 609 610 offset += copy; 611 buf += copy; 612 613 skb_walk_frags(skb, skb_iter) { 614 copy = min_t(int, skb_iter->len, 615 rxm->full_len - offset + rxm->offset - 616 TLS_CIPHER_AES_GCM_128_TAG_SIZE); 617 618 if (skb_iter->decrypted) 619 skb_store_bits(skb_iter, offset, buf, copy); 620 621 offset += copy; 622 buf += copy; 623 } 624 625 free_buf: 626 kfree(orig_buf); 627 return err; 628 } 629 630 int tls_device_decrypted(struct sock *sk, struct sk_buff *skb) 631 { 632 struct tls_context *tls_ctx = tls_get_ctx(sk); 633 struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx); 634 int is_decrypted = skb->decrypted; 635 int is_encrypted = !is_decrypted; 636 struct sk_buff *skb_iter; 637 638 /* Skip if it is already decrypted */ 639 if (ctx->sw.decrypted) 640 return 0; 641 642 /* Check if all the data is decrypted already */ 643 skb_walk_frags(skb, skb_iter) { 644 is_decrypted &= skb_iter->decrypted; 645 is_encrypted &= !skb_iter->decrypted; 646 } 647 648 ctx->sw.decrypted |= is_decrypted; 649 650 /* Return immedeatly if the record is either entirely plaintext or 651 * entirely ciphertext. Otherwise handle reencrypt partially decrypted 652 * record. 653 */ 654 return (is_encrypted || is_decrypted) ? 0 : 655 tls_device_reencrypt(sk, skb); 656 } 657 658 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) 659 { 660 u16 nonce_size, tag_size, iv_size, rec_seq_size; 661 struct tls_record_info *start_marker_record; 662 struct tls_offload_context_tx *offload_ctx; 663 struct tls_crypto_info *crypto_info; 664 struct net_device *netdev; 665 char *iv, *rec_seq; 666 struct sk_buff *skb; 667 int rc = -EINVAL; 668 __be64 rcd_sn; 669 670 if (!ctx) 671 goto out; 672 673 if (ctx->priv_ctx_tx) { 674 rc = -EEXIST; 675 goto out; 676 } 677 678 start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL); 679 if (!start_marker_record) { 680 rc = -ENOMEM; 681 goto out; 682 } 683 684 offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL); 685 if (!offload_ctx) { 686 rc = -ENOMEM; 687 goto free_marker_record; 688 } 689 690 crypto_info = &ctx->crypto_send.info; 691 switch (crypto_info->cipher_type) { 692 case TLS_CIPHER_AES_GCM_128: 693 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 694 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; 695 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 696 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv; 697 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; 698 rec_seq = 699 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; 700 break; 701 default: 702 rc = -EINVAL; 703 goto free_offload_ctx; 704 } 705 706 ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size; 707 ctx->tx.tag_size = tag_size; 708 ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size; 709 ctx->tx.iv_size = iv_size; 710 ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 711 GFP_KERNEL); 712 if (!ctx->tx.iv) { 713 rc = -ENOMEM; 714 goto free_offload_ctx; 715 } 716 717 memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); 718 719 ctx->tx.rec_seq_size = rec_seq_size; 720 ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); 721 if (!ctx->tx.rec_seq) { 722 rc = -ENOMEM; 723 goto free_iv; 724 } 725 726 rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info); 727 if (rc) 728 goto free_rec_seq; 729 730 /* start at rec_seq - 1 to account for the start marker record */ 731 memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn)); 732 offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1; 733 734 start_marker_record->end_seq = tcp_sk(sk)->write_seq; 735 start_marker_record->len = 0; 736 start_marker_record->num_frags = 0; 737 738 INIT_LIST_HEAD(&offload_ctx->records_list); 739 list_add_tail(&start_marker_record->list, &offload_ctx->records_list); 740 spin_lock_init(&offload_ctx->lock); 741 sg_init_table(offload_ctx->sg_tx_data, 742 ARRAY_SIZE(offload_ctx->sg_tx_data)); 743 744 clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked); 745 ctx->push_pending_record = tls_device_push_pending_record; 746 747 /* TLS offload is greatly simplified if we don't send 748 * SKBs where only part of the payload needs to be encrypted. 749 * So mark the last skb in the write queue as end of record. 750 */ 751 skb = tcp_write_queue_tail(sk); 752 if (skb) 753 TCP_SKB_CB(skb)->eor = 1; 754 755 /* We support starting offload on multiple sockets 756 * concurrently, so we only need a read lock here. 757 * This lock must precede get_netdev_for_sock to prevent races between 758 * NETDEV_DOWN and setsockopt. 759 */ 760 down_read(&device_offload_lock); 761 netdev = get_netdev_for_sock(sk); 762 if (!netdev) { 763 pr_err_ratelimited("%s: netdev not found\n", __func__); 764 rc = -EINVAL; 765 goto release_lock; 766 } 767 768 if (!(netdev->features & NETIF_F_HW_TLS_TX)) { 769 rc = -ENOTSUPP; 770 goto release_netdev; 771 } 772 773 /* Avoid offloading if the device is down 774 * We don't want to offload new flows after 775 * the NETDEV_DOWN event 776 */ 777 if (!(netdev->flags & IFF_UP)) { 778 rc = -EINVAL; 779 goto release_netdev; 780 } 781 782 ctx->priv_ctx_tx = offload_ctx; 783 rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX, 784 &ctx->crypto_send.info, 785 tcp_sk(sk)->write_seq); 786 if (rc) 787 goto release_netdev; 788 789 tls_device_attach(ctx, sk, netdev); 790 791 /* following this assignment tls_is_sk_tx_device_offloaded 792 * will return true and the context might be accessed 793 * by the netdev's xmit function. 794 */ 795 smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb); 796 dev_put(netdev); 797 up_read(&device_offload_lock); 798 goto out; 799 800 release_netdev: 801 dev_put(netdev); 802 release_lock: 803 up_read(&device_offload_lock); 804 clean_acked_data_disable(inet_csk(sk)); 805 crypto_free_aead(offload_ctx->aead_send); 806 free_rec_seq: 807 kfree(ctx->tx.rec_seq); 808 free_iv: 809 kfree(ctx->tx.iv); 810 free_offload_ctx: 811 kfree(offload_ctx); 812 ctx->priv_ctx_tx = NULL; 813 free_marker_record: 814 kfree(start_marker_record); 815 out: 816 return rc; 817 } 818 819 int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) 820 { 821 struct tls_offload_context_rx *context; 822 struct net_device *netdev; 823 int rc = 0; 824 825 /* We support starting offload on multiple sockets 826 * concurrently, so we only need a read lock here. 827 * This lock must precede get_netdev_for_sock to prevent races between 828 * NETDEV_DOWN and setsockopt. 829 */ 830 down_read(&device_offload_lock); 831 netdev = get_netdev_for_sock(sk); 832 if (!netdev) { 833 pr_err_ratelimited("%s: netdev not found\n", __func__); 834 rc = -EINVAL; 835 goto release_lock; 836 } 837 838 if (!(netdev->features & NETIF_F_HW_TLS_RX)) { 839 pr_err_ratelimited("%s: netdev %s with no TLS offload\n", 840 __func__, netdev->name); 841 rc = -ENOTSUPP; 842 goto release_netdev; 843 } 844 845 /* Avoid offloading if the device is down 846 * We don't want to offload new flows after 847 * the NETDEV_DOWN event 848 */ 849 if (!(netdev->flags & IFF_UP)) { 850 rc = -EINVAL; 851 goto release_netdev; 852 } 853 854 context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL); 855 if (!context) { 856 rc = -ENOMEM; 857 goto release_netdev; 858 } 859 860 ctx->priv_ctx_rx = context; 861 rc = tls_set_sw_offload(sk, ctx, 0); 862 if (rc) 863 goto release_ctx; 864 865 rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX, 866 &ctx->crypto_recv.info, 867 tcp_sk(sk)->copied_seq); 868 if (rc) { 869 pr_err_ratelimited("%s: The netdev has refused to offload this socket\n", 870 __func__); 871 goto free_sw_resources; 872 } 873 874 tls_device_attach(ctx, sk, netdev); 875 goto release_netdev; 876 877 free_sw_resources: 878 tls_sw_free_resources_rx(sk); 879 release_ctx: 880 ctx->priv_ctx_rx = NULL; 881 release_netdev: 882 dev_put(netdev); 883 release_lock: 884 up_read(&device_offload_lock); 885 return rc; 886 } 887 888 void tls_device_offload_cleanup_rx(struct sock *sk) 889 { 890 struct tls_context *tls_ctx = tls_get_ctx(sk); 891 struct net_device *netdev; 892 893 down_read(&device_offload_lock); 894 netdev = tls_ctx->netdev; 895 if (!netdev) 896 goto out; 897 898 if (!(netdev->features & NETIF_F_HW_TLS_RX)) { 899 pr_err_ratelimited("%s: device is missing NETIF_F_HW_TLS_RX cap\n", 900 __func__); 901 goto out; 902 } 903 904 netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx, 905 TLS_OFFLOAD_CTX_DIR_RX); 906 907 if (tls_ctx->tx_conf != TLS_HW) { 908 dev_put(netdev); 909 tls_ctx->netdev = NULL; 910 } 911 out: 912 up_read(&device_offload_lock); 913 kfree(tls_ctx->rx.rec_seq); 914 kfree(tls_ctx->rx.iv); 915 tls_sw_release_resources_rx(sk); 916 } 917 918 static int tls_device_down(struct net_device *netdev) 919 { 920 struct tls_context *ctx, *tmp; 921 unsigned long flags; 922 LIST_HEAD(list); 923 924 /* Request a write lock to block new offload attempts */ 925 down_write(&device_offload_lock); 926 927 spin_lock_irqsave(&tls_device_lock, flags); 928 list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) { 929 if (ctx->netdev != netdev || 930 !refcount_inc_not_zero(&ctx->refcount)) 931 continue; 932 933 list_move(&ctx->list, &list); 934 } 935 spin_unlock_irqrestore(&tls_device_lock, flags); 936 937 list_for_each_entry_safe(ctx, tmp, &list, list) { 938 if (ctx->tx_conf == TLS_HW) 939 netdev->tlsdev_ops->tls_dev_del(netdev, ctx, 940 TLS_OFFLOAD_CTX_DIR_TX); 941 if (ctx->rx_conf == TLS_HW) 942 netdev->tlsdev_ops->tls_dev_del(netdev, ctx, 943 TLS_OFFLOAD_CTX_DIR_RX); 944 ctx->netdev = NULL; 945 dev_put(netdev); 946 list_del_init(&ctx->list); 947 948 if (refcount_dec_and_test(&ctx->refcount)) 949 tls_device_free_ctx(ctx); 950 } 951 952 up_write(&device_offload_lock); 953 954 flush_work(&tls_device_gc_work); 955 956 return NOTIFY_DONE; 957 } 958 959 static int tls_dev_event(struct notifier_block *this, unsigned long event, 960 void *ptr) 961 { 962 struct net_device *dev = netdev_notifier_info_to_dev(ptr); 963 964 if (!(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX))) 965 return NOTIFY_DONE; 966 967 switch (event) { 968 case NETDEV_REGISTER: 969 case NETDEV_FEAT_CHANGE: 970 if ((dev->features & NETIF_F_HW_TLS_RX) && 971 !dev->tlsdev_ops->tls_dev_resync_rx) 972 return NOTIFY_BAD; 973 974 if (dev->tlsdev_ops && 975 dev->tlsdev_ops->tls_dev_add && 976 dev->tlsdev_ops->tls_dev_del) 977 return NOTIFY_DONE; 978 else 979 return NOTIFY_BAD; 980 case NETDEV_DOWN: 981 return tls_device_down(dev); 982 } 983 return NOTIFY_DONE; 984 } 985 986 static struct notifier_block tls_dev_notifier = { 987 .notifier_call = tls_dev_event, 988 }; 989 990 void __init tls_device_init(void) 991 { 992 register_netdevice_notifier(&tls_dev_notifier); 993 } 994 995 void __exit tls_device_cleanup(void) 996 { 997 unregister_netdevice_notifier(&tls_dev_notifier); 998 flush_work(&tls_device_gc_work); 999 } 1000