1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 #include <linux/inetdevice.h> 42 #include <linux/inet_diag.h> 43 44 #include <net/snmp.h> 45 #include <net/tls.h> 46 #include <net/tls_toe.h> 47 48 #include "tls.h" 49 50 MODULE_AUTHOR("Mellanox Technologies"); 51 MODULE_DESCRIPTION("Transport Layer Security Support"); 52 MODULE_LICENSE("Dual BSD/GPL"); 53 MODULE_ALIAS_TCP_ULP("tls"); 54 55 enum { 56 TLSV4, 57 TLSV6, 58 TLS_NUM_PROTS, 59 }; 60 61 #define CHECK_CIPHER_DESC(cipher,ci) \ 62 static_assert(cipher ## _IV_SIZE <= TLS_MAX_IV_SIZE); \ 63 static_assert(cipher ## _SALT_SIZE <= TLS_MAX_SALT_SIZE); \ 64 static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE); \ 65 static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE); \ 66 static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE); \ 67 static_assert(sizeof_field(struct ci, key) == cipher ## _KEY_SIZE); \ 68 static_assert(sizeof_field(struct ci, salt) == cipher ## _SALT_SIZE); \ 69 static_assert(sizeof_field(struct ci, rec_seq) == cipher ## _REC_SEQ_SIZE); 70 71 #define __CIPHER_DESC(ci) \ 72 .iv_offset = offsetof(struct ci, iv), \ 73 .key_offset = offsetof(struct ci, key), \ 74 .salt_offset = offsetof(struct ci, salt), \ 75 .rec_seq_offset = offsetof(struct ci, rec_seq), \ 76 .crypto_info = sizeof(struct ci) 77 78 #define CIPHER_DESC(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ 79 .nonce = cipher ## _IV_SIZE, \ 80 .iv = cipher ## _IV_SIZE, \ 81 .key = cipher ## _KEY_SIZE, \ 82 .salt = cipher ## _SALT_SIZE, \ 83 .tag = cipher ## _TAG_SIZE, \ 84 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 85 .cipher_name = algname, \ 86 .offloadable = _offloadable, \ 87 __CIPHER_DESC(ci), \ 88 } 89 90 #define CIPHER_DESC_NONCE0(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ 91 .nonce = 0, \ 92 .iv = cipher ## _IV_SIZE, \ 93 .key = cipher ## _KEY_SIZE, \ 94 .salt = cipher ## _SALT_SIZE, \ 95 .tag = cipher ## _TAG_SIZE, \ 96 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 97 .cipher_name = algname, \ 98 .offloadable = _offloadable, \ 99 __CIPHER_DESC(ci), \ 100 } 101 102 const struct tls_cipher_desc tls_cipher_desc[TLS_CIPHER_MAX + 1 - TLS_CIPHER_MIN] = { 103 CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128, "gcm(aes)", true), 104 CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256, "gcm(aes)", true), 105 CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128, "ccm(aes)", false), 106 CIPHER_DESC_NONCE0(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305, "rfc7539(chacha20,poly1305)", false), 107 CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm, "gcm(sm4)", false), 108 CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm, "ccm(sm4)", false), 109 CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128, "gcm(aria)", false), 110 CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256, "gcm(aria)", false), 111 }; 112 113 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128); 114 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256); 115 CHECK_CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128); 116 CHECK_CIPHER_DESC(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305); 117 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm); 118 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm); 119 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128); 120 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256); 121 122 static const struct proto *saved_tcpv6_prot; 123 static DEFINE_MUTEX(tcpv6_prot_mutex); 124 static const struct proto *saved_tcpv4_prot; 125 static DEFINE_MUTEX(tcpv4_prot_mutex); 126 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 127 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 128 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 129 const struct proto *base); 130 131 void update_sk_prot(struct sock *sk, struct tls_context *ctx) 132 { 133 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 134 135 WRITE_ONCE(sk->sk_prot, 136 &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); 137 WRITE_ONCE(sk->sk_socket->ops, 138 &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]); 139 } 140 141 int wait_on_pending_writer(struct sock *sk, long *timeo) 142 { 143 DEFINE_WAIT_FUNC(wait, woken_wake_function); 144 int ret, rc = 0; 145 146 add_wait_queue(sk_sleep(sk), &wait); 147 while (1) { 148 if (!*timeo) { 149 rc = -EAGAIN; 150 break; 151 } 152 153 if (signal_pending(current)) { 154 rc = sock_intr_errno(*timeo); 155 break; 156 } 157 158 ret = sk_wait_event(sk, timeo, 159 !READ_ONCE(sk->sk_write_pending), &wait); 160 if (ret) { 161 if (ret < 0) 162 rc = ret; 163 break; 164 } 165 } 166 remove_wait_queue(sk_sleep(sk), &wait); 167 return rc; 168 } 169 170 int tls_push_sg(struct sock *sk, 171 struct tls_context *ctx, 172 struct scatterlist *sg, 173 u16 first_offset, 174 int flags) 175 { 176 struct bio_vec bvec; 177 struct msghdr msg = { 178 .msg_flags = MSG_SPLICE_PAGES | flags, 179 }; 180 int ret = 0; 181 struct page *p; 182 size_t size; 183 int offset = first_offset; 184 185 size = sg->length - offset; 186 offset += sg->offset; 187 188 ctx->splicing_pages = true; 189 while (1) { 190 /* is sending application-limited? */ 191 tcp_rate_check_app_limited(sk); 192 p = sg_page(sg); 193 retry: 194 bvec_set_page(&bvec, p, size, offset); 195 iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size); 196 197 ret = tcp_sendmsg_locked(sk, &msg, size); 198 199 if (ret != size) { 200 if (ret > 0) { 201 offset += ret; 202 size -= ret; 203 goto retry; 204 } 205 206 offset -= sg->offset; 207 ctx->partially_sent_offset = offset; 208 ctx->partially_sent_record = (void *)sg; 209 ctx->splicing_pages = false; 210 return ret; 211 } 212 213 put_page(p); 214 sk_mem_uncharge(sk, sg->length); 215 sg = sg_next(sg); 216 if (!sg) 217 break; 218 219 offset = sg->offset; 220 size = sg->length; 221 } 222 223 ctx->splicing_pages = false; 224 225 return 0; 226 } 227 228 static int tls_handle_open_record(struct sock *sk, int flags) 229 { 230 struct tls_context *ctx = tls_get_ctx(sk); 231 232 if (tls_is_pending_open_record(ctx)) 233 return ctx->push_pending_record(sk, flags); 234 235 return 0; 236 } 237 238 int tls_process_cmsg(struct sock *sk, struct msghdr *msg, 239 unsigned char *record_type) 240 { 241 struct cmsghdr *cmsg; 242 int rc = -EINVAL; 243 244 for_each_cmsghdr(cmsg, msg) { 245 if (!CMSG_OK(msg, cmsg)) 246 return -EINVAL; 247 if (cmsg->cmsg_level != SOL_TLS) 248 continue; 249 250 switch (cmsg->cmsg_type) { 251 case TLS_SET_RECORD_TYPE: 252 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 253 return -EINVAL; 254 255 if (msg->msg_flags & MSG_MORE) 256 return -EINVAL; 257 258 rc = tls_handle_open_record(sk, msg->msg_flags); 259 if (rc) 260 return rc; 261 262 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 263 rc = 0; 264 break; 265 default: 266 return -EINVAL; 267 } 268 } 269 270 return rc; 271 } 272 273 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, 274 int flags) 275 { 276 struct scatterlist *sg; 277 u16 offset; 278 279 sg = ctx->partially_sent_record; 280 offset = ctx->partially_sent_offset; 281 282 ctx->partially_sent_record = NULL; 283 return tls_push_sg(sk, ctx, sg, offset, flags); 284 } 285 286 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) 287 { 288 struct scatterlist *sg; 289 290 for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { 291 put_page(sg_page(sg)); 292 sk_mem_uncharge(sk, sg->length); 293 } 294 ctx->partially_sent_record = NULL; 295 } 296 297 static void tls_write_space(struct sock *sk) 298 { 299 struct tls_context *ctx = tls_get_ctx(sk); 300 301 /* If splicing_pages call lower protocol write space handler 302 * to ensure we wake up any waiting operations there. For example 303 * if splicing pages where to call sk_wait_event. 304 */ 305 if (ctx->splicing_pages) { 306 ctx->sk_write_space(sk); 307 return; 308 } 309 310 #ifdef CONFIG_TLS_DEVICE 311 if (ctx->tx_conf == TLS_HW) 312 tls_device_write_space(sk, ctx); 313 else 314 #endif 315 tls_sw_write_space(sk, ctx); 316 317 ctx->sk_write_space(sk); 318 } 319 320 /** 321 * tls_ctx_free() - free TLS ULP context 322 * @sk: socket to with @ctx is attached 323 * @ctx: TLS context structure 324 * 325 * Free TLS context. If @sk is %NULL caller guarantees that the socket 326 * to which @ctx was attached has no outstanding references. 327 */ 328 void tls_ctx_free(struct sock *sk, struct tls_context *ctx) 329 { 330 if (!ctx) 331 return; 332 333 memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); 334 memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); 335 mutex_destroy(&ctx->tx_lock); 336 337 if (sk) 338 kfree_rcu(ctx, rcu); 339 else 340 kfree(ctx); 341 } 342 343 static void tls_sk_proto_cleanup(struct sock *sk, 344 struct tls_context *ctx, long timeo) 345 { 346 if (unlikely(sk->sk_write_pending) && 347 !wait_on_pending_writer(sk, &timeo)) 348 tls_handle_open_record(sk, 0); 349 350 /* We need these for tls_sw_fallback handling of other packets */ 351 if (ctx->tx_conf == TLS_SW) { 352 tls_sw_release_resources_tx(sk); 353 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 354 } else if (ctx->tx_conf == TLS_HW) { 355 tls_device_free_resources_tx(sk); 356 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 357 } 358 359 if (ctx->rx_conf == TLS_SW) { 360 tls_sw_release_resources_rx(sk); 361 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 362 } else if (ctx->rx_conf == TLS_HW) { 363 tls_device_offload_cleanup_rx(sk); 364 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 365 } 366 } 367 368 static void tls_sk_proto_close(struct sock *sk, long timeout) 369 { 370 struct inet_connection_sock *icsk = inet_csk(sk); 371 struct tls_context *ctx = tls_get_ctx(sk); 372 long timeo = sock_sndtimeo(sk, 0); 373 bool free_ctx; 374 375 if (ctx->tx_conf == TLS_SW) 376 tls_sw_cancel_work_tx(ctx); 377 378 lock_sock(sk); 379 free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; 380 381 if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) 382 tls_sk_proto_cleanup(sk, ctx, timeo); 383 384 write_lock_bh(&sk->sk_callback_lock); 385 if (free_ctx) 386 rcu_assign_pointer(icsk->icsk_ulp_data, NULL); 387 WRITE_ONCE(sk->sk_prot, ctx->sk_proto); 388 if (sk->sk_write_space == tls_write_space) 389 sk->sk_write_space = ctx->sk_write_space; 390 write_unlock_bh(&sk->sk_callback_lock); 391 release_sock(sk); 392 if (ctx->tx_conf == TLS_SW) 393 tls_sw_free_ctx_tx(ctx); 394 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 395 tls_sw_strparser_done(ctx); 396 if (ctx->rx_conf == TLS_SW) 397 tls_sw_free_ctx_rx(ctx); 398 ctx->sk_proto->close(sk, timeout); 399 400 if (free_ctx) 401 tls_ctx_free(sk, ctx); 402 } 403 404 static __poll_t tls_sk_poll(struct file *file, struct socket *sock, 405 struct poll_table_struct *wait) 406 { 407 struct tls_sw_context_rx *ctx; 408 struct tls_context *tls_ctx; 409 struct sock *sk = sock->sk; 410 struct sk_psock *psock; 411 __poll_t mask = 0; 412 u8 shutdown; 413 int state; 414 415 mask = tcp_poll(file, sock, wait); 416 417 state = inet_sk_state_load(sk); 418 shutdown = READ_ONCE(sk->sk_shutdown); 419 if (unlikely(state != TCP_ESTABLISHED || shutdown & RCV_SHUTDOWN)) 420 return mask; 421 422 tls_ctx = tls_get_ctx(sk); 423 ctx = tls_sw_ctx_rx(tls_ctx); 424 psock = sk_psock_get(sk); 425 426 if (skb_queue_empty_lockless(&ctx->rx_list) && 427 !tls_strp_msg_ready(ctx) && 428 sk_psock_queue_empty(psock)) 429 mask &= ~(EPOLLIN | EPOLLRDNORM); 430 431 if (psock) 432 sk_psock_put(sk, psock); 433 434 return mask; 435 } 436 437 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, 438 int __user *optlen, int tx) 439 { 440 int rc = 0; 441 const struct tls_cipher_desc *cipher_desc; 442 struct tls_context *ctx = tls_get_ctx(sk); 443 struct tls_crypto_info *crypto_info; 444 struct cipher_context *cctx; 445 int len; 446 447 if (get_user(len, optlen)) 448 return -EFAULT; 449 450 if (!optval || (len < sizeof(*crypto_info))) { 451 rc = -EINVAL; 452 goto out; 453 } 454 455 if (!ctx) { 456 rc = -EBUSY; 457 goto out; 458 } 459 460 /* get user crypto info */ 461 if (tx) { 462 crypto_info = &ctx->crypto_send.info; 463 cctx = &ctx->tx; 464 } else { 465 crypto_info = &ctx->crypto_recv.info; 466 cctx = &ctx->rx; 467 } 468 469 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 470 rc = -EBUSY; 471 goto out; 472 } 473 474 if (len == sizeof(*crypto_info)) { 475 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 476 rc = -EFAULT; 477 goto out; 478 } 479 480 cipher_desc = get_cipher_desc(crypto_info->cipher_type); 481 if (!cipher_desc || len != cipher_desc->crypto_info) { 482 rc = -EINVAL; 483 goto out; 484 } 485 486 memcpy(crypto_info_iv(crypto_info, cipher_desc), 487 cctx->iv + cipher_desc->salt, cipher_desc->iv); 488 memcpy(crypto_info_rec_seq(crypto_info, cipher_desc), 489 cctx->rec_seq, cipher_desc->rec_seq); 490 491 if (copy_to_user(optval, crypto_info, cipher_desc->crypto_info)) 492 rc = -EFAULT; 493 494 out: 495 return rc; 496 } 497 498 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, 499 int __user *optlen) 500 { 501 struct tls_context *ctx = tls_get_ctx(sk); 502 unsigned int value; 503 int len; 504 505 if (get_user(len, optlen)) 506 return -EFAULT; 507 508 if (len != sizeof(value)) 509 return -EINVAL; 510 511 value = ctx->zerocopy_sendfile; 512 if (copy_to_user(optval, &value, sizeof(value))) 513 return -EFAULT; 514 515 return 0; 516 } 517 518 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, 519 int __user *optlen) 520 { 521 struct tls_context *ctx = tls_get_ctx(sk); 522 int value, len; 523 524 if (ctx->prot_info.version != TLS_1_3_VERSION) 525 return -EINVAL; 526 527 if (get_user(len, optlen)) 528 return -EFAULT; 529 if (len < sizeof(value)) 530 return -EINVAL; 531 532 value = -EINVAL; 533 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 534 value = ctx->rx_no_pad; 535 if (value < 0) 536 return value; 537 538 if (put_user(sizeof(value), optlen)) 539 return -EFAULT; 540 if (copy_to_user(optval, &value, sizeof(value))) 541 return -EFAULT; 542 543 return 0; 544 } 545 546 static int do_tls_getsockopt(struct sock *sk, int optname, 547 char __user *optval, int __user *optlen) 548 { 549 int rc = 0; 550 551 lock_sock(sk); 552 553 switch (optname) { 554 case TLS_TX: 555 case TLS_RX: 556 rc = do_tls_getsockopt_conf(sk, optval, optlen, 557 optname == TLS_TX); 558 break; 559 case TLS_TX_ZEROCOPY_RO: 560 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); 561 break; 562 case TLS_RX_EXPECT_NO_PAD: 563 rc = do_tls_getsockopt_no_pad(sk, optval, optlen); 564 break; 565 default: 566 rc = -ENOPROTOOPT; 567 break; 568 } 569 570 release_sock(sk); 571 572 return rc; 573 } 574 575 static int tls_getsockopt(struct sock *sk, int level, int optname, 576 char __user *optval, int __user *optlen) 577 { 578 struct tls_context *ctx = tls_get_ctx(sk); 579 580 if (level != SOL_TLS) 581 return ctx->sk_proto->getsockopt(sk, level, 582 optname, optval, optlen); 583 584 return do_tls_getsockopt(sk, optname, optval, optlen); 585 } 586 587 static int validate_crypto_info(const struct tls_crypto_info *crypto_info, 588 const struct tls_crypto_info *alt_crypto_info) 589 { 590 if (crypto_info->version != TLS_1_2_VERSION && 591 crypto_info->version != TLS_1_3_VERSION) 592 return -EINVAL; 593 594 switch (crypto_info->cipher_type) { 595 case TLS_CIPHER_ARIA_GCM_128: 596 case TLS_CIPHER_ARIA_GCM_256: 597 if (crypto_info->version != TLS_1_2_VERSION) 598 return -EINVAL; 599 break; 600 } 601 602 /* Ensure that TLS version and ciphers are same in both directions */ 603 if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { 604 if (alt_crypto_info->version != crypto_info->version || 605 alt_crypto_info->cipher_type != crypto_info->cipher_type) 606 return -EINVAL; 607 } 608 609 return 0; 610 } 611 612 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, 613 unsigned int optlen, int tx) 614 { 615 struct tls_crypto_info *crypto_info; 616 struct tls_crypto_info *alt_crypto_info; 617 struct tls_context *ctx = tls_get_ctx(sk); 618 const struct tls_cipher_desc *cipher_desc; 619 int rc = 0; 620 int conf; 621 622 if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) 623 return -EINVAL; 624 625 if (tx) { 626 crypto_info = &ctx->crypto_send.info; 627 alt_crypto_info = &ctx->crypto_recv.info; 628 } else { 629 crypto_info = &ctx->crypto_recv.info; 630 alt_crypto_info = &ctx->crypto_send.info; 631 } 632 633 /* Currently we don't support set crypto info more than one time */ 634 if (TLS_CRYPTO_INFO_READY(crypto_info)) 635 return -EBUSY; 636 637 rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); 638 if (rc) { 639 rc = -EFAULT; 640 goto err_crypto_info; 641 } 642 643 rc = validate_crypto_info(crypto_info, alt_crypto_info); 644 if (rc) 645 goto err_crypto_info; 646 647 cipher_desc = get_cipher_desc(crypto_info->cipher_type); 648 if (!cipher_desc) { 649 rc = -EINVAL; 650 goto err_crypto_info; 651 } 652 653 if (optlen != cipher_desc->crypto_info) { 654 rc = -EINVAL; 655 goto err_crypto_info; 656 } 657 658 rc = copy_from_sockptr_offset(crypto_info + 1, optval, 659 sizeof(*crypto_info), 660 optlen - sizeof(*crypto_info)); 661 if (rc) { 662 rc = -EFAULT; 663 goto err_crypto_info; 664 } 665 666 if (tx) { 667 rc = tls_set_device_offload(sk); 668 conf = TLS_HW; 669 if (!rc) { 670 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); 671 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 672 } else { 673 rc = tls_set_sw_offload(sk, 1); 674 if (rc) 675 goto err_crypto_info; 676 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); 677 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 678 conf = TLS_SW; 679 } 680 } else { 681 rc = tls_set_device_offload_rx(sk, ctx); 682 conf = TLS_HW; 683 if (!rc) { 684 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); 685 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 686 } else { 687 rc = tls_set_sw_offload(sk, 0); 688 if (rc) 689 goto err_crypto_info; 690 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); 691 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 692 conf = TLS_SW; 693 } 694 tls_sw_strparser_arm(sk, ctx); 695 } 696 697 if (tx) 698 ctx->tx_conf = conf; 699 else 700 ctx->rx_conf = conf; 701 update_sk_prot(sk, ctx); 702 if (tx) { 703 ctx->sk_write_space = sk->sk_write_space; 704 sk->sk_write_space = tls_write_space; 705 } else { 706 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); 707 708 tls_strp_check_rcv(&rx_ctx->strp); 709 } 710 return 0; 711 712 err_crypto_info: 713 memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); 714 return rc; 715 } 716 717 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, 718 unsigned int optlen) 719 { 720 struct tls_context *ctx = tls_get_ctx(sk); 721 unsigned int value; 722 723 if (sockptr_is_null(optval) || optlen != sizeof(value)) 724 return -EINVAL; 725 726 if (copy_from_sockptr(&value, optval, sizeof(value))) 727 return -EFAULT; 728 729 if (value > 1) 730 return -EINVAL; 731 732 ctx->zerocopy_sendfile = value; 733 734 return 0; 735 } 736 737 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, 738 unsigned int optlen) 739 { 740 struct tls_context *ctx = tls_get_ctx(sk); 741 u32 val; 742 int rc; 743 744 if (ctx->prot_info.version != TLS_1_3_VERSION || 745 sockptr_is_null(optval) || optlen < sizeof(val)) 746 return -EINVAL; 747 748 rc = copy_from_sockptr(&val, optval, sizeof(val)); 749 if (rc) 750 return -EFAULT; 751 if (val > 1) 752 return -EINVAL; 753 rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); 754 if (rc < 1) 755 return rc == 0 ? -EINVAL : rc; 756 757 lock_sock(sk); 758 rc = -EINVAL; 759 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { 760 ctx->rx_no_pad = val; 761 tls_update_rx_zc_capable(ctx); 762 rc = 0; 763 } 764 release_sock(sk); 765 766 return rc; 767 } 768 769 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, 770 unsigned int optlen) 771 { 772 int rc = 0; 773 774 switch (optname) { 775 case TLS_TX: 776 case TLS_RX: 777 lock_sock(sk); 778 rc = do_tls_setsockopt_conf(sk, optval, optlen, 779 optname == TLS_TX); 780 release_sock(sk); 781 break; 782 case TLS_TX_ZEROCOPY_RO: 783 lock_sock(sk); 784 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); 785 release_sock(sk); 786 break; 787 case TLS_RX_EXPECT_NO_PAD: 788 rc = do_tls_setsockopt_no_pad(sk, optval, optlen); 789 break; 790 default: 791 rc = -ENOPROTOOPT; 792 break; 793 } 794 return rc; 795 } 796 797 static int tls_setsockopt(struct sock *sk, int level, int optname, 798 sockptr_t optval, unsigned int optlen) 799 { 800 struct tls_context *ctx = tls_get_ctx(sk); 801 802 if (level != SOL_TLS) 803 return ctx->sk_proto->setsockopt(sk, level, optname, optval, 804 optlen); 805 806 return do_tls_setsockopt(sk, optname, optval, optlen); 807 } 808 809 struct tls_context *tls_ctx_create(struct sock *sk) 810 { 811 struct inet_connection_sock *icsk = inet_csk(sk); 812 struct tls_context *ctx; 813 814 ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); 815 if (!ctx) 816 return NULL; 817 818 mutex_init(&ctx->tx_lock); 819 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 820 ctx->sk_proto = READ_ONCE(sk->sk_prot); 821 ctx->sk = sk; 822 return ctx; 823 } 824 825 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 826 const struct proto_ops *base) 827 { 828 ops[TLS_BASE][TLS_BASE] = *base; 829 830 ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 831 ops[TLS_SW ][TLS_BASE].splice_eof = tls_sw_splice_eof; 832 833 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; 834 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; 835 ops[TLS_BASE][TLS_SW ].poll = tls_sk_poll; 836 ops[TLS_BASE][TLS_SW ].read_sock = tls_sw_read_sock; 837 838 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; 839 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; 840 ops[TLS_SW ][TLS_SW ].poll = tls_sk_poll; 841 ops[TLS_SW ][TLS_SW ].read_sock = tls_sw_read_sock; 842 843 #ifdef CONFIG_TLS_DEVICE 844 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 845 846 ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; 847 848 ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; 849 850 ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; 851 852 ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; 853 #endif 854 #ifdef CONFIG_TLS_TOE 855 ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 856 #endif 857 } 858 859 static void tls_build_proto(struct sock *sk) 860 { 861 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 862 struct proto *prot = READ_ONCE(sk->sk_prot); 863 864 /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ 865 if (ip_ver == TLSV6 && 866 unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) { 867 mutex_lock(&tcpv6_prot_mutex); 868 if (likely(prot != saved_tcpv6_prot)) { 869 build_protos(tls_prots[TLSV6], prot); 870 build_proto_ops(tls_proto_ops[TLSV6], 871 sk->sk_socket->ops); 872 smp_store_release(&saved_tcpv6_prot, prot); 873 } 874 mutex_unlock(&tcpv6_prot_mutex); 875 } 876 877 if (ip_ver == TLSV4 && 878 unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) { 879 mutex_lock(&tcpv4_prot_mutex); 880 if (likely(prot != saved_tcpv4_prot)) { 881 build_protos(tls_prots[TLSV4], prot); 882 build_proto_ops(tls_proto_ops[TLSV4], 883 sk->sk_socket->ops); 884 smp_store_release(&saved_tcpv4_prot, prot); 885 } 886 mutex_unlock(&tcpv4_prot_mutex); 887 } 888 } 889 890 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 891 const struct proto *base) 892 { 893 prot[TLS_BASE][TLS_BASE] = *base; 894 prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; 895 prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; 896 prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; 897 898 prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 899 prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; 900 prot[TLS_SW][TLS_BASE].splice_eof = tls_sw_splice_eof; 901 902 prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; 903 prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; 904 prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 905 prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; 906 907 prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; 908 prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; 909 prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 910 prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; 911 912 #ifdef CONFIG_TLS_DEVICE 913 prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 914 prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; 915 prot[TLS_HW][TLS_BASE].splice_eof = tls_device_splice_eof; 916 917 prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; 918 prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; 919 prot[TLS_HW][TLS_SW].splice_eof = tls_device_splice_eof; 920 921 prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; 922 923 prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; 924 925 prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; 926 #endif 927 #ifdef CONFIG_TLS_TOE 928 prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 929 prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash; 930 prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash; 931 #endif 932 } 933 934 static int tls_init(struct sock *sk) 935 { 936 struct tls_context *ctx; 937 int rc = 0; 938 939 tls_build_proto(sk); 940 941 #ifdef CONFIG_TLS_TOE 942 if (tls_toe_bypass(sk)) 943 return 0; 944 #endif 945 946 /* The TLS ulp is currently supported only for TCP sockets 947 * in ESTABLISHED state. 948 * Supporting sockets in LISTEN state will require us 949 * to modify the accept implementation to clone rather then 950 * share the ulp context. 951 */ 952 if (sk->sk_state != TCP_ESTABLISHED) 953 return -ENOTCONN; 954 955 /* allocate tls context */ 956 write_lock_bh(&sk->sk_callback_lock); 957 ctx = tls_ctx_create(sk); 958 if (!ctx) { 959 rc = -ENOMEM; 960 goto out; 961 } 962 963 ctx->tx_conf = TLS_BASE; 964 ctx->rx_conf = TLS_BASE; 965 update_sk_prot(sk, ctx); 966 out: 967 write_unlock_bh(&sk->sk_callback_lock); 968 return rc; 969 } 970 971 static void tls_update(struct sock *sk, struct proto *p, 972 void (*write_space)(struct sock *sk)) 973 { 974 struct tls_context *ctx; 975 976 WARN_ON_ONCE(sk->sk_prot == p); 977 978 ctx = tls_get_ctx(sk); 979 if (likely(ctx)) { 980 ctx->sk_write_space = write_space; 981 ctx->sk_proto = p; 982 } else { 983 /* Pairs with lockless read in sk_clone_lock(). */ 984 WRITE_ONCE(sk->sk_prot, p); 985 sk->sk_write_space = write_space; 986 } 987 } 988 989 static u16 tls_user_config(struct tls_context *ctx, bool tx) 990 { 991 u16 config = tx ? ctx->tx_conf : ctx->rx_conf; 992 993 switch (config) { 994 case TLS_BASE: 995 return TLS_CONF_BASE; 996 case TLS_SW: 997 return TLS_CONF_SW; 998 case TLS_HW: 999 return TLS_CONF_HW; 1000 case TLS_HW_RECORD: 1001 return TLS_CONF_HW_RECORD; 1002 } 1003 return 0; 1004 } 1005 1006 static int tls_get_info(const struct sock *sk, struct sk_buff *skb) 1007 { 1008 u16 version, cipher_type; 1009 struct tls_context *ctx; 1010 struct nlattr *start; 1011 int err; 1012 1013 start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); 1014 if (!start) 1015 return -EMSGSIZE; 1016 1017 rcu_read_lock(); 1018 ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); 1019 if (!ctx) { 1020 err = 0; 1021 goto nla_failure; 1022 } 1023 version = ctx->prot_info.version; 1024 if (version) { 1025 err = nla_put_u16(skb, TLS_INFO_VERSION, version); 1026 if (err) 1027 goto nla_failure; 1028 } 1029 cipher_type = ctx->prot_info.cipher_type; 1030 if (cipher_type) { 1031 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); 1032 if (err) 1033 goto nla_failure; 1034 } 1035 err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); 1036 if (err) 1037 goto nla_failure; 1038 1039 err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); 1040 if (err) 1041 goto nla_failure; 1042 1043 if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) { 1044 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX); 1045 if (err) 1046 goto nla_failure; 1047 } 1048 if (ctx->rx_no_pad) { 1049 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); 1050 if (err) 1051 goto nla_failure; 1052 } 1053 1054 rcu_read_unlock(); 1055 nla_nest_end(skb, start); 1056 return 0; 1057 1058 nla_failure: 1059 rcu_read_unlock(); 1060 nla_nest_cancel(skb, start); 1061 return err; 1062 } 1063 1064 static size_t tls_get_info_size(const struct sock *sk) 1065 { 1066 size_t size = 0; 1067 1068 size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ 1069 nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ 1070 nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ 1071 nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ 1072 nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ 1073 nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ 1074 nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ 1075 0; 1076 1077 return size; 1078 } 1079 1080 static int __net_init tls_init_net(struct net *net) 1081 { 1082 int err; 1083 1084 net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); 1085 if (!net->mib.tls_statistics) 1086 return -ENOMEM; 1087 1088 err = tls_proc_init(net); 1089 if (err) 1090 goto err_free_stats; 1091 1092 return 0; 1093 err_free_stats: 1094 free_percpu(net->mib.tls_statistics); 1095 return err; 1096 } 1097 1098 static void __net_exit tls_exit_net(struct net *net) 1099 { 1100 tls_proc_fini(net); 1101 free_percpu(net->mib.tls_statistics); 1102 } 1103 1104 static struct pernet_operations tls_proc_ops = { 1105 .init = tls_init_net, 1106 .exit = tls_exit_net, 1107 }; 1108 1109 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 1110 .name = "tls", 1111 .owner = THIS_MODULE, 1112 .init = tls_init, 1113 .update = tls_update, 1114 .get_info = tls_get_info, 1115 .get_info_size = tls_get_info_size, 1116 }; 1117 1118 static int __init tls_register(void) 1119 { 1120 int err; 1121 1122 err = register_pernet_subsys(&tls_proc_ops); 1123 if (err) 1124 return err; 1125 1126 err = tls_strp_dev_init(); 1127 if (err) 1128 goto err_pernet; 1129 1130 err = tls_device_init(); 1131 if (err) 1132 goto err_strp; 1133 1134 tcp_register_ulp(&tcp_tls_ulp_ops); 1135 1136 return 0; 1137 err_strp: 1138 tls_strp_dev_exit(); 1139 err_pernet: 1140 unregister_pernet_subsys(&tls_proc_ops); 1141 return err; 1142 } 1143 1144 static void __exit tls_unregister(void) 1145 { 1146 tcp_unregister_ulp(&tcp_tls_ulp_ops); 1147 tls_strp_dev_exit(); 1148 tls_device_cleanup(); 1149 unregister_pernet_subsys(&tls_proc_ops); 1150 } 1151 1152 module_init(tls_register); 1153 module_exit(tls_unregister); 1154