xref: /linux/arch/x86/crypto/sm4_aesni_avx_glue.c (revision e0c0ab04f6785abaa71b9b8dc252cb1a2072c225)
1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3  * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4  * as specified in
5  * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6  *
7  * Copyright (c) 2021, Alibaba Group.
8  * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9  */
10 
11 #include <asm/fpu/api.h>
12 #include <linux/module.h>
13 #include <linux/crypto.h>
14 #include <linux/kernel.h>
15 #include <crypto/internal/skcipher.h>
16 #include <crypto/sm4.h>
17 #include "sm4-avx.h"
18 
19 #define SM4_CRYPT8_BLOCK_SIZE	(SM4_BLOCK_SIZE * 8)
20 
21 asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
22 				const u8 *src, int nblocks);
23 asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
24 				const u8 *src, int nblocks);
25 asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
26 				const u8 *src, u8 *iv);
27 asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
28 				const u8 *src, u8 *iv);
29 
30 static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
31 			unsigned int key_len)
32 {
33 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
34 
35 	return sm4_expandkey(ctx, key, key_len);
36 }
37 
38 static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
39 {
40 	struct skcipher_walk walk;
41 	unsigned int nbytes;
42 	int err;
43 
44 	err = skcipher_walk_virt(&walk, req, false);
45 
46 	while ((nbytes = walk.nbytes) > 0) {
47 		const u8 *src = walk.src.virt.addr;
48 		u8 *dst = walk.dst.virt.addr;
49 
50 		kernel_fpu_begin();
51 		while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
52 			sm4_aesni_avx_crypt8(rkey, dst, src, 8);
53 			dst += SM4_CRYPT8_BLOCK_SIZE;
54 			src += SM4_CRYPT8_BLOCK_SIZE;
55 			nbytes -= SM4_CRYPT8_BLOCK_SIZE;
56 		}
57 		while (nbytes >= SM4_BLOCK_SIZE) {
58 			unsigned int nblocks = min(nbytes >> 4, 4u);
59 			sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
60 			dst += nblocks * SM4_BLOCK_SIZE;
61 			src += nblocks * SM4_BLOCK_SIZE;
62 			nbytes -= nblocks * SM4_BLOCK_SIZE;
63 		}
64 		kernel_fpu_end();
65 
66 		err = skcipher_walk_done(&walk, nbytes);
67 	}
68 
69 	return err;
70 }
71 
72 int sm4_avx_ecb_encrypt(struct skcipher_request *req)
73 {
74 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
75 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
76 
77 	return ecb_do_crypt(req, ctx->rkey_enc);
78 }
79 EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
80 
81 int sm4_avx_ecb_decrypt(struct skcipher_request *req)
82 {
83 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
84 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
85 
86 	return ecb_do_crypt(req, ctx->rkey_dec);
87 }
88 EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
89 
90 int sm4_cbc_encrypt(struct skcipher_request *req)
91 {
92 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
93 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
94 	struct skcipher_walk walk;
95 	unsigned int nbytes;
96 	int err;
97 
98 	err = skcipher_walk_virt(&walk, req, false);
99 
100 	while ((nbytes = walk.nbytes) > 0) {
101 		const u8 *iv = walk.iv;
102 		const u8 *src = walk.src.virt.addr;
103 		u8 *dst = walk.dst.virt.addr;
104 
105 		while (nbytes >= SM4_BLOCK_SIZE) {
106 			crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
107 			sm4_crypt_block(ctx->rkey_enc, dst, dst);
108 			iv = dst;
109 			src += SM4_BLOCK_SIZE;
110 			dst += SM4_BLOCK_SIZE;
111 			nbytes -= SM4_BLOCK_SIZE;
112 		}
113 		if (iv != walk.iv)
114 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
115 
116 		err = skcipher_walk_done(&walk, nbytes);
117 	}
118 
119 	return err;
120 }
121 EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
122 
123 int sm4_avx_cbc_decrypt(struct skcipher_request *req,
124 			unsigned int bsize, sm4_crypt_func func)
125 {
126 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
127 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
128 	struct skcipher_walk walk;
129 	unsigned int nbytes;
130 	int err;
131 
132 	err = skcipher_walk_virt(&walk, req, false);
133 
134 	while ((nbytes = walk.nbytes) > 0) {
135 		const u8 *src = walk.src.virt.addr;
136 		u8 *dst = walk.dst.virt.addr;
137 
138 		kernel_fpu_begin();
139 
140 		while (nbytes >= bsize) {
141 			func(ctx->rkey_dec, dst, src, walk.iv);
142 			dst += bsize;
143 			src += bsize;
144 			nbytes -= bsize;
145 		}
146 
147 		while (nbytes >= SM4_BLOCK_SIZE) {
148 			u8 keystream[SM4_BLOCK_SIZE * 8];
149 			u8 iv[SM4_BLOCK_SIZE];
150 			unsigned int nblocks = min(nbytes >> 4, 8u);
151 			int i;
152 
153 			sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
154 						src, nblocks);
155 
156 			src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
157 			dst += (nblocks - 1) * SM4_BLOCK_SIZE;
158 			memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
159 
160 			for (i = nblocks - 1; i > 0; i--) {
161 				crypto_xor_cpy(dst, src,
162 					&keystream[i * SM4_BLOCK_SIZE],
163 					SM4_BLOCK_SIZE);
164 				src -= SM4_BLOCK_SIZE;
165 				dst -= SM4_BLOCK_SIZE;
166 			}
167 			crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
168 			memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
169 			dst += nblocks * SM4_BLOCK_SIZE;
170 			src += (nblocks + 1) * SM4_BLOCK_SIZE;
171 			nbytes -= nblocks * SM4_BLOCK_SIZE;
172 		}
173 
174 		kernel_fpu_end();
175 		err = skcipher_walk_done(&walk, nbytes);
176 	}
177 
178 	return err;
179 }
180 EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
181 
182 static int cbc_decrypt(struct skcipher_request *req)
183 {
184 	return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
185 				sm4_aesni_avx_cbc_dec_blk8);
186 }
187 
188 int sm4_avx_ctr_crypt(struct skcipher_request *req,
189 			unsigned int bsize, sm4_crypt_func func)
190 {
191 	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
192 	struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
193 	struct skcipher_walk walk;
194 	unsigned int nbytes;
195 	int err;
196 
197 	err = skcipher_walk_virt(&walk, req, false);
198 
199 	while ((nbytes = walk.nbytes) > 0) {
200 		const u8 *src = walk.src.virt.addr;
201 		u8 *dst = walk.dst.virt.addr;
202 
203 		kernel_fpu_begin();
204 
205 		while (nbytes >= bsize) {
206 			func(ctx->rkey_enc, dst, src, walk.iv);
207 			dst += bsize;
208 			src += bsize;
209 			nbytes -= bsize;
210 		}
211 
212 		while (nbytes >= SM4_BLOCK_SIZE) {
213 			u8 keystream[SM4_BLOCK_SIZE * 8];
214 			unsigned int nblocks = min(nbytes >> 4, 8u);
215 			int i;
216 
217 			for (i = 0; i < nblocks; i++) {
218 				memcpy(&keystream[i * SM4_BLOCK_SIZE],
219 					walk.iv, SM4_BLOCK_SIZE);
220 				crypto_inc(walk.iv, SM4_BLOCK_SIZE);
221 			}
222 			sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
223 					keystream, nblocks);
224 
225 			crypto_xor_cpy(dst, src, keystream,
226 					nblocks * SM4_BLOCK_SIZE);
227 			dst += nblocks * SM4_BLOCK_SIZE;
228 			src += nblocks * SM4_BLOCK_SIZE;
229 			nbytes -= nblocks * SM4_BLOCK_SIZE;
230 		}
231 
232 		kernel_fpu_end();
233 
234 		/* tail */
235 		if (walk.nbytes == walk.total && nbytes > 0) {
236 			u8 keystream[SM4_BLOCK_SIZE];
237 
238 			memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
239 			crypto_inc(walk.iv, SM4_BLOCK_SIZE);
240 
241 			sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
242 
243 			crypto_xor_cpy(dst, src, keystream, nbytes);
244 			dst += nbytes;
245 			src += nbytes;
246 			nbytes = 0;
247 		}
248 
249 		err = skcipher_walk_done(&walk, nbytes);
250 	}
251 
252 	return err;
253 }
254 EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
255 
256 static int ctr_crypt(struct skcipher_request *req)
257 {
258 	return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
259 				sm4_aesni_avx_ctr_enc_blk8);
260 }
261 
262 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
263 	{
264 		.base = {
265 			.cra_name		= "ecb(sm4)",
266 			.cra_driver_name	= "ecb-sm4-aesni-avx",
267 			.cra_priority		= 400,
268 			.cra_blocksize		= SM4_BLOCK_SIZE,
269 			.cra_ctxsize		= sizeof(struct sm4_ctx),
270 			.cra_module		= THIS_MODULE,
271 		},
272 		.min_keysize	= SM4_KEY_SIZE,
273 		.max_keysize	= SM4_KEY_SIZE,
274 		.walksize	= 8 * SM4_BLOCK_SIZE,
275 		.setkey		= sm4_skcipher_setkey,
276 		.encrypt	= sm4_avx_ecb_encrypt,
277 		.decrypt	= sm4_avx_ecb_decrypt,
278 	}, {
279 		.base = {
280 			.cra_name		= "cbc(sm4)",
281 			.cra_driver_name	= "cbc-sm4-aesni-avx",
282 			.cra_priority		= 400,
283 			.cra_blocksize		= SM4_BLOCK_SIZE,
284 			.cra_ctxsize		= sizeof(struct sm4_ctx),
285 			.cra_module		= THIS_MODULE,
286 		},
287 		.min_keysize	= SM4_KEY_SIZE,
288 		.max_keysize	= SM4_KEY_SIZE,
289 		.ivsize		= SM4_BLOCK_SIZE,
290 		.walksize	= 8 * SM4_BLOCK_SIZE,
291 		.setkey		= sm4_skcipher_setkey,
292 		.encrypt	= sm4_cbc_encrypt,
293 		.decrypt	= cbc_decrypt,
294 	}, {
295 		.base = {
296 			.cra_name		= "ctr(sm4)",
297 			.cra_driver_name	= "ctr-sm4-aesni-avx",
298 			.cra_priority		= 400,
299 			.cra_blocksize		= 1,
300 			.cra_ctxsize		= sizeof(struct sm4_ctx),
301 			.cra_module		= THIS_MODULE,
302 		},
303 		.min_keysize	= SM4_KEY_SIZE,
304 		.max_keysize	= SM4_KEY_SIZE,
305 		.ivsize		= SM4_BLOCK_SIZE,
306 		.chunksize	= SM4_BLOCK_SIZE,
307 		.walksize	= 8 * SM4_BLOCK_SIZE,
308 		.setkey		= sm4_skcipher_setkey,
309 		.encrypt	= ctr_crypt,
310 		.decrypt	= ctr_crypt,
311 	}
312 };
313 
314 static int __init sm4_init(void)
315 {
316 	const char *feature_name;
317 
318 	if (!boot_cpu_has(X86_FEATURE_AVX) ||
319 	    !boot_cpu_has(X86_FEATURE_AES) ||
320 	    !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
321 		pr_info("AVX or AES-NI instructions are not detected.\n");
322 		return -ENODEV;
323 	}
324 
325 	if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
326 				&feature_name)) {
327 		pr_info("CPU feature '%s' is not supported.\n", feature_name);
328 		return -ENODEV;
329 	}
330 
331 	return crypto_register_skciphers(sm4_aesni_avx_skciphers,
332 					 ARRAY_SIZE(sm4_aesni_avx_skciphers));
333 }
334 
335 static void __exit sm4_exit(void)
336 {
337 	crypto_unregister_skciphers(sm4_aesni_avx_skciphers,
338 				    ARRAY_SIZE(sm4_aesni_avx_skciphers));
339 }
340 
341 module_init(sm4_init);
342 module_exit(sm4_exit);
343 
344 MODULE_LICENSE("GPL v2");
345 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
346 MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
347 MODULE_ALIAS_CRYPTO("sm4");
348 MODULE_ALIAS_CRYPTO("sm4-aesni-avx");
349