1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 #include <linux/inetdevice.h> 42 #include <linux/inet_diag.h> 43 44 #include <net/snmp.h> 45 #include <net/tls.h> 46 #include <net/tls_toe.h> 47 48 #include "tls.h" 49 50 MODULE_AUTHOR("Mellanox Technologies"); 51 MODULE_DESCRIPTION("Transport Layer Security Support"); 52 MODULE_LICENSE("Dual BSD/GPL"); 53 MODULE_ALIAS_TCP_ULP("tls"); 54 55 enum { 56 TLSV4, 57 TLSV6, 58 TLS_NUM_PROTS, 59 }; 60 61 #define CHECK_CIPHER_DESC(cipher,ci) \ 62 static_assert(cipher ## _IV_SIZE <= TLS_MAX_IV_SIZE); \ 63 static_assert(cipher ## _SALT_SIZE <= TLS_MAX_SALT_SIZE); \ 64 static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE); \ 65 static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE); \ 66 static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE); \ 67 static_assert(sizeof_field(struct ci, key) == cipher ## _KEY_SIZE); \ 68 static_assert(sizeof_field(struct ci, salt) == cipher ## _SALT_SIZE); \ 69 static_assert(sizeof_field(struct ci, rec_seq) == cipher ## _REC_SEQ_SIZE); 70 71 #define __CIPHER_DESC(ci) \ 72 .iv_offset = offsetof(struct ci, iv), \ 73 .key_offset = offsetof(struct ci, key), \ 74 .salt_offset = offsetof(struct ci, salt), \ 75 .rec_seq_offset = offsetof(struct ci, rec_seq), \ 76 .crypto_info = sizeof(struct ci) 77 78 #define CIPHER_DESC(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ 79 .nonce = cipher ## _IV_SIZE, \ 80 .iv = cipher ## _IV_SIZE, \ 81 .key = cipher ## _KEY_SIZE, \ 82 .salt = cipher ## _SALT_SIZE, \ 83 .tag = cipher ## _TAG_SIZE, \ 84 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 85 .cipher_name = algname, \ 86 .offloadable = _offloadable, \ 87 __CIPHER_DESC(ci), \ 88 } 89 90 #define CIPHER_DESC_NONCE0(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ 91 .nonce = 0, \ 92 .iv = cipher ## _IV_SIZE, \ 93 .key = cipher ## _KEY_SIZE, \ 94 .salt = cipher ## _SALT_SIZE, \ 95 .tag = cipher ## _TAG_SIZE, \ 96 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 97 .cipher_name = algname, \ 98 .offloadable = _offloadable, \ 99 __CIPHER_DESC(ci), \ 100 } 101 102 const struct tls_cipher_desc tls_cipher_desc[TLS_CIPHER_MAX + 1 - TLS_CIPHER_MIN] = { 103 CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128, "gcm(aes)", true), 104 CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256, "gcm(aes)", true), 105 CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128, "ccm(aes)", false), 106 CIPHER_DESC_NONCE0(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305, "rfc7539(chacha20,poly1305)", false), 107 CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm, "gcm(sm4)", false), 108 CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm, "ccm(sm4)", false), 109 CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128, "gcm(aria)", false), 110 CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256, "gcm(aria)", false), 111 }; 112 113 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128); 114 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256); 115 CHECK_CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128); 116 CHECK_CIPHER_DESC(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305); 117 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm); 118 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm); 119 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128); 120 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256); 121 122 static const struct proto *saved_tcpv6_prot; 123 static DEFINE_MUTEX(tcpv6_prot_mutex); 124 static const struct proto *saved_tcpv4_prot; 125 static DEFINE_MUTEX(tcpv4_prot_mutex); 126 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 127 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 128 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 129 const struct proto *base); 130 131 void update_sk_prot(struct sock *sk, struct tls_context *ctx) 132 { 133 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 134 135 WRITE_ONCE(sk->sk_prot, 136 &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); 137 WRITE_ONCE(sk->sk_socket->ops, 138 &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]); 139 } 140 141 int wait_on_pending_writer(struct sock *sk, long *timeo) 142 { 143 int rc = 0; 144 DEFINE_WAIT_FUNC(wait, woken_wake_function); 145 146 add_wait_queue(sk_sleep(sk), &wait); 147 while (1) { 148 if (!*timeo) { 149 rc = -EAGAIN; 150 break; 151 } 152 153 if (signal_pending(current)) { 154 rc = sock_intr_errno(*timeo); 155 break; 156 } 157 158 if (sk_wait_event(sk, timeo, 159 !READ_ONCE(sk->sk_write_pending), &wait)) 160 break; 161 } 162 remove_wait_queue(sk_sleep(sk), &wait); 163 return rc; 164 } 165 166 int tls_push_sg(struct sock *sk, 167 struct tls_context *ctx, 168 struct scatterlist *sg, 169 u16 first_offset, 170 int flags) 171 { 172 struct bio_vec bvec; 173 struct msghdr msg = { 174 .msg_flags = MSG_SPLICE_PAGES | flags, 175 }; 176 int ret = 0; 177 struct page *p; 178 size_t size; 179 int offset = first_offset; 180 181 size = sg->length - offset; 182 offset += sg->offset; 183 184 ctx->splicing_pages = true; 185 while (1) { 186 /* is sending application-limited? */ 187 tcp_rate_check_app_limited(sk); 188 p = sg_page(sg); 189 retry: 190 bvec_set_page(&bvec, p, size, offset); 191 iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size); 192 193 ret = tcp_sendmsg_locked(sk, &msg, size); 194 195 if (ret != size) { 196 if (ret > 0) { 197 offset += ret; 198 size -= ret; 199 goto retry; 200 } 201 202 offset -= sg->offset; 203 ctx->partially_sent_offset = offset; 204 ctx->partially_sent_record = (void *)sg; 205 ctx->splicing_pages = false; 206 return ret; 207 } 208 209 put_page(p); 210 sk_mem_uncharge(sk, sg->length); 211 sg = sg_next(sg); 212 if (!sg) 213 break; 214 215 offset = sg->offset; 216 size = sg->length; 217 } 218 219 ctx->splicing_pages = false; 220 221 return 0; 222 } 223 224 static int tls_handle_open_record(struct sock *sk, int flags) 225 { 226 struct tls_context *ctx = tls_get_ctx(sk); 227 228 if (tls_is_pending_open_record(ctx)) 229 return ctx->push_pending_record(sk, flags); 230 231 return 0; 232 } 233 234 int tls_process_cmsg(struct sock *sk, struct msghdr *msg, 235 unsigned char *record_type) 236 { 237 struct cmsghdr *cmsg; 238 int rc = -EINVAL; 239 240 for_each_cmsghdr(cmsg, msg) { 241 if (!CMSG_OK(msg, cmsg)) 242 return -EINVAL; 243 if (cmsg->cmsg_level != SOL_TLS) 244 continue; 245 246 switch (cmsg->cmsg_type) { 247 case TLS_SET_RECORD_TYPE: 248 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 249 return -EINVAL; 250 251 if (msg->msg_flags & MSG_MORE) 252 return -EINVAL; 253 254 rc = tls_handle_open_record(sk, msg->msg_flags); 255 if (rc) 256 return rc; 257 258 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 259 rc = 0; 260 break; 261 default: 262 return -EINVAL; 263 } 264 } 265 266 return rc; 267 } 268 269 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, 270 int flags) 271 { 272 struct scatterlist *sg; 273 u16 offset; 274 275 sg = ctx->partially_sent_record; 276 offset = ctx->partially_sent_offset; 277 278 ctx->partially_sent_record = NULL; 279 return tls_push_sg(sk, ctx, sg, offset, flags); 280 } 281 282 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) 283 { 284 struct scatterlist *sg; 285 286 for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { 287 put_page(sg_page(sg)); 288 sk_mem_uncharge(sk, sg->length); 289 } 290 ctx->partially_sent_record = NULL; 291 } 292 293 static void tls_write_space(struct sock *sk) 294 { 295 struct tls_context *ctx = tls_get_ctx(sk); 296 297 /* If splicing_pages call lower protocol write space handler 298 * to ensure we wake up any waiting operations there. For example 299 * if splicing pages where to call sk_wait_event. 300 */ 301 if (ctx->splicing_pages) { 302 ctx->sk_write_space(sk); 303 return; 304 } 305 306 #ifdef CONFIG_TLS_DEVICE 307 if (ctx->tx_conf == TLS_HW) 308 tls_device_write_space(sk, ctx); 309 else 310 #endif 311 tls_sw_write_space(sk, ctx); 312 313 ctx->sk_write_space(sk); 314 } 315 316 /** 317 * tls_ctx_free() - free TLS ULP context 318 * @sk: socket to with @ctx is attached 319 * @ctx: TLS context structure 320 * 321 * Free TLS context. If @sk is %NULL caller guarantees that the socket 322 * to which @ctx was attached has no outstanding references. 323 */ 324 void tls_ctx_free(struct sock *sk, struct tls_context *ctx) 325 { 326 if (!ctx) 327 return; 328 329 memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); 330 memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); 331 mutex_destroy(&ctx->tx_lock); 332 333 if (sk) 334 kfree_rcu(ctx, rcu); 335 else 336 kfree(ctx); 337 } 338 339 static void tls_sk_proto_cleanup(struct sock *sk, 340 struct tls_context *ctx, long timeo) 341 { 342 if (unlikely(sk->sk_write_pending) && 343 !wait_on_pending_writer(sk, &timeo)) 344 tls_handle_open_record(sk, 0); 345 346 /* We need these for tls_sw_fallback handling of other packets */ 347 if (ctx->tx_conf == TLS_SW) { 348 tls_sw_release_resources_tx(sk); 349 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 350 } else if (ctx->tx_conf == TLS_HW) { 351 tls_device_free_resources_tx(sk); 352 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 353 } 354 355 if (ctx->rx_conf == TLS_SW) { 356 tls_sw_release_resources_rx(sk); 357 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 358 } else if (ctx->rx_conf == TLS_HW) { 359 tls_device_offload_cleanup_rx(sk); 360 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 361 } 362 } 363 364 static void tls_sk_proto_close(struct sock *sk, long timeout) 365 { 366 struct inet_connection_sock *icsk = inet_csk(sk); 367 struct tls_context *ctx = tls_get_ctx(sk); 368 long timeo = sock_sndtimeo(sk, 0); 369 bool free_ctx; 370 371 if (ctx->tx_conf == TLS_SW) 372 tls_sw_cancel_work_tx(ctx); 373 374 lock_sock(sk); 375 free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; 376 377 if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) 378 tls_sk_proto_cleanup(sk, ctx, timeo); 379 380 write_lock_bh(&sk->sk_callback_lock); 381 if (free_ctx) 382 rcu_assign_pointer(icsk->icsk_ulp_data, NULL); 383 WRITE_ONCE(sk->sk_prot, ctx->sk_proto); 384 if (sk->sk_write_space == tls_write_space) 385 sk->sk_write_space = ctx->sk_write_space; 386 write_unlock_bh(&sk->sk_callback_lock); 387 release_sock(sk); 388 if (ctx->tx_conf == TLS_SW) 389 tls_sw_free_ctx_tx(ctx); 390 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 391 tls_sw_strparser_done(ctx); 392 if (ctx->rx_conf == TLS_SW) 393 tls_sw_free_ctx_rx(ctx); 394 ctx->sk_proto->close(sk, timeout); 395 396 if (free_ctx) 397 tls_ctx_free(sk, ctx); 398 } 399 400 static __poll_t tls_sk_poll(struct file *file, struct socket *sock, 401 struct poll_table_struct *wait) 402 { 403 struct tls_sw_context_rx *ctx; 404 struct tls_context *tls_ctx; 405 struct sock *sk = sock->sk; 406 struct sk_psock *psock; 407 __poll_t mask = 0; 408 u8 shutdown; 409 int state; 410 411 mask = tcp_poll(file, sock, wait); 412 413 state = inet_sk_state_load(sk); 414 shutdown = READ_ONCE(sk->sk_shutdown); 415 if (unlikely(state != TCP_ESTABLISHED || shutdown & RCV_SHUTDOWN)) 416 return mask; 417 418 tls_ctx = tls_get_ctx(sk); 419 ctx = tls_sw_ctx_rx(tls_ctx); 420 psock = sk_psock_get(sk); 421 422 if (skb_queue_empty_lockless(&ctx->rx_list) && 423 !tls_strp_msg_ready(ctx) && 424 sk_psock_queue_empty(psock)) 425 mask &= ~(EPOLLIN | EPOLLRDNORM); 426 427 if (psock) 428 sk_psock_put(sk, psock); 429 430 return mask; 431 } 432 433 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, 434 int __user *optlen, int tx) 435 { 436 int rc = 0; 437 const struct tls_cipher_desc *cipher_desc; 438 struct tls_context *ctx = tls_get_ctx(sk); 439 struct tls_crypto_info *crypto_info; 440 struct cipher_context *cctx; 441 int len; 442 443 if (get_user(len, optlen)) 444 return -EFAULT; 445 446 if (!optval || (len < sizeof(*crypto_info))) { 447 rc = -EINVAL; 448 goto out; 449 } 450 451 if (!ctx) { 452 rc = -EBUSY; 453 goto out; 454 } 455 456 /* get user crypto info */ 457 if (tx) { 458 crypto_info = &ctx->crypto_send.info; 459 cctx = &ctx->tx; 460 } else { 461 crypto_info = &ctx->crypto_recv.info; 462 cctx = &ctx->rx; 463 } 464 465 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 466 rc = -EBUSY; 467 goto out; 468 } 469 470 if (len == sizeof(*crypto_info)) { 471 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 472 rc = -EFAULT; 473 goto out; 474 } 475 476 cipher_desc = get_cipher_desc(crypto_info->cipher_type); 477 if (!cipher_desc || len != cipher_desc->crypto_info) { 478 rc = -EINVAL; 479 goto out; 480 } 481 482 memcpy(crypto_info_iv(crypto_info, cipher_desc), 483 cctx->iv + cipher_desc->salt, cipher_desc->iv); 484 memcpy(crypto_info_rec_seq(crypto_info, cipher_desc), 485 cctx->rec_seq, cipher_desc->rec_seq); 486 487 if (copy_to_user(optval, crypto_info, cipher_desc->crypto_info)) 488 rc = -EFAULT; 489 490 out: 491 return rc; 492 } 493 494 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, 495 int __user *optlen) 496 { 497 struct tls_context *ctx = tls_get_ctx(sk); 498 unsigned int value; 499 int len; 500 501 if (get_user(len, optlen)) 502 return -EFAULT; 503 504 if (len != sizeof(value)) 505 return -EINVAL; 506 507 value = ctx->zerocopy_sendfile; 508 if (copy_to_user(optval, &value, sizeof(value))) 509 return -EFAULT; 510 511 return 0; 512 } 513 514 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, 515 int __user *optlen) 516 { 517 struct tls_context *ctx = tls_get_ctx(sk); 518 int value, len; 519 520 if (ctx->prot_info.version != TLS_1_3_VERSION) 521 return -EINVAL; 522 523 if (get_user(len, optlen)) 524 return -EFAULT; 525 if (len < sizeof(value)) 526 return -EINVAL; 527 528 value = -EINVAL; 529 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 530 value = ctx->rx_no_pad; 531 if (value < 0) 532 return value; 533 534 if (put_user(sizeof(value), optlen)) 535 return -EFAULT; 536 if (copy_to_user(optval, &value, sizeof(value))) 537 return -EFAULT; 538 539 return 0; 540 } 541 542 static int do_tls_getsockopt(struct sock *sk, int optname, 543 char __user *optval, int __user *optlen) 544 { 545 int rc = 0; 546 547 lock_sock(sk); 548 549 switch (optname) { 550 case TLS_TX: 551 case TLS_RX: 552 rc = do_tls_getsockopt_conf(sk, optval, optlen, 553 optname == TLS_TX); 554 break; 555 case TLS_TX_ZEROCOPY_RO: 556 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); 557 break; 558 case TLS_RX_EXPECT_NO_PAD: 559 rc = do_tls_getsockopt_no_pad(sk, optval, optlen); 560 break; 561 default: 562 rc = -ENOPROTOOPT; 563 break; 564 } 565 566 release_sock(sk); 567 568 return rc; 569 } 570 571 static int tls_getsockopt(struct sock *sk, int level, int optname, 572 char __user *optval, int __user *optlen) 573 { 574 struct tls_context *ctx = tls_get_ctx(sk); 575 576 if (level != SOL_TLS) 577 return ctx->sk_proto->getsockopt(sk, level, 578 optname, optval, optlen); 579 580 return do_tls_getsockopt(sk, optname, optval, optlen); 581 } 582 583 static int validate_crypto_info(const struct tls_crypto_info *crypto_info, 584 const struct tls_crypto_info *alt_crypto_info) 585 { 586 if (crypto_info->version != TLS_1_2_VERSION && 587 crypto_info->version != TLS_1_3_VERSION) 588 return -EINVAL; 589 590 switch (crypto_info->cipher_type) { 591 case TLS_CIPHER_ARIA_GCM_128: 592 case TLS_CIPHER_ARIA_GCM_256: 593 if (crypto_info->version != TLS_1_2_VERSION) 594 return -EINVAL; 595 break; 596 } 597 598 /* Ensure that TLS version and ciphers are same in both directions */ 599 if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { 600 if (alt_crypto_info->version != crypto_info->version || 601 alt_crypto_info->cipher_type != crypto_info->cipher_type) 602 return -EINVAL; 603 } 604 605 return 0; 606 } 607 608 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, 609 unsigned int optlen, int tx) 610 { 611 struct tls_crypto_info *crypto_info; 612 struct tls_crypto_info *alt_crypto_info; 613 struct tls_context *ctx = tls_get_ctx(sk); 614 const struct tls_cipher_desc *cipher_desc; 615 int rc = 0; 616 int conf; 617 618 if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) 619 return -EINVAL; 620 621 if (tx) { 622 crypto_info = &ctx->crypto_send.info; 623 alt_crypto_info = &ctx->crypto_recv.info; 624 } else { 625 crypto_info = &ctx->crypto_recv.info; 626 alt_crypto_info = &ctx->crypto_send.info; 627 } 628 629 /* Currently we don't support set crypto info more than one time */ 630 if (TLS_CRYPTO_INFO_READY(crypto_info)) 631 return -EBUSY; 632 633 rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); 634 if (rc) { 635 rc = -EFAULT; 636 goto err_crypto_info; 637 } 638 639 rc = validate_crypto_info(crypto_info, alt_crypto_info); 640 if (rc) 641 goto err_crypto_info; 642 643 cipher_desc = get_cipher_desc(crypto_info->cipher_type); 644 if (!cipher_desc) { 645 rc = -EINVAL; 646 goto err_crypto_info; 647 } 648 649 if (optlen != cipher_desc->crypto_info) { 650 rc = -EINVAL; 651 goto err_crypto_info; 652 } 653 654 rc = copy_from_sockptr_offset(crypto_info + 1, optval, 655 sizeof(*crypto_info), 656 optlen - sizeof(*crypto_info)); 657 if (rc) { 658 rc = -EFAULT; 659 goto err_crypto_info; 660 } 661 662 if (tx) { 663 rc = tls_set_device_offload(sk); 664 conf = TLS_HW; 665 if (!rc) { 666 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); 667 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 668 } else { 669 rc = tls_set_sw_offload(sk, 1); 670 if (rc) 671 goto err_crypto_info; 672 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); 673 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 674 conf = TLS_SW; 675 } 676 } else { 677 rc = tls_set_device_offload_rx(sk, ctx); 678 conf = TLS_HW; 679 if (!rc) { 680 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); 681 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 682 } else { 683 rc = tls_set_sw_offload(sk, 0); 684 if (rc) 685 goto err_crypto_info; 686 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); 687 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 688 conf = TLS_SW; 689 } 690 tls_sw_strparser_arm(sk, ctx); 691 } 692 693 if (tx) 694 ctx->tx_conf = conf; 695 else 696 ctx->rx_conf = conf; 697 update_sk_prot(sk, ctx); 698 if (tx) { 699 ctx->sk_write_space = sk->sk_write_space; 700 sk->sk_write_space = tls_write_space; 701 } else { 702 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); 703 704 tls_strp_check_rcv(&rx_ctx->strp); 705 } 706 return 0; 707 708 err_crypto_info: 709 memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); 710 return rc; 711 } 712 713 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, 714 unsigned int optlen) 715 { 716 struct tls_context *ctx = tls_get_ctx(sk); 717 unsigned int value; 718 719 if (sockptr_is_null(optval) || optlen != sizeof(value)) 720 return -EINVAL; 721 722 if (copy_from_sockptr(&value, optval, sizeof(value))) 723 return -EFAULT; 724 725 if (value > 1) 726 return -EINVAL; 727 728 ctx->zerocopy_sendfile = value; 729 730 return 0; 731 } 732 733 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, 734 unsigned int optlen) 735 { 736 struct tls_context *ctx = tls_get_ctx(sk); 737 u32 val; 738 int rc; 739 740 if (ctx->prot_info.version != TLS_1_3_VERSION || 741 sockptr_is_null(optval) || optlen < sizeof(val)) 742 return -EINVAL; 743 744 rc = copy_from_sockptr(&val, optval, sizeof(val)); 745 if (rc) 746 return -EFAULT; 747 if (val > 1) 748 return -EINVAL; 749 rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); 750 if (rc < 1) 751 return rc == 0 ? -EINVAL : rc; 752 753 lock_sock(sk); 754 rc = -EINVAL; 755 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { 756 ctx->rx_no_pad = val; 757 tls_update_rx_zc_capable(ctx); 758 rc = 0; 759 } 760 release_sock(sk); 761 762 return rc; 763 } 764 765 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, 766 unsigned int optlen) 767 { 768 int rc = 0; 769 770 switch (optname) { 771 case TLS_TX: 772 case TLS_RX: 773 lock_sock(sk); 774 rc = do_tls_setsockopt_conf(sk, optval, optlen, 775 optname == TLS_TX); 776 release_sock(sk); 777 break; 778 case TLS_TX_ZEROCOPY_RO: 779 lock_sock(sk); 780 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); 781 release_sock(sk); 782 break; 783 case TLS_RX_EXPECT_NO_PAD: 784 rc = do_tls_setsockopt_no_pad(sk, optval, optlen); 785 break; 786 default: 787 rc = -ENOPROTOOPT; 788 break; 789 } 790 return rc; 791 } 792 793 static int tls_setsockopt(struct sock *sk, int level, int optname, 794 sockptr_t optval, unsigned int optlen) 795 { 796 struct tls_context *ctx = tls_get_ctx(sk); 797 798 if (level != SOL_TLS) 799 return ctx->sk_proto->setsockopt(sk, level, optname, optval, 800 optlen); 801 802 return do_tls_setsockopt(sk, optname, optval, optlen); 803 } 804 805 struct tls_context *tls_ctx_create(struct sock *sk) 806 { 807 struct inet_connection_sock *icsk = inet_csk(sk); 808 struct tls_context *ctx; 809 810 ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); 811 if (!ctx) 812 return NULL; 813 814 mutex_init(&ctx->tx_lock); 815 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 816 ctx->sk_proto = READ_ONCE(sk->sk_prot); 817 ctx->sk = sk; 818 return ctx; 819 } 820 821 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 822 const struct proto_ops *base) 823 { 824 ops[TLS_BASE][TLS_BASE] = *base; 825 826 ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 827 ops[TLS_SW ][TLS_BASE].splice_eof = tls_sw_splice_eof; 828 829 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; 830 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; 831 ops[TLS_BASE][TLS_SW ].poll = tls_sk_poll; 832 ops[TLS_BASE][TLS_SW ].read_sock = tls_sw_read_sock; 833 834 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; 835 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; 836 ops[TLS_SW ][TLS_SW ].poll = tls_sk_poll; 837 ops[TLS_SW ][TLS_SW ].read_sock = tls_sw_read_sock; 838 839 #ifdef CONFIG_TLS_DEVICE 840 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 841 842 ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; 843 844 ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; 845 846 ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; 847 848 ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; 849 #endif 850 #ifdef CONFIG_TLS_TOE 851 ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 852 #endif 853 } 854 855 static void tls_build_proto(struct sock *sk) 856 { 857 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 858 struct proto *prot = READ_ONCE(sk->sk_prot); 859 860 /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ 861 if (ip_ver == TLSV6 && 862 unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) { 863 mutex_lock(&tcpv6_prot_mutex); 864 if (likely(prot != saved_tcpv6_prot)) { 865 build_protos(tls_prots[TLSV6], prot); 866 build_proto_ops(tls_proto_ops[TLSV6], 867 sk->sk_socket->ops); 868 smp_store_release(&saved_tcpv6_prot, prot); 869 } 870 mutex_unlock(&tcpv6_prot_mutex); 871 } 872 873 if (ip_ver == TLSV4 && 874 unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) { 875 mutex_lock(&tcpv4_prot_mutex); 876 if (likely(prot != saved_tcpv4_prot)) { 877 build_protos(tls_prots[TLSV4], prot); 878 build_proto_ops(tls_proto_ops[TLSV4], 879 sk->sk_socket->ops); 880 smp_store_release(&saved_tcpv4_prot, prot); 881 } 882 mutex_unlock(&tcpv4_prot_mutex); 883 } 884 } 885 886 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 887 const struct proto *base) 888 { 889 prot[TLS_BASE][TLS_BASE] = *base; 890 prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; 891 prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; 892 prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; 893 894 prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 895 prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; 896 prot[TLS_SW][TLS_BASE].splice_eof = tls_sw_splice_eof; 897 898 prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; 899 prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; 900 prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 901 prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; 902 903 prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; 904 prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; 905 prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 906 prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; 907 908 #ifdef CONFIG_TLS_DEVICE 909 prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 910 prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; 911 prot[TLS_HW][TLS_BASE].splice_eof = tls_device_splice_eof; 912 913 prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; 914 prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; 915 prot[TLS_HW][TLS_SW].splice_eof = tls_device_splice_eof; 916 917 prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; 918 919 prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; 920 921 prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; 922 #endif 923 #ifdef CONFIG_TLS_TOE 924 prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 925 prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash; 926 prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash; 927 #endif 928 } 929 930 static int tls_init(struct sock *sk) 931 { 932 struct tls_context *ctx; 933 int rc = 0; 934 935 tls_build_proto(sk); 936 937 #ifdef CONFIG_TLS_TOE 938 if (tls_toe_bypass(sk)) 939 return 0; 940 #endif 941 942 /* The TLS ulp is currently supported only for TCP sockets 943 * in ESTABLISHED state. 944 * Supporting sockets in LISTEN state will require us 945 * to modify the accept implementation to clone rather then 946 * share the ulp context. 947 */ 948 if (sk->sk_state != TCP_ESTABLISHED) 949 return -ENOTCONN; 950 951 /* allocate tls context */ 952 write_lock_bh(&sk->sk_callback_lock); 953 ctx = tls_ctx_create(sk); 954 if (!ctx) { 955 rc = -ENOMEM; 956 goto out; 957 } 958 959 ctx->tx_conf = TLS_BASE; 960 ctx->rx_conf = TLS_BASE; 961 update_sk_prot(sk, ctx); 962 out: 963 write_unlock_bh(&sk->sk_callback_lock); 964 return rc; 965 } 966 967 static void tls_update(struct sock *sk, struct proto *p, 968 void (*write_space)(struct sock *sk)) 969 { 970 struct tls_context *ctx; 971 972 WARN_ON_ONCE(sk->sk_prot == p); 973 974 ctx = tls_get_ctx(sk); 975 if (likely(ctx)) { 976 ctx->sk_write_space = write_space; 977 ctx->sk_proto = p; 978 } else { 979 /* Pairs with lockless read in sk_clone_lock(). */ 980 WRITE_ONCE(sk->sk_prot, p); 981 sk->sk_write_space = write_space; 982 } 983 } 984 985 static u16 tls_user_config(struct tls_context *ctx, bool tx) 986 { 987 u16 config = tx ? ctx->tx_conf : ctx->rx_conf; 988 989 switch (config) { 990 case TLS_BASE: 991 return TLS_CONF_BASE; 992 case TLS_SW: 993 return TLS_CONF_SW; 994 case TLS_HW: 995 return TLS_CONF_HW; 996 case TLS_HW_RECORD: 997 return TLS_CONF_HW_RECORD; 998 } 999 return 0; 1000 } 1001 1002 static int tls_get_info(const struct sock *sk, struct sk_buff *skb) 1003 { 1004 u16 version, cipher_type; 1005 struct tls_context *ctx; 1006 struct nlattr *start; 1007 int err; 1008 1009 start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); 1010 if (!start) 1011 return -EMSGSIZE; 1012 1013 rcu_read_lock(); 1014 ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); 1015 if (!ctx) { 1016 err = 0; 1017 goto nla_failure; 1018 } 1019 version = ctx->prot_info.version; 1020 if (version) { 1021 err = nla_put_u16(skb, TLS_INFO_VERSION, version); 1022 if (err) 1023 goto nla_failure; 1024 } 1025 cipher_type = ctx->prot_info.cipher_type; 1026 if (cipher_type) { 1027 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); 1028 if (err) 1029 goto nla_failure; 1030 } 1031 err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); 1032 if (err) 1033 goto nla_failure; 1034 1035 err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); 1036 if (err) 1037 goto nla_failure; 1038 1039 if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) { 1040 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX); 1041 if (err) 1042 goto nla_failure; 1043 } 1044 if (ctx->rx_no_pad) { 1045 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); 1046 if (err) 1047 goto nla_failure; 1048 } 1049 1050 rcu_read_unlock(); 1051 nla_nest_end(skb, start); 1052 return 0; 1053 1054 nla_failure: 1055 rcu_read_unlock(); 1056 nla_nest_cancel(skb, start); 1057 return err; 1058 } 1059 1060 static size_t tls_get_info_size(const struct sock *sk) 1061 { 1062 size_t size = 0; 1063 1064 size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ 1065 nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ 1066 nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ 1067 nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ 1068 nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ 1069 nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ 1070 nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ 1071 0; 1072 1073 return size; 1074 } 1075 1076 static int __net_init tls_init_net(struct net *net) 1077 { 1078 int err; 1079 1080 net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); 1081 if (!net->mib.tls_statistics) 1082 return -ENOMEM; 1083 1084 err = tls_proc_init(net); 1085 if (err) 1086 goto err_free_stats; 1087 1088 return 0; 1089 err_free_stats: 1090 free_percpu(net->mib.tls_statistics); 1091 return err; 1092 } 1093 1094 static void __net_exit tls_exit_net(struct net *net) 1095 { 1096 tls_proc_fini(net); 1097 free_percpu(net->mib.tls_statistics); 1098 } 1099 1100 static struct pernet_operations tls_proc_ops = { 1101 .init = tls_init_net, 1102 .exit = tls_exit_net, 1103 }; 1104 1105 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 1106 .name = "tls", 1107 .owner = THIS_MODULE, 1108 .init = tls_init, 1109 .update = tls_update, 1110 .get_info = tls_get_info, 1111 .get_info_size = tls_get_info_size, 1112 }; 1113 1114 static int __init tls_register(void) 1115 { 1116 int err; 1117 1118 err = register_pernet_subsys(&tls_proc_ops); 1119 if (err) 1120 return err; 1121 1122 err = tls_strp_dev_init(); 1123 if (err) 1124 goto err_pernet; 1125 1126 err = tls_device_init(); 1127 if (err) 1128 goto err_strp; 1129 1130 tcp_register_ulp(&tcp_tls_ulp_ops); 1131 1132 return 0; 1133 err_strp: 1134 tls_strp_dev_exit(); 1135 err_pernet: 1136 unregister_pernet_subsys(&tls_proc_ops); 1137 return err; 1138 } 1139 1140 static void __exit tls_unregister(void) 1141 { 1142 tcp_unregister_ulp(&tcp_tls_ulp_ops); 1143 tls_strp_dev_exit(); 1144 tls_device_cleanup(); 1145 unregister_pernet_subsys(&tls_proc_ops); 1146 } 1147 1148 module_init(tls_register); 1149 module_exit(tls_unregister); 1150