xref: /linux/arch/arm64/crypto/aes-glue.c (revision e3966940559d52aa1800a008dcfeec218dd31f88)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/hwcap.h>
9 #include <asm/neon.h>
10 #include <crypto/aes.h>
11 #include <crypto/ctr.h>
12 #include <crypto/internal/hash.h>
13 #include <crypto/internal/skcipher.h>
14 #include <crypto/scatterwalk.h>
15 #include <crypto/sha2.h>
16 #include <crypto/utils.h>
17 #include <crypto/xts.h>
18 #include <linux/cpufeature.h>
19 #include <linux/kernel.h>
20 #include <linux/module.h>
21 #include <linux/string.h>
22 
23 #include "aes-ce-setkey.h"
24 
25 #ifdef USE_V8_CRYPTO_EXTENSIONS
26 #define MODE			"ce"
27 #define PRIO			300
28 #define aes_expandkey		ce_aes_expandkey
29 #define aes_ecb_encrypt		ce_aes_ecb_encrypt
30 #define aes_ecb_decrypt		ce_aes_ecb_decrypt
31 #define aes_cbc_encrypt		ce_aes_cbc_encrypt
32 #define aes_cbc_decrypt		ce_aes_cbc_decrypt
33 #define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
34 #define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
35 #define aes_essiv_cbc_encrypt	ce_aes_essiv_cbc_encrypt
36 #define aes_essiv_cbc_decrypt	ce_aes_essiv_cbc_decrypt
37 #define aes_ctr_encrypt		ce_aes_ctr_encrypt
38 #define aes_xctr_encrypt	ce_aes_xctr_encrypt
39 #define aes_xts_encrypt		ce_aes_xts_encrypt
40 #define aes_xts_decrypt		ce_aes_xts_decrypt
41 #define aes_mac_update		ce_aes_mac_update
42 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 Crypto Extensions");
43 #else
44 #define MODE			"neon"
45 #define PRIO			200
46 #define aes_ecb_encrypt		neon_aes_ecb_encrypt
47 #define aes_ecb_decrypt		neon_aes_ecb_decrypt
48 #define aes_cbc_encrypt		neon_aes_cbc_encrypt
49 #define aes_cbc_decrypt		neon_aes_cbc_decrypt
50 #define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
51 #define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
52 #define aes_essiv_cbc_encrypt	neon_aes_essiv_cbc_encrypt
53 #define aes_essiv_cbc_decrypt	neon_aes_essiv_cbc_decrypt
54 #define aes_ctr_encrypt		neon_aes_ctr_encrypt
55 #define aes_xctr_encrypt	neon_aes_xctr_encrypt
56 #define aes_xts_encrypt		neon_aes_xts_encrypt
57 #define aes_xts_decrypt		neon_aes_xts_decrypt
58 #define aes_mac_update		neon_aes_mac_update
59 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS/XCTR using ARMv8 NEON");
60 #endif
61 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
62 MODULE_ALIAS_CRYPTO("ecb(aes)");
63 MODULE_ALIAS_CRYPTO("cbc(aes)");
64 MODULE_ALIAS_CRYPTO("ctr(aes)");
65 MODULE_ALIAS_CRYPTO("xts(aes)");
66 MODULE_ALIAS_CRYPTO("xctr(aes)");
67 #endif
68 MODULE_ALIAS_CRYPTO("cts(cbc(aes))");
69 MODULE_ALIAS_CRYPTO("essiv(cbc(aes),sha256)");
70 MODULE_ALIAS_CRYPTO("cmac(aes)");
71 MODULE_ALIAS_CRYPTO("xcbc(aes)");
72 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
73 
74 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
75 MODULE_LICENSE("GPL v2");
76 
77 /* defined in aes-modes.S */
78 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
79 				int rounds, int blocks);
80 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
81 				int rounds, int blocks);
82 
83 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
84 				int rounds, int blocks, u8 iv[]);
85 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
86 				int rounds, int blocks, u8 iv[]);
87 
88 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
89 				int rounds, int bytes, u8 const iv[]);
90 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
91 				int rounds, int bytes, u8 const iv[]);
92 
93 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
94 				int rounds, int bytes, u8 ctr[]);
95 
96 asmlinkage void aes_xctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
97 				 int rounds, int bytes, u8 ctr[], int byte_ctr);
98 
99 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
100 				int rounds, int bytes, u32 const rk2[], u8 iv[],
101 				int first);
102 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
103 				int rounds, int bytes, u32 const rk2[], u8 iv[],
104 				int first);
105 
106 asmlinkage void aes_essiv_cbc_encrypt(u8 out[], u8 const in[], u32 const rk1[],
107 				      int rounds, int blocks, u8 iv[],
108 				      u32 const rk2[]);
109 asmlinkage void aes_essiv_cbc_decrypt(u8 out[], u8 const in[], u32 const rk1[],
110 				      int rounds, int blocks, u8 iv[],
111 				      u32 const rk2[]);
112 
113 asmlinkage int aes_mac_update(u8 const in[], u32 const rk[], int rounds,
114 			      int blocks, u8 dg[], int enc_before,
115 			      int enc_after);
116 
117 struct crypto_aes_xts_ctx {
118 	struct crypto_aes_ctx key1;
119 	struct crypto_aes_ctx __aligned(8) key2;
120 };
121 
122 struct crypto_aes_essiv_cbc_ctx {
123 	struct crypto_aes_ctx key1;
124 	struct crypto_aes_ctx __aligned(8) key2;
125 };
126 
127 struct mac_tfm_ctx {
128 	struct crypto_aes_ctx key;
129 	u8 __aligned(8) consts[];
130 };
131 
132 struct mac_desc_ctx {
133 	u8 dg[AES_BLOCK_SIZE];
134 };
135 
136 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
137 			       unsigned int key_len)
138 {
139 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
140 
141 	return aes_expandkey(ctx, in_key, key_len);
142 }
143 
144 static int __maybe_unused xts_set_key(struct crypto_skcipher *tfm,
145 				      const u8 *in_key, unsigned int key_len)
146 {
147 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
148 	int ret;
149 
150 	ret = xts_verify_key(tfm, in_key, key_len);
151 	if (ret)
152 		return ret;
153 
154 	ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
155 	if (!ret)
156 		ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
157 				    key_len / 2);
158 	return ret;
159 }
160 
161 static int __maybe_unused essiv_cbc_set_key(struct crypto_skcipher *tfm,
162 					    const u8 *in_key,
163 					    unsigned int key_len)
164 {
165 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
166 	u8 digest[SHA256_DIGEST_SIZE];
167 	int ret;
168 
169 	ret = aes_expandkey(&ctx->key1, in_key, key_len);
170 	if (ret)
171 		return ret;
172 
173 	sha256(in_key, key_len, digest);
174 
175 	return aes_expandkey(&ctx->key2, digest, sizeof(digest));
176 }
177 
178 static int __maybe_unused ecb_encrypt(struct skcipher_request *req)
179 {
180 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
181 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
182 	int err, rounds = 6 + ctx->key_length / 4;
183 	struct skcipher_walk walk;
184 	unsigned int blocks;
185 
186 	err = skcipher_walk_virt(&walk, req, false);
187 
188 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
189 		kernel_neon_begin();
190 		aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
191 				ctx->key_enc, rounds, blocks);
192 		kernel_neon_end();
193 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
194 	}
195 	return err;
196 }
197 
198 static int __maybe_unused ecb_decrypt(struct skcipher_request *req)
199 {
200 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
201 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
202 	int err, rounds = 6 + ctx->key_length / 4;
203 	struct skcipher_walk walk;
204 	unsigned int blocks;
205 
206 	err = skcipher_walk_virt(&walk, req, false);
207 
208 	while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
209 		kernel_neon_begin();
210 		aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
211 				ctx->key_dec, rounds, blocks);
212 		kernel_neon_end();
213 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
214 	}
215 	return err;
216 }
217 
218 static int cbc_encrypt_walk(struct skcipher_request *req,
219 			    struct skcipher_walk *walk)
220 {
221 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
222 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
223 	int err = 0, rounds = 6 + ctx->key_length / 4;
224 	unsigned int blocks;
225 
226 	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
227 		kernel_neon_begin();
228 		aes_cbc_encrypt(walk->dst.virt.addr, walk->src.virt.addr,
229 				ctx->key_enc, rounds, blocks, walk->iv);
230 		kernel_neon_end();
231 		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
232 	}
233 	return err;
234 }
235 
236 static int __maybe_unused cbc_encrypt(struct skcipher_request *req)
237 {
238 	struct skcipher_walk walk;
239 	int err;
240 
241 	err = skcipher_walk_virt(&walk, req, false);
242 	if (err)
243 		return err;
244 	return cbc_encrypt_walk(req, &walk);
245 }
246 
247 static int cbc_decrypt_walk(struct skcipher_request *req,
248 			    struct skcipher_walk *walk)
249 {
250 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
251 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
252 	int err = 0, rounds = 6 + ctx->key_length / 4;
253 	unsigned int blocks;
254 
255 	while ((blocks = (walk->nbytes / AES_BLOCK_SIZE))) {
256 		kernel_neon_begin();
257 		aes_cbc_decrypt(walk->dst.virt.addr, walk->src.virt.addr,
258 				ctx->key_dec, rounds, blocks, walk->iv);
259 		kernel_neon_end();
260 		err = skcipher_walk_done(walk, walk->nbytes % AES_BLOCK_SIZE);
261 	}
262 	return err;
263 }
264 
265 static int __maybe_unused cbc_decrypt(struct skcipher_request *req)
266 {
267 	struct skcipher_walk walk;
268 	int err;
269 
270 	err = skcipher_walk_virt(&walk, req, false);
271 	if (err)
272 		return err;
273 	return cbc_decrypt_walk(req, &walk);
274 }
275 
276 static int cts_cbc_encrypt(struct skcipher_request *req)
277 {
278 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
279 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
280 	int err, rounds = 6 + ctx->key_length / 4;
281 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
282 	struct scatterlist *src = req->src, *dst = req->dst;
283 	struct scatterlist sg_src[2], sg_dst[2];
284 	struct skcipher_request subreq;
285 	struct skcipher_walk walk;
286 
287 	skcipher_request_set_tfm(&subreq, tfm);
288 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
289 				      NULL, NULL);
290 
291 	if (req->cryptlen <= AES_BLOCK_SIZE) {
292 		if (req->cryptlen < AES_BLOCK_SIZE)
293 			return -EINVAL;
294 		cbc_blocks = 1;
295 	}
296 
297 	if (cbc_blocks > 0) {
298 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
299 					   cbc_blocks * AES_BLOCK_SIZE,
300 					   req->iv);
301 
302 		err = skcipher_walk_virt(&walk, &subreq, false) ?:
303 		      cbc_encrypt_walk(&subreq, &walk);
304 		if (err)
305 			return err;
306 
307 		if (req->cryptlen == AES_BLOCK_SIZE)
308 			return 0;
309 
310 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
311 		if (req->dst != req->src)
312 			dst = scatterwalk_ffwd(sg_dst, req->dst,
313 					       subreq.cryptlen);
314 	}
315 
316 	/* handle ciphertext stealing */
317 	skcipher_request_set_crypt(&subreq, src, dst,
318 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
319 				   req->iv);
320 
321 	err = skcipher_walk_virt(&walk, &subreq, false);
322 	if (err)
323 		return err;
324 
325 	kernel_neon_begin();
326 	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
327 			    ctx->key_enc, rounds, walk.nbytes, walk.iv);
328 	kernel_neon_end();
329 
330 	return skcipher_walk_done(&walk, 0);
331 }
332 
333 static int cts_cbc_decrypt(struct skcipher_request *req)
334 {
335 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
336 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
337 	int err, rounds = 6 + ctx->key_length / 4;
338 	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
339 	struct scatterlist *src = req->src, *dst = req->dst;
340 	struct scatterlist sg_src[2], sg_dst[2];
341 	struct skcipher_request subreq;
342 	struct skcipher_walk walk;
343 
344 	skcipher_request_set_tfm(&subreq, tfm);
345 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
346 				      NULL, NULL);
347 
348 	if (req->cryptlen <= AES_BLOCK_SIZE) {
349 		if (req->cryptlen < AES_BLOCK_SIZE)
350 			return -EINVAL;
351 		cbc_blocks = 1;
352 	}
353 
354 	if (cbc_blocks > 0) {
355 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
356 					   cbc_blocks * AES_BLOCK_SIZE,
357 					   req->iv);
358 
359 		err = skcipher_walk_virt(&walk, &subreq, false) ?:
360 		      cbc_decrypt_walk(&subreq, &walk);
361 		if (err)
362 			return err;
363 
364 		if (req->cryptlen == AES_BLOCK_SIZE)
365 			return 0;
366 
367 		dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
368 		if (req->dst != req->src)
369 			dst = scatterwalk_ffwd(sg_dst, req->dst,
370 					       subreq.cryptlen);
371 	}
372 
373 	/* handle ciphertext stealing */
374 	skcipher_request_set_crypt(&subreq, src, dst,
375 				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
376 				   req->iv);
377 
378 	err = skcipher_walk_virt(&walk, &subreq, false);
379 	if (err)
380 		return err;
381 
382 	kernel_neon_begin();
383 	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
384 			    ctx->key_dec, rounds, walk.nbytes, walk.iv);
385 	kernel_neon_end();
386 
387 	return skcipher_walk_done(&walk, 0);
388 }
389 
390 static int __maybe_unused essiv_cbc_encrypt(struct skcipher_request *req)
391 {
392 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
393 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
394 	int err, rounds = 6 + ctx->key1.key_length / 4;
395 	struct skcipher_walk walk;
396 	unsigned int blocks;
397 
398 	err = skcipher_walk_virt(&walk, req, false);
399 
400 	blocks = walk.nbytes / AES_BLOCK_SIZE;
401 	if (blocks) {
402 		kernel_neon_begin();
403 		aes_essiv_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
404 				      ctx->key1.key_enc, rounds, blocks,
405 				      req->iv, ctx->key2.key_enc);
406 		kernel_neon_end();
407 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
408 	}
409 	return err ?: cbc_encrypt_walk(req, &walk);
410 }
411 
412 static int __maybe_unused essiv_cbc_decrypt(struct skcipher_request *req)
413 {
414 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
415 	struct crypto_aes_essiv_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
416 	int err, rounds = 6 + ctx->key1.key_length / 4;
417 	struct skcipher_walk walk;
418 	unsigned int blocks;
419 
420 	err = skcipher_walk_virt(&walk, req, false);
421 
422 	blocks = walk.nbytes / AES_BLOCK_SIZE;
423 	if (blocks) {
424 		kernel_neon_begin();
425 		aes_essiv_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
426 				      ctx->key1.key_dec, rounds, blocks,
427 				      req->iv, ctx->key2.key_enc);
428 		kernel_neon_end();
429 		err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
430 	}
431 	return err ?: cbc_decrypt_walk(req, &walk);
432 }
433 
434 static int __maybe_unused xctr_encrypt(struct skcipher_request *req)
435 {
436 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
437 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
438 	int err, rounds = 6 + ctx->key_length / 4;
439 	struct skcipher_walk walk;
440 	unsigned int byte_ctr = 0;
441 
442 	err = skcipher_walk_virt(&walk, req, false);
443 
444 	while (walk.nbytes > 0) {
445 		const u8 *src = walk.src.virt.addr;
446 		unsigned int nbytes = walk.nbytes;
447 		u8 *dst = walk.dst.virt.addr;
448 		u8 buf[AES_BLOCK_SIZE];
449 
450 		/*
451 		 * If given less than 16 bytes, we must copy the partial block
452 		 * into a temporary buffer of 16 bytes to avoid out of bounds
453 		 * reads and writes.  Furthermore, this code is somewhat unusual
454 		 * in that it expects the end of the data to be at the end of
455 		 * the temporary buffer, rather than the start of the data at
456 		 * the start of the temporary buffer.
457 		 */
458 		if (unlikely(nbytes < AES_BLOCK_SIZE))
459 			src = dst = memcpy(buf + sizeof(buf) - nbytes,
460 					   src, nbytes);
461 		else if (nbytes < walk.total)
462 			nbytes &= ~(AES_BLOCK_SIZE - 1);
463 
464 		kernel_neon_begin();
465 		aes_xctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
466 						 walk.iv, byte_ctr);
467 		kernel_neon_end();
468 
469 		if (unlikely(nbytes < AES_BLOCK_SIZE))
470 			memcpy(walk.dst.virt.addr,
471 			       buf + sizeof(buf) - nbytes, nbytes);
472 		byte_ctr += nbytes;
473 
474 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
475 	}
476 
477 	return err;
478 }
479 
480 static int __maybe_unused ctr_encrypt(struct skcipher_request *req)
481 {
482 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
483 	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
484 	int err, rounds = 6 + ctx->key_length / 4;
485 	struct skcipher_walk walk;
486 
487 	err = skcipher_walk_virt(&walk, req, false);
488 
489 	while (walk.nbytes > 0) {
490 		const u8 *src = walk.src.virt.addr;
491 		unsigned int nbytes = walk.nbytes;
492 		u8 *dst = walk.dst.virt.addr;
493 		u8 buf[AES_BLOCK_SIZE];
494 
495 		/*
496 		 * If given less than 16 bytes, we must copy the partial block
497 		 * into a temporary buffer of 16 bytes to avoid out of bounds
498 		 * reads and writes.  Furthermore, this code is somewhat unusual
499 		 * in that it expects the end of the data to be at the end of
500 		 * the temporary buffer, rather than the start of the data at
501 		 * the start of the temporary buffer.
502 		 */
503 		if (unlikely(nbytes < AES_BLOCK_SIZE))
504 			src = dst = memcpy(buf + sizeof(buf) - nbytes,
505 					   src, nbytes);
506 		else if (nbytes < walk.total)
507 			nbytes &= ~(AES_BLOCK_SIZE - 1);
508 
509 		kernel_neon_begin();
510 		aes_ctr_encrypt(dst, src, ctx->key_enc, rounds, nbytes,
511 				walk.iv);
512 		kernel_neon_end();
513 
514 		if (unlikely(nbytes < AES_BLOCK_SIZE))
515 			memcpy(walk.dst.virt.addr,
516 			       buf + sizeof(buf) - nbytes, nbytes);
517 
518 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
519 	}
520 
521 	return err;
522 }
523 
524 static int __maybe_unused xts_encrypt(struct skcipher_request *req)
525 {
526 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
527 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
528 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
529 	int tail = req->cryptlen % AES_BLOCK_SIZE;
530 	struct scatterlist sg_src[2], sg_dst[2];
531 	struct skcipher_request subreq;
532 	struct scatterlist *src, *dst;
533 	struct skcipher_walk walk;
534 
535 	if (req->cryptlen < AES_BLOCK_SIZE)
536 		return -EINVAL;
537 
538 	err = skcipher_walk_virt(&walk, req, false);
539 
540 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
541 		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
542 					      AES_BLOCK_SIZE) - 2;
543 
544 		skcipher_walk_abort(&walk);
545 
546 		skcipher_request_set_tfm(&subreq, tfm);
547 		skcipher_request_set_callback(&subreq,
548 					      skcipher_request_flags(req),
549 					      NULL, NULL);
550 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
551 					   xts_blocks * AES_BLOCK_SIZE,
552 					   req->iv);
553 		req = &subreq;
554 		err = skcipher_walk_virt(&walk, req, false);
555 	} else {
556 		tail = 0;
557 	}
558 
559 	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
560 		int nbytes = walk.nbytes;
561 
562 		if (walk.nbytes < walk.total)
563 			nbytes &= ~(AES_BLOCK_SIZE - 1);
564 
565 		kernel_neon_begin();
566 		aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
567 				ctx->key1.key_enc, rounds, nbytes,
568 				ctx->key2.key_enc, walk.iv, first);
569 		kernel_neon_end();
570 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
571 	}
572 
573 	if (err || likely(!tail))
574 		return err;
575 
576 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
577 	if (req->dst != req->src)
578 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
579 
580 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
581 				   req->iv);
582 
583 	err = skcipher_walk_virt(&walk, &subreq, false);
584 	if (err)
585 		return err;
586 
587 	kernel_neon_begin();
588 	aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
589 			ctx->key1.key_enc, rounds, walk.nbytes,
590 			ctx->key2.key_enc, walk.iv, first);
591 	kernel_neon_end();
592 
593 	return skcipher_walk_done(&walk, 0);
594 }
595 
596 static int __maybe_unused xts_decrypt(struct skcipher_request *req)
597 {
598 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
599 	struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
600 	int err, first, rounds = 6 + ctx->key1.key_length / 4;
601 	int tail = req->cryptlen % AES_BLOCK_SIZE;
602 	struct scatterlist sg_src[2], sg_dst[2];
603 	struct skcipher_request subreq;
604 	struct scatterlist *src, *dst;
605 	struct skcipher_walk walk;
606 
607 	if (req->cryptlen < AES_BLOCK_SIZE)
608 		return -EINVAL;
609 
610 	err = skcipher_walk_virt(&walk, req, false);
611 
612 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
613 		int xts_blocks = DIV_ROUND_UP(req->cryptlen,
614 					      AES_BLOCK_SIZE) - 2;
615 
616 		skcipher_walk_abort(&walk);
617 
618 		skcipher_request_set_tfm(&subreq, tfm);
619 		skcipher_request_set_callback(&subreq,
620 					      skcipher_request_flags(req),
621 					      NULL, NULL);
622 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
623 					   xts_blocks * AES_BLOCK_SIZE,
624 					   req->iv);
625 		req = &subreq;
626 		err = skcipher_walk_virt(&walk, req, false);
627 	} else {
628 		tail = 0;
629 	}
630 
631 	for (first = 1; walk.nbytes >= AES_BLOCK_SIZE; first = 0) {
632 		int nbytes = walk.nbytes;
633 
634 		if (walk.nbytes < walk.total)
635 			nbytes &= ~(AES_BLOCK_SIZE - 1);
636 
637 		kernel_neon_begin();
638 		aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
639 				ctx->key1.key_dec, rounds, nbytes,
640 				ctx->key2.key_enc, walk.iv, first);
641 		kernel_neon_end();
642 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
643 	}
644 
645 	if (err || likely(!tail))
646 		return err;
647 
648 	dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
649 	if (req->dst != req->src)
650 		dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
651 
652 	skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
653 				   req->iv);
654 
655 	err = skcipher_walk_virt(&walk, &subreq, false);
656 	if (err)
657 		return err;
658 
659 
660 	kernel_neon_begin();
661 	aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
662 			ctx->key1.key_dec, rounds, walk.nbytes,
663 			ctx->key2.key_enc, walk.iv, first);
664 	kernel_neon_end();
665 
666 	return skcipher_walk_done(&walk, 0);
667 }
668 
669 static struct skcipher_alg aes_algs[] = { {
670 #if defined(USE_V8_CRYPTO_EXTENSIONS) || !IS_ENABLED(CONFIG_CRYPTO_AES_ARM64_BS)
671 	.base = {
672 		.cra_name		= "ecb(aes)",
673 		.cra_driver_name	= "ecb-aes-" MODE,
674 		.cra_priority		= PRIO,
675 		.cra_blocksize		= AES_BLOCK_SIZE,
676 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
677 		.cra_module		= THIS_MODULE,
678 	},
679 	.min_keysize	= AES_MIN_KEY_SIZE,
680 	.max_keysize	= AES_MAX_KEY_SIZE,
681 	.setkey		= skcipher_aes_setkey,
682 	.encrypt	= ecb_encrypt,
683 	.decrypt	= ecb_decrypt,
684 }, {
685 	.base = {
686 		.cra_name		= "cbc(aes)",
687 		.cra_driver_name	= "cbc-aes-" MODE,
688 		.cra_priority		= PRIO,
689 		.cra_blocksize		= AES_BLOCK_SIZE,
690 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
691 		.cra_module		= THIS_MODULE,
692 	},
693 	.min_keysize	= AES_MIN_KEY_SIZE,
694 	.max_keysize	= AES_MAX_KEY_SIZE,
695 	.ivsize		= AES_BLOCK_SIZE,
696 	.setkey		= skcipher_aes_setkey,
697 	.encrypt	= cbc_encrypt,
698 	.decrypt	= cbc_decrypt,
699 }, {
700 	.base = {
701 		.cra_name		= "ctr(aes)",
702 		.cra_driver_name	= "ctr-aes-" MODE,
703 		.cra_priority		= PRIO,
704 		.cra_blocksize		= 1,
705 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
706 		.cra_module		= THIS_MODULE,
707 	},
708 	.min_keysize	= AES_MIN_KEY_SIZE,
709 	.max_keysize	= AES_MAX_KEY_SIZE,
710 	.ivsize		= AES_BLOCK_SIZE,
711 	.chunksize	= AES_BLOCK_SIZE,
712 	.setkey		= skcipher_aes_setkey,
713 	.encrypt	= ctr_encrypt,
714 	.decrypt	= ctr_encrypt,
715 }, {
716 	.base = {
717 		.cra_name		= "xctr(aes)",
718 		.cra_driver_name	= "xctr-aes-" MODE,
719 		.cra_priority		= PRIO,
720 		.cra_blocksize		= 1,
721 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
722 		.cra_module		= THIS_MODULE,
723 	},
724 	.min_keysize	= AES_MIN_KEY_SIZE,
725 	.max_keysize	= AES_MAX_KEY_SIZE,
726 	.ivsize		= AES_BLOCK_SIZE,
727 	.chunksize	= AES_BLOCK_SIZE,
728 	.setkey		= skcipher_aes_setkey,
729 	.encrypt	= xctr_encrypt,
730 	.decrypt	= xctr_encrypt,
731 }, {
732 	.base = {
733 		.cra_name		= "xts(aes)",
734 		.cra_driver_name	= "xts-aes-" MODE,
735 		.cra_priority		= PRIO,
736 		.cra_blocksize		= AES_BLOCK_SIZE,
737 		.cra_ctxsize		= sizeof(struct crypto_aes_xts_ctx),
738 		.cra_module		= THIS_MODULE,
739 	},
740 	.min_keysize	= 2 * AES_MIN_KEY_SIZE,
741 	.max_keysize	= 2 * AES_MAX_KEY_SIZE,
742 	.ivsize		= AES_BLOCK_SIZE,
743 	.walksize	= 2 * AES_BLOCK_SIZE,
744 	.setkey		= xts_set_key,
745 	.encrypt	= xts_encrypt,
746 	.decrypt	= xts_decrypt,
747 }, {
748 #endif
749 	.base = {
750 		.cra_name		= "cts(cbc(aes))",
751 		.cra_driver_name	= "cts-cbc-aes-" MODE,
752 		.cra_priority		= PRIO,
753 		.cra_blocksize		= AES_BLOCK_SIZE,
754 		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
755 		.cra_module		= THIS_MODULE,
756 	},
757 	.min_keysize	= AES_MIN_KEY_SIZE,
758 	.max_keysize	= AES_MAX_KEY_SIZE,
759 	.ivsize		= AES_BLOCK_SIZE,
760 	.walksize	= 2 * AES_BLOCK_SIZE,
761 	.setkey		= skcipher_aes_setkey,
762 	.encrypt	= cts_cbc_encrypt,
763 	.decrypt	= cts_cbc_decrypt,
764 }, {
765 	.base = {
766 		.cra_name		= "essiv(cbc(aes),sha256)",
767 		.cra_driver_name	= "essiv-cbc-aes-sha256-" MODE,
768 		.cra_priority		= PRIO + 1,
769 		.cra_blocksize		= AES_BLOCK_SIZE,
770 		.cra_ctxsize		= sizeof(struct crypto_aes_essiv_cbc_ctx),
771 		.cra_module		= THIS_MODULE,
772 	},
773 	.min_keysize	= AES_MIN_KEY_SIZE,
774 	.max_keysize	= AES_MAX_KEY_SIZE,
775 	.ivsize		= AES_BLOCK_SIZE,
776 	.setkey		= essiv_cbc_set_key,
777 	.encrypt	= essiv_cbc_encrypt,
778 	.decrypt	= essiv_cbc_decrypt,
779 } };
780 
781 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
782 			 unsigned int key_len)
783 {
784 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
785 
786 	return aes_expandkey(&ctx->key, in_key, key_len);
787 }
788 
789 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
790 {
791 	u64 a = be64_to_cpu(x->a);
792 	u64 b = be64_to_cpu(x->b);
793 
794 	y->a = cpu_to_be64((a << 1) | (b >> 63));
795 	y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
796 }
797 
798 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
799 		       unsigned int key_len)
800 {
801 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
802 	be128 *consts = (be128 *)ctx->consts;
803 	int rounds = 6 + key_len / 4;
804 	int err;
805 
806 	err = cbcmac_setkey(tfm, in_key, key_len);
807 	if (err)
808 		return err;
809 
810 	/* encrypt the zero vector */
811 	kernel_neon_begin();
812 	aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
813 			rounds, 1);
814 	kernel_neon_end();
815 
816 	cmac_gf128_mul_by_x(consts, consts);
817 	cmac_gf128_mul_by_x(consts + 1, consts);
818 
819 	return 0;
820 }
821 
822 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
823 		       unsigned int key_len)
824 {
825 	static u8 const ks[3][AES_BLOCK_SIZE] = {
826 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
827 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
828 		{ [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
829 	};
830 
831 	struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
832 	int rounds = 6 + key_len / 4;
833 	u8 key[AES_BLOCK_SIZE];
834 	int err;
835 
836 	err = cbcmac_setkey(tfm, in_key, key_len);
837 	if (err)
838 		return err;
839 
840 	kernel_neon_begin();
841 	aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
842 	aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
843 	kernel_neon_end();
844 
845 	return cbcmac_setkey(tfm, key, sizeof(key));
846 }
847 
848 static int mac_init(struct shash_desc *desc)
849 {
850 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
851 
852 	memset(ctx->dg, 0, AES_BLOCK_SIZE);
853 	return 0;
854 }
855 
856 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
857 			  u8 dg[], int enc_before)
858 {
859 	int rounds = 6 + ctx->key_length / 4;
860 	int rem;
861 
862 	do {
863 		kernel_neon_begin();
864 		rem = aes_mac_update(in, ctx->key_enc, rounds, blocks,
865 				     dg, enc_before, !enc_before);
866 		kernel_neon_end();
867 		in += (blocks - rem) * AES_BLOCK_SIZE;
868 		blocks = rem;
869 	} while (blocks);
870 }
871 
872 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
873 {
874 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
875 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
876 	int blocks = len / AES_BLOCK_SIZE;
877 
878 	len %= AES_BLOCK_SIZE;
879 	mac_do_update(&tctx->key, p, blocks, ctx->dg, 0);
880 	return len;
881 }
882 
883 static int cbcmac_finup(struct shash_desc *desc, const u8 *src,
884 			unsigned int len, u8 *out)
885 {
886 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
887 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
888 
889 	if (len) {
890 		crypto_xor(ctx->dg, src, len);
891 		mac_do_update(&tctx->key, NULL, 0, ctx->dg, 1);
892 	}
893 	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
894 	return 0;
895 }
896 
897 static int cmac_finup(struct shash_desc *desc, const u8 *src, unsigned int len,
898 		      u8 *out)
899 {
900 	struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
901 	struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
902 	u8 *consts = tctx->consts;
903 
904 	crypto_xor(ctx->dg, src, len);
905 	if (len != AES_BLOCK_SIZE) {
906 		ctx->dg[len] ^= 0x80;
907 		consts += AES_BLOCK_SIZE;
908 	}
909 	mac_do_update(&tctx->key, consts, 1, ctx->dg, 0);
910 	memcpy(out, ctx->dg, AES_BLOCK_SIZE);
911 	return 0;
912 }
913 
914 static struct shash_alg mac_algs[] = { {
915 	.base.cra_name		= "cmac(aes)",
916 	.base.cra_driver_name	= "cmac-aes-" MODE,
917 	.base.cra_priority	= PRIO,
918 	.base.cra_flags		= CRYPTO_AHASH_ALG_BLOCK_ONLY |
919 				  CRYPTO_AHASH_ALG_FINAL_NONZERO,
920 	.base.cra_blocksize	= AES_BLOCK_SIZE,
921 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
922 				  2 * AES_BLOCK_SIZE,
923 	.base.cra_module	= THIS_MODULE,
924 
925 	.digestsize		= AES_BLOCK_SIZE,
926 	.init			= mac_init,
927 	.update			= mac_update,
928 	.finup			= cmac_finup,
929 	.setkey			= cmac_setkey,
930 	.descsize		= sizeof(struct mac_desc_ctx),
931 }, {
932 	.base.cra_name		= "xcbc(aes)",
933 	.base.cra_driver_name	= "xcbc-aes-" MODE,
934 	.base.cra_priority	= PRIO,
935 	.base.cra_flags		= CRYPTO_AHASH_ALG_BLOCK_ONLY |
936 				  CRYPTO_AHASH_ALG_FINAL_NONZERO,
937 	.base.cra_blocksize	= AES_BLOCK_SIZE,
938 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx) +
939 				  2 * AES_BLOCK_SIZE,
940 	.base.cra_module	= THIS_MODULE,
941 
942 	.digestsize		= AES_BLOCK_SIZE,
943 	.init			= mac_init,
944 	.update			= mac_update,
945 	.finup			= cmac_finup,
946 	.setkey			= xcbc_setkey,
947 	.descsize		= sizeof(struct mac_desc_ctx),
948 }, {
949 	.base.cra_name		= "cbcmac(aes)",
950 	.base.cra_driver_name	= "cbcmac-aes-" MODE,
951 	.base.cra_priority	= PRIO,
952 	.base.cra_flags		= CRYPTO_AHASH_ALG_BLOCK_ONLY,
953 	.base.cra_blocksize	= AES_BLOCK_SIZE,
954 	.base.cra_ctxsize	= sizeof(struct mac_tfm_ctx),
955 	.base.cra_module	= THIS_MODULE,
956 
957 	.digestsize		= AES_BLOCK_SIZE,
958 	.init			= mac_init,
959 	.update			= mac_update,
960 	.finup			= cbcmac_finup,
961 	.setkey			= cbcmac_setkey,
962 	.descsize		= sizeof(struct mac_desc_ctx),
963 } };
964 
965 static void aes_exit(void)
966 {
967 	crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
968 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
969 }
970 
971 static int __init aes_init(void)
972 {
973 	int err;
974 
975 	err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
976 	if (err)
977 		return err;
978 
979 	err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
980 	if (err)
981 		goto unregister_ciphers;
982 
983 	return 0;
984 
985 unregister_ciphers:
986 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
987 	return err;
988 }
989 
990 #ifdef USE_V8_CRYPTO_EXTENSIONS
991 module_cpu_feature_match(AES, aes_init);
992 EXPORT_SYMBOL_NS(ce_aes_mac_update, "CRYPTO_INTERNAL");
993 #else
994 module_init(aes_init);
995 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
996 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
997 EXPORT_SYMBOL(neon_aes_ctr_encrypt);
998 EXPORT_SYMBOL(neon_aes_xts_encrypt);
999 EXPORT_SYMBOL(neon_aes_xts_decrypt);
1000 #endif
1001 module_exit(aes_exit);
1002