xref: /linux/net/tls/tls_sw.c (revision cd11d11286cba88aab5b1da1c83ee36e5b5cefb7)
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  *
8  * This software is available to you under a choice of one of two
9  * licenses.  You may choose to be licensed under the terms of the GNU
10  * General Public License (GPL) Version 2, available from the file
11  * COPYING in the main directory of this source tree, or the
12  * OpenIB.org BSD license below:
13  *
14  *     Redistribution and use in source and binary forms, with or
15  *     without modification, are permitted provided that the following
16  *     conditions are met:
17  *
18  *      - Redistributions of source code must retain the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer.
21  *
22  *      - Redistributions in binary form must reproduce the above
23  *        copyright notice, this list of conditions and the following
24  *        disclaimer in the documentation and/or other materials
25  *        provided with the distribution.
26  *
27  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34  * SOFTWARE.
35  */
36 
37 #include <linux/sched/signal.h>
38 #include <linux/module.h>
39 #include <crypto/aead.h>
40 
41 #include <net/strparser.h>
42 #include <net/tls.h>
43 
44 #define MAX_IV_SIZE	TLS_CIPHER_AES_GCM_128_IV_SIZE
45 
46 static int __skb_nsg(struct sk_buff *skb, int offset, int len,
47                      unsigned int recursion_level)
48 {
49         int start = skb_headlen(skb);
50         int i, chunk = start - offset;
51         struct sk_buff *frag_iter;
52         int elt = 0;
53 
54         if (unlikely(recursion_level >= 24))
55                 return -EMSGSIZE;
56 
57         if (chunk > 0) {
58                 if (chunk > len)
59                         chunk = len;
60                 elt++;
61                 len -= chunk;
62                 if (len == 0)
63                         return elt;
64                 offset += chunk;
65         }
66 
67         for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
68                 int end;
69 
70                 WARN_ON(start > offset + len);
71 
72                 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
73                 chunk = end - offset;
74                 if (chunk > 0) {
75                         if (chunk > len)
76                                 chunk = len;
77                         elt++;
78                         len -= chunk;
79                         if (len == 0)
80                                 return elt;
81                         offset += chunk;
82                 }
83                 start = end;
84         }
85 
86         if (unlikely(skb_has_frag_list(skb))) {
87                 skb_walk_frags(skb, frag_iter) {
88                         int end, ret;
89 
90                         WARN_ON(start > offset + len);
91 
92                         end = start + frag_iter->len;
93                         chunk = end - offset;
94                         if (chunk > 0) {
95                                 if (chunk > len)
96                                         chunk = len;
97                                 ret = __skb_nsg(frag_iter, offset - start, chunk,
98                                                 recursion_level + 1);
99                                 if (unlikely(ret < 0))
100                                         return ret;
101                                 elt += ret;
102                                 len -= chunk;
103                                 if (len == 0)
104                                         return elt;
105                                 offset += chunk;
106                         }
107                         start = end;
108                 }
109         }
110         BUG_ON(len);
111         return elt;
112 }
113 
114 /* Return the number of scatterlist elements required to completely map the
115  * skb, or -EMSGSIZE if the recursion depth is exceeded.
116  */
117 static int skb_nsg(struct sk_buff *skb, int offset, int len)
118 {
119         return __skb_nsg(skb, offset, len, 0);
120 }
121 
122 static void tls_decrypt_done(struct crypto_async_request *req, int err)
123 {
124 	struct aead_request *aead_req = (struct aead_request *)req;
125 	struct scatterlist *sgout = aead_req->dst;
126 	struct tls_sw_context_rx *ctx;
127 	struct tls_context *tls_ctx;
128 	struct scatterlist *sg;
129 	struct sk_buff *skb;
130 	unsigned int pages;
131 	int pending;
132 
133 	skb = (struct sk_buff *)req->data;
134 	tls_ctx = tls_get_ctx(skb->sk);
135 	ctx = tls_sw_ctx_rx(tls_ctx);
136 	pending = atomic_dec_return(&ctx->decrypt_pending);
137 
138 	/* Propagate if there was an err */
139 	if (err) {
140 		ctx->async_wait.err = err;
141 		tls_err_abort(skb->sk, err);
142 	}
143 
144 	/* After using skb->sk to propagate sk through crypto async callback
145 	 * we need to NULL it again.
146 	 */
147 	skb->sk = NULL;
148 
149 	/* Release the skb, pages and memory allocated for crypto req */
150 	kfree_skb(skb);
151 
152 	/* Skip the first S/G entry as it points to AAD */
153 	for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
154 		if (!sg)
155 			break;
156 		put_page(sg_page(sg));
157 	}
158 
159 	kfree(aead_req);
160 
161 	if (!pending && READ_ONCE(ctx->async_notify))
162 		complete(&ctx->async_wait.completion);
163 }
164 
165 static int tls_do_decryption(struct sock *sk,
166 			     struct sk_buff *skb,
167 			     struct scatterlist *sgin,
168 			     struct scatterlist *sgout,
169 			     char *iv_recv,
170 			     size_t data_len,
171 			     struct aead_request *aead_req,
172 			     bool async)
173 {
174 	struct tls_context *tls_ctx = tls_get_ctx(sk);
175 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
176 	int ret;
177 
178 	aead_request_set_tfm(aead_req, ctx->aead_recv);
179 	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
180 	aead_request_set_crypt(aead_req, sgin, sgout,
181 			       data_len + tls_ctx->rx.tag_size,
182 			       (u8 *)iv_recv);
183 
184 	if (async) {
185 		/* Using skb->sk to push sk through to crypto async callback
186 		 * handler. This allows propagating errors up to the socket
187 		 * if needed. It _must_ be cleared in the async handler
188 		 * before kfree_skb is called. We _know_ skb->sk is NULL
189 		 * because it is a clone from strparser.
190 		 */
191 		skb->sk = sk;
192 		aead_request_set_callback(aead_req,
193 					  CRYPTO_TFM_REQ_MAY_BACKLOG,
194 					  tls_decrypt_done, skb);
195 		atomic_inc(&ctx->decrypt_pending);
196 	} else {
197 		aead_request_set_callback(aead_req,
198 					  CRYPTO_TFM_REQ_MAY_BACKLOG,
199 					  crypto_req_done, &ctx->async_wait);
200 	}
201 
202 	ret = crypto_aead_decrypt(aead_req);
203 	if (ret == -EINPROGRESS) {
204 		if (async)
205 			return ret;
206 
207 		ret = crypto_wait_req(ret, &ctx->async_wait);
208 	}
209 
210 	if (async)
211 		atomic_dec(&ctx->decrypt_pending);
212 
213 	return ret;
214 }
215 
216 static void trim_sg(struct sock *sk, struct scatterlist *sg,
217 		    int *sg_num_elem, unsigned int *sg_size, int target_size)
218 {
219 	int i = *sg_num_elem - 1;
220 	int trim = *sg_size - target_size;
221 
222 	if (trim <= 0) {
223 		WARN_ON(trim < 0);
224 		return;
225 	}
226 
227 	*sg_size = target_size;
228 	while (trim >= sg[i].length) {
229 		trim -= sg[i].length;
230 		sk_mem_uncharge(sk, sg[i].length);
231 		put_page(sg_page(&sg[i]));
232 		i--;
233 
234 		if (i < 0)
235 			goto out;
236 	}
237 
238 	sg[i].length -= trim;
239 	sk_mem_uncharge(sk, trim);
240 
241 out:
242 	*sg_num_elem = i + 1;
243 }
244 
245 static void trim_both_sgl(struct sock *sk, int target_size)
246 {
247 	struct tls_context *tls_ctx = tls_get_ctx(sk);
248 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
249 	struct tls_rec *rec = ctx->open_rec;
250 
251 	trim_sg(sk, rec->sg_plaintext_data,
252 		&rec->sg_plaintext_num_elem,
253 		&rec->sg_plaintext_size,
254 		target_size);
255 
256 	if (target_size > 0)
257 		target_size += tls_ctx->tx.overhead_size;
258 
259 	trim_sg(sk, rec->sg_encrypted_data,
260 		&rec->sg_encrypted_num_elem,
261 		&rec->sg_encrypted_size,
262 		target_size);
263 }
264 
265 static int alloc_encrypted_sg(struct sock *sk, int len)
266 {
267 	struct tls_context *tls_ctx = tls_get_ctx(sk);
268 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
269 	struct tls_rec *rec = ctx->open_rec;
270 	int rc = 0;
271 
272 	rc = sk_alloc_sg(sk, len,
273 			 rec->sg_encrypted_data, 0,
274 			 &rec->sg_encrypted_num_elem,
275 			 &rec->sg_encrypted_size, 0);
276 
277 	if (rc == -ENOSPC)
278 		rec->sg_encrypted_num_elem = ARRAY_SIZE(rec->sg_encrypted_data);
279 
280 	return rc;
281 }
282 
283 static int alloc_plaintext_sg(struct sock *sk, int len)
284 {
285 	struct tls_context *tls_ctx = tls_get_ctx(sk);
286 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
287 	struct tls_rec *rec = ctx->open_rec;
288 	int rc = 0;
289 
290 	rc = sk_alloc_sg(sk, len, rec->sg_plaintext_data, 0,
291 			 &rec->sg_plaintext_num_elem, &rec->sg_plaintext_size,
292 			 tls_ctx->pending_open_record_frags);
293 
294 	if (rc == -ENOSPC)
295 		rec->sg_plaintext_num_elem = ARRAY_SIZE(rec->sg_plaintext_data);
296 
297 	return rc;
298 }
299 
300 static void free_sg(struct sock *sk, struct scatterlist *sg,
301 		    int *sg_num_elem, unsigned int *sg_size)
302 {
303 	int i, n = *sg_num_elem;
304 
305 	for (i = 0; i < n; ++i) {
306 		sk_mem_uncharge(sk, sg[i].length);
307 		put_page(sg_page(&sg[i]));
308 	}
309 	*sg_num_elem = 0;
310 	*sg_size = 0;
311 }
312 
313 static void tls_free_open_rec(struct sock *sk)
314 {
315 	struct tls_context *tls_ctx = tls_get_ctx(sk);
316 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
317 	struct tls_rec *rec = ctx->open_rec;
318 
319 	/* Return if there is no open record */
320 	if (!rec)
321 		return;
322 
323 	free_sg(sk, rec->sg_encrypted_data,
324 		&rec->sg_encrypted_num_elem,
325 		&rec->sg_encrypted_size);
326 
327 	free_sg(sk, rec->sg_plaintext_data,
328 		&rec->sg_plaintext_num_elem,
329 		&rec->sg_plaintext_size);
330 
331 	kfree(rec);
332 }
333 
334 int tls_tx_records(struct sock *sk, int flags)
335 {
336 	struct tls_context *tls_ctx = tls_get_ctx(sk);
337 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
338 	struct tls_rec *rec, *tmp;
339 	int tx_flags, rc = 0;
340 
341 	if (tls_is_partially_sent_record(tls_ctx)) {
342 		rec = list_first_entry(&ctx->tx_list,
343 				       struct tls_rec, list);
344 
345 		if (flags == -1)
346 			tx_flags = rec->tx_flags;
347 		else
348 			tx_flags = flags;
349 
350 		rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
351 		if (rc)
352 			goto tx_err;
353 
354 		/* Full record has been transmitted.
355 		 * Remove the head of tx_list
356 		 */
357 		list_del(&rec->list);
358 		free_sg(sk, rec->sg_plaintext_data,
359 			&rec->sg_plaintext_num_elem, &rec->sg_plaintext_size);
360 
361 		kfree(rec);
362 	}
363 
364 	/* Tx all ready records */
365 	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
366 		if (READ_ONCE(rec->tx_ready)) {
367 			if (flags == -1)
368 				tx_flags = rec->tx_flags;
369 			else
370 				tx_flags = flags;
371 
372 			rc = tls_push_sg(sk, tls_ctx,
373 					 &rec->sg_encrypted_data[0],
374 					 0, tx_flags);
375 			if (rc)
376 				goto tx_err;
377 
378 			list_del(&rec->list);
379 			free_sg(sk, rec->sg_plaintext_data,
380 				&rec->sg_plaintext_num_elem,
381 				&rec->sg_plaintext_size);
382 
383 			kfree(rec);
384 		} else {
385 			break;
386 		}
387 	}
388 
389 tx_err:
390 	if (rc < 0 && rc != -EAGAIN)
391 		tls_err_abort(sk, EBADMSG);
392 
393 	return rc;
394 }
395 
396 static void tls_encrypt_done(struct crypto_async_request *req, int err)
397 {
398 	struct aead_request *aead_req = (struct aead_request *)req;
399 	struct sock *sk = req->data;
400 	struct tls_context *tls_ctx = tls_get_ctx(sk);
401 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
402 	struct tls_rec *rec;
403 	bool ready = false;
404 	int pending;
405 
406 	rec = container_of(aead_req, struct tls_rec, aead_req);
407 
408 	rec->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
409 	rec->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
410 
411 
412 	/* Free the record if error is previously set on socket */
413 	if (err || sk->sk_err) {
414 		free_sg(sk, rec->sg_encrypted_data,
415 			&rec->sg_encrypted_num_elem, &rec->sg_encrypted_size);
416 
417 		kfree(rec);
418 		rec = NULL;
419 
420 		/* If err is already set on socket, return the same code */
421 		if (sk->sk_err) {
422 			ctx->async_wait.err = sk->sk_err;
423 		} else {
424 			ctx->async_wait.err = err;
425 			tls_err_abort(sk, err);
426 		}
427 	}
428 
429 	if (rec) {
430 		struct tls_rec *first_rec;
431 
432 		/* Mark the record as ready for transmission */
433 		smp_store_mb(rec->tx_ready, true);
434 
435 		/* If received record is at head of tx_list, schedule tx */
436 		first_rec = list_first_entry(&ctx->tx_list,
437 					     struct tls_rec, list);
438 		if (rec == first_rec)
439 			ready = true;
440 	}
441 
442 	pending = atomic_dec_return(&ctx->encrypt_pending);
443 
444 	if (!pending && READ_ONCE(ctx->async_notify))
445 		complete(&ctx->async_wait.completion);
446 
447 	if (!ready)
448 		return;
449 
450 	/* Schedule the transmission */
451 	if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
452 		schedule_delayed_work(&ctx->tx_work.work, 1);
453 }
454 
455 static int tls_do_encryption(struct sock *sk,
456 			     struct tls_context *tls_ctx,
457 			     struct tls_sw_context_tx *ctx,
458 			     struct aead_request *aead_req,
459 			     size_t data_len)
460 {
461 	struct tls_rec *rec = ctx->open_rec;
462 	int rc;
463 
464 	rec->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
465 	rec->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
466 
467 	aead_request_set_tfm(aead_req, ctx->aead_send);
468 	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
469 	aead_request_set_crypt(aead_req, rec->sg_aead_in,
470 			       rec->sg_aead_out,
471 			       data_len, tls_ctx->tx.iv);
472 
473 	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
474 				  tls_encrypt_done, sk);
475 
476 	/* Add the record in tx_list */
477 	list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
478 	atomic_inc(&ctx->encrypt_pending);
479 
480 	rc = crypto_aead_encrypt(aead_req);
481 	if (!rc || rc != -EINPROGRESS) {
482 		atomic_dec(&ctx->encrypt_pending);
483 		rec->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
484 		rec->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
485 	}
486 
487 	if (!rc) {
488 		WRITE_ONCE(rec->tx_ready, true);
489 	} else if (rc != -EINPROGRESS) {
490 		list_del(&rec->list);
491 		return rc;
492 	}
493 
494 	/* Unhook the record from context if encryption is not failure */
495 	ctx->open_rec = NULL;
496 	tls_advance_record_sn(sk, &tls_ctx->tx);
497 	return rc;
498 }
499 
500 static int tls_push_record(struct sock *sk, int flags,
501 			   unsigned char record_type)
502 {
503 	struct tls_context *tls_ctx = tls_get_ctx(sk);
504 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
505 	struct tls_rec *rec = ctx->open_rec;
506 	struct aead_request *req;
507 	int rc;
508 
509 	if (!rec)
510 		return 0;
511 
512 	rec->tx_flags = flags;
513 	req = &rec->aead_req;
514 
515 	sg_mark_end(rec->sg_plaintext_data + rec->sg_plaintext_num_elem - 1);
516 	sg_mark_end(rec->sg_encrypted_data + rec->sg_encrypted_num_elem - 1);
517 
518 	tls_make_aad(rec->aad_space, rec->sg_plaintext_size,
519 		     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
520 		     record_type);
521 
522 	tls_fill_prepend(tls_ctx,
523 			 page_address(sg_page(&rec->sg_encrypted_data[0])) +
524 			 rec->sg_encrypted_data[0].offset,
525 			 rec->sg_plaintext_size, record_type);
526 
527 	tls_ctx->pending_open_record_frags = 0;
528 
529 	rc = tls_do_encryption(sk, tls_ctx, ctx, req, rec->sg_plaintext_size);
530 	if (rc == -EINPROGRESS)
531 		return -EINPROGRESS;
532 
533 	if (rc < 0) {
534 		tls_err_abort(sk, EBADMSG);
535 		return rc;
536 	}
537 
538 	return tls_tx_records(sk, flags);
539 }
540 
541 static int tls_sw_push_pending_record(struct sock *sk, int flags)
542 {
543 	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
544 }
545 
546 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
547 			      int length, int *pages_used,
548 			      unsigned int *size_used,
549 			      struct scatterlist *to, int to_max_pages,
550 			      bool charge)
551 {
552 	struct page *pages[MAX_SKB_FRAGS];
553 
554 	size_t offset;
555 	ssize_t copied, use;
556 	int i = 0;
557 	unsigned int size = *size_used;
558 	int num_elem = *pages_used;
559 	int rc = 0;
560 	int maxpages;
561 
562 	while (length > 0) {
563 		i = 0;
564 		maxpages = to_max_pages - num_elem;
565 		if (maxpages == 0) {
566 			rc = -EFAULT;
567 			goto out;
568 		}
569 		copied = iov_iter_get_pages(from, pages,
570 					    length,
571 					    maxpages, &offset);
572 		if (copied <= 0) {
573 			rc = -EFAULT;
574 			goto out;
575 		}
576 
577 		iov_iter_advance(from, copied);
578 
579 		length -= copied;
580 		size += copied;
581 		while (copied) {
582 			use = min_t(int, copied, PAGE_SIZE - offset);
583 
584 			sg_set_page(&to[num_elem],
585 				    pages[i], use, offset);
586 			sg_unmark_end(&to[num_elem]);
587 			if (charge)
588 				sk_mem_charge(sk, use);
589 
590 			offset = 0;
591 			copied -= use;
592 
593 			++i;
594 			++num_elem;
595 		}
596 	}
597 
598 	/* Mark the end in the last sg entry if newly added */
599 	if (num_elem > *pages_used)
600 		sg_mark_end(&to[num_elem - 1]);
601 out:
602 	if (rc)
603 		iov_iter_revert(from, size - *size_used);
604 	*size_used = size;
605 	*pages_used = num_elem;
606 
607 	return rc;
608 }
609 
610 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
611 			     int bytes)
612 {
613 	struct tls_context *tls_ctx = tls_get_ctx(sk);
614 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
615 	struct tls_rec *rec = ctx->open_rec;
616 	struct scatterlist *sg = rec->sg_plaintext_data;
617 	int copy, i, rc = 0;
618 
619 	for (i = tls_ctx->pending_open_record_frags;
620 	     i < rec->sg_plaintext_num_elem; ++i) {
621 		copy = sg[i].length;
622 		if (copy_from_iter(
623 				page_address(sg_page(&sg[i])) + sg[i].offset,
624 				copy, from) != copy) {
625 			rc = -EFAULT;
626 			goto out;
627 		}
628 		bytes -= copy;
629 
630 		++tls_ctx->pending_open_record_frags;
631 
632 		if (!bytes)
633 			break;
634 	}
635 
636 out:
637 	return rc;
638 }
639 
640 struct tls_rec *get_rec(struct sock *sk)
641 {
642 	struct tls_context *tls_ctx = tls_get_ctx(sk);
643 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
644 	struct tls_rec *rec;
645 	int mem_size;
646 
647 	/* Return if we already have an open record */
648 	if (ctx->open_rec)
649 		return ctx->open_rec;
650 
651 	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
652 
653 	rec = kzalloc(mem_size, sk->sk_allocation);
654 	if (!rec)
655 		return NULL;
656 
657 	sg_init_table(&rec->sg_plaintext_data[0],
658 		      ARRAY_SIZE(rec->sg_plaintext_data));
659 	sg_init_table(&rec->sg_encrypted_data[0],
660 		      ARRAY_SIZE(rec->sg_encrypted_data));
661 
662 	sg_init_table(rec->sg_aead_in, 2);
663 	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space,
664 		   sizeof(rec->aad_space));
665 	sg_unmark_end(&rec->sg_aead_in[1]);
666 	sg_chain(rec->sg_aead_in, 2, rec->sg_plaintext_data);
667 
668 	sg_init_table(rec->sg_aead_out, 2);
669 	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space,
670 		   sizeof(rec->aad_space));
671 	sg_unmark_end(&rec->sg_aead_out[1]);
672 	sg_chain(rec->sg_aead_out, 2, rec->sg_encrypted_data);
673 
674 	ctx->open_rec = rec;
675 
676 	return rec;
677 }
678 
679 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
680 {
681 	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
682 	struct tls_context *tls_ctx = tls_get_ctx(sk);
683 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
684 	struct crypto_tfm *tfm = crypto_aead_tfm(ctx->aead_send);
685 	bool async_capable = tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
686 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
687 	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
688 	bool eor = !(msg->msg_flags & MSG_MORE);
689 	size_t try_to_copy, copied = 0;
690 	struct tls_rec *rec;
691 	int required_size;
692 	int num_async = 0;
693 	bool full_record;
694 	int record_room;
695 	int num_zc = 0;
696 	int orig_size;
697 	int ret = 0;
698 
699 	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
700 		return -ENOTSUPP;
701 
702 	lock_sock(sk);
703 
704 	/* Wait till there is any pending write on socket */
705 	if (unlikely(sk->sk_write_pending)) {
706 		ret = wait_on_pending_writer(sk, &timeo);
707 		if (unlikely(ret))
708 			goto send_end;
709 	}
710 
711 	if (unlikely(msg->msg_controllen)) {
712 		ret = tls_proccess_cmsg(sk, msg, &record_type);
713 		if (ret) {
714 			if (ret == -EINPROGRESS)
715 				num_async++;
716 			else if (ret != -EAGAIN)
717 				goto send_end;
718 		}
719 	}
720 
721 	while (msg_data_left(msg)) {
722 		if (sk->sk_err) {
723 			ret = -sk->sk_err;
724 			goto send_end;
725 		}
726 
727 		rec = get_rec(sk);
728 		if (!rec) {
729 			ret = -ENOMEM;
730 			goto send_end;
731 		}
732 
733 		orig_size = rec->sg_plaintext_size;
734 		full_record = false;
735 		try_to_copy = msg_data_left(msg);
736 		record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size;
737 		if (try_to_copy >= record_room) {
738 			try_to_copy = record_room;
739 			full_record = true;
740 		}
741 
742 		required_size = rec->sg_plaintext_size + try_to_copy +
743 				tls_ctx->tx.overhead_size;
744 
745 		if (!sk_stream_memory_free(sk))
746 			goto wait_for_sndbuf;
747 
748 alloc_encrypted:
749 		ret = alloc_encrypted_sg(sk, required_size);
750 		if (ret) {
751 			if (ret != -ENOSPC)
752 				goto wait_for_memory;
753 
754 			/* Adjust try_to_copy according to the amount that was
755 			 * actually allocated. The difference is due
756 			 * to max sg elements limit
757 			 */
758 			try_to_copy -= required_size - rec->sg_encrypted_size;
759 			full_record = true;
760 		}
761 
762 		if (!is_kvec && (full_record || eor) && !async_capable) {
763 			ret = zerocopy_from_iter(sk, &msg->msg_iter,
764 				try_to_copy, &rec->sg_plaintext_num_elem,
765 				&rec->sg_plaintext_size,
766 				rec->sg_plaintext_data,
767 				ARRAY_SIZE(rec->sg_plaintext_data),
768 				true);
769 			if (ret)
770 				goto fallback_to_reg_send;
771 
772 			num_zc++;
773 			copied += try_to_copy;
774 			ret = tls_push_record(sk, msg->msg_flags, record_type);
775 			if (ret) {
776 				if (ret == -EINPROGRESS)
777 					num_async++;
778 				else if (ret != -EAGAIN)
779 					goto send_end;
780 			}
781 			continue;
782 
783 fallback_to_reg_send:
784 			trim_sg(sk, rec->sg_plaintext_data,
785 				&rec->sg_plaintext_num_elem,
786 				&rec->sg_plaintext_size,
787 				orig_size);
788 		}
789 
790 		required_size = rec->sg_plaintext_size + try_to_copy;
791 alloc_plaintext:
792 		ret = alloc_plaintext_sg(sk, required_size);
793 		if (ret) {
794 			if (ret != -ENOSPC)
795 				goto wait_for_memory;
796 
797 			/* Adjust try_to_copy according to the amount that was
798 			 * actually allocated. The difference is due
799 			 * to max sg elements limit
800 			 */
801 			try_to_copy -= required_size - rec->sg_plaintext_size;
802 			full_record = true;
803 
804 			trim_sg(sk, rec->sg_encrypted_data,
805 				&rec->sg_encrypted_num_elem,
806 				&rec->sg_encrypted_size,
807 				rec->sg_plaintext_size +
808 				tls_ctx->tx.overhead_size);
809 		}
810 
811 		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
812 		if (ret)
813 			goto trim_sgl;
814 
815 		copied += try_to_copy;
816 		if (full_record || eor) {
817 			ret = tls_push_record(sk, msg->msg_flags, record_type);
818 			if (ret) {
819 				if (ret == -EINPROGRESS)
820 					num_async++;
821 				else if (ret != -EAGAIN)
822 					goto send_end;
823 			}
824 		}
825 
826 		continue;
827 
828 wait_for_sndbuf:
829 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
830 wait_for_memory:
831 		ret = sk_stream_wait_memory(sk, &timeo);
832 		if (ret) {
833 trim_sgl:
834 			trim_both_sgl(sk, orig_size);
835 			goto send_end;
836 		}
837 
838 		if (rec->sg_encrypted_size < required_size)
839 			goto alloc_encrypted;
840 
841 		goto alloc_plaintext;
842 	}
843 
844 	if (!num_async) {
845 		goto send_end;
846 	} else if (num_zc) {
847 		/* Wait for pending encryptions to get completed */
848 		smp_store_mb(ctx->async_notify, true);
849 
850 		if (atomic_read(&ctx->encrypt_pending))
851 			crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
852 		else
853 			reinit_completion(&ctx->async_wait.completion);
854 
855 		WRITE_ONCE(ctx->async_notify, false);
856 
857 		if (ctx->async_wait.err) {
858 			ret = ctx->async_wait.err;
859 			copied = 0;
860 		}
861 	}
862 
863 	/* Transmit if any encryptions have completed */
864 	if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
865 		cancel_delayed_work(&ctx->tx_work.work);
866 		tls_tx_records(sk, msg->msg_flags);
867 	}
868 
869 send_end:
870 	ret = sk_stream_error(sk, msg->msg_flags, ret);
871 
872 	release_sock(sk);
873 	return copied ? copied : ret;
874 }
875 
876 int tls_sw_sendpage(struct sock *sk, struct page *page,
877 		    int offset, size_t size, int flags)
878 {
879 	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
880 	struct tls_context *tls_ctx = tls_get_ctx(sk);
881 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
882 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
883 	size_t orig_size = size;
884 	struct scatterlist *sg;
885 	struct tls_rec *rec;
886 	int num_async = 0;
887 	bool full_record;
888 	int record_room;
889 	int ret = 0;
890 	bool eor;
891 
892 	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
893 		      MSG_SENDPAGE_NOTLAST))
894 		return -ENOTSUPP;
895 
896 	/* No MSG_EOR from splice, only look at MSG_MORE */
897 	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
898 
899 	lock_sock(sk);
900 
901 	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
902 
903 	/* Wait till there is any pending write on socket */
904 	if (unlikely(sk->sk_write_pending)) {
905 		ret = wait_on_pending_writer(sk, &timeo);
906 		if (unlikely(ret))
907 			goto sendpage_end;
908 	}
909 
910 	/* Call the sk_stream functions to manage the sndbuf mem. */
911 	while (size > 0) {
912 		size_t copy, required_size;
913 
914 		if (sk->sk_err) {
915 			ret = -sk->sk_err;
916 			goto sendpage_end;
917 		}
918 
919 		rec = get_rec(sk);
920 		if (!rec) {
921 			ret = -ENOMEM;
922 			goto sendpage_end;
923 		}
924 
925 		full_record = false;
926 		record_room = TLS_MAX_PAYLOAD_SIZE - rec->sg_plaintext_size;
927 		copy = size;
928 		if (copy >= record_room) {
929 			copy = record_room;
930 			full_record = true;
931 		}
932 		required_size = rec->sg_plaintext_size + copy +
933 			      tls_ctx->tx.overhead_size;
934 
935 		if (!sk_stream_memory_free(sk))
936 			goto wait_for_sndbuf;
937 alloc_payload:
938 		ret = alloc_encrypted_sg(sk, required_size);
939 		if (ret) {
940 			if (ret != -ENOSPC)
941 				goto wait_for_memory;
942 
943 			/* Adjust copy according to the amount that was
944 			 * actually allocated. The difference is due
945 			 * to max sg elements limit
946 			 */
947 			copy -= required_size - rec->sg_plaintext_size;
948 			full_record = true;
949 		}
950 
951 		get_page(page);
952 		sg = rec->sg_plaintext_data + rec->sg_plaintext_num_elem;
953 		sg_set_page(sg, page, copy, offset);
954 		sg_unmark_end(sg);
955 
956 		rec->sg_plaintext_num_elem++;
957 
958 		sk_mem_charge(sk, copy);
959 		offset += copy;
960 		size -= copy;
961 		rec->sg_plaintext_size += copy;
962 		tls_ctx->pending_open_record_frags = rec->sg_plaintext_num_elem;
963 
964 		if (full_record || eor ||
965 		    rec->sg_plaintext_num_elem ==
966 		    ARRAY_SIZE(rec->sg_plaintext_data)) {
967 			ret = tls_push_record(sk, flags, record_type);
968 			if (ret) {
969 				if (ret == -EINPROGRESS)
970 					num_async++;
971 				else if (ret != -EAGAIN)
972 					goto sendpage_end;
973 			}
974 		}
975 		continue;
976 wait_for_sndbuf:
977 		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
978 wait_for_memory:
979 		ret = sk_stream_wait_memory(sk, &timeo);
980 		if (ret) {
981 			trim_both_sgl(sk, rec->sg_plaintext_size);
982 			goto sendpage_end;
983 		}
984 
985 		goto alloc_payload;
986 	}
987 
988 	if (num_async) {
989 		/* Transmit if any encryptions have completed */
990 		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
991 			cancel_delayed_work(&ctx->tx_work.work);
992 			tls_tx_records(sk, flags);
993 		}
994 	}
995 sendpage_end:
996 	if (orig_size > size)
997 		ret = orig_size - size;
998 	else
999 		ret = sk_stream_error(sk, flags, ret);
1000 
1001 	release_sock(sk);
1002 	return ret;
1003 }
1004 
1005 static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
1006 				     long timeo, int *err)
1007 {
1008 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1009 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1010 	struct sk_buff *skb;
1011 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
1012 
1013 	while (!(skb = ctx->recv_pkt)) {
1014 		if (sk->sk_err) {
1015 			*err = sock_error(sk);
1016 			return NULL;
1017 		}
1018 
1019 		if (sk->sk_shutdown & RCV_SHUTDOWN)
1020 			return NULL;
1021 
1022 		if (sock_flag(sk, SOCK_DONE))
1023 			return NULL;
1024 
1025 		if ((flags & MSG_DONTWAIT) || !timeo) {
1026 			*err = -EAGAIN;
1027 			return NULL;
1028 		}
1029 
1030 		add_wait_queue(sk_sleep(sk), &wait);
1031 		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1032 		sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
1033 		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
1034 		remove_wait_queue(sk_sleep(sk), &wait);
1035 
1036 		/* Handle signals */
1037 		if (signal_pending(current)) {
1038 			*err = sock_intr_errno(timeo);
1039 			return NULL;
1040 		}
1041 	}
1042 
1043 	return skb;
1044 }
1045 
1046 /* This function decrypts the input skb into either out_iov or in out_sg
1047  * or in skb buffers itself. The input parameter 'zc' indicates if
1048  * zero-copy mode needs to be tried or not. With zero-copy mode, either
1049  * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
1050  * NULL, then the decryption happens inside skb buffers itself, i.e.
1051  * zero-copy gets disabled and 'zc' is updated.
1052  */
1053 
1054 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
1055 			    struct iov_iter *out_iov,
1056 			    struct scatterlist *out_sg,
1057 			    int *chunk, bool *zc)
1058 {
1059 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1060 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1061 	struct strp_msg *rxm = strp_msg(skb);
1062 	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
1063 	struct aead_request *aead_req;
1064 	struct sk_buff *unused;
1065 	u8 *aad, *iv, *mem = NULL;
1066 	struct scatterlist *sgin = NULL;
1067 	struct scatterlist *sgout = NULL;
1068 	const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
1069 
1070 	if (*zc && (out_iov || out_sg)) {
1071 		if (out_iov)
1072 			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
1073 		else
1074 			n_sgout = sg_nents(out_sg);
1075 		n_sgin = skb_nsg(skb, rxm->offset + tls_ctx->rx.prepend_size,
1076 				 rxm->full_len - tls_ctx->rx.prepend_size);
1077 	} else {
1078 		n_sgout = 0;
1079 		*zc = false;
1080 		n_sgin = skb_cow_data(skb, 0, &unused);
1081 	}
1082 
1083 	if (n_sgin < 1)
1084 		return -EBADMSG;
1085 
1086 	/* Increment to accommodate AAD */
1087 	n_sgin = n_sgin + 1;
1088 
1089 	nsg = n_sgin + n_sgout;
1090 
1091 	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
1092 	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
1093 	mem_size = mem_size + TLS_AAD_SPACE_SIZE;
1094 	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
1095 
1096 	/* Allocate a single block of memory which contains
1097 	 * aead_req || sgin[] || sgout[] || aad || iv.
1098 	 * This order achieves correct alignment for aead_req, sgin, sgout.
1099 	 */
1100 	mem = kmalloc(mem_size, sk->sk_allocation);
1101 	if (!mem)
1102 		return -ENOMEM;
1103 
1104 	/* Segment the allocated memory */
1105 	aead_req = (struct aead_request *)mem;
1106 	sgin = (struct scatterlist *)(mem + aead_size);
1107 	sgout = sgin + n_sgin;
1108 	aad = (u8 *)(sgout + n_sgout);
1109 	iv = aad + TLS_AAD_SPACE_SIZE;
1110 
1111 	/* Prepare IV */
1112 	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
1113 			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1114 			    tls_ctx->rx.iv_size);
1115 	if (err < 0) {
1116 		kfree(mem);
1117 		return err;
1118 	}
1119 	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1120 
1121 	/* Prepare AAD */
1122 	tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
1123 		     tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
1124 		     ctx->control);
1125 
1126 	/* Prepare sgin */
1127 	sg_init_table(sgin, n_sgin);
1128 	sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
1129 	err = skb_to_sgvec(skb, &sgin[1],
1130 			   rxm->offset + tls_ctx->rx.prepend_size,
1131 			   rxm->full_len - tls_ctx->rx.prepend_size);
1132 	if (err < 0) {
1133 		kfree(mem);
1134 		return err;
1135 	}
1136 
1137 	if (n_sgout) {
1138 		if (out_iov) {
1139 			sg_init_table(sgout, n_sgout);
1140 			sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
1141 
1142 			*chunk = 0;
1143 			err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
1144 						 chunk, &sgout[1],
1145 						 (n_sgout - 1), false);
1146 			if (err < 0)
1147 				goto fallback_to_reg_recv;
1148 		} else if (out_sg) {
1149 			memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
1150 		} else {
1151 			goto fallback_to_reg_recv;
1152 		}
1153 	} else {
1154 fallback_to_reg_recv:
1155 		sgout = sgin;
1156 		pages = 0;
1157 		*chunk = 0;
1158 		*zc = false;
1159 	}
1160 
1161 	/* Prepare and submit AEAD request */
1162 	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
1163 				data_len, aead_req, *zc);
1164 	if (err == -EINPROGRESS)
1165 		return err;
1166 
1167 	/* Release the pages in case iov was mapped to pages */
1168 	for (; pages > 0; pages--)
1169 		put_page(sg_page(&sgout[pages]));
1170 
1171 	kfree(mem);
1172 	return err;
1173 }
1174 
1175 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
1176 			      struct iov_iter *dest, int *chunk, bool *zc)
1177 {
1178 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1179 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1180 	struct strp_msg *rxm = strp_msg(skb);
1181 	int err = 0;
1182 
1183 #ifdef CONFIG_TLS_DEVICE
1184 	err = tls_device_decrypted(sk, skb);
1185 	if (err < 0)
1186 		return err;
1187 #endif
1188 	if (!ctx->decrypted) {
1189 		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
1190 		if (err < 0) {
1191 			if (err == -EINPROGRESS)
1192 				tls_advance_record_sn(sk, &tls_ctx->rx);
1193 
1194 			return err;
1195 		}
1196 	} else {
1197 		*zc = false;
1198 	}
1199 
1200 	rxm->offset += tls_ctx->rx.prepend_size;
1201 	rxm->full_len -= tls_ctx->rx.overhead_size;
1202 	tls_advance_record_sn(sk, &tls_ctx->rx);
1203 	ctx->decrypted = true;
1204 	ctx->saved_data_ready(sk);
1205 
1206 	return err;
1207 }
1208 
1209 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
1210 		struct scatterlist *sgout)
1211 {
1212 	bool zc = true;
1213 	int chunk;
1214 
1215 	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
1216 }
1217 
1218 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
1219 			       unsigned int len)
1220 {
1221 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1222 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1223 
1224 	if (skb) {
1225 		struct strp_msg *rxm = strp_msg(skb);
1226 
1227 		if (len < rxm->full_len) {
1228 			rxm->offset += len;
1229 			rxm->full_len -= len;
1230 			return false;
1231 		}
1232 		kfree_skb(skb);
1233 	}
1234 
1235 	/* Finished with message */
1236 	ctx->recv_pkt = NULL;
1237 	__strp_unpause(&ctx->strp);
1238 
1239 	return true;
1240 }
1241 
1242 int tls_sw_recvmsg(struct sock *sk,
1243 		   struct msghdr *msg,
1244 		   size_t len,
1245 		   int nonblock,
1246 		   int flags,
1247 		   int *addr_len)
1248 {
1249 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1250 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1251 	unsigned char control;
1252 	struct strp_msg *rxm;
1253 	struct sk_buff *skb;
1254 	ssize_t copied = 0;
1255 	bool cmsg = false;
1256 	int target, err = 0;
1257 	long timeo;
1258 	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
1259 	int num_async = 0;
1260 
1261 	flags |= nonblock;
1262 
1263 	if (unlikely(flags & MSG_ERRQUEUE))
1264 		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
1265 
1266 	lock_sock(sk);
1267 
1268 	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
1269 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1270 	do {
1271 		bool zc = false;
1272 		bool async = false;
1273 		int chunk = 0;
1274 
1275 		skb = tls_wait_data(sk, flags, timeo, &err);
1276 		if (!skb)
1277 			goto recv_end;
1278 
1279 		rxm = strp_msg(skb);
1280 
1281 		if (!cmsg) {
1282 			int cerr;
1283 
1284 			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
1285 					sizeof(ctx->control), &ctx->control);
1286 			cmsg = true;
1287 			control = ctx->control;
1288 			if (ctx->control != TLS_RECORD_TYPE_DATA) {
1289 				if (cerr || msg->msg_flags & MSG_CTRUNC) {
1290 					err = -EIO;
1291 					goto recv_end;
1292 				}
1293 			}
1294 		} else if (control != ctx->control) {
1295 			goto recv_end;
1296 		}
1297 
1298 		if (!ctx->decrypted) {
1299 			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
1300 
1301 			if (!is_kvec && to_copy <= len &&
1302 			    likely(!(flags & MSG_PEEK)))
1303 				zc = true;
1304 
1305 			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
1306 						 &chunk, &zc);
1307 			if (err < 0 && err != -EINPROGRESS) {
1308 				tls_err_abort(sk, EBADMSG);
1309 				goto recv_end;
1310 			}
1311 
1312 			if (err == -EINPROGRESS) {
1313 				async = true;
1314 				num_async++;
1315 				goto pick_next_record;
1316 			}
1317 
1318 			ctx->decrypted = true;
1319 		}
1320 
1321 		if (!zc) {
1322 			chunk = min_t(unsigned int, rxm->full_len, len);
1323 
1324 			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
1325 						    chunk);
1326 			if (err < 0)
1327 				goto recv_end;
1328 		}
1329 
1330 pick_next_record:
1331 		copied += chunk;
1332 		len -= chunk;
1333 		if (likely(!(flags & MSG_PEEK))) {
1334 			u8 control = ctx->control;
1335 
1336 			/* For async, drop current skb reference */
1337 			if (async)
1338 				skb = NULL;
1339 
1340 			if (tls_sw_advance_skb(sk, skb, chunk)) {
1341 				/* Return full control message to
1342 				 * userspace before trying to parse
1343 				 * another message type
1344 				 */
1345 				msg->msg_flags |= MSG_EOR;
1346 				if (control != TLS_RECORD_TYPE_DATA)
1347 					goto recv_end;
1348 			} else {
1349 				break;
1350 			}
1351 		} else {
1352 			/* MSG_PEEK right now cannot look beyond current skb
1353 			 * from strparser, meaning we cannot advance skb here
1354 			 * and thus unpause strparser since we'd loose original
1355 			 * one.
1356 			 */
1357 			break;
1358 		}
1359 
1360 		/* If we have a new message from strparser, continue now. */
1361 		if (copied >= target && !ctx->recv_pkt)
1362 			break;
1363 	} while (len);
1364 
1365 recv_end:
1366 	if (num_async) {
1367 		/* Wait for all previously submitted records to be decrypted */
1368 		smp_store_mb(ctx->async_notify, true);
1369 		if (atomic_read(&ctx->decrypt_pending)) {
1370 			err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1371 			if (err) {
1372 				/* one of async decrypt failed */
1373 				tls_err_abort(sk, err);
1374 				copied = 0;
1375 			}
1376 		} else {
1377 			reinit_completion(&ctx->async_wait.completion);
1378 		}
1379 		WRITE_ONCE(ctx->async_notify, false);
1380 	}
1381 
1382 	release_sock(sk);
1383 	return copied ? : err;
1384 }
1385 
1386 ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
1387 			   struct pipe_inode_info *pipe,
1388 			   size_t len, unsigned int flags)
1389 {
1390 	struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
1391 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1392 	struct strp_msg *rxm = NULL;
1393 	struct sock *sk = sock->sk;
1394 	struct sk_buff *skb;
1395 	ssize_t copied = 0;
1396 	int err = 0;
1397 	long timeo;
1398 	int chunk;
1399 	bool zc = false;
1400 
1401 	lock_sock(sk);
1402 
1403 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
1404 
1405 	skb = tls_wait_data(sk, flags, timeo, &err);
1406 	if (!skb)
1407 		goto splice_read_end;
1408 
1409 	/* splice does not support reading control messages */
1410 	if (ctx->control != TLS_RECORD_TYPE_DATA) {
1411 		err = -ENOTSUPP;
1412 		goto splice_read_end;
1413 	}
1414 
1415 	if (!ctx->decrypted) {
1416 		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
1417 
1418 		if (err < 0) {
1419 			tls_err_abort(sk, EBADMSG);
1420 			goto splice_read_end;
1421 		}
1422 		ctx->decrypted = true;
1423 	}
1424 	rxm = strp_msg(skb);
1425 
1426 	chunk = min_t(unsigned int, rxm->full_len, len);
1427 	copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
1428 	if (copied < 0)
1429 		goto splice_read_end;
1430 
1431 	if (likely(!(flags & MSG_PEEK)))
1432 		tls_sw_advance_skb(sk, skb, copied);
1433 
1434 splice_read_end:
1435 	release_sock(sk);
1436 	return copied ? : err;
1437 }
1438 
1439 unsigned int tls_sw_poll(struct file *file, struct socket *sock,
1440 			 struct poll_table_struct *wait)
1441 {
1442 	unsigned int ret;
1443 	struct sock *sk = sock->sk;
1444 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1445 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1446 
1447 	/* Grab POLLOUT and POLLHUP from the underlying socket */
1448 	ret = ctx->sk_poll(file, sock, wait);
1449 
1450 	/* Clear POLLIN bits, and set based on recv_pkt */
1451 	ret &= ~(POLLIN | POLLRDNORM);
1452 	if (ctx->recv_pkt)
1453 		ret |= POLLIN | POLLRDNORM;
1454 
1455 	return ret;
1456 }
1457 
1458 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
1459 {
1460 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1461 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1462 	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
1463 	struct strp_msg *rxm = strp_msg(skb);
1464 	size_t cipher_overhead;
1465 	size_t data_len = 0;
1466 	int ret;
1467 
1468 	/* Verify that we have a full TLS header, or wait for more data */
1469 	if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
1470 		return 0;
1471 
1472 	/* Sanity-check size of on-stack buffer. */
1473 	if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
1474 		ret = -EINVAL;
1475 		goto read_failure;
1476 	}
1477 
1478 	/* Linearize header to local buffer */
1479 	ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
1480 
1481 	if (ret < 0)
1482 		goto read_failure;
1483 
1484 	ctx->control = header[0];
1485 
1486 	data_len = ((header[4] & 0xFF) | (header[3] << 8));
1487 
1488 	cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
1489 
1490 	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
1491 		ret = -EMSGSIZE;
1492 		goto read_failure;
1493 	}
1494 	if (data_len < cipher_overhead) {
1495 		ret = -EBADMSG;
1496 		goto read_failure;
1497 	}
1498 
1499 	if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
1500 	    header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
1501 		ret = -EINVAL;
1502 		goto read_failure;
1503 	}
1504 
1505 #ifdef CONFIG_TLS_DEVICE
1506 	handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
1507 			     *(u64*)tls_ctx->rx.rec_seq);
1508 #endif
1509 	return data_len + TLS_HEADER_SIZE;
1510 
1511 read_failure:
1512 	tls_err_abort(strp->sk, ret);
1513 
1514 	return ret;
1515 }
1516 
1517 static void tls_queue(struct strparser *strp, struct sk_buff *skb)
1518 {
1519 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1520 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1521 
1522 	ctx->decrypted = false;
1523 
1524 	ctx->recv_pkt = skb;
1525 	strp_pause(strp);
1526 
1527 	ctx->saved_data_ready(strp->sk);
1528 }
1529 
1530 static void tls_data_ready(struct sock *sk)
1531 {
1532 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1533 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1534 
1535 	strp_data_ready(&ctx->strp);
1536 }
1537 
1538 void tls_sw_free_resources_tx(struct sock *sk)
1539 {
1540 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1541 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1542 	struct tls_rec *rec, *tmp;
1543 
1544 	/* Wait for any pending async encryptions to complete */
1545 	smp_store_mb(ctx->async_notify, true);
1546 	if (atomic_read(&ctx->encrypt_pending))
1547 		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
1548 
1549 	cancel_delayed_work_sync(&ctx->tx_work.work);
1550 
1551 	/* Tx whatever records we can transmit and abandon the rest */
1552 	tls_tx_records(sk, -1);
1553 
1554 	/* Free up un-sent records in tx_list. First, free
1555 	 * the partially sent record if any at head of tx_list.
1556 	 */
1557 	if (tls_ctx->partially_sent_record) {
1558 		struct scatterlist *sg = tls_ctx->partially_sent_record;
1559 
1560 		while (1) {
1561 			put_page(sg_page(sg));
1562 			sk_mem_uncharge(sk, sg->length);
1563 
1564 			if (sg_is_last(sg))
1565 				break;
1566 			sg++;
1567 		}
1568 
1569 		tls_ctx->partially_sent_record = NULL;
1570 
1571 		rec = list_first_entry(&ctx->tx_list,
1572 				       struct tls_rec, list);
1573 
1574 		free_sg(sk, rec->sg_plaintext_data,
1575 			&rec->sg_plaintext_num_elem,
1576 			&rec->sg_plaintext_size);
1577 
1578 		list_del(&rec->list);
1579 		kfree(rec);
1580 	}
1581 
1582 	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
1583 		free_sg(sk, rec->sg_encrypted_data,
1584 			&rec->sg_encrypted_num_elem,
1585 			&rec->sg_encrypted_size);
1586 
1587 		free_sg(sk, rec->sg_plaintext_data,
1588 			&rec->sg_plaintext_num_elem,
1589 			&rec->sg_plaintext_size);
1590 
1591 		list_del(&rec->list);
1592 		kfree(rec);
1593 	}
1594 
1595 	crypto_free_aead(ctx->aead_send);
1596 	tls_free_open_rec(sk);
1597 
1598 	kfree(ctx);
1599 }
1600 
1601 void tls_sw_release_resources_rx(struct sock *sk)
1602 {
1603 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1604 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1605 
1606 	if (ctx->aead_recv) {
1607 		kfree_skb(ctx->recv_pkt);
1608 		ctx->recv_pkt = NULL;
1609 		crypto_free_aead(ctx->aead_recv);
1610 		strp_stop(&ctx->strp);
1611 		write_lock_bh(&sk->sk_callback_lock);
1612 		sk->sk_data_ready = ctx->saved_data_ready;
1613 		write_unlock_bh(&sk->sk_callback_lock);
1614 		release_sock(sk);
1615 		strp_done(&ctx->strp);
1616 		lock_sock(sk);
1617 	}
1618 }
1619 
1620 void tls_sw_free_resources_rx(struct sock *sk)
1621 {
1622 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1623 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1624 
1625 	tls_sw_release_resources_rx(sk);
1626 
1627 	kfree(ctx);
1628 }
1629 
1630 /* The work handler to transmitt the encrypted records in tx_list */
1631 static void tx_work_handler(struct work_struct *work)
1632 {
1633 	struct delayed_work *delayed_work = to_delayed_work(work);
1634 	struct tx_work *tx_work = container_of(delayed_work,
1635 					       struct tx_work, work);
1636 	struct sock *sk = tx_work->sk;
1637 	struct tls_context *tls_ctx = tls_get_ctx(sk);
1638 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1639 
1640 	if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
1641 		return;
1642 
1643 	lock_sock(sk);
1644 	tls_tx_records(sk, -1);
1645 	release_sock(sk);
1646 }
1647 
1648 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
1649 {
1650 	struct tls_crypto_info *crypto_info;
1651 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
1652 	struct tls_sw_context_tx *sw_ctx_tx = NULL;
1653 	struct tls_sw_context_rx *sw_ctx_rx = NULL;
1654 	struct cipher_context *cctx;
1655 	struct crypto_aead **aead;
1656 	struct strp_callbacks cb;
1657 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
1658 	char *iv, *rec_seq;
1659 	int rc = 0;
1660 
1661 	if (!ctx) {
1662 		rc = -EINVAL;
1663 		goto out;
1664 	}
1665 
1666 	if (tx) {
1667 		if (!ctx->priv_ctx_tx) {
1668 			sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
1669 			if (!sw_ctx_tx) {
1670 				rc = -ENOMEM;
1671 				goto out;
1672 			}
1673 			ctx->priv_ctx_tx = sw_ctx_tx;
1674 		} else {
1675 			sw_ctx_tx =
1676 				(struct tls_sw_context_tx *)ctx->priv_ctx_tx;
1677 		}
1678 	} else {
1679 		if (!ctx->priv_ctx_rx) {
1680 			sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
1681 			if (!sw_ctx_rx) {
1682 				rc = -ENOMEM;
1683 				goto out;
1684 			}
1685 			ctx->priv_ctx_rx = sw_ctx_rx;
1686 		} else {
1687 			sw_ctx_rx =
1688 				(struct tls_sw_context_rx *)ctx->priv_ctx_rx;
1689 		}
1690 	}
1691 
1692 	if (tx) {
1693 		crypto_init_wait(&sw_ctx_tx->async_wait);
1694 		crypto_info = &ctx->crypto_send.info;
1695 		cctx = &ctx->tx;
1696 		aead = &sw_ctx_tx->aead_send;
1697 		INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
1698 		INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
1699 		sw_ctx_tx->tx_work.sk = sk;
1700 	} else {
1701 		crypto_init_wait(&sw_ctx_rx->async_wait);
1702 		crypto_info = &ctx->crypto_recv.info;
1703 		cctx = &ctx->rx;
1704 		aead = &sw_ctx_rx->aead_recv;
1705 	}
1706 
1707 	switch (crypto_info->cipher_type) {
1708 	case TLS_CIPHER_AES_GCM_128: {
1709 		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1710 		tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1711 		iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1712 		iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1713 		rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1714 		rec_seq =
1715 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1716 		gcm_128_info =
1717 			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
1718 		break;
1719 	}
1720 	default:
1721 		rc = -EINVAL;
1722 		goto free_priv;
1723 	}
1724 
1725 	/* Sanity-check the IV size for stack allocations. */
1726 	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
1727 		rc = -EINVAL;
1728 		goto free_priv;
1729 	}
1730 
1731 	cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1732 	cctx->tag_size = tag_size;
1733 	cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1734 	cctx->iv_size = iv_size;
1735 	cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1736 			   GFP_KERNEL);
1737 	if (!cctx->iv) {
1738 		rc = -ENOMEM;
1739 		goto free_priv;
1740 	}
1741 	memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1742 	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1743 	cctx->rec_seq_size = rec_seq_size;
1744 	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
1745 	if (!cctx->rec_seq) {
1746 		rc = -ENOMEM;
1747 		goto free_iv;
1748 	}
1749 
1750 	if (!*aead) {
1751 		*aead = crypto_alloc_aead("gcm(aes)", 0, 0);
1752 		if (IS_ERR(*aead)) {
1753 			rc = PTR_ERR(*aead);
1754 			*aead = NULL;
1755 			goto free_rec_seq;
1756 		}
1757 	}
1758 
1759 	ctx->push_pending_record = tls_sw_push_pending_record;
1760 
1761 	rc = crypto_aead_setkey(*aead, gcm_128_info->key,
1762 				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1763 	if (rc)
1764 		goto free_aead;
1765 
1766 	rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
1767 	if (rc)
1768 		goto free_aead;
1769 
1770 	if (sw_ctx_rx) {
1771 		/* Set up strparser */
1772 		memset(&cb, 0, sizeof(cb));
1773 		cb.rcv_msg = tls_queue;
1774 		cb.parse_msg = tls_read_size;
1775 
1776 		strp_init(&sw_ctx_rx->strp, sk, &cb);
1777 
1778 		write_lock_bh(&sk->sk_callback_lock);
1779 		sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
1780 		sk->sk_data_ready = tls_data_ready;
1781 		write_unlock_bh(&sk->sk_callback_lock);
1782 
1783 		sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
1784 
1785 		strp_check_rcv(&sw_ctx_rx->strp);
1786 	}
1787 
1788 	goto out;
1789 
1790 free_aead:
1791 	crypto_free_aead(*aead);
1792 	*aead = NULL;
1793 free_rec_seq:
1794 	kfree(cctx->rec_seq);
1795 	cctx->rec_seq = NULL;
1796 free_iv:
1797 	kfree(cctx->iv);
1798 	cctx->iv = NULL;
1799 free_priv:
1800 	if (tx) {
1801 		kfree(ctx->priv_ctx_tx);
1802 		ctx->priv_ctx_tx = NULL;
1803 	} else {
1804 		kfree(ctx->priv_ctx_rx);
1805 		ctx->priv_ctx_rx = NULL;
1806 	}
1807 out:
1808 	return rc;
1809 }
1810