xref: /linux/arch/x86/crypto/sm4_aesni_avx_glue.c (revision 221013afb459e5deb8bd08e29b37050af5586d1c)
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (c) 2021, Alibaba Group.
8  * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10 
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
14 #include <asm/simd.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/sm4.h>
18 #include "sm4-avx.h"
19 
20 #define SM4_CRYPT8_BLOCK_SIZE	(SM4_BLOCK_SIZE * 8)
21 
22 asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23 				const u8 *src, int nblocks);
24 asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25 				const u8 *src, int nblocks);
26 asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27 				const u8 *src, u8 *iv);
28 asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29 				const u8 *src, u8 *iv);
30 
31 static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
32 			unsigned int key_len)
33 {
34 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
35 
36 	return sm4_expandkey(ctx, key, key_len);
37 }
38 
39 static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
40 {
41 	struct skcipher_walk walk;
42 	unsigned int nbytes;
43 	int err;
44 
45 	err = skcipher_walk_virt(&walk, req, false);
46 
47 	while ((nbytes = walk.nbytes) > 0) {
48 		const u8 *src = walk.src.virt.addr;
49 		u8 *dst = walk.dst.virt.addr;
50 
51 		kernel_fpu_begin();
52 		while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
53 			sm4_aesni_avx_crypt8(rkey, dst, src, 8);
54 			dst += SM4_CRYPT8_BLOCK_SIZE;
55 			src += SM4_CRYPT8_BLOCK_SIZE;
56 			nbytes -= SM4_CRYPT8_BLOCK_SIZE;
57 		}
58 		while (nbytes >= SM4_BLOCK_SIZE) {
59 			unsigned int nblocks = min(nbytes >> 4, 4u);
60 			sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
61 			dst += nblocks * SM4_BLOCK_SIZE;
62 			src += nblocks * SM4_BLOCK_SIZE;
63 			nbytes -= nblocks * SM4_BLOCK_SIZE;
64 		}
65 		kernel_fpu_end();
66 
67 		err = skcipher_walk_done(&walk, nbytes);
68 	}
69 
70 	return err;
71 }
72 
73 int sm4_avx_ecb_encrypt(struct skcipher_request *req)
74 {
75 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
76 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
77 
78 	return ecb_do_crypt(req, ctx->rkey_enc);
79 }
80 EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
81 
82 int sm4_avx_ecb_decrypt(struct skcipher_request *req)
83 {
84 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
85 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
86 
87 	return ecb_do_crypt(req, ctx->rkey_dec);
88 }
89 EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
90 
91 int sm4_cbc_encrypt(struct skcipher_request *req)
92 {
93 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
94 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
95 	struct skcipher_walk walk;
96 	unsigned int nbytes;
97 	int err;
98 
99 	err = skcipher_walk_virt(&walk, req, false);
100 
101 	while ((nbytes = walk.nbytes) > 0) {
102 		const u8 *iv = walk.iv;
103 		const u8 *src = walk.src.virt.addr;
104 		u8 *dst = walk.dst.virt.addr;
105 
106 		while (nbytes >= SM4_BLOCK_SIZE) {
107 			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
108 			sm4_crypt_block(ctx->rkey_enc, dst, dst);
109 			iv = dst;
110 			src += SM4_BLOCK_SIZE;
111 			dst += SM4_BLOCK_SIZE;
112 			nbytes -= SM4_BLOCK_SIZE;
113 		}
114 		if (iv != walk.iv)
115 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
116 
117 		err = skcipher_walk_done(&walk, nbytes);
118 	}
119 
120 	return err;
121 }
122 EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
123 
124 int sm4_avx_cbc_decrypt(struct skcipher_request *req,
125 			unsigned int bsize, sm4_crypt_func func)
126 {
127 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
128 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
129 	struct skcipher_walk walk;
130 	unsigned int nbytes;
131 	int err;
132 
133 	err = skcipher_walk_virt(&walk, req, false);
134 
135 	while ((nbytes = walk.nbytes) > 0) {
136 		const u8 *src = walk.src.virt.addr;
137 		u8 *dst = walk.dst.virt.addr;
138 
139 		kernel_fpu_begin();
140 
141 		while (nbytes >= bsize) {
142 			func(ctx->rkey_dec, dst, src, walk.iv);
143 			dst += bsize;
144 			src += bsize;
145 			nbytes -= bsize;
146 		}
147 
148 		while (nbytes >= SM4_BLOCK_SIZE) {
149 			u8 keystream[SM4_BLOCK_SIZE * 8];
150 			u8 iv[SM4_BLOCK_SIZE];
151 			unsigned int nblocks = min(nbytes >> 4, 8u);
152 			int i;
153 
154 			sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
155 						src, nblocks);
156 
157 			src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
158 			dst += (nblocks - 1) * SM4_BLOCK_SIZE;
159 			memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
160 
161 			for (i = nblocks - 1; i > 0; i--) {
162 				crypto_xor_cpy(dst, src,
163 					&keystream[i * SM4_BLOCK_SIZE],
164 					SM4_BLOCK_SIZE);
165 				src -= SM4_BLOCK_SIZE;
166 				dst -= SM4_BLOCK_SIZE;
167 			}
168 			crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
169 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
170 			dst += nblocks * SM4_BLOCK_SIZE;
171 			src += (nblocks + 1) * SM4_BLOCK_SIZE;
172 			nbytes -= nblocks * SM4_BLOCK_SIZE;
173 		}
174 
175 		kernel_fpu_end();
176 		err = skcipher_walk_done(&walk, nbytes);
177 	}
178 
179 	return err;
180 }
181 EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
182 
183 static int cbc_decrypt(struct skcipher_request *req)
184 {
185 	return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
186 				sm4_aesni_avx_cbc_dec_blk8);
187 }
188 
189 int sm4_avx_ctr_crypt(struct skcipher_request *req,
190 			unsigned int bsize, sm4_crypt_func func)
191 {
192 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
193 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
194 	struct skcipher_walk walk;
195 	unsigned int nbytes;
196 	int err;
197 
198 	err = skcipher_walk_virt(&walk, req, false);
199 
200 	while ((nbytes = walk.nbytes) > 0) {
201 		const u8 *src = walk.src.virt.addr;
202 		u8 *dst = walk.dst.virt.addr;
203 
204 		kernel_fpu_begin();
205 
206 		while (nbytes >= bsize) {
207 			func(ctx->rkey_enc, dst, src, walk.iv);
208 			dst += bsize;
209 			src += bsize;
210 			nbytes -= bsize;
211 		}
212 
213 		while (nbytes >= SM4_BLOCK_SIZE) {
214 			u8 keystream[SM4_BLOCK_SIZE * 8];
215 			unsigned int nblocks = min(nbytes >> 4, 8u);
216 			int i;
217 
218 			for (i = 0; i < nblocks; i++) {
219 				memcpy(&keystream[i * SM4_BLOCK_SIZE],
220 					walk.iv, SM4_BLOCK_SIZE);
221 				crypto_inc(walk.iv, SM4_BLOCK_SIZE);
222 			}
223 			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
224 					keystream, nblocks);
225 
226 			crypto_xor_cpy(dst, src, keystream,
227 					nblocks * SM4_BLOCK_SIZE);
228 			dst += nblocks * SM4_BLOCK_SIZE;
229 			src += nblocks * SM4_BLOCK_SIZE;
230 			nbytes -= nblocks * SM4_BLOCK_SIZE;
231 		}
232 
233 		kernel_fpu_end();
234 
235 		/* tail */
236 		if (walk.nbytes == walk.total && nbytes > 0) {
237 			u8 keystream[SM4_BLOCK_SIZE];
238 
239 			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
240 			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
241 
242 			sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
243 
244 			crypto_xor_cpy(dst, src, keystream, nbytes);
245 			dst += nbytes;
246 			src += nbytes;
247 			nbytes = 0;
248 		}
249 
250 		err = skcipher_walk_done(&walk, nbytes);
251 	}
252 
253 	return err;
254 }
255 EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
256 
257 static int ctr_crypt(struct skcipher_request *req)
258 {
259 	return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
260 				sm4_aesni_avx_ctr_enc_blk8);
261 }
262 
263 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
264 	{
265 		.base = {
266 			.cra_name		= "__ecb(sm4)",
267 			.cra_driver_name	= "__ecb-sm4-aesni-avx",
268 			.cra_priority		= 400,
269 			.cra_flags		= CRYPTO_ALG_INTERNAL,
270 			.cra_blocksize		= SM4_BLOCK_SIZE,
271 			.cra_ctxsize		= sizeof(struct sm4_ctx),
272 			.cra_module		= THIS_MODULE,
273 		},
274 		.min_keysize	= SM4_KEY_SIZE,
275 		.max_keysize	= SM4_KEY_SIZE,
276 		.walksize	= 8 * SM4_BLOCK_SIZE,
277 		.setkey		= sm4_skcipher_setkey,
278 		.encrypt	= sm4_avx_ecb_encrypt,
279 		.decrypt	= sm4_avx_ecb_decrypt,
280 	}, {
281 		.base = {
282 			.cra_name		= "__cbc(sm4)",
283 			.cra_driver_name	= "__cbc-sm4-aesni-avx",
284 			.cra_priority		= 400,
285 			.cra_flags		= CRYPTO_ALG_INTERNAL,
286 			.cra_blocksize		= SM4_BLOCK_SIZE,
287 			.cra_ctxsize		= sizeof(struct sm4_ctx),
288 			.cra_module		= THIS_MODULE,
289 		},
290 		.min_keysize	= SM4_KEY_SIZE,
291 		.max_keysize	= SM4_KEY_SIZE,
292 		.ivsize		= SM4_BLOCK_SIZE,
293 		.walksize	= 8 * SM4_BLOCK_SIZE,
294 		.setkey		= sm4_skcipher_setkey,
295 		.encrypt	= sm4_cbc_encrypt,
296 		.decrypt	= cbc_decrypt,
297 	}, {
298 		.base = {
299 			.cra_name		= "__ctr(sm4)",
300 			.cra_driver_name	= "__ctr-sm4-aesni-avx",
301 			.cra_priority		= 400,
302 			.cra_flags		= CRYPTO_ALG_INTERNAL,
303 			.cra_blocksize		= 1,
304 			.cra_ctxsize		= sizeof(struct sm4_ctx),
305 			.cra_module		= THIS_MODULE,
306 		},
307 		.min_keysize	= SM4_KEY_SIZE,
308 		.max_keysize	= SM4_KEY_SIZE,
309 		.ivsize		= SM4_BLOCK_SIZE,
310 		.chunksize	= SM4_BLOCK_SIZE,
311 		.walksize	= 8 * SM4_BLOCK_SIZE,
312 		.setkey		= sm4_skcipher_setkey,
313 		.encrypt	= ctr_crypt,
314 		.decrypt	= ctr_crypt,
315 	}
316 };
317 
318 static struct simd_skcipher_alg *
319 simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
320 
321 static int __init sm4_init(void)
322 {
323 	const char *feature_name;
324 
325 	if (!boot_cpu_has(X86_FEATURE_AVX) ||
326 	    !boot_cpu_has(X86_FEATURE_AES) ||
327 	    !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
328 		pr_info("AVX or AES-NI instructions are not detected.\n");
329 		return -ENODEV;
330 	}
331 
332 	if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
333 				&feature_name)) {
334 		pr_info("CPU feature '%s' is not supported.\n", feature_name);
335 		return -ENODEV;
336 	}
337 
338 	return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
339 					ARRAY_SIZE(sm4_aesni_avx_skciphers),
340 					simd_sm4_aesni_avx_skciphers);
341 }
342 
343 static void __exit sm4_exit(void)
344 {
345 	simd_unregister_skciphers(sm4_aesni_avx_skciphers,
346 					ARRAY_SIZE(sm4_aesni_avx_skciphers),
347 					simd_sm4_aesni_avx_skciphers);
348 }
349 
350 module_init(sm4_init);
351 module_exit(sm4_exit);
352 
353 MODULE_LICENSE("GPL v2");
354 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
355 MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
356 MODULE_ALIAS_CRYPTO("sm4");
357 MODULE_ALIAS_CRYPTO("sm4-aesni-avx");
358