xref: /linux/net/tls/tls_sw.c (revision f6f3bac08ff9855d803081a353a1fafaa8845739)
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  *
8  * This software is available to you under a choice of one of two
9  * licenses.  You may choose to be licensed under the terms of the GNU
10  * General Public License (GPL) Version 2, available from the file
11  * COPYING in the main directory of this source tree, or the
12  * OpenIB.org BSD license below:
13  *
14  *     Redistribution and use in source and binary forms, with or
15  *     without modification, are permitted provided that the following
16  *     conditions are met:
17  *
18  *      - Redistributions of source code must retain the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer.
21  *
22  *      - Redistributions in binary form must reproduce the above
23  *        copyright notice, this list of conditions and the following
24  *        disclaimer in the documentation and/or other materials
25  *        provided with the distribution.
26  *
27  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34  * SOFTWARE.
35  */
36 
37 #include <linux/sched/signal.h>
38 #include <linux/module.h>
39 #include <crypto/aead.h>
40 
41 #include <net/strparser.h>
42 #include <net/tls.h>
43 
44 #define MAX_IV_SIZE	TLS_CIPHER_AES_GCM_128_IV_SIZE
45 
46 static int __skb_nsg(struct sk_buff *skb, int offset, int len,
47                      unsigned int recursion_level)
48 {
49         int start = skb_headlen(skb);
50         int i, chunk = start - offset;
51         struct sk_buff *frag_iter;
52         int elt = 0;
53 
54         if (unlikely(recursion_level >= 24))
55                 return -EMSGSIZE;
56 
57         if (chunk > 0) {
58                 if (chunk > len)
59                         chunk = len;
60                 elt++;
61                 len -= chunk;
62                 if (len == 0)
63                         return elt;
64                 offset += chunk;
65         }
66 
67         for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
68                 int end;
69 
70                 WARN_ON(start > offset + len);
71 
72                 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
73                 chunk = end - offset;
74                 if (chunk > 0) {
75                         if (chunk > len)
76                                 chunk = len;
77                         elt++;
78                         len -= chunk;
79                         if (len == 0)
80                                 return elt;
81                         offset += chunk;
82                 }
83                 start = end;
84         }
85 
86         if (unlikely(skb_has_frag_list(skb))) {
87                 skb_walk_frags(skb, frag_iter) {
88                         int end, ret;
89 
90                         WARN_ON(start > offset + len);
91 
92                         end = start + frag_iter->len;
93                         chunk = end - offset;
94                         if (chunk > 0) {
95                                 if (chunk > len)
96                                         chunk = len;
97                                 ret = __skb_nsg(frag_iter, offset - start, chunk,
98                                                 recursion_level + 1);
99                                 if (unlikely(ret < 0))
100                                         return ret;
101                                 elt += ret;
102                                 len -= chunk;
103                                 if (len == 0)
104                                         return elt;
105                                 offset += chunk;
106                         }
107                         start = end;
108                 }
109         }
110         BUG_ON(len);
111         return elt;
112 }
113 
114 /* Return the number of scatterlist elements required to completely map the
115  * skb, or -EMSGSIZE if the recursion depth is exceeded.
116  */
117 static int skb_nsg(struct sk_buff *skb, int offset, int len)
118 {
119         return __skb_nsg(skb, offset, len, 0);
120 }
121 
122 static void tls_decrypt_done(struct crypto_async_request *req, int err)
123 {
124 	struct aead_request *aead_req = (struct aead_request *)req;
125 	struct decrypt_req_ctx *req_ctx =
126 			(struct decrypt_req_ctx *)(aead_req + 1);
127 
128 	struct scatterlist *sgout = aead_req->dst;
129 
130 	struct tls_context *tls_ctx = tls_get_ctx(req_ctx->sk);
131 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
132 	int pending = atomic_dec_return(&ctx->decrypt_pending);
133 	struct scatterlist *sg;
134 	unsigned int pages;
135 
136 	/* Propagate if there was an err */
137 	if (err) {
138 		ctx->async_wait.err = err;
139 		tls_err_abort(req_ctx->sk, err);
140 	}
141 
142 	/* Release the skb, pages and memory allocated for crypto req */
143 	kfree_skb(req->data);
144 
145 	/* Skip the first S/G entry as it points to AAD */
146 	for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
147 		if (!sg)
148 			break;
149 		put_page(sg_page(sg));
150 	}
151 
152 	kfree(aead_req);
153 
154 	if (!pending && READ_ONCE(ctx->async_notify))
155 		complete(&ctx->async_wait.completion);
156 }
157 
158 static int tls_do_decryption(struct sock *sk,
159 			     struct sk_buff *skb,
160 			     struct scatterlist *sgin,
161 			     struct scatterlist *sgout,
162 			     char *iv_recv,
163 			     size_t data_len,
164 			     struct aead_request *aead_req,
165 			     bool async)
166 {
167 	struct tls_context *tls_ctx = tls_get_ctx(sk);
168 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
169 	int ret;
170 
171 	aead_request_set_tfm(aead_req, ctx->aead_recv);
172 	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
173 	aead_request_set_crypt(aead_req, sgin, sgout,
174 			       data_len + tls_ctx->rx.tag_size,
175 			       (u8 *)iv_recv);
176 
177 	if (async) {
178 		struct decrypt_req_ctx *req_ctx;
179 
180 		req_ctx = (struct decrypt_req_ctx *)(aead_req + 1);
181 		req_ctx->sk = sk;
182 
183 		aead_request_set_callback(aead_req,
184 					  CRYPTO_TFM_REQ_MAY_BACKLOG,
185 					  tls_decrypt_done, skb);
186 		atomic_inc(&ctx->decrypt_pending);
187 	} else {
188 		aead_request_set_callback(aead_req,
189 					  CRYPTO_TFM_REQ_MAY_BACKLOG,
190 					  crypto_req_done, &ctx->async_wait);
191 	}
192 
193 	ret = crypto_aead_decrypt(aead_req);
194 	if (ret == -EINPROGRESS) {
195 		if (async)
196 			return ret;
197 
198 		ret = crypto_wait_req(ret, &ctx->async_wait);
199 	}
200 
201 	if (async)
202 		atomic_dec(&ctx->decrypt_pending);
203 
204 	return ret;
205 }
206 
207 static void trim_sg(struct sock *sk, struct scatterlist *sg,
208 		    int *sg_num_elem, unsigned int *sg_size, int target_size)
209 {
210 	int i = *sg_num_elem - 1;
211 	int trim = *sg_size - target_size;
212 
213 	if (trim <= 0) {
214 		WARN_ON(trim < 0);
215 		return;
216 	}
217 
218 	*sg_size = target_size;
219 	while (trim >= sg[i].length) {
220 		trim -= sg[i].length;
221 		sk_mem_uncharge(sk, sg[i].length);
222 		put_page(sg_page(&sg[i]));
223 		i--;
224 
225 		if (i < 0)
226 			goto out;
227 	}
228 
229 	sg[i].length -= trim;
230 	sk_mem_uncharge(sk, trim);
231 
232 out:
233 	*sg_num_elem = i + 1;
234 }
235 
236 static void trim_both_sgl(struct sock *sk, int target_size)
237 {
238 	struct tls_context *tls_ctx = tls_get_ctx(sk);
239 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
240 
241 	trim_sg(sk, ctx->sg_plaintext_data,
242 		&ctx->sg_plaintext_num_elem,
243 		&ctx->sg_plaintext_size,
244 		target_size);
245 
246 	if (target_size > 0)
247 		target_size += tls_ctx->tx.overhead_size;
248 
249 	trim_sg(sk, ctx->sg_encrypted_data,
250 		&ctx->sg_encrypted_num_elem,
251 		&ctx->sg_encrypted_size,
252 		target_size);
253 }
254 
255 static int alloc_encrypted_sg(struct sock *sk, int len)
256 {
257 	struct tls_context *tls_ctx = tls_get_ctx(sk);
258 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
259 	int rc = 0;
260 
261 	rc = sk_alloc_sg(sk, len,
262 			 ctx->sg_encrypted_data, 0,
263 			 &ctx->sg_encrypted_num_elem,
264 			 &ctx->sg_encrypted_size, 0);
265 
266 	return rc;
267 }
268 
269 static int alloc_plaintext_sg(struct sock *sk, int len)
270 {
271 	struct tls_context *tls_ctx = tls_get_ctx(sk);
272 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
273 	int rc = 0;
274 
275 	rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
276 			 &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
277 			 tls_ctx->pending_open_record_frags);
278 
279 	return rc;
280 }
281 
282 static void free_sg(struct sock *sk, struct scatterlist *sg,
283 		    int *sg_num_elem, unsigned int *sg_size)
284 {
285 	int i, n = *sg_num_elem;
286 
287 	for (i = 0; i < n; ++i) {
288 		sk_mem_uncharge(sk, sg[i].length);
289 		put_page(sg_page(&sg[i]));
290 	}
291 	*sg_num_elem = 0;
292 	*sg_size = 0;
293 }
294 
295 static void tls_free_both_sg(struct sock *sk)
296 {
297 	struct tls_context *tls_ctx = tls_get_ctx(sk);
298 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
299 
300 	free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
301 		&ctx->sg_encrypted_size);
302 
303 	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
304 		&ctx->sg_plaintext_size);
305 }
306 
307 static int tls_do_encryption(struct tls_context *tls_ctx,
308 			     struct tls_sw_context_tx *ctx,
309 			     struct aead_request *aead_req,
310 			     size_t data_len)
311 {
312 	int rc;
313 
314 	ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
315 	ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
316 
317 	aead_request_set_tfm(aead_req, ctx->aead_send);
318 	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
319 	aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
320 			       data_len, tls_ctx->tx.iv);
321 
322 	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
323 				  crypto_req_done, &ctx->async_wait);
324 
325 	rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
326 
327 	ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
328 	ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
329 
330 	return rc;
331 }
332 
333 static int tls_push_record(struct sock *sk, int flags,
334 			   unsigned char record_type)
335 {
336 	struct tls_context *tls_ctx = tls_get_ctx(sk);
337 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
338 	struct aead_request *req;
339 	int rc;
340 
341 	req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
342 	if (!req)
343 		return -ENOMEM;
344 
345 	sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
346 	sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
347 
348 	tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
349 		     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
350 		     record_type);
351 
352 	tls_fill_prepend(tls_ctx,
353 			 page_address(sg_page(&ctx->sg_encrypted_data[0])) +
354 			 ctx->sg_encrypted_data[0].offset,
355 			 ctx->sg_plaintext_size, record_type);
356 
357 	tls_ctx->pending_open_record_frags = 0;
358 	set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
359 
360 	rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);
361 	if (rc < 0) {
362 		/* If we are called from write_space and
363 		 * we fail, we need to set this SOCK_NOSPACE
364 		 * to trigger another write_space in the future.
365 		 */
366 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
367 		goto out_req;
368 	}
369 
370 	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
371 		&ctx->sg_plaintext_size);
372 
373 	ctx->sg_encrypted_num_elem = 0;
374 	ctx->sg_encrypted_size = 0;
375 
376 	/* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
377 	rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
378 	if (rc < 0 && rc != -EAGAIN)
379 		tls_err_abort(sk, EBADMSG);
380 
381 	tls_advance_record_sn(sk, &tls_ctx->tx);
382 out_req:
383 	aead_request_free(req);
384 	return rc;
385 }
386 
387 static int tls_sw_push_pending_record(struct sock *sk, int flags)
388 {
389 	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
390 }
391 
392 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
393 			      int length, int *pages_used,
394 			      unsigned int *size_used,
395 			      struct scatterlist *to, int to_max_pages,
396 			      bool charge)
397 {
398 	struct page *pages[MAX_SKB_FRAGS];
399 
400 	size_t offset;
401 	ssize_t copied, use;
402 	int i = 0;
403 	unsigned int size = *size_used;
404 	int num_elem = *pages_used;
405 	int rc = 0;
406 	int maxpages;
407 
408 	while (length > 0) {
409 		i = 0;
410 		maxpages = to_max_pages - num_elem;
411 		if (maxpages == 0) {
412 			rc = -EFAULT;
413 			goto out;
414 		}
415 		copied = iov_iter_get_pages(from, pages,
416 					    length,
417 					    maxpages, &offset);
418 		if (copied <= 0) {
419 			rc = -EFAULT;
420 			goto out;
421 		}
422 
423 		iov_iter_advance(from, copied);
424 
425 		length -= copied;
426 		size += copied;
427 		while (copied) {
428 			use = min_t(int, copied, PAGE_SIZE - offset);
429 
430 			sg_set_page(&to[num_elem],
431 				    pages[i], use, offset);
432 			sg_unmark_end(&to[num_elem]);
433 			if (charge)
434 				sk_mem_charge(sk, use);
435 
436 			offset = 0;
437 			copied -= use;
438 
439 			++i;
440 			++num_elem;
441 		}
442 	}
443 
444 	/* Mark the end in the last sg entry if newly added */
445 	if (num_elem > *pages_used)
446 		sg_mark_end(&to[num_elem - 1]);
447 out:
448 	if (rc)
449 		iov_iter_revert(from, size - *size_used);
450 	*size_used = size;
451 	*pages_used = num_elem;
452 
453 	return rc;
454 }
455 
456 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
457 			     int bytes)
458 {
459 	struct tls_context *tls_ctx = tls_get_ctx(sk);
460 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
461 	struct scatterlist *sg = ctx->sg_plaintext_data;
462 	int copy, i, rc = 0;
463 
464 	for (i = tls_ctx->pending_open_record_frags;
465 	     i < ctx->sg_plaintext_num_elem; ++i) {
466 		copy = sg[i].length;
467 		if (copy_from_iter(
468 				page_address(sg_page(&sg[i])) + sg[i].offset,
469 				copy, from) != copy) {
470 			rc = -EFAULT;
471 			goto out;
472 		}
473 		bytes -= copy;
474 
475 		++tls_ctx->pending_open_record_frags;
476 
477 		if (!bytes)
478 			break;
479 	}
480 
481 out:
482 	return rc;
483 }
484 
485 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
486 {
487 	struct tls_context *tls_ctx = tls_get_ctx(sk);
488 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
489 	int ret = 0;
490 	int required_size;
491 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
492 	bool eor = !(msg->msg_flags & MSG_MORE);
493 	size_t try_to_copy, copied = 0;
494 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
495 	int record_room;
496 	bool full_record;
497 	int orig_size;
498 	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
499 
500 	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
501 		return -ENOTSUPP;
502 
503 	lock_sock(sk);
504 
505 	if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
506 		goto send_end;
507 
508 	if (unlikely(msg->msg_controllen)) {
509 		ret = tls_proccess_cmsg(sk, msg, &record_type);
510 		if (ret)
511 			goto send_end;
512 	}
513 
514 	while (msg_data_left(msg)) {
515 		if (sk->sk_err) {
516 			ret = -sk->sk_err;
517 			goto send_end;
518 		}
519 
520 		orig_size = ctx->sg_plaintext_size;
521 		full_record = false;
522 		try_to_copy = msg_data_left(msg);
523 		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
524 		if (try_to_copy >= record_room) {
525 			try_to_copy = record_room;
526 			full_record = true;
527 		}
528 
529 		required_size = ctx->sg_plaintext_size + try_to_copy +
530 				tls_ctx->tx.overhead_size;
531 
532 		if (!sk_stream_memory_free(sk))
533 			goto wait_for_sndbuf;
534 alloc_encrypted:
535 		ret = alloc_encrypted_sg(sk, required_size);
536 		if (ret) {
537 			if (ret != -ENOSPC)
538 				goto wait_for_memory;
539 
540 			/* Adjust try_to_copy according to the amount that was
541 			 * actually allocated. The difference is due
542 			 * to max sg elements limit
543 			 */
544 			try_to_copy -= required_size - ctx->sg_encrypted_size;
545 			full_record = true;
546 		}
547 		if (!is_kvec && (full_record || eor)) {
548 			ret = zerocopy_from_iter(sk, &msg->msg_iter,
549 				try_to_copy, &ctx->sg_plaintext_num_elem,
550 				&ctx->sg_plaintext_size,
551 				ctx->sg_plaintext_data,
552 				ARRAY_SIZE(ctx->sg_plaintext_data),
553 				true);
554 			if (ret)
555 				goto fallback_to_reg_send;
556 
557 			copied += try_to_copy;
558 			ret = tls_push_record(sk, msg->msg_flags, record_type);
559 			if (ret)
560 				goto send_end;
561 			continue;
562 
563 fallback_to_reg_send:
564 			trim_sg(sk, ctx->sg_plaintext_data,
565 				&ctx->sg_plaintext_num_elem,
566 				&ctx->sg_plaintext_size,
567 				orig_size);
568 		}
569 
570 		required_size = ctx->sg_plaintext_size + try_to_copy;
571 alloc_plaintext:
572 		ret = alloc_plaintext_sg(sk, required_size);
573 		if (ret) {
574 			if (ret != -ENOSPC)
575 				goto wait_for_memory;
576 
577 			/* Adjust try_to_copy according to the amount that was
578 			 * actually allocated. The difference is due
579 			 * to max sg elements limit
580 			 */
581 			try_to_copy -= required_size - ctx->sg_plaintext_size;
582 			full_record = true;
583 
584 			trim_sg(sk, ctx->sg_encrypted_data,
585 				&ctx->sg_encrypted_num_elem,
586 				&ctx->sg_encrypted_size,
587 				ctx->sg_plaintext_size +
588 				tls_ctx->tx.overhead_size);
589 		}
590 
591 		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
592 		if (ret)
593 			goto trim_sgl;
594 
595 		copied += try_to_copy;
596 		if (full_record || eor) {
597 push_record:
598 			ret = tls_push_record(sk, msg->msg_flags, record_type);
599 			if (ret) {
600 				if (ret == -ENOMEM)
601 					goto wait_for_memory;
602 
603 				goto send_end;
604 			}
605 		}
606 
607 		continue;
608 
609 wait_for_sndbuf:
610 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
611 wait_for_memory:
612 		ret = sk_stream_wait_memory(sk, &timeo);
613 		if (ret) {
614 trim_sgl:
615 			trim_both_sgl(sk, orig_size);
616 			goto send_end;
617 		}
618 
619 		if (tls_is_pending_closed_record(tls_ctx))
620 			goto push_record;
621 
622 		if (ctx->sg_encrypted_size < required_size)
623 			goto alloc_encrypted;
624 
625 		goto alloc_plaintext;
626 	}
627 
628 send_end:
629 	ret = sk_stream_error(sk, msg->msg_flags, ret);
630 
631 	release_sock(sk);
632 	return copied ? copied : ret;
633 }
634 
635 int tls_sw_sendpage(struct sock *sk, struct page *page,
636 		    int offset, size_t size, int flags)
637 {
638 	struct tls_context *tls_ctx = tls_get_ctx(sk);
639 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
640 	int ret = 0;
641 	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
642 	bool eor;
643 	size_t orig_size = size;
644 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
645 	struct scatterlist *sg;
646 	bool full_record;
647 	int record_room;
648 
649 	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
650 		      MSG_SENDPAGE_NOTLAST))
651 		return -ENOTSUPP;
652 
653 	/* No MSG_EOR from splice, only look at MSG_MORE */
654 	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
655 
656 	lock_sock(sk);
657 
658 	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
659 
660 	if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
661 		goto sendpage_end;
662 
663 	/* Call the sk_stream functions to manage the sndbuf mem. */
664 	while (size > 0) {
665 		size_t copy, required_size;
666 
667 		if (sk->sk_err) {
668 			ret = -sk->sk_err;
669 			goto sendpage_end;
670 		}
671 
672 		full_record = false;
673 		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
674 		copy = size;
675 		if (copy >= record_room) {
676 			copy = record_room;
677 			full_record = true;
678 		}
679 		required_size = ctx->sg_plaintext_size + copy +
680 			      tls_ctx->tx.overhead_size;
681 
682 		if (!sk_stream_memory_free(sk))
683 			goto wait_for_sndbuf;
684 alloc_payload:
685 		ret = alloc_encrypted_sg(sk, required_size);
686 		if (ret) {
687 			if (ret != -ENOSPC)
688 				goto wait_for_memory;
689 
690 			/* Adjust copy according to the amount that was
691 			 * actually allocated. The difference is due
692 			 * to max sg elements limit
693 			 */
694 			copy -= required_size - ctx->sg_plaintext_size;
695 			full_record = true;
696 		}
697 
698 		get_page(page);
699 		sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
700 		sg_set_page(sg, page, copy, offset);
701 		sg_unmark_end(sg);
702 
703 		ctx->sg_plaintext_num_elem++;
704 
705 		sk_mem_charge(sk, copy);
706 		offset += copy;
707 		size -= copy;
708 		ctx->sg_plaintext_size += copy;
709 		tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
710 
711 		if (full_record || eor ||
712 		    ctx->sg_plaintext_num_elem ==
713 		    ARRAY_SIZE(ctx->sg_plaintext_data)) {
714 push_record:
715 			ret = tls_push_record(sk, flags, record_type);
716 			if (ret) {
717 				if (ret == -ENOMEM)
718 					goto wait_for_memory;
719 
720 				goto sendpage_end;
721 			}
722 		}
723 		continue;
724 wait_for_sndbuf:
725 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
726 wait_for_memory:
727 		ret = sk_stream_wait_memory(sk, &timeo);
728 		if (ret) {
729 			trim_both_sgl(sk, ctx->sg_plaintext_size);
730 			goto sendpage_end;
731 		}
732 
733 		if (tls_is_pending_closed_record(tls_ctx))
734 			goto push_record;
735 
736 		goto alloc_payload;
737 	}
738 
739 sendpage_end:
740 	if (orig_size > size)
741 		ret = orig_size - size;
742 	else
743 		ret = sk_stream_error(sk, flags, ret);
744 
745 	release_sock(sk);
746 	return ret;
747 }
748 
749 static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
750 				     long timeo, int *err)
751 {
752 	struct tls_context *tls_ctx = tls_get_ctx(sk);
753 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
754 	struct sk_buff *skb;
755 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
756 
757 	while (!(skb = ctx->recv_pkt)) {
758 		if (sk->sk_err) {
759 			*err = sock_error(sk);
760 			return NULL;
761 		}
762 
763 		if (sk->sk_shutdown & RCV_SHUTDOWN)
764 			return NULL;
765 
766 		if (sock_flag(sk, SOCK_DONE))
767 			return NULL;
768 
769 		if ((flags & MSG_DONTWAIT) || !timeo) {
770 			*err = -EAGAIN;
771 			return NULL;
772 		}
773 
774 		add_wait_queue(sk_sleep(sk), &wait);
775 		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
776 		sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
777 		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
778 		remove_wait_queue(sk_sleep(sk), &wait);
779 
780 		/* Handle signals */
781 		if (signal_pending(current)) {
782 			*err = sock_intr_errno(timeo);
783 			return NULL;
784 		}
785 	}
786 
787 	return skb;
788 }
789 
790 /* This function decrypts the input skb into either out_iov or in out_sg
791  * or in skb buffers itself. The input parameter 'zc' indicates if
792  * zero-copy mode needs to be tried or not. With zero-copy mode, either
793  * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
794  * NULL, then the decryption happens inside skb buffers itself, i.e.
795  * zero-copy gets disabled and 'zc' is updated.
796  */
797 
798 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
799 			    struct iov_iter *out_iov,
800 			    struct scatterlist *out_sg,
801 			    int *chunk, bool *zc)
802 {
803 	struct tls_context *tls_ctx = tls_get_ctx(sk);
804 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
805 	struct strp_msg *rxm = strp_msg(skb);
806 	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
807 	struct aead_request *aead_req;
808 	struct sk_buff *unused;
809 	u8 *aad, *iv, *mem = NULL;
810 	struct scatterlist *sgin = NULL;
811 	struct scatterlist *sgout = NULL;
812 	const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
813 
814 	if (*zc && (out_iov || out_sg)) {
815 		if (out_iov)
816 			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
817 		else
818 			n_sgout = sg_nents(out_sg);
819 		n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size,
820 				 rxm->full_len - tls_ctx->rx.prepend_size);
821 	} else {
822 		n_sgout = 0;
823 		*zc = false;
824 		n_sgin = skb_cow_data(skb, 0, &unused);
825 	}
826 
827 	if (n_sgin < 1)
828 		return -EBADMSG;
829 
830 	/* Increment to accommodate AAD */
831 	n_sgin = n_sgin + 1;
832 
833 	nsg = n_sgin + n_sgout;
834 
835 	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
836 	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
837 	mem_size = mem_size + TLS_AAD_SPACE_SIZE;
838 	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
839 
840 	/* Allocate a single block of memory which contains
841 	 * aead_req || sgin[] || sgout[] || aad || iv.
842 	 * This order achieves correct alignment for aead_req, sgin, sgout.
843 	 */
844 	mem = kmalloc(mem_size, sk->sk_allocation);
845 	if (!mem)
846 		return -ENOMEM;
847 
848 	/* Segment the allocated memory */
849 	aead_req = (struct aead_request *)mem;
850 	sgin = (struct scatterlist *)(mem + aead_size);
851 	sgout = sgin + n_sgin;
852 	aad = (u8 *)(sgout + n_sgout);
853 	iv = aad + TLS_AAD_SPACE_SIZE;
854 
855 	/* Prepare IV */
856 	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
857 			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
858 			    tls_ctx->rx.iv_size);
859 	if (err < 0) {
860 		kfree(mem);
861 		return err;
862 	}
863 	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
864 
865 	/* Prepare AAD */
866 	tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
867 		     tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
868 		     ctx->control);
869 
870 	/* Prepare sgin */
871 	sg_init_table(sgin, n_sgin);
872 	sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
873 	err = skb_to_sgvec(skb, &sgin[1],
874 			   rxm->offset + tls_ctx->rx.prepend_size,
875 			   rxm->full_len - tls_ctx->rx.prepend_size);
876 	if (err < 0) {
877 		kfree(mem);
878 		return err;
879 	}
880 
881 	if (n_sgout) {
882 		if (out_iov) {
883 			sg_init_table(sgout, n_sgout);
884 			sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
885 
886 			*chunk = 0;
887 			err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
888 						 chunk, &sgout[1],
889 						 (n_sgout - 1), false);
890 			if (err < 0)
891 				goto fallback_to_reg_recv;
892 		} else if (out_sg) {
893 			memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
894 		} else {
895 			goto fallback_to_reg_recv;
896 		}
897 	} else {
898 fallback_to_reg_recv:
899 		sgout = sgin;
900 		pages = 0;
901 		*chunk = 0;
902 		*zc = false;
903 	}
904 
905 	/* Prepare and submit AEAD request */
906 	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
907 				data_len, aead_req, *zc);
908 	if (err == -EINPROGRESS)
909 		return err;
910 
911 	/* Release the pages in case iov was mapped to pages */
912 	for (; pages > 0; pages--)
913 		put_page(sg_page(&sgout[pages]));
914 
915 	kfree(mem);
916 	return err;
917 }
918 
919 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
920 			      struct iov_iter *dest, int *chunk, bool *zc)
921 {
922 	struct tls_context *tls_ctx = tls_get_ctx(sk);
923 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
924 	struct strp_msg *rxm = strp_msg(skb);
925 	int err = 0;
926 
927 #ifdef CONFIG_TLS_DEVICE
928 	err = tls_device_decrypted(sk, skb);
929 	if (err < 0)
930 		return err;
931 #endif
932 	if (!ctx->decrypted) {
933 		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
934 		if (err < 0) {
935 			if (err == -EINPROGRESS)
936 				tls_advance_record_sn(sk, &tls_ctx->rx);
937 
938 			return err;
939 		}
940 	} else {
941 		*zc = false;
942 	}
943 
944 	rxm->offset += tls_ctx->rx.prepend_size;
945 	rxm->full_len -= tls_ctx->rx.overhead_size;
946 	tls_advance_record_sn(sk, &tls_ctx->rx);
947 	ctx->decrypted = true;
948 	ctx->saved_data_ready(sk);
949 
950 	return err;
951 }
952 
953 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
954 		struct scatterlist *sgout)
955 {
956 	bool zc = true;
957 	int chunk;
958 
959 	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
960 }
961 
962 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
963 			       unsigned int len)
964 {
965 	struct tls_context *tls_ctx = tls_get_ctx(sk);
966 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
967 
968 	if (skb) {
969 		struct strp_msg *rxm = strp_msg(skb);
970 
971 		if (len < rxm->full_len) {
972 			rxm->offset += len;
973 			rxm->full_len -= len;
974 			return false;
975 		}
976 		kfree_skb(skb);
977 	}
978 
979 	/* Finished with message */
980 	ctx->recv_pkt = NULL;
981 	__strp_unpause(&ctx->strp);
982 
983 	return true;
984 }
985 
986 int tls_sw_recvmsg(struct sock *sk,
987 		   struct msghdr *msg,
988 		   size_t len,
989 		   int nonblock,
990 		   int flags,
991 		   int *addr_len)
992 {
993 	struct tls_context *tls_ctx = tls_get_ctx(sk);
994 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
995 	unsigned char control;
996 	struct strp_msg *rxm;
997 	struct sk_buff *skb;
998 	ssize_t copied = 0;
999 	bool cmsg = false;
1000 	int target, err = 0;
1001 	long timeo;
1002 	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
1003 	int num_async = 0;
1004 
1005 	flags |= nonblock;
1006 
1007 	if (unlikely(flags & MSG_ERRQUEUE))
1008 		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1009 
1010 	lock_sock(sk);
1011 
1012 	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1013 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1014 	do {
1015 		bool zc = false;
1016 		bool async = false;
1017 		int chunk = 0;
1018 
1019 		skb = tls_wait_data(sk, flags, timeo, &err);
1020 		if (!skb)
1021 			goto recv_end;
1022 
1023 		rxm = strp_msg(skb);
1024 
1025 		if (!cmsg) {
1026 			int cerr;
1027 
1028 			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1029 					sizeof(ctx->control), &ctx->control);
1030 			cmsg = true;
1031 			control = ctx->control;
1032 			if (ctx->control != TLS_RECORD_TYPE_DATA) {
1033 				if (cerr || msg->msg_flags & MSG_CTRUNC) {
1034 					err = -EIO;
1035 					goto recv_end;
1036 				}
1037 			}
1038 		} else if (control != ctx->control) {
1039 			goto recv_end;
1040 		}
1041 
1042 		if (!ctx->decrypted) {
1043 			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
1044 
1045 			if (!is_kvec && to_copy <= len &&
1046 			    likely(!(flags & MSG_PEEK)))
1047 				zc = true;
1048 
1049 			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
1050 						 &chunk, &zc);
1051 			if (err < 0 && err != -EINPROGRESS) {
1052 				tls_err_abort(sk, EBADMSG);
1053 				goto recv_end;
1054 			}
1055 
1056 			if (err == -EINPROGRESS) {
1057 				async = true;
1058 				num_async++;
1059 				goto pick_next_record;
1060 			}
1061 
1062 			ctx->decrypted = true;
1063 		}
1064 
1065 		if (!zc) {
1066 			chunk = min_t(unsigned int, rxm->full_len, len);
1067 
1068 			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
1069 						    chunk);
1070 			if (err < 0)
1071 				goto recv_end;
1072 		}
1073 
1074 pick_next_record:
1075 		copied += chunk;
1076 		len -= chunk;
1077 		if (likely(!(flags & MSG_PEEK))) {
1078 			u8 control = ctx->control;
1079 
1080 			/* For async, drop current skb reference */
1081 			if (async)
1082 				skb = NULL;
1083 
1084 			if (tls_sw_advance_skb(sk, skb, chunk)) {
1085 				/* Return full control message to
1086 				 * userspace before trying to parse
1087 				 * another message type
1088 				 */
1089 				msg->msg_flags |= MSG_EOR;
1090 				if (control != TLS_RECORD_TYPE_DATA)
1091 					goto recv_end;
1092 			} else {
1093 				break;
1094 			}
1095 		}
1096 
1097 		/* If we have a new message from strparser, continue now. */
1098 		if (copied >= target && !ctx->recv_pkt)
1099 			break;
1100 	} while (len);
1101 
1102 recv_end:
1103 	if (num_async) {
1104 		/* Wait for all previously submitted records to be decrypted */
1105 		smp_store_mb(ctx->async_notify, true);
1106 		if (atomic_read(&ctx->decrypt_pending)) {
1107 			err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1108 			if (err) {
1109 				/* one of async decrypt failed */
1110 				tls_err_abort(sk, err);
1111 				copied = 0;
1112 			}
1113 		} else {
1114 			reinit_completion(&ctx->async_wait.completion);
1115 		}
1116 		WRITE_ONCE(ctx->async_notify, false);
1117 	}
1118 
1119 	release_sock(sk);
1120 	return copied ? : err;
1121 }
1122 
1123 ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
1124 			   struct pipe_inode_info *pipe,
1125 			   size_t len, unsigned int flags)
1126 {
1127 	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
1128 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1129 	struct strp_msg *rxm = NULL;
1130 	struct sock *sk = sock->sk;
1131 	struct sk_buff *skb;
1132 	ssize_t copied = 0;
1133 	int err = 0;
1134 	long timeo;
1135 	int chunk;
1136 	bool zc = false;
1137 
1138 	lock_sock(sk);
1139 
1140 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1141 
1142 	skb = tls_wait_data(sk, flags, timeo, &err);
1143 	if (!skb)
1144 		goto splice_read_end;
1145 
1146 	/* splice does not support reading control messages */
1147 	if (ctx->control != TLS_RECORD_TYPE_DATA) {
1148 		err = -ENOTSUPP;
1149 		goto splice_read_end;
1150 	}
1151 
1152 	if (!ctx->decrypted) {
1153 		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
1154 
1155 		if (err < 0) {
1156 			tls_err_abort(sk, EBADMSG);
1157 			goto splice_read_end;
1158 		}
1159 		ctx->decrypted = true;
1160 	}
1161 	rxm = strp_msg(skb);
1162 
1163 	chunk = min_t(unsigned int, rxm->full_len, len);
1164 	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
1165 	if (copied < 0)
1166 		goto splice_read_end;
1167 
1168 	if (likely(!(flags & MSG_PEEK)))
1169 		tls_sw_advance_skb(sk, skb, copied);
1170 
1171 splice_read_end:
1172 	release_sock(sk);
1173 	return copied ? : err;
1174 }
1175 
1176 unsigned int tls_sw_poll(struct file *file, struct socket *sock,
1177 			 struct poll_table_struct *wait)
1178 {
1179 	unsigned int ret;
1180 	struct sock *sk = sock->sk;
1181 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1182 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1183 
1184 	/* Grab POLLOUT and POLLHUP from the underlying socket */
1185 	ret = ctx->sk_poll(file, sock, wait);
1186 
1187 	/* Clear POLLIN bits, and set based on recv_pkt */
1188 	ret &= ~(POLLIN | POLLRDNORM);
1189 	if (ctx->recv_pkt)
1190 		ret |= POLLIN | POLLRDNORM;
1191 
1192 	return ret;
1193 }
1194 
1195 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
1196 {
1197 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1198 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1199 	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
1200 	struct strp_msg *rxm = strp_msg(skb);
1201 	size_t cipher_overhead;
1202 	size_t data_len = 0;
1203 	int ret;
1204 
1205 	/* Verify that we have a full TLS header, or wait for more data */
1206 	if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
1207 		return 0;
1208 
1209 	/* Sanity-check size of on-stack buffer. */
1210 	if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
1211 		ret = -EINVAL;
1212 		goto read_failure;
1213 	}
1214 
1215 	/* Linearize header to local buffer */
1216 	ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
1217 
1218 	if (ret < 0)
1219 		goto read_failure;
1220 
1221 	ctx->control = header[0];
1222 
1223 	data_len = ((header[4] & 0xFF) | (header[3] << 8));
1224 
1225 	cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
1226 
1227 	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
1228 		ret = -EMSGSIZE;
1229 		goto read_failure;
1230 	}
1231 	if (data_len < cipher_overhead) {
1232 		ret = -EBADMSG;
1233 		goto read_failure;
1234 	}
1235 
1236 	if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) ||
1237 	    header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) {
1238 		ret = -EINVAL;
1239 		goto read_failure;
1240 	}
1241 
1242 #ifdef CONFIG_TLS_DEVICE
1243 	handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
1244 			     *(u64*)tls_ctx->rx.rec_seq);
1245 #endif
1246 	return data_len + TLS_HEADER_SIZE;
1247 
1248 read_failure:
1249 	tls_err_abort(strp->sk, ret);
1250 
1251 	return ret;
1252 }
1253 
1254 static void tls_queue(struct strparser *strp, struct sk_buff *skb)
1255 {
1256 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1257 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1258 
1259 	ctx->decrypted = false;
1260 
1261 	ctx->recv_pkt = skb;
1262 	strp_pause(strp);
1263 
1264 	ctx->saved_data_ready(strp->sk);
1265 }
1266 
1267 static void tls_data_ready(struct sock *sk)
1268 {
1269 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1270 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1271 
1272 	strp_data_ready(&ctx->strp);
1273 }
1274 
1275 void tls_sw_free_resources_tx(struct sock *sk)
1276 {
1277 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1278 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1279 
1280 	crypto_free_aead(ctx->aead_send);
1281 	tls_free_both_sg(sk);
1282 
1283 	kfree(ctx);
1284 }
1285 
1286 void tls_sw_release_resources_rx(struct sock *sk)
1287 {
1288 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1289 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1290 
1291 	if (ctx->aead_recv) {
1292 		kfree_skb(ctx->recv_pkt);
1293 		ctx->recv_pkt = NULL;
1294 		crypto_free_aead(ctx->aead_recv);
1295 		strp_stop(&ctx->strp);
1296 		write_lock_bh(&sk->sk_callback_lock);
1297 		sk->sk_data_ready = ctx->saved_data_ready;
1298 		write_unlock_bh(&sk->sk_callback_lock);
1299 		release_sock(sk);
1300 		strp_done(&ctx->strp);
1301 		lock_sock(sk);
1302 	}
1303 }
1304 
1305 void tls_sw_free_resources_rx(struct sock *sk)
1306 {
1307 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1308 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1309 
1310 	tls_sw_release_resources_rx(sk);
1311 
1312 	kfree(ctx);
1313 }
1314 
1315 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
1316 {
1317 	char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
1318 	struct tls_crypto_info *crypto_info;
1319 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
1320 	struct tls_sw_context_tx *sw_ctx_tx = NULL;
1321 	struct tls_sw_context_rx *sw_ctx_rx = NULL;
1322 	struct cipher_context *cctx;
1323 	struct crypto_aead **aead;
1324 	struct strp_callbacks cb;
1325 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
1326 	char *iv, *rec_seq;
1327 	int rc = 0;
1328 
1329 	if (!ctx) {
1330 		rc = -EINVAL;
1331 		goto out;
1332 	}
1333 
1334 	if (tx) {
1335 		if (!ctx->priv_ctx_tx) {
1336 			sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
1337 			if (!sw_ctx_tx) {
1338 				rc = -ENOMEM;
1339 				goto out;
1340 			}
1341 			ctx->priv_ctx_tx = sw_ctx_tx;
1342 		} else {
1343 			sw_ctx_tx =
1344 				(struct tls_sw_context_tx *)ctx->priv_ctx_tx;
1345 		}
1346 	} else {
1347 		if (!ctx->priv_ctx_rx) {
1348 			sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
1349 			if (!sw_ctx_rx) {
1350 				rc = -ENOMEM;
1351 				goto out;
1352 			}
1353 			ctx->priv_ctx_rx = sw_ctx_rx;
1354 		} else {
1355 			sw_ctx_rx =
1356 				(struct tls_sw_context_rx *)ctx->priv_ctx_rx;
1357 		}
1358 	}
1359 
1360 	if (tx) {
1361 		crypto_init_wait(&sw_ctx_tx->async_wait);
1362 		crypto_info = &ctx->crypto_send;
1363 		cctx = &ctx->tx;
1364 		aead = &sw_ctx_tx->aead_send;
1365 	} else {
1366 		crypto_init_wait(&sw_ctx_rx->async_wait);
1367 		crypto_info = &ctx->crypto_recv;
1368 		cctx = &ctx->rx;
1369 		aead = &sw_ctx_rx->aead_recv;
1370 	}
1371 
1372 	switch (crypto_info->cipher_type) {
1373 	case TLS_CIPHER_AES_GCM_128: {
1374 		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1375 		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1376 		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1377 		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1378 		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1379 		rec_seq =
1380 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1381 		gcm_128_info =
1382 			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
1383 		break;
1384 	}
1385 	default:
1386 		rc = -EINVAL;
1387 		goto free_priv;
1388 	}
1389 
1390 	/* Sanity-check the IV size for stack allocations. */
1391 	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
1392 		rc = -EINVAL;
1393 		goto free_priv;
1394 	}
1395 
1396 	cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1397 	cctx->tag_size = tag_size;
1398 	cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1399 	cctx->iv_size = iv_size;
1400 	cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1401 			   GFP_KERNEL);
1402 	if (!cctx->iv) {
1403 		rc = -ENOMEM;
1404 		goto free_priv;
1405 	}
1406 	memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1407 	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1408 	cctx->rec_seq_size = rec_seq_size;
1409 	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
1410 	if (!cctx->rec_seq) {
1411 		rc = -ENOMEM;
1412 		goto free_iv;
1413 	}
1414 
1415 	if (sw_ctx_tx) {
1416 		sg_init_table(sw_ctx_tx->sg_encrypted_data,
1417 			      ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
1418 		sg_init_table(sw_ctx_tx->sg_plaintext_data,
1419 			      ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
1420 
1421 		sg_init_table(sw_ctx_tx->sg_aead_in, 2);
1422 		sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
1423 			   sizeof(sw_ctx_tx->aad_space));
1424 		sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
1425 		sg_chain(sw_ctx_tx->sg_aead_in, 2,
1426 			 sw_ctx_tx->sg_plaintext_data);
1427 		sg_init_table(sw_ctx_tx->sg_aead_out, 2);
1428 		sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
1429 			   sizeof(sw_ctx_tx->aad_space));
1430 		sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
1431 		sg_chain(sw_ctx_tx->sg_aead_out, 2,
1432 			 sw_ctx_tx->sg_encrypted_data);
1433 	}
1434 
1435 	if (!*aead) {
1436 		*aead = crypto_alloc_aead("gcm(aes)", 0, 0);
1437 		if (IS_ERR(*aead)) {
1438 			rc = PTR_ERR(*aead);
1439 			*aead = NULL;
1440 			goto free_rec_seq;
1441 		}
1442 	}
1443 
1444 	ctx->push_pending_record = tls_sw_push_pending_record;
1445 
1446 	memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1447 
1448 	rc = crypto_aead_setkey(*aead, keyval,
1449 				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1450 	if (rc)
1451 		goto free_aead;
1452 
1453 	rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
1454 	if (rc)
1455 		goto free_aead;
1456 
1457 	if (sw_ctx_rx) {
1458 		(*aead)->reqsize = sizeof(struct decrypt_req_ctx);
1459 
1460 		/* Set up strparser */
1461 		memset(&cb, 0, sizeof(cb));
1462 		cb.rcv_msg = tls_queue;
1463 		cb.parse_msg = tls_read_size;
1464 
1465 		strp_init(&sw_ctx_rx->strp, sk, &cb);
1466 
1467 		write_lock_bh(&sk->sk_callback_lock);
1468 		sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
1469 		sk->sk_data_ready = tls_data_ready;
1470 		write_unlock_bh(&sk->sk_callback_lock);
1471 
1472 		sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
1473 
1474 		strp_check_rcv(&sw_ctx_rx->strp);
1475 	}
1476 
1477 	goto out;
1478 
1479 free_aead:
1480 	crypto_free_aead(*aead);
1481 	*aead = NULL;
1482 free_rec_seq:
1483 	kfree(cctx->rec_seq);
1484 	cctx->rec_seq = NULL;
1485 free_iv:
1486 	kfree(cctx->iv);
1487 	cctx->iv = NULL;
1488 free_priv:
1489 	if (tx) {
1490 		kfree(ctx->priv_ctx_tx);
1491 		ctx->priv_ctx_tx = NULL;
1492 	} else {
1493 		kfree(ctx->priv_ctx_rx);
1494 		ctx->priv_ctx_rx = NULL;
1495 	}
1496 out:
1497 	return rc;
1498 }
1499