1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Bit sliced AES using NEON instructions 4 * 5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 6 */ 7 8 #include <asm/neon.h> 9 #include <asm/simd.h> 10 #include <crypto/aes.h> 11 #include <crypto/ctr.h> 12 #include <crypto/internal/simd.h> 13 #include <crypto/internal/skcipher.h> 14 #include <crypto/scatterwalk.h> 15 #include <crypto/xts.h> 16 #include <linux/module.h> 17 18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 19 MODULE_LICENSE("GPL v2"); 20 21 MODULE_ALIAS_CRYPTO("ecb(aes)"); 22 MODULE_ALIAS_CRYPTO("cbc(aes)"); 23 MODULE_ALIAS_CRYPTO("ctr(aes)"); 24 MODULE_ALIAS_CRYPTO("xts(aes)"); 25 26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 27 28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 29 int rounds, int blocks); 30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 31 int rounds, int blocks); 32 33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 34 int rounds, int blocks, u8 iv[]); 35 36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 37 int rounds, int blocks, u8 iv[]); 38 39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 40 int rounds, int blocks, u8 iv[]); 41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 42 int rounds, int blocks, u8 iv[]); 43 44 /* borrowed from aes-neon-blk.ko */ 45 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], 46 int rounds, int blocks); 47 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], 48 int rounds, int blocks, u8 iv[]); 49 asmlinkage void neon_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], 50 int rounds, int bytes, u8 ctr[]); 51 asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[], 52 u32 const rk1[], int rounds, int bytes, 53 u32 const rk2[], u8 iv[], int first); 54 asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[], 55 u32 const rk1[], int rounds, int bytes, 56 u32 const rk2[], u8 iv[], int first); 57 58 struct aesbs_ctx { 59 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; 60 int rounds; 61 } __aligned(AES_BLOCK_SIZE); 62 63 struct aesbs_cbc_ctr_ctx { 64 struct aesbs_ctx key; 65 u32 enc[AES_MAX_KEYLENGTH_U32]; 66 }; 67 68 struct aesbs_xts_ctx { 69 struct aesbs_ctx key; 70 u32 twkey[AES_MAX_KEYLENGTH_U32]; 71 struct crypto_aes_ctx cts; 72 }; 73 74 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 75 unsigned int key_len) 76 { 77 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 78 struct crypto_aes_ctx rk; 79 int err; 80 81 err = aes_expandkey(&rk, in_key, key_len); 82 if (err) 83 return err; 84 85 ctx->rounds = 6 + key_len / 4; 86 87 kernel_neon_begin(); 88 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 89 kernel_neon_end(); 90 91 return 0; 92 } 93 94 static int __ecb_crypt(struct skcipher_request *req, 95 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 96 int rounds, int blocks)) 97 { 98 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 99 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 100 struct skcipher_walk walk; 101 int err; 102 103 err = skcipher_walk_virt(&walk, req, false); 104 105 while (walk.nbytes >= AES_BLOCK_SIZE) { 106 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 107 108 if (walk.nbytes < walk.total) 109 blocks = round_down(blocks, 110 walk.stride / AES_BLOCK_SIZE); 111 112 kernel_neon_begin(); 113 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 114 ctx->rounds, blocks); 115 kernel_neon_end(); 116 err = skcipher_walk_done(&walk, 117 walk.nbytes - blocks * AES_BLOCK_SIZE); 118 } 119 120 return err; 121 } 122 123 static int ecb_encrypt(struct skcipher_request *req) 124 { 125 return __ecb_crypt(req, aesbs_ecb_encrypt); 126 } 127 128 static int ecb_decrypt(struct skcipher_request *req) 129 { 130 return __ecb_crypt(req, aesbs_ecb_decrypt); 131 } 132 133 static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 134 unsigned int key_len) 135 { 136 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 137 struct crypto_aes_ctx rk; 138 int err; 139 140 err = aes_expandkey(&rk, in_key, key_len); 141 if (err) 142 return err; 143 144 ctx->key.rounds = 6 + key_len / 4; 145 146 memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); 147 148 kernel_neon_begin(); 149 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds); 150 kernel_neon_end(); 151 memzero_explicit(&rk, sizeof(rk)); 152 153 return 0; 154 } 155 156 static int cbc_encrypt(struct skcipher_request *req) 157 { 158 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 159 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 160 struct skcipher_walk walk; 161 int err; 162 163 err = skcipher_walk_virt(&walk, req, false); 164 165 while (walk.nbytes >= AES_BLOCK_SIZE) { 166 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 167 168 /* fall back to the non-bitsliced NEON implementation */ 169 kernel_neon_begin(); 170 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 171 ctx->enc, ctx->key.rounds, blocks, 172 walk.iv); 173 kernel_neon_end(); 174 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); 175 } 176 return err; 177 } 178 179 static int cbc_decrypt(struct skcipher_request *req) 180 { 181 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 182 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 183 struct skcipher_walk walk; 184 int err; 185 186 err = skcipher_walk_virt(&walk, req, false); 187 188 while (walk.nbytes >= AES_BLOCK_SIZE) { 189 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 190 191 if (walk.nbytes < walk.total) 192 blocks = round_down(blocks, 193 walk.stride / AES_BLOCK_SIZE); 194 195 kernel_neon_begin(); 196 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 197 ctx->key.rk, ctx->key.rounds, blocks, 198 walk.iv); 199 kernel_neon_end(); 200 err = skcipher_walk_done(&walk, 201 walk.nbytes - blocks * AES_BLOCK_SIZE); 202 } 203 204 return err; 205 } 206 207 static int ctr_encrypt(struct skcipher_request *req) 208 { 209 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 210 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 211 struct skcipher_walk walk; 212 int err; 213 214 err = skcipher_walk_virt(&walk, req, false); 215 216 while (walk.nbytes > 0) { 217 int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; 218 int nbytes = walk.nbytes % (8 * AES_BLOCK_SIZE); 219 const u8 *src = walk.src.virt.addr; 220 u8 *dst = walk.dst.virt.addr; 221 222 kernel_neon_begin(); 223 if (blocks >= 8) { 224 aesbs_ctr_encrypt(dst, src, ctx->key.rk, ctx->key.rounds, 225 blocks, walk.iv); 226 dst += blocks * AES_BLOCK_SIZE; 227 src += blocks * AES_BLOCK_SIZE; 228 } 229 if (nbytes && walk.nbytes == walk.total) { 230 u8 buf[AES_BLOCK_SIZE]; 231 u8 *d = dst; 232 233 if (unlikely(nbytes < AES_BLOCK_SIZE)) 234 src = dst = memcpy(buf + sizeof(buf) - nbytes, 235 src, nbytes); 236 237 neon_aes_ctr_encrypt(dst, src, ctx->enc, ctx->key.rounds, 238 nbytes, walk.iv); 239 240 if (unlikely(nbytes < AES_BLOCK_SIZE)) 241 memcpy(d, dst, nbytes); 242 243 nbytes = 0; 244 } 245 kernel_neon_end(); 246 err = skcipher_walk_done(&walk, nbytes); 247 } 248 return err; 249 } 250 251 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 252 unsigned int key_len) 253 { 254 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 255 struct crypto_aes_ctx rk; 256 int err; 257 258 err = xts_verify_key(tfm, in_key, key_len); 259 if (err) 260 return err; 261 262 key_len /= 2; 263 err = aes_expandkey(&ctx->cts, in_key, key_len); 264 if (err) 265 return err; 266 267 err = aes_expandkey(&rk, in_key + key_len, key_len); 268 if (err) 269 return err; 270 271 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); 272 273 return aesbs_setkey(tfm, in_key, key_len); 274 } 275 276 static int __xts_crypt(struct skcipher_request *req, bool encrypt, 277 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 278 int rounds, int blocks, u8 iv[])) 279 { 280 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 281 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 282 int tail = req->cryptlen % (8 * AES_BLOCK_SIZE); 283 struct scatterlist sg_src[2], sg_dst[2]; 284 struct skcipher_request subreq; 285 struct scatterlist *src, *dst; 286 struct skcipher_walk walk; 287 int nbytes, err; 288 int first = 1; 289 u8 *out, *in; 290 291 if (req->cryptlen < AES_BLOCK_SIZE) 292 return -EINVAL; 293 294 /* ensure that the cts tail is covered by a single step */ 295 if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) { 296 int xts_blocks = DIV_ROUND_UP(req->cryptlen, 297 AES_BLOCK_SIZE) - 2; 298 299 skcipher_request_set_tfm(&subreq, tfm); 300 skcipher_request_set_callback(&subreq, 301 skcipher_request_flags(req), 302 NULL, NULL); 303 skcipher_request_set_crypt(&subreq, req->src, req->dst, 304 xts_blocks * AES_BLOCK_SIZE, 305 req->iv); 306 req = &subreq; 307 } else { 308 tail = 0; 309 } 310 311 err = skcipher_walk_virt(&walk, req, false); 312 if (err) 313 return err; 314 315 while (walk.nbytes >= AES_BLOCK_SIZE) { 316 int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; 317 out = walk.dst.virt.addr; 318 in = walk.src.virt.addr; 319 nbytes = walk.nbytes; 320 321 kernel_neon_begin(); 322 if (blocks >= 8) { 323 if (first == 1) 324 neon_aes_ecb_encrypt(walk.iv, walk.iv, 325 ctx->twkey, 326 ctx->key.rounds, 1); 327 first = 2; 328 329 fn(out, in, ctx->key.rk, ctx->key.rounds, blocks, 330 walk.iv); 331 332 out += blocks * AES_BLOCK_SIZE; 333 in += blocks * AES_BLOCK_SIZE; 334 nbytes -= blocks * AES_BLOCK_SIZE; 335 } 336 if (walk.nbytes == walk.total && nbytes > 0) { 337 if (encrypt) 338 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, 339 ctx->key.rounds, nbytes, 340 ctx->twkey, walk.iv, first); 341 else 342 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, 343 ctx->key.rounds, nbytes, 344 ctx->twkey, walk.iv, first); 345 nbytes = first = 0; 346 } 347 kernel_neon_end(); 348 err = skcipher_walk_done(&walk, nbytes); 349 } 350 351 if (err || likely(!tail)) 352 return err; 353 354 /* handle ciphertext stealing */ 355 dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen); 356 if (req->dst != req->src) 357 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen); 358 359 skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail, 360 req->iv); 361 362 err = skcipher_walk_virt(&walk, req, false); 363 if (err) 364 return err; 365 366 out = walk.dst.virt.addr; 367 in = walk.src.virt.addr; 368 nbytes = walk.nbytes; 369 370 kernel_neon_begin(); 371 if (encrypt) 372 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, ctx->key.rounds, 373 nbytes, ctx->twkey, walk.iv, first); 374 else 375 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, ctx->key.rounds, 376 nbytes, ctx->twkey, walk.iv, first); 377 kernel_neon_end(); 378 379 return skcipher_walk_done(&walk, 0); 380 } 381 382 static int xts_encrypt(struct skcipher_request *req) 383 { 384 return __xts_crypt(req, true, aesbs_xts_encrypt); 385 } 386 387 static int xts_decrypt(struct skcipher_request *req) 388 { 389 return __xts_crypt(req, false, aesbs_xts_decrypt); 390 } 391 392 static struct skcipher_alg aes_algs[] = { { 393 .base.cra_name = "ecb(aes)", 394 .base.cra_driver_name = "ecb-aes-neonbs", 395 .base.cra_priority = 250, 396 .base.cra_blocksize = AES_BLOCK_SIZE, 397 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 398 .base.cra_module = THIS_MODULE, 399 400 .min_keysize = AES_MIN_KEY_SIZE, 401 .max_keysize = AES_MAX_KEY_SIZE, 402 .walksize = 8 * AES_BLOCK_SIZE, 403 .setkey = aesbs_setkey, 404 .encrypt = ecb_encrypt, 405 .decrypt = ecb_decrypt, 406 }, { 407 .base.cra_name = "cbc(aes)", 408 .base.cra_driver_name = "cbc-aes-neonbs", 409 .base.cra_priority = 250, 410 .base.cra_blocksize = AES_BLOCK_SIZE, 411 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), 412 .base.cra_module = THIS_MODULE, 413 414 .min_keysize = AES_MIN_KEY_SIZE, 415 .max_keysize = AES_MAX_KEY_SIZE, 416 .walksize = 8 * AES_BLOCK_SIZE, 417 .ivsize = AES_BLOCK_SIZE, 418 .setkey = aesbs_cbc_ctr_setkey, 419 .encrypt = cbc_encrypt, 420 .decrypt = cbc_decrypt, 421 }, { 422 .base.cra_name = "ctr(aes)", 423 .base.cra_driver_name = "ctr-aes-neonbs", 424 .base.cra_priority = 250, 425 .base.cra_blocksize = 1, 426 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), 427 .base.cra_module = THIS_MODULE, 428 429 .min_keysize = AES_MIN_KEY_SIZE, 430 .max_keysize = AES_MAX_KEY_SIZE, 431 .chunksize = AES_BLOCK_SIZE, 432 .walksize = 8 * AES_BLOCK_SIZE, 433 .ivsize = AES_BLOCK_SIZE, 434 .setkey = aesbs_cbc_ctr_setkey, 435 .encrypt = ctr_encrypt, 436 .decrypt = ctr_encrypt, 437 }, { 438 .base.cra_name = "xts(aes)", 439 .base.cra_driver_name = "xts-aes-neonbs", 440 .base.cra_priority = 250, 441 .base.cra_blocksize = AES_BLOCK_SIZE, 442 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 443 .base.cra_module = THIS_MODULE, 444 445 .min_keysize = 2 * AES_MIN_KEY_SIZE, 446 .max_keysize = 2 * AES_MAX_KEY_SIZE, 447 .walksize = 8 * AES_BLOCK_SIZE, 448 .ivsize = AES_BLOCK_SIZE, 449 .setkey = aesbs_xts_setkey, 450 .encrypt = xts_encrypt, 451 .decrypt = xts_decrypt, 452 } }; 453 454 static void aes_exit(void) 455 { 456 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 457 } 458 459 static int __init aes_init(void) 460 { 461 if (!cpu_have_named_feature(ASIMD)) 462 return -ENODEV; 463 464 return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 465 } 466 467 module_init(aes_init); 468 module_exit(aes_exit); 469