1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved. 5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved. 6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved. 7 * 8 * This software is available to you under a choice of one of two 9 * licenses. You may choose to be licensed under the terms of the GNU 10 * General Public License (GPL) Version 2, available from the file 11 * COPYING in the main directory of this source tree, or the 12 * OpenIB.org BSD license below: 13 * 14 * Redistribution and use in source and binary forms, with or 15 * without modification, are permitted provided that the following 16 * conditions are met: 17 * 18 * - Redistributions of source code must retain the above 19 * copyright notice, this list of conditions and the following 20 * disclaimer. 21 * 22 * - Redistributions in binary form must reproduce the above 23 * copyright notice, this list of conditions and the following 24 * disclaimer in the documentation and/or other materials 25 * provided with the distribution. 26 * 27 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 28 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 29 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 30 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 31 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 32 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 33 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 34 * SOFTWARE. 35 */ 36 37 #include <linux/sched/signal.h> 38 #include <linux/module.h> 39 #include <crypto/aead.h> 40 41 #include <net/strparser.h> 42 #include <net/tls.h> 43 44 #define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE 45 46 static int __skb_nsg(struct sk_buff *skb, int offset, int len, 47 unsigned int recursion_level) 48 { 49 int start = skb_headlen(skb); 50 int i, chunk = start - offset; 51 struct sk_buff *frag_iter; 52 int elt = 0; 53 54 if (unlikely(recursion_level >= 24)) 55 return -EMSGSIZE; 56 57 if (chunk > 0) { 58 if (chunk > len) 59 chunk = len; 60 elt++; 61 len -= chunk; 62 if (len == 0) 63 return elt; 64 offset += chunk; 65 } 66 67 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { 68 int end; 69 70 WARN_ON(start > offset + len); 71 72 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]); 73 chunk = end - offset; 74 if (chunk > 0) { 75 if (chunk > len) 76 chunk = len; 77 elt++; 78 len -= chunk; 79 if (len == 0) 80 return elt; 81 offset += chunk; 82 } 83 start = end; 84 } 85 86 if (unlikely(skb_has_frag_list(skb))) { 87 skb_walk_frags(skb, frag_iter) { 88 int end, ret; 89 90 WARN_ON(start > offset + len); 91 92 end = start + frag_iter->len; 93 chunk = end - offset; 94 if (chunk > 0) { 95 if (chunk > len) 96 chunk = len; 97 ret = __skb_nsg(frag_iter, offset - start, chunk, 98 recursion_level + 1); 99 if (unlikely(ret < 0)) 100 return ret; 101 elt += ret; 102 len -= chunk; 103 if (len == 0) 104 return elt; 105 offset += chunk; 106 } 107 start = end; 108 } 109 } 110 BUG_ON(len); 111 return elt; 112 } 113 114 /* Return the number of scatterlist elements required to completely map the 115 * skb, or -EMSGSIZE if the recursion depth is exceeded. 116 */ 117 static int skb_nsg(struct sk_buff *skb, int offset, int len) 118 { 119 return __skb_nsg(skb, offset, len, 0); 120 } 121 122 static void tls_decrypt_done(struct crypto_async_request *req, int err) 123 { 124 struct aead_request *aead_req = (struct aead_request *)req; 125 struct decrypt_req_ctx *req_ctx = 126 (struct decrypt_req_ctx *)(aead_req + 1); 127 128 struct scatterlist *sgout = aead_req->dst; 129 130 struct tls_context *tls_ctx = tls_get_ctx(req_ctx->sk); 131 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 132 int pending = atomic_dec_return(&ctx->decrypt_pending); 133 struct scatterlist *sg; 134 unsigned int pages; 135 136 /* Propagate if there was an err */ 137 if (err) { 138 ctx->async_wait.err = err; 139 tls_err_abort(req_ctx->sk, err); 140 } 141 142 /* Release the skb, pages and memory allocated for crypto req */ 143 kfree_skb(req->data); 144 145 /* Skip the first S/G entry as it points to AAD */ 146 for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { 147 if (!sg) 148 break; 149 put_page(sg_page(sg)); 150 } 151 152 kfree(aead_req); 153 154 if (!pending && READ_ONCE(ctx->async_notify)) 155 complete(&ctx->async_wait.completion); 156 } 157 158 static int tls_do_decryption(struct sock *sk, 159 struct sk_buff *skb, 160 struct scatterlist *sgin, 161 struct scatterlist *sgout, 162 char *iv_recv, 163 size_t data_len, 164 struct aead_request *aead_req, 165 bool async) 166 { 167 struct tls_context *tls_ctx = tls_get_ctx(sk); 168 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 169 int ret; 170 171 aead_request_set_tfm(aead_req, ctx->aead_recv); 172 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); 173 aead_request_set_crypt(aead_req, sgin, sgout, 174 data_len + tls_ctx->rx.tag_size, 175 (u8 *)iv_recv); 176 177 if (async) { 178 struct decrypt_req_ctx *req_ctx; 179 180 req_ctx = (struct decrypt_req_ctx *)(aead_req + 1); 181 req_ctx->sk = sk; 182 183 aead_request_set_callback(aead_req, 184 CRYPTO_TFM_REQ_MAY_BACKLOG, 185 tls_decrypt_done, skb); 186 atomic_inc(&ctx->decrypt_pending); 187 } else { 188 aead_request_set_callback(aead_req, 189 CRYPTO_TFM_REQ_MAY_BACKLOG, 190 crypto_req_done, &ctx->async_wait); 191 } 192 193 ret = crypto_aead_decrypt(aead_req); 194 if (ret == -EINPROGRESS) { 195 if (async) 196 return ret; 197 198 ret = crypto_wait_req(ret, &ctx->async_wait); 199 } 200 201 if (async) 202 atomic_dec(&ctx->decrypt_pending); 203 204 return ret; 205 } 206 207 static void trim_sg(struct sock *sk, struct scatterlist *sg, 208 int *sg_num_elem, unsigned int *sg_size, int target_size) 209 { 210 int i = *sg_num_elem - 1; 211 int trim = *sg_size - target_size; 212 213 if (trim <= 0) { 214 WARN_ON(trim < 0); 215 return; 216 } 217 218 *sg_size = target_size; 219 while (trim >= sg[i].length) { 220 trim -= sg[i].length; 221 sk_mem_uncharge(sk, sg[i].length); 222 put_page(sg_page(&sg[i])); 223 i--; 224 225 if (i < 0) 226 goto out; 227 } 228 229 sg[i].length -= trim; 230 sk_mem_uncharge(sk, trim); 231 232 out: 233 *sg_num_elem = i + 1; 234 } 235 236 static void trim_both_sgl(struct sock *sk, int target_size) 237 { 238 struct tls_context *tls_ctx = tls_get_ctx(sk); 239 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 240 241 trim_sg(sk, ctx->sg_plaintext_data, 242 &ctx->sg_plaintext_num_elem, 243 &ctx->sg_plaintext_size, 244 target_size); 245 246 if (target_size > 0) 247 target_size += tls_ctx->tx.overhead_size; 248 249 trim_sg(sk, ctx->sg_encrypted_data, 250 &ctx->sg_encrypted_num_elem, 251 &ctx->sg_encrypted_size, 252 target_size); 253 } 254 255 static int alloc_encrypted_sg(struct sock *sk, int len) 256 { 257 struct tls_context *tls_ctx = tls_get_ctx(sk); 258 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 259 int rc = 0; 260 261 rc = sk_alloc_sg(sk, len, 262 ctx->sg_encrypted_data, 0, 263 &ctx->sg_encrypted_num_elem, 264 &ctx->sg_encrypted_size, 0); 265 266 return rc; 267 } 268 269 static int alloc_plaintext_sg(struct sock *sk, int len) 270 { 271 struct tls_context *tls_ctx = tls_get_ctx(sk); 272 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 273 int rc = 0; 274 275 rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0, 276 &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size, 277 tls_ctx->pending_open_record_frags); 278 279 return rc; 280 } 281 282 static void free_sg(struct sock *sk, struct scatterlist *sg, 283 int *sg_num_elem, unsigned int *sg_size) 284 { 285 int i, n = *sg_num_elem; 286 287 for (i = 0; i < n; ++i) { 288 sk_mem_uncharge(sk, sg[i].length); 289 put_page(sg_page(&sg[i])); 290 } 291 *sg_num_elem = 0; 292 *sg_size = 0; 293 } 294 295 static void tls_free_both_sg(struct sock *sk) 296 { 297 struct tls_context *tls_ctx = tls_get_ctx(sk); 298 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 299 300 free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem, 301 &ctx->sg_encrypted_size); 302 303 free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, 304 &ctx->sg_plaintext_size); 305 } 306 307 static int tls_do_encryption(struct tls_context *tls_ctx, 308 struct tls_sw_context_tx *ctx, 309 struct aead_request *aead_req, 310 size_t data_len) 311 { 312 int rc; 313 314 ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size; 315 ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size; 316 317 aead_request_set_tfm(aead_req, ctx->aead_send); 318 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); 319 aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out, 320 data_len, tls_ctx->tx.iv); 321 322 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, 323 crypto_req_done, &ctx->async_wait); 324 325 rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait); 326 327 ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size; 328 ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size; 329 330 return rc; 331 } 332 333 static int tls_push_record(struct sock *sk, int flags, 334 unsigned char record_type) 335 { 336 struct tls_context *tls_ctx = tls_get_ctx(sk); 337 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 338 struct aead_request *req; 339 int rc; 340 341 req = aead_request_alloc(ctx->aead_send, sk->sk_allocation); 342 if (!req) 343 return -ENOMEM; 344 345 sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1); 346 sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1); 347 348 tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size, 349 tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size, 350 record_type); 351 352 tls_fill_prepend(tls_ctx, 353 page_address(sg_page(&ctx->sg_encrypted_data[0])) + 354 ctx->sg_encrypted_data[0].offset, 355 ctx->sg_plaintext_size, record_type); 356 357 tls_ctx->pending_open_record_frags = 0; 358 set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags); 359 360 rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size); 361 if (rc < 0) { 362 /* If we are called from write_space and 363 * we fail, we need to set this SOCK_NOSPACE 364 * to trigger another write_space in the future. 365 */ 366 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 367 goto out_req; 368 } 369 370 free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, 371 &ctx->sg_plaintext_size); 372 373 ctx->sg_encrypted_num_elem = 0; 374 ctx->sg_encrypted_size = 0; 375 376 /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */ 377 rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags); 378 if (rc < 0 && rc != -EAGAIN) 379 tls_err_abort(sk, EBADMSG); 380 381 tls_advance_record_sn(sk, &tls_ctx->tx); 382 out_req: 383 aead_request_free(req); 384 return rc; 385 } 386 387 static int tls_sw_push_pending_record(struct sock *sk, int flags) 388 { 389 return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA); 390 } 391 392 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 393 int length, int *pages_used, 394 unsigned int *size_used, 395 struct scatterlist *to, int to_max_pages, 396 bool charge) 397 { 398 struct page *pages[MAX_SKB_FRAGS]; 399 400 size_t offset; 401 ssize_t copied, use; 402 int i = 0; 403 unsigned int size = *size_used; 404 int num_elem = *pages_used; 405 int rc = 0; 406 int maxpages; 407 408 while (length > 0) { 409 i = 0; 410 maxpages = to_max_pages - num_elem; 411 if (maxpages == 0) { 412 rc = -EFAULT; 413 goto out; 414 } 415 copied = iov_iter_get_pages(from, pages, 416 length, 417 maxpages, &offset); 418 if (copied <= 0) { 419 rc = -EFAULT; 420 goto out; 421 } 422 423 iov_iter_advance(from, copied); 424 425 length -= copied; 426 size += copied; 427 while (copied) { 428 use = min_t(int, copied, PAGE_SIZE - offset); 429 430 sg_set_page(&to[num_elem], 431 pages[i], use, offset); 432 sg_unmark_end(&to[num_elem]); 433 if (charge) 434 sk_mem_charge(sk, use); 435 436 offset = 0; 437 copied -= use; 438 439 ++i; 440 ++num_elem; 441 } 442 } 443 444 /* Mark the end in the last sg entry if newly added */ 445 if (num_elem > *pages_used) 446 sg_mark_end(&to[num_elem - 1]); 447 out: 448 if (rc) 449 iov_iter_revert(from, size - *size_used); 450 *size_used = size; 451 *pages_used = num_elem; 452 453 return rc; 454 } 455 456 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, 457 int bytes) 458 { 459 struct tls_context *tls_ctx = tls_get_ctx(sk); 460 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 461 struct scatterlist *sg = ctx->sg_plaintext_data; 462 int copy, i, rc = 0; 463 464 for (i = tls_ctx->pending_open_record_frags; 465 i < ctx->sg_plaintext_num_elem; ++i) { 466 copy = sg[i].length; 467 if (copy_from_iter( 468 page_address(sg_page(&sg[i])) + sg[i].offset, 469 copy, from) != copy) { 470 rc = -EFAULT; 471 goto out; 472 } 473 bytes -= copy; 474 475 ++tls_ctx->pending_open_record_frags; 476 477 if (!bytes) 478 break; 479 } 480 481 out: 482 return rc; 483 } 484 485 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 486 { 487 struct tls_context *tls_ctx = tls_get_ctx(sk); 488 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 489 int ret = 0; 490 int required_size; 491 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 492 bool eor = !(msg->msg_flags & MSG_MORE); 493 size_t try_to_copy, copied = 0; 494 unsigned char record_type = TLS_RECORD_TYPE_DATA; 495 int record_room; 496 bool full_record; 497 int orig_size; 498 bool is_kvec = msg->msg_iter.type & ITER_KVEC; 499 500 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) 501 return -ENOTSUPP; 502 503 lock_sock(sk); 504 505 if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo)) 506 goto send_end; 507 508 if (unlikely(msg->msg_controllen)) { 509 ret = tls_proccess_cmsg(sk, msg, &record_type); 510 if (ret) 511 goto send_end; 512 } 513 514 while (msg_data_left(msg)) { 515 if (sk->sk_err) { 516 ret = -sk->sk_err; 517 goto send_end; 518 } 519 520 orig_size = ctx->sg_plaintext_size; 521 full_record = false; 522 try_to_copy = msg_data_left(msg); 523 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size; 524 if (try_to_copy >= record_room) { 525 try_to_copy = record_room; 526 full_record = true; 527 } 528 529 required_size = ctx->sg_plaintext_size + try_to_copy + 530 tls_ctx->tx.overhead_size; 531 532 if (!sk_stream_memory_free(sk)) 533 goto wait_for_sndbuf; 534 alloc_encrypted: 535 ret = alloc_encrypted_sg(sk, required_size); 536 if (ret) { 537 if (ret != -ENOSPC) 538 goto wait_for_memory; 539 540 /* Adjust try_to_copy according to the amount that was 541 * actually allocated. The difference is due 542 * to max sg elements limit 543 */ 544 try_to_copy -= required_size - ctx->sg_encrypted_size; 545 full_record = true; 546 } 547 if (!is_kvec && (full_record || eor)) { 548 ret = zerocopy_from_iter(sk, &msg->msg_iter, 549 try_to_copy, &ctx->sg_plaintext_num_elem, 550 &ctx->sg_plaintext_size, 551 ctx->sg_plaintext_data, 552 ARRAY_SIZE(ctx->sg_plaintext_data), 553 true); 554 if (ret) 555 goto fallback_to_reg_send; 556 557 copied += try_to_copy; 558 ret = tls_push_record(sk, msg->msg_flags, record_type); 559 if (ret) 560 goto send_end; 561 continue; 562 563 fallback_to_reg_send: 564 trim_sg(sk, ctx->sg_plaintext_data, 565 &ctx->sg_plaintext_num_elem, 566 &ctx->sg_plaintext_size, 567 orig_size); 568 } 569 570 required_size = ctx->sg_plaintext_size + try_to_copy; 571 alloc_plaintext: 572 ret = alloc_plaintext_sg(sk, required_size); 573 if (ret) { 574 if (ret != -ENOSPC) 575 goto wait_for_memory; 576 577 /* Adjust try_to_copy according to the amount that was 578 * actually allocated. The difference is due 579 * to max sg elements limit 580 */ 581 try_to_copy -= required_size - ctx->sg_plaintext_size; 582 full_record = true; 583 584 trim_sg(sk, ctx->sg_encrypted_data, 585 &ctx->sg_encrypted_num_elem, 586 &ctx->sg_encrypted_size, 587 ctx->sg_plaintext_size + 588 tls_ctx->tx.overhead_size); 589 } 590 591 ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy); 592 if (ret) 593 goto trim_sgl; 594 595 copied += try_to_copy; 596 if (full_record || eor) { 597 push_record: 598 ret = tls_push_record(sk, msg->msg_flags, record_type); 599 if (ret) { 600 if (ret == -ENOMEM) 601 goto wait_for_memory; 602 603 goto send_end; 604 } 605 } 606 607 continue; 608 609 wait_for_sndbuf: 610 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 611 wait_for_memory: 612 ret = sk_stream_wait_memory(sk, &timeo); 613 if (ret) { 614 trim_sgl: 615 trim_both_sgl(sk, orig_size); 616 goto send_end; 617 } 618 619 if (tls_is_pending_closed_record(tls_ctx)) 620 goto push_record; 621 622 if (ctx->sg_encrypted_size < required_size) 623 goto alloc_encrypted; 624 625 goto alloc_plaintext; 626 } 627 628 send_end: 629 ret = sk_stream_error(sk, msg->msg_flags, ret); 630 631 release_sock(sk); 632 return copied ? copied : ret; 633 } 634 635 int tls_sw_sendpage(struct sock *sk, struct page *page, 636 int offset, size_t size, int flags) 637 { 638 struct tls_context *tls_ctx = tls_get_ctx(sk); 639 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 640 int ret = 0; 641 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); 642 bool eor; 643 size_t orig_size = size; 644 unsigned char record_type = TLS_RECORD_TYPE_DATA; 645 struct scatterlist *sg; 646 bool full_record; 647 int record_room; 648 649 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 650 MSG_SENDPAGE_NOTLAST)) 651 return -ENOTSUPP; 652 653 /* No MSG_EOR from splice, only look at MSG_MORE */ 654 eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST)); 655 656 lock_sock(sk); 657 658 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); 659 660 if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo)) 661 goto sendpage_end; 662 663 /* Call the sk_stream functions to manage the sndbuf mem. */ 664 while (size > 0) { 665 size_t copy, required_size; 666 667 if (sk->sk_err) { 668 ret = -sk->sk_err; 669 goto sendpage_end; 670 } 671 672 full_record = false; 673 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size; 674 copy = size; 675 if (copy >= record_room) { 676 copy = record_room; 677 full_record = true; 678 } 679 required_size = ctx->sg_plaintext_size + copy + 680 tls_ctx->tx.overhead_size; 681 682 if (!sk_stream_memory_free(sk)) 683 goto wait_for_sndbuf; 684 alloc_payload: 685 ret = alloc_encrypted_sg(sk, required_size); 686 if (ret) { 687 if (ret != -ENOSPC) 688 goto wait_for_memory; 689 690 /* Adjust copy according to the amount that was 691 * actually allocated. The difference is due 692 * to max sg elements limit 693 */ 694 copy -= required_size - ctx->sg_plaintext_size; 695 full_record = true; 696 } 697 698 get_page(page); 699 sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem; 700 sg_set_page(sg, page, copy, offset); 701 sg_unmark_end(sg); 702 703 ctx->sg_plaintext_num_elem++; 704 705 sk_mem_charge(sk, copy); 706 offset += copy; 707 size -= copy; 708 ctx->sg_plaintext_size += copy; 709 tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem; 710 711 if (full_record || eor || 712 ctx->sg_plaintext_num_elem == 713 ARRAY_SIZE(ctx->sg_plaintext_data)) { 714 push_record: 715 ret = tls_push_record(sk, flags, record_type); 716 if (ret) { 717 if (ret == -ENOMEM) 718 goto wait_for_memory; 719 720 goto sendpage_end; 721 } 722 } 723 continue; 724 wait_for_sndbuf: 725 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 726 wait_for_memory: 727 ret = sk_stream_wait_memory(sk, &timeo); 728 if (ret) { 729 trim_both_sgl(sk, ctx->sg_plaintext_size); 730 goto sendpage_end; 731 } 732 733 if (tls_is_pending_closed_record(tls_ctx)) 734 goto push_record; 735 736 goto alloc_payload; 737 } 738 739 sendpage_end: 740 if (orig_size > size) 741 ret = orig_size - size; 742 else 743 ret = sk_stream_error(sk, flags, ret); 744 745 release_sock(sk); 746 return ret; 747 } 748 749 static struct sk_buff *tls_wait_data(struct sock *sk, int flags, 750 long timeo, int *err) 751 { 752 struct tls_context *tls_ctx = tls_get_ctx(sk); 753 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 754 struct sk_buff *skb; 755 DEFINE_WAIT_FUNC(wait, woken_wake_function); 756 757 while (!(skb = ctx->recv_pkt)) { 758 if (sk->sk_err) { 759 *err = sock_error(sk); 760 return NULL; 761 } 762 763 if (sk->sk_shutdown & RCV_SHUTDOWN) 764 return NULL; 765 766 if (sock_flag(sk, SOCK_DONE)) 767 return NULL; 768 769 if ((flags & MSG_DONTWAIT) || !timeo) { 770 *err = -EAGAIN; 771 return NULL; 772 } 773 774 add_wait_queue(sk_sleep(sk), &wait); 775 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 776 sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait); 777 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 778 remove_wait_queue(sk_sleep(sk), &wait); 779 780 /* Handle signals */ 781 if (signal_pending(current)) { 782 *err = sock_intr_errno(timeo); 783 return NULL; 784 } 785 } 786 787 return skb; 788 } 789 790 /* This function decrypts the input skb into either out_iov or in out_sg 791 * or in skb buffers itself. The input parameter 'zc' indicates if 792 * zero-copy mode needs to be tried or not. With zero-copy mode, either 793 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are 794 * NULL, then the decryption happens inside skb buffers itself, i.e. 795 * zero-copy gets disabled and 'zc' is updated. 796 */ 797 798 static int decrypt_internal(struct sock *sk, struct sk_buff *skb, 799 struct iov_iter *out_iov, 800 struct scatterlist *out_sg, 801 int *chunk, bool *zc) 802 { 803 struct tls_context *tls_ctx = tls_get_ctx(sk); 804 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 805 struct strp_msg *rxm = strp_msg(skb); 806 int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; 807 struct aead_request *aead_req; 808 struct sk_buff *unused; 809 u8 *aad, *iv, *mem = NULL; 810 struct scatterlist *sgin = NULL; 811 struct scatterlist *sgout = NULL; 812 const int data_len = rxm->full_len - tls_ctx->rx.overhead_size; 813 814 if (*zc && (out_iov || out_sg)) { 815 if (out_iov) 816 n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1; 817 else 818 n_sgout = sg_nents(out_sg); 819 n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size, 820 rxm->full_len - tls_ctx->rx.prepend_size); 821 } else { 822 n_sgout = 0; 823 *zc = false; 824 n_sgin = skb_cow_data(skb, 0, &unused); 825 } 826 827 if (n_sgin < 1) 828 return -EBADMSG; 829 830 /* Increment to accommodate AAD */ 831 n_sgin = n_sgin + 1; 832 833 nsg = n_sgin + n_sgout; 834 835 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); 836 mem_size = aead_size + (nsg * sizeof(struct scatterlist)); 837 mem_size = mem_size + TLS_AAD_SPACE_SIZE; 838 mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); 839 840 /* Allocate a single block of memory which contains 841 * aead_req || sgin[] || sgout[] || aad || iv. 842 * This order achieves correct alignment for aead_req, sgin, sgout. 843 */ 844 mem = kmalloc(mem_size, sk->sk_allocation); 845 if (!mem) 846 return -ENOMEM; 847 848 /* Segment the allocated memory */ 849 aead_req = (struct aead_request *)mem; 850 sgin = (struct scatterlist *)(mem + aead_size); 851 sgout = sgin + n_sgin; 852 aad = (u8 *)(sgout + n_sgout); 853 iv = aad + TLS_AAD_SPACE_SIZE; 854 855 /* Prepare IV */ 856 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, 857 iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 858 tls_ctx->rx.iv_size); 859 if (err < 0) { 860 kfree(mem); 861 return err; 862 } 863 memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); 864 865 /* Prepare AAD */ 866 tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size, 867 tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size, 868 ctx->control); 869 870 /* Prepare sgin */ 871 sg_init_table(sgin, n_sgin); 872 sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE); 873 err = skb_to_sgvec(skb, &sgin[1], 874 rxm->offset + tls_ctx->rx.prepend_size, 875 rxm->full_len - tls_ctx->rx.prepend_size); 876 if (err < 0) { 877 kfree(mem); 878 return err; 879 } 880 881 if (n_sgout) { 882 if (out_iov) { 883 sg_init_table(sgout, n_sgout); 884 sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE); 885 886 *chunk = 0; 887 err = zerocopy_from_iter(sk, out_iov, data_len, &pages, 888 chunk, &sgout[1], 889 (n_sgout - 1), false); 890 if (err < 0) 891 goto fallback_to_reg_recv; 892 } else if (out_sg) { 893 memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); 894 } else { 895 goto fallback_to_reg_recv; 896 } 897 } else { 898 fallback_to_reg_recv: 899 sgout = sgin; 900 pages = 0; 901 *chunk = 0; 902 *zc = false; 903 } 904 905 /* Prepare and submit AEAD request */ 906 err = tls_do_decryption(sk, skb, sgin, sgout, iv, 907 data_len, aead_req, *zc); 908 if (err == -EINPROGRESS) 909 return err; 910 911 /* Release the pages in case iov was mapped to pages */ 912 for (; pages > 0; pages--) 913 put_page(sg_page(&sgout[pages])); 914 915 kfree(mem); 916 return err; 917 } 918 919 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, 920 struct iov_iter *dest, int *chunk, bool *zc) 921 { 922 struct tls_context *tls_ctx = tls_get_ctx(sk); 923 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 924 struct strp_msg *rxm = strp_msg(skb); 925 int err = 0; 926 927 #ifdef CONFIG_TLS_DEVICE 928 err = tls_device_decrypted(sk, skb); 929 if (err < 0) 930 return err; 931 #endif 932 if (!ctx->decrypted) { 933 err = decrypt_internal(sk, skb, dest, NULL, chunk, zc); 934 if (err < 0) { 935 if (err == -EINPROGRESS) 936 tls_advance_record_sn(sk, &tls_ctx->rx); 937 938 return err; 939 } 940 } else { 941 *zc = false; 942 } 943 944 rxm->offset += tls_ctx->rx.prepend_size; 945 rxm->full_len -= tls_ctx->rx.overhead_size; 946 tls_advance_record_sn(sk, &tls_ctx->rx); 947 ctx->decrypted = true; 948 ctx->saved_data_ready(sk); 949 950 return err; 951 } 952 953 int decrypt_skb(struct sock *sk, struct sk_buff *skb, 954 struct scatterlist *sgout) 955 { 956 bool zc = true; 957 int chunk; 958 959 return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc); 960 } 961 962 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, 963 unsigned int len) 964 { 965 struct tls_context *tls_ctx = tls_get_ctx(sk); 966 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 967 968 if (skb) { 969 struct strp_msg *rxm = strp_msg(skb); 970 971 if (len < rxm->full_len) { 972 rxm->offset += len; 973 rxm->full_len -= len; 974 return false; 975 } 976 kfree_skb(skb); 977 } 978 979 /* Finished with message */ 980 ctx->recv_pkt = NULL; 981 __strp_unpause(&ctx->strp); 982 983 return true; 984 } 985 986 int tls_sw_recvmsg(struct sock *sk, 987 struct msghdr *msg, 988 size_t len, 989 int nonblock, 990 int flags, 991 int *addr_len) 992 { 993 struct tls_context *tls_ctx = tls_get_ctx(sk); 994 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 995 unsigned char control; 996 struct strp_msg *rxm; 997 struct sk_buff *skb; 998 ssize_t copied = 0; 999 bool cmsg = false; 1000 int target, err = 0; 1001 long timeo; 1002 bool is_kvec = msg->msg_iter.type & ITER_KVEC; 1003 int num_async = 0; 1004 1005 flags |= nonblock; 1006 1007 if (unlikely(flags & MSG_ERRQUEUE)) 1008 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); 1009 1010 lock_sock(sk); 1011 1012 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); 1013 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 1014 do { 1015 bool zc = false; 1016 bool async = false; 1017 int chunk = 0; 1018 1019 skb = tls_wait_data(sk, flags, timeo, &err); 1020 if (!skb) 1021 goto recv_end; 1022 1023 rxm = strp_msg(skb); 1024 1025 if (!cmsg) { 1026 int cerr; 1027 1028 cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, 1029 sizeof(ctx->control), &ctx->control); 1030 cmsg = true; 1031 control = ctx->control; 1032 if (ctx->control != TLS_RECORD_TYPE_DATA) { 1033 if (cerr || msg->msg_flags & MSG_CTRUNC) { 1034 err = -EIO; 1035 goto recv_end; 1036 } 1037 } 1038 } else if (control != ctx->control) { 1039 goto recv_end; 1040 } 1041 1042 if (!ctx->decrypted) { 1043 int to_copy = rxm->full_len - tls_ctx->rx.overhead_size; 1044 1045 if (!is_kvec && to_copy <= len && 1046 likely(!(flags & MSG_PEEK))) 1047 zc = true; 1048 1049 err = decrypt_skb_update(sk, skb, &msg->msg_iter, 1050 &chunk, &zc); 1051 if (err < 0 && err != -EINPROGRESS) { 1052 tls_err_abort(sk, EBADMSG); 1053 goto recv_end; 1054 } 1055 1056 if (err == -EINPROGRESS) { 1057 async = true; 1058 num_async++; 1059 goto pick_next_record; 1060 } 1061 1062 ctx->decrypted = true; 1063 } 1064 1065 if (!zc) { 1066 chunk = min_t(unsigned int, rxm->full_len, len); 1067 1068 err = skb_copy_datagram_msg(skb, rxm->offset, msg, 1069 chunk); 1070 if (err < 0) 1071 goto recv_end; 1072 } 1073 1074 pick_next_record: 1075 copied += chunk; 1076 len -= chunk; 1077 if (likely(!(flags & MSG_PEEK))) { 1078 u8 control = ctx->control; 1079 1080 /* For async, drop current skb reference */ 1081 if (async) 1082 skb = NULL; 1083 1084 if (tls_sw_advance_skb(sk, skb, chunk)) { 1085 /* Return full control message to 1086 * userspace before trying to parse 1087 * another message type 1088 */ 1089 msg->msg_flags |= MSG_EOR; 1090 if (control != TLS_RECORD_TYPE_DATA) 1091 goto recv_end; 1092 } else { 1093 break; 1094 } 1095 } 1096 1097 /* If we have a new message from strparser, continue now. */ 1098 if (copied >= target && !ctx->recv_pkt) 1099 break; 1100 } while (len); 1101 1102 recv_end: 1103 if (num_async) { 1104 /* Wait for all previously submitted records to be decrypted */ 1105 smp_store_mb(ctx->async_notify, true); 1106 if (atomic_read(&ctx->decrypt_pending)) { 1107 err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); 1108 if (err) { 1109 /* one of async decrypt failed */ 1110 tls_err_abort(sk, err); 1111 copied = 0; 1112 } 1113 } else { 1114 reinit_completion(&ctx->async_wait.completion); 1115 } 1116 WRITE_ONCE(ctx->async_notify, false); 1117 } 1118 1119 release_sock(sk); 1120 return copied ? : err; 1121 } 1122 1123 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, 1124 struct pipe_inode_info *pipe, 1125 size_t len, unsigned int flags) 1126 { 1127 struct tls_context *tls_ctx = tls_get_ctx(sock->sk); 1128 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1129 struct strp_msg *rxm = NULL; 1130 struct sock *sk = sock->sk; 1131 struct sk_buff *skb; 1132 ssize_t copied = 0; 1133 int err = 0; 1134 long timeo; 1135 int chunk; 1136 bool zc = false; 1137 1138 lock_sock(sk); 1139 1140 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 1141 1142 skb = tls_wait_data(sk, flags, timeo, &err); 1143 if (!skb) 1144 goto splice_read_end; 1145 1146 /* splice does not support reading control messages */ 1147 if (ctx->control != TLS_RECORD_TYPE_DATA) { 1148 err = -ENOTSUPP; 1149 goto splice_read_end; 1150 } 1151 1152 if (!ctx->decrypted) { 1153 err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc); 1154 1155 if (err < 0) { 1156 tls_err_abort(sk, EBADMSG); 1157 goto splice_read_end; 1158 } 1159 ctx->decrypted = true; 1160 } 1161 rxm = strp_msg(skb); 1162 1163 chunk = min_t(unsigned int, rxm->full_len, len); 1164 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); 1165 if (copied < 0) 1166 goto splice_read_end; 1167 1168 if (likely(!(flags & MSG_PEEK))) 1169 tls_sw_advance_skb(sk, skb, copied); 1170 1171 splice_read_end: 1172 release_sock(sk); 1173 return copied ? : err; 1174 } 1175 1176 unsigned int tls_sw_poll(struct file *file, struct socket *sock, 1177 struct poll_table_struct *wait) 1178 { 1179 unsigned int ret; 1180 struct sock *sk = sock->sk; 1181 struct tls_context *tls_ctx = tls_get_ctx(sk); 1182 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1183 1184 /* Grab POLLOUT and POLLHUP from the underlying socket */ 1185 ret = ctx->sk_poll(file, sock, wait); 1186 1187 /* Clear POLLIN bits, and set based on recv_pkt */ 1188 ret &= ~(POLLIN | POLLRDNORM); 1189 if (ctx->recv_pkt) 1190 ret |= POLLIN | POLLRDNORM; 1191 1192 return ret; 1193 } 1194 1195 static int tls_read_size(struct strparser *strp, struct sk_buff *skb) 1196 { 1197 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 1198 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1199 char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; 1200 struct strp_msg *rxm = strp_msg(skb); 1201 size_t cipher_overhead; 1202 size_t data_len = 0; 1203 int ret; 1204 1205 /* Verify that we have a full TLS header, or wait for more data */ 1206 if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) 1207 return 0; 1208 1209 /* Sanity-check size of on-stack buffer. */ 1210 if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) { 1211 ret = -EINVAL; 1212 goto read_failure; 1213 } 1214 1215 /* Linearize header to local buffer */ 1216 ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); 1217 1218 if (ret < 0) 1219 goto read_failure; 1220 1221 ctx->control = header[0]; 1222 1223 data_len = ((header[4] & 0xFF) | (header[3] << 8)); 1224 1225 cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size; 1226 1227 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) { 1228 ret = -EMSGSIZE; 1229 goto read_failure; 1230 } 1231 if (data_len < cipher_overhead) { 1232 ret = -EBADMSG; 1233 goto read_failure; 1234 } 1235 1236 if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) || 1237 header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) { 1238 ret = -EINVAL; 1239 goto read_failure; 1240 } 1241 1242 #ifdef CONFIG_TLS_DEVICE 1243 handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset, 1244 *(u64*)tls_ctx->rx.rec_seq); 1245 #endif 1246 return data_len + TLS_HEADER_SIZE; 1247 1248 read_failure: 1249 tls_err_abort(strp->sk, ret); 1250 1251 return ret; 1252 } 1253 1254 static void tls_queue(struct strparser *strp, struct sk_buff *skb) 1255 { 1256 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 1257 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1258 1259 ctx->decrypted = false; 1260 1261 ctx->recv_pkt = skb; 1262 strp_pause(strp); 1263 1264 ctx->saved_data_ready(strp->sk); 1265 } 1266 1267 static void tls_data_ready(struct sock *sk) 1268 { 1269 struct tls_context *tls_ctx = tls_get_ctx(sk); 1270 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1271 1272 strp_data_ready(&ctx->strp); 1273 } 1274 1275 void tls_sw_free_resources_tx(struct sock *sk) 1276 { 1277 struct tls_context *tls_ctx = tls_get_ctx(sk); 1278 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 1279 1280 crypto_free_aead(ctx->aead_send); 1281 tls_free_both_sg(sk); 1282 1283 kfree(ctx); 1284 } 1285 1286 void tls_sw_release_resources_rx(struct sock *sk) 1287 { 1288 struct tls_context *tls_ctx = tls_get_ctx(sk); 1289 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1290 1291 if (ctx->aead_recv) { 1292 kfree_skb(ctx->recv_pkt); 1293 ctx->recv_pkt = NULL; 1294 crypto_free_aead(ctx->aead_recv); 1295 strp_stop(&ctx->strp); 1296 write_lock_bh(&sk->sk_callback_lock); 1297 sk->sk_data_ready = ctx->saved_data_ready; 1298 write_unlock_bh(&sk->sk_callback_lock); 1299 release_sock(sk); 1300 strp_done(&ctx->strp); 1301 lock_sock(sk); 1302 } 1303 } 1304 1305 void tls_sw_free_resources_rx(struct sock *sk) 1306 { 1307 struct tls_context *tls_ctx = tls_get_ctx(sk); 1308 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1309 1310 tls_sw_release_resources_rx(sk); 1311 1312 kfree(ctx); 1313 } 1314 1315 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) 1316 { 1317 char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE]; 1318 struct tls_crypto_info *crypto_info; 1319 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; 1320 struct tls_sw_context_tx *sw_ctx_tx = NULL; 1321 struct tls_sw_context_rx *sw_ctx_rx = NULL; 1322 struct cipher_context *cctx; 1323 struct crypto_aead **aead; 1324 struct strp_callbacks cb; 1325 u16 nonce_size, tag_size, iv_size, rec_seq_size; 1326 char *iv, *rec_seq; 1327 int rc = 0; 1328 1329 if (!ctx) { 1330 rc = -EINVAL; 1331 goto out; 1332 } 1333 1334 if (tx) { 1335 if (!ctx->priv_ctx_tx) { 1336 sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); 1337 if (!sw_ctx_tx) { 1338 rc = -ENOMEM; 1339 goto out; 1340 } 1341 ctx->priv_ctx_tx = sw_ctx_tx; 1342 } else { 1343 sw_ctx_tx = 1344 (struct tls_sw_context_tx *)ctx->priv_ctx_tx; 1345 } 1346 } else { 1347 if (!ctx->priv_ctx_rx) { 1348 sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); 1349 if (!sw_ctx_rx) { 1350 rc = -ENOMEM; 1351 goto out; 1352 } 1353 ctx->priv_ctx_rx = sw_ctx_rx; 1354 } else { 1355 sw_ctx_rx = 1356 (struct tls_sw_context_rx *)ctx->priv_ctx_rx; 1357 } 1358 } 1359 1360 if (tx) { 1361 crypto_init_wait(&sw_ctx_tx->async_wait); 1362 crypto_info = &ctx->crypto_send; 1363 cctx = &ctx->tx; 1364 aead = &sw_ctx_tx->aead_send; 1365 } else { 1366 crypto_init_wait(&sw_ctx_rx->async_wait); 1367 crypto_info = &ctx->crypto_recv; 1368 cctx = &ctx->rx; 1369 aead = &sw_ctx_rx->aead_recv; 1370 } 1371 1372 switch (crypto_info->cipher_type) { 1373 case TLS_CIPHER_AES_GCM_128: { 1374 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 1375 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; 1376 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 1377 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv; 1378 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; 1379 rec_seq = 1380 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; 1381 gcm_128_info = 1382 (struct tls12_crypto_info_aes_gcm_128 *)crypto_info; 1383 break; 1384 } 1385 default: 1386 rc = -EINVAL; 1387 goto free_priv; 1388 } 1389 1390 /* Sanity-check the IV size for stack allocations. */ 1391 if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) { 1392 rc = -EINVAL; 1393 goto free_priv; 1394 } 1395 1396 cctx->prepend_size = TLS_HEADER_SIZE + nonce_size; 1397 cctx->tag_size = tag_size; 1398 cctx->overhead_size = cctx->prepend_size + cctx->tag_size; 1399 cctx->iv_size = iv_size; 1400 cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 1401 GFP_KERNEL); 1402 if (!cctx->iv) { 1403 rc = -ENOMEM; 1404 goto free_priv; 1405 } 1406 memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); 1407 memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); 1408 cctx->rec_seq_size = rec_seq_size; 1409 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); 1410 if (!cctx->rec_seq) { 1411 rc = -ENOMEM; 1412 goto free_iv; 1413 } 1414 1415 if (sw_ctx_tx) { 1416 sg_init_table(sw_ctx_tx->sg_encrypted_data, 1417 ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data)); 1418 sg_init_table(sw_ctx_tx->sg_plaintext_data, 1419 ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data)); 1420 1421 sg_init_table(sw_ctx_tx->sg_aead_in, 2); 1422 sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space, 1423 sizeof(sw_ctx_tx->aad_space)); 1424 sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]); 1425 sg_chain(sw_ctx_tx->sg_aead_in, 2, 1426 sw_ctx_tx->sg_plaintext_data); 1427 sg_init_table(sw_ctx_tx->sg_aead_out, 2); 1428 sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space, 1429 sizeof(sw_ctx_tx->aad_space)); 1430 sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]); 1431 sg_chain(sw_ctx_tx->sg_aead_out, 2, 1432 sw_ctx_tx->sg_encrypted_data); 1433 } 1434 1435 if (!*aead) { 1436 *aead = crypto_alloc_aead("gcm(aes)", 0, 0); 1437 if (IS_ERR(*aead)) { 1438 rc = PTR_ERR(*aead); 1439 *aead = NULL; 1440 goto free_rec_seq; 1441 } 1442 } 1443 1444 ctx->push_pending_record = tls_sw_push_pending_record; 1445 1446 memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE); 1447 1448 rc = crypto_aead_setkey(*aead, keyval, 1449 TLS_CIPHER_AES_GCM_128_KEY_SIZE); 1450 if (rc) 1451 goto free_aead; 1452 1453 rc = crypto_aead_setauthsize(*aead, cctx->tag_size); 1454 if (rc) 1455 goto free_aead; 1456 1457 if (sw_ctx_rx) { 1458 (*aead)->reqsize = sizeof(struct decrypt_req_ctx); 1459 1460 /* Set up strparser */ 1461 memset(&cb, 0, sizeof(cb)); 1462 cb.rcv_msg = tls_queue; 1463 cb.parse_msg = tls_read_size; 1464 1465 strp_init(&sw_ctx_rx->strp, sk, &cb); 1466 1467 write_lock_bh(&sk->sk_callback_lock); 1468 sw_ctx_rx->saved_data_ready = sk->sk_data_ready; 1469 sk->sk_data_ready = tls_data_ready; 1470 write_unlock_bh(&sk->sk_callback_lock); 1471 1472 sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll; 1473 1474 strp_check_rcv(&sw_ctx_rx->strp); 1475 } 1476 1477 goto out; 1478 1479 free_aead: 1480 crypto_free_aead(*aead); 1481 *aead = NULL; 1482 free_rec_seq: 1483 kfree(cctx->rec_seq); 1484 cctx->rec_seq = NULL; 1485 free_iv: 1486 kfree(cctx->iv); 1487 cctx->iv = NULL; 1488 free_priv: 1489 if (tx) { 1490 kfree(ctx->priv_ctx_tx); 1491 ctx->priv_ctx_tx = NULL; 1492 } else { 1493 kfree(ctx->priv_ctx_rx); 1494 ctx->priv_ctx_rx = NULL; 1495 } 1496 out: 1497 return rc; 1498 } 1499