1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Bit sliced AES using NEON instructions 4 * 5 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 6 */ 7 8 #include <asm/neon.h> 9 #include <asm/simd.h> 10 #include <crypto/aes.h> 11 #include <crypto/ctr.h> 12 #include <crypto/internal/simd.h> 13 #include <crypto/internal/skcipher.h> 14 #include <crypto/scatterwalk.h> 15 #include <crypto/xts.h> 16 #include <linux/module.h> 17 #include "aes-cipher.h" 18 19 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 20 MODULE_DESCRIPTION("Bit sliced AES using NEON instructions"); 21 MODULE_LICENSE("GPL v2"); 22 23 MODULE_ALIAS_CRYPTO("ecb(aes)"); 24 MODULE_ALIAS_CRYPTO("cbc(aes)"); 25 MODULE_ALIAS_CRYPTO("ctr(aes)"); 26 MODULE_ALIAS_CRYPTO("xts(aes)"); 27 28 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 29 30 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 31 int rounds, int blocks); 32 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 33 int rounds, int blocks); 34 35 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 36 int rounds, int blocks, u8 iv[]); 37 38 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 39 int rounds, int blocks, u8 ctr[]); 40 41 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 42 int rounds, int blocks, u8 iv[], int); 43 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 44 int rounds, int blocks, u8 iv[], int); 45 46 struct aesbs_ctx { 47 int rounds; 48 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE); 49 }; 50 51 struct aesbs_cbc_ctx { 52 struct aesbs_ctx key; 53 struct crypto_aes_ctx fallback; 54 }; 55 56 struct aesbs_xts_ctx { 57 struct aesbs_ctx key; 58 struct crypto_aes_ctx fallback; 59 struct crypto_aes_ctx tweak_key; 60 }; 61 62 struct aesbs_ctr_ctx { 63 struct aesbs_ctx key; /* must be first member */ 64 struct crypto_aes_ctx fallback; 65 }; 66 67 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 68 unsigned int key_len) 69 { 70 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 71 struct crypto_aes_ctx rk; 72 int err; 73 74 err = aes_expandkey(&rk, in_key, key_len); 75 if (err) 76 return err; 77 78 ctx->rounds = 6 + key_len / 4; 79 80 kernel_neon_begin(); 81 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 82 kernel_neon_end(); 83 84 return 0; 85 } 86 87 static int __ecb_crypt(struct skcipher_request *req, 88 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 89 int rounds, int blocks)) 90 { 91 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 92 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 93 struct skcipher_walk walk; 94 int err; 95 96 err = skcipher_walk_virt(&walk, req, false); 97 98 while (walk.nbytes >= AES_BLOCK_SIZE) { 99 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 100 101 if (walk.nbytes < walk.total) 102 blocks = round_down(blocks, 103 walk.stride / AES_BLOCK_SIZE); 104 105 kernel_neon_begin(); 106 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 107 ctx->rounds, blocks); 108 kernel_neon_end(); 109 err = skcipher_walk_done(&walk, 110 walk.nbytes - blocks * AES_BLOCK_SIZE); 111 } 112 113 return err; 114 } 115 116 static int ecb_encrypt(struct skcipher_request *req) 117 { 118 return __ecb_crypt(req, aesbs_ecb_encrypt); 119 } 120 121 static int ecb_decrypt(struct skcipher_request *req) 122 { 123 return __ecb_crypt(req, aesbs_ecb_decrypt); 124 } 125 126 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 127 unsigned int key_len) 128 { 129 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 130 int err; 131 132 err = aes_expandkey(&ctx->fallback, in_key, key_len); 133 if (err) 134 return err; 135 136 ctx->key.rounds = 6 + key_len / 4; 137 138 kernel_neon_begin(); 139 aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds); 140 kernel_neon_end(); 141 142 return 0; 143 } 144 145 static int cbc_encrypt(struct skcipher_request *req) 146 { 147 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 148 const struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 149 struct skcipher_walk walk; 150 unsigned int nbytes; 151 int err; 152 153 err = skcipher_walk_virt(&walk, req, false); 154 155 while ((nbytes = walk.nbytes) >= AES_BLOCK_SIZE) { 156 const u8 *src = walk.src.virt.addr; 157 u8 *dst = walk.dst.virt.addr; 158 u8 *prev = walk.iv; 159 160 do { 161 crypto_xor_cpy(dst, src, prev, AES_BLOCK_SIZE); 162 __aes_arm_encrypt(ctx->fallback.key_enc, 163 ctx->key.rounds, dst, dst); 164 prev = dst; 165 src += AES_BLOCK_SIZE; 166 dst += AES_BLOCK_SIZE; 167 nbytes -= AES_BLOCK_SIZE; 168 } while (nbytes >= AES_BLOCK_SIZE); 169 memcpy(walk.iv, prev, AES_BLOCK_SIZE); 170 err = skcipher_walk_done(&walk, nbytes); 171 } 172 return err; 173 } 174 175 static int cbc_decrypt(struct skcipher_request *req) 176 { 177 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 178 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 179 struct skcipher_walk walk; 180 int err; 181 182 err = skcipher_walk_virt(&walk, req, false); 183 184 while (walk.nbytes >= AES_BLOCK_SIZE) { 185 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 186 187 if (walk.nbytes < walk.total) 188 blocks = round_down(blocks, 189 walk.stride / AES_BLOCK_SIZE); 190 191 kernel_neon_begin(); 192 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 193 ctx->key.rk, ctx->key.rounds, blocks, 194 walk.iv); 195 kernel_neon_end(); 196 err = skcipher_walk_done(&walk, 197 walk.nbytes - blocks * AES_BLOCK_SIZE); 198 } 199 200 return err; 201 } 202 203 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key, 204 unsigned int key_len) 205 { 206 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 207 int err; 208 209 err = aes_expandkey(&ctx->fallback, in_key, key_len); 210 if (err) 211 return err; 212 213 ctx->key.rounds = 6 + key_len / 4; 214 215 kernel_neon_begin(); 216 aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds); 217 kernel_neon_end(); 218 219 return 0; 220 } 221 222 static int ctr_encrypt(struct skcipher_request *req) 223 { 224 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 225 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 226 struct skcipher_walk walk; 227 u8 buf[AES_BLOCK_SIZE]; 228 int err; 229 230 err = skcipher_walk_virt(&walk, req, false); 231 232 while (walk.nbytes > 0) { 233 const u8 *src = walk.src.virt.addr; 234 u8 *dst = walk.dst.virt.addr; 235 int bytes = walk.nbytes; 236 237 if (unlikely(bytes < AES_BLOCK_SIZE)) 238 src = dst = memcpy(buf + sizeof(buf) - bytes, 239 src, bytes); 240 else if (walk.nbytes < walk.total) 241 bytes &= ~(8 * AES_BLOCK_SIZE - 1); 242 243 kernel_neon_begin(); 244 aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv); 245 kernel_neon_end(); 246 247 if (unlikely(bytes < AES_BLOCK_SIZE)) 248 memcpy(walk.dst.virt.addr, 249 buf + sizeof(buf) - bytes, bytes); 250 251 err = skcipher_walk_done(&walk, walk.nbytes - bytes); 252 } 253 254 return err; 255 } 256 257 static void ctr_encrypt_one(struct crypto_skcipher *tfm, const u8 *src, u8 *dst) 258 { 259 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 260 261 __aes_arm_encrypt(ctx->fallback.key_enc, ctx->key.rounds, src, dst); 262 } 263 264 static int ctr_encrypt_sync(struct skcipher_request *req) 265 { 266 if (!crypto_simd_usable()) 267 return crypto_ctr_encrypt_walk(req, ctr_encrypt_one); 268 269 return ctr_encrypt(req); 270 } 271 272 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 273 unsigned int key_len) 274 { 275 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 276 int err; 277 278 err = xts_verify_key(tfm, in_key, key_len); 279 if (err) 280 return err; 281 282 key_len /= 2; 283 err = aes_expandkey(&ctx->fallback, in_key, key_len); 284 if (err) 285 return err; 286 err = aes_expandkey(&ctx->tweak_key, in_key + key_len, key_len); 287 if (err) 288 return err; 289 290 return aesbs_setkey(tfm, in_key, key_len); 291 } 292 293 static int __xts_crypt(struct skcipher_request *req, bool encrypt, 294 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 295 int rounds, int blocks, u8 iv[], int)) 296 { 297 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 298 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 299 const int rounds = ctx->key.rounds; 300 int tail = req->cryptlen % AES_BLOCK_SIZE; 301 struct skcipher_request subreq; 302 u8 buf[2 * AES_BLOCK_SIZE]; 303 struct skcipher_walk walk; 304 int err; 305 306 if (req->cryptlen < AES_BLOCK_SIZE) 307 return -EINVAL; 308 309 if (unlikely(tail)) { 310 skcipher_request_set_tfm(&subreq, tfm); 311 skcipher_request_set_callback(&subreq, 312 skcipher_request_flags(req), 313 NULL, NULL); 314 skcipher_request_set_crypt(&subreq, req->src, req->dst, 315 req->cryptlen - tail, req->iv); 316 req = &subreq; 317 } 318 319 err = skcipher_walk_virt(&walk, req, true); 320 if (err) 321 return err; 322 323 __aes_arm_encrypt(ctx->tweak_key.key_enc, rounds, walk.iv, walk.iv); 324 325 while (walk.nbytes >= AES_BLOCK_SIZE) { 326 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 327 int reorder_last_tweak = !encrypt && tail > 0; 328 329 if (walk.nbytes < walk.total) { 330 blocks = round_down(blocks, 331 walk.stride / AES_BLOCK_SIZE); 332 reorder_last_tweak = 0; 333 } 334 335 kernel_neon_begin(); 336 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk, 337 rounds, blocks, walk.iv, reorder_last_tweak); 338 kernel_neon_end(); 339 err = skcipher_walk_done(&walk, 340 walk.nbytes - blocks * AES_BLOCK_SIZE); 341 } 342 343 if (err || likely(!tail)) 344 return err; 345 346 /* handle ciphertext stealing */ 347 scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE, 348 AES_BLOCK_SIZE, 0); 349 memcpy(buf + AES_BLOCK_SIZE, buf, tail); 350 scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0); 351 352 crypto_xor(buf, req->iv, AES_BLOCK_SIZE); 353 354 if (encrypt) 355 __aes_arm_encrypt(ctx->fallback.key_enc, rounds, buf, buf); 356 else 357 __aes_arm_decrypt(ctx->fallback.key_dec, rounds, buf, buf); 358 359 crypto_xor(buf, req->iv, AES_BLOCK_SIZE); 360 361 scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE, 362 AES_BLOCK_SIZE + tail, 1); 363 return 0; 364 } 365 366 static int xts_encrypt(struct skcipher_request *req) 367 { 368 return __xts_crypt(req, true, aesbs_xts_encrypt); 369 } 370 371 static int xts_decrypt(struct skcipher_request *req) 372 { 373 return __xts_crypt(req, false, aesbs_xts_decrypt); 374 } 375 376 static struct skcipher_alg aes_algs[] = { { 377 .base.cra_name = "__ecb(aes)", 378 .base.cra_driver_name = "__ecb-aes-neonbs", 379 .base.cra_priority = 250, 380 .base.cra_blocksize = AES_BLOCK_SIZE, 381 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 382 .base.cra_module = THIS_MODULE, 383 .base.cra_flags = CRYPTO_ALG_INTERNAL, 384 385 .min_keysize = AES_MIN_KEY_SIZE, 386 .max_keysize = AES_MAX_KEY_SIZE, 387 .walksize = 8 * AES_BLOCK_SIZE, 388 .setkey = aesbs_setkey, 389 .encrypt = ecb_encrypt, 390 .decrypt = ecb_decrypt, 391 }, { 392 .base.cra_name = "__cbc(aes)", 393 .base.cra_driver_name = "__cbc-aes-neonbs", 394 .base.cra_priority = 250, 395 .base.cra_blocksize = AES_BLOCK_SIZE, 396 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx), 397 .base.cra_module = THIS_MODULE, 398 .base.cra_flags = CRYPTO_ALG_INTERNAL, 399 400 .min_keysize = AES_MIN_KEY_SIZE, 401 .max_keysize = AES_MAX_KEY_SIZE, 402 .walksize = 8 * AES_BLOCK_SIZE, 403 .ivsize = AES_BLOCK_SIZE, 404 .setkey = aesbs_cbc_setkey, 405 .encrypt = cbc_encrypt, 406 .decrypt = cbc_decrypt, 407 }, { 408 .base.cra_name = "__ctr(aes)", 409 .base.cra_driver_name = "__ctr-aes-neonbs", 410 .base.cra_priority = 250, 411 .base.cra_blocksize = 1, 412 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 413 .base.cra_module = THIS_MODULE, 414 .base.cra_flags = CRYPTO_ALG_INTERNAL, 415 416 .min_keysize = AES_MIN_KEY_SIZE, 417 .max_keysize = AES_MAX_KEY_SIZE, 418 .chunksize = AES_BLOCK_SIZE, 419 .walksize = 8 * AES_BLOCK_SIZE, 420 .ivsize = AES_BLOCK_SIZE, 421 .setkey = aesbs_setkey, 422 .encrypt = ctr_encrypt, 423 .decrypt = ctr_encrypt, 424 }, { 425 .base.cra_name = "ctr(aes)", 426 .base.cra_driver_name = "ctr-aes-neonbs-sync", 427 .base.cra_priority = 250 - 1, 428 .base.cra_blocksize = 1, 429 .base.cra_ctxsize = sizeof(struct aesbs_ctr_ctx), 430 .base.cra_module = THIS_MODULE, 431 432 .min_keysize = AES_MIN_KEY_SIZE, 433 .max_keysize = AES_MAX_KEY_SIZE, 434 .chunksize = AES_BLOCK_SIZE, 435 .walksize = 8 * AES_BLOCK_SIZE, 436 .ivsize = AES_BLOCK_SIZE, 437 .setkey = aesbs_ctr_setkey_sync, 438 .encrypt = ctr_encrypt_sync, 439 .decrypt = ctr_encrypt_sync, 440 }, { 441 .base.cra_name = "__xts(aes)", 442 .base.cra_driver_name = "__xts-aes-neonbs", 443 .base.cra_priority = 250, 444 .base.cra_blocksize = AES_BLOCK_SIZE, 445 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 446 .base.cra_module = THIS_MODULE, 447 .base.cra_flags = CRYPTO_ALG_INTERNAL, 448 449 .min_keysize = 2 * AES_MIN_KEY_SIZE, 450 .max_keysize = 2 * AES_MAX_KEY_SIZE, 451 .walksize = 8 * AES_BLOCK_SIZE, 452 .ivsize = AES_BLOCK_SIZE, 453 .setkey = aesbs_xts_setkey, 454 .encrypt = xts_encrypt, 455 .decrypt = xts_decrypt, 456 } }; 457 458 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)]; 459 460 static void aes_exit(void) 461 { 462 int i; 463 464 for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++) 465 if (aes_simd_algs[i]) 466 simd_skcipher_free(aes_simd_algs[i]); 467 468 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 469 } 470 471 static int __init aes_init(void) 472 { 473 struct simd_skcipher_alg *simd; 474 const char *basename; 475 const char *algname; 476 const char *drvname; 477 int err; 478 int i; 479 480 if (!(elf_hwcap & HWCAP_NEON)) 481 return -ENODEV; 482 483 err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 484 if (err) 485 return err; 486 487 for (i = 0; i < ARRAY_SIZE(aes_algs); i++) { 488 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL)) 489 continue; 490 491 algname = aes_algs[i].base.cra_name + 2; 492 drvname = aes_algs[i].base.cra_driver_name + 2; 493 basename = aes_algs[i].base.cra_driver_name; 494 simd = simd_skcipher_create_compat(aes_algs + i, algname, drvname, basename); 495 err = PTR_ERR(simd); 496 if (IS_ERR(simd)) 497 goto unregister_simds; 498 499 aes_simd_algs[i] = simd; 500 } 501 return 0; 502 503 unregister_simds: 504 aes_exit(); 505 return err; 506 } 507 508 late_initcall(aes_init); 509 module_exit(aes_exit); 510