1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved. 5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved. 6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved. 7 * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io 8 * 9 * This software is available to you under a choice of one of two 10 * licenses. You may choose to be licensed under the terms of the GNU 11 * General Public License (GPL) Version 2, available from the file 12 * COPYING in the main directory of this source tree, or the 13 * OpenIB.org BSD license below: 14 * 15 * Redistribution and use in source and binary forms, with or 16 * without modification, are permitted provided that the following 17 * conditions are met: 18 * 19 * - Redistributions of source code must retain the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer. 22 * 23 * - Redistributions in binary form must reproduce the above 24 * copyright notice, this list of conditions and the following 25 * disclaimer in the documentation and/or other materials 26 * provided with the distribution. 27 * 28 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 29 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 30 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 31 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 32 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 33 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 34 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 * SOFTWARE. 36 */ 37 38 #include <linux/bug.h> 39 #include <linux/sched/signal.h> 40 #include <linux/module.h> 41 #include <linux/splice.h> 42 #include <crypto/aead.h> 43 44 #include <net/strparser.h> 45 #include <net/tls.h> 46 47 noinline void tls_err_abort(struct sock *sk, int err) 48 { 49 WARN_ON_ONCE(err >= 0); 50 /* sk->sk_err should contain a positive error code. */ 51 sk->sk_err = -err; 52 sk_error_report(sk); 53 } 54 55 static int __skb_nsg(struct sk_buff *skb, int offset, int len, 56 unsigned int recursion_level) 57 { 58 int start = skb_headlen(skb); 59 int i, chunk = start - offset; 60 struct sk_buff *frag_iter; 61 int elt = 0; 62 63 if (unlikely(recursion_level >= 24)) 64 return -EMSGSIZE; 65 66 if (chunk > 0) { 67 if (chunk > len) 68 chunk = len; 69 elt++; 70 len -= chunk; 71 if (len == 0) 72 return elt; 73 offset += chunk; 74 } 75 76 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { 77 int end; 78 79 WARN_ON(start > offset + len); 80 81 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]); 82 chunk = end - offset; 83 if (chunk > 0) { 84 if (chunk > len) 85 chunk = len; 86 elt++; 87 len -= chunk; 88 if (len == 0) 89 return elt; 90 offset += chunk; 91 } 92 start = end; 93 } 94 95 if (unlikely(skb_has_frag_list(skb))) { 96 skb_walk_frags(skb, frag_iter) { 97 int end, ret; 98 99 WARN_ON(start > offset + len); 100 101 end = start + frag_iter->len; 102 chunk = end - offset; 103 if (chunk > 0) { 104 if (chunk > len) 105 chunk = len; 106 ret = __skb_nsg(frag_iter, offset - start, chunk, 107 recursion_level + 1); 108 if (unlikely(ret < 0)) 109 return ret; 110 elt += ret; 111 len -= chunk; 112 if (len == 0) 113 return elt; 114 offset += chunk; 115 } 116 start = end; 117 } 118 } 119 BUG_ON(len); 120 return elt; 121 } 122 123 /* Return the number of scatterlist elements required to completely map the 124 * skb, or -EMSGSIZE if the recursion depth is exceeded. 125 */ 126 static int skb_nsg(struct sk_buff *skb, int offset, int len) 127 { 128 return __skb_nsg(skb, offset, len, 0); 129 } 130 131 static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb) 132 { 133 struct strp_msg *rxm = strp_msg(skb); 134 struct tls_msg *tlm = tls_msg(skb); 135 int sub = 0; 136 137 /* Determine zero-padding length */ 138 if (prot->version == TLS_1_3_VERSION) { 139 int offset = rxm->full_len - TLS_TAG_SIZE - 1; 140 char content_type = 0; 141 int err; 142 143 while (content_type == 0) { 144 if (offset < prot->prepend_size) 145 return -EBADMSG; 146 err = skb_copy_bits(skb, rxm->offset + offset, 147 &content_type, 1); 148 if (err) 149 return err; 150 if (content_type) 151 break; 152 sub++; 153 offset--; 154 } 155 tlm->control = content_type; 156 } 157 return sub; 158 } 159 160 static void tls_decrypt_done(struct crypto_async_request *req, int err) 161 { 162 struct aead_request *aead_req = (struct aead_request *)req; 163 struct scatterlist *sgout = aead_req->dst; 164 struct scatterlist *sgin = aead_req->src; 165 struct tls_sw_context_rx *ctx; 166 struct tls_context *tls_ctx; 167 struct tls_prot_info *prot; 168 struct scatterlist *sg; 169 struct sk_buff *skb; 170 unsigned int pages; 171 int pending; 172 173 skb = (struct sk_buff *)req->data; 174 tls_ctx = tls_get_ctx(skb->sk); 175 ctx = tls_sw_ctx_rx(tls_ctx); 176 prot = &tls_ctx->prot_info; 177 178 /* Propagate if there was an err */ 179 if (err) { 180 if (err == -EBADMSG) 181 TLS_INC_STATS(sock_net(skb->sk), 182 LINUX_MIB_TLSDECRYPTERROR); 183 ctx->async_wait.err = err; 184 tls_err_abort(skb->sk, err); 185 } else { 186 struct strp_msg *rxm = strp_msg(skb); 187 int pad; 188 189 pad = padding_length(prot, skb); 190 if (pad < 0) { 191 ctx->async_wait.err = pad; 192 tls_err_abort(skb->sk, pad); 193 } else { 194 rxm->full_len -= pad; 195 rxm->offset += prot->prepend_size; 196 rxm->full_len -= prot->overhead_size; 197 } 198 } 199 200 /* After using skb->sk to propagate sk through crypto async callback 201 * we need to NULL it again. 202 */ 203 skb->sk = NULL; 204 205 206 /* Free the destination pages if skb was not decrypted inplace */ 207 if (sgout != sgin) { 208 /* Skip the first S/G entry as it points to AAD */ 209 for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { 210 if (!sg) 211 break; 212 put_page(sg_page(sg)); 213 } 214 } 215 216 kfree(aead_req); 217 218 spin_lock_bh(&ctx->decrypt_compl_lock); 219 pending = atomic_dec_return(&ctx->decrypt_pending); 220 221 if (!pending && ctx->async_notify) 222 complete(&ctx->async_wait.completion); 223 spin_unlock_bh(&ctx->decrypt_compl_lock); 224 } 225 226 static int tls_do_decryption(struct sock *sk, 227 struct sk_buff *skb, 228 struct scatterlist *sgin, 229 struct scatterlist *sgout, 230 char *iv_recv, 231 size_t data_len, 232 struct aead_request *aead_req, 233 bool async) 234 { 235 struct tls_context *tls_ctx = tls_get_ctx(sk); 236 struct tls_prot_info *prot = &tls_ctx->prot_info; 237 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 238 int ret; 239 240 aead_request_set_tfm(aead_req, ctx->aead_recv); 241 aead_request_set_ad(aead_req, prot->aad_size); 242 aead_request_set_crypt(aead_req, sgin, sgout, 243 data_len + prot->tag_size, 244 (u8 *)iv_recv); 245 246 if (async) { 247 /* Using skb->sk to push sk through to crypto async callback 248 * handler. This allows propagating errors up to the socket 249 * if needed. It _must_ be cleared in the async handler 250 * before consume_skb is called. We _know_ skb->sk is NULL 251 * because it is a clone from strparser. 252 */ 253 skb->sk = sk; 254 aead_request_set_callback(aead_req, 255 CRYPTO_TFM_REQ_MAY_BACKLOG, 256 tls_decrypt_done, skb); 257 atomic_inc(&ctx->decrypt_pending); 258 } else { 259 aead_request_set_callback(aead_req, 260 CRYPTO_TFM_REQ_MAY_BACKLOG, 261 crypto_req_done, &ctx->async_wait); 262 } 263 264 ret = crypto_aead_decrypt(aead_req); 265 if (ret == -EINPROGRESS) { 266 if (async) 267 return ret; 268 269 ret = crypto_wait_req(ret, &ctx->async_wait); 270 } 271 272 if (async) 273 atomic_dec(&ctx->decrypt_pending); 274 275 return ret; 276 } 277 278 static void tls_trim_both_msgs(struct sock *sk, int target_size) 279 { 280 struct tls_context *tls_ctx = tls_get_ctx(sk); 281 struct tls_prot_info *prot = &tls_ctx->prot_info; 282 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 283 struct tls_rec *rec = ctx->open_rec; 284 285 sk_msg_trim(sk, &rec->msg_plaintext, target_size); 286 if (target_size > 0) 287 target_size += prot->overhead_size; 288 sk_msg_trim(sk, &rec->msg_encrypted, target_size); 289 } 290 291 static int tls_alloc_encrypted_msg(struct sock *sk, int len) 292 { 293 struct tls_context *tls_ctx = tls_get_ctx(sk); 294 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 295 struct tls_rec *rec = ctx->open_rec; 296 struct sk_msg *msg_en = &rec->msg_encrypted; 297 298 return sk_msg_alloc(sk, msg_en, len, 0); 299 } 300 301 static int tls_clone_plaintext_msg(struct sock *sk, int required) 302 { 303 struct tls_context *tls_ctx = tls_get_ctx(sk); 304 struct tls_prot_info *prot = &tls_ctx->prot_info; 305 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 306 struct tls_rec *rec = ctx->open_rec; 307 struct sk_msg *msg_pl = &rec->msg_plaintext; 308 struct sk_msg *msg_en = &rec->msg_encrypted; 309 int skip, len; 310 311 /* We add page references worth len bytes from encrypted sg 312 * at the end of plaintext sg. It is guaranteed that msg_en 313 * has enough required room (ensured by caller). 314 */ 315 len = required - msg_pl->sg.size; 316 317 /* Skip initial bytes in msg_en's data to be able to use 318 * same offset of both plain and encrypted data. 319 */ 320 skip = prot->prepend_size + msg_pl->sg.size; 321 322 return sk_msg_clone(sk, msg_pl, msg_en, skip, len); 323 } 324 325 static struct tls_rec *tls_get_rec(struct sock *sk) 326 { 327 struct tls_context *tls_ctx = tls_get_ctx(sk); 328 struct tls_prot_info *prot = &tls_ctx->prot_info; 329 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 330 struct sk_msg *msg_pl, *msg_en; 331 struct tls_rec *rec; 332 int mem_size; 333 334 mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send); 335 336 rec = kzalloc(mem_size, sk->sk_allocation); 337 if (!rec) 338 return NULL; 339 340 msg_pl = &rec->msg_plaintext; 341 msg_en = &rec->msg_encrypted; 342 343 sk_msg_init(msg_pl); 344 sk_msg_init(msg_en); 345 346 sg_init_table(rec->sg_aead_in, 2); 347 sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size); 348 sg_unmark_end(&rec->sg_aead_in[1]); 349 350 sg_init_table(rec->sg_aead_out, 2); 351 sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size); 352 sg_unmark_end(&rec->sg_aead_out[1]); 353 354 return rec; 355 } 356 357 static void tls_free_rec(struct sock *sk, struct tls_rec *rec) 358 { 359 sk_msg_free(sk, &rec->msg_encrypted); 360 sk_msg_free(sk, &rec->msg_plaintext); 361 kfree(rec); 362 } 363 364 static void tls_free_open_rec(struct sock *sk) 365 { 366 struct tls_context *tls_ctx = tls_get_ctx(sk); 367 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 368 struct tls_rec *rec = ctx->open_rec; 369 370 if (rec) { 371 tls_free_rec(sk, rec); 372 ctx->open_rec = NULL; 373 } 374 } 375 376 int tls_tx_records(struct sock *sk, int flags) 377 { 378 struct tls_context *tls_ctx = tls_get_ctx(sk); 379 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 380 struct tls_rec *rec, *tmp; 381 struct sk_msg *msg_en; 382 int tx_flags, rc = 0; 383 384 if (tls_is_partially_sent_record(tls_ctx)) { 385 rec = list_first_entry(&ctx->tx_list, 386 struct tls_rec, list); 387 388 if (flags == -1) 389 tx_flags = rec->tx_flags; 390 else 391 tx_flags = flags; 392 393 rc = tls_push_partial_record(sk, tls_ctx, tx_flags); 394 if (rc) 395 goto tx_err; 396 397 /* Full record has been transmitted. 398 * Remove the head of tx_list 399 */ 400 list_del(&rec->list); 401 sk_msg_free(sk, &rec->msg_plaintext); 402 kfree(rec); 403 } 404 405 /* Tx all ready records */ 406 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) { 407 if (READ_ONCE(rec->tx_ready)) { 408 if (flags == -1) 409 tx_flags = rec->tx_flags; 410 else 411 tx_flags = flags; 412 413 msg_en = &rec->msg_encrypted; 414 rc = tls_push_sg(sk, tls_ctx, 415 &msg_en->sg.data[msg_en->sg.curr], 416 0, tx_flags); 417 if (rc) 418 goto tx_err; 419 420 list_del(&rec->list); 421 sk_msg_free(sk, &rec->msg_plaintext); 422 kfree(rec); 423 } else { 424 break; 425 } 426 } 427 428 tx_err: 429 if (rc < 0 && rc != -EAGAIN) 430 tls_err_abort(sk, -EBADMSG); 431 432 return rc; 433 } 434 435 static void tls_encrypt_done(struct crypto_async_request *req, int err) 436 { 437 struct aead_request *aead_req = (struct aead_request *)req; 438 struct sock *sk = req->data; 439 struct tls_context *tls_ctx = tls_get_ctx(sk); 440 struct tls_prot_info *prot = &tls_ctx->prot_info; 441 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 442 struct scatterlist *sge; 443 struct sk_msg *msg_en; 444 struct tls_rec *rec; 445 bool ready = false; 446 int pending; 447 448 rec = container_of(aead_req, struct tls_rec, aead_req); 449 msg_en = &rec->msg_encrypted; 450 451 sge = sk_msg_elem(msg_en, msg_en->sg.curr); 452 sge->offset -= prot->prepend_size; 453 sge->length += prot->prepend_size; 454 455 /* Check if error is previously set on socket */ 456 if (err || sk->sk_err) { 457 rec = NULL; 458 459 /* If err is already set on socket, return the same code */ 460 if (sk->sk_err) { 461 ctx->async_wait.err = -sk->sk_err; 462 } else { 463 ctx->async_wait.err = err; 464 tls_err_abort(sk, err); 465 } 466 } 467 468 if (rec) { 469 struct tls_rec *first_rec; 470 471 /* Mark the record as ready for transmission */ 472 smp_store_mb(rec->tx_ready, true); 473 474 /* If received record is at head of tx_list, schedule tx */ 475 first_rec = list_first_entry(&ctx->tx_list, 476 struct tls_rec, list); 477 if (rec == first_rec) 478 ready = true; 479 } 480 481 spin_lock_bh(&ctx->encrypt_compl_lock); 482 pending = atomic_dec_return(&ctx->encrypt_pending); 483 484 if (!pending && ctx->async_notify) 485 complete(&ctx->async_wait.completion); 486 spin_unlock_bh(&ctx->encrypt_compl_lock); 487 488 if (!ready) 489 return; 490 491 /* Schedule the transmission */ 492 if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) 493 schedule_delayed_work(&ctx->tx_work.work, 1); 494 } 495 496 static int tls_do_encryption(struct sock *sk, 497 struct tls_context *tls_ctx, 498 struct tls_sw_context_tx *ctx, 499 struct aead_request *aead_req, 500 size_t data_len, u32 start) 501 { 502 struct tls_prot_info *prot = &tls_ctx->prot_info; 503 struct tls_rec *rec = ctx->open_rec; 504 struct sk_msg *msg_en = &rec->msg_encrypted; 505 struct scatterlist *sge = sk_msg_elem(msg_en, start); 506 int rc, iv_offset = 0; 507 508 /* For CCM based ciphers, first byte of IV is a constant */ 509 switch (prot->cipher_type) { 510 case TLS_CIPHER_AES_CCM_128: 511 rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE; 512 iv_offset = 1; 513 break; 514 case TLS_CIPHER_SM4_CCM: 515 rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE; 516 iv_offset = 1; 517 break; 518 } 519 520 memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, 521 prot->iv_size + prot->salt_size); 522 523 xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq); 524 525 sge->offset += prot->prepend_size; 526 sge->length -= prot->prepend_size; 527 528 msg_en->sg.curr = start; 529 530 aead_request_set_tfm(aead_req, ctx->aead_send); 531 aead_request_set_ad(aead_req, prot->aad_size); 532 aead_request_set_crypt(aead_req, rec->sg_aead_in, 533 rec->sg_aead_out, 534 data_len, rec->iv_data); 535 536 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, 537 tls_encrypt_done, sk); 538 539 /* Add the record in tx_list */ 540 list_add_tail((struct list_head *)&rec->list, &ctx->tx_list); 541 atomic_inc(&ctx->encrypt_pending); 542 543 rc = crypto_aead_encrypt(aead_req); 544 if (!rc || rc != -EINPROGRESS) { 545 atomic_dec(&ctx->encrypt_pending); 546 sge->offset -= prot->prepend_size; 547 sge->length += prot->prepend_size; 548 } 549 550 if (!rc) { 551 WRITE_ONCE(rec->tx_ready, true); 552 } else if (rc != -EINPROGRESS) { 553 list_del(&rec->list); 554 return rc; 555 } 556 557 /* Unhook the record from context if encryption is not failure */ 558 ctx->open_rec = NULL; 559 tls_advance_record_sn(sk, prot, &tls_ctx->tx); 560 return rc; 561 } 562 563 static int tls_split_open_record(struct sock *sk, struct tls_rec *from, 564 struct tls_rec **to, struct sk_msg *msg_opl, 565 struct sk_msg *msg_oen, u32 split_point, 566 u32 tx_overhead_size, u32 *orig_end) 567 { 568 u32 i, j, bytes = 0, apply = msg_opl->apply_bytes; 569 struct scatterlist *sge, *osge, *nsge; 570 u32 orig_size = msg_opl->sg.size; 571 struct scatterlist tmp = { }; 572 struct sk_msg *msg_npl; 573 struct tls_rec *new; 574 int ret; 575 576 new = tls_get_rec(sk); 577 if (!new) 578 return -ENOMEM; 579 ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size + 580 tx_overhead_size, 0); 581 if (ret < 0) { 582 tls_free_rec(sk, new); 583 return ret; 584 } 585 586 *orig_end = msg_opl->sg.end; 587 i = msg_opl->sg.start; 588 sge = sk_msg_elem(msg_opl, i); 589 while (apply && sge->length) { 590 if (sge->length > apply) { 591 u32 len = sge->length - apply; 592 593 get_page(sg_page(sge)); 594 sg_set_page(&tmp, sg_page(sge), len, 595 sge->offset + apply); 596 sge->length = apply; 597 bytes += apply; 598 apply = 0; 599 } else { 600 apply -= sge->length; 601 bytes += sge->length; 602 } 603 604 sk_msg_iter_var_next(i); 605 if (i == msg_opl->sg.end) 606 break; 607 sge = sk_msg_elem(msg_opl, i); 608 } 609 610 msg_opl->sg.end = i; 611 msg_opl->sg.curr = i; 612 msg_opl->sg.copybreak = 0; 613 msg_opl->apply_bytes = 0; 614 msg_opl->sg.size = bytes; 615 616 msg_npl = &new->msg_plaintext; 617 msg_npl->apply_bytes = apply; 618 msg_npl->sg.size = orig_size - bytes; 619 620 j = msg_npl->sg.start; 621 nsge = sk_msg_elem(msg_npl, j); 622 if (tmp.length) { 623 memcpy(nsge, &tmp, sizeof(*nsge)); 624 sk_msg_iter_var_next(j); 625 nsge = sk_msg_elem(msg_npl, j); 626 } 627 628 osge = sk_msg_elem(msg_opl, i); 629 while (osge->length) { 630 memcpy(nsge, osge, sizeof(*nsge)); 631 sg_unmark_end(nsge); 632 sk_msg_iter_var_next(i); 633 sk_msg_iter_var_next(j); 634 if (i == *orig_end) 635 break; 636 osge = sk_msg_elem(msg_opl, i); 637 nsge = sk_msg_elem(msg_npl, j); 638 } 639 640 msg_npl->sg.end = j; 641 msg_npl->sg.curr = j; 642 msg_npl->sg.copybreak = 0; 643 644 *to = new; 645 return 0; 646 } 647 648 static void tls_merge_open_record(struct sock *sk, struct tls_rec *to, 649 struct tls_rec *from, u32 orig_end) 650 { 651 struct sk_msg *msg_npl = &from->msg_plaintext; 652 struct sk_msg *msg_opl = &to->msg_plaintext; 653 struct scatterlist *osge, *nsge; 654 u32 i, j; 655 656 i = msg_opl->sg.end; 657 sk_msg_iter_var_prev(i); 658 j = msg_npl->sg.start; 659 660 osge = sk_msg_elem(msg_opl, i); 661 nsge = sk_msg_elem(msg_npl, j); 662 663 if (sg_page(osge) == sg_page(nsge) && 664 osge->offset + osge->length == nsge->offset) { 665 osge->length += nsge->length; 666 put_page(sg_page(nsge)); 667 } 668 669 msg_opl->sg.end = orig_end; 670 msg_opl->sg.curr = orig_end; 671 msg_opl->sg.copybreak = 0; 672 msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size; 673 msg_opl->sg.size += msg_npl->sg.size; 674 675 sk_msg_free(sk, &to->msg_encrypted); 676 sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted); 677 678 kfree(from); 679 } 680 681 static int tls_push_record(struct sock *sk, int flags, 682 unsigned char record_type) 683 { 684 struct tls_context *tls_ctx = tls_get_ctx(sk); 685 struct tls_prot_info *prot = &tls_ctx->prot_info; 686 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 687 struct tls_rec *rec = ctx->open_rec, *tmp = NULL; 688 u32 i, split_point, orig_end; 689 struct sk_msg *msg_pl, *msg_en; 690 struct aead_request *req; 691 bool split; 692 int rc; 693 694 if (!rec) 695 return 0; 696 697 msg_pl = &rec->msg_plaintext; 698 msg_en = &rec->msg_encrypted; 699 700 split_point = msg_pl->apply_bytes; 701 split = split_point && split_point < msg_pl->sg.size; 702 if (unlikely((!split && 703 msg_pl->sg.size + 704 prot->overhead_size > msg_en->sg.size) || 705 (split && 706 split_point + 707 prot->overhead_size > msg_en->sg.size))) { 708 split = true; 709 split_point = msg_en->sg.size; 710 } 711 if (split) { 712 rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, 713 split_point, prot->overhead_size, 714 &orig_end); 715 if (rc < 0) 716 return rc; 717 /* This can happen if above tls_split_open_record allocates 718 * a single large encryption buffer instead of two smaller 719 * ones. In this case adjust pointers and continue without 720 * split. 721 */ 722 if (!msg_pl->sg.size) { 723 tls_merge_open_record(sk, rec, tmp, orig_end); 724 msg_pl = &rec->msg_plaintext; 725 msg_en = &rec->msg_encrypted; 726 split = false; 727 } 728 sk_msg_trim(sk, msg_en, msg_pl->sg.size + 729 prot->overhead_size); 730 } 731 732 rec->tx_flags = flags; 733 req = &rec->aead_req; 734 735 i = msg_pl->sg.end; 736 sk_msg_iter_var_prev(i); 737 738 rec->content_type = record_type; 739 if (prot->version == TLS_1_3_VERSION) { 740 /* Add content type to end of message. No padding added */ 741 sg_set_buf(&rec->sg_content_type, &rec->content_type, 1); 742 sg_mark_end(&rec->sg_content_type); 743 sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1, 744 &rec->sg_content_type); 745 } else { 746 sg_mark_end(sk_msg_elem(msg_pl, i)); 747 } 748 749 if (msg_pl->sg.end < msg_pl->sg.start) { 750 sg_chain(&msg_pl->sg.data[msg_pl->sg.start], 751 MAX_SKB_FRAGS - msg_pl->sg.start + 1, 752 msg_pl->sg.data); 753 } 754 755 i = msg_pl->sg.start; 756 sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]); 757 758 i = msg_en->sg.end; 759 sk_msg_iter_var_prev(i); 760 sg_mark_end(sk_msg_elem(msg_en, i)); 761 762 i = msg_en->sg.start; 763 sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]); 764 765 tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size, 766 tls_ctx->tx.rec_seq, record_type, prot); 767 768 tls_fill_prepend(tls_ctx, 769 page_address(sg_page(&msg_en->sg.data[i])) + 770 msg_en->sg.data[i].offset, 771 msg_pl->sg.size + prot->tail_size, 772 record_type); 773 774 tls_ctx->pending_open_record_frags = false; 775 776 rc = tls_do_encryption(sk, tls_ctx, ctx, req, 777 msg_pl->sg.size + prot->tail_size, i); 778 if (rc < 0) { 779 if (rc != -EINPROGRESS) { 780 tls_err_abort(sk, -EBADMSG); 781 if (split) { 782 tls_ctx->pending_open_record_frags = true; 783 tls_merge_open_record(sk, rec, tmp, orig_end); 784 } 785 } 786 ctx->async_capable = 1; 787 return rc; 788 } else if (split) { 789 msg_pl = &tmp->msg_plaintext; 790 msg_en = &tmp->msg_encrypted; 791 sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size); 792 tls_ctx->pending_open_record_frags = true; 793 ctx->open_rec = tmp; 794 } 795 796 return tls_tx_records(sk, flags); 797 } 798 799 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, 800 bool full_record, u8 record_type, 801 ssize_t *copied, int flags) 802 { 803 struct tls_context *tls_ctx = tls_get_ctx(sk); 804 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 805 struct sk_msg msg_redir = { }; 806 struct sk_psock *psock; 807 struct sock *sk_redir; 808 struct tls_rec *rec; 809 bool enospc, policy; 810 int err = 0, send; 811 u32 delta = 0; 812 813 policy = !(flags & MSG_SENDPAGE_NOPOLICY); 814 psock = sk_psock_get(sk); 815 if (!psock || !policy) { 816 err = tls_push_record(sk, flags, record_type); 817 if (err && sk->sk_err == EBADMSG) { 818 *copied -= sk_msg_free(sk, msg); 819 tls_free_open_rec(sk); 820 err = -sk->sk_err; 821 } 822 if (psock) 823 sk_psock_put(sk, psock); 824 return err; 825 } 826 more_data: 827 enospc = sk_msg_full(msg); 828 if (psock->eval == __SK_NONE) { 829 delta = msg->sg.size; 830 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 831 delta -= msg->sg.size; 832 } 833 if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && 834 !enospc && !full_record) { 835 err = -ENOSPC; 836 goto out_err; 837 } 838 msg->cork_bytes = 0; 839 send = msg->sg.size; 840 if (msg->apply_bytes && msg->apply_bytes < send) 841 send = msg->apply_bytes; 842 843 switch (psock->eval) { 844 case __SK_PASS: 845 err = tls_push_record(sk, flags, record_type); 846 if (err && sk->sk_err == EBADMSG) { 847 *copied -= sk_msg_free(sk, msg); 848 tls_free_open_rec(sk); 849 err = -sk->sk_err; 850 goto out_err; 851 } 852 break; 853 case __SK_REDIRECT: 854 sk_redir = psock->sk_redir; 855 memcpy(&msg_redir, msg, sizeof(*msg)); 856 if (msg->apply_bytes < send) 857 msg->apply_bytes = 0; 858 else 859 msg->apply_bytes -= send; 860 sk_msg_return_zero(sk, msg, send); 861 msg->sg.size -= send; 862 release_sock(sk); 863 err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags); 864 lock_sock(sk); 865 if (err < 0) { 866 *copied -= sk_msg_free_nocharge(sk, &msg_redir); 867 msg->sg.size = 0; 868 } 869 if (msg->sg.size == 0) 870 tls_free_open_rec(sk); 871 break; 872 case __SK_DROP: 873 default: 874 sk_msg_free_partial(sk, msg, send); 875 if (msg->apply_bytes < send) 876 msg->apply_bytes = 0; 877 else 878 msg->apply_bytes -= send; 879 if (msg->sg.size == 0) 880 tls_free_open_rec(sk); 881 *copied -= (send + delta); 882 err = -EACCES; 883 } 884 885 if (likely(!err)) { 886 bool reset_eval = !ctx->open_rec; 887 888 rec = ctx->open_rec; 889 if (rec) { 890 msg = &rec->msg_plaintext; 891 if (!msg->apply_bytes) 892 reset_eval = true; 893 } 894 if (reset_eval) { 895 psock->eval = __SK_NONE; 896 if (psock->sk_redir) { 897 sock_put(psock->sk_redir); 898 psock->sk_redir = NULL; 899 } 900 } 901 if (rec) 902 goto more_data; 903 } 904 out_err: 905 sk_psock_put(sk, psock); 906 return err; 907 } 908 909 static int tls_sw_push_pending_record(struct sock *sk, int flags) 910 { 911 struct tls_context *tls_ctx = tls_get_ctx(sk); 912 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 913 struct tls_rec *rec = ctx->open_rec; 914 struct sk_msg *msg_pl; 915 size_t copied; 916 917 if (!rec) 918 return 0; 919 920 msg_pl = &rec->msg_plaintext; 921 copied = msg_pl->sg.size; 922 if (!copied) 923 return 0; 924 925 return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA, 926 &copied, flags); 927 } 928 929 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 930 { 931 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 932 struct tls_context *tls_ctx = tls_get_ctx(sk); 933 struct tls_prot_info *prot = &tls_ctx->prot_info; 934 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 935 bool async_capable = ctx->async_capable; 936 unsigned char record_type = TLS_RECORD_TYPE_DATA; 937 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); 938 bool eor = !(msg->msg_flags & MSG_MORE); 939 size_t try_to_copy; 940 ssize_t copied = 0; 941 struct sk_msg *msg_pl, *msg_en; 942 struct tls_rec *rec; 943 int required_size; 944 int num_async = 0; 945 bool full_record; 946 int record_room; 947 int num_zc = 0; 948 int orig_size; 949 int ret = 0; 950 int pending; 951 952 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 953 MSG_CMSG_COMPAT)) 954 return -EOPNOTSUPP; 955 956 mutex_lock(&tls_ctx->tx_lock); 957 lock_sock(sk); 958 959 if (unlikely(msg->msg_controllen)) { 960 ret = tls_proccess_cmsg(sk, msg, &record_type); 961 if (ret) { 962 if (ret == -EINPROGRESS) 963 num_async++; 964 else if (ret != -EAGAIN) 965 goto send_end; 966 } 967 } 968 969 while (msg_data_left(msg)) { 970 if (sk->sk_err) { 971 ret = -sk->sk_err; 972 goto send_end; 973 } 974 975 if (ctx->open_rec) 976 rec = ctx->open_rec; 977 else 978 rec = ctx->open_rec = tls_get_rec(sk); 979 if (!rec) { 980 ret = -ENOMEM; 981 goto send_end; 982 } 983 984 msg_pl = &rec->msg_plaintext; 985 msg_en = &rec->msg_encrypted; 986 987 orig_size = msg_pl->sg.size; 988 full_record = false; 989 try_to_copy = msg_data_left(msg); 990 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; 991 if (try_to_copy >= record_room) { 992 try_to_copy = record_room; 993 full_record = true; 994 } 995 996 required_size = msg_pl->sg.size + try_to_copy + 997 prot->overhead_size; 998 999 if (!sk_stream_memory_free(sk)) 1000 goto wait_for_sndbuf; 1001 1002 alloc_encrypted: 1003 ret = tls_alloc_encrypted_msg(sk, required_size); 1004 if (ret) { 1005 if (ret != -ENOSPC) 1006 goto wait_for_memory; 1007 1008 /* Adjust try_to_copy according to the amount that was 1009 * actually allocated. The difference is due 1010 * to max sg elements limit 1011 */ 1012 try_to_copy -= required_size - msg_en->sg.size; 1013 full_record = true; 1014 } 1015 1016 if (!is_kvec && (full_record || eor) && !async_capable) { 1017 u32 first = msg_pl->sg.end; 1018 1019 ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter, 1020 msg_pl, try_to_copy); 1021 if (ret) 1022 goto fallback_to_reg_send; 1023 1024 num_zc++; 1025 copied += try_to_copy; 1026 1027 sk_msg_sg_copy_set(msg_pl, first); 1028 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, 1029 record_type, &copied, 1030 msg->msg_flags); 1031 if (ret) { 1032 if (ret == -EINPROGRESS) 1033 num_async++; 1034 else if (ret == -ENOMEM) 1035 goto wait_for_memory; 1036 else if (ctx->open_rec && ret == -ENOSPC) 1037 goto rollback_iter; 1038 else if (ret != -EAGAIN) 1039 goto send_end; 1040 } 1041 continue; 1042 rollback_iter: 1043 copied -= try_to_copy; 1044 sk_msg_sg_copy_clear(msg_pl, first); 1045 iov_iter_revert(&msg->msg_iter, 1046 msg_pl->sg.size - orig_size); 1047 fallback_to_reg_send: 1048 sk_msg_trim(sk, msg_pl, orig_size); 1049 } 1050 1051 required_size = msg_pl->sg.size + try_to_copy; 1052 1053 ret = tls_clone_plaintext_msg(sk, required_size); 1054 if (ret) { 1055 if (ret != -ENOSPC) 1056 goto send_end; 1057 1058 /* Adjust try_to_copy according to the amount that was 1059 * actually allocated. The difference is due 1060 * to max sg elements limit 1061 */ 1062 try_to_copy -= required_size - msg_pl->sg.size; 1063 full_record = true; 1064 sk_msg_trim(sk, msg_en, 1065 msg_pl->sg.size + prot->overhead_size); 1066 } 1067 1068 if (try_to_copy) { 1069 ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, 1070 msg_pl, try_to_copy); 1071 if (ret < 0) 1072 goto trim_sgl; 1073 } 1074 1075 /* Open records defined only if successfully copied, otherwise 1076 * we would trim the sg but not reset the open record frags. 1077 */ 1078 tls_ctx->pending_open_record_frags = true; 1079 copied += try_to_copy; 1080 if (full_record || eor) { 1081 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, 1082 record_type, &copied, 1083 msg->msg_flags); 1084 if (ret) { 1085 if (ret == -EINPROGRESS) 1086 num_async++; 1087 else if (ret == -ENOMEM) 1088 goto wait_for_memory; 1089 else if (ret != -EAGAIN) { 1090 if (ret == -ENOSPC) 1091 ret = 0; 1092 goto send_end; 1093 } 1094 } 1095 } 1096 1097 continue; 1098 1099 wait_for_sndbuf: 1100 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 1101 wait_for_memory: 1102 ret = sk_stream_wait_memory(sk, &timeo); 1103 if (ret) { 1104 trim_sgl: 1105 if (ctx->open_rec) 1106 tls_trim_both_msgs(sk, orig_size); 1107 goto send_end; 1108 } 1109 1110 if (ctx->open_rec && msg_en->sg.size < required_size) 1111 goto alloc_encrypted; 1112 } 1113 1114 if (!num_async) { 1115 goto send_end; 1116 } else if (num_zc) { 1117 /* Wait for pending encryptions to get completed */ 1118 spin_lock_bh(&ctx->encrypt_compl_lock); 1119 ctx->async_notify = true; 1120 1121 pending = atomic_read(&ctx->encrypt_pending); 1122 spin_unlock_bh(&ctx->encrypt_compl_lock); 1123 if (pending) 1124 crypto_wait_req(-EINPROGRESS, &ctx->async_wait); 1125 else 1126 reinit_completion(&ctx->async_wait.completion); 1127 1128 /* There can be no concurrent accesses, since we have no 1129 * pending encrypt operations 1130 */ 1131 WRITE_ONCE(ctx->async_notify, false); 1132 1133 if (ctx->async_wait.err) { 1134 ret = ctx->async_wait.err; 1135 copied = 0; 1136 } 1137 } 1138 1139 /* Transmit if any encryptions have completed */ 1140 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { 1141 cancel_delayed_work(&ctx->tx_work.work); 1142 tls_tx_records(sk, msg->msg_flags); 1143 } 1144 1145 send_end: 1146 ret = sk_stream_error(sk, msg->msg_flags, ret); 1147 1148 release_sock(sk); 1149 mutex_unlock(&tls_ctx->tx_lock); 1150 return copied > 0 ? copied : ret; 1151 } 1152 1153 static int tls_sw_do_sendpage(struct sock *sk, struct page *page, 1154 int offset, size_t size, int flags) 1155 { 1156 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); 1157 struct tls_context *tls_ctx = tls_get_ctx(sk); 1158 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 1159 struct tls_prot_info *prot = &tls_ctx->prot_info; 1160 unsigned char record_type = TLS_RECORD_TYPE_DATA; 1161 struct sk_msg *msg_pl; 1162 struct tls_rec *rec; 1163 int num_async = 0; 1164 ssize_t copied = 0; 1165 bool full_record; 1166 int record_room; 1167 int ret = 0; 1168 bool eor; 1169 1170 eor = !(flags & MSG_SENDPAGE_NOTLAST); 1171 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); 1172 1173 /* Call the sk_stream functions to manage the sndbuf mem. */ 1174 while (size > 0) { 1175 size_t copy, required_size; 1176 1177 if (sk->sk_err) { 1178 ret = -sk->sk_err; 1179 goto sendpage_end; 1180 } 1181 1182 if (ctx->open_rec) 1183 rec = ctx->open_rec; 1184 else 1185 rec = ctx->open_rec = tls_get_rec(sk); 1186 if (!rec) { 1187 ret = -ENOMEM; 1188 goto sendpage_end; 1189 } 1190 1191 msg_pl = &rec->msg_plaintext; 1192 1193 full_record = false; 1194 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; 1195 copy = size; 1196 if (copy >= record_room) { 1197 copy = record_room; 1198 full_record = true; 1199 } 1200 1201 required_size = msg_pl->sg.size + copy + prot->overhead_size; 1202 1203 if (!sk_stream_memory_free(sk)) 1204 goto wait_for_sndbuf; 1205 alloc_payload: 1206 ret = tls_alloc_encrypted_msg(sk, required_size); 1207 if (ret) { 1208 if (ret != -ENOSPC) 1209 goto wait_for_memory; 1210 1211 /* Adjust copy according to the amount that was 1212 * actually allocated. The difference is due 1213 * to max sg elements limit 1214 */ 1215 copy -= required_size - msg_pl->sg.size; 1216 full_record = true; 1217 } 1218 1219 sk_msg_page_add(msg_pl, page, copy, offset); 1220 sk_mem_charge(sk, copy); 1221 1222 offset += copy; 1223 size -= copy; 1224 copied += copy; 1225 1226 tls_ctx->pending_open_record_frags = true; 1227 if (full_record || eor || sk_msg_full(msg_pl)) { 1228 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, 1229 record_type, &copied, flags); 1230 if (ret) { 1231 if (ret == -EINPROGRESS) 1232 num_async++; 1233 else if (ret == -ENOMEM) 1234 goto wait_for_memory; 1235 else if (ret != -EAGAIN) { 1236 if (ret == -ENOSPC) 1237 ret = 0; 1238 goto sendpage_end; 1239 } 1240 } 1241 } 1242 continue; 1243 wait_for_sndbuf: 1244 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 1245 wait_for_memory: 1246 ret = sk_stream_wait_memory(sk, &timeo); 1247 if (ret) { 1248 if (ctx->open_rec) 1249 tls_trim_both_msgs(sk, msg_pl->sg.size); 1250 goto sendpage_end; 1251 } 1252 1253 if (ctx->open_rec) 1254 goto alloc_payload; 1255 } 1256 1257 if (num_async) { 1258 /* Transmit if any encryptions have completed */ 1259 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { 1260 cancel_delayed_work(&ctx->tx_work.work); 1261 tls_tx_records(sk, flags); 1262 } 1263 } 1264 sendpage_end: 1265 ret = sk_stream_error(sk, flags, ret); 1266 return copied > 0 ? copied : ret; 1267 } 1268 1269 int tls_sw_sendpage_locked(struct sock *sk, struct page *page, 1270 int offset, size_t size, int flags) 1271 { 1272 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 1273 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY | 1274 MSG_NO_SHARED_FRAGS)) 1275 return -EOPNOTSUPP; 1276 1277 return tls_sw_do_sendpage(sk, page, offset, size, flags); 1278 } 1279 1280 int tls_sw_sendpage(struct sock *sk, struct page *page, 1281 int offset, size_t size, int flags) 1282 { 1283 struct tls_context *tls_ctx = tls_get_ctx(sk); 1284 int ret; 1285 1286 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 1287 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) 1288 return -EOPNOTSUPP; 1289 1290 mutex_lock(&tls_ctx->tx_lock); 1291 lock_sock(sk); 1292 ret = tls_sw_do_sendpage(sk, page, offset, size, flags); 1293 release_sock(sk); 1294 mutex_unlock(&tls_ctx->tx_lock); 1295 return ret; 1296 } 1297 1298 static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock, 1299 bool nonblock, long timeo, int *err) 1300 { 1301 struct tls_context *tls_ctx = tls_get_ctx(sk); 1302 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1303 struct sk_buff *skb; 1304 DEFINE_WAIT_FUNC(wait, woken_wake_function); 1305 1306 while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) { 1307 if (sk->sk_err) { 1308 *err = sock_error(sk); 1309 return NULL; 1310 } 1311 1312 if (!skb_queue_empty(&sk->sk_receive_queue)) { 1313 __strp_unpause(&ctx->strp); 1314 if (ctx->recv_pkt) 1315 return ctx->recv_pkt; 1316 } 1317 1318 if (sk->sk_shutdown & RCV_SHUTDOWN) 1319 return NULL; 1320 1321 if (sock_flag(sk, SOCK_DONE)) 1322 return NULL; 1323 1324 if (nonblock || !timeo) { 1325 *err = -EAGAIN; 1326 return NULL; 1327 } 1328 1329 add_wait_queue(sk_sleep(sk), &wait); 1330 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 1331 sk_wait_event(sk, &timeo, 1332 ctx->recv_pkt != skb || 1333 !sk_psock_queue_empty(psock), 1334 &wait); 1335 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 1336 remove_wait_queue(sk_sleep(sk), &wait); 1337 1338 /* Handle signals */ 1339 if (signal_pending(current)) { 1340 *err = sock_intr_errno(timeo); 1341 return NULL; 1342 } 1343 } 1344 1345 return skb; 1346 } 1347 1348 static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from, 1349 int length, int *pages_used, 1350 unsigned int *size_used, 1351 struct scatterlist *to, 1352 int to_max_pages) 1353 { 1354 int rc = 0, i = 0, num_elem = *pages_used, maxpages; 1355 struct page *pages[MAX_SKB_FRAGS]; 1356 unsigned int size = *size_used; 1357 ssize_t copied, use; 1358 size_t offset; 1359 1360 while (length > 0) { 1361 i = 0; 1362 maxpages = to_max_pages - num_elem; 1363 if (maxpages == 0) { 1364 rc = -EFAULT; 1365 goto out; 1366 } 1367 copied = iov_iter_get_pages(from, pages, 1368 length, 1369 maxpages, &offset); 1370 if (copied <= 0) { 1371 rc = -EFAULT; 1372 goto out; 1373 } 1374 1375 iov_iter_advance(from, copied); 1376 1377 length -= copied; 1378 size += copied; 1379 while (copied) { 1380 use = min_t(int, copied, PAGE_SIZE - offset); 1381 1382 sg_set_page(&to[num_elem], 1383 pages[i], use, offset); 1384 sg_unmark_end(&to[num_elem]); 1385 /* We do not uncharge memory from this API */ 1386 1387 offset = 0; 1388 copied -= use; 1389 1390 i++; 1391 num_elem++; 1392 } 1393 } 1394 /* Mark the end in the last sg entry if newly added */ 1395 if (num_elem > *pages_used) 1396 sg_mark_end(&to[num_elem - 1]); 1397 out: 1398 if (rc) 1399 iov_iter_revert(from, size - *size_used); 1400 *size_used = size; 1401 *pages_used = num_elem; 1402 1403 return rc; 1404 } 1405 1406 /* This function decrypts the input skb into either out_iov or in out_sg 1407 * or in skb buffers itself. The input parameter 'zc' indicates if 1408 * zero-copy mode needs to be tried or not. With zero-copy mode, either 1409 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are 1410 * NULL, then the decryption happens inside skb buffers itself, i.e. 1411 * zero-copy gets disabled and 'zc' is updated. 1412 */ 1413 1414 static int decrypt_internal(struct sock *sk, struct sk_buff *skb, 1415 struct iov_iter *out_iov, 1416 struct scatterlist *out_sg, 1417 int *chunk, bool *zc, bool async) 1418 { 1419 struct tls_context *tls_ctx = tls_get_ctx(sk); 1420 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1421 struct tls_prot_info *prot = &tls_ctx->prot_info; 1422 struct strp_msg *rxm = strp_msg(skb); 1423 struct tls_msg *tlm = tls_msg(skb); 1424 int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; 1425 struct aead_request *aead_req; 1426 struct sk_buff *unused; 1427 u8 *aad, *iv, *mem = NULL; 1428 struct scatterlist *sgin = NULL; 1429 struct scatterlist *sgout = NULL; 1430 const int data_len = rxm->full_len - prot->overhead_size + 1431 prot->tail_size; 1432 int iv_offset = 0; 1433 1434 if (*zc && (out_iov || out_sg)) { 1435 if (out_iov) 1436 n_sgout = 1 + 1437 iov_iter_npages_cap(out_iov, INT_MAX, data_len); 1438 else 1439 n_sgout = sg_nents(out_sg); 1440 n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size, 1441 rxm->full_len - prot->prepend_size); 1442 } else { 1443 n_sgout = 0; 1444 *zc = false; 1445 n_sgin = skb_cow_data(skb, 0, &unused); 1446 } 1447 1448 if (n_sgin < 1) 1449 return -EBADMSG; 1450 1451 /* Increment to accommodate AAD */ 1452 n_sgin = n_sgin + 1; 1453 1454 nsg = n_sgin + n_sgout; 1455 1456 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); 1457 mem_size = aead_size + (nsg * sizeof(struct scatterlist)); 1458 mem_size = mem_size + prot->aad_size; 1459 mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); 1460 1461 /* Allocate a single block of memory which contains 1462 * aead_req || sgin[] || sgout[] || aad || iv. 1463 * This order achieves correct alignment for aead_req, sgin, sgout. 1464 */ 1465 mem = kmalloc(mem_size, sk->sk_allocation); 1466 if (!mem) 1467 return -ENOMEM; 1468 1469 /* Segment the allocated memory */ 1470 aead_req = (struct aead_request *)mem; 1471 sgin = (struct scatterlist *)(mem + aead_size); 1472 sgout = sgin + n_sgin; 1473 aad = (u8 *)(sgout + n_sgout); 1474 iv = aad + prot->aad_size; 1475 1476 /* For CCM based ciphers, first byte of nonce+iv is a constant */ 1477 switch (prot->cipher_type) { 1478 case TLS_CIPHER_AES_CCM_128: 1479 iv[0] = TLS_AES_CCM_IV_B0_BYTE; 1480 iv_offset = 1; 1481 break; 1482 case TLS_CIPHER_SM4_CCM: 1483 iv[0] = TLS_SM4_CCM_IV_B0_BYTE; 1484 iv_offset = 1; 1485 break; 1486 } 1487 1488 /* Prepare IV */ 1489 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, 1490 iv + iv_offset + prot->salt_size, 1491 prot->iv_size); 1492 if (err < 0) { 1493 kfree(mem); 1494 return err; 1495 } 1496 if (prot->version == TLS_1_3_VERSION || 1497 prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) 1498 memcpy(iv + iv_offset, tls_ctx->rx.iv, 1499 prot->iv_size + prot->salt_size); 1500 else 1501 memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size); 1502 1503 xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq); 1504 1505 /* Prepare AAD */ 1506 tls_make_aad(aad, rxm->full_len - prot->overhead_size + 1507 prot->tail_size, 1508 tls_ctx->rx.rec_seq, tlm->control, prot); 1509 1510 /* Prepare sgin */ 1511 sg_init_table(sgin, n_sgin); 1512 sg_set_buf(&sgin[0], aad, prot->aad_size); 1513 err = skb_to_sgvec(skb, &sgin[1], 1514 rxm->offset + prot->prepend_size, 1515 rxm->full_len - prot->prepend_size); 1516 if (err < 0) { 1517 kfree(mem); 1518 return err; 1519 } 1520 1521 if (n_sgout) { 1522 if (out_iov) { 1523 sg_init_table(sgout, n_sgout); 1524 sg_set_buf(&sgout[0], aad, prot->aad_size); 1525 1526 *chunk = 0; 1527 err = tls_setup_from_iter(sk, out_iov, data_len, 1528 &pages, chunk, &sgout[1], 1529 (n_sgout - 1)); 1530 if (err < 0) 1531 goto fallback_to_reg_recv; 1532 } else if (out_sg) { 1533 memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); 1534 } else { 1535 goto fallback_to_reg_recv; 1536 } 1537 } else { 1538 fallback_to_reg_recv: 1539 sgout = sgin; 1540 pages = 0; 1541 *chunk = data_len; 1542 *zc = false; 1543 } 1544 1545 /* Prepare and submit AEAD request */ 1546 err = tls_do_decryption(sk, skb, sgin, sgout, iv, 1547 data_len, aead_req, async); 1548 if (err == -EINPROGRESS) 1549 return err; 1550 1551 /* Release the pages in case iov was mapped to pages */ 1552 for (; pages > 0; pages--) 1553 put_page(sg_page(&sgout[pages])); 1554 1555 kfree(mem); 1556 return err; 1557 } 1558 1559 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, 1560 struct iov_iter *dest, int *chunk, bool *zc, 1561 bool async) 1562 { 1563 struct tls_context *tls_ctx = tls_get_ctx(sk); 1564 struct tls_prot_info *prot = &tls_ctx->prot_info; 1565 struct strp_msg *rxm = strp_msg(skb); 1566 struct tls_msg *tlm = tls_msg(skb); 1567 int pad, err; 1568 1569 if (tlm->decrypted) { 1570 *zc = false; 1571 return 0; 1572 } 1573 1574 if (tls_ctx->rx_conf == TLS_HW) { 1575 err = tls_device_decrypted(sk, tls_ctx, skb, rxm); 1576 if (err < 0) 1577 return err; 1578 if (err > 0) { 1579 tlm->decrypted = 1; 1580 *zc = false; 1581 goto decrypt_done; 1582 } 1583 } 1584 1585 err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async); 1586 if (err < 0) { 1587 if (err == -EINPROGRESS) 1588 tls_advance_record_sn(sk, prot, &tls_ctx->rx); 1589 else if (err == -EBADMSG) 1590 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); 1591 return err; 1592 } 1593 1594 decrypt_done: 1595 pad = padding_length(prot, skb); 1596 if (pad < 0) 1597 return pad; 1598 1599 rxm->full_len -= pad; 1600 rxm->offset += prot->prepend_size; 1601 rxm->full_len -= prot->overhead_size; 1602 tls_advance_record_sn(sk, prot, &tls_ctx->rx); 1603 tlm->decrypted = 1; 1604 1605 return 0; 1606 } 1607 1608 int decrypt_skb(struct sock *sk, struct sk_buff *skb, 1609 struct scatterlist *sgout) 1610 { 1611 bool zc = true; 1612 int chunk; 1613 1614 return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false); 1615 } 1616 1617 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, 1618 unsigned int len) 1619 { 1620 struct tls_context *tls_ctx = tls_get_ctx(sk); 1621 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1622 1623 if (skb) { 1624 struct strp_msg *rxm = strp_msg(skb); 1625 1626 if (len < rxm->full_len) { 1627 rxm->offset += len; 1628 rxm->full_len -= len; 1629 return false; 1630 } 1631 consume_skb(skb); 1632 } 1633 1634 /* Finished with message */ 1635 ctx->recv_pkt = NULL; 1636 __strp_unpause(&ctx->strp); 1637 1638 return true; 1639 } 1640 1641 /* This function traverses the rx_list in tls receive context to copies the 1642 * decrypted records into the buffer provided by caller zero copy is not 1643 * true. Further, the records are removed from the rx_list if it is not a peek 1644 * case and the record has been consumed completely. 1645 */ 1646 static int process_rx_list(struct tls_sw_context_rx *ctx, 1647 struct msghdr *msg, 1648 u8 *control, 1649 bool *cmsg, 1650 size_t skip, 1651 size_t len, 1652 bool zc, 1653 bool is_peek) 1654 { 1655 struct sk_buff *skb = skb_peek(&ctx->rx_list); 1656 u8 ctrl = *control; 1657 u8 msgc = *cmsg; 1658 struct tls_msg *tlm; 1659 ssize_t copied = 0; 1660 1661 /* Set the record type in 'control' if caller didn't pass it */ 1662 if (!ctrl && skb) { 1663 tlm = tls_msg(skb); 1664 ctrl = tlm->control; 1665 } 1666 1667 while (skip && skb) { 1668 struct strp_msg *rxm = strp_msg(skb); 1669 tlm = tls_msg(skb); 1670 1671 /* Cannot process a record of different type */ 1672 if (ctrl != tlm->control) 1673 return 0; 1674 1675 if (skip < rxm->full_len) 1676 break; 1677 1678 skip = skip - rxm->full_len; 1679 skb = skb_peek_next(skb, &ctx->rx_list); 1680 } 1681 1682 while (len && skb) { 1683 struct sk_buff *next_skb; 1684 struct strp_msg *rxm = strp_msg(skb); 1685 int chunk = min_t(unsigned int, rxm->full_len - skip, len); 1686 1687 tlm = tls_msg(skb); 1688 1689 /* Cannot process a record of different type */ 1690 if (ctrl != tlm->control) 1691 return 0; 1692 1693 /* Set record type if not already done. For a non-data record, 1694 * do not proceed if record type could not be copied. 1695 */ 1696 if (!msgc) { 1697 int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, 1698 sizeof(ctrl), &ctrl); 1699 msgc = true; 1700 if (ctrl != TLS_RECORD_TYPE_DATA) { 1701 if (cerr || msg->msg_flags & MSG_CTRUNC) 1702 return -EIO; 1703 1704 *cmsg = msgc; 1705 } 1706 } 1707 1708 if (!zc || (rxm->full_len - skip) > len) { 1709 int err = skb_copy_datagram_msg(skb, rxm->offset + skip, 1710 msg, chunk); 1711 if (err < 0) 1712 return err; 1713 } 1714 1715 len = len - chunk; 1716 copied = copied + chunk; 1717 1718 /* Consume the data from record if it is non-peek case*/ 1719 if (!is_peek) { 1720 rxm->offset = rxm->offset + chunk; 1721 rxm->full_len = rxm->full_len - chunk; 1722 1723 /* Return if there is unconsumed data in the record */ 1724 if (rxm->full_len - skip) 1725 break; 1726 } 1727 1728 /* The remaining skip-bytes must lie in 1st record in rx_list. 1729 * So from the 2nd record, 'skip' should be 0. 1730 */ 1731 skip = 0; 1732 1733 if (msg) 1734 msg->msg_flags |= MSG_EOR; 1735 1736 next_skb = skb_peek_next(skb, &ctx->rx_list); 1737 1738 if (!is_peek) { 1739 skb_unlink(skb, &ctx->rx_list); 1740 consume_skb(skb); 1741 } 1742 1743 skb = next_skb; 1744 } 1745 1746 *control = ctrl; 1747 return copied; 1748 } 1749 1750 int tls_sw_recvmsg(struct sock *sk, 1751 struct msghdr *msg, 1752 size_t len, 1753 int nonblock, 1754 int flags, 1755 int *addr_len) 1756 { 1757 struct tls_context *tls_ctx = tls_get_ctx(sk); 1758 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1759 struct tls_prot_info *prot = &tls_ctx->prot_info; 1760 struct sk_psock *psock; 1761 int num_async, pending; 1762 unsigned char control = 0; 1763 ssize_t decrypted = 0; 1764 struct strp_msg *rxm; 1765 struct tls_msg *tlm; 1766 struct sk_buff *skb; 1767 ssize_t copied = 0; 1768 bool cmsg = false; 1769 int target, err = 0; 1770 long timeo; 1771 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); 1772 bool is_peek = flags & MSG_PEEK; 1773 bool bpf_strp_enabled; 1774 1775 flags |= nonblock; 1776 1777 if (unlikely(flags & MSG_ERRQUEUE)) 1778 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); 1779 1780 psock = sk_psock_get(sk); 1781 lock_sock(sk); 1782 bpf_strp_enabled = sk_psock_strp_enabled(psock); 1783 1784 /* Process pending decrypted records. It must be non-zero-copy */ 1785 err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false, 1786 is_peek); 1787 if (err < 0) { 1788 tls_err_abort(sk, err); 1789 goto end; 1790 } 1791 1792 copied = err; 1793 if (len <= copied) 1794 goto end; 1795 1796 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); 1797 len = len - copied; 1798 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 1799 1800 decrypted = 0; 1801 num_async = 0; 1802 while (len && (decrypted + copied < target || ctx->recv_pkt)) { 1803 bool retain_skb = false; 1804 bool zc = false; 1805 int to_decrypt; 1806 int chunk = 0; 1807 bool async_capable; 1808 bool async = false; 1809 1810 skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err); 1811 if (!skb) { 1812 if (psock) { 1813 int ret = sk_msg_recvmsg(sk, psock, msg, len, 1814 flags); 1815 1816 if (ret > 0) { 1817 decrypted += ret; 1818 len -= ret; 1819 continue; 1820 } 1821 } 1822 goto recv_end; 1823 } 1824 1825 rxm = strp_msg(skb); 1826 tlm = tls_msg(skb); 1827 1828 to_decrypt = rxm->full_len - prot->overhead_size; 1829 1830 if (to_decrypt <= len && !is_kvec && !is_peek && 1831 tlm->control == TLS_RECORD_TYPE_DATA && 1832 prot->version != TLS_1_3_VERSION && 1833 !bpf_strp_enabled) 1834 zc = true; 1835 1836 /* Do not use async mode if record is non-data */ 1837 if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled) 1838 async_capable = ctx->async_capable; 1839 else 1840 async_capable = false; 1841 1842 err = decrypt_skb_update(sk, skb, &msg->msg_iter, 1843 &chunk, &zc, async_capable); 1844 if (err < 0 && err != -EINPROGRESS) { 1845 tls_err_abort(sk, -EBADMSG); 1846 goto recv_end; 1847 } 1848 1849 if (err == -EINPROGRESS) { 1850 async = true; 1851 num_async++; 1852 } 1853 1854 /* If the type of records being processed is not known yet, 1855 * set it to record type just dequeued. If it is already known, 1856 * but does not match the record type just dequeued, go to end. 1857 * We always get record type here since for tls1.2, record type 1858 * is known just after record is dequeued from stream parser. 1859 * For tls1.3, we disable async. 1860 */ 1861 1862 if (!control) 1863 control = tlm->control; 1864 else if (control != tlm->control) 1865 goto recv_end; 1866 1867 if (!cmsg) { 1868 int cerr; 1869 1870 cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, 1871 sizeof(control), &control); 1872 cmsg = true; 1873 if (control != TLS_RECORD_TYPE_DATA) { 1874 if (cerr || msg->msg_flags & MSG_CTRUNC) { 1875 err = -EIO; 1876 goto recv_end; 1877 } 1878 } 1879 } 1880 1881 if (async) 1882 goto pick_next_record; 1883 1884 if (!zc) { 1885 if (bpf_strp_enabled) { 1886 err = sk_psock_tls_strp_read(psock, skb); 1887 if (err != __SK_PASS) { 1888 rxm->offset = rxm->offset + rxm->full_len; 1889 rxm->full_len = 0; 1890 if (err == __SK_DROP) 1891 consume_skb(skb); 1892 ctx->recv_pkt = NULL; 1893 __strp_unpause(&ctx->strp); 1894 continue; 1895 } 1896 } 1897 1898 if (rxm->full_len > len) { 1899 retain_skb = true; 1900 chunk = len; 1901 } else { 1902 chunk = rxm->full_len; 1903 } 1904 1905 err = skb_copy_datagram_msg(skb, rxm->offset, 1906 msg, chunk); 1907 if (err < 0) 1908 goto recv_end; 1909 1910 if (!is_peek) { 1911 rxm->offset = rxm->offset + chunk; 1912 rxm->full_len = rxm->full_len - chunk; 1913 } 1914 } 1915 1916 pick_next_record: 1917 if (chunk > len) 1918 chunk = len; 1919 1920 decrypted += chunk; 1921 len -= chunk; 1922 1923 /* For async or peek case, queue the current skb */ 1924 if (async || is_peek || retain_skb) { 1925 skb_queue_tail(&ctx->rx_list, skb); 1926 skb = NULL; 1927 } 1928 1929 if (tls_sw_advance_skb(sk, skb, chunk)) { 1930 /* Return full control message to 1931 * userspace before trying to parse 1932 * another message type 1933 */ 1934 msg->msg_flags |= MSG_EOR; 1935 if (control != TLS_RECORD_TYPE_DATA) 1936 goto recv_end; 1937 } else { 1938 break; 1939 } 1940 } 1941 1942 recv_end: 1943 if (num_async) { 1944 /* Wait for all previously submitted records to be decrypted */ 1945 spin_lock_bh(&ctx->decrypt_compl_lock); 1946 ctx->async_notify = true; 1947 pending = atomic_read(&ctx->decrypt_pending); 1948 spin_unlock_bh(&ctx->decrypt_compl_lock); 1949 if (pending) { 1950 err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); 1951 if (err) { 1952 /* one of async decrypt failed */ 1953 tls_err_abort(sk, err); 1954 copied = 0; 1955 decrypted = 0; 1956 goto end; 1957 } 1958 } else { 1959 reinit_completion(&ctx->async_wait.completion); 1960 } 1961 1962 /* There can be no concurrent accesses, since we have no 1963 * pending decrypt operations 1964 */ 1965 WRITE_ONCE(ctx->async_notify, false); 1966 1967 /* Drain records from the rx_list & copy if required */ 1968 if (is_peek || is_kvec) 1969 err = process_rx_list(ctx, msg, &control, &cmsg, copied, 1970 decrypted, false, is_peek); 1971 else 1972 err = process_rx_list(ctx, msg, &control, &cmsg, 0, 1973 decrypted, true, is_peek); 1974 if (err < 0) { 1975 tls_err_abort(sk, err); 1976 copied = 0; 1977 goto end; 1978 } 1979 } 1980 1981 copied += decrypted; 1982 1983 end: 1984 release_sock(sk); 1985 sk_defer_free_flush(sk); 1986 if (psock) 1987 sk_psock_put(sk, psock); 1988 return copied ? : err; 1989 } 1990 1991 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, 1992 struct pipe_inode_info *pipe, 1993 size_t len, unsigned int flags) 1994 { 1995 struct tls_context *tls_ctx = tls_get_ctx(sock->sk); 1996 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1997 struct strp_msg *rxm = NULL; 1998 struct sock *sk = sock->sk; 1999 struct tls_msg *tlm; 2000 struct sk_buff *skb; 2001 ssize_t copied = 0; 2002 bool from_queue; 2003 int err = 0; 2004 long timeo; 2005 int chunk; 2006 bool zc = false; 2007 2008 lock_sock(sk); 2009 2010 timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK); 2011 2012 from_queue = !skb_queue_empty(&ctx->rx_list); 2013 if (from_queue) { 2014 skb = __skb_dequeue(&ctx->rx_list); 2015 } else { 2016 skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, 2017 &err); 2018 if (!skb) 2019 goto splice_read_end; 2020 2021 err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false); 2022 if (err < 0) { 2023 tls_err_abort(sk, -EBADMSG); 2024 goto splice_read_end; 2025 } 2026 } 2027 2028 rxm = strp_msg(skb); 2029 tlm = tls_msg(skb); 2030 2031 /* splice does not support reading control messages */ 2032 if (tlm->control != TLS_RECORD_TYPE_DATA) { 2033 err = -EINVAL; 2034 goto splice_read_end; 2035 } 2036 2037 chunk = min_t(unsigned int, rxm->full_len, len); 2038 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); 2039 if (copied < 0) 2040 goto splice_read_end; 2041 2042 if (!from_queue) { 2043 ctx->recv_pkt = NULL; 2044 __strp_unpause(&ctx->strp); 2045 } 2046 if (chunk < rxm->full_len) { 2047 __skb_queue_head(&ctx->rx_list, skb); 2048 rxm->offset += len; 2049 rxm->full_len -= len; 2050 } else { 2051 consume_skb(skb); 2052 } 2053 2054 splice_read_end: 2055 release_sock(sk); 2056 sk_defer_free_flush(sk); 2057 return copied ? : err; 2058 } 2059 2060 bool tls_sw_sock_is_readable(struct sock *sk) 2061 { 2062 struct tls_context *tls_ctx = tls_get_ctx(sk); 2063 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2064 bool ingress_empty = true; 2065 struct sk_psock *psock; 2066 2067 rcu_read_lock(); 2068 psock = sk_psock(sk); 2069 if (psock) 2070 ingress_empty = list_empty(&psock->ingress_msg); 2071 rcu_read_unlock(); 2072 2073 return !ingress_empty || ctx->recv_pkt || 2074 !skb_queue_empty(&ctx->rx_list); 2075 } 2076 2077 static int tls_read_size(struct strparser *strp, struct sk_buff *skb) 2078 { 2079 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 2080 struct tls_prot_info *prot = &tls_ctx->prot_info; 2081 char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; 2082 struct strp_msg *rxm = strp_msg(skb); 2083 struct tls_msg *tlm = tls_msg(skb); 2084 size_t cipher_overhead; 2085 size_t data_len = 0; 2086 int ret; 2087 2088 /* Verify that we have a full TLS header, or wait for more data */ 2089 if (rxm->offset + prot->prepend_size > skb->len) 2090 return 0; 2091 2092 /* Sanity-check size of on-stack buffer. */ 2093 if (WARN_ON(prot->prepend_size > sizeof(header))) { 2094 ret = -EINVAL; 2095 goto read_failure; 2096 } 2097 2098 /* Linearize header to local buffer */ 2099 ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size); 2100 if (ret < 0) 2101 goto read_failure; 2102 2103 tlm->decrypted = 0; 2104 tlm->control = header[0]; 2105 2106 data_len = ((header[4] & 0xFF) | (header[3] << 8)); 2107 2108 cipher_overhead = prot->tag_size; 2109 if (prot->version != TLS_1_3_VERSION && 2110 prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) 2111 cipher_overhead += prot->iv_size; 2112 2113 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead + 2114 prot->tail_size) { 2115 ret = -EMSGSIZE; 2116 goto read_failure; 2117 } 2118 if (data_len < cipher_overhead) { 2119 ret = -EBADMSG; 2120 goto read_failure; 2121 } 2122 2123 /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */ 2124 if (header[1] != TLS_1_2_VERSION_MINOR || 2125 header[2] != TLS_1_2_VERSION_MAJOR) { 2126 ret = -EINVAL; 2127 goto read_failure; 2128 } 2129 2130 tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE, 2131 TCP_SKB_CB(skb)->seq + rxm->offset); 2132 return data_len + TLS_HEADER_SIZE; 2133 2134 read_failure: 2135 tls_err_abort(strp->sk, ret); 2136 2137 return ret; 2138 } 2139 2140 static void tls_queue(struct strparser *strp, struct sk_buff *skb) 2141 { 2142 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 2143 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2144 2145 ctx->recv_pkt = skb; 2146 strp_pause(strp); 2147 2148 ctx->saved_data_ready(strp->sk); 2149 } 2150 2151 static void tls_data_ready(struct sock *sk) 2152 { 2153 struct tls_context *tls_ctx = tls_get_ctx(sk); 2154 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2155 struct sk_psock *psock; 2156 2157 strp_data_ready(&ctx->strp); 2158 2159 psock = sk_psock_get(sk); 2160 if (psock) { 2161 if (!list_empty(&psock->ingress_msg)) 2162 ctx->saved_data_ready(sk); 2163 sk_psock_put(sk, psock); 2164 } 2165 } 2166 2167 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx) 2168 { 2169 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 2170 2171 set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask); 2172 set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask); 2173 cancel_delayed_work_sync(&ctx->tx_work.work); 2174 } 2175 2176 void tls_sw_release_resources_tx(struct sock *sk) 2177 { 2178 struct tls_context *tls_ctx = tls_get_ctx(sk); 2179 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 2180 struct tls_rec *rec, *tmp; 2181 int pending; 2182 2183 /* Wait for any pending async encryptions to complete */ 2184 spin_lock_bh(&ctx->encrypt_compl_lock); 2185 ctx->async_notify = true; 2186 pending = atomic_read(&ctx->encrypt_pending); 2187 spin_unlock_bh(&ctx->encrypt_compl_lock); 2188 2189 if (pending) 2190 crypto_wait_req(-EINPROGRESS, &ctx->async_wait); 2191 2192 tls_tx_records(sk, -1); 2193 2194 /* Free up un-sent records in tx_list. First, free 2195 * the partially sent record if any at head of tx_list. 2196 */ 2197 if (tls_ctx->partially_sent_record) { 2198 tls_free_partial_record(sk, tls_ctx); 2199 rec = list_first_entry(&ctx->tx_list, 2200 struct tls_rec, list); 2201 list_del(&rec->list); 2202 sk_msg_free(sk, &rec->msg_plaintext); 2203 kfree(rec); 2204 } 2205 2206 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) { 2207 list_del(&rec->list); 2208 sk_msg_free(sk, &rec->msg_encrypted); 2209 sk_msg_free(sk, &rec->msg_plaintext); 2210 kfree(rec); 2211 } 2212 2213 crypto_free_aead(ctx->aead_send); 2214 tls_free_open_rec(sk); 2215 } 2216 2217 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx) 2218 { 2219 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 2220 2221 kfree(ctx); 2222 } 2223 2224 void tls_sw_release_resources_rx(struct sock *sk) 2225 { 2226 struct tls_context *tls_ctx = tls_get_ctx(sk); 2227 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2228 2229 kfree(tls_ctx->rx.rec_seq); 2230 kfree(tls_ctx->rx.iv); 2231 2232 if (ctx->aead_recv) { 2233 kfree_skb(ctx->recv_pkt); 2234 ctx->recv_pkt = NULL; 2235 skb_queue_purge(&ctx->rx_list); 2236 crypto_free_aead(ctx->aead_recv); 2237 strp_stop(&ctx->strp); 2238 /* If tls_sw_strparser_arm() was not called (cleanup paths) 2239 * we still want to strp_stop(), but sk->sk_data_ready was 2240 * never swapped. 2241 */ 2242 if (ctx->saved_data_ready) { 2243 write_lock_bh(&sk->sk_callback_lock); 2244 sk->sk_data_ready = ctx->saved_data_ready; 2245 write_unlock_bh(&sk->sk_callback_lock); 2246 } 2247 } 2248 } 2249 2250 void tls_sw_strparser_done(struct tls_context *tls_ctx) 2251 { 2252 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2253 2254 strp_done(&ctx->strp); 2255 } 2256 2257 void tls_sw_free_ctx_rx(struct tls_context *tls_ctx) 2258 { 2259 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2260 2261 kfree(ctx); 2262 } 2263 2264 void tls_sw_free_resources_rx(struct sock *sk) 2265 { 2266 struct tls_context *tls_ctx = tls_get_ctx(sk); 2267 2268 tls_sw_release_resources_rx(sk); 2269 tls_sw_free_ctx_rx(tls_ctx); 2270 } 2271 2272 /* The work handler to transmitt the encrypted records in tx_list */ 2273 static void tx_work_handler(struct work_struct *work) 2274 { 2275 struct delayed_work *delayed_work = to_delayed_work(work); 2276 struct tx_work *tx_work = container_of(delayed_work, 2277 struct tx_work, work); 2278 struct sock *sk = tx_work->sk; 2279 struct tls_context *tls_ctx = tls_get_ctx(sk); 2280 struct tls_sw_context_tx *ctx; 2281 2282 if (unlikely(!tls_ctx)) 2283 return; 2284 2285 ctx = tls_sw_ctx_tx(tls_ctx); 2286 if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask)) 2287 return; 2288 2289 if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) 2290 return; 2291 mutex_lock(&tls_ctx->tx_lock); 2292 lock_sock(sk); 2293 tls_tx_records(sk, -1); 2294 release_sock(sk); 2295 mutex_unlock(&tls_ctx->tx_lock); 2296 } 2297 2298 void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) 2299 { 2300 struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); 2301 2302 /* Schedule the transmission if tx list is ready */ 2303 if (is_tx_ready(tx_ctx) && 2304 !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) 2305 schedule_delayed_work(&tx_ctx->tx_work.work, 0); 2306 } 2307 2308 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx) 2309 { 2310 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); 2311 2312 write_lock_bh(&sk->sk_callback_lock); 2313 rx_ctx->saved_data_ready = sk->sk_data_ready; 2314 sk->sk_data_ready = tls_data_ready; 2315 write_unlock_bh(&sk->sk_callback_lock); 2316 2317 strp_check_rcv(&rx_ctx->strp); 2318 } 2319 2320 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) 2321 { 2322 struct tls_context *tls_ctx = tls_get_ctx(sk); 2323 struct tls_prot_info *prot = &tls_ctx->prot_info; 2324 struct tls_crypto_info *crypto_info; 2325 struct tls_sw_context_tx *sw_ctx_tx = NULL; 2326 struct tls_sw_context_rx *sw_ctx_rx = NULL; 2327 struct cipher_context *cctx; 2328 struct crypto_aead **aead; 2329 struct strp_callbacks cb; 2330 u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size; 2331 struct crypto_tfm *tfm; 2332 char *iv, *rec_seq, *key, *salt, *cipher_name; 2333 size_t keysize; 2334 int rc = 0; 2335 2336 if (!ctx) { 2337 rc = -EINVAL; 2338 goto out; 2339 } 2340 2341 if (tx) { 2342 if (!ctx->priv_ctx_tx) { 2343 sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); 2344 if (!sw_ctx_tx) { 2345 rc = -ENOMEM; 2346 goto out; 2347 } 2348 ctx->priv_ctx_tx = sw_ctx_tx; 2349 } else { 2350 sw_ctx_tx = 2351 (struct tls_sw_context_tx *)ctx->priv_ctx_tx; 2352 } 2353 } else { 2354 if (!ctx->priv_ctx_rx) { 2355 sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); 2356 if (!sw_ctx_rx) { 2357 rc = -ENOMEM; 2358 goto out; 2359 } 2360 ctx->priv_ctx_rx = sw_ctx_rx; 2361 } else { 2362 sw_ctx_rx = 2363 (struct tls_sw_context_rx *)ctx->priv_ctx_rx; 2364 } 2365 } 2366 2367 if (tx) { 2368 crypto_init_wait(&sw_ctx_tx->async_wait); 2369 spin_lock_init(&sw_ctx_tx->encrypt_compl_lock); 2370 crypto_info = &ctx->crypto_send.info; 2371 cctx = &ctx->tx; 2372 aead = &sw_ctx_tx->aead_send; 2373 INIT_LIST_HEAD(&sw_ctx_tx->tx_list); 2374 INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); 2375 sw_ctx_tx->tx_work.sk = sk; 2376 } else { 2377 crypto_init_wait(&sw_ctx_rx->async_wait); 2378 spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); 2379 crypto_info = &ctx->crypto_recv.info; 2380 cctx = &ctx->rx; 2381 skb_queue_head_init(&sw_ctx_rx->rx_list); 2382 aead = &sw_ctx_rx->aead_recv; 2383 } 2384 2385 switch (crypto_info->cipher_type) { 2386 case TLS_CIPHER_AES_GCM_128: { 2387 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; 2388 2389 gcm_128_info = (void *)crypto_info; 2390 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 2391 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; 2392 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 2393 iv = gcm_128_info->iv; 2394 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; 2395 rec_seq = gcm_128_info->rec_seq; 2396 keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE; 2397 key = gcm_128_info->key; 2398 salt = gcm_128_info->salt; 2399 salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE; 2400 cipher_name = "gcm(aes)"; 2401 break; 2402 } 2403 case TLS_CIPHER_AES_GCM_256: { 2404 struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; 2405 2406 gcm_256_info = (void *)crypto_info; 2407 nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; 2408 tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE; 2409 iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; 2410 iv = gcm_256_info->iv; 2411 rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE; 2412 rec_seq = gcm_256_info->rec_seq; 2413 keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE; 2414 key = gcm_256_info->key; 2415 salt = gcm_256_info->salt; 2416 salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE; 2417 cipher_name = "gcm(aes)"; 2418 break; 2419 } 2420 case TLS_CIPHER_AES_CCM_128: { 2421 struct tls12_crypto_info_aes_ccm_128 *ccm_128_info; 2422 2423 ccm_128_info = (void *)crypto_info; 2424 nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; 2425 tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE; 2426 iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; 2427 iv = ccm_128_info->iv; 2428 rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE; 2429 rec_seq = ccm_128_info->rec_seq; 2430 keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE; 2431 key = ccm_128_info->key; 2432 salt = ccm_128_info->salt; 2433 salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE; 2434 cipher_name = "ccm(aes)"; 2435 break; 2436 } 2437 case TLS_CIPHER_CHACHA20_POLY1305: { 2438 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info; 2439 2440 chacha20_poly1305_info = (void *)crypto_info; 2441 nonce_size = 0; 2442 tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE; 2443 iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE; 2444 iv = chacha20_poly1305_info->iv; 2445 rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE; 2446 rec_seq = chacha20_poly1305_info->rec_seq; 2447 keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE; 2448 key = chacha20_poly1305_info->key; 2449 salt = chacha20_poly1305_info->salt; 2450 salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE; 2451 cipher_name = "rfc7539(chacha20,poly1305)"; 2452 break; 2453 } 2454 case TLS_CIPHER_SM4_GCM: { 2455 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info; 2456 2457 sm4_gcm_info = (void *)crypto_info; 2458 nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE; 2459 tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE; 2460 iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE; 2461 iv = sm4_gcm_info->iv; 2462 rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE; 2463 rec_seq = sm4_gcm_info->rec_seq; 2464 keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE; 2465 key = sm4_gcm_info->key; 2466 salt = sm4_gcm_info->salt; 2467 salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE; 2468 cipher_name = "gcm(sm4)"; 2469 break; 2470 } 2471 case TLS_CIPHER_SM4_CCM: { 2472 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info; 2473 2474 sm4_ccm_info = (void *)crypto_info; 2475 nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE; 2476 tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE; 2477 iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE; 2478 iv = sm4_ccm_info->iv; 2479 rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE; 2480 rec_seq = sm4_ccm_info->rec_seq; 2481 keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE; 2482 key = sm4_ccm_info->key; 2483 salt = sm4_ccm_info->salt; 2484 salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE; 2485 cipher_name = "ccm(sm4)"; 2486 break; 2487 } 2488 default: 2489 rc = -EINVAL; 2490 goto free_priv; 2491 } 2492 2493 /* Sanity-check the sizes for stack allocations. */ 2494 if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE || 2495 rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE) { 2496 rc = -EINVAL; 2497 goto free_priv; 2498 } 2499 2500 if (crypto_info->version == TLS_1_3_VERSION) { 2501 nonce_size = 0; 2502 prot->aad_size = TLS_HEADER_SIZE; 2503 prot->tail_size = 1; 2504 } else { 2505 prot->aad_size = TLS_AAD_SPACE_SIZE; 2506 prot->tail_size = 0; 2507 } 2508 2509 prot->version = crypto_info->version; 2510 prot->cipher_type = crypto_info->cipher_type; 2511 prot->prepend_size = TLS_HEADER_SIZE + nonce_size; 2512 prot->tag_size = tag_size; 2513 prot->overhead_size = prot->prepend_size + 2514 prot->tag_size + prot->tail_size; 2515 prot->iv_size = iv_size; 2516 prot->salt_size = salt_size; 2517 cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL); 2518 if (!cctx->iv) { 2519 rc = -ENOMEM; 2520 goto free_priv; 2521 } 2522 /* Note: 128 & 256 bit salt are the same size */ 2523 prot->rec_seq_size = rec_seq_size; 2524 memcpy(cctx->iv, salt, salt_size); 2525 memcpy(cctx->iv + salt_size, iv, iv_size); 2526 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); 2527 if (!cctx->rec_seq) { 2528 rc = -ENOMEM; 2529 goto free_iv; 2530 } 2531 2532 if (!*aead) { 2533 *aead = crypto_alloc_aead(cipher_name, 0, 0); 2534 if (IS_ERR(*aead)) { 2535 rc = PTR_ERR(*aead); 2536 *aead = NULL; 2537 goto free_rec_seq; 2538 } 2539 } 2540 2541 ctx->push_pending_record = tls_sw_push_pending_record; 2542 2543 rc = crypto_aead_setkey(*aead, key, keysize); 2544 2545 if (rc) 2546 goto free_aead; 2547 2548 rc = crypto_aead_setauthsize(*aead, prot->tag_size); 2549 if (rc) 2550 goto free_aead; 2551 2552 if (sw_ctx_rx) { 2553 tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv); 2554 2555 if (crypto_info->version == TLS_1_3_VERSION) 2556 sw_ctx_rx->async_capable = 0; 2557 else 2558 sw_ctx_rx->async_capable = 2559 !!(tfm->__crt_alg->cra_flags & 2560 CRYPTO_ALG_ASYNC); 2561 2562 /* Set up strparser */ 2563 memset(&cb, 0, sizeof(cb)); 2564 cb.rcv_msg = tls_queue; 2565 cb.parse_msg = tls_read_size; 2566 2567 strp_init(&sw_ctx_rx->strp, sk, &cb); 2568 } 2569 2570 goto out; 2571 2572 free_aead: 2573 crypto_free_aead(*aead); 2574 *aead = NULL; 2575 free_rec_seq: 2576 kfree(cctx->rec_seq); 2577 cctx->rec_seq = NULL; 2578 free_iv: 2579 kfree(cctx->iv); 2580 cctx->iv = NULL; 2581 free_priv: 2582 if (tx) { 2583 kfree(ctx->priv_ctx_tx); 2584 ctx->priv_ctx_tx = NULL; 2585 } else { 2586 kfree(ctx->priv_ctx_rx); 2587 ctx->priv_ctx_rx = NULL; 2588 } 2589 out: 2590 return rc; 2591 } 2592