xref: /linux/arch/arm64/crypto/sm4-ce-glue.c (revision 0e9ab8e4d44ae9d9aaf213bfd2c90bbe7289337b)
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, using ARMv8 Crypto Extensions
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (C) 2022, Alibaba Group.
8  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10 
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
14 #include <linux/cpufeature.h>
15 #include <asm/neon.h>
16 #include <asm/simd.h>
17 #include <crypto/b128ops.h>
18 #include <crypto/internal/simd.h>
19 #include <crypto/internal/skcipher.h>
20 #include <crypto/internal/hash.h>
21 #include <crypto/scatterwalk.h>
22 #include <crypto/xts.h>
23 #include <crypto/sm4.h>
24 
25 #define BYTES2BLKS(nbytes)	((nbytes) >> 4)
26 
27 asmlinkage void sm4_ce_expand_key(const u8 *key, u32 *rkey_enc, u32 *rkey_dec,
28 				  const u32 *fk, const u32 *ck);
29 asmlinkage void sm4_ce_crypt_block(const u32 *rkey, u8 *dst, const u8 *src);
30 asmlinkage void sm4_ce_crypt(const u32 *rkey, u8 *dst, const u8 *src,
31 			     unsigned int nblks);
32 asmlinkage void sm4_ce_cbc_enc(const u32 *rkey, u8 *dst, const u8 *src,
33 			       u8 *iv, unsigned int nblocks);
34 asmlinkage void sm4_ce_cbc_dec(const u32 *rkey, u8 *dst, const u8 *src,
35 			       u8 *iv, unsigned int nblocks);
36 asmlinkage void sm4_ce_cbc_cts_enc(const u32 *rkey, u8 *dst, const u8 *src,
37 				   u8 *iv, unsigned int nbytes);
38 asmlinkage void sm4_ce_cbc_cts_dec(const u32 *rkey, u8 *dst, const u8 *src,
39 				   u8 *iv, unsigned int nbytes);
40 asmlinkage void sm4_ce_cfb_enc(const u32 *rkey, u8 *dst, const u8 *src,
41 			       u8 *iv, unsigned int nblks);
42 asmlinkage void sm4_ce_cfb_dec(const u32 *rkey, u8 *dst, const u8 *src,
43 			       u8 *iv, unsigned int nblks);
44 asmlinkage void sm4_ce_ctr_enc(const u32 *rkey, u8 *dst, const u8 *src,
45 			       u8 *iv, unsigned int nblks);
46 asmlinkage void sm4_ce_xts_enc(const u32 *rkey1, u8 *dst, const u8 *src,
47 			       u8 *tweak, unsigned int nbytes,
48 			       const u32 *rkey2_enc);
49 asmlinkage void sm4_ce_xts_dec(const u32 *rkey1, u8 *dst, const u8 *src,
50 			       u8 *tweak, unsigned int nbytes,
51 			       const u32 *rkey2_enc);
52 asmlinkage void sm4_ce_mac_update(const u32 *rkey_enc, u8 *digest,
53 				  const u8 *src, unsigned int nblocks,
54 				  bool enc_before, bool enc_after);
55 
56 EXPORT_SYMBOL(sm4_ce_expand_key);
57 EXPORT_SYMBOL(sm4_ce_crypt_block);
58 EXPORT_SYMBOL(sm4_ce_cbc_enc);
59 EXPORT_SYMBOL(sm4_ce_cfb_enc);
60 
61 struct sm4_xts_ctx {
62 	struct sm4_ctx key1;
63 	struct sm4_ctx key2;
64 };
65 
66 struct sm4_mac_tfm_ctx {
67 	struct sm4_ctx key;
68 	u8 __aligned(8) consts[];
69 };
70 
71 struct sm4_mac_desc_ctx {
72 	unsigned int len;
73 	u8 digest[SM4_BLOCK_SIZE];
74 };
75 
76 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
77 		      unsigned int key_len)
78 {
79 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
80 
81 	if (key_len != SM4_KEY_SIZE)
82 		return -EINVAL;
83 
84 	kernel_neon_begin();
85 	sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
86 			  crypto_sm4_fk, crypto_sm4_ck);
87 	kernel_neon_end();
88 	return 0;
89 }
90 
91 static int sm4_xts_setkey(struct crypto_skcipher *tfm, const u8 *key,
92 			  unsigned int key_len)
93 {
94 	struct sm4_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
95 	int ret;
96 
97 	if (key_len != SM4_KEY_SIZE * 2)
98 		return -EINVAL;
99 
100 	ret = xts_verify_key(tfm, key, key_len);
101 	if (ret)
102 		return ret;
103 
104 	kernel_neon_begin();
105 	sm4_ce_expand_key(key, ctx->key1.rkey_enc,
106 			  ctx->key1.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
107 	sm4_ce_expand_key(&key[SM4_KEY_SIZE], ctx->key2.rkey_enc,
108 			  ctx->key2.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
109 	kernel_neon_end();
110 
111 	return 0;
112 }
113 
114 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
115 {
116 	struct skcipher_walk walk;
117 	unsigned int nbytes;
118 	int err;
119 
120 	err = skcipher_walk_virt(&walk, req, false);
121 
122 	while ((nbytes = walk.nbytes) > 0) {
123 		const u8 *src = walk.src.virt.addr;
124 		u8 *dst = walk.dst.virt.addr;
125 		unsigned int nblks;
126 
127 		kernel_neon_begin();
128 
129 		nblks = BYTES2BLKS(nbytes);
130 		if (nblks) {
131 			sm4_ce_crypt(rkey, dst, src, nblks);
132 			nbytes -= nblks * SM4_BLOCK_SIZE;
133 		}
134 
135 		kernel_neon_end();
136 
137 		err = skcipher_walk_done(&walk, nbytes);
138 	}
139 
140 	return err;
141 }
142 
143 static int sm4_ecb_encrypt(struct skcipher_request *req)
144 {
145 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
146 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
147 
148 	return sm4_ecb_do_crypt(req, ctx->rkey_enc);
149 }
150 
151 static int sm4_ecb_decrypt(struct skcipher_request *req)
152 {
153 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
154 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
155 
156 	return sm4_ecb_do_crypt(req, ctx->rkey_dec);
157 }
158 
159 static int sm4_cbc_crypt(struct skcipher_request *req,
160 			 struct sm4_ctx *ctx, bool encrypt)
161 {
162 	struct skcipher_walk walk;
163 	unsigned int nbytes;
164 	int err;
165 
166 	err = skcipher_walk_virt(&walk, req, false);
167 	if (err)
168 		return err;
169 
170 	while ((nbytes = walk.nbytes) > 0) {
171 		const u8 *src = walk.src.virt.addr;
172 		u8 *dst = walk.dst.virt.addr;
173 		unsigned int nblocks;
174 
175 		nblocks = nbytes / SM4_BLOCK_SIZE;
176 		if (nblocks) {
177 			kernel_neon_begin();
178 
179 			if (encrypt)
180 				sm4_ce_cbc_enc(ctx->rkey_enc, dst, src,
181 					       walk.iv, nblocks);
182 			else
183 				sm4_ce_cbc_dec(ctx->rkey_dec, dst, src,
184 					       walk.iv, nblocks);
185 
186 			kernel_neon_end();
187 		}
188 
189 		err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
190 	}
191 
192 	return err;
193 }
194 
195 static int sm4_cbc_encrypt(struct skcipher_request *req)
196 {
197 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
198 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
199 
200 	return sm4_cbc_crypt(req, ctx, true);
201 }
202 
203 static int sm4_cbc_decrypt(struct skcipher_request *req)
204 {
205 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
206 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
207 
208 	return sm4_cbc_crypt(req, ctx, false);
209 }
210 
211 static int sm4_cbc_cts_crypt(struct skcipher_request *req, bool encrypt)
212 {
213 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
214 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
215 	struct scatterlist *src = req->src;
216 	struct scatterlist *dst = req->dst;
217 	struct scatterlist sg_src[2], sg_dst[2];
218 	struct skcipher_request subreq;
219 	struct skcipher_walk walk;
220 	int cbc_blocks;
221 	int err;
222 
223 	if (req->cryptlen < SM4_BLOCK_SIZE)
224 		return -EINVAL;
225 
226 	if (req->cryptlen == SM4_BLOCK_SIZE)
227 		return sm4_cbc_crypt(req, ctx, encrypt);
228 
229 	skcipher_request_set_tfm(&subreq, tfm);
230 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
231 				      NULL, NULL);
232 
233 	/* handle the CBC cryption part */
234 	cbc_blocks = DIV_ROUND_UP(req->cryptlen, SM4_BLOCK_SIZE) - 2;
235 	if (cbc_blocks) {
236 		skcipher_request_set_crypt(&subreq, src, dst,
237 					   cbc_blocks * SM4_BLOCK_SIZE,
238 					   req->iv);
239 
240 		err = sm4_cbc_crypt(&subreq, ctx, encrypt);
241 		if (err)
242 			return err;
243 
244 		dst = src = scatterwalk_ffwd(sg_src, src, subreq.cryptlen);
245 		if (req->dst != req->src)
246 			dst = scatterwalk_ffwd(sg_dst, req->dst,
247 					       subreq.cryptlen);
248 	}
249 
250 	/* handle ciphertext stealing */
251 	skcipher_request_set_crypt(&subreq, src, dst,
252 				   req->cryptlen - cbc_blocks * SM4_BLOCK_SIZE,
253 				   req->iv);
254 
255 	err = skcipher_walk_virt(&walk, &subreq, false);
256 	if (err)
257 		return err;
258 
259 	kernel_neon_begin();
260 
261 	if (encrypt)
262 		sm4_ce_cbc_cts_enc(ctx->rkey_enc, walk.dst.virt.addr,
263 				   walk.src.virt.addr, walk.iv, walk.nbytes);
264 	else
265 		sm4_ce_cbc_cts_dec(ctx->rkey_dec, walk.dst.virt.addr,
266 				   walk.src.virt.addr, walk.iv, walk.nbytes);
267 
268 	kernel_neon_end();
269 
270 	return skcipher_walk_done(&walk, 0);
271 }
272 
273 static int sm4_cbc_cts_encrypt(struct skcipher_request *req)
274 {
275 	return sm4_cbc_cts_crypt(req, true);
276 }
277 
278 static int sm4_cbc_cts_decrypt(struct skcipher_request *req)
279 {
280 	return sm4_cbc_cts_crypt(req, false);
281 }
282 
283 static int sm4_cfb_encrypt(struct skcipher_request *req)
284 {
285 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
286 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
287 	struct skcipher_walk walk;
288 	unsigned int nbytes;
289 	int err;
290 
291 	err = skcipher_walk_virt(&walk, req, false);
292 
293 	while ((nbytes = walk.nbytes) > 0) {
294 		const u8 *src = walk.src.virt.addr;
295 		u8 *dst = walk.dst.virt.addr;
296 		unsigned int nblks;
297 
298 		kernel_neon_begin();
299 
300 		nblks = BYTES2BLKS(nbytes);
301 		if (nblks) {
302 			sm4_ce_cfb_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
303 			dst += nblks * SM4_BLOCK_SIZE;
304 			src += nblks * SM4_BLOCK_SIZE;
305 			nbytes -= nblks * SM4_BLOCK_SIZE;
306 		}
307 
308 		/* tail */
309 		if (walk.nbytes == walk.total && nbytes > 0) {
310 			u8 keystream[SM4_BLOCK_SIZE];
311 
312 			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
313 			crypto_xor_cpy(dst, src, keystream, nbytes);
314 			nbytes = 0;
315 		}
316 
317 		kernel_neon_end();
318 
319 		err = skcipher_walk_done(&walk, nbytes);
320 	}
321 
322 	return err;
323 }
324 
325 static int sm4_cfb_decrypt(struct skcipher_request *req)
326 {
327 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
328 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
329 	struct skcipher_walk walk;
330 	unsigned int nbytes;
331 	int err;
332 
333 	err = skcipher_walk_virt(&walk, req, false);
334 
335 	while ((nbytes = walk.nbytes) > 0) {
336 		const u8 *src = walk.src.virt.addr;
337 		u8 *dst = walk.dst.virt.addr;
338 		unsigned int nblks;
339 
340 		kernel_neon_begin();
341 
342 		nblks = BYTES2BLKS(nbytes);
343 		if (nblks) {
344 			sm4_ce_cfb_dec(ctx->rkey_enc, dst, src, walk.iv, nblks);
345 			dst += nblks * SM4_BLOCK_SIZE;
346 			src += nblks * SM4_BLOCK_SIZE;
347 			nbytes -= nblks * SM4_BLOCK_SIZE;
348 		}
349 
350 		/* tail */
351 		if (walk.nbytes == walk.total && nbytes > 0) {
352 			u8 keystream[SM4_BLOCK_SIZE];
353 
354 			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
355 			crypto_xor_cpy(dst, src, keystream, nbytes);
356 			nbytes = 0;
357 		}
358 
359 		kernel_neon_end();
360 
361 		err = skcipher_walk_done(&walk, nbytes);
362 	}
363 
364 	return err;
365 }
366 
367 static int sm4_ctr_crypt(struct skcipher_request *req)
368 {
369 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
370 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
371 	struct skcipher_walk walk;
372 	unsigned int nbytes;
373 	int err;
374 
375 	err = skcipher_walk_virt(&walk, req, false);
376 
377 	while ((nbytes = walk.nbytes) > 0) {
378 		const u8 *src = walk.src.virt.addr;
379 		u8 *dst = walk.dst.virt.addr;
380 		unsigned int nblks;
381 
382 		kernel_neon_begin();
383 
384 		nblks = BYTES2BLKS(nbytes);
385 		if (nblks) {
386 			sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
387 			dst += nblks * SM4_BLOCK_SIZE;
388 			src += nblks * SM4_BLOCK_SIZE;
389 			nbytes -= nblks * SM4_BLOCK_SIZE;
390 		}
391 
392 		/* tail */
393 		if (walk.nbytes == walk.total && nbytes > 0) {
394 			u8 keystream[SM4_BLOCK_SIZE];
395 
396 			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
397 			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
398 			crypto_xor_cpy(dst, src, keystream, nbytes);
399 			nbytes = 0;
400 		}
401 
402 		kernel_neon_end();
403 
404 		err = skcipher_walk_done(&walk, nbytes);
405 	}
406 
407 	return err;
408 }
409 
410 static int sm4_xts_crypt(struct skcipher_request *req, bool encrypt)
411 {
412 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
413 	struct sm4_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
414 	int tail = req->cryptlen % SM4_BLOCK_SIZE;
415 	const u32 *rkey2_enc = ctx->key2.rkey_enc;
416 	struct scatterlist sg_src[2], sg_dst[2];
417 	struct skcipher_request subreq;
418 	struct scatterlist *src, *dst;
419 	struct skcipher_walk walk;
420 	unsigned int nbytes;
421 	int err;
422 
423 	if (req->cryptlen < SM4_BLOCK_SIZE)
424 		return -EINVAL;
425 
426 	err = skcipher_walk_virt(&walk, req, false);
427 	if (err)
428 		return err;
429 
430 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
431 		int nblocks = DIV_ROUND_UP(req->cryptlen, SM4_BLOCK_SIZE) - 2;
432 
433 		skcipher_walk_abort(&walk);
434 
435 		skcipher_request_set_tfm(&subreq, tfm);
436 		skcipher_request_set_callback(&subreq,
437 					      skcipher_request_flags(req),
438 					      NULL, NULL);
439 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
440 					   nblocks * SM4_BLOCK_SIZE, req->iv);
441 
442 		err = skcipher_walk_virt(&walk, &subreq, false);
443 		if (err)
444 			return err;
445 	} else {
446 		tail = 0;
447 	}
448 
449 	while ((nbytes = walk.nbytes) >= SM4_BLOCK_SIZE) {
450 		if (nbytes < walk.total)
451 			nbytes &= ~(SM4_BLOCK_SIZE - 1);
452 
453 		kernel_neon_begin();
454 
455 		if (encrypt)
456 			sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
457 				       walk.src.virt.addr, walk.iv, nbytes,
458 				       rkey2_enc);
459 		else
460 			sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
461 				       walk.src.virt.addr, walk.iv, nbytes,
462 				       rkey2_enc);
463 
464 		kernel_neon_end();
465 
466 		rkey2_enc = NULL;
467 
468 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
469 		if (err)
470 			return err;
471 	}
472 
473 	if (likely(tail == 0))
474 		return 0;
475 
476 	/* handle ciphertext stealing */
477 
478 	dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
479 	if (req->dst != req->src)
480 		dst = scatterwalk_ffwd(sg_dst, req->dst, subreq.cryptlen);
481 
482 	skcipher_request_set_crypt(&subreq, src, dst, SM4_BLOCK_SIZE + tail,
483 				   req->iv);
484 
485 	err = skcipher_walk_virt(&walk, &subreq, false);
486 	if (err)
487 		return err;
488 
489 	kernel_neon_begin();
490 
491 	if (encrypt)
492 		sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
493 			       walk.src.virt.addr, walk.iv, walk.nbytes,
494 			       rkey2_enc);
495 	else
496 		sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
497 			       walk.src.virt.addr, walk.iv, walk.nbytes,
498 			       rkey2_enc);
499 
500 	kernel_neon_end();
501 
502 	return skcipher_walk_done(&walk, 0);
503 }
504 
505 static int sm4_xts_encrypt(struct skcipher_request *req)
506 {
507 	return sm4_xts_crypt(req, true);
508 }
509 
510 static int sm4_xts_decrypt(struct skcipher_request *req)
511 {
512 	return sm4_xts_crypt(req, false);
513 }
514 
515 static struct skcipher_alg sm4_algs[] = {
516 	{
517 		.base = {
518 			.cra_name		= "ecb(sm4)",
519 			.cra_driver_name	= "ecb-sm4-ce",
520 			.cra_priority		= 400,
521 			.cra_blocksize		= SM4_BLOCK_SIZE,
522 			.cra_ctxsize		= sizeof(struct sm4_ctx),
523 			.cra_module		= THIS_MODULE,
524 		},
525 		.min_keysize	= SM4_KEY_SIZE,
526 		.max_keysize	= SM4_KEY_SIZE,
527 		.setkey		= sm4_setkey,
528 		.encrypt	= sm4_ecb_encrypt,
529 		.decrypt	= sm4_ecb_decrypt,
530 	}, {
531 		.base = {
532 			.cra_name		= "cbc(sm4)",
533 			.cra_driver_name	= "cbc-sm4-ce",
534 			.cra_priority		= 400,
535 			.cra_blocksize		= SM4_BLOCK_SIZE,
536 			.cra_ctxsize		= sizeof(struct sm4_ctx),
537 			.cra_module		= THIS_MODULE,
538 		},
539 		.min_keysize	= SM4_KEY_SIZE,
540 		.max_keysize	= SM4_KEY_SIZE,
541 		.ivsize		= SM4_BLOCK_SIZE,
542 		.setkey		= sm4_setkey,
543 		.encrypt	= sm4_cbc_encrypt,
544 		.decrypt	= sm4_cbc_decrypt,
545 	}, {
546 		.base = {
547 			.cra_name		= "cfb(sm4)",
548 			.cra_driver_name	= "cfb-sm4-ce",
549 			.cra_priority		= 400,
550 			.cra_blocksize		= 1,
551 			.cra_ctxsize		= sizeof(struct sm4_ctx),
552 			.cra_module		= THIS_MODULE,
553 		},
554 		.min_keysize	= SM4_KEY_SIZE,
555 		.max_keysize	= SM4_KEY_SIZE,
556 		.ivsize		= SM4_BLOCK_SIZE,
557 		.chunksize	= SM4_BLOCK_SIZE,
558 		.setkey		= sm4_setkey,
559 		.encrypt	= sm4_cfb_encrypt,
560 		.decrypt	= sm4_cfb_decrypt,
561 	}, {
562 		.base = {
563 			.cra_name		= "ctr(sm4)",
564 			.cra_driver_name	= "ctr-sm4-ce",
565 			.cra_priority		= 400,
566 			.cra_blocksize		= 1,
567 			.cra_ctxsize		= sizeof(struct sm4_ctx),
568 			.cra_module		= THIS_MODULE,
569 		},
570 		.min_keysize	= SM4_KEY_SIZE,
571 		.max_keysize	= SM4_KEY_SIZE,
572 		.ivsize		= SM4_BLOCK_SIZE,
573 		.chunksize	= SM4_BLOCK_SIZE,
574 		.setkey		= sm4_setkey,
575 		.encrypt	= sm4_ctr_crypt,
576 		.decrypt	= sm4_ctr_crypt,
577 	}, {
578 		.base = {
579 			.cra_name		= "cts(cbc(sm4))",
580 			.cra_driver_name	= "cts-cbc-sm4-ce",
581 			.cra_priority		= 400,
582 			.cra_blocksize		= SM4_BLOCK_SIZE,
583 			.cra_ctxsize		= sizeof(struct sm4_ctx),
584 			.cra_module		= THIS_MODULE,
585 		},
586 		.min_keysize	= SM4_KEY_SIZE,
587 		.max_keysize	= SM4_KEY_SIZE,
588 		.ivsize		= SM4_BLOCK_SIZE,
589 		.walksize	= SM4_BLOCK_SIZE * 2,
590 		.setkey		= sm4_setkey,
591 		.encrypt	= sm4_cbc_cts_encrypt,
592 		.decrypt	= sm4_cbc_cts_decrypt,
593 	}, {
594 		.base = {
595 			.cra_name		= "xts(sm4)",
596 			.cra_driver_name	= "xts-sm4-ce",
597 			.cra_priority		= 400,
598 			.cra_blocksize		= SM4_BLOCK_SIZE,
599 			.cra_ctxsize		= sizeof(struct sm4_xts_ctx),
600 			.cra_module		= THIS_MODULE,
601 		},
602 		.min_keysize	= SM4_KEY_SIZE * 2,
603 		.max_keysize	= SM4_KEY_SIZE * 2,
604 		.ivsize		= SM4_BLOCK_SIZE,
605 		.walksize	= SM4_BLOCK_SIZE * 2,
606 		.setkey		= sm4_xts_setkey,
607 		.encrypt	= sm4_xts_encrypt,
608 		.decrypt	= sm4_xts_decrypt,
609 	}
610 };
611 
612 static int sm4_cbcmac_setkey(struct crypto_shash *tfm, const u8 *key,
613 			     unsigned int key_len)
614 {
615 	struct sm4_mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
616 
617 	if (key_len != SM4_KEY_SIZE)
618 		return -EINVAL;
619 
620 	kernel_neon_begin();
621 	sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
622 			  crypto_sm4_fk, crypto_sm4_ck);
623 	kernel_neon_end();
624 
625 	return 0;
626 }
627 
628 static int sm4_cmac_setkey(struct crypto_shash *tfm, const u8 *key,
629 			   unsigned int key_len)
630 {
631 	struct sm4_mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
632 	be128 *consts = (be128 *)ctx->consts;
633 	u64 a, b;
634 
635 	if (key_len != SM4_KEY_SIZE)
636 		return -EINVAL;
637 
638 	memset(consts, 0, SM4_BLOCK_SIZE);
639 
640 	kernel_neon_begin();
641 
642 	sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
643 			  crypto_sm4_fk, crypto_sm4_ck);
644 
645 	/* encrypt the zero block */
646 	sm4_ce_crypt_block(ctx->key.rkey_enc, (u8 *)consts, (const u8 *)consts);
647 
648 	kernel_neon_end();
649 
650 	/* gf(2^128) multiply zero-ciphertext with u and u^2 */
651 	a = be64_to_cpu(consts[0].a);
652 	b = be64_to_cpu(consts[0].b);
653 	consts[0].a = cpu_to_be64((a << 1) | (b >> 63));
654 	consts[0].b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
655 
656 	a = be64_to_cpu(consts[0].a);
657 	b = be64_to_cpu(consts[0].b);
658 	consts[1].a = cpu_to_be64((a << 1) | (b >> 63));
659 	consts[1].b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
660 
661 	return 0;
662 }
663 
664 static int sm4_xcbc_setkey(struct crypto_shash *tfm, const u8 *key,
665 			   unsigned int key_len)
666 {
667 	struct sm4_mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
668 	u8 __aligned(8) key2[SM4_BLOCK_SIZE];
669 	static u8 const ks[3][SM4_BLOCK_SIZE] = {
670 		{ [0 ... SM4_BLOCK_SIZE - 1] = 0x1},
671 		{ [0 ... SM4_BLOCK_SIZE - 1] = 0x2},
672 		{ [0 ... SM4_BLOCK_SIZE - 1] = 0x3},
673 	};
674 
675 	if (key_len != SM4_KEY_SIZE)
676 		return -EINVAL;
677 
678 	kernel_neon_begin();
679 
680 	sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
681 			  crypto_sm4_fk, crypto_sm4_ck);
682 
683 	sm4_ce_crypt_block(ctx->key.rkey_enc, key2, ks[0]);
684 	sm4_ce_crypt(ctx->key.rkey_enc, ctx->consts, ks[1], 2);
685 
686 	sm4_ce_expand_key(key2, ctx->key.rkey_enc, ctx->key.rkey_dec,
687 			  crypto_sm4_fk, crypto_sm4_ck);
688 
689 	kernel_neon_end();
690 
691 	return 0;
692 }
693 
694 static int sm4_mac_init(struct shash_desc *desc)
695 {
696 	struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
697 
698 	memset(ctx->digest, 0, SM4_BLOCK_SIZE);
699 	ctx->len = 0;
700 
701 	return 0;
702 }
703 
704 static int sm4_mac_update(struct shash_desc *desc, const u8 *p,
705 			  unsigned int len)
706 {
707 	struct sm4_mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
708 	struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
709 	unsigned int l, nblocks;
710 
711 	if (len == 0)
712 		return 0;
713 
714 	if (ctx->len || ctx->len + len < SM4_BLOCK_SIZE) {
715 		l = min(len, SM4_BLOCK_SIZE - ctx->len);
716 
717 		crypto_xor(ctx->digest + ctx->len, p, l);
718 		ctx->len += l;
719 		len -= l;
720 		p += l;
721 	}
722 
723 	if (len && (ctx->len % SM4_BLOCK_SIZE) == 0) {
724 		kernel_neon_begin();
725 
726 		if (len < SM4_BLOCK_SIZE && ctx->len == SM4_BLOCK_SIZE) {
727 			sm4_ce_crypt_block(tctx->key.rkey_enc,
728 					   ctx->digest, ctx->digest);
729 			ctx->len = 0;
730 		} else {
731 			nblocks = len / SM4_BLOCK_SIZE;
732 			len %= SM4_BLOCK_SIZE;
733 
734 			sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, p,
735 					  nblocks, (ctx->len == SM4_BLOCK_SIZE),
736 					  (len != 0));
737 
738 			p += nblocks * SM4_BLOCK_SIZE;
739 
740 			if (len == 0)
741 				ctx->len = SM4_BLOCK_SIZE;
742 		}
743 
744 		kernel_neon_end();
745 
746 		if (len) {
747 			crypto_xor(ctx->digest, p, len);
748 			ctx->len = len;
749 		}
750 	}
751 
752 	return 0;
753 }
754 
755 static int sm4_cmac_final(struct shash_desc *desc, u8 *out)
756 {
757 	struct sm4_mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
758 	struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
759 	const u8 *consts = tctx->consts;
760 
761 	if (ctx->len != SM4_BLOCK_SIZE) {
762 		ctx->digest[ctx->len] ^= 0x80;
763 		consts += SM4_BLOCK_SIZE;
764 	}
765 
766 	kernel_neon_begin();
767 	sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, consts, 1,
768 			  false, true);
769 	kernel_neon_end();
770 
771 	memcpy(out, ctx->digest, SM4_BLOCK_SIZE);
772 
773 	return 0;
774 }
775 
776 static int sm4_cbcmac_final(struct shash_desc *desc, u8 *out)
777 {
778 	struct sm4_mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
779 	struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
780 
781 	if (ctx->len) {
782 		kernel_neon_begin();
783 		sm4_ce_crypt_block(tctx->key.rkey_enc, ctx->digest,
784 				   ctx->digest);
785 		kernel_neon_end();
786 	}
787 
788 	memcpy(out, ctx->digest, SM4_BLOCK_SIZE);
789 
790 	return 0;
791 }
792 
793 static struct shash_alg sm4_mac_algs[] = {
794 	{
795 		.base = {
796 			.cra_name		= "cmac(sm4)",
797 			.cra_driver_name	= "cmac-sm4-ce",
798 			.cra_priority		= 400,
799 			.cra_blocksize		= SM4_BLOCK_SIZE,
800 			.cra_ctxsize		= sizeof(struct sm4_mac_tfm_ctx)
801 							+ SM4_BLOCK_SIZE * 2,
802 			.cra_module		= THIS_MODULE,
803 		},
804 		.digestsize	= SM4_BLOCK_SIZE,
805 		.init		= sm4_mac_init,
806 		.update		= sm4_mac_update,
807 		.final		= sm4_cmac_final,
808 		.setkey		= sm4_cmac_setkey,
809 		.descsize	= sizeof(struct sm4_mac_desc_ctx),
810 	}, {
811 		.base = {
812 			.cra_name		= "xcbc(sm4)",
813 			.cra_driver_name	= "xcbc-sm4-ce",
814 			.cra_priority		= 400,
815 			.cra_blocksize		= SM4_BLOCK_SIZE,
816 			.cra_ctxsize		= sizeof(struct sm4_mac_tfm_ctx)
817 							+ SM4_BLOCK_SIZE * 2,
818 			.cra_module		= THIS_MODULE,
819 		},
820 		.digestsize	= SM4_BLOCK_SIZE,
821 		.init		= sm4_mac_init,
822 		.update		= sm4_mac_update,
823 		.final		= sm4_cmac_final,
824 		.setkey		= sm4_xcbc_setkey,
825 		.descsize	= sizeof(struct sm4_mac_desc_ctx),
826 	}, {
827 		.base = {
828 			.cra_name		= "cbcmac(sm4)",
829 			.cra_driver_name	= "cbcmac-sm4-ce",
830 			.cra_priority		= 400,
831 			.cra_blocksize		= 1,
832 			.cra_ctxsize		= sizeof(struct sm4_mac_tfm_ctx),
833 			.cra_module		= THIS_MODULE,
834 		},
835 		.digestsize	= SM4_BLOCK_SIZE,
836 		.init		= sm4_mac_init,
837 		.update		= sm4_mac_update,
838 		.final		= sm4_cbcmac_final,
839 		.setkey		= sm4_cbcmac_setkey,
840 		.descsize	= sizeof(struct sm4_mac_desc_ctx),
841 	}
842 };
843 
844 static int __init sm4_init(void)
845 {
846 	int err;
847 
848 	err = crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
849 	if (err)
850 		return err;
851 
852 	err = crypto_register_shashes(sm4_mac_algs, ARRAY_SIZE(sm4_mac_algs));
853 	if (err)
854 		goto out_err;
855 
856 	return 0;
857 
858 out_err:
859 	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
860 	return err;
861 }
862 
863 static void __exit sm4_exit(void)
864 {
865 	crypto_unregister_shashes(sm4_mac_algs, ARRAY_SIZE(sm4_mac_algs));
866 	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
867 }
868 
869 module_cpu_feature_match(SM4, sm4_init);
870 module_exit(sm4_exit);
871 
872 MODULE_DESCRIPTION("SM4 ECB/CBC/CFB/CTR/XTS using ARMv8 Crypto Extensions");
873 MODULE_ALIAS_CRYPTO("sm4-ce");
874 MODULE_ALIAS_CRYPTO("sm4");
875 MODULE_ALIAS_CRYPTO("ecb(sm4)");
876 MODULE_ALIAS_CRYPTO("cbc(sm4)");
877 MODULE_ALIAS_CRYPTO("cfb(sm4)");
878 MODULE_ALIAS_CRYPTO("ctr(sm4)");
879 MODULE_ALIAS_CRYPTO("cts(cbc(sm4))");
880 MODULE_ALIAS_CRYPTO("xts(sm4)");
881 MODULE_ALIAS_CRYPTO("cmac(sm4)");
882 MODULE_ALIAS_CRYPTO("xcbc(sm4)");
883 MODULE_ALIAS_CRYPTO("cbcmac(sm4)");
884 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
885 MODULE_LICENSE("GPL v2");
886