xref: /linux/arch/arm64/crypto/sm4-ce-glue.c (revision fcad9bbf9e1a7de6c53908954ba1b1a1ab11ef1e)
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, using ARMv8 Crypto Extensions
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (C) 2022, Alibaba Group.
8  * Copyright (C) 2022 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10 
11 #include <asm/neon.h>
12 #include <crypto/b128ops.h>
13 #include <crypto/internal/hash.h>
14 #include <crypto/internal/skcipher.h>
15 #include <crypto/scatterwalk.h>
16 #include <crypto/sm4.h>
17 #include <crypto/utils.h>
18 #include <crypto/xts.h>
19 #include <linux/cpufeature.h>
20 #include <linux/kernel.h>
21 #include <linux/module.h>
22 #include <linux/string.h>
23 
24 #define BYTES2BLKS(nbytes)	((nbytes) >> 4)
25 
26 asmlinkage void sm4_ce_expand_key(const u8 *key, u32 *rkey_enc, u32 *rkey_dec,
27 				  const u32 *fk, const u32 *ck);
28 asmlinkage void sm4_ce_crypt_block(const u32 *rkey, u8 *dst, const u8 *src);
29 asmlinkage void sm4_ce_crypt(const u32 *rkey, u8 *dst, const u8 *src,
30 			     unsigned int nblks);
31 asmlinkage void sm4_ce_cbc_enc(const u32 *rkey, u8 *dst, const u8 *src,
32 			       u8 *iv, unsigned int nblocks);
33 asmlinkage void sm4_ce_cbc_dec(const u32 *rkey, u8 *dst, const u8 *src,
34 			       u8 *iv, unsigned int nblocks);
35 asmlinkage void sm4_ce_cbc_cts_enc(const u32 *rkey, u8 *dst, const u8 *src,
36 				   u8 *iv, unsigned int nbytes);
37 asmlinkage void sm4_ce_cbc_cts_dec(const u32 *rkey, u8 *dst, const u8 *src,
38 				   u8 *iv, unsigned int nbytes);
39 asmlinkage void sm4_ce_ctr_enc(const u32 *rkey, u8 *dst, const u8 *src,
40 			       u8 *iv, unsigned int nblks);
41 asmlinkage void sm4_ce_xts_enc(const u32 *rkey1, u8 *dst, const u8 *src,
42 			       u8 *tweak, unsigned int nbytes,
43 			       const u32 *rkey2_enc);
44 asmlinkage void sm4_ce_xts_dec(const u32 *rkey1, u8 *dst, const u8 *src,
45 			       u8 *tweak, unsigned int nbytes,
46 			       const u32 *rkey2_enc);
47 asmlinkage void sm4_ce_mac_update(const u32 *rkey_enc, u8 *digest,
48 				  const u8 *src, unsigned int nblocks,
49 				  bool enc_before, bool enc_after);
50 
51 EXPORT_SYMBOL(sm4_ce_expand_key);
52 EXPORT_SYMBOL(sm4_ce_crypt_block);
53 EXPORT_SYMBOL(sm4_ce_cbc_enc);
54 
55 struct sm4_xts_ctx {
56 	struct sm4_ctx key1;
57 	struct sm4_ctx key2;
58 };
59 
60 struct sm4_mac_tfm_ctx {
61 	struct sm4_ctx key;
62 	u8 __aligned(8) consts[];
63 };
64 
65 struct sm4_mac_desc_ctx {
66 	u8 digest[SM4_BLOCK_SIZE];
67 };
68 
69 static int sm4_setkey(struct crypto_skcipher *tfm, const u8 *key,
70 		      unsigned int key_len)
71 {
72 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
73 
74 	if (key_len != SM4_KEY_SIZE)
75 		return -EINVAL;
76 
77 	kernel_neon_begin();
78 	sm4_ce_expand_key(key, ctx->rkey_enc, ctx->rkey_dec,
79 			  crypto_sm4_fk, crypto_sm4_ck);
80 	kernel_neon_end();
81 	return 0;
82 }
83 
84 static int sm4_xts_setkey(struct crypto_skcipher *tfm, const u8 *key,
85 			  unsigned int key_len)
86 {
87 	struct sm4_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
88 	int ret;
89 
90 	if (key_len != SM4_KEY_SIZE * 2)
91 		return -EINVAL;
92 
93 	ret = xts_verify_key(tfm, key, key_len);
94 	if (ret)
95 		return ret;
96 
97 	kernel_neon_begin();
98 	sm4_ce_expand_key(key, ctx->key1.rkey_enc,
99 			  ctx->key1.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
100 	sm4_ce_expand_key(&key[SM4_KEY_SIZE], ctx->key2.rkey_enc,
101 			  ctx->key2.rkey_dec, crypto_sm4_fk, crypto_sm4_ck);
102 	kernel_neon_end();
103 
104 	return 0;
105 }
106 
107 static int sm4_ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
108 {
109 	struct skcipher_walk walk;
110 	unsigned int nbytes;
111 	int err;
112 
113 	err = skcipher_walk_virt(&walk, req, false);
114 
115 	while ((nbytes = walk.nbytes) > 0) {
116 		const u8 *src = walk.src.virt.addr;
117 		u8 *dst = walk.dst.virt.addr;
118 		unsigned int nblks;
119 
120 		kernel_neon_begin();
121 
122 		nblks = BYTES2BLKS(nbytes);
123 		if (nblks) {
124 			sm4_ce_crypt(rkey, dst, src, nblks);
125 			nbytes -= nblks * SM4_BLOCK_SIZE;
126 		}
127 
128 		kernel_neon_end();
129 
130 		err = skcipher_walk_done(&walk, nbytes);
131 	}
132 
133 	return err;
134 }
135 
136 static int sm4_ecb_encrypt(struct skcipher_request *req)
137 {
138 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
139 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
140 
141 	return sm4_ecb_do_crypt(req, ctx->rkey_enc);
142 }
143 
144 static int sm4_ecb_decrypt(struct skcipher_request *req)
145 {
146 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
147 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
148 
149 	return sm4_ecb_do_crypt(req, ctx->rkey_dec);
150 }
151 
152 static int sm4_cbc_crypt(struct skcipher_request *req,
153 			 struct sm4_ctx *ctx, bool encrypt)
154 {
155 	struct skcipher_walk walk;
156 	unsigned int nbytes;
157 	int err;
158 
159 	err = skcipher_walk_virt(&walk, req, false);
160 	if (err)
161 		return err;
162 
163 	while ((nbytes = walk.nbytes) > 0) {
164 		const u8 *src = walk.src.virt.addr;
165 		u8 *dst = walk.dst.virt.addr;
166 		unsigned int nblocks;
167 
168 		nblocks = nbytes / SM4_BLOCK_SIZE;
169 		if (nblocks) {
170 			kernel_neon_begin();
171 
172 			if (encrypt)
173 				sm4_ce_cbc_enc(ctx->rkey_enc, dst, src,
174 					       walk.iv, nblocks);
175 			else
176 				sm4_ce_cbc_dec(ctx->rkey_dec, dst, src,
177 					       walk.iv, nblocks);
178 
179 			kernel_neon_end();
180 		}
181 
182 		err = skcipher_walk_done(&walk, nbytes % SM4_BLOCK_SIZE);
183 	}
184 
185 	return err;
186 }
187 
188 static int sm4_cbc_encrypt(struct skcipher_request *req)
189 {
190 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
191 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
192 
193 	return sm4_cbc_crypt(req, ctx, true);
194 }
195 
196 static int sm4_cbc_decrypt(struct skcipher_request *req)
197 {
198 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
199 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
200 
201 	return sm4_cbc_crypt(req, ctx, false);
202 }
203 
204 static int sm4_cbc_cts_crypt(struct skcipher_request *req, bool encrypt)
205 {
206 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
207 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
208 	struct scatterlist *src = req->src;
209 	struct scatterlist *dst = req->dst;
210 	struct scatterlist sg_src[2], sg_dst[2];
211 	struct skcipher_request subreq;
212 	struct skcipher_walk walk;
213 	int cbc_blocks;
214 	int err;
215 
216 	if (req->cryptlen < SM4_BLOCK_SIZE)
217 		return -EINVAL;
218 
219 	if (req->cryptlen == SM4_BLOCK_SIZE)
220 		return sm4_cbc_crypt(req, ctx, encrypt);
221 
222 	skcipher_request_set_tfm(&subreq, tfm);
223 	skcipher_request_set_callback(&subreq, skcipher_request_flags(req),
224 				      NULL, NULL);
225 
226 	/* handle the CBC cryption part */
227 	cbc_blocks = DIV_ROUND_UP(req->cryptlen, SM4_BLOCK_SIZE) - 2;
228 	if (cbc_blocks) {
229 		skcipher_request_set_crypt(&subreq, src, dst,
230 					   cbc_blocks * SM4_BLOCK_SIZE,
231 					   req->iv);
232 
233 		err = sm4_cbc_crypt(&subreq, ctx, encrypt);
234 		if (err)
235 			return err;
236 
237 		dst = src = scatterwalk_ffwd(sg_src, src, subreq.cryptlen);
238 		if (req->dst != req->src)
239 			dst = scatterwalk_ffwd(sg_dst, req->dst,
240 					       subreq.cryptlen);
241 	}
242 
243 	/* handle ciphertext stealing */
244 	skcipher_request_set_crypt(&subreq, src, dst,
245 				   req->cryptlen - cbc_blocks * SM4_BLOCK_SIZE,
246 				   req->iv);
247 
248 	err = skcipher_walk_virt(&walk, &subreq, false);
249 	if (err)
250 		return err;
251 
252 	kernel_neon_begin();
253 
254 	if (encrypt)
255 		sm4_ce_cbc_cts_enc(ctx->rkey_enc, walk.dst.virt.addr,
256 				   walk.src.virt.addr, walk.iv, walk.nbytes);
257 	else
258 		sm4_ce_cbc_cts_dec(ctx->rkey_dec, walk.dst.virt.addr,
259 				   walk.src.virt.addr, walk.iv, walk.nbytes);
260 
261 	kernel_neon_end();
262 
263 	return skcipher_walk_done(&walk, 0);
264 }
265 
266 static int sm4_cbc_cts_encrypt(struct skcipher_request *req)
267 {
268 	return sm4_cbc_cts_crypt(req, true);
269 }
270 
271 static int sm4_cbc_cts_decrypt(struct skcipher_request *req)
272 {
273 	return sm4_cbc_cts_crypt(req, false);
274 }
275 
276 static int sm4_ctr_crypt(struct skcipher_request *req)
277 {
278 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
279 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
280 	struct skcipher_walk walk;
281 	unsigned int nbytes;
282 	int err;
283 
284 	err = skcipher_walk_virt(&walk, req, false);
285 
286 	while ((nbytes = walk.nbytes) > 0) {
287 		const u8 *src = walk.src.virt.addr;
288 		u8 *dst = walk.dst.virt.addr;
289 		unsigned int nblks;
290 
291 		kernel_neon_begin();
292 
293 		nblks = BYTES2BLKS(nbytes);
294 		if (nblks) {
295 			sm4_ce_ctr_enc(ctx->rkey_enc, dst, src, walk.iv, nblks);
296 			dst += nblks * SM4_BLOCK_SIZE;
297 			src += nblks * SM4_BLOCK_SIZE;
298 			nbytes -= nblks * SM4_BLOCK_SIZE;
299 		}
300 
301 		/* tail */
302 		if (walk.nbytes == walk.total && nbytes > 0) {
303 			u8 keystream[SM4_BLOCK_SIZE];
304 
305 			sm4_ce_crypt_block(ctx->rkey_enc, keystream, walk.iv);
306 			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
307 			crypto_xor_cpy(dst, src, keystream, nbytes);
308 			nbytes = 0;
309 		}
310 
311 		kernel_neon_end();
312 
313 		err = skcipher_walk_done(&walk, nbytes);
314 	}
315 
316 	return err;
317 }
318 
319 static int sm4_xts_crypt(struct skcipher_request *req, bool encrypt)
320 {
321 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
322 	struct sm4_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
323 	int tail = req->cryptlen % SM4_BLOCK_SIZE;
324 	const u32 *rkey2_enc = ctx->key2.rkey_enc;
325 	struct scatterlist sg_src[2], sg_dst[2];
326 	struct skcipher_request subreq;
327 	struct scatterlist *src, *dst;
328 	struct skcipher_walk walk;
329 	unsigned int nbytes;
330 	int err;
331 
332 	if (req->cryptlen < SM4_BLOCK_SIZE)
333 		return -EINVAL;
334 
335 	err = skcipher_walk_virt(&walk, req, false);
336 	if (err)
337 		return err;
338 
339 	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
340 		int nblocks = DIV_ROUND_UP(req->cryptlen, SM4_BLOCK_SIZE) - 2;
341 
342 		skcipher_walk_abort(&walk);
343 
344 		skcipher_request_set_tfm(&subreq, tfm);
345 		skcipher_request_set_callback(&subreq,
346 					      skcipher_request_flags(req),
347 					      NULL, NULL);
348 		skcipher_request_set_crypt(&subreq, req->src, req->dst,
349 					   nblocks * SM4_BLOCK_SIZE, req->iv);
350 
351 		err = skcipher_walk_virt(&walk, &subreq, false);
352 		if (err)
353 			return err;
354 	} else {
355 		tail = 0;
356 	}
357 
358 	while ((nbytes = walk.nbytes) >= SM4_BLOCK_SIZE) {
359 		if (nbytes < walk.total)
360 			nbytes &= ~(SM4_BLOCK_SIZE - 1);
361 
362 		kernel_neon_begin();
363 
364 		if (encrypt)
365 			sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
366 				       walk.src.virt.addr, walk.iv, nbytes,
367 				       rkey2_enc);
368 		else
369 			sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
370 				       walk.src.virt.addr, walk.iv, nbytes,
371 				       rkey2_enc);
372 
373 		kernel_neon_end();
374 
375 		rkey2_enc = NULL;
376 
377 		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
378 		if (err)
379 			return err;
380 	}
381 
382 	if (likely(tail == 0))
383 		return 0;
384 
385 	/* handle ciphertext stealing */
386 
387 	dst = src = scatterwalk_ffwd(sg_src, req->src, subreq.cryptlen);
388 	if (req->dst != req->src)
389 		dst = scatterwalk_ffwd(sg_dst, req->dst, subreq.cryptlen);
390 
391 	skcipher_request_set_crypt(&subreq, src, dst, SM4_BLOCK_SIZE + tail,
392 				   req->iv);
393 
394 	err = skcipher_walk_virt(&walk, &subreq, false);
395 	if (err)
396 		return err;
397 
398 	kernel_neon_begin();
399 
400 	if (encrypt)
401 		sm4_ce_xts_enc(ctx->key1.rkey_enc, walk.dst.virt.addr,
402 			       walk.src.virt.addr, walk.iv, walk.nbytes,
403 			       rkey2_enc);
404 	else
405 		sm4_ce_xts_dec(ctx->key1.rkey_dec, walk.dst.virt.addr,
406 			       walk.src.virt.addr, walk.iv, walk.nbytes,
407 			       rkey2_enc);
408 
409 	kernel_neon_end();
410 
411 	return skcipher_walk_done(&walk, 0);
412 }
413 
414 static int sm4_xts_encrypt(struct skcipher_request *req)
415 {
416 	return sm4_xts_crypt(req, true);
417 }
418 
419 static int sm4_xts_decrypt(struct skcipher_request *req)
420 {
421 	return sm4_xts_crypt(req, false);
422 }
423 
424 static struct skcipher_alg sm4_algs[] = {
425 	{
426 		.base = {
427 			.cra_name		= "ecb(sm4)",
428 			.cra_driver_name	= "ecb-sm4-ce",
429 			.cra_priority		= 400,
430 			.cra_blocksize		= SM4_BLOCK_SIZE,
431 			.cra_ctxsize		= sizeof(struct sm4_ctx),
432 			.cra_module		= THIS_MODULE,
433 		},
434 		.min_keysize	= SM4_KEY_SIZE,
435 		.max_keysize	= SM4_KEY_SIZE,
436 		.setkey		= sm4_setkey,
437 		.encrypt	= sm4_ecb_encrypt,
438 		.decrypt	= sm4_ecb_decrypt,
439 	}, {
440 		.base = {
441 			.cra_name		= "cbc(sm4)",
442 			.cra_driver_name	= "cbc-sm4-ce",
443 			.cra_priority		= 400,
444 			.cra_blocksize		= SM4_BLOCK_SIZE,
445 			.cra_ctxsize		= sizeof(struct sm4_ctx),
446 			.cra_module		= THIS_MODULE,
447 		},
448 		.min_keysize	= SM4_KEY_SIZE,
449 		.max_keysize	= SM4_KEY_SIZE,
450 		.ivsize		= SM4_BLOCK_SIZE,
451 		.setkey		= sm4_setkey,
452 		.encrypt	= sm4_cbc_encrypt,
453 		.decrypt	= sm4_cbc_decrypt,
454 	}, {
455 		.base = {
456 			.cra_name		= "ctr(sm4)",
457 			.cra_driver_name	= "ctr-sm4-ce",
458 			.cra_priority		= 400,
459 			.cra_blocksize		= 1,
460 			.cra_ctxsize		= sizeof(struct sm4_ctx),
461 			.cra_module		= THIS_MODULE,
462 		},
463 		.min_keysize	= SM4_KEY_SIZE,
464 		.max_keysize	= SM4_KEY_SIZE,
465 		.ivsize		= SM4_BLOCK_SIZE,
466 		.chunksize	= SM4_BLOCK_SIZE,
467 		.setkey		= sm4_setkey,
468 		.encrypt	= sm4_ctr_crypt,
469 		.decrypt	= sm4_ctr_crypt,
470 	}, {
471 		.base = {
472 			.cra_name		= "cts(cbc(sm4))",
473 			.cra_driver_name	= "cts-cbc-sm4-ce",
474 			.cra_priority		= 400,
475 			.cra_blocksize		= SM4_BLOCK_SIZE,
476 			.cra_ctxsize		= sizeof(struct sm4_ctx),
477 			.cra_module		= THIS_MODULE,
478 		},
479 		.min_keysize	= SM4_KEY_SIZE,
480 		.max_keysize	= SM4_KEY_SIZE,
481 		.ivsize		= SM4_BLOCK_SIZE,
482 		.walksize	= SM4_BLOCK_SIZE * 2,
483 		.setkey		= sm4_setkey,
484 		.encrypt	= sm4_cbc_cts_encrypt,
485 		.decrypt	= sm4_cbc_cts_decrypt,
486 	}, {
487 		.base = {
488 			.cra_name		= "xts(sm4)",
489 			.cra_driver_name	= "xts-sm4-ce",
490 			.cra_priority		= 400,
491 			.cra_blocksize		= SM4_BLOCK_SIZE,
492 			.cra_ctxsize		= sizeof(struct sm4_xts_ctx),
493 			.cra_module		= THIS_MODULE,
494 		},
495 		.min_keysize	= SM4_KEY_SIZE * 2,
496 		.max_keysize	= SM4_KEY_SIZE * 2,
497 		.ivsize		= SM4_BLOCK_SIZE,
498 		.walksize	= SM4_BLOCK_SIZE * 2,
499 		.setkey		= sm4_xts_setkey,
500 		.encrypt	= sm4_xts_encrypt,
501 		.decrypt	= sm4_xts_decrypt,
502 	}
503 };
504 
505 static int sm4_cbcmac_setkey(struct crypto_shash *tfm, const u8 *key,
506 			     unsigned int key_len)
507 {
508 	struct sm4_mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
509 
510 	if (key_len != SM4_KEY_SIZE)
511 		return -EINVAL;
512 
513 	kernel_neon_begin();
514 	sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
515 			  crypto_sm4_fk, crypto_sm4_ck);
516 	kernel_neon_end();
517 
518 	return 0;
519 }
520 
521 static int sm4_cmac_setkey(struct crypto_shash *tfm, const u8 *key,
522 			   unsigned int key_len)
523 {
524 	struct sm4_mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
525 	be128 *consts = (be128 *)ctx->consts;
526 	u64 a, b;
527 
528 	if (key_len != SM4_KEY_SIZE)
529 		return -EINVAL;
530 
531 	memset(consts, 0, SM4_BLOCK_SIZE);
532 
533 	kernel_neon_begin();
534 
535 	sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
536 			  crypto_sm4_fk, crypto_sm4_ck);
537 
538 	/* encrypt the zero block */
539 	sm4_ce_crypt_block(ctx->key.rkey_enc, (u8 *)consts, (const u8 *)consts);
540 
541 	kernel_neon_end();
542 
543 	/* gf(2^128) multiply zero-ciphertext with u and u^2 */
544 	a = be64_to_cpu(consts[0].a);
545 	b = be64_to_cpu(consts[0].b);
546 	consts[0].a = cpu_to_be64((a << 1) | (b >> 63));
547 	consts[0].b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
548 
549 	a = be64_to_cpu(consts[0].a);
550 	b = be64_to_cpu(consts[0].b);
551 	consts[1].a = cpu_to_be64((a << 1) | (b >> 63));
552 	consts[1].b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
553 
554 	return 0;
555 }
556 
557 static int sm4_xcbc_setkey(struct crypto_shash *tfm, const u8 *key,
558 			   unsigned int key_len)
559 {
560 	struct sm4_mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
561 	u8 __aligned(8) key2[SM4_BLOCK_SIZE];
562 	static u8 const ks[3][SM4_BLOCK_SIZE] = {
563 		{ [0 ... SM4_BLOCK_SIZE - 1] = 0x1},
564 		{ [0 ... SM4_BLOCK_SIZE - 1] = 0x2},
565 		{ [0 ... SM4_BLOCK_SIZE - 1] = 0x3},
566 	};
567 
568 	if (key_len != SM4_KEY_SIZE)
569 		return -EINVAL;
570 
571 	kernel_neon_begin();
572 
573 	sm4_ce_expand_key(key, ctx->key.rkey_enc, ctx->key.rkey_dec,
574 			  crypto_sm4_fk, crypto_sm4_ck);
575 
576 	sm4_ce_crypt_block(ctx->key.rkey_enc, key2, ks[0]);
577 	sm4_ce_crypt(ctx->key.rkey_enc, ctx->consts, ks[1], 2);
578 
579 	sm4_ce_expand_key(key2, ctx->key.rkey_enc, ctx->key.rkey_dec,
580 			  crypto_sm4_fk, crypto_sm4_ck);
581 
582 	kernel_neon_end();
583 
584 	return 0;
585 }
586 
587 static int sm4_mac_init(struct shash_desc *desc)
588 {
589 	struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
590 
591 	memset(ctx->digest, 0, SM4_BLOCK_SIZE);
592 	return 0;
593 }
594 
595 static int sm4_mac_update(struct shash_desc *desc, const u8 *p,
596 			  unsigned int len)
597 {
598 	struct sm4_mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
599 	struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
600 	unsigned int nblocks = len / SM4_BLOCK_SIZE;
601 
602 	len %= SM4_BLOCK_SIZE;
603 	kernel_neon_begin();
604 	sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, p,
605 			  nblocks, false, true);
606 	kernel_neon_end();
607 	return len;
608 }
609 
610 static int sm4_cmac_finup(struct shash_desc *desc, const u8 *src,
611 			  unsigned int len, u8 *out)
612 {
613 	struct sm4_mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
614 	struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
615 	const u8 *consts = tctx->consts;
616 
617 	crypto_xor(ctx->digest, src, len);
618 	if (len != SM4_BLOCK_SIZE) {
619 		ctx->digest[len] ^= 0x80;
620 		consts += SM4_BLOCK_SIZE;
621 	}
622 	kernel_neon_begin();
623 	sm4_ce_mac_update(tctx->key.rkey_enc, ctx->digest, consts, 1,
624 			  false, true);
625 	kernel_neon_end();
626 	memcpy(out, ctx->digest, SM4_BLOCK_SIZE);
627 	return 0;
628 }
629 
630 static int sm4_cbcmac_finup(struct shash_desc *desc, const u8 *src,
631 			    unsigned int len, u8 *out)
632 {
633 	struct sm4_mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
634 	struct sm4_mac_desc_ctx *ctx = shash_desc_ctx(desc);
635 
636 	if (len) {
637 		crypto_xor(ctx->digest, src, len);
638 		kernel_neon_begin();
639 		sm4_ce_crypt_block(tctx->key.rkey_enc, ctx->digest,
640 				   ctx->digest);
641 		kernel_neon_end();
642 	}
643 	memcpy(out, ctx->digest, SM4_BLOCK_SIZE);
644 	return 0;
645 }
646 
647 static struct shash_alg sm4_mac_algs[] = {
648 	{
649 		.base = {
650 			.cra_name		= "cmac(sm4)",
651 			.cra_driver_name	= "cmac-sm4-ce",
652 			.cra_priority		= 400,
653 			.cra_flags		= CRYPTO_AHASH_ALG_BLOCK_ONLY |
654 						  CRYPTO_AHASH_ALG_FINAL_NONZERO,
655 			.cra_blocksize		= SM4_BLOCK_SIZE,
656 			.cra_ctxsize		= sizeof(struct sm4_mac_tfm_ctx)
657 							+ SM4_BLOCK_SIZE * 2,
658 			.cra_module		= THIS_MODULE,
659 		},
660 		.digestsize	= SM4_BLOCK_SIZE,
661 		.init		= sm4_mac_init,
662 		.update		= sm4_mac_update,
663 		.finup		= sm4_cmac_finup,
664 		.setkey		= sm4_cmac_setkey,
665 		.descsize	= sizeof(struct sm4_mac_desc_ctx),
666 	}, {
667 		.base = {
668 			.cra_name		= "xcbc(sm4)",
669 			.cra_driver_name	= "xcbc-sm4-ce",
670 			.cra_priority		= 400,
671 			.cra_flags		= CRYPTO_AHASH_ALG_BLOCK_ONLY |
672 						  CRYPTO_AHASH_ALG_FINAL_NONZERO,
673 			.cra_blocksize		= SM4_BLOCK_SIZE,
674 			.cra_ctxsize		= sizeof(struct sm4_mac_tfm_ctx)
675 							+ SM4_BLOCK_SIZE * 2,
676 			.cra_module		= THIS_MODULE,
677 		},
678 		.digestsize	= SM4_BLOCK_SIZE,
679 		.init		= sm4_mac_init,
680 		.update		= sm4_mac_update,
681 		.finup		= sm4_cmac_finup,
682 		.setkey		= sm4_xcbc_setkey,
683 		.descsize	= sizeof(struct sm4_mac_desc_ctx),
684 	}, {
685 		.base = {
686 			.cra_name		= "cbcmac(sm4)",
687 			.cra_driver_name	= "cbcmac-sm4-ce",
688 			.cra_priority		= 400,
689 			.cra_flags		= CRYPTO_AHASH_ALG_BLOCK_ONLY,
690 			.cra_blocksize		= SM4_BLOCK_SIZE,
691 			.cra_ctxsize		= sizeof(struct sm4_mac_tfm_ctx),
692 			.cra_module		= THIS_MODULE,
693 		},
694 		.digestsize	= SM4_BLOCK_SIZE,
695 		.init		= sm4_mac_init,
696 		.update		= sm4_mac_update,
697 		.finup		= sm4_cbcmac_finup,
698 		.setkey		= sm4_cbcmac_setkey,
699 		.descsize	= sizeof(struct sm4_mac_desc_ctx),
700 	}
701 };
702 
703 static int __init sm4_init(void)
704 {
705 	int err;
706 
707 	err = crypto_register_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
708 	if (err)
709 		return err;
710 
711 	err = crypto_register_shashes(sm4_mac_algs, ARRAY_SIZE(sm4_mac_algs));
712 	if (err)
713 		goto out_err;
714 
715 	return 0;
716 
717 out_err:
718 	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
719 	return err;
720 }
721 
722 static void __exit sm4_exit(void)
723 {
724 	crypto_unregister_shashes(sm4_mac_algs, ARRAY_SIZE(sm4_mac_algs));
725 	crypto_unregister_skciphers(sm4_algs, ARRAY_SIZE(sm4_algs));
726 }
727 
728 module_cpu_feature_match(SM4, sm4_init);
729 module_exit(sm4_exit);
730 
731 MODULE_DESCRIPTION("SM4 ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
732 MODULE_ALIAS_CRYPTO("sm4-ce");
733 MODULE_ALIAS_CRYPTO("sm4");
734 MODULE_ALIAS_CRYPTO("ecb(sm4)");
735 MODULE_ALIAS_CRYPTO("cbc(sm4)");
736 MODULE_ALIAS_CRYPTO("ctr(sm4)");
737 MODULE_ALIAS_CRYPTO("cts(cbc(sm4))");
738 MODULE_ALIAS_CRYPTO("xts(sm4)");
739 MODULE_ALIAS_CRYPTO("cmac(sm4)");
740 MODULE_ALIAS_CRYPTO("xcbc(sm4)");
741 MODULE_ALIAS_CRYPTO("cbcmac(sm4)");
742 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
743 MODULE_LICENSE("GPL v2");
744