xref: /linux/arch/x86/crypto/sm4_aesni_avx_glue.c (revision 746680ec6696585e30db3e18c93a63df9cbec39c)
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (c) 2021, Alibaba Group.
8  * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10 
11 #include <asm/fpu/api.h>
12 #include <linux/module.h>
13 #include <linux/crypto.h>
14 #include <linux/export.h>
15 #include <linux/kernel.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/sm4.h>
18 #include "sm4-avx.h"
19 
20 #define SM4_CRYPT8_BLOCK_SIZE	(SM4_BLOCK_SIZE * 8)
21 
22 asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23 				const u8 *src, int nblocks);
24 asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25 				const u8 *src, int nblocks);
26 asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27 				const u8 *src, u8 *iv);
28 asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29 				const u8 *src, u8 *iv);
30 
31 static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
32 			unsigned int key_len)
33 {
34 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
35 
36 	return sm4_expandkey(ctx, key, key_len);
37 }
38 
39 static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
40 {
41 	struct skcipher_walk walk;
42 	unsigned int nbytes;
43 	int err;
44 
45 	err = skcipher_walk_virt(&walk, req, false);
46 
47 	while ((nbytes = walk.nbytes) > 0) {
48 		const u8 *src = walk.src.virt.addr;
49 		u8 *dst = walk.dst.virt.addr;
50 
51 		kernel_fpu_begin();
52 		while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
53 			sm4_aesni_avx_crypt8(rkey, dst, src, 8);
54 			dst += SM4_CRYPT8_BLOCK_SIZE;
55 			src += SM4_CRYPT8_BLOCK_SIZE;
56 			nbytes -= SM4_CRYPT8_BLOCK_SIZE;
57 		}
58 		while (nbytes >= SM4_BLOCK_SIZE) {
59 			unsigned int nblocks = min(nbytes >> 4, 4u);
60 			sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
61 			dst += nblocks * SM4_BLOCK_SIZE;
62 			src += nblocks * SM4_BLOCK_SIZE;
63 			nbytes -= nblocks * SM4_BLOCK_SIZE;
64 		}
65 		kernel_fpu_end();
66 
67 		err = skcipher_walk_done(&walk, nbytes);
68 	}
69 
70 	return err;
71 }
72 
73 int sm4_avx_ecb_encrypt(struct skcipher_request *req)
74 {
75 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
76 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
77 
78 	return ecb_do_crypt(req, ctx->rkey_enc);
79 }
80 EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
81 
82 int sm4_avx_ecb_decrypt(struct skcipher_request *req)
83 {
84 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
85 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
86 
87 	return ecb_do_crypt(req, ctx->rkey_dec);
88 }
89 EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
90 
91 int sm4_cbc_encrypt(struct skcipher_request *req)
92 {
93 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
94 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
95 	struct skcipher_walk walk;
96 	unsigned int nbytes;
97 	int err;
98 
99 	err = skcipher_walk_virt(&walk, req, false);
100 
101 	while ((nbytes = walk.nbytes) > 0) {
102 		const u8 *iv = walk.iv;
103 		const u8 *src = walk.src.virt.addr;
104 		u8 *dst = walk.dst.virt.addr;
105 
106 		while (nbytes >= SM4_BLOCK_SIZE) {
107 			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
108 			sm4_crypt_block(ctx->rkey_enc, dst, dst);
109 			iv = dst;
110 			src += SM4_BLOCK_SIZE;
111 			dst += SM4_BLOCK_SIZE;
112 			nbytes -= SM4_BLOCK_SIZE;
113 		}
114 		if (iv != walk.iv)
115 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
116 
117 		err = skcipher_walk_done(&walk, nbytes);
118 	}
119 
120 	return err;
121 }
122 EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
123 
124 int sm4_avx_cbc_decrypt(struct skcipher_request *req,
125 			unsigned int bsize, sm4_crypt_func func)
126 {
127 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
128 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
129 	struct skcipher_walk walk;
130 	unsigned int nbytes;
131 	int err;
132 
133 	err = skcipher_walk_virt(&walk, req, false);
134 
135 	while ((nbytes = walk.nbytes) > 0) {
136 		const u8 *src = walk.src.virt.addr;
137 		u8 *dst = walk.dst.virt.addr;
138 
139 		kernel_fpu_begin();
140 
141 		while (nbytes >= bsize) {
142 			func(ctx->rkey_dec, dst, src, walk.iv);
143 			dst += bsize;
144 			src += bsize;
145 			nbytes -= bsize;
146 		}
147 
148 		while (nbytes >= SM4_BLOCK_SIZE) {
149 			u8 keystream[SM4_BLOCK_SIZE * 8];
150 			u8 iv[SM4_BLOCK_SIZE];
151 			unsigned int nblocks = min(nbytes >> 4, 8u);
152 			int i;
153 
154 			sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
155 						src, nblocks);
156 
157 			src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
158 			dst += (nblocks - 1) * SM4_BLOCK_SIZE;
159 			memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
160 
161 			for (i = nblocks - 1; i > 0; i--) {
162 				crypto_xor_cpy(dst, src,
163 					&keystream[i * SM4_BLOCK_SIZE],
164 					SM4_BLOCK_SIZE);
165 				src -= SM4_BLOCK_SIZE;
166 				dst -= SM4_BLOCK_SIZE;
167 			}
168 			crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
169 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
170 			dst += nblocks * SM4_BLOCK_SIZE;
171 			src += (nblocks + 1) * SM4_BLOCK_SIZE;
172 			nbytes -= nblocks * SM4_BLOCK_SIZE;
173 		}
174 
175 		kernel_fpu_end();
176 		err = skcipher_walk_done(&walk, nbytes);
177 	}
178 
179 	return err;
180 }
181 EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
182 
183 static int cbc_decrypt(struct skcipher_request *req)
184 {
185 	return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
186 				sm4_aesni_avx_cbc_dec_blk8);
187 }
188 
189 int sm4_avx_ctr_crypt(struct skcipher_request *req,
190 			unsigned int bsize, sm4_crypt_func func)
191 {
192 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
193 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
194 	struct skcipher_walk walk;
195 	unsigned int nbytes;
196 	int err;
197 
198 	err = skcipher_walk_virt(&walk, req, false);
199 
200 	while ((nbytes = walk.nbytes) > 0) {
201 		const u8 *src = walk.src.virt.addr;
202 		u8 *dst = walk.dst.virt.addr;
203 
204 		kernel_fpu_begin();
205 
206 		while (nbytes >= bsize) {
207 			func(ctx->rkey_enc, dst, src, walk.iv);
208 			dst += bsize;
209 			src += bsize;
210 			nbytes -= bsize;
211 		}
212 
213 		while (nbytes >= SM4_BLOCK_SIZE) {
214 			u8 keystream[SM4_BLOCK_SIZE * 8];
215 			unsigned int nblocks = min(nbytes >> 4, 8u);
216 			int i;
217 
218 			for (i = 0; i < nblocks; i++) {
219 				memcpy(&keystream[i * SM4_BLOCK_SIZE],
220 					walk.iv, SM4_BLOCK_SIZE);
221 				crypto_inc(walk.iv, SM4_BLOCK_SIZE);
222 			}
223 			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
224 					keystream, nblocks);
225 
226 			crypto_xor_cpy(dst, src, keystream,
227 					nblocks * SM4_BLOCK_SIZE);
228 			dst += nblocks * SM4_BLOCK_SIZE;
229 			src += nblocks * SM4_BLOCK_SIZE;
230 			nbytes -= nblocks * SM4_BLOCK_SIZE;
231 		}
232 
233 		kernel_fpu_end();
234 
235 		/* tail */
236 		if (walk.nbytes == walk.total && nbytes > 0) {
237 			u8 keystream[SM4_BLOCK_SIZE];
238 
239 			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
240 			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
241 
242 			sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
243 
244 			crypto_xor_cpy(dst, src, keystream, nbytes);
245 			dst += nbytes;
246 			src += nbytes;
247 			nbytes = 0;
248 		}
249 
250 		err = skcipher_walk_done(&walk, nbytes);
251 	}
252 
253 	return err;
254 }
255 EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
256 
257 static int ctr_crypt(struct skcipher_request *req)
258 {
259 	return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
260 				sm4_aesni_avx_ctr_enc_blk8);
261 }
262 
263 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
264 	{
265 		.base = {
266 			.cra_name		= "ecb(sm4)",
267 			.cra_driver_name	= "ecb-sm4-aesni-avx",
268 			.cra_priority		= 400,
269 			.cra_blocksize		= SM4_BLOCK_SIZE,
270 			.cra_ctxsize		= sizeof(struct sm4_ctx),
271 			.cra_module		= THIS_MODULE,
272 		},
273 		.min_keysize	= SM4_KEY_SIZE,
274 		.max_keysize	= SM4_KEY_SIZE,
275 		.walksize	= 8 * SM4_BLOCK_SIZE,
276 		.setkey		= sm4_skcipher_setkey,
277 		.encrypt	= sm4_avx_ecb_encrypt,
278 		.decrypt	= sm4_avx_ecb_decrypt,
279 	}, {
280 		.base = {
281 			.cra_name		= "cbc(sm4)",
282 			.cra_driver_name	= "cbc-sm4-aesni-avx",
283 			.cra_priority		= 400,
284 			.cra_blocksize		= SM4_BLOCK_SIZE,
285 			.cra_ctxsize		= sizeof(struct sm4_ctx),
286 			.cra_module		= THIS_MODULE,
287 		},
288 		.min_keysize	= SM4_KEY_SIZE,
289 		.max_keysize	= SM4_KEY_SIZE,
290 		.ivsize		= SM4_BLOCK_SIZE,
291 		.walksize	= 8 * SM4_BLOCK_SIZE,
292 		.setkey		= sm4_skcipher_setkey,
293 		.encrypt	= sm4_cbc_encrypt,
294 		.decrypt	= cbc_decrypt,
295 	}, {
296 		.base = {
297 			.cra_name		= "ctr(sm4)",
298 			.cra_driver_name	= "ctr-sm4-aesni-avx",
299 			.cra_priority		= 400,
300 			.cra_blocksize		= 1,
301 			.cra_ctxsize		= sizeof(struct sm4_ctx),
302 			.cra_module		= THIS_MODULE,
303 		},
304 		.min_keysize	= SM4_KEY_SIZE,
305 		.max_keysize	= SM4_KEY_SIZE,
306 		.ivsize		= SM4_BLOCK_SIZE,
307 		.chunksize	= SM4_BLOCK_SIZE,
308 		.walksize	= 8 * SM4_BLOCK_SIZE,
309 		.setkey		= sm4_skcipher_setkey,
310 		.encrypt	= ctr_crypt,
311 		.decrypt	= ctr_crypt,
312 	}
313 };
314 
315 static int __init sm4_init(void)
316 {
317 	const char *feature_name;
318 
319 	if (!boot_cpu_has(X86_FEATURE_AVX) ||
320 	    !boot_cpu_has(X86_FEATURE_AES) ||
321 	    !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
322 		pr_info("AVX or AES-NI instructions are not detected.\n");
323 		return -ENODEV;
324 	}
325 
326 	if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
327 				&feature_name)) {
328 		pr_info("CPU feature '%s' is not supported.\n", feature_name);
329 		return -ENODEV;
330 	}
331 
332 	return crypto_register_skciphers(sm4_aesni_avx_skciphers,
333 					 ARRAY_SIZE(sm4_aesni_avx_skciphers));
334 }
335 
336 static void __exit sm4_exit(void)
337 {
338 	crypto_unregister_skciphers(sm4_aesni_avx_skciphers,
339 				    ARRAY_SIZE(sm4_aesni_avx_skciphers));
340 }
341 
342 module_init(sm4_init);
343 module_exit(sm4_exit);
344 
345 MODULE_LICENSE("GPL v2");
346 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
347 MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
348 MODULE_ALIAS_CRYPTO("sm4");
349 MODULE_ALIAS_CRYPTO("sm4-aesni-avx");
350