xref: /linux/crypto/algif_skcipher.c (revision a8fe58cec351c25e09c393bf46117c0c47b5a17c)
1 /*
2  * algif_skcipher: User-space interface for skcipher algorithms
3  *
4  * This file provides the user-space API for symmetric key ciphers.
5  *
6  * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
7  *
8  * This program is free software; you can redistribute it and/or modify it
9  * under the terms of the GNU General Public License as published by the Free
10  * Software Foundation; either version 2 of the License, or (at your option)
11  * any later version.
12  *
13  */
14 
15 #include <crypto/scatterwalk.h>
16 #include <crypto/skcipher.h>
17 #include <crypto/if_alg.h>
18 #include <linux/init.h>
19 #include <linux/list.h>
20 #include <linux/kernel.h>
21 #include <linux/mm.h>
22 #include <linux/module.h>
23 #include <linux/net.h>
24 #include <net/sock.h>
25 
26 struct skcipher_sg_list {
27 	struct list_head list;
28 
29 	int cur;
30 
31 	struct scatterlist sg[0];
32 };
33 
34 struct skcipher_tfm {
35 	struct crypto_skcipher *skcipher;
36 	bool has_key;
37 };
38 
39 struct skcipher_ctx {
40 	struct list_head tsgl;
41 	struct af_alg_sgl rsgl;
42 
43 	void *iv;
44 
45 	struct af_alg_completion completion;
46 
47 	atomic_t inflight;
48 	size_t used;
49 
50 	unsigned int len;
51 	bool more;
52 	bool merge;
53 	bool enc;
54 
55 	struct skcipher_request req;
56 };
57 
58 struct skcipher_async_rsgl {
59 	struct af_alg_sgl sgl;
60 	struct list_head list;
61 };
62 
63 struct skcipher_async_req {
64 	struct kiocb *iocb;
65 	struct skcipher_async_rsgl first_sgl;
66 	struct list_head list;
67 	struct scatterlist *tsg;
68 	char iv[];
69 };
70 
71 #define GET_SREQ(areq, ctx) (struct skcipher_async_req *)((char *)areq + \
72 	crypto_skcipher_reqsize(crypto_skcipher_reqtfm(&ctx->req)))
73 
74 #define GET_REQ_SIZE(ctx) \
75 	crypto_skcipher_reqsize(crypto_skcipher_reqtfm(&ctx->req))
76 
77 #define GET_IV_SIZE(ctx) \
78 	crypto_skcipher_ivsize(crypto_skcipher_reqtfm(&ctx->req))
79 
80 #define MAX_SGL_ENTS ((4096 - sizeof(struct skcipher_sg_list)) / \
81 		      sizeof(struct scatterlist) - 1)
82 
83 static void skcipher_free_async_sgls(struct skcipher_async_req *sreq)
84 {
85 	struct skcipher_async_rsgl *rsgl, *tmp;
86 	struct scatterlist *sgl;
87 	struct scatterlist *sg;
88 	int i, n;
89 
90 	list_for_each_entry_safe(rsgl, tmp, &sreq->list, list) {
91 		af_alg_free_sg(&rsgl->sgl);
92 		if (rsgl != &sreq->first_sgl)
93 			kfree(rsgl);
94 	}
95 	sgl = sreq->tsg;
96 	n = sg_nents(sgl);
97 	for_each_sg(sgl, sg, n, i)
98 		put_page(sg_page(sg));
99 
100 	kfree(sreq->tsg);
101 }
102 
103 static void skcipher_async_cb(struct crypto_async_request *req, int err)
104 {
105 	struct sock *sk = req->data;
106 	struct alg_sock *ask = alg_sk(sk);
107 	struct skcipher_ctx *ctx = ask->private;
108 	struct skcipher_async_req *sreq = GET_SREQ(req, ctx);
109 	struct kiocb *iocb = sreq->iocb;
110 
111 	atomic_dec(&ctx->inflight);
112 	skcipher_free_async_sgls(sreq);
113 	kfree(req);
114 	iocb->ki_complete(iocb, err, err);
115 }
116 
117 static inline int skcipher_sndbuf(struct sock *sk)
118 {
119 	struct alg_sock *ask = alg_sk(sk);
120 	struct skcipher_ctx *ctx = ask->private;
121 
122 	return max_t(int, max_t(int, sk->sk_sndbuf & PAGE_MASK, PAGE_SIZE) -
123 			  ctx->used, 0);
124 }
125 
126 static inline bool skcipher_writable(struct sock *sk)
127 {
128 	return PAGE_SIZE <= skcipher_sndbuf(sk);
129 }
130 
131 static int skcipher_alloc_sgl(struct sock *sk)
132 {
133 	struct alg_sock *ask = alg_sk(sk);
134 	struct skcipher_ctx *ctx = ask->private;
135 	struct skcipher_sg_list *sgl;
136 	struct scatterlist *sg = NULL;
137 
138 	sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
139 	if (!list_empty(&ctx->tsgl))
140 		sg = sgl->sg;
141 
142 	if (!sg || sgl->cur >= MAX_SGL_ENTS) {
143 		sgl = sock_kmalloc(sk, sizeof(*sgl) +
144 				       sizeof(sgl->sg[0]) * (MAX_SGL_ENTS + 1),
145 				   GFP_KERNEL);
146 		if (!sgl)
147 			return -ENOMEM;
148 
149 		sg_init_table(sgl->sg, MAX_SGL_ENTS + 1);
150 		sgl->cur = 0;
151 
152 		if (sg)
153 			sg_chain(sg, MAX_SGL_ENTS + 1, sgl->sg);
154 
155 		list_add_tail(&sgl->list, &ctx->tsgl);
156 	}
157 
158 	return 0;
159 }
160 
161 static void skcipher_pull_sgl(struct sock *sk, size_t used, int put)
162 {
163 	struct alg_sock *ask = alg_sk(sk);
164 	struct skcipher_ctx *ctx = ask->private;
165 	struct skcipher_sg_list *sgl;
166 	struct scatterlist *sg;
167 	int i;
168 
169 	while (!list_empty(&ctx->tsgl)) {
170 		sgl = list_first_entry(&ctx->tsgl, struct skcipher_sg_list,
171 				       list);
172 		sg = sgl->sg;
173 
174 		for (i = 0; i < sgl->cur; i++) {
175 			size_t plen = min_t(size_t, used, sg[i].length);
176 
177 			if (!sg_page(sg + i))
178 				continue;
179 
180 			sg[i].length -= plen;
181 			sg[i].offset += plen;
182 
183 			used -= plen;
184 			ctx->used -= plen;
185 
186 			if (sg[i].length)
187 				return;
188 			if (put)
189 				put_page(sg_page(sg + i));
190 			sg_assign_page(sg + i, NULL);
191 		}
192 
193 		list_del(&sgl->list);
194 		sock_kfree_s(sk, sgl,
195 			     sizeof(*sgl) + sizeof(sgl->sg[0]) *
196 					    (MAX_SGL_ENTS + 1));
197 	}
198 
199 	if (!ctx->used)
200 		ctx->merge = 0;
201 }
202 
203 static void skcipher_free_sgl(struct sock *sk)
204 {
205 	struct alg_sock *ask = alg_sk(sk);
206 	struct skcipher_ctx *ctx = ask->private;
207 
208 	skcipher_pull_sgl(sk, ctx->used, 1);
209 }
210 
211 static int skcipher_wait_for_wmem(struct sock *sk, unsigned flags)
212 {
213 	long timeout;
214 	DEFINE_WAIT(wait);
215 	int err = -ERESTARTSYS;
216 
217 	if (flags & MSG_DONTWAIT)
218 		return -EAGAIN;
219 
220 	sk_set_bit(SOCKWQ_ASYNC_NOSPACE, sk);
221 
222 	for (;;) {
223 		if (signal_pending(current))
224 			break;
225 		prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
226 		timeout = MAX_SCHEDULE_TIMEOUT;
227 		if (sk_wait_event(sk, &timeout, skcipher_writable(sk))) {
228 			err = 0;
229 			break;
230 		}
231 	}
232 	finish_wait(sk_sleep(sk), &wait);
233 
234 	return err;
235 }
236 
237 static void skcipher_wmem_wakeup(struct sock *sk)
238 {
239 	struct socket_wq *wq;
240 
241 	if (!skcipher_writable(sk))
242 		return;
243 
244 	rcu_read_lock();
245 	wq = rcu_dereference(sk->sk_wq);
246 	if (skwq_has_sleeper(wq))
247 		wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
248 							   POLLRDNORM |
249 							   POLLRDBAND);
250 	sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
251 	rcu_read_unlock();
252 }
253 
254 static int skcipher_wait_for_data(struct sock *sk, unsigned flags)
255 {
256 	struct alg_sock *ask = alg_sk(sk);
257 	struct skcipher_ctx *ctx = ask->private;
258 	long timeout;
259 	DEFINE_WAIT(wait);
260 	int err = -ERESTARTSYS;
261 
262 	if (flags & MSG_DONTWAIT) {
263 		return -EAGAIN;
264 	}
265 
266 	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
267 
268 	for (;;) {
269 		if (signal_pending(current))
270 			break;
271 		prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
272 		timeout = MAX_SCHEDULE_TIMEOUT;
273 		if (sk_wait_event(sk, &timeout, ctx->used)) {
274 			err = 0;
275 			break;
276 		}
277 	}
278 	finish_wait(sk_sleep(sk), &wait);
279 
280 	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
281 
282 	return err;
283 }
284 
285 static void skcipher_data_wakeup(struct sock *sk)
286 {
287 	struct alg_sock *ask = alg_sk(sk);
288 	struct skcipher_ctx *ctx = ask->private;
289 	struct socket_wq *wq;
290 
291 	if (!ctx->used)
292 		return;
293 
294 	rcu_read_lock();
295 	wq = rcu_dereference(sk->sk_wq);
296 	if (skwq_has_sleeper(wq))
297 		wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
298 							   POLLRDNORM |
299 							   POLLRDBAND);
300 	sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
301 	rcu_read_unlock();
302 }
303 
304 static int skcipher_sendmsg(struct socket *sock, struct msghdr *msg,
305 			    size_t size)
306 {
307 	struct sock *sk = sock->sk;
308 	struct alg_sock *ask = alg_sk(sk);
309 	struct skcipher_ctx *ctx = ask->private;
310 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
311 	unsigned ivsize = crypto_skcipher_ivsize(tfm);
312 	struct skcipher_sg_list *sgl;
313 	struct af_alg_control con = {};
314 	long copied = 0;
315 	bool enc = 0;
316 	bool init = 0;
317 	int err;
318 	int i;
319 
320 	if (msg->msg_controllen) {
321 		err = af_alg_cmsg_send(msg, &con);
322 		if (err)
323 			return err;
324 
325 		init = 1;
326 		switch (con.op) {
327 		case ALG_OP_ENCRYPT:
328 			enc = 1;
329 			break;
330 		case ALG_OP_DECRYPT:
331 			enc = 0;
332 			break;
333 		default:
334 			return -EINVAL;
335 		}
336 
337 		if (con.iv && con.iv->ivlen != ivsize)
338 			return -EINVAL;
339 	}
340 
341 	err = -EINVAL;
342 
343 	lock_sock(sk);
344 	if (!ctx->more && ctx->used)
345 		goto unlock;
346 
347 	if (init) {
348 		ctx->enc = enc;
349 		if (con.iv)
350 			memcpy(ctx->iv, con.iv->iv, ivsize);
351 	}
352 
353 	while (size) {
354 		struct scatterlist *sg;
355 		unsigned long len = size;
356 		size_t plen;
357 
358 		if (ctx->merge) {
359 			sgl = list_entry(ctx->tsgl.prev,
360 					 struct skcipher_sg_list, list);
361 			sg = sgl->sg + sgl->cur - 1;
362 			len = min_t(unsigned long, len,
363 				    PAGE_SIZE - sg->offset - sg->length);
364 
365 			err = memcpy_from_msg(page_address(sg_page(sg)) +
366 					      sg->offset + sg->length,
367 					      msg, len);
368 			if (err)
369 				goto unlock;
370 
371 			sg->length += len;
372 			ctx->merge = (sg->offset + sg->length) &
373 				     (PAGE_SIZE - 1);
374 
375 			ctx->used += len;
376 			copied += len;
377 			size -= len;
378 			continue;
379 		}
380 
381 		if (!skcipher_writable(sk)) {
382 			err = skcipher_wait_for_wmem(sk, msg->msg_flags);
383 			if (err)
384 				goto unlock;
385 		}
386 
387 		len = min_t(unsigned long, len, skcipher_sndbuf(sk));
388 
389 		err = skcipher_alloc_sgl(sk);
390 		if (err)
391 			goto unlock;
392 
393 		sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
394 		sg = sgl->sg;
395 		if (sgl->cur)
396 			sg_unmark_end(sg + sgl->cur - 1);
397 		do {
398 			i = sgl->cur;
399 			plen = min_t(size_t, len, PAGE_SIZE);
400 
401 			sg_assign_page(sg + i, alloc_page(GFP_KERNEL));
402 			err = -ENOMEM;
403 			if (!sg_page(sg + i))
404 				goto unlock;
405 
406 			err = memcpy_from_msg(page_address(sg_page(sg + i)),
407 					      msg, plen);
408 			if (err) {
409 				__free_page(sg_page(sg + i));
410 				sg_assign_page(sg + i, NULL);
411 				goto unlock;
412 			}
413 
414 			sg[i].length = plen;
415 			len -= plen;
416 			ctx->used += plen;
417 			copied += plen;
418 			size -= plen;
419 			sgl->cur++;
420 		} while (len && sgl->cur < MAX_SGL_ENTS);
421 
422 		if (!size)
423 			sg_mark_end(sg + sgl->cur - 1);
424 
425 		ctx->merge = plen & (PAGE_SIZE - 1);
426 	}
427 
428 	err = 0;
429 
430 	ctx->more = msg->msg_flags & MSG_MORE;
431 
432 unlock:
433 	skcipher_data_wakeup(sk);
434 	release_sock(sk);
435 
436 	return copied ?: err;
437 }
438 
439 static ssize_t skcipher_sendpage(struct socket *sock, struct page *page,
440 				 int offset, size_t size, int flags)
441 {
442 	struct sock *sk = sock->sk;
443 	struct alg_sock *ask = alg_sk(sk);
444 	struct skcipher_ctx *ctx = ask->private;
445 	struct skcipher_sg_list *sgl;
446 	int err = -EINVAL;
447 
448 	if (flags & MSG_SENDPAGE_NOTLAST)
449 		flags |= MSG_MORE;
450 
451 	lock_sock(sk);
452 	if (!ctx->more && ctx->used)
453 		goto unlock;
454 
455 	if (!size)
456 		goto done;
457 
458 	if (!skcipher_writable(sk)) {
459 		err = skcipher_wait_for_wmem(sk, flags);
460 		if (err)
461 			goto unlock;
462 	}
463 
464 	err = skcipher_alloc_sgl(sk);
465 	if (err)
466 		goto unlock;
467 
468 	ctx->merge = 0;
469 	sgl = list_entry(ctx->tsgl.prev, struct skcipher_sg_list, list);
470 
471 	if (sgl->cur)
472 		sg_unmark_end(sgl->sg + sgl->cur - 1);
473 
474 	sg_mark_end(sgl->sg + sgl->cur);
475 	get_page(page);
476 	sg_set_page(sgl->sg + sgl->cur, page, size, offset);
477 	sgl->cur++;
478 	ctx->used += size;
479 
480 done:
481 	ctx->more = flags & MSG_MORE;
482 
483 unlock:
484 	skcipher_data_wakeup(sk);
485 	release_sock(sk);
486 
487 	return err ?: size;
488 }
489 
490 static int skcipher_all_sg_nents(struct skcipher_ctx *ctx)
491 {
492 	struct skcipher_sg_list *sgl;
493 	struct scatterlist *sg;
494 	int nents = 0;
495 
496 	list_for_each_entry(sgl, &ctx->tsgl, list) {
497 		sg = sgl->sg;
498 
499 		while (!sg->length)
500 			sg++;
501 
502 		nents += sg_nents(sg);
503 	}
504 	return nents;
505 }
506 
507 static int skcipher_recvmsg_async(struct socket *sock, struct msghdr *msg,
508 				  int flags)
509 {
510 	struct sock *sk = sock->sk;
511 	struct alg_sock *ask = alg_sk(sk);
512 	struct skcipher_ctx *ctx = ask->private;
513 	struct skcipher_sg_list *sgl;
514 	struct scatterlist *sg;
515 	struct skcipher_async_req *sreq;
516 	struct skcipher_request *req;
517 	struct skcipher_async_rsgl *last_rsgl = NULL;
518 	unsigned int txbufs = 0, len = 0, tx_nents = skcipher_all_sg_nents(ctx);
519 	unsigned int reqlen = sizeof(struct skcipher_async_req) +
520 				GET_REQ_SIZE(ctx) + GET_IV_SIZE(ctx);
521 	int err = -ENOMEM;
522 	bool mark = false;
523 
524 	lock_sock(sk);
525 	req = kmalloc(reqlen, GFP_KERNEL);
526 	if (unlikely(!req))
527 		goto unlock;
528 
529 	sreq = GET_SREQ(req, ctx);
530 	sreq->iocb = msg->msg_iocb;
531 	memset(&sreq->first_sgl, '\0', sizeof(struct skcipher_async_rsgl));
532 	INIT_LIST_HEAD(&sreq->list);
533 	sreq->tsg = kcalloc(tx_nents, sizeof(*sg), GFP_KERNEL);
534 	if (unlikely(!sreq->tsg)) {
535 		kfree(req);
536 		goto unlock;
537 	}
538 	sg_init_table(sreq->tsg, tx_nents);
539 	memcpy(sreq->iv, ctx->iv, GET_IV_SIZE(ctx));
540 	skcipher_request_set_tfm(req, crypto_skcipher_reqtfm(&ctx->req));
541 	skcipher_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
542 				      skcipher_async_cb, sk);
543 
544 	while (iov_iter_count(&msg->msg_iter)) {
545 		struct skcipher_async_rsgl *rsgl;
546 		int used;
547 
548 		if (!ctx->used) {
549 			err = skcipher_wait_for_data(sk, flags);
550 			if (err)
551 				goto free;
552 		}
553 		sgl = list_first_entry(&ctx->tsgl,
554 				       struct skcipher_sg_list, list);
555 		sg = sgl->sg;
556 
557 		while (!sg->length)
558 			sg++;
559 
560 		used = min_t(unsigned long, ctx->used,
561 			     iov_iter_count(&msg->msg_iter));
562 		used = min_t(unsigned long, used, sg->length);
563 
564 		if (txbufs == tx_nents) {
565 			struct scatterlist *tmp;
566 			int x;
567 			/* Ran out of tx slots in async request
568 			 * need to expand */
569 			tmp = kcalloc(tx_nents * 2, sizeof(*tmp),
570 				      GFP_KERNEL);
571 			if (!tmp)
572 				goto free;
573 
574 			sg_init_table(tmp, tx_nents * 2);
575 			for (x = 0; x < tx_nents; x++)
576 				sg_set_page(&tmp[x], sg_page(&sreq->tsg[x]),
577 					    sreq->tsg[x].length,
578 					    sreq->tsg[x].offset);
579 			kfree(sreq->tsg);
580 			sreq->tsg = tmp;
581 			tx_nents *= 2;
582 			mark = true;
583 		}
584 		/* Need to take over the tx sgl from ctx
585 		 * to the asynch req - these sgls will be freed later */
586 		sg_set_page(sreq->tsg + txbufs++, sg_page(sg), sg->length,
587 			    sg->offset);
588 
589 		if (list_empty(&sreq->list)) {
590 			rsgl = &sreq->first_sgl;
591 			list_add_tail(&rsgl->list, &sreq->list);
592 		} else {
593 			rsgl = kmalloc(sizeof(*rsgl), GFP_KERNEL);
594 			if (!rsgl) {
595 				err = -ENOMEM;
596 				goto free;
597 			}
598 			list_add_tail(&rsgl->list, &sreq->list);
599 		}
600 
601 		used = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, used);
602 		err = used;
603 		if (used < 0)
604 			goto free;
605 		if (last_rsgl)
606 			af_alg_link_sg(&last_rsgl->sgl, &rsgl->sgl);
607 
608 		last_rsgl = rsgl;
609 		len += used;
610 		skcipher_pull_sgl(sk, used, 0);
611 		iov_iter_advance(&msg->msg_iter, used);
612 	}
613 
614 	if (mark)
615 		sg_mark_end(sreq->tsg + txbufs - 1);
616 
617 	skcipher_request_set_crypt(req, sreq->tsg, sreq->first_sgl.sgl.sg,
618 				   len, sreq->iv);
619 	err = ctx->enc ? crypto_skcipher_encrypt(req) :
620 			 crypto_skcipher_decrypt(req);
621 	if (err == -EINPROGRESS) {
622 		atomic_inc(&ctx->inflight);
623 		err = -EIOCBQUEUED;
624 		goto unlock;
625 	}
626 free:
627 	skcipher_free_async_sgls(sreq);
628 	kfree(req);
629 unlock:
630 	skcipher_wmem_wakeup(sk);
631 	release_sock(sk);
632 	return err;
633 }
634 
635 static int skcipher_recvmsg_sync(struct socket *sock, struct msghdr *msg,
636 				 int flags)
637 {
638 	struct sock *sk = sock->sk;
639 	struct alg_sock *ask = alg_sk(sk);
640 	struct skcipher_ctx *ctx = ask->private;
641 	unsigned bs = crypto_skcipher_blocksize(crypto_skcipher_reqtfm(
642 		&ctx->req));
643 	struct skcipher_sg_list *sgl;
644 	struct scatterlist *sg;
645 	int err = -EAGAIN;
646 	int used;
647 	long copied = 0;
648 
649 	lock_sock(sk);
650 	while (msg_data_left(msg)) {
651 		if (!ctx->used) {
652 			err = skcipher_wait_for_data(sk, flags);
653 			if (err)
654 				goto unlock;
655 		}
656 
657 		used = min_t(unsigned long, ctx->used, msg_data_left(msg));
658 
659 		used = af_alg_make_sg(&ctx->rsgl, &msg->msg_iter, used);
660 		err = used;
661 		if (err < 0)
662 			goto unlock;
663 
664 		if (ctx->more || used < ctx->used)
665 			used -= used % bs;
666 
667 		err = -EINVAL;
668 		if (!used)
669 			goto free;
670 
671 		sgl = list_first_entry(&ctx->tsgl,
672 				       struct skcipher_sg_list, list);
673 		sg = sgl->sg;
674 
675 		while (!sg->length)
676 			sg++;
677 
678 		skcipher_request_set_crypt(&ctx->req, sg, ctx->rsgl.sg, used,
679 					   ctx->iv);
680 
681 		err = af_alg_wait_for_completion(
682 				ctx->enc ?
683 					crypto_skcipher_encrypt(&ctx->req) :
684 					crypto_skcipher_decrypt(&ctx->req),
685 				&ctx->completion);
686 
687 free:
688 		af_alg_free_sg(&ctx->rsgl);
689 
690 		if (err)
691 			goto unlock;
692 
693 		copied += used;
694 		skcipher_pull_sgl(sk, used, 1);
695 		iov_iter_advance(&msg->msg_iter, used);
696 	}
697 
698 	err = 0;
699 
700 unlock:
701 	skcipher_wmem_wakeup(sk);
702 	release_sock(sk);
703 
704 	return copied ?: err;
705 }
706 
707 static int skcipher_recvmsg(struct socket *sock, struct msghdr *msg,
708 			    size_t ignored, int flags)
709 {
710 	return (msg->msg_iocb && !is_sync_kiocb(msg->msg_iocb)) ?
711 		skcipher_recvmsg_async(sock, msg, flags) :
712 		skcipher_recvmsg_sync(sock, msg, flags);
713 }
714 
715 static unsigned int skcipher_poll(struct file *file, struct socket *sock,
716 				  poll_table *wait)
717 {
718 	struct sock *sk = sock->sk;
719 	struct alg_sock *ask = alg_sk(sk);
720 	struct skcipher_ctx *ctx = ask->private;
721 	unsigned int mask;
722 
723 	sock_poll_wait(file, sk_sleep(sk), wait);
724 	mask = 0;
725 
726 	if (ctx->used)
727 		mask |= POLLIN | POLLRDNORM;
728 
729 	if (skcipher_writable(sk))
730 		mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
731 
732 	return mask;
733 }
734 
735 static struct proto_ops algif_skcipher_ops = {
736 	.family		=	PF_ALG,
737 
738 	.connect	=	sock_no_connect,
739 	.socketpair	=	sock_no_socketpair,
740 	.getname	=	sock_no_getname,
741 	.ioctl		=	sock_no_ioctl,
742 	.listen		=	sock_no_listen,
743 	.shutdown	=	sock_no_shutdown,
744 	.getsockopt	=	sock_no_getsockopt,
745 	.mmap		=	sock_no_mmap,
746 	.bind		=	sock_no_bind,
747 	.accept		=	sock_no_accept,
748 	.setsockopt	=	sock_no_setsockopt,
749 
750 	.release	=	af_alg_release,
751 	.sendmsg	=	skcipher_sendmsg,
752 	.sendpage	=	skcipher_sendpage,
753 	.recvmsg	=	skcipher_recvmsg,
754 	.poll		=	skcipher_poll,
755 };
756 
757 static int skcipher_check_key(struct socket *sock)
758 {
759 	int err = 0;
760 	struct sock *psk;
761 	struct alg_sock *pask;
762 	struct skcipher_tfm *tfm;
763 	struct sock *sk = sock->sk;
764 	struct alg_sock *ask = alg_sk(sk);
765 
766 	lock_sock(sk);
767 	if (ask->refcnt)
768 		goto unlock_child;
769 
770 	psk = ask->parent;
771 	pask = alg_sk(ask->parent);
772 	tfm = pask->private;
773 
774 	err = -ENOKEY;
775 	lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
776 	if (!tfm->has_key)
777 		goto unlock;
778 
779 	if (!pask->refcnt++)
780 		sock_hold(psk);
781 
782 	ask->refcnt = 1;
783 	sock_put(psk);
784 
785 	err = 0;
786 
787 unlock:
788 	release_sock(psk);
789 unlock_child:
790 	release_sock(sk);
791 
792 	return err;
793 }
794 
795 static int skcipher_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
796 				  size_t size)
797 {
798 	int err;
799 
800 	err = skcipher_check_key(sock);
801 	if (err)
802 		return err;
803 
804 	return skcipher_sendmsg(sock, msg, size);
805 }
806 
807 static ssize_t skcipher_sendpage_nokey(struct socket *sock, struct page *page,
808 				       int offset, size_t size, int flags)
809 {
810 	int err;
811 
812 	err = skcipher_check_key(sock);
813 	if (err)
814 		return err;
815 
816 	return skcipher_sendpage(sock, page, offset, size, flags);
817 }
818 
819 static int skcipher_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
820 				  size_t ignored, int flags)
821 {
822 	int err;
823 
824 	err = skcipher_check_key(sock);
825 	if (err)
826 		return err;
827 
828 	return skcipher_recvmsg(sock, msg, ignored, flags);
829 }
830 
831 static struct proto_ops algif_skcipher_ops_nokey = {
832 	.family		=	PF_ALG,
833 
834 	.connect	=	sock_no_connect,
835 	.socketpair	=	sock_no_socketpair,
836 	.getname	=	sock_no_getname,
837 	.ioctl		=	sock_no_ioctl,
838 	.listen		=	sock_no_listen,
839 	.shutdown	=	sock_no_shutdown,
840 	.getsockopt	=	sock_no_getsockopt,
841 	.mmap		=	sock_no_mmap,
842 	.bind		=	sock_no_bind,
843 	.accept		=	sock_no_accept,
844 	.setsockopt	=	sock_no_setsockopt,
845 
846 	.release	=	af_alg_release,
847 	.sendmsg	=	skcipher_sendmsg_nokey,
848 	.sendpage	=	skcipher_sendpage_nokey,
849 	.recvmsg	=	skcipher_recvmsg_nokey,
850 	.poll		=	skcipher_poll,
851 };
852 
853 static void *skcipher_bind(const char *name, u32 type, u32 mask)
854 {
855 	struct skcipher_tfm *tfm;
856 	struct crypto_skcipher *skcipher;
857 
858 	tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
859 	if (!tfm)
860 		return ERR_PTR(-ENOMEM);
861 
862 	skcipher = crypto_alloc_skcipher(name, type, mask);
863 	if (IS_ERR(skcipher)) {
864 		kfree(tfm);
865 		return ERR_CAST(skcipher);
866 	}
867 
868 	tfm->skcipher = skcipher;
869 
870 	return tfm;
871 }
872 
873 static void skcipher_release(void *private)
874 {
875 	struct skcipher_tfm *tfm = private;
876 
877 	crypto_free_skcipher(tfm->skcipher);
878 	kfree(tfm);
879 }
880 
881 static int skcipher_setkey(void *private, const u8 *key, unsigned int keylen)
882 {
883 	struct skcipher_tfm *tfm = private;
884 	int err;
885 
886 	err = crypto_skcipher_setkey(tfm->skcipher, key, keylen);
887 	tfm->has_key = !err;
888 
889 	return err;
890 }
891 
892 static void skcipher_wait(struct sock *sk)
893 {
894 	struct alg_sock *ask = alg_sk(sk);
895 	struct skcipher_ctx *ctx = ask->private;
896 	int ctr = 0;
897 
898 	while (atomic_read(&ctx->inflight) && ctr++ < 100)
899 		msleep(100);
900 }
901 
902 static void skcipher_sock_destruct(struct sock *sk)
903 {
904 	struct alg_sock *ask = alg_sk(sk);
905 	struct skcipher_ctx *ctx = ask->private;
906 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(&ctx->req);
907 
908 	if (atomic_read(&ctx->inflight))
909 		skcipher_wait(sk);
910 
911 	skcipher_free_sgl(sk);
912 	sock_kzfree_s(sk, ctx->iv, crypto_skcipher_ivsize(tfm));
913 	sock_kfree_s(sk, ctx, ctx->len);
914 	af_alg_release_parent(sk);
915 }
916 
917 static int skcipher_accept_parent_nokey(void *private, struct sock *sk)
918 {
919 	struct skcipher_ctx *ctx;
920 	struct alg_sock *ask = alg_sk(sk);
921 	struct skcipher_tfm *tfm = private;
922 	struct crypto_skcipher *skcipher = tfm->skcipher;
923 	unsigned int len = sizeof(*ctx) + crypto_skcipher_reqsize(skcipher);
924 
925 	ctx = sock_kmalloc(sk, len, GFP_KERNEL);
926 	if (!ctx)
927 		return -ENOMEM;
928 
929 	ctx->iv = sock_kmalloc(sk, crypto_skcipher_ivsize(skcipher),
930 			       GFP_KERNEL);
931 	if (!ctx->iv) {
932 		sock_kfree_s(sk, ctx, len);
933 		return -ENOMEM;
934 	}
935 
936 	memset(ctx->iv, 0, crypto_skcipher_ivsize(skcipher));
937 
938 	INIT_LIST_HEAD(&ctx->tsgl);
939 	ctx->len = len;
940 	ctx->used = 0;
941 	ctx->more = 0;
942 	ctx->merge = 0;
943 	ctx->enc = 0;
944 	atomic_set(&ctx->inflight, 0);
945 	af_alg_init_completion(&ctx->completion);
946 
947 	ask->private = ctx;
948 
949 	skcipher_request_set_tfm(&ctx->req, skcipher);
950 	skcipher_request_set_callback(&ctx->req, CRYPTO_TFM_REQ_MAY_BACKLOG,
951 				      af_alg_complete, &ctx->completion);
952 
953 	sk->sk_destruct = skcipher_sock_destruct;
954 
955 	return 0;
956 }
957 
958 static int skcipher_accept_parent(void *private, struct sock *sk)
959 {
960 	struct skcipher_tfm *tfm = private;
961 
962 	if (!tfm->has_key && crypto_skcipher_has_setkey(tfm->skcipher))
963 		return -ENOKEY;
964 
965 	return skcipher_accept_parent_nokey(private, sk);
966 }
967 
968 static const struct af_alg_type algif_type_skcipher = {
969 	.bind		=	skcipher_bind,
970 	.release	=	skcipher_release,
971 	.setkey		=	skcipher_setkey,
972 	.accept		=	skcipher_accept_parent,
973 	.accept_nokey	=	skcipher_accept_parent_nokey,
974 	.ops		=	&algif_skcipher_ops,
975 	.ops_nokey	=	&algif_skcipher_ops_nokey,
976 	.name		=	"skcipher",
977 	.owner		=	THIS_MODULE
978 };
979 
980 static int __init algif_skcipher_init(void)
981 {
982 	return af_alg_register_type(&algif_type_skcipher);
983 }
984 
985 static void __exit algif_skcipher_exit(void)
986 {
987 	int err = af_alg_unregister_type(&algif_type_skcipher);
988 	BUG_ON(err);
989 }
990 
991 module_init(algif_skcipher_init);
992 module_exit(algif_skcipher_exit);
993 MODULE_LICENSE("GPL");
994