xref: /linux/net/tls/tls_main.c (revision 8f7aa3d3c7323f4ca2768a9e74ebbe359c4f8f88)
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  *
5  * This software is available to you under a choice of one of two
6  * licenses.  You may choose to be licensed under the terms of the GNU
7  * General Public License (GPL) Version 2, available from the file
8  * COPYING in the main directory of this source tree, or the
9  * OpenIB.org BSD license below:
10  *
11  *     Redistribution and use in source and binary forms, with or
12  *     without modification, are permitted provided that the following
13  *     conditions are met:
14  *
15  *      - Redistributions of source code must retain the above
16  *        copyright notice, this list of conditions and the following
17  *        disclaimer.
18  *
19  *      - Redistributions in binary form must reproduce the above
20  *        copyright notice, this list of conditions and the following
21  *        disclaimer in the documentation and/or other materials
22  *        provided with the distribution.
23  *
24  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31  * SOFTWARE.
32  */
33 
34 #include <linux/module.h>
35 
36 #include <net/tcp.h>
37 #include <net/inet_common.h>
38 #include <linux/highmem.h>
39 #include <linux/netdevice.h>
40 #include <linux/sched/signal.h>
41 #include <linux/inetdevice.h>
42 #include <linux/inet_diag.h>
43 
44 #include <net/snmp.h>
45 #include <net/tls.h>
46 #include <net/tls_toe.h>
47 
48 #include "tls.h"
49 
50 MODULE_AUTHOR("Mellanox Technologies");
51 MODULE_DESCRIPTION("Transport Layer Security Support");
52 MODULE_LICENSE("Dual BSD/GPL");
53 MODULE_ALIAS_TCP_ULP("tls");
54 
55 enum {
56 	TLSV4,
57 	TLSV6,
58 	TLS_NUM_PROTS,
59 };
60 
61 #define CHECK_CIPHER_DESC(cipher,ci)				\
62 	static_assert(cipher ## _IV_SIZE <= TLS_MAX_IV_SIZE);		\
63 	static_assert(cipher ## _SALT_SIZE <= TLS_MAX_SALT_SIZE);		\
64 	static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE);	\
65 	static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE);		\
66 	static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE);	\
67 	static_assert(sizeof_field(struct ci, key) == cipher ## _KEY_SIZE);	\
68 	static_assert(sizeof_field(struct ci, salt) == cipher ## _SALT_SIZE);	\
69 	static_assert(sizeof_field(struct ci, rec_seq) == cipher ## _REC_SEQ_SIZE);
70 
71 #define __CIPHER_DESC(ci) \
72 	.iv_offset = offsetof(struct ci, iv), \
73 	.key_offset = offsetof(struct ci, key), \
74 	.salt_offset = offsetof(struct ci, salt), \
75 	.rec_seq_offset = offsetof(struct ci, rec_seq), \
76 	.crypto_info = sizeof(struct ci)
77 
78 #define CIPHER_DESC(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = {	\
79 	.nonce = cipher ## _IV_SIZE, \
80 	.iv = cipher ## _IV_SIZE, \
81 	.key = cipher ## _KEY_SIZE, \
82 	.salt = cipher ## _SALT_SIZE, \
83 	.tag = cipher ## _TAG_SIZE, \
84 	.rec_seq = cipher ## _REC_SEQ_SIZE, \
85 	.cipher_name = algname,	\
86 	.offloadable = _offloadable, \
87 	__CIPHER_DESC(ci), \
88 }
89 
90 #define CIPHER_DESC_NONCE0(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \
91 	.nonce = 0, \
92 	.iv = cipher ## _IV_SIZE, \
93 	.key = cipher ## _KEY_SIZE, \
94 	.salt = cipher ## _SALT_SIZE, \
95 	.tag = cipher ## _TAG_SIZE, \
96 	.rec_seq = cipher ## _REC_SEQ_SIZE, \
97 	.cipher_name = algname,	\
98 	.offloadable = _offloadable, \
99 	__CIPHER_DESC(ci), \
100 }
101 
102 const struct tls_cipher_desc tls_cipher_desc[TLS_CIPHER_MAX + 1 - TLS_CIPHER_MIN] = {
103 	CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128, "gcm(aes)", true),
104 	CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256, "gcm(aes)", true),
105 	CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128, "ccm(aes)", false),
106 	CIPHER_DESC_NONCE0(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305, "rfc7539(chacha20,poly1305)", false),
107 	CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm, "gcm(sm4)", false),
108 	CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm, "ccm(sm4)", false),
109 	CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128, "gcm(aria)", false),
110 	CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256, "gcm(aria)", false),
111 };
112 
113 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128);
114 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256);
115 CHECK_CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128);
116 CHECK_CIPHER_DESC(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305);
117 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm);
118 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm);
119 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128);
120 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256);
121 
122 static const struct proto *saved_tcpv6_prot;
123 static DEFINE_MUTEX(tcpv6_prot_mutex);
124 static const struct proto *saved_tcpv4_prot;
125 static DEFINE_MUTEX(tcpv4_prot_mutex);
126 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
127 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
128 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
129 			 const struct proto *base);
130 
131 void update_sk_prot(struct sock *sk, struct tls_context *ctx)
132 {
133 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
134 
135 	WRITE_ONCE(sk->sk_prot,
136 		   &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
137 	WRITE_ONCE(sk->sk_socket->ops,
138 		   &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
139 }
140 
141 int wait_on_pending_writer(struct sock *sk, long *timeo)
142 {
143 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
144 	int ret, rc = 0;
145 
146 	add_wait_queue(sk_sleep(sk), &wait);
147 	while (1) {
148 		if (!*timeo) {
149 			rc = -EAGAIN;
150 			break;
151 		}
152 
153 		if (signal_pending(current)) {
154 			rc = sock_intr_errno(*timeo);
155 			break;
156 		}
157 
158 		ret = sk_wait_event(sk, timeo,
159 				    !READ_ONCE(sk->sk_write_pending), &wait);
160 		if (ret) {
161 			if (ret < 0)
162 				rc = ret;
163 			break;
164 		}
165 	}
166 	remove_wait_queue(sk_sleep(sk), &wait);
167 	return rc;
168 }
169 
170 int tls_push_sg(struct sock *sk,
171 		struct tls_context *ctx,
172 		struct scatterlist *sg,
173 		u16 first_offset,
174 		int flags)
175 {
176 	struct bio_vec bvec;
177 	struct msghdr msg = {
178 		.msg_flags = MSG_SPLICE_PAGES | flags,
179 	};
180 	int ret = 0;
181 	struct page *p;
182 	size_t size;
183 	int offset = first_offset;
184 
185 	size = sg->length - offset;
186 	offset += sg->offset;
187 
188 	ctx->splicing_pages = true;
189 	while (1) {
190 		/* is sending application-limited? */
191 		tcp_rate_check_app_limited(sk);
192 		p = sg_page(sg);
193 retry:
194 		bvec_set_page(&bvec, p, size, offset);
195 		iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size);
196 
197 		ret = tcp_sendmsg_locked(sk, &msg, size);
198 
199 		if (ret != size) {
200 			if (ret > 0) {
201 				offset += ret;
202 				size -= ret;
203 				goto retry;
204 			}
205 
206 			offset -= sg->offset;
207 			ctx->partially_sent_offset = offset;
208 			ctx->partially_sent_record = (void *)sg;
209 			ctx->splicing_pages = false;
210 			return ret;
211 		}
212 
213 		put_page(p);
214 		sk_mem_uncharge(sk, sg->length);
215 		sg = sg_next(sg);
216 		if (!sg)
217 			break;
218 
219 		offset = sg->offset;
220 		size = sg->length;
221 	}
222 
223 	ctx->splicing_pages = false;
224 
225 	return 0;
226 }
227 
228 static int tls_handle_open_record(struct sock *sk, int flags)
229 {
230 	struct tls_context *ctx = tls_get_ctx(sk);
231 
232 	if (tls_is_pending_open_record(ctx))
233 		return ctx->push_pending_record(sk, flags);
234 
235 	return 0;
236 }
237 
238 int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
239 		     unsigned char *record_type)
240 {
241 	struct cmsghdr *cmsg;
242 	int rc = -EINVAL;
243 
244 	for_each_cmsghdr(cmsg, msg) {
245 		if (!CMSG_OK(msg, cmsg))
246 			return -EINVAL;
247 		if (cmsg->cmsg_level != SOL_TLS)
248 			continue;
249 
250 		switch (cmsg->cmsg_type) {
251 		case TLS_SET_RECORD_TYPE:
252 			if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
253 				return -EINVAL;
254 
255 			if (msg->msg_flags & MSG_MORE)
256 				return -EINVAL;
257 
258 			*record_type = *(unsigned char *)CMSG_DATA(cmsg);
259 
260 			rc = tls_handle_open_record(sk, msg->msg_flags);
261 			break;
262 		default:
263 			return -EINVAL;
264 		}
265 	}
266 
267 	return rc;
268 }
269 
270 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
271 			    int flags)
272 {
273 	struct scatterlist *sg;
274 	u16 offset;
275 
276 	sg = ctx->partially_sent_record;
277 	offset = ctx->partially_sent_offset;
278 
279 	ctx->partially_sent_record = NULL;
280 	return tls_push_sg(sk, ctx, sg, offset, flags);
281 }
282 
283 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
284 {
285 	struct scatterlist *sg;
286 
287 	for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
288 		put_page(sg_page(sg));
289 		sk_mem_uncharge(sk, sg->length);
290 	}
291 	ctx->partially_sent_record = NULL;
292 }
293 
294 static void tls_write_space(struct sock *sk)
295 {
296 	struct tls_context *ctx = tls_get_ctx(sk);
297 
298 	/* If splicing_pages call lower protocol write space handler
299 	 * to ensure we wake up any waiting operations there. For example
300 	 * if splicing pages where to call sk_wait_event.
301 	 */
302 	if (ctx->splicing_pages) {
303 		ctx->sk_write_space(sk);
304 		return;
305 	}
306 
307 #ifdef CONFIG_TLS_DEVICE
308 	if (ctx->tx_conf == TLS_HW)
309 		tls_device_write_space(sk, ctx);
310 	else
311 #endif
312 		tls_sw_write_space(sk, ctx);
313 
314 	ctx->sk_write_space(sk);
315 }
316 
317 /**
318  * tls_ctx_free() - free TLS ULP context
319  * @sk:  socket to with @ctx is attached
320  * @ctx: TLS context structure
321  *
322  * Free TLS context. If @sk is %NULL caller guarantees that the socket
323  * to which @ctx was attached has no outstanding references.
324  */
325 void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
326 {
327 	if (!ctx)
328 		return;
329 
330 	memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
331 	memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
332 	mutex_destroy(&ctx->tx_lock);
333 
334 	if (sk)
335 		kfree_rcu(ctx, rcu);
336 	else
337 		kfree(ctx);
338 }
339 
340 static void tls_sk_proto_cleanup(struct sock *sk,
341 				 struct tls_context *ctx, long timeo)
342 {
343 	if (unlikely(sk->sk_write_pending) &&
344 	    !wait_on_pending_writer(sk, &timeo))
345 		tls_handle_open_record(sk, 0);
346 
347 	/* We need these for tls_sw_fallback handling of other packets */
348 	if (ctx->tx_conf == TLS_SW) {
349 		tls_sw_release_resources_tx(sk);
350 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
351 	} else if (ctx->tx_conf == TLS_HW) {
352 		tls_device_free_resources_tx(sk);
353 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
354 	}
355 
356 	if (ctx->rx_conf == TLS_SW) {
357 		tls_sw_release_resources_rx(sk);
358 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
359 	} else if (ctx->rx_conf == TLS_HW) {
360 		tls_device_offload_cleanup_rx(sk);
361 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
362 	}
363 }
364 
365 static void tls_sk_proto_close(struct sock *sk, long timeout)
366 {
367 	struct inet_connection_sock *icsk = inet_csk(sk);
368 	struct tls_context *ctx = tls_get_ctx(sk);
369 	long timeo = sock_sndtimeo(sk, 0);
370 	bool free_ctx;
371 
372 	if (ctx->tx_conf == TLS_SW)
373 		tls_sw_cancel_work_tx(ctx);
374 
375 	lock_sock(sk);
376 	free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
377 
378 	if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
379 		tls_sk_proto_cleanup(sk, ctx, timeo);
380 
381 	write_lock_bh(&sk->sk_callback_lock);
382 	if (free_ctx)
383 		rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
384 	WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
385 	if (sk->sk_write_space == tls_write_space)
386 		sk->sk_write_space = ctx->sk_write_space;
387 	write_unlock_bh(&sk->sk_callback_lock);
388 	release_sock(sk);
389 	if (ctx->tx_conf == TLS_SW)
390 		tls_sw_free_ctx_tx(ctx);
391 	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
392 		tls_sw_strparser_done(ctx);
393 	if (ctx->rx_conf == TLS_SW)
394 		tls_sw_free_ctx_rx(ctx);
395 	ctx->sk_proto->close(sk, timeout);
396 
397 	if (free_ctx)
398 		tls_ctx_free(sk, ctx);
399 }
400 
401 static __poll_t tls_sk_poll(struct file *file, struct socket *sock,
402 			    struct poll_table_struct *wait)
403 {
404 	struct tls_sw_context_rx *ctx;
405 	struct tls_context *tls_ctx;
406 	struct sock *sk = sock->sk;
407 	struct sk_psock *psock;
408 	__poll_t mask = 0;
409 	u8 shutdown;
410 	int state;
411 
412 	mask = tcp_poll(file, sock, wait);
413 
414 	state = inet_sk_state_load(sk);
415 	shutdown = READ_ONCE(sk->sk_shutdown);
416 	if (unlikely(state != TCP_ESTABLISHED || shutdown & RCV_SHUTDOWN))
417 		return mask;
418 
419 	tls_ctx = tls_get_ctx(sk);
420 	ctx = tls_sw_ctx_rx(tls_ctx);
421 	psock = sk_psock_get(sk);
422 
423 	if ((skb_queue_empty_lockless(&ctx->rx_list) &&
424 	     !tls_strp_msg_ready(ctx) &&
425 	     sk_psock_queue_empty(psock)) ||
426 	    READ_ONCE(ctx->key_update_pending))
427 		mask &= ~(EPOLLIN | EPOLLRDNORM);
428 
429 	if (psock)
430 		sk_psock_put(sk, psock);
431 
432 	return mask;
433 }
434 
435 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
436 				  int __user *optlen, int tx)
437 {
438 	int rc = 0;
439 	const struct tls_cipher_desc *cipher_desc;
440 	struct tls_context *ctx = tls_get_ctx(sk);
441 	struct tls_crypto_info *crypto_info;
442 	struct cipher_context *cctx;
443 	int len;
444 
445 	if (get_user(len, optlen))
446 		return -EFAULT;
447 
448 	if (!optval || (len < sizeof(*crypto_info))) {
449 		rc = -EINVAL;
450 		goto out;
451 	}
452 
453 	if (!ctx) {
454 		rc = -EBUSY;
455 		goto out;
456 	}
457 
458 	/* get user crypto info */
459 	if (tx) {
460 		crypto_info = &ctx->crypto_send.info;
461 		cctx = &ctx->tx;
462 	} else {
463 		crypto_info = &ctx->crypto_recv.info;
464 		cctx = &ctx->rx;
465 	}
466 
467 	if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
468 		rc = -EBUSY;
469 		goto out;
470 	}
471 
472 	if (len == sizeof(*crypto_info)) {
473 		if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
474 			rc = -EFAULT;
475 		goto out;
476 	}
477 
478 	cipher_desc = get_cipher_desc(crypto_info->cipher_type);
479 	if (!cipher_desc || len != cipher_desc->crypto_info) {
480 		rc = -EINVAL;
481 		goto out;
482 	}
483 
484 	memcpy(crypto_info_iv(crypto_info, cipher_desc),
485 	       cctx->iv + cipher_desc->salt, cipher_desc->iv);
486 	memcpy(crypto_info_rec_seq(crypto_info, cipher_desc),
487 	       cctx->rec_seq, cipher_desc->rec_seq);
488 
489 	if (copy_to_user(optval, crypto_info, cipher_desc->crypto_info))
490 		rc = -EFAULT;
491 
492 out:
493 	return rc;
494 }
495 
496 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
497 				   int __user *optlen)
498 {
499 	struct tls_context *ctx = tls_get_ctx(sk);
500 	unsigned int value;
501 	int len;
502 
503 	if (get_user(len, optlen))
504 		return -EFAULT;
505 
506 	if (len != sizeof(value))
507 		return -EINVAL;
508 
509 	value = ctx->zerocopy_sendfile;
510 	if (copy_to_user(optval, &value, sizeof(value)))
511 		return -EFAULT;
512 
513 	return 0;
514 }
515 
516 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
517 				    int __user *optlen)
518 {
519 	struct tls_context *ctx = tls_get_ctx(sk);
520 	int value, len;
521 
522 	if (ctx->prot_info.version != TLS_1_3_VERSION)
523 		return -EINVAL;
524 
525 	if (get_user(len, optlen))
526 		return -EFAULT;
527 	if (len < sizeof(value))
528 		return -EINVAL;
529 
530 	value = -EINVAL;
531 	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
532 		value = ctx->rx_no_pad;
533 	if (value < 0)
534 		return value;
535 
536 	if (put_user(sizeof(value), optlen))
537 		return -EFAULT;
538 	if (copy_to_user(optval, &value, sizeof(value)))
539 		return -EFAULT;
540 
541 	return 0;
542 }
543 
544 static int do_tls_getsockopt_tx_payload_len(struct sock *sk, char __user *optval,
545 					    int __user *optlen)
546 {
547 	struct tls_context *ctx = tls_get_ctx(sk);
548 	u16 payload_len = ctx->tx_max_payload_len;
549 	int len;
550 
551 	if (get_user(len, optlen))
552 		return -EFAULT;
553 
554 	if (len < sizeof(payload_len))
555 		return -EINVAL;
556 
557 	if (put_user(sizeof(payload_len), optlen))
558 		return -EFAULT;
559 
560 	if (copy_to_user(optval, &payload_len, sizeof(payload_len)))
561 		return -EFAULT;
562 
563 	return 0;
564 }
565 
566 static int do_tls_getsockopt(struct sock *sk, int optname,
567 			     char __user *optval, int __user *optlen)
568 {
569 	int rc = 0;
570 
571 	lock_sock(sk);
572 
573 	switch (optname) {
574 	case TLS_TX:
575 	case TLS_RX:
576 		rc = do_tls_getsockopt_conf(sk, optval, optlen,
577 					    optname == TLS_TX);
578 		break;
579 	case TLS_TX_ZEROCOPY_RO:
580 		rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
581 		break;
582 	case TLS_RX_EXPECT_NO_PAD:
583 		rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
584 		break;
585 	case TLS_TX_MAX_PAYLOAD_LEN:
586 		rc = do_tls_getsockopt_tx_payload_len(sk, optval, optlen);
587 		break;
588 	default:
589 		rc = -ENOPROTOOPT;
590 		break;
591 	}
592 
593 	release_sock(sk);
594 
595 	return rc;
596 }
597 
598 static int tls_getsockopt(struct sock *sk, int level, int optname,
599 			  char __user *optval, int __user *optlen)
600 {
601 	struct tls_context *ctx = tls_get_ctx(sk);
602 
603 	if (level != SOL_TLS)
604 		return ctx->sk_proto->getsockopt(sk, level,
605 						 optname, optval, optlen);
606 
607 	return do_tls_getsockopt(sk, optname, optval, optlen);
608 }
609 
610 static int validate_crypto_info(const struct tls_crypto_info *crypto_info,
611 				const struct tls_crypto_info *alt_crypto_info)
612 {
613 	if (crypto_info->version != TLS_1_2_VERSION &&
614 	    crypto_info->version != TLS_1_3_VERSION)
615 		return -EINVAL;
616 
617 	switch (crypto_info->cipher_type) {
618 	case TLS_CIPHER_ARIA_GCM_128:
619 	case TLS_CIPHER_ARIA_GCM_256:
620 		if (crypto_info->version != TLS_1_2_VERSION)
621 			return -EINVAL;
622 		break;
623 	}
624 
625 	/* Ensure that TLS version and ciphers are same in both directions */
626 	if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
627 		if (alt_crypto_info->version != crypto_info->version ||
628 		    alt_crypto_info->cipher_type != crypto_info->cipher_type)
629 			return -EINVAL;
630 	}
631 
632 	return 0;
633 }
634 
635 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
636 				  unsigned int optlen, int tx)
637 {
638 	struct tls_crypto_info *crypto_info, *alt_crypto_info;
639 	struct tls_crypto_info *old_crypto_info = NULL;
640 	struct tls_context *ctx = tls_get_ctx(sk);
641 	const struct tls_cipher_desc *cipher_desc;
642 	union tls_crypto_context *crypto_ctx;
643 	union tls_crypto_context tmp = {};
644 	bool update = false;
645 	int rc = 0;
646 	int conf;
647 
648 	if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
649 		return -EINVAL;
650 
651 	if (tx) {
652 		crypto_ctx = &ctx->crypto_send;
653 		alt_crypto_info = &ctx->crypto_recv.info;
654 	} else {
655 		crypto_ctx = &ctx->crypto_recv;
656 		alt_crypto_info = &ctx->crypto_send.info;
657 	}
658 
659 	crypto_info = &crypto_ctx->info;
660 
661 	if (TLS_CRYPTO_INFO_READY(crypto_info)) {
662 		/* Currently we only support setting crypto info more
663 		 * than one time for TLS 1.3
664 		 */
665 		if (crypto_info->version != TLS_1_3_VERSION) {
666 			TLS_INC_STATS(sock_net(sk), tx ? LINUX_MIB_TLSTXREKEYERROR
667 						       : LINUX_MIB_TLSRXREKEYERROR);
668 			return -EBUSY;
669 		}
670 
671 		update = true;
672 		old_crypto_info = crypto_info;
673 		crypto_info = &tmp.info;
674 		crypto_ctx = &tmp;
675 	}
676 
677 	rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
678 	if (rc) {
679 		rc = -EFAULT;
680 		goto err_crypto_info;
681 	}
682 
683 	if (update) {
684 		/* Ensure that TLS version and ciphers are not modified */
685 		if (crypto_info->version != old_crypto_info->version ||
686 		    crypto_info->cipher_type != old_crypto_info->cipher_type)
687 			rc = -EINVAL;
688 	} else {
689 		rc = validate_crypto_info(crypto_info, alt_crypto_info);
690 	}
691 	if (rc)
692 		goto err_crypto_info;
693 
694 	cipher_desc = get_cipher_desc(crypto_info->cipher_type);
695 	if (!cipher_desc) {
696 		rc = -EINVAL;
697 		goto err_crypto_info;
698 	}
699 
700 	if (optlen != cipher_desc->crypto_info) {
701 		rc = -EINVAL;
702 		goto err_crypto_info;
703 	}
704 
705 	rc = copy_from_sockptr_offset(crypto_info + 1, optval,
706 				      sizeof(*crypto_info),
707 				      optlen - sizeof(*crypto_info));
708 	if (rc) {
709 		rc = -EFAULT;
710 		goto err_crypto_info;
711 	}
712 
713 	if (tx) {
714 		rc = tls_set_device_offload(sk);
715 		conf = TLS_HW;
716 		if (!rc) {
717 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
718 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
719 		} else {
720 			rc = tls_set_sw_offload(sk, 1,
721 						update ? crypto_info : NULL);
722 			if (rc)
723 				goto err_crypto_info;
724 
725 			if (update) {
726 				TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXREKEYOK);
727 			} else {
728 				TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
729 				TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
730 			}
731 			conf = TLS_SW;
732 		}
733 	} else {
734 		rc = tls_set_device_offload_rx(sk, ctx);
735 		conf = TLS_HW;
736 		if (!rc) {
737 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
738 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
739 		} else {
740 			rc = tls_set_sw_offload(sk, 0,
741 						update ? crypto_info : NULL);
742 			if (rc)
743 				goto err_crypto_info;
744 
745 			if (update) {
746 				TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYOK);
747 			} else {
748 				TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
749 				TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
750 			}
751 			conf = TLS_SW;
752 		}
753 		if (!update)
754 			tls_sw_strparser_arm(sk, ctx);
755 	}
756 
757 	if (tx)
758 		ctx->tx_conf = conf;
759 	else
760 		ctx->rx_conf = conf;
761 	update_sk_prot(sk, ctx);
762 
763 	if (update)
764 		return 0;
765 
766 	if (tx) {
767 		ctx->sk_write_space = sk->sk_write_space;
768 		sk->sk_write_space = tls_write_space;
769 	} else {
770 		struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx);
771 
772 		tls_strp_check_rcv(&rx_ctx->strp);
773 	}
774 	return 0;
775 
776 err_crypto_info:
777 	if (update) {
778 		TLS_INC_STATS(sock_net(sk), tx ? LINUX_MIB_TLSTXREKEYERROR
779 					       : LINUX_MIB_TLSRXREKEYERROR);
780 	}
781 	memzero_explicit(crypto_ctx, sizeof(*crypto_ctx));
782 	return rc;
783 }
784 
785 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
786 				   unsigned int optlen)
787 {
788 	struct tls_context *ctx = tls_get_ctx(sk);
789 	unsigned int value;
790 
791 	if (sockptr_is_null(optval) || optlen != sizeof(value))
792 		return -EINVAL;
793 
794 	if (copy_from_sockptr(&value, optval, sizeof(value)))
795 		return -EFAULT;
796 
797 	if (value > 1)
798 		return -EINVAL;
799 
800 	ctx->zerocopy_sendfile = value;
801 
802 	return 0;
803 }
804 
805 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
806 				    unsigned int optlen)
807 {
808 	struct tls_context *ctx = tls_get_ctx(sk);
809 	u32 val;
810 	int rc;
811 
812 	if (ctx->prot_info.version != TLS_1_3_VERSION ||
813 	    sockptr_is_null(optval) || optlen < sizeof(val))
814 		return -EINVAL;
815 
816 	rc = copy_from_sockptr(&val, optval, sizeof(val));
817 	if (rc)
818 		return -EFAULT;
819 	if (val > 1)
820 		return -EINVAL;
821 	rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
822 	if (rc < 1)
823 		return rc == 0 ? -EINVAL : rc;
824 
825 	lock_sock(sk);
826 	rc = -EINVAL;
827 	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
828 		ctx->rx_no_pad = val;
829 		tls_update_rx_zc_capable(ctx);
830 		rc = 0;
831 	}
832 	release_sock(sk);
833 
834 	return rc;
835 }
836 
837 static int do_tls_setsockopt_tx_payload_len(struct sock *sk, sockptr_t optval,
838 					    unsigned int optlen)
839 {
840 	struct tls_context *ctx = tls_get_ctx(sk);
841 	struct tls_sw_context_tx *sw_ctx = tls_sw_ctx_tx(ctx);
842 	u16 value;
843 	bool tls_13 = ctx->prot_info.version == TLS_1_3_VERSION;
844 
845 	if (sw_ctx && sw_ctx->open_rec)
846 		return -EBUSY;
847 
848 	if (sockptr_is_null(optval) || optlen != sizeof(value))
849 		return -EINVAL;
850 
851 	if (copy_from_sockptr(&value, optval, sizeof(value)))
852 		return -EFAULT;
853 
854 	if (value < TLS_MIN_RECORD_SIZE_LIM - (tls_13 ? 1 : 0) ||
855 	    value > TLS_MAX_PAYLOAD_SIZE)
856 		return -EINVAL;
857 
858 	ctx->tx_max_payload_len = value;
859 
860 	return 0;
861 }
862 
863 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
864 			     unsigned int optlen)
865 {
866 	int rc = 0;
867 
868 	switch (optname) {
869 	case TLS_TX:
870 	case TLS_RX:
871 		lock_sock(sk);
872 		rc = do_tls_setsockopt_conf(sk, optval, optlen,
873 					    optname == TLS_TX);
874 		release_sock(sk);
875 		break;
876 	case TLS_TX_ZEROCOPY_RO:
877 		lock_sock(sk);
878 		rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
879 		release_sock(sk);
880 		break;
881 	case TLS_RX_EXPECT_NO_PAD:
882 		rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
883 		break;
884 	case TLS_TX_MAX_PAYLOAD_LEN:
885 		lock_sock(sk);
886 		rc = do_tls_setsockopt_tx_payload_len(sk, optval, optlen);
887 		release_sock(sk);
888 		break;
889 	default:
890 		rc = -ENOPROTOOPT;
891 		break;
892 	}
893 	return rc;
894 }
895 
896 static int tls_setsockopt(struct sock *sk, int level, int optname,
897 			  sockptr_t optval, unsigned int optlen)
898 {
899 	struct tls_context *ctx = tls_get_ctx(sk);
900 
901 	if (level != SOL_TLS)
902 		return ctx->sk_proto->setsockopt(sk, level, optname, optval,
903 						 optlen);
904 
905 	return do_tls_setsockopt(sk, optname, optval, optlen);
906 }
907 
908 static int tls_disconnect(struct sock *sk, int flags)
909 {
910 	return -EOPNOTSUPP;
911 }
912 
913 struct tls_context *tls_ctx_create(struct sock *sk)
914 {
915 	struct inet_connection_sock *icsk = inet_csk(sk);
916 	struct tls_context *ctx;
917 
918 	ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
919 	if (!ctx)
920 		return NULL;
921 
922 	mutex_init(&ctx->tx_lock);
923 	ctx->sk_proto = READ_ONCE(sk->sk_prot);
924 	ctx->sk = sk;
925 	/* Release semantic of rcu_assign_pointer() ensures that
926 	 * ctx->sk_proto is visible before changing sk->sk_prot in
927 	 * update_sk_prot(), and prevents reading uninitialized value in
928 	 * tls_{getsockopt, setsockopt}. Note that we do not need a
929 	 * read barrier in tls_{getsockopt,setsockopt} as there is an
930 	 * address dependency between sk->sk_proto->{getsockopt,setsockopt}
931 	 * and ctx->sk_proto.
932 	 */
933 	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
934 	return ctx;
935 }
936 
937 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
938 			    const struct proto_ops *base)
939 {
940 	ops[TLS_BASE][TLS_BASE] = *base;
941 
942 	ops[TLS_SW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
943 	ops[TLS_SW  ][TLS_BASE].splice_eof	= tls_sw_splice_eof;
944 
945 	ops[TLS_BASE][TLS_SW  ] = ops[TLS_BASE][TLS_BASE];
946 	ops[TLS_BASE][TLS_SW  ].splice_read	= tls_sw_splice_read;
947 	ops[TLS_BASE][TLS_SW  ].poll		= tls_sk_poll;
948 	ops[TLS_BASE][TLS_SW  ].read_sock	= tls_sw_read_sock;
949 
950 	ops[TLS_SW  ][TLS_SW  ] = ops[TLS_SW  ][TLS_BASE];
951 	ops[TLS_SW  ][TLS_SW  ].splice_read	= tls_sw_splice_read;
952 	ops[TLS_SW  ][TLS_SW  ].poll		= tls_sk_poll;
953 	ops[TLS_SW  ][TLS_SW  ].read_sock	= tls_sw_read_sock;
954 
955 #ifdef CONFIG_TLS_DEVICE
956 	ops[TLS_HW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
957 
958 	ops[TLS_HW  ][TLS_SW  ] = ops[TLS_BASE][TLS_SW  ];
959 
960 	ops[TLS_BASE][TLS_HW  ] = ops[TLS_BASE][TLS_SW  ];
961 
962 	ops[TLS_SW  ][TLS_HW  ] = ops[TLS_SW  ][TLS_SW  ];
963 
964 	ops[TLS_HW  ][TLS_HW  ] = ops[TLS_HW  ][TLS_SW  ];
965 #endif
966 #ifdef CONFIG_TLS_TOE
967 	ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
968 #endif
969 }
970 
971 static void tls_build_proto(struct sock *sk)
972 {
973 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
974 	struct proto *prot = READ_ONCE(sk->sk_prot);
975 
976 	/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
977 	if (ip_ver == TLSV6 &&
978 	    unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
979 		mutex_lock(&tcpv6_prot_mutex);
980 		if (likely(prot != saved_tcpv6_prot)) {
981 			build_protos(tls_prots[TLSV6], prot);
982 			build_proto_ops(tls_proto_ops[TLSV6],
983 					sk->sk_socket->ops);
984 			smp_store_release(&saved_tcpv6_prot, prot);
985 		}
986 		mutex_unlock(&tcpv6_prot_mutex);
987 	}
988 
989 	if (ip_ver == TLSV4 &&
990 	    unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
991 		mutex_lock(&tcpv4_prot_mutex);
992 		if (likely(prot != saved_tcpv4_prot)) {
993 			build_protos(tls_prots[TLSV4], prot);
994 			build_proto_ops(tls_proto_ops[TLSV4],
995 					sk->sk_socket->ops);
996 			smp_store_release(&saved_tcpv4_prot, prot);
997 		}
998 		mutex_unlock(&tcpv4_prot_mutex);
999 	}
1000 }
1001 
1002 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
1003 			 const struct proto *base)
1004 {
1005 	prot[TLS_BASE][TLS_BASE] = *base;
1006 	prot[TLS_BASE][TLS_BASE].setsockopt	= tls_setsockopt;
1007 	prot[TLS_BASE][TLS_BASE].getsockopt	= tls_getsockopt;
1008 	prot[TLS_BASE][TLS_BASE].disconnect	= tls_disconnect;
1009 	prot[TLS_BASE][TLS_BASE].close		= tls_sk_proto_close;
1010 
1011 	prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
1012 	prot[TLS_SW][TLS_BASE].sendmsg		= tls_sw_sendmsg;
1013 	prot[TLS_SW][TLS_BASE].splice_eof	= tls_sw_splice_eof;
1014 
1015 	prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
1016 	prot[TLS_BASE][TLS_SW].recvmsg		  = tls_sw_recvmsg;
1017 	prot[TLS_BASE][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
1018 	prot[TLS_BASE][TLS_SW].close		  = tls_sk_proto_close;
1019 
1020 	prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
1021 	prot[TLS_SW][TLS_SW].recvmsg		= tls_sw_recvmsg;
1022 	prot[TLS_SW][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
1023 	prot[TLS_SW][TLS_SW].close		= tls_sk_proto_close;
1024 
1025 #ifdef CONFIG_TLS_DEVICE
1026 	prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
1027 	prot[TLS_HW][TLS_BASE].sendmsg		= tls_device_sendmsg;
1028 	prot[TLS_HW][TLS_BASE].splice_eof	= tls_device_splice_eof;
1029 
1030 	prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
1031 	prot[TLS_HW][TLS_SW].sendmsg		= tls_device_sendmsg;
1032 	prot[TLS_HW][TLS_SW].splice_eof		= tls_device_splice_eof;
1033 
1034 	prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
1035 
1036 	prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
1037 
1038 	prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
1039 #endif
1040 #ifdef CONFIG_TLS_TOE
1041 	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
1042 	prot[TLS_HW_RECORD][TLS_HW_RECORD].hash		= tls_toe_hash;
1043 	prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash	= tls_toe_unhash;
1044 #endif
1045 }
1046 
1047 static int tls_init(struct sock *sk)
1048 {
1049 	struct tls_context *ctx;
1050 	int rc = 0;
1051 
1052 	tls_build_proto(sk);
1053 
1054 #ifdef CONFIG_TLS_TOE
1055 	if (tls_toe_bypass(sk))
1056 		return 0;
1057 #endif
1058 
1059 	/* The TLS ulp is currently supported only for TCP sockets
1060 	 * in ESTABLISHED state.
1061 	 * Supporting sockets in LISTEN state will require us
1062 	 * to modify the accept implementation to clone rather then
1063 	 * share the ulp context.
1064 	 */
1065 	if (sk->sk_state != TCP_ESTABLISHED)
1066 		return -ENOTCONN;
1067 
1068 	/* allocate tls context */
1069 	write_lock_bh(&sk->sk_callback_lock);
1070 	ctx = tls_ctx_create(sk);
1071 	if (!ctx) {
1072 		rc = -ENOMEM;
1073 		goto out;
1074 	}
1075 
1076 	ctx->tx_conf = TLS_BASE;
1077 	ctx->rx_conf = TLS_BASE;
1078 	ctx->tx_max_payload_len = TLS_MAX_PAYLOAD_SIZE;
1079 	update_sk_prot(sk, ctx);
1080 out:
1081 	write_unlock_bh(&sk->sk_callback_lock);
1082 	return rc;
1083 }
1084 
1085 static void tls_update(struct sock *sk, struct proto *p,
1086 		       void (*write_space)(struct sock *sk))
1087 {
1088 	struct tls_context *ctx;
1089 
1090 	WARN_ON_ONCE(sk->sk_prot == p);
1091 
1092 	ctx = tls_get_ctx(sk);
1093 	if (likely(ctx)) {
1094 		ctx->sk_write_space = write_space;
1095 		ctx->sk_proto = p;
1096 	} else {
1097 		/* Pairs with lockless read in sk_clone_lock(). */
1098 		WRITE_ONCE(sk->sk_prot, p);
1099 		sk->sk_write_space = write_space;
1100 	}
1101 }
1102 
1103 static u16 tls_user_config(struct tls_context *ctx, bool tx)
1104 {
1105 	u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
1106 
1107 	switch (config) {
1108 	case TLS_BASE:
1109 		return TLS_CONF_BASE;
1110 	case TLS_SW:
1111 		return TLS_CONF_SW;
1112 	case TLS_HW:
1113 		return TLS_CONF_HW;
1114 	case TLS_HW_RECORD:
1115 		return TLS_CONF_HW_RECORD;
1116 	}
1117 	return 0;
1118 }
1119 
1120 static int tls_get_info(struct sock *sk, struct sk_buff *skb, bool net_admin)
1121 {
1122 	u16 version, cipher_type;
1123 	struct tls_context *ctx;
1124 	struct nlattr *start;
1125 	int err;
1126 
1127 	start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
1128 	if (!start)
1129 		return -EMSGSIZE;
1130 
1131 	rcu_read_lock();
1132 	ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
1133 	if (!ctx) {
1134 		err = 0;
1135 		goto nla_failure;
1136 	}
1137 	version = ctx->prot_info.version;
1138 	if (version) {
1139 		err = nla_put_u16(skb, TLS_INFO_VERSION, version);
1140 		if (err)
1141 			goto nla_failure;
1142 	}
1143 	cipher_type = ctx->prot_info.cipher_type;
1144 	if (cipher_type) {
1145 		err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
1146 		if (err)
1147 			goto nla_failure;
1148 	}
1149 	err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
1150 	if (err)
1151 		goto nla_failure;
1152 
1153 	err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
1154 	if (err)
1155 		goto nla_failure;
1156 
1157 	if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
1158 		err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
1159 		if (err)
1160 			goto nla_failure;
1161 	}
1162 	if (ctx->rx_no_pad) {
1163 		err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
1164 		if (err)
1165 			goto nla_failure;
1166 	}
1167 
1168 	err = nla_put_u16(skb, TLS_INFO_TX_MAX_PAYLOAD_LEN,
1169 			  ctx->tx_max_payload_len);
1170 
1171 	if (err)
1172 		goto nla_failure;
1173 
1174 	rcu_read_unlock();
1175 	nla_nest_end(skb, start);
1176 	return 0;
1177 
1178 nla_failure:
1179 	rcu_read_unlock();
1180 	nla_nest_cancel(skb, start);
1181 	return err;
1182 }
1183 
1184 static size_t tls_get_info_size(const struct sock *sk, bool net_admin)
1185 {
1186 	size_t size = 0;
1187 
1188 	size += nla_total_size(0) +		/* INET_ULP_INFO_TLS */
1189 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_VERSION */
1190 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_CIPHER */
1191 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_RXCONF */
1192 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_TXCONF */
1193 		nla_total_size(0) +		/* TLS_INFO_ZC_RO_TX */
1194 		nla_total_size(0) +		/* TLS_INFO_RX_NO_PAD */
1195 		nla_total_size(sizeof(u16)) +   /* TLS_INFO_TX_MAX_PAYLOAD_LEN */
1196 		0;
1197 
1198 	return size;
1199 }
1200 
1201 static int __net_init tls_init_net(struct net *net)
1202 {
1203 	int err;
1204 
1205 	net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
1206 	if (!net->mib.tls_statistics)
1207 		return -ENOMEM;
1208 
1209 	err = tls_proc_init(net);
1210 	if (err)
1211 		goto err_free_stats;
1212 
1213 	return 0;
1214 err_free_stats:
1215 	free_percpu(net->mib.tls_statistics);
1216 	return err;
1217 }
1218 
1219 static void __net_exit tls_exit_net(struct net *net)
1220 {
1221 	tls_proc_fini(net);
1222 	free_percpu(net->mib.tls_statistics);
1223 }
1224 
1225 static struct pernet_operations tls_proc_ops = {
1226 	.init = tls_init_net,
1227 	.exit = tls_exit_net,
1228 };
1229 
1230 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
1231 	.name			= "tls",
1232 	.owner			= THIS_MODULE,
1233 	.init			= tls_init,
1234 	.update			= tls_update,
1235 	.get_info		= tls_get_info,
1236 	.get_info_size		= tls_get_info_size,
1237 };
1238 
1239 static int __init tls_register(void)
1240 {
1241 	int err;
1242 
1243 	err = register_pernet_subsys(&tls_proc_ops);
1244 	if (err)
1245 		return err;
1246 
1247 	err = tls_strp_dev_init();
1248 	if (err)
1249 		goto err_pernet;
1250 
1251 	err = tls_device_init();
1252 	if (err)
1253 		goto err_strp;
1254 
1255 	tcp_register_ulp(&tcp_tls_ulp_ops);
1256 
1257 	return 0;
1258 err_strp:
1259 	tls_strp_dev_exit();
1260 err_pernet:
1261 	unregister_pernet_subsys(&tls_proc_ops);
1262 	return err;
1263 }
1264 
1265 static void __exit tls_unregister(void)
1266 {
1267 	tcp_unregister_ulp(&tcp_tls_ulp_ops);
1268 	tls_strp_dev_exit();
1269 	tls_device_cleanup();
1270 	unregister_pernet_subsys(&tls_proc_ops);
1271 }
1272 
1273 module_init(tls_register);
1274 module_exit(tls_unregister);
1275