1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 #include <linux/inetdevice.h> 42 #include <linux/inet_diag.h> 43 44 #include <net/snmp.h> 45 #include <net/tls.h> 46 #include "tls.h" 47 48 MODULE_AUTHOR("Mellanox Technologies"); 49 MODULE_DESCRIPTION("Transport Layer Security Support"); 50 MODULE_LICENSE("Dual BSD/GPL"); 51 MODULE_ALIAS_TCP_ULP("tls"); 52 53 enum { 54 TLSV4, 55 TLSV6, 56 TLS_NUM_PROTS, 57 }; 58 59 #define CHECK_CIPHER_DESC(cipher,ci) \ 60 static_assert(cipher ## _IV_SIZE <= TLS_MAX_IV_SIZE); \ 61 static_assert(cipher ## _SALT_SIZE <= TLS_MAX_SALT_SIZE); \ 62 static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE); \ 63 static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE); \ 64 static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE); \ 65 static_assert(sizeof_field(struct ci, key) == cipher ## _KEY_SIZE); \ 66 static_assert(sizeof_field(struct ci, salt) == cipher ## _SALT_SIZE); \ 67 static_assert(sizeof_field(struct ci, rec_seq) == cipher ## _REC_SEQ_SIZE); 68 69 #define __CIPHER_DESC(ci) \ 70 .iv_offset = offsetof(struct ci, iv), \ 71 .key_offset = offsetof(struct ci, key), \ 72 .salt_offset = offsetof(struct ci, salt), \ 73 .rec_seq_offset = offsetof(struct ci, rec_seq), \ 74 .crypto_info = sizeof(struct ci) 75 76 #define CIPHER_DESC(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ 77 .nonce = cipher ## _IV_SIZE, \ 78 .iv = cipher ## _IV_SIZE, \ 79 .key = cipher ## _KEY_SIZE, \ 80 .salt = cipher ## _SALT_SIZE, \ 81 .tag = cipher ## _TAG_SIZE, \ 82 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 83 .cipher_name = algname, \ 84 .offloadable = _offloadable, \ 85 __CIPHER_DESC(ci), \ 86 } 87 88 #define CIPHER_DESC_NONCE0(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ 89 .nonce = 0, \ 90 .iv = cipher ## _IV_SIZE, \ 91 .key = cipher ## _KEY_SIZE, \ 92 .salt = cipher ## _SALT_SIZE, \ 93 .tag = cipher ## _TAG_SIZE, \ 94 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 95 .cipher_name = algname, \ 96 .offloadable = _offloadable, \ 97 __CIPHER_DESC(ci), \ 98 } 99 100 const struct tls_cipher_desc tls_cipher_desc[TLS_CIPHER_MAX + 1 - TLS_CIPHER_MIN] = { 101 CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128, "gcm(aes)", true), 102 CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256, "gcm(aes)", true), 103 CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128, "ccm(aes)", false), 104 CIPHER_DESC_NONCE0(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305, "rfc7539(chacha20,poly1305)", false), 105 CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm, "gcm(sm4)", false), 106 CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm, "ccm(sm4)", false), 107 CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128, "gcm(aria)", false), 108 CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256, "gcm(aria)", false), 109 }; 110 111 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128); 112 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256); 113 CHECK_CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128); 114 CHECK_CIPHER_DESC(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305); 115 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm); 116 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm); 117 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128); 118 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256); 119 120 static const struct proto *saved_tcpv6_prot; 121 static DEFINE_MUTEX(tcpv6_prot_mutex); 122 static const struct proto *saved_tcpv4_prot; 123 static DEFINE_MUTEX(tcpv4_prot_mutex); 124 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 125 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 126 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 127 const struct proto *base); 128 129 void update_sk_prot(struct sock *sk, struct tls_context *ctx) 130 { 131 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 132 133 WRITE_ONCE(sk->sk_prot, 134 &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); 135 WRITE_ONCE(sk->sk_socket->ops, 136 &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]); 137 } 138 139 int wait_on_pending_writer(struct sock *sk, long *timeo) 140 { 141 DEFINE_WAIT_FUNC(wait, woken_wake_function); 142 int ret, rc = 0; 143 144 add_wait_queue(sk_sleep(sk), &wait); 145 while (1) { 146 if (!*timeo) { 147 rc = -EAGAIN; 148 break; 149 } 150 151 if (signal_pending(current)) { 152 rc = sock_intr_errno(*timeo); 153 break; 154 } 155 156 ret = sk_wait_event(sk, timeo, 157 !READ_ONCE(sk->sk_write_pending), &wait); 158 if (ret) { 159 if (ret < 0) 160 rc = ret; 161 break; 162 } 163 } 164 remove_wait_queue(sk_sleep(sk), &wait); 165 return rc; 166 } 167 168 int tls_push_sg(struct sock *sk, 169 struct tls_context *ctx, 170 struct scatterlist *sg, 171 u16 first_offset, 172 int flags) 173 { 174 struct bio_vec bvec; 175 struct msghdr msg = { 176 .msg_flags = MSG_SPLICE_PAGES | flags, 177 }; 178 int ret = 0; 179 struct page *p; 180 size_t size; 181 int offset = first_offset; 182 183 size = sg->length - offset; 184 offset += sg->offset; 185 186 ctx->splicing_pages = true; 187 while (1) { 188 /* is sending application-limited? */ 189 tcp_rate_check_app_limited(sk); 190 p = sg_page(sg); 191 retry: 192 bvec_set_page(&bvec, p, size, offset); 193 iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size); 194 195 ret = tcp_sendmsg_locked(sk, &msg, size); 196 197 if (ret != size) { 198 if (ret > 0) { 199 offset += ret; 200 size -= ret; 201 goto retry; 202 } 203 204 offset -= sg->offset; 205 ctx->partially_sent_offset = offset; 206 ctx->partially_sent_record = (void *)sg; 207 ctx->splicing_pages = false; 208 return ret; 209 } 210 211 put_page(p); 212 sk_mem_uncharge(sk, sg->length); 213 sg = sg_next(sg); 214 if (!sg) 215 break; 216 217 offset = sg->offset; 218 size = sg->length; 219 } 220 221 ctx->splicing_pages = false; 222 223 return 0; 224 } 225 226 static int tls_handle_open_record(struct sock *sk, int flags) 227 { 228 struct tls_context *ctx = tls_get_ctx(sk); 229 230 if (tls_is_pending_open_record(ctx)) 231 return ctx->push_pending_record(sk, flags); 232 233 return 0; 234 } 235 236 int tls_process_cmsg(struct sock *sk, struct msghdr *msg, 237 unsigned char *record_type) 238 { 239 struct cmsghdr *cmsg; 240 int rc = -EINVAL; 241 242 for_each_cmsghdr(cmsg, msg) { 243 if (!CMSG_OK(msg, cmsg)) 244 return -EINVAL; 245 if (cmsg->cmsg_level != SOL_TLS) 246 continue; 247 248 switch (cmsg->cmsg_type) { 249 case TLS_SET_RECORD_TYPE: 250 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 251 return -EINVAL; 252 253 if (msg->msg_flags & MSG_MORE) 254 return -EINVAL; 255 256 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 257 258 rc = tls_handle_open_record(sk, msg->msg_flags); 259 break; 260 default: 261 return -EINVAL; 262 } 263 } 264 265 return rc; 266 } 267 268 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, 269 int flags) 270 { 271 struct scatterlist *sg; 272 u16 offset; 273 274 sg = ctx->partially_sent_record; 275 offset = ctx->partially_sent_offset; 276 277 ctx->partially_sent_record = NULL; 278 return tls_push_sg(sk, ctx, sg, offset, flags); 279 } 280 281 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) 282 { 283 struct scatterlist *sg; 284 285 for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { 286 put_page(sg_page(sg)); 287 sk_mem_uncharge(sk, sg->length); 288 } 289 ctx->partially_sent_record = NULL; 290 } 291 292 static void tls_write_space(struct sock *sk) 293 { 294 struct tls_context *ctx = tls_get_ctx(sk); 295 296 /* If splicing_pages call lower protocol write space handler 297 * to ensure we wake up any waiting operations there. For example 298 * if splicing pages where to call sk_wait_event. 299 */ 300 if (ctx->splicing_pages) { 301 ctx->sk_write_space(sk); 302 return; 303 } 304 305 #ifdef CONFIG_TLS_DEVICE 306 if (ctx->tx_conf == TLS_HW) 307 tls_device_write_space(sk, ctx); 308 else 309 #endif 310 tls_sw_write_space(sk, ctx); 311 312 ctx->sk_write_space(sk); 313 } 314 315 /** 316 * tls_ctx_free() - free TLS ULP context 317 * @sk: socket to with @ctx is attached 318 * @ctx: TLS context structure 319 * 320 * Free TLS context. If @sk is %NULL caller guarantees that the socket 321 * to which @ctx was attached has no outstanding references. 322 */ 323 void tls_ctx_free(struct sock *sk, struct tls_context *ctx) 324 { 325 if (!ctx) 326 return; 327 328 memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); 329 memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); 330 mutex_destroy(&ctx->tx_lock); 331 332 if (sk) 333 kfree_rcu(ctx, rcu); 334 else 335 kfree(ctx); 336 } 337 338 static void tls_sk_proto_cleanup(struct sock *sk, 339 struct tls_context *ctx, long timeo) 340 { 341 if (unlikely(sk->sk_write_pending) && 342 !wait_on_pending_writer(sk, &timeo)) 343 tls_handle_open_record(sk, 0); 344 345 /* We need these for tls_sw_fallback handling of other packets */ 346 if (ctx->tx_conf == TLS_SW) { 347 tls_sw_release_resources_tx(sk); 348 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 349 } else if (ctx->tx_conf == TLS_HW) { 350 tls_device_free_resources_tx(sk); 351 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 352 } 353 354 if (ctx->rx_conf == TLS_SW) { 355 tls_sw_release_resources_rx(sk); 356 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 357 } else if (ctx->rx_conf == TLS_HW) { 358 tls_device_offload_cleanup_rx(sk); 359 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 360 } 361 } 362 363 static void tls_sk_proto_close(struct sock *sk, long timeout) 364 { 365 struct inet_connection_sock *icsk = inet_csk(sk); 366 struct tls_context *ctx = tls_get_ctx(sk); 367 long timeo = sock_sndtimeo(sk, 0); 368 bool free_ctx; 369 370 if (ctx->tx_conf == TLS_SW) 371 tls_sw_cancel_work_tx(ctx); 372 373 lock_sock(sk); 374 free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; 375 376 if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) 377 tls_sk_proto_cleanup(sk, ctx, timeo); 378 379 write_lock_bh(&sk->sk_callback_lock); 380 if (free_ctx) 381 rcu_assign_pointer(icsk->icsk_ulp_data, NULL); 382 WRITE_ONCE(sk->sk_prot, ctx->sk_proto); 383 if (sk->sk_write_space == tls_write_space) 384 sk->sk_write_space = ctx->sk_write_space; 385 write_unlock_bh(&sk->sk_callback_lock); 386 release_sock(sk); 387 if (ctx->tx_conf == TLS_SW) 388 tls_sw_free_ctx_tx(ctx); 389 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 390 tls_sw_strparser_done(ctx); 391 if (ctx->rx_conf == TLS_SW) 392 tls_sw_free_ctx_rx(ctx); 393 ctx->sk_proto->close(sk, timeout); 394 395 if (free_ctx) 396 tls_ctx_free(sk, ctx); 397 } 398 399 static __poll_t tls_sk_poll(struct file *file, struct socket *sock, 400 struct poll_table_struct *wait) 401 { 402 struct tls_sw_context_rx *ctx; 403 struct tls_context *tls_ctx; 404 struct sock *sk = sock->sk; 405 __poll_t mask = 0; 406 u8 shutdown; 407 int state; 408 409 mask = tcp_poll(file, sock, wait); 410 411 state = inet_sk_state_load(sk); 412 shutdown = READ_ONCE(sk->sk_shutdown); 413 if (unlikely(state != TCP_ESTABLISHED || shutdown & RCV_SHUTDOWN)) 414 return mask; 415 416 tls_ctx = tls_get_ctx(sk); 417 ctx = tls_sw_ctx_rx(tls_ctx); 418 419 if ((skb_queue_empty_lockless(&ctx->rx_list) && 420 !tls_strp_msg_ready(ctx)) || 421 READ_ONCE(ctx->key_update_pending)) 422 mask &= ~(EPOLLIN | EPOLLRDNORM); 423 424 return mask; 425 } 426 427 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, 428 int __user *optlen, int tx) 429 { 430 int rc = 0; 431 const struct tls_cipher_desc *cipher_desc; 432 struct tls_context *ctx = tls_get_ctx(sk); 433 struct tls_crypto_info *crypto_info; 434 struct cipher_context *cctx; 435 int len; 436 437 if (get_user(len, optlen)) 438 return -EFAULT; 439 440 if (!optval || (len < sizeof(*crypto_info))) { 441 rc = -EINVAL; 442 goto out; 443 } 444 445 if (!ctx) { 446 rc = -EBUSY; 447 goto out; 448 } 449 450 /* get user crypto info */ 451 if (tx) { 452 crypto_info = &ctx->crypto_send.info; 453 cctx = &ctx->tx; 454 } else { 455 crypto_info = &ctx->crypto_recv.info; 456 cctx = &ctx->rx; 457 } 458 459 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 460 rc = -EBUSY; 461 goto out; 462 } 463 464 if (len == sizeof(*crypto_info)) { 465 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 466 rc = -EFAULT; 467 goto out; 468 } 469 470 cipher_desc = get_cipher_desc(crypto_info->cipher_type); 471 if (!cipher_desc || len != cipher_desc->crypto_info) { 472 rc = -EINVAL; 473 goto out; 474 } 475 476 memcpy(crypto_info_iv(crypto_info, cipher_desc), 477 cctx->iv + cipher_desc->salt, cipher_desc->iv); 478 memcpy(crypto_info_rec_seq(crypto_info, cipher_desc), 479 cctx->rec_seq, cipher_desc->rec_seq); 480 481 if (copy_to_user(optval, crypto_info, cipher_desc->crypto_info)) 482 rc = -EFAULT; 483 484 out: 485 return rc; 486 } 487 488 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, 489 int __user *optlen) 490 { 491 struct tls_context *ctx = tls_get_ctx(sk); 492 unsigned int value; 493 int len; 494 495 if (get_user(len, optlen)) 496 return -EFAULT; 497 498 if (len != sizeof(value)) 499 return -EINVAL; 500 501 value = ctx->zerocopy_sendfile; 502 if (copy_to_user(optval, &value, sizeof(value))) 503 return -EFAULT; 504 505 return 0; 506 } 507 508 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, 509 int __user *optlen) 510 { 511 struct tls_context *ctx = tls_get_ctx(sk); 512 int value, len; 513 514 if (ctx->prot_info.version != TLS_1_3_VERSION) 515 return -EINVAL; 516 517 if (get_user(len, optlen)) 518 return -EFAULT; 519 if (len < sizeof(value)) 520 return -EINVAL; 521 522 value = -EINVAL; 523 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 524 value = ctx->rx_no_pad; 525 if (value < 0) 526 return value; 527 528 if (put_user(sizeof(value), optlen)) 529 return -EFAULT; 530 if (copy_to_user(optval, &value, sizeof(value))) 531 return -EFAULT; 532 533 return 0; 534 } 535 536 static int do_tls_getsockopt_tx_payload_len(struct sock *sk, char __user *optval, 537 int __user *optlen) 538 { 539 struct tls_context *ctx = tls_get_ctx(sk); 540 u16 payload_len = ctx->tx_max_payload_len; 541 int len; 542 543 if (get_user(len, optlen)) 544 return -EFAULT; 545 546 if (len < sizeof(payload_len)) 547 return -EINVAL; 548 549 if (put_user(sizeof(payload_len), optlen)) 550 return -EFAULT; 551 552 if (copy_to_user(optval, &payload_len, sizeof(payload_len))) 553 return -EFAULT; 554 555 return 0; 556 } 557 558 static int do_tls_getsockopt(struct sock *sk, int optname, 559 char __user *optval, int __user *optlen) 560 { 561 int rc = 0; 562 563 lock_sock(sk); 564 565 switch (optname) { 566 case TLS_TX: 567 case TLS_RX: 568 rc = do_tls_getsockopt_conf(sk, optval, optlen, 569 optname == TLS_TX); 570 break; 571 case TLS_TX_ZEROCOPY_RO: 572 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); 573 break; 574 case TLS_RX_EXPECT_NO_PAD: 575 rc = do_tls_getsockopt_no_pad(sk, optval, optlen); 576 break; 577 case TLS_TX_MAX_PAYLOAD_LEN: 578 rc = do_tls_getsockopt_tx_payload_len(sk, optval, optlen); 579 break; 580 default: 581 rc = -ENOPROTOOPT; 582 break; 583 } 584 585 release_sock(sk); 586 587 return rc; 588 } 589 590 static int tls_getsockopt(struct sock *sk, int level, int optname, 591 char __user *optval, int __user *optlen) 592 { 593 struct tls_context *ctx = tls_get_ctx(sk); 594 595 if (level != SOL_TLS) 596 return ctx->sk_proto->getsockopt(sk, level, 597 optname, optval, optlen); 598 599 return do_tls_getsockopt(sk, optname, optval, optlen); 600 } 601 602 static int validate_crypto_info(const struct tls_crypto_info *crypto_info, 603 const struct tls_crypto_info *alt_crypto_info) 604 { 605 if (crypto_info->version != TLS_1_2_VERSION && 606 crypto_info->version != TLS_1_3_VERSION) 607 return -EINVAL; 608 609 switch (crypto_info->cipher_type) { 610 case TLS_CIPHER_ARIA_GCM_128: 611 case TLS_CIPHER_ARIA_GCM_256: 612 if (crypto_info->version != TLS_1_2_VERSION) 613 return -EINVAL; 614 break; 615 } 616 617 /* Ensure that TLS version and ciphers are same in both directions */ 618 if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { 619 if (alt_crypto_info->version != crypto_info->version || 620 alt_crypto_info->cipher_type != crypto_info->cipher_type) 621 return -EINVAL; 622 } 623 624 return 0; 625 } 626 627 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, 628 unsigned int optlen, int tx) 629 { 630 struct tls_crypto_info *crypto_info, *alt_crypto_info; 631 struct tls_crypto_info *old_crypto_info = NULL; 632 struct tls_context *ctx = tls_get_ctx(sk); 633 const struct tls_cipher_desc *cipher_desc; 634 union tls_crypto_context *crypto_ctx; 635 union tls_crypto_context tmp = {}; 636 bool update = false; 637 int rc = 0; 638 int conf; 639 640 /* TLS and sockmap are mutually exclusive. A socket already in a 641 * sockmap (i.e. with a psock attached) cannot be upgraded to TLS. 642 * sockmap rejects TLS sockets already (see sk_psock_init()). 643 */ 644 rcu_read_lock(); 645 if (sk_psock(sk)) { 646 rcu_read_unlock(); 647 return -EINVAL; 648 } 649 rcu_read_unlock(); 650 651 if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) 652 return -EINVAL; 653 654 if (tx) { 655 crypto_ctx = &ctx->crypto_send; 656 alt_crypto_info = &ctx->crypto_recv.info; 657 } else { 658 crypto_ctx = &ctx->crypto_recv; 659 alt_crypto_info = &ctx->crypto_send.info; 660 } 661 662 crypto_info = &crypto_ctx->info; 663 664 if (TLS_CRYPTO_INFO_READY(crypto_info)) { 665 /* Currently we only support setting crypto info more 666 * than one time for TLS 1.3 667 */ 668 if (crypto_info->version != TLS_1_3_VERSION) { 669 TLS_INC_STATS(sock_net(sk), tx ? LINUX_MIB_TLSTXREKEYERROR 670 : LINUX_MIB_TLSRXREKEYERROR); 671 return -EBUSY; 672 } 673 674 update = true; 675 old_crypto_info = crypto_info; 676 crypto_info = &tmp.info; 677 crypto_ctx = &tmp; 678 } 679 680 rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); 681 if (rc) { 682 rc = -EFAULT; 683 goto err_crypto_info; 684 } 685 686 if (update) { 687 /* Ensure that TLS version and ciphers are not modified */ 688 if (crypto_info->version != old_crypto_info->version || 689 crypto_info->cipher_type != old_crypto_info->cipher_type) 690 rc = -EINVAL; 691 } else { 692 rc = validate_crypto_info(crypto_info, alt_crypto_info); 693 } 694 if (rc) 695 goto err_crypto_info; 696 697 cipher_desc = get_cipher_desc(crypto_info->cipher_type); 698 if (!cipher_desc) { 699 rc = -EINVAL; 700 goto err_crypto_info; 701 } 702 703 if (optlen != cipher_desc->crypto_info) { 704 rc = -EINVAL; 705 goto err_crypto_info; 706 } 707 708 rc = copy_from_sockptr_offset(crypto_info + 1, optval, 709 sizeof(*crypto_info), 710 optlen - sizeof(*crypto_info)); 711 if (rc) { 712 rc = -EFAULT; 713 goto err_crypto_info; 714 } 715 716 if (tx) { 717 rc = tls_set_device_offload(sk); 718 conf = TLS_HW; 719 if (!rc) { 720 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); 721 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 722 } else { 723 rc = tls_set_sw_offload(sk, 1, 724 update ? crypto_info : NULL); 725 if (rc) 726 goto err_crypto_info; 727 728 if (update) { 729 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK); 730 } else { 731 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); 732 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 733 } 734 conf = TLS_SW; 735 } 736 } else { 737 rc = tls_set_device_offload_rx(sk, ctx); 738 conf = TLS_HW; 739 if (!rc) { 740 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); 741 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 742 } else { 743 rc = tls_set_sw_offload(sk, 0, 744 update ? crypto_info : NULL); 745 if (rc) 746 goto err_crypto_info; 747 748 if (update) { 749 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK); 750 } else { 751 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); 752 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 753 } 754 conf = TLS_SW; 755 } 756 if (!update) 757 tls_sw_strparser_arm(sk, ctx); 758 } 759 760 if (tx) 761 ctx->tx_conf = conf; 762 else 763 ctx->rx_conf = conf; 764 update_sk_prot(sk, ctx); 765 766 if (update) 767 return 0; 768 769 if (tx) { 770 ctx->sk_write_space = sk->sk_write_space; 771 sk->sk_write_space = tls_write_space; 772 } else { 773 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); 774 775 tls_strp_check_rcv(&rx_ctx->strp, true); 776 } 777 return 0; 778 779 err_crypto_info: 780 if (update) { 781 TLS_INC_STATS(sock_net(sk), tx ? LINUX_MIB_TLSTXREKEYERROR 782 : LINUX_MIB_TLSRXREKEYERROR); 783 } 784 memzero_explicit(crypto_ctx, sizeof(*crypto_ctx)); 785 return rc; 786 } 787 788 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, 789 unsigned int optlen) 790 { 791 struct tls_context *ctx = tls_get_ctx(sk); 792 unsigned int value; 793 794 if (sockptr_is_null(optval) || optlen != sizeof(value)) 795 return -EINVAL; 796 797 if (copy_from_sockptr(&value, optval, sizeof(value))) 798 return -EFAULT; 799 800 if (value > 1) 801 return -EINVAL; 802 803 ctx->zerocopy_sendfile = value; 804 805 return 0; 806 } 807 808 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, 809 unsigned int optlen) 810 { 811 struct tls_context *ctx = tls_get_ctx(sk); 812 u32 val; 813 int rc; 814 815 if (ctx->prot_info.version != TLS_1_3_VERSION || 816 sockptr_is_null(optval) || optlen < sizeof(val)) 817 return -EINVAL; 818 819 rc = copy_from_sockptr(&val, optval, sizeof(val)); 820 if (rc) 821 return -EFAULT; 822 if (val > 1) 823 return -EINVAL; 824 rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); 825 if (rc < 1) 826 return rc == 0 ? -EINVAL : rc; 827 828 lock_sock(sk); 829 rc = -EINVAL; 830 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { 831 ctx->rx_no_pad = val; 832 tls_update_rx_zc_capable(ctx); 833 rc = 0; 834 } 835 release_sock(sk); 836 837 return rc; 838 } 839 840 static int do_tls_setsockopt_tx_payload_len(struct sock *sk, sockptr_t optval, 841 unsigned int optlen) 842 { 843 struct tls_context *ctx = tls_get_ctx(sk); 844 struct tls_sw_context_tx *sw_ctx = tls_sw_ctx_tx(ctx); 845 u16 value; 846 bool tls_13 = ctx->prot_info.version == TLS_1_3_VERSION; 847 848 if (sw_ctx && sw_ctx->open_rec) 849 return -EBUSY; 850 851 if (sockptr_is_null(optval) || optlen != sizeof(value)) 852 return -EINVAL; 853 854 if (copy_from_sockptr(&value, optval, sizeof(value))) 855 return -EFAULT; 856 857 if (value < TLS_MIN_RECORD_SIZE_LIM - (tls_13 ? 1 : 0) || 858 value > TLS_MAX_PAYLOAD_SIZE) 859 return -EINVAL; 860 861 ctx->tx_max_payload_len = value; 862 863 return 0; 864 } 865 866 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, 867 unsigned int optlen) 868 { 869 int rc = 0; 870 871 switch (optname) { 872 case TLS_TX: 873 case TLS_RX: 874 lock_sock(sk); 875 rc = do_tls_setsockopt_conf(sk, optval, optlen, 876 optname == TLS_TX); 877 release_sock(sk); 878 break; 879 case TLS_TX_ZEROCOPY_RO: 880 lock_sock(sk); 881 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); 882 release_sock(sk); 883 break; 884 case TLS_RX_EXPECT_NO_PAD: 885 rc = do_tls_setsockopt_no_pad(sk, optval, optlen); 886 break; 887 case TLS_TX_MAX_PAYLOAD_LEN: 888 lock_sock(sk); 889 rc = do_tls_setsockopt_tx_payload_len(sk, optval, optlen); 890 release_sock(sk); 891 break; 892 default: 893 rc = -ENOPROTOOPT; 894 break; 895 } 896 return rc; 897 } 898 899 static int tls_setsockopt(struct sock *sk, int level, int optname, 900 sockptr_t optval, unsigned int optlen) 901 { 902 struct tls_context *ctx = tls_get_ctx(sk); 903 904 if (level != SOL_TLS) 905 return ctx->sk_proto->setsockopt(sk, level, optname, optval, 906 optlen); 907 908 return do_tls_setsockopt(sk, optname, optval, optlen); 909 } 910 911 static int tls_disconnect(struct sock *sk, int flags) 912 { 913 return -EOPNOTSUPP; 914 } 915 916 struct tls_context *tls_ctx_create(struct sock *sk) 917 { 918 struct inet_connection_sock *icsk = inet_csk(sk); 919 struct tls_context *ctx; 920 921 ctx = kzalloc_obj(*ctx, GFP_ATOMIC); 922 if (!ctx) 923 return NULL; 924 925 mutex_init(&ctx->tx_lock); 926 ctx->sk_proto = READ_ONCE(sk->sk_prot); 927 ctx->sk = sk; 928 /* Release semantic of rcu_assign_pointer() ensures that 929 * ctx->sk_proto is visible before changing sk->sk_prot in 930 * update_sk_prot(), and prevents reading uninitialized value in 931 * tls_{getsockopt, setsockopt}. Note that we do not need a 932 * read barrier in tls_{getsockopt,setsockopt} as there is an 933 * address dependency between sk->sk_proto->{getsockopt,setsockopt} 934 * and ctx->sk_proto. 935 */ 936 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 937 return ctx; 938 } 939 940 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 941 const struct proto_ops *base) 942 { 943 ops[TLS_BASE][TLS_BASE] = *base; 944 945 ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 946 ops[TLS_SW ][TLS_BASE].splice_eof = tls_sw_splice_eof; 947 948 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; 949 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; 950 ops[TLS_BASE][TLS_SW ].poll = tls_sk_poll; 951 ops[TLS_BASE][TLS_SW ].read_sock = tls_sw_read_sock; 952 953 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; 954 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; 955 ops[TLS_SW ][TLS_SW ].poll = tls_sk_poll; 956 ops[TLS_SW ][TLS_SW ].read_sock = tls_sw_read_sock; 957 958 #ifdef CONFIG_TLS_DEVICE 959 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 960 961 ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; 962 963 ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; 964 965 ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; 966 967 ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; 968 #endif 969 } 970 971 static void tls_build_proto(struct sock *sk) 972 { 973 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 974 struct proto *prot = READ_ONCE(sk->sk_prot); 975 976 /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ 977 if (ip_ver == TLSV6 && 978 unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) { 979 mutex_lock(&tcpv6_prot_mutex); 980 if (likely(prot != saved_tcpv6_prot)) { 981 build_protos(tls_prots[TLSV6], prot); 982 build_proto_ops(tls_proto_ops[TLSV6], 983 sk->sk_socket->ops); 984 smp_store_release(&saved_tcpv6_prot, prot); 985 } 986 mutex_unlock(&tcpv6_prot_mutex); 987 } 988 989 if (ip_ver == TLSV4 && 990 unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) { 991 mutex_lock(&tcpv4_prot_mutex); 992 if (likely(prot != saved_tcpv4_prot)) { 993 build_protos(tls_prots[TLSV4], prot); 994 build_proto_ops(tls_proto_ops[TLSV4], 995 sk->sk_socket->ops); 996 smp_store_release(&saved_tcpv4_prot, prot); 997 } 998 mutex_unlock(&tcpv4_prot_mutex); 999 } 1000 } 1001 1002 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 1003 const struct proto *base) 1004 { 1005 prot[TLS_BASE][TLS_BASE] = *base; 1006 prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; 1007 prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; 1008 prot[TLS_BASE][TLS_BASE].disconnect = tls_disconnect; 1009 prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; 1010 1011 prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 1012 prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; 1013 prot[TLS_SW][TLS_BASE].splice_eof = tls_sw_splice_eof; 1014 1015 prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; 1016 prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; 1017 prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 1018 prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; 1019 1020 prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; 1021 prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; 1022 prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 1023 prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; 1024 1025 #ifdef CONFIG_TLS_DEVICE 1026 prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 1027 prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; 1028 prot[TLS_HW][TLS_BASE].splice_eof = tls_device_splice_eof; 1029 1030 prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; 1031 prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; 1032 prot[TLS_HW][TLS_SW].splice_eof = tls_device_splice_eof; 1033 1034 prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; 1035 1036 prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; 1037 1038 prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; 1039 #endif 1040 } 1041 1042 static int tls_init(struct sock *sk) 1043 { 1044 struct tls_context *ctx; 1045 int rc = 0; 1046 1047 tls_build_proto(sk); 1048 1049 /* The TLS ulp is currently supported only for TCP sockets 1050 * in ESTABLISHED state. 1051 * Supporting sockets in LISTEN state will require us 1052 * to modify the accept implementation to clone rather then 1053 * share the ulp context. 1054 */ 1055 if (sk->sk_state != TCP_ESTABLISHED) 1056 return -ENOTCONN; 1057 1058 /* allocate tls context */ 1059 write_lock_bh(&sk->sk_callback_lock); 1060 ctx = tls_ctx_create(sk); 1061 if (!ctx) { 1062 rc = -ENOMEM; 1063 goto out; 1064 } 1065 1066 ctx->tx_conf = TLS_BASE; 1067 ctx->rx_conf = TLS_BASE; 1068 ctx->tx_max_payload_len = TLS_MAX_PAYLOAD_SIZE; 1069 update_sk_prot(sk, ctx); 1070 out: 1071 write_unlock_bh(&sk->sk_callback_lock); 1072 return rc; 1073 } 1074 1075 static void tls_update(struct sock *sk, struct proto *p, 1076 void (*write_space)(struct sock *sk)) 1077 { 1078 struct tls_context *ctx; 1079 1080 WARN_ON_ONCE(sk->sk_prot == p); 1081 1082 ctx = tls_get_ctx(sk); 1083 if (likely(ctx)) { 1084 ctx->sk_write_space = write_space; 1085 ctx->sk_proto = p; 1086 } else { 1087 /* Pairs with lockless read in sk_clone_lock(). */ 1088 WRITE_ONCE(sk->sk_prot, p); 1089 sk->sk_write_space = write_space; 1090 } 1091 } 1092 1093 static u16 tls_user_config(struct tls_context *ctx, bool tx) 1094 { 1095 u16 config = tx ? ctx->tx_conf : ctx->rx_conf; 1096 1097 switch (config) { 1098 case TLS_BASE: 1099 return TLS_CONF_BASE; 1100 case TLS_SW: 1101 return TLS_CONF_SW; 1102 case TLS_HW: 1103 return TLS_CONF_HW; 1104 } 1105 return 0; 1106 } 1107 1108 static int tls_get_info(struct sock *sk, struct sk_buff *skb, bool net_admin) 1109 { 1110 u16 version, cipher_type; 1111 struct tls_context *ctx; 1112 struct nlattr *start; 1113 int err; 1114 1115 start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); 1116 if (!start) 1117 return -EMSGSIZE; 1118 1119 rcu_read_lock(); 1120 ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); 1121 if (!ctx) { 1122 err = 0; 1123 goto nla_failure; 1124 } 1125 version = ctx->prot_info.version; 1126 if (version) { 1127 err = nla_put_u16(skb, TLS_INFO_VERSION, version); 1128 if (err) 1129 goto nla_failure; 1130 } 1131 cipher_type = ctx->prot_info.cipher_type; 1132 if (cipher_type) { 1133 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); 1134 if (err) 1135 goto nla_failure; 1136 } 1137 err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); 1138 if (err) 1139 goto nla_failure; 1140 1141 err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); 1142 if (err) 1143 goto nla_failure; 1144 1145 if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) { 1146 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX); 1147 if (err) 1148 goto nla_failure; 1149 } 1150 if (ctx->rx_no_pad) { 1151 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); 1152 if (err) 1153 goto nla_failure; 1154 } 1155 1156 err = nla_put_u16(skb, TLS_INFO_TX_MAX_PAYLOAD_LEN, 1157 ctx->tx_max_payload_len); 1158 1159 if (err) 1160 goto nla_failure; 1161 1162 rcu_read_unlock(); 1163 nla_nest_end(skb, start); 1164 return 0; 1165 1166 nla_failure: 1167 rcu_read_unlock(); 1168 nla_nest_cancel(skb, start); 1169 return err; 1170 } 1171 1172 static size_t tls_get_info_size(const struct sock *sk, bool net_admin) 1173 { 1174 size_t size = 0; 1175 1176 size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ 1177 nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ 1178 nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ 1179 nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ 1180 nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ 1181 nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ 1182 nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ 1183 nla_total_size(sizeof(u16)) + /* TLS_INFO_TX_MAX_PAYLOAD_LEN */ 1184 0; 1185 1186 return size; 1187 } 1188 1189 static int __net_init tls_init_net(struct net *net) 1190 { 1191 int err; 1192 1193 net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); 1194 if (!net->mib.tls_statistics) 1195 return -ENOMEM; 1196 1197 err = tls_proc_init(net); 1198 if (err) 1199 goto err_free_stats; 1200 1201 return 0; 1202 err_free_stats: 1203 free_percpu(net->mib.tls_statistics); 1204 return err; 1205 } 1206 1207 static void __net_exit tls_exit_net(struct net *net) 1208 { 1209 tls_proc_fini(net); 1210 free_percpu(net->mib.tls_statistics); 1211 } 1212 1213 static struct pernet_operations tls_proc_ops = { 1214 .init = tls_init_net, 1215 .exit = tls_exit_net, 1216 }; 1217 1218 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 1219 .name = "tls", 1220 .owner = THIS_MODULE, 1221 .init = tls_init, 1222 .update = tls_update, 1223 .get_info = tls_get_info, 1224 .get_info_size = tls_get_info_size, 1225 }; 1226 1227 static int __init tls_register(void) 1228 { 1229 int err; 1230 1231 err = register_pernet_subsys(&tls_proc_ops); 1232 if (err) 1233 return err; 1234 1235 err = tls_strp_dev_init(); 1236 if (err) 1237 goto err_pernet; 1238 1239 err = tls_device_init(); 1240 if (err) 1241 goto err_strp; 1242 1243 tcp_register_ulp(&tcp_tls_ulp_ops); 1244 1245 return 0; 1246 err_strp: 1247 tls_strp_dev_exit(); 1248 err_pernet: 1249 unregister_pernet_subsys(&tls_proc_ops); 1250 return err; 1251 } 1252 1253 static void __exit tls_unregister(void) 1254 { 1255 tcp_unregister_ulp(&tcp_tls_ulp_ops); 1256 tls_strp_dev_exit(); 1257 tls_device_cleanup(); 1258 unregister_pernet_subsys(&tls_proc_ops); 1259 } 1260 1261 module_init(tls_register); 1262 module_exit(tls_unregister); 1263