1 /* SPDX-License-Identifier: GPL-2.0-or-later */
2 /*
3 * SM4 Cipher Algorithm, AES-NI/AVX optimized.
4 * as specified in
5 * https://tools.ietf.org/id/draft-ribose-cfrg-sm4-10.html
6 *
7 * Copyright (c) 2021, Alibaba Group.
8 * Copyright (c) 2021 Tianjia Zhang <tianjia.zhang@linux.alibaba.com>
9 */
10
11 #include <linux/module.h>
12 #include <linux/crypto.h>
13 #include <linux/kernel.h>
14 #include <asm/simd.h>
15 #include <crypto/internal/simd.h>
16 #include <crypto/internal/skcipher.h>
17 #include <crypto/sm4.h>
18 #include "sm4-avx.h"
19
20 #define SM4_CRYPT8_BLOCK_SIZE (SM4_BLOCK_SIZE * 8)
21
22 asmlinkage void sm4_aesni_avx_crypt4(const u32 *rk, u8 *dst,
23 const u8 *src, int nblocks);
24 asmlinkage void sm4_aesni_avx_crypt8(const u32 *rk, u8 *dst,
25 const u8 *src, int nblocks);
26 asmlinkage void sm4_aesni_avx_ctr_enc_blk8(const u32 *rk, u8 *dst,
27 const u8 *src, u8 *iv);
28 asmlinkage void sm4_aesni_avx_cbc_dec_blk8(const u32 *rk, u8 *dst,
29 const u8 *src, u8 *iv);
30
sm4_skcipher_setkey(struct crypto_skcipher * tfm,const u8 * key,unsigned int key_len)31 static int sm4_skcipher_setkey(struct crypto_skcipher *tfm, const u8 *key,
32 unsigned int key_len)
33 {
34 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
35
36 return sm4_expandkey(ctx, key, key_len);
37 }
38
ecb_do_crypt(struct skcipher_request * req,const u32 * rkey)39 static int ecb_do_crypt(struct skcipher_request *req, const u32 *rkey)
40 {
41 struct skcipher_walk walk;
42 unsigned int nbytes;
43 int err;
44
45 err = skcipher_walk_virt(&walk, req, false);
46
47 while ((nbytes = walk.nbytes) > 0) {
48 const u8 *src = walk.src.virt.addr;
49 u8 *dst = walk.dst.virt.addr;
50
51 kernel_fpu_begin();
52 while (nbytes >= SM4_CRYPT8_BLOCK_SIZE) {
53 sm4_aesni_avx_crypt8(rkey, dst, src, 8);
54 dst += SM4_CRYPT8_BLOCK_SIZE;
55 src += SM4_CRYPT8_BLOCK_SIZE;
56 nbytes -= SM4_CRYPT8_BLOCK_SIZE;
57 }
58 while (nbytes >= SM4_BLOCK_SIZE) {
59 unsigned int nblocks = min(nbytes >> 4, 4u);
60 sm4_aesni_avx_crypt4(rkey, dst, src, nblocks);
61 dst += nblocks * SM4_BLOCK_SIZE;
62 src += nblocks * SM4_BLOCK_SIZE;
63 nbytes -= nblocks * SM4_BLOCK_SIZE;
64 }
65 kernel_fpu_end();
66
67 err = skcipher_walk_done(&walk, nbytes);
68 }
69
70 return err;
71 }
72
sm4_avx_ecb_encrypt(struct skcipher_request * req)73 int sm4_avx_ecb_encrypt(struct skcipher_request *req)
74 {
75 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
76 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
77
78 return ecb_do_crypt(req, ctx->rkey_enc);
79 }
80 EXPORT_SYMBOL_GPL(sm4_avx_ecb_encrypt);
81
sm4_avx_ecb_decrypt(struct skcipher_request * req)82 int sm4_avx_ecb_decrypt(struct skcipher_request *req)
83 {
84 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
85 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
86
87 return ecb_do_crypt(req, ctx->rkey_dec);
88 }
89 EXPORT_SYMBOL_GPL(sm4_avx_ecb_decrypt);
90
sm4_cbc_encrypt(struct skcipher_request * req)91 int sm4_cbc_encrypt(struct skcipher_request *req)
92 {
93 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
94 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
95 struct skcipher_walk walk;
96 unsigned int nbytes;
97 int err;
98
99 err = skcipher_walk_virt(&walk, req, false);
100
101 while ((nbytes = walk.nbytes) > 0) {
102 const u8 *iv = walk.iv;
103 const u8 *src = walk.src.virt.addr;
104 u8 *dst = walk.dst.virt.addr;
105
106 while (nbytes >= SM4_BLOCK_SIZE) {
107 crypto_xor_cpy(dst, src, iv, SM4_BLOCK_SIZE);
108 sm4_crypt_block(ctx->rkey_enc, dst, dst);
109 iv = dst;
110 src += SM4_BLOCK_SIZE;
111 dst += SM4_BLOCK_SIZE;
112 nbytes -= SM4_BLOCK_SIZE;
113 }
114 if (iv != walk.iv)
115 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
116
117 err = skcipher_walk_done(&walk, nbytes);
118 }
119
120 return err;
121 }
122 EXPORT_SYMBOL_GPL(sm4_cbc_encrypt);
123
sm4_avx_cbc_decrypt(struct skcipher_request * req,unsigned int bsize,sm4_crypt_func func)124 int sm4_avx_cbc_decrypt(struct skcipher_request *req,
125 unsigned int bsize, sm4_crypt_func func)
126 {
127 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
128 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
129 struct skcipher_walk walk;
130 unsigned int nbytes;
131 int err;
132
133 err = skcipher_walk_virt(&walk, req, false);
134
135 while ((nbytes = walk.nbytes) > 0) {
136 const u8 *src = walk.src.virt.addr;
137 u8 *dst = walk.dst.virt.addr;
138
139 kernel_fpu_begin();
140
141 while (nbytes >= bsize) {
142 func(ctx->rkey_dec, dst, src, walk.iv);
143 dst += bsize;
144 src += bsize;
145 nbytes -= bsize;
146 }
147
148 while (nbytes >= SM4_BLOCK_SIZE) {
149 u8 keystream[SM4_BLOCK_SIZE * 8];
150 u8 iv[SM4_BLOCK_SIZE];
151 unsigned int nblocks = min(nbytes >> 4, 8u);
152 int i;
153
154 sm4_aesni_avx_crypt8(ctx->rkey_dec, keystream,
155 src, nblocks);
156
157 src += ((int)nblocks - 2) * SM4_BLOCK_SIZE;
158 dst += (nblocks - 1) * SM4_BLOCK_SIZE;
159 memcpy(iv, src + SM4_BLOCK_SIZE, SM4_BLOCK_SIZE);
160
161 for (i = nblocks - 1; i > 0; i--) {
162 crypto_xor_cpy(dst, src,
163 &keystream[i * SM4_BLOCK_SIZE],
164 SM4_BLOCK_SIZE);
165 src -= SM4_BLOCK_SIZE;
166 dst -= SM4_BLOCK_SIZE;
167 }
168 crypto_xor_cpy(dst, walk.iv, keystream, SM4_BLOCK_SIZE);
169 memcpy(walk.iv, iv, SM4_BLOCK_SIZE);
170 dst += nblocks * SM4_BLOCK_SIZE;
171 src += (nblocks + 1) * SM4_BLOCK_SIZE;
172 nbytes -= nblocks * SM4_BLOCK_SIZE;
173 }
174
175 kernel_fpu_end();
176 err = skcipher_walk_done(&walk, nbytes);
177 }
178
179 return err;
180 }
181 EXPORT_SYMBOL_GPL(sm4_avx_cbc_decrypt);
182
cbc_decrypt(struct skcipher_request * req)183 static int cbc_decrypt(struct skcipher_request *req)
184 {
185 return sm4_avx_cbc_decrypt(req, SM4_CRYPT8_BLOCK_SIZE,
186 sm4_aesni_avx_cbc_dec_blk8);
187 }
188
sm4_avx_ctr_crypt(struct skcipher_request * req,unsigned int bsize,sm4_crypt_func func)189 int sm4_avx_ctr_crypt(struct skcipher_request *req,
190 unsigned int bsize, sm4_crypt_func func)
191 {
192 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
193 struct sm4_ctx *ctx = crypto_skcipher_ctx(tfm);
194 struct skcipher_walk walk;
195 unsigned int nbytes;
196 int err;
197
198 err = skcipher_walk_virt(&walk, req, false);
199
200 while ((nbytes = walk.nbytes) > 0) {
201 const u8 *src = walk.src.virt.addr;
202 u8 *dst = walk.dst.virt.addr;
203
204 kernel_fpu_begin();
205
206 while (nbytes >= bsize) {
207 func(ctx->rkey_enc, dst, src, walk.iv);
208 dst += bsize;
209 src += bsize;
210 nbytes -= bsize;
211 }
212
213 while (nbytes >= SM4_BLOCK_SIZE) {
214 u8 keystream[SM4_BLOCK_SIZE * 8];
215 unsigned int nblocks = min(nbytes >> 4, 8u);
216 int i;
217
218 for (i = 0; i < nblocks; i++) {
219 memcpy(&keystream[i * SM4_BLOCK_SIZE],
220 walk.iv, SM4_BLOCK_SIZE);
221 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
222 }
223 sm4_aesni_avx_crypt8(ctx->rkey_enc, keystream,
224 keystream, nblocks);
225
226 crypto_xor_cpy(dst, src, keystream,
227 nblocks * SM4_BLOCK_SIZE);
228 dst += nblocks * SM4_BLOCK_SIZE;
229 src += nblocks * SM4_BLOCK_SIZE;
230 nbytes -= nblocks * SM4_BLOCK_SIZE;
231 }
232
233 kernel_fpu_end();
234
235 /* tail */
236 if (walk.nbytes == walk.total && nbytes > 0) {
237 u8 keystream[SM4_BLOCK_SIZE];
238
239 memcpy(keystream, walk.iv, SM4_BLOCK_SIZE);
240 crypto_inc(walk.iv, SM4_BLOCK_SIZE);
241
242 sm4_crypt_block(ctx->rkey_enc, keystream, keystream);
243
244 crypto_xor_cpy(dst, src, keystream, nbytes);
245 dst += nbytes;
246 src += nbytes;
247 nbytes = 0;
248 }
249
250 err = skcipher_walk_done(&walk, nbytes);
251 }
252
253 return err;
254 }
255 EXPORT_SYMBOL_GPL(sm4_avx_ctr_crypt);
256
ctr_crypt(struct skcipher_request * req)257 static int ctr_crypt(struct skcipher_request *req)
258 {
259 return sm4_avx_ctr_crypt(req, SM4_CRYPT8_BLOCK_SIZE,
260 sm4_aesni_avx_ctr_enc_blk8);
261 }
262
263 static struct skcipher_alg sm4_aesni_avx_skciphers[] = {
264 {
265 .base = {
266 .cra_name = "__ecb(sm4)",
267 .cra_driver_name = "__ecb-sm4-aesni-avx",
268 .cra_priority = 400,
269 .cra_flags = CRYPTO_ALG_INTERNAL,
270 .cra_blocksize = SM4_BLOCK_SIZE,
271 .cra_ctxsize = sizeof(struct sm4_ctx),
272 .cra_module = THIS_MODULE,
273 },
274 .min_keysize = SM4_KEY_SIZE,
275 .max_keysize = SM4_KEY_SIZE,
276 .walksize = 8 * SM4_BLOCK_SIZE,
277 .setkey = sm4_skcipher_setkey,
278 .encrypt = sm4_avx_ecb_encrypt,
279 .decrypt = sm4_avx_ecb_decrypt,
280 }, {
281 .base = {
282 .cra_name = "__cbc(sm4)",
283 .cra_driver_name = "__cbc-sm4-aesni-avx",
284 .cra_priority = 400,
285 .cra_flags = CRYPTO_ALG_INTERNAL,
286 .cra_blocksize = SM4_BLOCK_SIZE,
287 .cra_ctxsize = sizeof(struct sm4_ctx),
288 .cra_module = THIS_MODULE,
289 },
290 .min_keysize = SM4_KEY_SIZE,
291 .max_keysize = SM4_KEY_SIZE,
292 .ivsize = SM4_BLOCK_SIZE,
293 .walksize = 8 * SM4_BLOCK_SIZE,
294 .setkey = sm4_skcipher_setkey,
295 .encrypt = sm4_cbc_encrypt,
296 .decrypt = cbc_decrypt,
297 }, {
298 .base = {
299 .cra_name = "__ctr(sm4)",
300 .cra_driver_name = "__ctr-sm4-aesni-avx",
301 .cra_priority = 400,
302 .cra_flags = CRYPTO_ALG_INTERNAL,
303 .cra_blocksize = 1,
304 .cra_ctxsize = sizeof(struct sm4_ctx),
305 .cra_module = THIS_MODULE,
306 },
307 .min_keysize = SM4_KEY_SIZE,
308 .max_keysize = SM4_KEY_SIZE,
309 .ivsize = SM4_BLOCK_SIZE,
310 .chunksize = SM4_BLOCK_SIZE,
311 .walksize = 8 * SM4_BLOCK_SIZE,
312 .setkey = sm4_skcipher_setkey,
313 .encrypt = ctr_crypt,
314 .decrypt = ctr_crypt,
315 }
316 };
317
318 static struct simd_skcipher_alg *
319 simd_sm4_aesni_avx_skciphers[ARRAY_SIZE(sm4_aesni_avx_skciphers)];
320
sm4_init(void)321 static int __init sm4_init(void)
322 {
323 const char *feature_name;
324
325 if (!boot_cpu_has(X86_FEATURE_AVX) ||
326 !boot_cpu_has(X86_FEATURE_AES) ||
327 !boot_cpu_has(X86_FEATURE_OSXSAVE)) {
328 pr_info("AVX or AES-NI instructions are not detected.\n");
329 return -ENODEV;
330 }
331
332 if (!cpu_has_xfeatures(XFEATURE_MASK_SSE | XFEATURE_MASK_YMM,
333 &feature_name)) {
334 pr_info("CPU feature '%s' is not supported.\n", feature_name);
335 return -ENODEV;
336 }
337
338 return simd_register_skciphers_compat(sm4_aesni_avx_skciphers,
339 ARRAY_SIZE(sm4_aesni_avx_skciphers),
340 simd_sm4_aesni_avx_skciphers);
341 }
342
sm4_exit(void)343 static void __exit sm4_exit(void)
344 {
345 simd_unregister_skciphers(sm4_aesni_avx_skciphers,
346 ARRAY_SIZE(sm4_aesni_avx_skciphers),
347 simd_sm4_aesni_avx_skciphers);
348 }
349
350 module_init(sm4_init);
351 module_exit(sm4_exit);
352
353 MODULE_LICENSE("GPL v2");
354 MODULE_AUTHOR("Tianjia Zhang <tianjia.zhang@linux.alibaba.com>");
355 MODULE_DESCRIPTION("SM4 Cipher Algorithm, AES-NI/AVX optimized");
356 MODULE_ALIAS_CRYPTO("sm4");
357 MODULE_ALIAS_CRYPTO("sm4-aesni-avx");
358