xref: /linux/arch/arm/crypto/aes-neonbs-glue.c (revision 746680ec6696585e30db3e18c93a63df9cbec39c)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Bit sliced AES using NEON instructions
4  *
5  * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7 
8 #include <asm/neon.h>
9 #include <asm/simd.h>
10 #include <crypto/aes.h>
11 #include <crypto/internal/skcipher.h>
12 #include <crypto/scatterwalk.h>
13 #include <crypto/xts.h>
14 #include <linux/module.h>
15 #include "aes-cipher.h"
16 
17 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
18 MODULE_DESCRIPTION("Bit sliced AES using NEON instructions");
19 MODULE_LICENSE("GPL v2");
20 
21 MODULE_ALIAS_CRYPTO("ecb(aes)");
22 MODULE_ALIAS_CRYPTO("cbc(aes)");
23 MODULE_ALIAS_CRYPTO("ctr(aes)");
24 MODULE_ALIAS_CRYPTO("xts(aes)");
25 
26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds);
27 
28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[],
29 				  int rounds, int blocks);
30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[],
31 				  int rounds, int blocks);
32 
33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[],
34 				  int rounds, int blocks, u8 iv[]);
35 
36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[],
37 				  int rounds, int blocks, u8 ctr[]);
38 
39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[],
40 				  int rounds, int blocks, u8 iv[], int);
41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[],
42 				  int rounds, int blocks, u8 iv[], int);
43 
44 struct aesbs_ctx {
45 	int	rounds;
46 	u8	rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE);
47 };
48 
49 struct aesbs_cbc_ctx {
50 	struct aesbs_ctx	key;
51 	struct crypto_aes_ctx	fallback;
52 };
53 
54 struct aesbs_xts_ctx {
55 	struct aesbs_ctx	key;
56 	struct crypto_aes_ctx	fallback;
57 	struct crypto_aes_ctx	tweak_key;
58 };
59 
60 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
61 			unsigned int key_len)
62 {
63 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
64 	struct crypto_aes_ctx rk;
65 	int err;
66 
67 	err = aes_expandkey(&rk, in_key, key_len);
68 	if (err)
69 		return err;
70 
71 	ctx->rounds = 6 + key_len / 4;
72 
73 	kernel_neon_begin();
74 	aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds);
75 	kernel_neon_end();
76 
77 	return 0;
78 }
79 
80 static int __ecb_crypt(struct skcipher_request *req,
81 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
82 				  int rounds, int blocks))
83 {
84 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
85 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
86 	struct skcipher_walk walk;
87 	int err;
88 
89 	err = skcipher_walk_virt(&walk, req, false);
90 
91 	while (walk.nbytes >= AES_BLOCK_SIZE) {
92 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
93 
94 		if (walk.nbytes < walk.total)
95 			blocks = round_down(blocks,
96 					    walk.stride / AES_BLOCK_SIZE);
97 
98 		kernel_neon_begin();
99 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk,
100 		   ctx->rounds, blocks);
101 		kernel_neon_end();
102 		err = skcipher_walk_done(&walk,
103 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
104 	}
105 
106 	return err;
107 }
108 
109 static int ecb_encrypt(struct skcipher_request *req)
110 {
111 	return __ecb_crypt(req, aesbs_ecb_encrypt);
112 }
113 
114 static int ecb_decrypt(struct skcipher_request *req)
115 {
116 	return __ecb_crypt(req, aesbs_ecb_decrypt);
117 }
118 
119 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
120 			    unsigned int key_len)
121 {
122 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
123 	int err;
124 
125 	err = aes_expandkey(&ctx->fallback, in_key, key_len);
126 	if (err)
127 		return err;
128 
129 	ctx->key.rounds = 6 + key_len / 4;
130 
131 	kernel_neon_begin();
132 	aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds);
133 	kernel_neon_end();
134 
135 	return 0;
136 }
137 
138 static int cbc_encrypt(struct skcipher_request *req)
139 {
140 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
141 	const struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
142 	struct skcipher_walk walk;
143 	unsigned int nbytes;
144 	int err;
145 
146 	err = skcipher_walk_virt(&walk, req, false);
147 
148 	while ((nbytes = walk.nbytes) >= AES_BLOCK_SIZE) {
149 		const u8 *src = walk.src.virt.addr;
150 		u8 *dst = walk.dst.virt.addr;
151 		u8 *prev = walk.iv;
152 
153 		do {
154 			crypto_xor_cpy(dst, src, prev, AES_BLOCK_SIZE);
155 			__aes_arm_encrypt(ctx->fallback.key_enc,
156 					  ctx->key.rounds, dst, dst);
157 			prev = dst;
158 			src += AES_BLOCK_SIZE;
159 			dst += AES_BLOCK_SIZE;
160 			nbytes -= AES_BLOCK_SIZE;
161 		} while (nbytes >= AES_BLOCK_SIZE);
162 		memcpy(walk.iv, prev, AES_BLOCK_SIZE);
163 		err = skcipher_walk_done(&walk, nbytes);
164 	}
165 	return err;
166 }
167 
168 static int cbc_decrypt(struct skcipher_request *req)
169 {
170 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
171 	struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm);
172 	struct skcipher_walk walk;
173 	int err;
174 
175 	err = skcipher_walk_virt(&walk, req, false);
176 
177 	while (walk.nbytes >= AES_BLOCK_SIZE) {
178 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
179 
180 		if (walk.nbytes < walk.total)
181 			blocks = round_down(blocks,
182 					    walk.stride / AES_BLOCK_SIZE);
183 
184 		kernel_neon_begin();
185 		aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
186 				  ctx->key.rk, ctx->key.rounds, blocks,
187 				  walk.iv);
188 		kernel_neon_end();
189 		err = skcipher_walk_done(&walk,
190 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
191 	}
192 
193 	return err;
194 }
195 
196 static int ctr_encrypt(struct skcipher_request *req)
197 {
198 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
199 	struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm);
200 	struct skcipher_walk walk;
201 	u8 buf[AES_BLOCK_SIZE];
202 	int err;
203 
204 	err = skcipher_walk_virt(&walk, req, false);
205 
206 	while (walk.nbytes > 0) {
207 		const u8 *src = walk.src.virt.addr;
208 		u8 *dst = walk.dst.virt.addr;
209 		unsigned int bytes = walk.nbytes;
210 
211 		if (unlikely(bytes < AES_BLOCK_SIZE))
212 			src = dst = memcpy(buf + sizeof(buf) - bytes,
213 					   src, bytes);
214 		else if (walk.nbytes < walk.total)
215 			bytes &= ~(8 * AES_BLOCK_SIZE - 1);
216 
217 		kernel_neon_begin();
218 		aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv);
219 		kernel_neon_end();
220 
221 		if (unlikely(bytes < AES_BLOCK_SIZE))
222 			memcpy(walk.dst.virt.addr,
223 			       buf + sizeof(buf) - bytes, bytes);
224 
225 		err = skcipher_walk_done(&walk, walk.nbytes - bytes);
226 	}
227 
228 	return err;
229 }
230 
231 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
232 			    unsigned int key_len)
233 {
234 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
235 	int err;
236 
237 	err = xts_verify_key(tfm, in_key, key_len);
238 	if (err)
239 		return err;
240 
241 	key_len /= 2;
242 	err = aes_expandkey(&ctx->fallback, in_key, key_len);
243 	if (err)
244 		return err;
245 	err = aes_expandkey(&ctx->tweak_key, in_key + key_len, key_len);
246 	if (err)
247 		return err;
248 
249 	return aesbs_setkey(tfm, in_key, key_len);
250 }
251 
252 static int __xts_crypt(struct skcipher_request *req, bool encrypt,
253 		       void (*fn)(u8 out[], u8 const in[], u8 const rk[],
254 				  int rounds, int blocks, u8 iv[], int))
255 {
256 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
257 	struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
258 	const int rounds = ctx->key.rounds;
259 	int tail = req->cryptlen % AES_BLOCK_SIZE;
260 	struct skcipher_request subreq;
261 	u8 buf[2 * AES_BLOCK_SIZE];
262 	struct skcipher_walk walk;
263 	int err;
264 
265 	if (req->cryptlen < AES_BLOCK_SIZE)
266 		return -EINVAL;
267 
268 	if (unlikely(tail)) {
269 		skcipher_request_set_tfm(&subreq, tfm);
270 		skcipher_request_set_callback(&subreq,
271 					      skcipher_request_flags(req),
272 					      NULL, NULL);
273 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
274 					   req->cryptlen - tail, req->iv);
275 		req = &subreq;
276 	}
277 
278 	err = skcipher_walk_virt(&walk, req, true);
279 	if (err)
280 		return err;
281 
282 	__aes_arm_encrypt(ctx->tweak_key.key_enc, rounds, walk.iv, walk.iv);
283 
284 	while (walk.nbytes >= AES_BLOCK_SIZE) {
285 		unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE;
286 		int reorder_last_tweak = !encrypt && tail > 0;
287 
288 		if (walk.nbytes < walk.total) {
289 			blocks = round_down(blocks,
290 					    walk.stride / AES_BLOCK_SIZE);
291 			reorder_last_tweak = 0;
292 		}
293 
294 		kernel_neon_begin();
295 		fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk,
296 		   rounds, blocks, walk.iv, reorder_last_tweak);
297 		kernel_neon_end();
298 		err = skcipher_walk_done(&walk,
299 					 walk.nbytes - blocks * AES_BLOCK_SIZE);
300 	}
301 
302 	if (err || likely(!tail))
303 		return err;
304 
305 	/* handle ciphertext stealing */
306 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
307 				 AES_BLOCK_SIZE, 0);
308 	memcpy(buf + AES_BLOCK_SIZE, buf, tail);
309 	scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0);
310 
311 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
312 
313 	if (encrypt)
314 		__aes_arm_encrypt(ctx->fallback.key_enc, rounds, buf, buf);
315 	else
316 		__aes_arm_decrypt(ctx->fallback.key_dec, rounds, buf, buf);
317 
318 	crypto_xor(buf, req->iv, AES_BLOCK_SIZE);
319 
320 	scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE,
321 				 AES_BLOCK_SIZE + tail, 1);
322 	return 0;
323 }
324 
325 static int xts_encrypt(struct skcipher_request *req)
326 {
327 	return __xts_crypt(req, true, aesbs_xts_encrypt);
328 }
329 
330 static int xts_decrypt(struct skcipher_request *req)
331 {
332 	return __xts_crypt(req, false, aesbs_xts_decrypt);
333 }
334 
335 static struct skcipher_alg aes_algs[] = { {
336 	.base.cra_name		= "ecb(aes)",
337 	.base.cra_driver_name	= "ecb-aes-neonbs",
338 	.base.cra_priority	= 250,
339 	.base.cra_blocksize	= AES_BLOCK_SIZE,
340 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
341 	.base.cra_module	= THIS_MODULE,
342 
343 	.min_keysize		= AES_MIN_KEY_SIZE,
344 	.max_keysize		= AES_MAX_KEY_SIZE,
345 	.walksize		= 8 * AES_BLOCK_SIZE,
346 	.setkey			= aesbs_setkey,
347 	.encrypt		= ecb_encrypt,
348 	.decrypt		= ecb_decrypt,
349 }, {
350 	.base.cra_name		= "cbc(aes)",
351 	.base.cra_driver_name	= "cbc-aes-neonbs",
352 	.base.cra_priority	= 250,
353 	.base.cra_blocksize	= AES_BLOCK_SIZE,
354 	.base.cra_ctxsize	= sizeof(struct aesbs_cbc_ctx),
355 	.base.cra_module	= THIS_MODULE,
356 
357 	.min_keysize		= AES_MIN_KEY_SIZE,
358 	.max_keysize		= AES_MAX_KEY_SIZE,
359 	.walksize		= 8 * AES_BLOCK_SIZE,
360 	.ivsize			= AES_BLOCK_SIZE,
361 	.setkey			= aesbs_cbc_setkey,
362 	.encrypt		= cbc_encrypt,
363 	.decrypt		= cbc_decrypt,
364 }, {
365 	.base.cra_name		= "ctr(aes)",
366 	.base.cra_driver_name	= "ctr-aes-neonbs",
367 	.base.cra_priority	= 250,
368 	.base.cra_blocksize	= 1,
369 	.base.cra_ctxsize	= sizeof(struct aesbs_ctx),
370 	.base.cra_module	= THIS_MODULE,
371 
372 	.min_keysize		= AES_MIN_KEY_SIZE,
373 	.max_keysize		= AES_MAX_KEY_SIZE,
374 	.chunksize		= AES_BLOCK_SIZE,
375 	.walksize		= 8 * AES_BLOCK_SIZE,
376 	.ivsize			= AES_BLOCK_SIZE,
377 	.setkey			= aesbs_setkey,
378 	.encrypt		= ctr_encrypt,
379 	.decrypt		= ctr_encrypt,
380 }, {
381 	.base.cra_name		= "xts(aes)",
382 	.base.cra_driver_name	= "xts-aes-neonbs",
383 	.base.cra_priority	= 250,
384 	.base.cra_blocksize	= AES_BLOCK_SIZE,
385 	.base.cra_ctxsize	= sizeof(struct aesbs_xts_ctx),
386 	.base.cra_module	= THIS_MODULE,
387 
388 	.min_keysize		= 2 * AES_MIN_KEY_SIZE,
389 	.max_keysize		= 2 * AES_MAX_KEY_SIZE,
390 	.walksize		= 8 * AES_BLOCK_SIZE,
391 	.ivsize			= AES_BLOCK_SIZE,
392 	.setkey			= aesbs_xts_setkey,
393 	.encrypt		= xts_encrypt,
394 	.decrypt		= xts_decrypt,
395 } };
396 
397 static void aes_exit(void)
398 {
399 	crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
400 }
401 
402 static int __init aes_init(void)
403 {
404 	if (!(elf_hwcap & HWCAP_NEON))
405 		return -ENODEV;
406 
407 	return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
408 }
409 
410 module_init(aes_init);
411 module_exit(aes_exit);
412