1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 #include <linux/inetdevice.h> 42 #include <linux/inet_diag.h> 43 44 #include <net/snmp.h> 45 #include <net/tls.h> 46 #include <net/tls_toe.h> 47 48 #include "tls.h" 49 50 MODULE_AUTHOR("Mellanox Technologies"); 51 MODULE_DESCRIPTION("Transport Layer Security Support"); 52 MODULE_LICENSE("Dual BSD/GPL"); 53 MODULE_ALIAS_TCP_ULP("tls"); 54 55 enum { 56 TLSV4, 57 TLSV6, 58 TLS_NUM_PROTS, 59 }; 60 61 #define CIPHER_SIZE_DESC(cipher) [cipher] = { \ 62 .iv = cipher ## _IV_SIZE, \ 63 .key = cipher ## _KEY_SIZE, \ 64 .salt = cipher ## _SALT_SIZE, \ 65 .tag = cipher ## _TAG_SIZE, \ 66 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 67 } 68 69 const struct tls_cipher_size_desc tls_cipher_size_desc[] = { 70 CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_128), 71 CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_256), 72 CIPHER_SIZE_DESC(TLS_CIPHER_AES_CCM_128), 73 CIPHER_SIZE_DESC(TLS_CIPHER_CHACHA20_POLY1305), 74 CIPHER_SIZE_DESC(TLS_CIPHER_SM4_GCM), 75 CIPHER_SIZE_DESC(TLS_CIPHER_SM4_CCM), 76 }; 77 78 static const struct proto *saved_tcpv6_prot; 79 static DEFINE_MUTEX(tcpv6_prot_mutex); 80 static const struct proto *saved_tcpv4_prot; 81 static DEFINE_MUTEX(tcpv4_prot_mutex); 82 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 83 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 84 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 85 const struct proto *base); 86 87 void update_sk_prot(struct sock *sk, struct tls_context *ctx) 88 { 89 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 90 91 WRITE_ONCE(sk->sk_prot, 92 &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); 93 WRITE_ONCE(sk->sk_socket->ops, 94 &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]); 95 } 96 97 int wait_on_pending_writer(struct sock *sk, long *timeo) 98 { 99 int rc = 0; 100 DEFINE_WAIT_FUNC(wait, woken_wake_function); 101 102 add_wait_queue(sk_sleep(sk), &wait); 103 while (1) { 104 if (!*timeo) { 105 rc = -EAGAIN; 106 break; 107 } 108 109 if (signal_pending(current)) { 110 rc = sock_intr_errno(*timeo); 111 break; 112 } 113 114 if (sk_wait_event(sk, timeo, 115 !READ_ONCE(sk->sk_write_pending), &wait)) 116 break; 117 } 118 remove_wait_queue(sk_sleep(sk), &wait); 119 return rc; 120 } 121 122 int tls_push_sg(struct sock *sk, 123 struct tls_context *ctx, 124 struct scatterlist *sg, 125 u16 first_offset, 126 int flags) 127 { 128 int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST; 129 int ret = 0; 130 struct page *p; 131 size_t size; 132 int offset = first_offset; 133 134 size = sg->length - offset; 135 offset += sg->offset; 136 137 ctx->in_tcp_sendpages = true; 138 while (1) { 139 if (sg_is_last(sg)) 140 sendpage_flags = flags; 141 142 /* is sending application-limited? */ 143 tcp_rate_check_app_limited(sk); 144 p = sg_page(sg); 145 retry: 146 ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags); 147 148 if (ret != size) { 149 if (ret > 0) { 150 offset += ret; 151 size -= ret; 152 goto retry; 153 } 154 155 offset -= sg->offset; 156 ctx->partially_sent_offset = offset; 157 ctx->partially_sent_record = (void *)sg; 158 ctx->in_tcp_sendpages = false; 159 return ret; 160 } 161 162 put_page(p); 163 sk_mem_uncharge(sk, sg->length); 164 sg = sg_next(sg); 165 if (!sg) 166 break; 167 168 offset = sg->offset; 169 size = sg->length; 170 } 171 172 ctx->in_tcp_sendpages = false; 173 174 return 0; 175 } 176 177 static int tls_handle_open_record(struct sock *sk, int flags) 178 { 179 struct tls_context *ctx = tls_get_ctx(sk); 180 181 if (tls_is_pending_open_record(ctx)) 182 return ctx->push_pending_record(sk, flags); 183 184 return 0; 185 } 186 187 int tls_process_cmsg(struct sock *sk, struct msghdr *msg, 188 unsigned char *record_type) 189 { 190 struct cmsghdr *cmsg; 191 int rc = -EINVAL; 192 193 for_each_cmsghdr(cmsg, msg) { 194 if (!CMSG_OK(msg, cmsg)) 195 return -EINVAL; 196 if (cmsg->cmsg_level != SOL_TLS) 197 continue; 198 199 switch (cmsg->cmsg_type) { 200 case TLS_SET_RECORD_TYPE: 201 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 202 return -EINVAL; 203 204 if (msg->msg_flags & MSG_MORE) 205 return -EINVAL; 206 207 rc = tls_handle_open_record(sk, msg->msg_flags); 208 if (rc) 209 return rc; 210 211 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 212 rc = 0; 213 break; 214 default: 215 return -EINVAL; 216 } 217 } 218 219 return rc; 220 } 221 222 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, 223 int flags) 224 { 225 struct scatterlist *sg; 226 u16 offset; 227 228 sg = ctx->partially_sent_record; 229 offset = ctx->partially_sent_offset; 230 231 ctx->partially_sent_record = NULL; 232 return tls_push_sg(sk, ctx, sg, offset, flags); 233 } 234 235 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) 236 { 237 struct scatterlist *sg; 238 239 for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { 240 put_page(sg_page(sg)); 241 sk_mem_uncharge(sk, sg->length); 242 } 243 ctx->partially_sent_record = NULL; 244 } 245 246 static void tls_write_space(struct sock *sk) 247 { 248 struct tls_context *ctx = tls_get_ctx(sk); 249 250 /* If in_tcp_sendpages call lower protocol write space handler 251 * to ensure we wake up any waiting operations there. For example 252 * if do_tcp_sendpages where to call sk_wait_event. 253 */ 254 if (ctx->in_tcp_sendpages) { 255 ctx->sk_write_space(sk); 256 return; 257 } 258 259 #ifdef CONFIG_TLS_DEVICE 260 if (ctx->tx_conf == TLS_HW) 261 tls_device_write_space(sk, ctx); 262 else 263 #endif 264 tls_sw_write_space(sk, ctx); 265 266 ctx->sk_write_space(sk); 267 } 268 269 /** 270 * tls_ctx_free() - free TLS ULP context 271 * @sk: socket to with @ctx is attached 272 * @ctx: TLS context structure 273 * 274 * Free TLS context. If @sk is %NULL caller guarantees that the socket 275 * to which @ctx was attached has no outstanding references. 276 */ 277 void tls_ctx_free(struct sock *sk, struct tls_context *ctx) 278 { 279 if (!ctx) 280 return; 281 282 memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); 283 memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); 284 mutex_destroy(&ctx->tx_lock); 285 286 if (sk) 287 kfree_rcu(ctx, rcu); 288 else 289 kfree(ctx); 290 } 291 292 static void tls_sk_proto_cleanup(struct sock *sk, 293 struct tls_context *ctx, long timeo) 294 { 295 if (unlikely(sk->sk_write_pending) && 296 !wait_on_pending_writer(sk, &timeo)) 297 tls_handle_open_record(sk, 0); 298 299 /* We need these for tls_sw_fallback handling of other packets */ 300 if (ctx->tx_conf == TLS_SW) { 301 kfree(ctx->tx.rec_seq); 302 kfree(ctx->tx.iv); 303 tls_sw_release_resources_tx(sk); 304 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 305 } else if (ctx->tx_conf == TLS_HW) { 306 tls_device_free_resources_tx(sk); 307 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 308 } 309 310 if (ctx->rx_conf == TLS_SW) { 311 tls_sw_release_resources_rx(sk); 312 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 313 } else if (ctx->rx_conf == TLS_HW) { 314 tls_device_offload_cleanup_rx(sk); 315 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 316 } 317 } 318 319 static void tls_sk_proto_close(struct sock *sk, long timeout) 320 { 321 struct inet_connection_sock *icsk = inet_csk(sk); 322 struct tls_context *ctx = tls_get_ctx(sk); 323 long timeo = sock_sndtimeo(sk, 0); 324 bool free_ctx; 325 326 if (ctx->tx_conf == TLS_SW) 327 tls_sw_cancel_work_tx(ctx); 328 329 lock_sock(sk); 330 free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; 331 332 if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) 333 tls_sk_proto_cleanup(sk, ctx, timeo); 334 335 write_lock_bh(&sk->sk_callback_lock); 336 if (free_ctx) 337 rcu_assign_pointer(icsk->icsk_ulp_data, NULL); 338 WRITE_ONCE(sk->sk_prot, ctx->sk_proto); 339 if (sk->sk_write_space == tls_write_space) 340 sk->sk_write_space = ctx->sk_write_space; 341 write_unlock_bh(&sk->sk_callback_lock); 342 release_sock(sk); 343 if (ctx->tx_conf == TLS_SW) 344 tls_sw_free_ctx_tx(ctx); 345 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 346 tls_sw_strparser_done(ctx); 347 if (ctx->rx_conf == TLS_SW) 348 tls_sw_free_ctx_rx(ctx); 349 ctx->sk_proto->close(sk, timeout); 350 351 if (free_ctx) 352 tls_ctx_free(sk, ctx); 353 } 354 355 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, 356 int __user *optlen, int tx) 357 { 358 int rc = 0; 359 struct tls_context *ctx = tls_get_ctx(sk); 360 struct tls_crypto_info *crypto_info; 361 struct cipher_context *cctx; 362 int len; 363 364 if (get_user(len, optlen)) 365 return -EFAULT; 366 367 if (!optval || (len < sizeof(*crypto_info))) { 368 rc = -EINVAL; 369 goto out; 370 } 371 372 if (!ctx) { 373 rc = -EBUSY; 374 goto out; 375 } 376 377 /* get user crypto info */ 378 if (tx) { 379 crypto_info = &ctx->crypto_send.info; 380 cctx = &ctx->tx; 381 } else { 382 crypto_info = &ctx->crypto_recv.info; 383 cctx = &ctx->rx; 384 } 385 386 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 387 rc = -EBUSY; 388 goto out; 389 } 390 391 if (len == sizeof(*crypto_info)) { 392 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 393 rc = -EFAULT; 394 goto out; 395 } 396 397 switch (crypto_info->cipher_type) { 398 case TLS_CIPHER_AES_GCM_128: { 399 struct tls12_crypto_info_aes_gcm_128 * 400 crypto_info_aes_gcm_128 = 401 container_of(crypto_info, 402 struct tls12_crypto_info_aes_gcm_128, 403 info); 404 405 if (len != sizeof(*crypto_info_aes_gcm_128)) { 406 rc = -EINVAL; 407 goto out; 408 } 409 memcpy(crypto_info_aes_gcm_128->iv, 410 cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 411 TLS_CIPHER_AES_GCM_128_IV_SIZE); 412 memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq, 413 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE); 414 if (copy_to_user(optval, 415 crypto_info_aes_gcm_128, 416 sizeof(*crypto_info_aes_gcm_128))) 417 rc = -EFAULT; 418 break; 419 } 420 case TLS_CIPHER_AES_GCM_256: { 421 struct tls12_crypto_info_aes_gcm_256 * 422 crypto_info_aes_gcm_256 = 423 container_of(crypto_info, 424 struct tls12_crypto_info_aes_gcm_256, 425 info); 426 427 if (len != sizeof(*crypto_info_aes_gcm_256)) { 428 rc = -EINVAL; 429 goto out; 430 } 431 memcpy(crypto_info_aes_gcm_256->iv, 432 cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE, 433 TLS_CIPHER_AES_GCM_256_IV_SIZE); 434 memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq, 435 TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE); 436 if (copy_to_user(optval, 437 crypto_info_aes_gcm_256, 438 sizeof(*crypto_info_aes_gcm_256))) 439 rc = -EFAULT; 440 break; 441 } 442 case TLS_CIPHER_AES_CCM_128: { 443 struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 = 444 container_of(crypto_info, 445 struct tls12_crypto_info_aes_ccm_128, info); 446 447 if (len != sizeof(*aes_ccm_128)) { 448 rc = -EINVAL; 449 goto out; 450 } 451 memcpy(aes_ccm_128->iv, 452 cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE, 453 TLS_CIPHER_AES_CCM_128_IV_SIZE); 454 memcpy(aes_ccm_128->rec_seq, cctx->rec_seq, 455 TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE); 456 if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128))) 457 rc = -EFAULT; 458 break; 459 } 460 case TLS_CIPHER_CHACHA20_POLY1305: { 461 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 = 462 container_of(crypto_info, 463 struct tls12_crypto_info_chacha20_poly1305, 464 info); 465 466 if (len != sizeof(*chacha20_poly1305)) { 467 rc = -EINVAL; 468 goto out; 469 } 470 memcpy(chacha20_poly1305->iv, 471 cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE, 472 TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE); 473 memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq, 474 TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE); 475 if (copy_to_user(optval, chacha20_poly1305, 476 sizeof(*chacha20_poly1305))) 477 rc = -EFAULT; 478 break; 479 } 480 case TLS_CIPHER_SM4_GCM: { 481 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info = 482 container_of(crypto_info, 483 struct tls12_crypto_info_sm4_gcm, info); 484 485 if (len != sizeof(*sm4_gcm_info)) { 486 rc = -EINVAL; 487 goto out; 488 } 489 memcpy(sm4_gcm_info->iv, 490 cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE, 491 TLS_CIPHER_SM4_GCM_IV_SIZE); 492 memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq, 493 TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE); 494 if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info))) 495 rc = -EFAULT; 496 break; 497 } 498 case TLS_CIPHER_SM4_CCM: { 499 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info = 500 container_of(crypto_info, 501 struct tls12_crypto_info_sm4_ccm, info); 502 503 if (len != sizeof(*sm4_ccm_info)) { 504 rc = -EINVAL; 505 goto out; 506 } 507 memcpy(sm4_ccm_info->iv, 508 cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE, 509 TLS_CIPHER_SM4_CCM_IV_SIZE); 510 memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq, 511 TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE); 512 if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info))) 513 rc = -EFAULT; 514 break; 515 } 516 case TLS_CIPHER_ARIA_GCM_128: { 517 struct tls12_crypto_info_aria_gcm_128 * 518 crypto_info_aria_gcm_128 = 519 container_of(crypto_info, 520 struct tls12_crypto_info_aria_gcm_128, 521 info); 522 523 if (len != sizeof(*crypto_info_aria_gcm_128)) { 524 rc = -EINVAL; 525 goto out; 526 } 527 memcpy(crypto_info_aria_gcm_128->iv, 528 cctx->iv + TLS_CIPHER_ARIA_GCM_128_SALT_SIZE, 529 TLS_CIPHER_ARIA_GCM_128_IV_SIZE); 530 memcpy(crypto_info_aria_gcm_128->rec_seq, cctx->rec_seq, 531 TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE); 532 if (copy_to_user(optval, 533 crypto_info_aria_gcm_128, 534 sizeof(*crypto_info_aria_gcm_128))) 535 rc = -EFAULT; 536 break; 537 } 538 case TLS_CIPHER_ARIA_GCM_256: { 539 struct tls12_crypto_info_aria_gcm_256 * 540 crypto_info_aria_gcm_256 = 541 container_of(crypto_info, 542 struct tls12_crypto_info_aria_gcm_256, 543 info); 544 545 if (len != sizeof(*crypto_info_aria_gcm_256)) { 546 rc = -EINVAL; 547 goto out; 548 } 549 memcpy(crypto_info_aria_gcm_256->iv, 550 cctx->iv + TLS_CIPHER_ARIA_GCM_256_SALT_SIZE, 551 TLS_CIPHER_ARIA_GCM_256_IV_SIZE); 552 memcpy(crypto_info_aria_gcm_256->rec_seq, cctx->rec_seq, 553 TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE); 554 if (copy_to_user(optval, 555 crypto_info_aria_gcm_256, 556 sizeof(*crypto_info_aria_gcm_256))) 557 rc = -EFAULT; 558 break; 559 } 560 default: 561 rc = -EINVAL; 562 } 563 564 out: 565 return rc; 566 } 567 568 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, 569 int __user *optlen) 570 { 571 struct tls_context *ctx = tls_get_ctx(sk); 572 unsigned int value; 573 int len; 574 575 if (get_user(len, optlen)) 576 return -EFAULT; 577 578 if (len != sizeof(value)) 579 return -EINVAL; 580 581 value = ctx->zerocopy_sendfile; 582 if (copy_to_user(optval, &value, sizeof(value))) 583 return -EFAULT; 584 585 return 0; 586 } 587 588 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, 589 int __user *optlen) 590 { 591 struct tls_context *ctx = tls_get_ctx(sk); 592 int value, len; 593 594 if (ctx->prot_info.version != TLS_1_3_VERSION) 595 return -EINVAL; 596 597 if (get_user(len, optlen)) 598 return -EFAULT; 599 if (len < sizeof(value)) 600 return -EINVAL; 601 602 value = -EINVAL; 603 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 604 value = ctx->rx_no_pad; 605 if (value < 0) 606 return value; 607 608 if (put_user(sizeof(value), optlen)) 609 return -EFAULT; 610 if (copy_to_user(optval, &value, sizeof(value))) 611 return -EFAULT; 612 613 return 0; 614 } 615 616 static int do_tls_getsockopt(struct sock *sk, int optname, 617 char __user *optval, int __user *optlen) 618 { 619 int rc = 0; 620 621 lock_sock(sk); 622 623 switch (optname) { 624 case TLS_TX: 625 case TLS_RX: 626 rc = do_tls_getsockopt_conf(sk, optval, optlen, 627 optname == TLS_TX); 628 break; 629 case TLS_TX_ZEROCOPY_RO: 630 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); 631 break; 632 case TLS_RX_EXPECT_NO_PAD: 633 rc = do_tls_getsockopt_no_pad(sk, optval, optlen); 634 break; 635 default: 636 rc = -ENOPROTOOPT; 637 break; 638 } 639 640 release_sock(sk); 641 642 return rc; 643 } 644 645 static int tls_getsockopt(struct sock *sk, int level, int optname, 646 char __user *optval, int __user *optlen) 647 { 648 struct tls_context *ctx = tls_get_ctx(sk); 649 650 if (level != SOL_TLS) 651 return ctx->sk_proto->getsockopt(sk, level, 652 optname, optval, optlen); 653 654 return do_tls_getsockopt(sk, optname, optval, optlen); 655 } 656 657 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, 658 unsigned int optlen, int tx) 659 { 660 struct tls_crypto_info *crypto_info; 661 struct tls_crypto_info *alt_crypto_info; 662 struct tls_context *ctx = tls_get_ctx(sk); 663 size_t optsize; 664 int rc = 0; 665 int conf; 666 667 if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) 668 return -EINVAL; 669 670 if (tx) { 671 crypto_info = &ctx->crypto_send.info; 672 alt_crypto_info = &ctx->crypto_recv.info; 673 } else { 674 crypto_info = &ctx->crypto_recv.info; 675 alt_crypto_info = &ctx->crypto_send.info; 676 } 677 678 /* Currently we don't support set crypto info more than one time */ 679 if (TLS_CRYPTO_INFO_READY(crypto_info)) 680 return -EBUSY; 681 682 rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); 683 if (rc) { 684 rc = -EFAULT; 685 goto err_crypto_info; 686 } 687 688 /* check version */ 689 if (crypto_info->version != TLS_1_2_VERSION && 690 crypto_info->version != TLS_1_3_VERSION) { 691 rc = -EINVAL; 692 goto err_crypto_info; 693 } 694 695 /* Ensure that TLS version and ciphers are same in both directions */ 696 if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { 697 if (alt_crypto_info->version != crypto_info->version || 698 alt_crypto_info->cipher_type != crypto_info->cipher_type) { 699 rc = -EINVAL; 700 goto err_crypto_info; 701 } 702 } 703 704 switch (crypto_info->cipher_type) { 705 case TLS_CIPHER_AES_GCM_128: 706 optsize = sizeof(struct tls12_crypto_info_aes_gcm_128); 707 break; 708 case TLS_CIPHER_AES_GCM_256: { 709 optsize = sizeof(struct tls12_crypto_info_aes_gcm_256); 710 break; 711 } 712 case TLS_CIPHER_AES_CCM_128: 713 optsize = sizeof(struct tls12_crypto_info_aes_ccm_128); 714 break; 715 case TLS_CIPHER_CHACHA20_POLY1305: 716 optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305); 717 break; 718 case TLS_CIPHER_SM4_GCM: 719 optsize = sizeof(struct tls12_crypto_info_sm4_gcm); 720 break; 721 case TLS_CIPHER_SM4_CCM: 722 optsize = sizeof(struct tls12_crypto_info_sm4_ccm); 723 break; 724 case TLS_CIPHER_ARIA_GCM_128: 725 if (crypto_info->version != TLS_1_2_VERSION) { 726 rc = -EINVAL; 727 goto err_crypto_info; 728 } 729 optsize = sizeof(struct tls12_crypto_info_aria_gcm_128); 730 break; 731 case TLS_CIPHER_ARIA_GCM_256: 732 if (crypto_info->version != TLS_1_2_VERSION) { 733 rc = -EINVAL; 734 goto err_crypto_info; 735 } 736 optsize = sizeof(struct tls12_crypto_info_aria_gcm_256); 737 break; 738 default: 739 rc = -EINVAL; 740 goto err_crypto_info; 741 } 742 743 if (optlen != optsize) { 744 rc = -EINVAL; 745 goto err_crypto_info; 746 } 747 748 rc = copy_from_sockptr_offset(crypto_info + 1, optval, 749 sizeof(*crypto_info), 750 optlen - sizeof(*crypto_info)); 751 if (rc) { 752 rc = -EFAULT; 753 goto err_crypto_info; 754 } 755 756 if (tx) { 757 rc = tls_set_device_offload(sk, ctx); 758 conf = TLS_HW; 759 if (!rc) { 760 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); 761 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 762 } else { 763 rc = tls_set_sw_offload(sk, ctx, 1); 764 if (rc) 765 goto err_crypto_info; 766 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); 767 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 768 conf = TLS_SW; 769 } 770 } else { 771 rc = tls_set_device_offload_rx(sk, ctx); 772 conf = TLS_HW; 773 if (!rc) { 774 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); 775 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 776 } else { 777 rc = tls_set_sw_offload(sk, ctx, 0); 778 if (rc) 779 goto err_crypto_info; 780 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); 781 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 782 conf = TLS_SW; 783 } 784 tls_sw_strparser_arm(sk, ctx); 785 } 786 787 if (tx) 788 ctx->tx_conf = conf; 789 else 790 ctx->rx_conf = conf; 791 update_sk_prot(sk, ctx); 792 if (tx) { 793 ctx->sk_write_space = sk->sk_write_space; 794 sk->sk_write_space = tls_write_space; 795 } else { 796 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); 797 798 tls_strp_check_rcv(&rx_ctx->strp); 799 } 800 return 0; 801 802 err_crypto_info: 803 memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); 804 return rc; 805 } 806 807 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, 808 unsigned int optlen) 809 { 810 struct tls_context *ctx = tls_get_ctx(sk); 811 unsigned int value; 812 813 if (sockptr_is_null(optval) || optlen != sizeof(value)) 814 return -EINVAL; 815 816 if (copy_from_sockptr(&value, optval, sizeof(value))) 817 return -EFAULT; 818 819 if (value > 1) 820 return -EINVAL; 821 822 ctx->zerocopy_sendfile = value; 823 824 return 0; 825 } 826 827 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, 828 unsigned int optlen) 829 { 830 struct tls_context *ctx = tls_get_ctx(sk); 831 u32 val; 832 int rc; 833 834 if (ctx->prot_info.version != TLS_1_3_VERSION || 835 sockptr_is_null(optval) || optlen < sizeof(val)) 836 return -EINVAL; 837 838 rc = copy_from_sockptr(&val, optval, sizeof(val)); 839 if (rc) 840 return -EFAULT; 841 if (val > 1) 842 return -EINVAL; 843 rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); 844 if (rc < 1) 845 return rc == 0 ? -EINVAL : rc; 846 847 lock_sock(sk); 848 rc = -EINVAL; 849 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { 850 ctx->rx_no_pad = val; 851 tls_update_rx_zc_capable(ctx); 852 rc = 0; 853 } 854 release_sock(sk); 855 856 return rc; 857 } 858 859 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, 860 unsigned int optlen) 861 { 862 int rc = 0; 863 864 switch (optname) { 865 case TLS_TX: 866 case TLS_RX: 867 lock_sock(sk); 868 rc = do_tls_setsockopt_conf(sk, optval, optlen, 869 optname == TLS_TX); 870 release_sock(sk); 871 break; 872 case TLS_TX_ZEROCOPY_RO: 873 lock_sock(sk); 874 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); 875 release_sock(sk); 876 break; 877 case TLS_RX_EXPECT_NO_PAD: 878 rc = do_tls_setsockopt_no_pad(sk, optval, optlen); 879 break; 880 default: 881 rc = -ENOPROTOOPT; 882 break; 883 } 884 return rc; 885 } 886 887 static int tls_setsockopt(struct sock *sk, int level, int optname, 888 sockptr_t optval, unsigned int optlen) 889 { 890 struct tls_context *ctx = tls_get_ctx(sk); 891 892 if (level != SOL_TLS) 893 return ctx->sk_proto->setsockopt(sk, level, optname, optval, 894 optlen); 895 896 return do_tls_setsockopt(sk, optname, optval, optlen); 897 } 898 899 struct tls_context *tls_ctx_create(struct sock *sk) 900 { 901 struct inet_connection_sock *icsk = inet_csk(sk); 902 struct tls_context *ctx; 903 904 ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); 905 if (!ctx) 906 return NULL; 907 908 mutex_init(&ctx->tx_lock); 909 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 910 ctx->sk_proto = READ_ONCE(sk->sk_prot); 911 ctx->sk = sk; 912 return ctx; 913 } 914 915 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 916 const struct proto_ops *base) 917 { 918 ops[TLS_BASE][TLS_BASE] = *base; 919 920 ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 921 ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked; 922 923 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; 924 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; 925 926 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; 927 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; 928 929 #ifdef CONFIG_TLS_DEVICE 930 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 931 ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL; 932 933 ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; 934 ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL; 935 936 ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; 937 938 ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; 939 940 ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; 941 ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL; 942 #endif 943 #ifdef CONFIG_TLS_TOE 944 ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 945 #endif 946 } 947 948 static void tls_build_proto(struct sock *sk) 949 { 950 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 951 struct proto *prot = READ_ONCE(sk->sk_prot); 952 953 /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ 954 if (ip_ver == TLSV6 && 955 unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) { 956 mutex_lock(&tcpv6_prot_mutex); 957 if (likely(prot != saved_tcpv6_prot)) { 958 build_protos(tls_prots[TLSV6], prot); 959 build_proto_ops(tls_proto_ops[TLSV6], 960 sk->sk_socket->ops); 961 smp_store_release(&saved_tcpv6_prot, prot); 962 } 963 mutex_unlock(&tcpv6_prot_mutex); 964 } 965 966 if (ip_ver == TLSV4 && 967 unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) { 968 mutex_lock(&tcpv4_prot_mutex); 969 if (likely(prot != saved_tcpv4_prot)) { 970 build_protos(tls_prots[TLSV4], prot); 971 build_proto_ops(tls_proto_ops[TLSV4], 972 sk->sk_socket->ops); 973 smp_store_release(&saved_tcpv4_prot, prot); 974 } 975 mutex_unlock(&tcpv4_prot_mutex); 976 } 977 } 978 979 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 980 const struct proto *base) 981 { 982 prot[TLS_BASE][TLS_BASE] = *base; 983 prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; 984 prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; 985 prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; 986 987 prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 988 prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; 989 prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage; 990 991 prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; 992 prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; 993 prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 994 prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; 995 996 prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; 997 prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; 998 prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 999 prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; 1000 1001 #ifdef CONFIG_TLS_DEVICE 1002 prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 1003 prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; 1004 prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage; 1005 1006 prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; 1007 prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; 1008 prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage; 1009 1010 prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; 1011 1012 prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; 1013 1014 prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; 1015 #endif 1016 #ifdef CONFIG_TLS_TOE 1017 prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 1018 prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash; 1019 prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash; 1020 #endif 1021 } 1022 1023 static int tls_init(struct sock *sk) 1024 { 1025 struct tls_context *ctx; 1026 int rc = 0; 1027 1028 tls_build_proto(sk); 1029 1030 #ifdef CONFIG_TLS_TOE 1031 if (tls_toe_bypass(sk)) 1032 return 0; 1033 #endif 1034 1035 /* The TLS ulp is currently supported only for TCP sockets 1036 * in ESTABLISHED state. 1037 * Supporting sockets in LISTEN state will require us 1038 * to modify the accept implementation to clone rather then 1039 * share the ulp context. 1040 */ 1041 if (sk->sk_state != TCP_ESTABLISHED) 1042 return -ENOTCONN; 1043 1044 /* allocate tls context */ 1045 write_lock_bh(&sk->sk_callback_lock); 1046 ctx = tls_ctx_create(sk); 1047 if (!ctx) { 1048 rc = -ENOMEM; 1049 goto out; 1050 } 1051 1052 ctx->tx_conf = TLS_BASE; 1053 ctx->rx_conf = TLS_BASE; 1054 update_sk_prot(sk, ctx); 1055 out: 1056 write_unlock_bh(&sk->sk_callback_lock); 1057 return rc; 1058 } 1059 1060 static void tls_update(struct sock *sk, struct proto *p, 1061 void (*write_space)(struct sock *sk)) 1062 { 1063 struct tls_context *ctx; 1064 1065 WARN_ON_ONCE(sk->sk_prot == p); 1066 1067 ctx = tls_get_ctx(sk); 1068 if (likely(ctx)) { 1069 ctx->sk_write_space = write_space; 1070 ctx->sk_proto = p; 1071 } else { 1072 /* Pairs with lockless read in sk_clone_lock(). */ 1073 WRITE_ONCE(sk->sk_prot, p); 1074 sk->sk_write_space = write_space; 1075 } 1076 } 1077 1078 static u16 tls_user_config(struct tls_context *ctx, bool tx) 1079 { 1080 u16 config = tx ? ctx->tx_conf : ctx->rx_conf; 1081 1082 switch (config) { 1083 case TLS_BASE: 1084 return TLS_CONF_BASE; 1085 case TLS_SW: 1086 return TLS_CONF_SW; 1087 case TLS_HW: 1088 return TLS_CONF_HW; 1089 case TLS_HW_RECORD: 1090 return TLS_CONF_HW_RECORD; 1091 } 1092 return 0; 1093 } 1094 1095 static int tls_get_info(const struct sock *sk, struct sk_buff *skb) 1096 { 1097 u16 version, cipher_type; 1098 struct tls_context *ctx; 1099 struct nlattr *start; 1100 int err; 1101 1102 start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); 1103 if (!start) 1104 return -EMSGSIZE; 1105 1106 rcu_read_lock(); 1107 ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); 1108 if (!ctx) { 1109 err = 0; 1110 goto nla_failure; 1111 } 1112 version = ctx->prot_info.version; 1113 if (version) { 1114 err = nla_put_u16(skb, TLS_INFO_VERSION, version); 1115 if (err) 1116 goto nla_failure; 1117 } 1118 cipher_type = ctx->prot_info.cipher_type; 1119 if (cipher_type) { 1120 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); 1121 if (err) 1122 goto nla_failure; 1123 } 1124 err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); 1125 if (err) 1126 goto nla_failure; 1127 1128 err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); 1129 if (err) 1130 goto nla_failure; 1131 1132 if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) { 1133 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX); 1134 if (err) 1135 goto nla_failure; 1136 } 1137 if (ctx->rx_no_pad) { 1138 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); 1139 if (err) 1140 goto nla_failure; 1141 } 1142 1143 rcu_read_unlock(); 1144 nla_nest_end(skb, start); 1145 return 0; 1146 1147 nla_failure: 1148 rcu_read_unlock(); 1149 nla_nest_cancel(skb, start); 1150 return err; 1151 } 1152 1153 static size_t tls_get_info_size(const struct sock *sk) 1154 { 1155 size_t size = 0; 1156 1157 size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ 1158 nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ 1159 nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ 1160 nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ 1161 nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ 1162 nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ 1163 nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ 1164 0; 1165 1166 return size; 1167 } 1168 1169 static int __net_init tls_init_net(struct net *net) 1170 { 1171 int err; 1172 1173 net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); 1174 if (!net->mib.tls_statistics) 1175 return -ENOMEM; 1176 1177 err = tls_proc_init(net); 1178 if (err) 1179 goto err_free_stats; 1180 1181 return 0; 1182 err_free_stats: 1183 free_percpu(net->mib.tls_statistics); 1184 return err; 1185 } 1186 1187 static void __net_exit tls_exit_net(struct net *net) 1188 { 1189 tls_proc_fini(net); 1190 free_percpu(net->mib.tls_statistics); 1191 } 1192 1193 static struct pernet_operations tls_proc_ops = { 1194 .init = tls_init_net, 1195 .exit = tls_exit_net, 1196 }; 1197 1198 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 1199 .name = "tls", 1200 .owner = THIS_MODULE, 1201 .init = tls_init, 1202 .update = tls_update, 1203 .get_info = tls_get_info, 1204 .get_info_size = tls_get_info_size, 1205 }; 1206 1207 static int __init tls_register(void) 1208 { 1209 int err; 1210 1211 err = register_pernet_subsys(&tls_proc_ops); 1212 if (err) 1213 return err; 1214 1215 err = tls_strp_dev_init(); 1216 if (err) 1217 goto err_pernet; 1218 1219 err = tls_device_init(); 1220 if (err) 1221 goto err_strp; 1222 1223 tcp_register_ulp(&tcp_tls_ulp_ops); 1224 1225 return 0; 1226 err_strp: 1227 tls_strp_dev_exit(); 1228 err_pernet: 1229 unregister_pernet_subsys(&tls_proc_ops); 1230 return err; 1231 } 1232 1233 static void __exit tls_unregister(void) 1234 { 1235 tcp_unregister_ulp(&tcp_tls_ulp_ops); 1236 tls_strp_dev_exit(); 1237 tls_device_cleanup(); 1238 unregister_pernet_subsys(&tls_proc_ops); 1239 } 1240 1241 module_init(tls_register); 1242 module_exit(tls_unregister); 1243