1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 #include <linux/inetdevice.h> 42 #include <linux/inet_diag.h> 43 44 #include <net/snmp.h> 45 #include <net/tls.h> 46 #include <net/tls_toe.h> 47 48 #include "tls.h" 49 50 MODULE_AUTHOR("Mellanox Technologies"); 51 MODULE_DESCRIPTION("Transport Layer Security Support"); 52 MODULE_LICENSE("Dual BSD/GPL"); 53 MODULE_ALIAS_TCP_ULP("tls"); 54 55 enum { 56 TLSV4, 57 TLSV6, 58 TLS_NUM_PROTS, 59 }; 60 61 #define CHECK_CIPHER_DESC(cipher,ci) \ 62 static_assert(cipher ## _IV_SIZE <= MAX_IV_SIZE); \ 63 static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE); \ 64 static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE); \ 65 static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE); \ 66 static_assert(sizeof_field(struct ci, key) == cipher ## _KEY_SIZE); \ 67 static_assert(sizeof_field(struct ci, salt) == cipher ## _SALT_SIZE); \ 68 static_assert(sizeof_field(struct ci, rec_seq) == cipher ## _REC_SEQ_SIZE); 69 70 #define __CIPHER_DESC(ci) \ 71 .iv_offset = offsetof(struct ci, iv), \ 72 .key_offset = offsetof(struct ci, key), \ 73 .salt_offset = offsetof(struct ci, salt), \ 74 .rec_seq_offset = offsetof(struct ci, rec_seq), \ 75 .crypto_info = sizeof(struct ci) 76 77 #define CIPHER_DESC(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ 78 .nonce = cipher ## _IV_SIZE, \ 79 .iv = cipher ## _IV_SIZE, \ 80 .key = cipher ## _KEY_SIZE, \ 81 .salt = cipher ## _SALT_SIZE, \ 82 .tag = cipher ## _TAG_SIZE, \ 83 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 84 .cipher_name = algname, \ 85 .offloadable = _offloadable, \ 86 __CIPHER_DESC(ci), \ 87 } 88 89 #define CIPHER_DESC_NONCE0(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ 90 .nonce = 0, \ 91 .iv = cipher ## _IV_SIZE, \ 92 .key = cipher ## _KEY_SIZE, \ 93 .salt = cipher ## _SALT_SIZE, \ 94 .tag = cipher ## _TAG_SIZE, \ 95 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 96 .cipher_name = algname, \ 97 .offloadable = _offloadable, \ 98 __CIPHER_DESC(ci), \ 99 } 100 101 const struct tls_cipher_desc tls_cipher_desc[TLS_CIPHER_MAX + 1 - TLS_CIPHER_MIN] = { 102 CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128, "gcm(aes)", true), 103 CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256, "gcm(aes)", true), 104 CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128, "ccm(aes)", false), 105 CIPHER_DESC_NONCE0(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305, "rfc7539(chacha20,poly1305)", false), 106 CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm, "gcm(sm4)", false), 107 CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm, "ccm(sm4)", false), 108 CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128, "gcm(aria)", false), 109 CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256, "gcm(aria)", false), 110 }; 111 112 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128); 113 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256); 114 CHECK_CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128); 115 CHECK_CIPHER_DESC(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305); 116 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm); 117 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm); 118 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128); 119 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256); 120 121 static const struct proto *saved_tcpv6_prot; 122 static DEFINE_MUTEX(tcpv6_prot_mutex); 123 static const struct proto *saved_tcpv4_prot; 124 static DEFINE_MUTEX(tcpv4_prot_mutex); 125 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 126 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 127 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 128 const struct proto *base); 129 130 void update_sk_prot(struct sock *sk, struct tls_context *ctx) 131 { 132 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 133 134 WRITE_ONCE(sk->sk_prot, 135 &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); 136 WRITE_ONCE(sk->sk_socket->ops, 137 &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]); 138 } 139 140 int wait_on_pending_writer(struct sock *sk, long *timeo) 141 { 142 int rc = 0; 143 DEFINE_WAIT_FUNC(wait, woken_wake_function); 144 145 add_wait_queue(sk_sleep(sk), &wait); 146 while (1) { 147 if (!*timeo) { 148 rc = -EAGAIN; 149 break; 150 } 151 152 if (signal_pending(current)) { 153 rc = sock_intr_errno(*timeo); 154 break; 155 } 156 157 if (sk_wait_event(sk, timeo, 158 !READ_ONCE(sk->sk_write_pending), &wait)) 159 break; 160 } 161 remove_wait_queue(sk_sleep(sk), &wait); 162 return rc; 163 } 164 165 int tls_push_sg(struct sock *sk, 166 struct tls_context *ctx, 167 struct scatterlist *sg, 168 u16 first_offset, 169 int flags) 170 { 171 struct bio_vec bvec; 172 struct msghdr msg = { 173 .msg_flags = MSG_SPLICE_PAGES | flags, 174 }; 175 int ret = 0; 176 struct page *p; 177 size_t size; 178 int offset = first_offset; 179 180 size = sg->length - offset; 181 offset += sg->offset; 182 183 ctx->splicing_pages = true; 184 while (1) { 185 /* is sending application-limited? */ 186 tcp_rate_check_app_limited(sk); 187 p = sg_page(sg); 188 retry: 189 bvec_set_page(&bvec, p, size, offset); 190 iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size); 191 192 ret = tcp_sendmsg_locked(sk, &msg, size); 193 194 if (ret != size) { 195 if (ret > 0) { 196 offset += ret; 197 size -= ret; 198 goto retry; 199 } 200 201 offset -= sg->offset; 202 ctx->partially_sent_offset = offset; 203 ctx->partially_sent_record = (void *)sg; 204 ctx->splicing_pages = false; 205 return ret; 206 } 207 208 put_page(p); 209 sk_mem_uncharge(sk, sg->length); 210 sg = sg_next(sg); 211 if (!sg) 212 break; 213 214 offset = sg->offset; 215 size = sg->length; 216 } 217 218 ctx->splicing_pages = false; 219 220 return 0; 221 } 222 223 static int tls_handle_open_record(struct sock *sk, int flags) 224 { 225 struct tls_context *ctx = tls_get_ctx(sk); 226 227 if (tls_is_pending_open_record(ctx)) 228 return ctx->push_pending_record(sk, flags); 229 230 return 0; 231 } 232 233 int tls_process_cmsg(struct sock *sk, struct msghdr *msg, 234 unsigned char *record_type) 235 { 236 struct cmsghdr *cmsg; 237 int rc = -EINVAL; 238 239 for_each_cmsghdr(cmsg, msg) { 240 if (!CMSG_OK(msg, cmsg)) 241 return -EINVAL; 242 if (cmsg->cmsg_level != SOL_TLS) 243 continue; 244 245 switch (cmsg->cmsg_type) { 246 case TLS_SET_RECORD_TYPE: 247 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 248 return -EINVAL; 249 250 if (msg->msg_flags & MSG_MORE) 251 return -EINVAL; 252 253 rc = tls_handle_open_record(sk, msg->msg_flags); 254 if (rc) 255 return rc; 256 257 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 258 rc = 0; 259 break; 260 default: 261 return -EINVAL; 262 } 263 } 264 265 return rc; 266 } 267 268 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, 269 int flags) 270 { 271 struct scatterlist *sg; 272 u16 offset; 273 274 sg = ctx->partially_sent_record; 275 offset = ctx->partially_sent_offset; 276 277 ctx->partially_sent_record = NULL; 278 return tls_push_sg(sk, ctx, sg, offset, flags); 279 } 280 281 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) 282 { 283 struct scatterlist *sg; 284 285 for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { 286 put_page(sg_page(sg)); 287 sk_mem_uncharge(sk, sg->length); 288 } 289 ctx->partially_sent_record = NULL; 290 } 291 292 static void tls_write_space(struct sock *sk) 293 { 294 struct tls_context *ctx = tls_get_ctx(sk); 295 296 /* If splicing_pages call lower protocol write space handler 297 * to ensure we wake up any waiting operations there. For example 298 * if splicing pages where to call sk_wait_event. 299 */ 300 if (ctx->splicing_pages) { 301 ctx->sk_write_space(sk); 302 return; 303 } 304 305 #ifdef CONFIG_TLS_DEVICE 306 if (ctx->tx_conf == TLS_HW) 307 tls_device_write_space(sk, ctx); 308 else 309 #endif 310 tls_sw_write_space(sk, ctx); 311 312 ctx->sk_write_space(sk); 313 } 314 315 /** 316 * tls_ctx_free() - free TLS ULP context 317 * @sk: socket to with @ctx is attached 318 * @ctx: TLS context structure 319 * 320 * Free TLS context. If @sk is %NULL caller guarantees that the socket 321 * to which @ctx was attached has no outstanding references. 322 */ 323 void tls_ctx_free(struct sock *sk, struct tls_context *ctx) 324 { 325 if (!ctx) 326 return; 327 328 memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); 329 memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); 330 mutex_destroy(&ctx->tx_lock); 331 332 if (sk) 333 kfree_rcu(ctx, rcu); 334 else 335 kfree(ctx); 336 } 337 338 static void tls_sk_proto_cleanup(struct sock *sk, 339 struct tls_context *ctx, long timeo) 340 { 341 if (unlikely(sk->sk_write_pending) && 342 !wait_on_pending_writer(sk, &timeo)) 343 tls_handle_open_record(sk, 0); 344 345 /* We need these for tls_sw_fallback handling of other packets */ 346 if (ctx->tx_conf == TLS_SW) { 347 kfree(ctx->tx.rec_seq); 348 kfree(ctx->tx.iv); 349 tls_sw_release_resources_tx(sk); 350 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 351 } else if (ctx->tx_conf == TLS_HW) { 352 tls_device_free_resources_tx(sk); 353 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 354 } 355 356 if (ctx->rx_conf == TLS_SW) { 357 tls_sw_release_resources_rx(sk); 358 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 359 } else if (ctx->rx_conf == TLS_HW) { 360 tls_device_offload_cleanup_rx(sk); 361 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 362 } 363 } 364 365 static void tls_sk_proto_close(struct sock *sk, long timeout) 366 { 367 struct inet_connection_sock *icsk = inet_csk(sk); 368 struct tls_context *ctx = tls_get_ctx(sk); 369 long timeo = sock_sndtimeo(sk, 0); 370 bool free_ctx; 371 372 if (ctx->tx_conf == TLS_SW) 373 tls_sw_cancel_work_tx(ctx); 374 375 lock_sock(sk); 376 free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; 377 378 if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) 379 tls_sk_proto_cleanup(sk, ctx, timeo); 380 381 write_lock_bh(&sk->sk_callback_lock); 382 if (free_ctx) 383 rcu_assign_pointer(icsk->icsk_ulp_data, NULL); 384 WRITE_ONCE(sk->sk_prot, ctx->sk_proto); 385 if (sk->sk_write_space == tls_write_space) 386 sk->sk_write_space = ctx->sk_write_space; 387 write_unlock_bh(&sk->sk_callback_lock); 388 release_sock(sk); 389 if (ctx->tx_conf == TLS_SW) 390 tls_sw_free_ctx_tx(ctx); 391 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 392 tls_sw_strparser_done(ctx); 393 if (ctx->rx_conf == TLS_SW) 394 tls_sw_free_ctx_rx(ctx); 395 ctx->sk_proto->close(sk, timeout); 396 397 if (free_ctx) 398 tls_ctx_free(sk, ctx); 399 } 400 401 static __poll_t tls_sk_poll(struct file *file, struct socket *sock, 402 struct poll_table_struct *wait) 403 { 404 struct tls_sw_context_rx *ctx; 405 struct tls_context *tls_ctx; 406 struct sock *sk = sock->sk; 407 struct sk_psock *psock; 408 __poll_t mask = 0; 409 u8 shutdown; 410 int state; 411 412 mask = tcp_poll(file, sock, wait); 413 414 state = inet_sk_state_load(sk); 415 shutdown = READ_ONCE(sk->sk_shutdown); 416 if (unlikely(state != TCP_ESTABLISHED || shutdown & RCV_SHUTDOWN)) 417 return mask; 418 419 tls_ctx = tls_get_ctx(sk); 420 ctx = tls_sw_ctx_rx(tls_ctx); 421 psock = sk_psock_get(sk); 422 423 if (skb_queue_empty_lockless(&ctx->rx_list) && 424 !tls_strp_msg_ready(ctx) && 425 sk_psock_queue_empty(psock)) 426 mask &= ~(EPOLLIN | EPOLLRDNORM); 427 428 if (psock) 429 sk_psock_put(sk, psock); 430 431 return mask; 432 } 433 434 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, 435 int __user *optlen, int tx) 436 { 437 int rc = 0; 438 const struct tls_cipher_desc *cipher_desc; 439 struct tls_context *ctx = tls_get_ctx(sk); 440 struct tls_crypto_info *crypto_info; 441 struct cipher_context *cctx; 442 int len; 443 444 if (get_user(len, optlen)) 445 return -EFAULT; 446 447 if (!optval || (len < sizeof(*crypto_info))) { 448 rc = -EINVAL; 449 goto out; 450 } 451 452 if (!ctx) { 453 rc = -EBUSY; 454 goto out; 455 } 456 457 /* get user crypto info */ 458 if (tx) { 459 crypto_info = &ctx->crypto_send.info; 460 cctx = &ctx->tx; 461 } else { 462 crypto_info = &ctx->crypto_recv.info; 463 cctx = &ctx->rx; 464 } 465 466 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 467 rc = -EBUSY; 468 goto out; 469 } 470 471 if (len == sizeof(*crypto_info)) { 472 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 473 rc = -EFAULT; 474 goto out; 475 } 476 477 cipher_desc = get_cipher_desc(crypto_info->cipher_type); 478 if (!cipher_desc || len != cipher_desc->crypto_info) { 479 rc = -EINVAL; 480 goto out; 481 } 482 483 memcpy(crypto_info_iv(crypto_info, cipher_desc), 484 cctx->iv + cipher_desc->salt, cipher_desc->iv); 485 memcpy(crypto_info_rec_seq(crypto_info, cipher_desc), 486 cctx->rec_seq, cipher_desc->rec_seq); 487 488 if (copy_to_user(optval, crypto_info, cipher_desc->crypto_info)) 489 rc = -EFAULT; 490 491 out: 492 return rc; 493 } 494 495 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, 496 int __user *optlen) 497 { 498 struct tls_context *ctx = tls_get_ctx(sk); 499 unsigned int value; 500 int len; 501 502 if (get_user(len, optlen)) 503 return -EFAULT; 504 505 if (len != sizeof(value)) 506 return -EINVAL; 507 508 value = ctx->zerocopy_sendfile; 509 if (copy_to_user(optval, &value, sizeof(value))) 510 return -EFAULT; 511 512 return 0; 513 } 514 515 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, 516 int __user *optlen) 517 { 518 struct tls_context *ctx = tls_get_ctx(sk); 519 int value, len; 520 521 if (ctx->prot_info.version != TLS_1_3_VERSION) 522 return -EINVAL; 523 524 if (get_user(len, optlen)) 525 return -EFAULT; 526 if (len < sizeof(value)) 527 return -EINVAL; 528 529 value = -EINVAL; 530 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 531 value = ctx->rx_no_pad; 532 if (value < 0) 533 return value; 534 535 if (put_user(sizeof(value), optlen)) 536 return -EFAULT; 537 if (copy_to_user(optval, &value, sizeof(value))) 538 return -EFAULT; 539 540 return 0; 541 } 542 543 static int do_tls_getsockopt(struct sock *sk, int optname, 544 char __user *optval, int __user *optlen) 545 { 546 int rc = 0; 547 548 lock_sock(sk); 549 550 switch (optname) { 551 case TLS_TX: 552 case TLS_RX: 553 rc = do_tls_getsockopt_conf(sk, optval, optlen, 554 optname == TLS_TX); 555 break; 556 case TLS_TX_ZEROCOPY_RO: 557 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); 558 break; 559 case TLS_RX_EXPECT_NO_PAD: 560 rc = do_tls_getsockopt_no_pad(sk, optval, optlen); 561 break; 562 default: 563 rc = -ENOPROTOOPT; 564 break; 565 } 566 567 release_sock(sk); 568 569 return rc; 570 } 571 572 static int tls_getsockopt(struct sock *sk, int level, int optname, 573 char __user *optval, int __user *optlen) 574 { 575 struct tls_context *ctx = tls_get_ctx(sk); 576 577 if (level != SOL_TLS) 578 return ctx->sk_proto->getsockopt(sk, level, 579 optname, optval, optlen); 580 581 return do_tls_getsockopt(sk, optname, optval, optlen); 582 } 583 584 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, 585 unsigned int optlen, int tx) 586 { 587 struct tls_crypto_info *crypto_info; 588 struct tls_crypto_info *alt_crypto_info; 589 struct tls_context *ctx = tls_get_ctx(sk); 590 const struct tls_cipher_desc *cipher_desc; 591 int rc = 0; 592 int conf; 593 594 if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) 595 return -EINVAL; 596 597 if (tx) { 598 crypto_info = &ctx->crypto_send.info; 599 alt_crypto_info = &ctx->crypto_recv.info; 600 } else { 601 crypto_info = &ctx->crypto_recv.info; 602 alt_crypto_info = &ctx->crypto_send.info; 603 } 604 605 /* Currently we don't support set crypto info more than one time */ 606 if (TLS_CRYPTO_INFO_READY(crypto_info)) 607 return -EBUSY; 608 609 rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); 610 if (rc) { 611 rc = -EFAULT; 612 goto err_crypto_info; 613 } 614 615 /* check version */ 616 if (crypto_info->version != TLS_1_2_VERSION && 617 crypto_info->version != TLS_1_3_VERSION) { 618 rc = -EINVAL; 619 goto err_crypto_info; 620 } 621 622 /* Ensure that TLS version and ciphers are same in both directions */ 623 if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { 624 if (alt_crypto_info->version != crypto_info->version || 625 alt_crypto_info->cipher_type != crypto_info->cipher_type) { 626 rc = -EINVAL; 627 goto err_crypto_info; 628 } 629 } 630 631 cipher_desc = get_cipher_desc(crypto_info->cipher_type); 632 if (!cipher_desc) { 633 rc = -EINVAL; 634 goto err_crypto_info; 635 } 636 637 switch (crypto_info->cipher_type) { 638 case TLS_CIPHER_ARIA_GCM_128: 639 case TLS_CIPHER_ARIA_GCM_256: 640 if (crypto_info->version != TLS_1_2_VERSION) { 641 rc = -EINVAL; 642 goto err_crypto_info; 643 } 644 break; 645 } 646 647 if (optlen != cipher_desc->crypto_info) { 648 rc = -EINVAL; 649 goto err_crypto_info; 650 } 651 652 rc = copy_from_sockptr_offset(crypto_info + 1, optval, 653 sizeof(*crypto_info), 654 optlen - sizeof(*crypto_info)); 655 if (rc) { 656 rc = -EFAULT; 657 goto err_crypto_info; 658 } 659 660 if (tx) { 661 rc = tls_set_device_offload(sk, ctx); 662 conf = TLS_HW; 663 if (!rc) { 664 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); 665 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 666 } else { 667 rc = tls_set_sw_offload(sk, ctx, 1); 668 if (rc) 669 goto err_crypto_info; 670 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); 671 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 672 conf = TLS_SW; 673 } 674 } else { 675 rc = tls_set_device_offload_rx(sk, ctx); 676 conf = TLS_HW; 677 if (!rc) { 678 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); 679 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 680 } else { 681 rc = tls_set_sw_offload(sk, ctx, 0); 682 if (rc) 683 goto err_crypto_info; 684 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); 685 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 686 conf = TLS_SW; 687 } 688 tls_sw_strparser_arm(sk, ctx); 689 } 690 691 if (tx) 692 ctx->tx_conf = conf; 693 else 694 ctx->rx_conf = conf; 695 update_sk_prot(sk, ctx); 696 if (tx) { 697 ctx->sk_write_space = sk->sk_write_space; 698 sk->sk_write_space = tls_write_space; 699 } else { 700 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); 701 702 tls_strp_check_rcv(&rx_ctx->strp); 703 } 704 return 0; 705 706 err_crypto_info: 707 memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); 708 return rc; 709 } 710 711 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, 712 unsigned int optlen) 713 { 714 struct tls_context *ctx = tls_get_ctx(sk); 715 unsigned int value; 716 717 if (sockptr_is_null(optval) || optlen != sizeof(value)) 718 return -EINVAL; 719 720 if (copy_from_sockptr(&value, optval, sizeof(value))) 721 return -EFAULT; 722 723 if (value > 1) 724 return -EINVAL; 725 726 ctx->zerocopy_sendfile = value; 727 728 return 0; 729 } 730 731 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, 732 unsigned int optlen) 733 { 734 struct tls_context *ctx = tls_get_ctx(sk); 735 u32 val; 736 int rc; 737 738 if (ctx->prot_info.version != TLS_1_3_VERSION || 739 sockptr_is_null(optval) || optlen < sizeof(val)) 740 return -EINVAL; 741 742 rc = copy_from_sockptr(&val, optval, sizeof(val)); 743 if (rc) 744 return -EFAULT; 745 if (val > 1) 746 return -EINVAL; 747 rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); 748 if (rc < 1) 749 return rc == 0 ? -EINVAL : rc; 750 751 lock_sock(sk); 752 rc = -EINVAL; 753 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { 754 ctx->rx_no_pad = val; 755 tls_update_rx_zc_capable(ctx); 756 rc = 0; 757 } 758 release_sock(sk); 759 760 return rc; 761 } 762 763 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, 764 unsigned int optlen) 765 { 766 int rc = 0; 767 768 switch (optname) { 769 case TLS_TX: 770 case TLS_RX: 771 lock_sock(sk); 772 rc = do_tls_setsockopt_conf(sk, optval, optlen, 773 optname == TLS_TX); 774 release_sock(sk); 775 break; 776 case TLS_TX_ZEROCOPY_RO: 777 lock_sock(sk); 778 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); 779 release_sock(sk); 780 break; 781 case TLS_RX_EXPECT_NO_PAD: 782 rc = do_tls_setsockopt_no_pad(sk, optval, optlen); 783 break; 784 default: 785 rc = -ENOPROTOOPT; 786 break; 787 } 788 return rc; 789 } 790 791 static int tls_setsockopt(struct sock *sk, int level, int optname, 792 sockptr_t optval, unsigned int optlen) 793 { 794 struct tls_context *ctx = tls_get_ctx(sk); 795 796 if (level != SOL_TLS) 797 return ctx->sk_proto->setsockopt(sk, level, optname, optval, 798 optlen); 799 800 return do_tls_setsockopt(sk, optname, optval, optlen); 801 } 802 803 struct tls_context *tls_ctx_create(struct sock *sk) 804 { 805 struct inet_connection_sock *icsk = inet_csk(sk); 806 struct tls_context *ctx; 807 808 ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); 809 if (!ctx) 810 return NULL; 811 812 mutex_init(&ctx->tx_lock); 813 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 814 ctx->sk_proto = READ_ONCE(sk->sk_prot); 815 ctx->sk = sk; 816 return ctx; 817 } 818 819 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 820 const struct proto_ops *base) 821 { 822 ops[TLS_BASE][TLS_BASE] = *base; 823 824 ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 825 ops[TLS_SW ][TLS_BASE].splice_eof = tls_sw_splice_eof; 826 827 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; 828 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; 829 ops[TLS_BASE][TLS_SW ].poll = tls_sk_poll; 830 ops[TLS_BASE][TLS_SW ].read_sock = tls_sw_read_sock; 831 832 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; 833 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; 834 ops[TLS_SW ][TLS_SW ].poll = tls_sk_poll; 835 ops[TLS_SW ][TLS_SW ].read_sock = tls_sw_read_sock; 836 837 #ifdef CONFIG_TLS_DEVICE 838 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 839 840 ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; 841 842 ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; 843 844 ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; 845 846 ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; 847 #endif 848 #ifdef CONFIG_TLS_TOE 849 ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 850 #endif 851 } 852 853 static void tls_build_proto(struct sock *sk) 854 { 855 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 856 struct proto *prot = READ_ONCE(sk->sk_prot); 857 858 /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ 859 if (ip_ver == TLSV6 && 860 unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) { 861 mutex_lock(&tcpv6_prot_mutex); 862 if (likely(prot != saved_tcpv6_prot)) { 863 build_protos(tls_prots[TLSV6], prot); 864 build_proto_ops(tls_proto_ops[TLSV6], 865 sk->sk_socket->ops); 866 smp_store_release(&saved_tcpv6_prot, prot); 867 } 868 mutex_unlock(&tcpv6_prot_mutex); 869 } 870 871 if (ip_ver == TLSV4 && 872 unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) { 873 mutex_lock(&tcpv4_prot_mutex); 874 if (likely(prot != saved_tcpv4_prot)) { 875 build_protos(tls_prots[TLSV4], prot); 876 build_proto_ops(tls_proto_ops[TLSV4], 877 sk->sk_socket->ops); 878 smp_store_release(&saved_tcpv4_prot, prot); 879 } 880 mutex_unlock(&tcpv4_prot_mutex); 881 } 882 } 883 884 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 885 const struct proto *base) 886 { 887 prot[TLS_BASE][TLS_BASE] = *base; 888 prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; 889 prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; 890 prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; 891 892 prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 893 prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; 894 prot[TLS_SW][TLS_BASE].splice_eof = tls_sw_splice_eof; 895 896 prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; 897 prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; 898 prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 899 prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; 900 901 prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; 902 prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; 903 prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 904 prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; 905 906 #ifdef CONFIG_TLS_DEVICE 907 prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 908 prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; 909 prot[TLS_HW][TLS_BASE].splice_eof = tls_device_splice_eof; 910 911 prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; 912 prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; 913 prot[TLS_HW][TLS_SW].splice_eof = tls_device_splice_eof; 914 915 prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; 916 917 prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; 918 919 prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; 920 #endif 921 #ifdef CONFIG_TLS_TOE 922 prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 923 prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash; 924 prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash; 925 #endif 926 } 927 928 static int tls_init(struct sock *sk) 929 { 930 struct tls_context *ctx; 931 int rc = 0; 932 933 tls_build_proto(sk); 934 935 #ifdef CONFIG_TLS_TOE 936 if (tls_toe_bypass(sk)) 937 return 0; 938 #endif 939 940 /* The TLS ulp is currently supported only for TCP sockets 941 * in ESTABLISHED state. 942 * Supporting sockets in LISTEN state will require us 943 * to modify the accept implementation to clone rather then 944 * share the ulp context. 945 */ 946 if (sk->sk_state != TCP_ESTABLISHED) 947 return -ENOTCONN; 948 949 /* allocate tls context */ 950 write_lock_bh(&sk->sk_callback_lock); 951 ctx = tls_ctx_create(sk); 952 if (!ctx) { 953 rc = -ENOMEM; 954 goto out; 955 } 956 957 ctx->tx_conf = TLS_BASE; 958 ctx->rx_conf = TLS_BASE; 959 update_sk_prot(sk, ctx); 960 out: 961 write_unlock_bh(&sk->sk_callback_lock); 962 return rc; 963 } 964 965 static void tls_update(struct sock *sk, struct proto *p, 966 void (*write_space)(struct sock *sk)) 967 { 968 struct tls_context *ctx; 969 970 WARN_ON_ONCE(sk->sk_prot == p); 971 972 ctx = tls_get_ctx(sk); 973 if (likely(ctx)) { 974 ctx->sk_write_space = write_space; 975 ctx->sk_proto = p; 976 } else { 977 /* Pairs with lockless read in sk_clone_lock(). */ 978 WRITE_ONCE(sk->sk_prot, p); 979 sk->sk_write_space = write_space; 980 } 981 } 982 983 static u16 tls_user_config(struct tls_context *ctx, bool tx) 984 { 985 u16 config = tx ? ctx->tx_conf : ctx->rx_conf; 986 987 switch (config) { 988 case TLS_BASE: 989 return TLS_CONF_BASE; 990 case TLS_SW: 991 return TLS_CONF_SW; 992 case TLS_HW: 993 return TLS_CONF_HW; 994 case TLS_HW_RECORD: 995 return TLS_CONF_HW_RECORD; 996 } 997 return 0; 998 } 999 1000 static int tls_get_info(const struct sock *sk, struct sk_buff *skb) 1001 { 1002 u16 version, cipher_type; 1003 struct tls_context *ctx; 1004 struct nlattr *start; 1005 int err; 1006 1007 start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); 1008 if (!start) 1009 return -EMSGSIZE; 1010 1011 rcu_read_lock(); 1012 ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); 1013 if (!ctx) { 1014 err = 0; 1015 goto nla_failure; 1016 } 1017 version = ctx->prot_info.version; 1018 if (version) { 1019 err = nla_put_u16(skb, TLS_INFO_VERSION, version); 1020 if (err) 1021 goto nla_failure; 1022 } 1023 cipher_type = ctx->prot_info.cipher_type; 1024 if (cipher_type) { 1025 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); 1026 if (err) 1027 goto nla_failure; 1028 } 1029 err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); 1030 if (err) 1031 goto nla_failure; 1032 1033 err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); 1034 if (err) 1035 goto nla_failure; 1036 1037 if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) { 1038 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX); 1039 if (err) 1040 goto nla_failure; 1041 } 1042 if (ctx->rx_no_pad) { 1043 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); 1044 if (err) 1045 goto nla_failure; 1046 } 1047 1048 rcu_read_unlock(); 1049 nla_nest_end(skb, start); 1050 return 0; 1051 1052 nla_failure: 1053 rcu_read_unlock(); 1054 nla_nest_cancel(skb, start); 1055 return err; 1056 } 1057 1058 static size_t tls_get_info_size(const struct sock *sk) 1059 { 1060 size_t size = 0; 1061 1062 size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ 1063 nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ 1064 nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ 1065 nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ 1066 nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ 1067 nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ 1068 nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ 1069 0; 1070 1071 return size; 1072 } 1073 1074 static int __net_init tls_init_net(struct net *net) 1075 { 1076 int err; 1077 1078 net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); 1079 if (!net->mib.tls_statistics) 1080 return -ENOMEM; 1081 1082 err = tls_proc_init(net); 1083 if (err) 1084 goto err_free_stats; 1085 1086 return 0; 1087 err_free_stats: 1088 free_percpu(net->mib.tls_statistics); 1089 return err; 1090 } 1091 1092 static void __net_exit tls_exit_net(struct net *net) 1093 { 1094 tls_proc_fini(net); 1095 free_percpu(net->mib.tls_statistics); 1096 } 1097 1098 static struct pernet_operations tls_proc_ops = { 1099 .init = tls_init_net, 1100 .exit = tls_exit_net, 1101 }; 1102 1103 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 1104 .name = "tls", 1105 .owner = THIS_MODULE, 1106 .init = tls_init, 1107 .update = tls_update, 1108 .get_info = tls_get_info, 1109 .get_info_size = tls_get_info_size, 1110 }; 1111 1112 static int __init tls_register(void) 1113 { 1114 int err; 1115 1116 err = register_pernet_subsys(&tls_proc_ops); 1117 if (err) 1118 return err; 1119 1120 err = tls_strp_dev_init(); 1121 if (err) 1122 goto err_pernet; 1123 1124 err = tls_device_init(); 1125 if (err) 1126 goto err_strp; 1127 1128 tcp_register_ulp(&tcp_tls_ulp_ops); 1129 1130 return 0; 1131 err_strp: 1132 tls_strp_dev_exit(); 1133 err_pernet: 1134 unregister_pernet_subsys(&tls_proc_ops); 1135 return err; 1136 } 1137 1138 static void __exit tls_unregister(void) 1139 { 1140 tcp_unregister_ulp(&tcp_tls_ulp_ops); 1141 tls_strp_dev_exit(); 1142 tls_device_cleanup(); 1143 unregister_pernet_subsys(&tls_proc_ops); 1144 } 1145 1146 module_init(tls_register); 1147 module_exit(tls_unregister); 1148