1 /* 2 * Bit sliced AES using NEON instructions 3 * 4 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 5 * 6 * This program is free software; you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 2 as 8 * published by the Free Software Foundation. 9 */ 10 11 #include <asm/neon.h> 12 #include <asm/simd.h> 13 #include <crypto/aes.h> 14 #include <crypto/internal/simd.h> 15 #include <crypto/internal/skcipher.h> 16 #include <crypto/xts.h> 17 #include <linux/module.h> 18 19 #include "aes-ctr-fallback.h" 20 21 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 22 MODULE_LICENSE("GPL v2"); 23 24 MODULE_ALIAS_CRYPTO("ecb(aes)"); 25 MODULE_ALIAS_CRYPTO("cbc(aes)"); 26 MODULE_ALIAS_CRYPTO("ctr(aes)"); 27 MODULE_ALIAS_CRYPTO("xts(aes)"); 28 29 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 30 31 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 32 int rounds, int blocks); 33 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 34 int rounds, int blocks); 35 36 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 37 int rounds, int blocks, u8 iv[]); 38 39 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 40 int rounds, int blocks, u8 iv[], u8 final[]); 41 42 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 43 int rounds, int blocks, u8 iv[]); 44 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 45 int rounds, int blocks, u8 iv[]); 46 47 /* borrowed from aes-neon-blk.ko */ 48 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], 49 int rounds, int blocks, int first); 50 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], 51 int rounds, int blocks, u8 iv[], 52 int first); 53 54 struct aesbs_ctx { 55 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; 56 int rounds; 57 } __aligned(AES_BLOCK_SIZE); 58 59 struct aesbs_cbc_ctx { 60 struct aesbs_ctx key; 61 u32 enc[AES_MAX_KEYLENGTH_U32]; 62 }; 63 64 struct aesbs_ctr_ctx { 65 struct aesbs_ctx key; /* must be first member */ 66 struct crypto_aes_ctx fallback; 67 }; 68 69 struct aesbs_xts_ctx { 70 struct aesbs_ctx key; 71 u32 twkey[AES_MAX_KEYLENGTH_U32]; 72 }; 73 74 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 75 unsigned int key_len) 76 { 77 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 78 struct crypto_aes_ctx rk; 79 int err; 80 81 err = crypto_aes_expand_key(&rk, in_key, key_len); 82 if (err) 83 return err; 84 85 ctx->rounds = 6 + key_len / 4; 86 87 kernel_neon_begin(); 88 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 89 kernel_neon_end(); 90 91 return 0; 92 } 93 94 static int __ecb_crypt(struct skcipher_request *req, 95 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 96 int rounds, int blocks)) 97 { 98 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 99 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 100 struct skcipher_walk walk; 101 int err; 102 103 err = skcipher_walk_virt(&walk, req, true); 104 105 kernel_neon_begin(); 106 while (walk.nbytes >= AES_BLOCK_SIZE) { 107 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 108 109 if (walk.nbytes < walk.total) 110 blocks = round_down(blocks, 111 walk.stride / AES_BLOCK_SIZE); 112 113 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 114 ctx->rounds, blocks); 115 err = skcipher_walk_done(&walk, 116 walk.nbytes - blocks * AES_BLOCK_SIZE); 117 } 118 kernel_neon_end(); 119 120 return err; 121 } 122 123 static int ecb_encrypt(struct skcipher_request *req) 124 { 125 return __ecb_crypt(req, aesbs_ecb_encrypt); 126 } 127 128 static int ecb_decrypt(struct skcipher_request *req) 129 { 130 return __ecb_crypt(req, aesbs_ecb_decrypt); 131 } 132 133 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 134 unsigned int key_len) 135 { 136 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 137 struct crypto_aes_ctx rk; 138 int err; 139 140 err = crypto_aes_expand_key(&rk, in_key, key_len); 141 if (err) 142 return err; 143 144 ctx->key.rounds = 6 + key_len / 4; 145 146 memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); 147 148 kernel_neon_begin(); 149 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds); 150 kernel_neon_end(); 151 152 return 0; 153 } 154 155 static int cbc_encrypt(struct skcipher_request *req) 156 { 157 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 158 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 159 struct skcipher_walk walk; 160 int err, first = 1; 161 162 err = skcipher_walk_virt(&walk, req, true); 163 164 kernel_neon_begin(); 165 while (walk.nbytes >= AES_BLOCK_SIZE) { 166 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 167 168 /* fall back to the non-bitsliced NEON implementation */ 169 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 170 ctx->enc, ctx->key.rounds, blocks, walk.iv, 171 first); 172 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); 173 first = 0; 174 } 175 kernel_neon_end(); 176 return err; 177 } 178 179 static int cbc_decrypt(struct skcipher_request *req) 180 { 181 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 182 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 183 struct skcipher_walk walk; 184 int err; 185 186 err = skcipher_walk_virt(&walk, req, true); 187 188 kernel_neon_begin(); 189 while (walk.nbytes >= AES_BLOCK_SIZE) { 190 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 191 192 if (walk.nbytes < walk.total) 193 blocks = round_down(blocks, 194 walk.stride / AES_BLOCK_SIZE); 195 196 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 197 ctx->key.rk, ctx->key.rounds, blocks, 198 walk.iv); 199 err = skcipher_walk_done(&walk, 200 walk.nbytes - blocks * AES_BLOCK_SIZE); 201 } 202 kernel_neon_end(); 203 204 return err; 205 } 206 207 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key, 208 unsigned int key_len) 209 { 210 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 211 int err; 212 213 err = crypto_aes_expand_key(&ctx->fallback, in_key, key_len); 214 if (err) 215 return err; 216 217 ctx->key.rounds = 6 + key_len / 4; 218 219 kernel_neon_begin(); 220 aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds); 221 kernel_neon_end(); 222 223 return 0; 224 } 225 226 static int ctr_encrypt(struct skcipher_request *req) 227 { 228 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 229 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 230 struct skcipher_walk walk; 231 u8 buf[AES_BLOCK_SIZE]; 232 int err; 233 234 err = skcipher_walk_virt(&walk, req, true); 235 236 kernel_neon_begin(); 237 while (walk.nbytes > 0) { 238 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 239 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL; 240 241 if (walk.nbytes < walk.total) { 242 blocks = round_down(blocks, 243 walk.stride / AES_BLOCK_SIZE); 244 final = NULL; 245 } 246 247 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 248 ctx->rk, ctx->rounds, blocks, walk.iv, final); 249 250 if (final) { 251 u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE; 252 u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE; 253 254 crypto_xor_cpy(dst, src, final, 255 walk.total % AES_BLOCK_SIZE); 256 257 err = skcipher_walk_done(&walk, 0); 258 break; 259 } 260 err = skcipher_walk_done(&walk, 261 walk.nbytes - blocks * AES_BLOCK_SIZE); 262 } 263 kernel_neon_end(); 264 265 return err; 266 } 267 268 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 269 unsigned int key_len) 270 { 271 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 272 struct crypto_aes_ctx rk; 273 int err; 274 275 err = xts_verify_key(tfm, in_key, key_len); 276 if (err) 277 return err; 278 279 key_len /= 2; 280 err = crypto_aes_expand_key(&rk, in_key + key_len, key_len); 281 if (err) 282 return err; 283 284 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); 285 286 return aesbs_setkey(tfm, in_key, key_len); 287 } 288 289 static int ctr_encrypt_sync(struct skcipher_request *req) 290 { 291 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 292 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 293 294 if (!may_use_simd()) 295 return aes_ctr_encrypt_fallback(&ctx->fallback, req); 296 297 return ctr_encrypt(req); 298 } 299 300 static int __xts_crypt(struct skcipher_request *req, 301 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 302 int rounds, int blocks, u8 iv[])) 303 { 304 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 305 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 306 struct skcipher_walk walk; 307 int err; 308 309 err = skcipher_walk_virt(&walk, req, true); 310 311 kernel_neon_begin(); 312 313 neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, 314 ctx->key.rounds, 1, 1); 315 316 while (walk.nbytes >= AES_BLOCK_SIZE) { 317 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 318 319 if (walk.nbytes < walk.total) 320 blocks = round_down(blocks, 321 walk.stride / AES_BLOCK_SIZE); 322 323 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk, 324 ctx->key.rounds, blocks, walk.iv); 325 err = skcipher_walk_done(&walk, 326 walk.nbytes - blocks * AES_BLOCK_SIZE); 327 } 328 kernel_neon_end(); 329 330 return err; 331 } 332 333 static int xts_encrypt(struct skcipher_request *req) 334 { 335 return __xts_crypt(req, aesbs_xts_encrypt); 336 } 337 338 static int xts_decrypt(struct skcipher_request *req) 339 { 340 return __xts_crypt(req, aesbs_xts_decrypt); 341 } 342 343 static struct skcipher_alg aes_algs[] = { { 344 .base.cra_name = "__ecb(aes)", 345 .base.cra_driver_name = "__ecb-aes-neonbs", 346 .base.cra_priority = 250, 347 .base.cra_blocksize = AES_BLOCK_SIZE, 348 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 349 .base.cra_module = THIS_MODULE, 350 .base.cra_flags = CRYPTO_ALG_INTERNAL, 351 352 .min_keysize = AES_MIN_KEY_SIZE, 353 .max_keysize = AES_MAX_KEY_SIZE, 354 .walksize = 8 * AES_BLOCK_SIZE, 355 .setkey = aesbs_setkey, 356 .encrypt = ecb_encrypt, 357 .decrypt = ecb_decrypt, 358 }, { 359 .base.cra_name = "__cbc(aes)", 360 .base.cra_driver_name = "__cbc-aes-neonbs", 361 .base.cra_priority = 250, 362 .base.cra_blocksize = AES_BLOCK_SIZE, 363 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx), 364 .base.cra_module = THIS_MODULE, 365 .base.cra_flags = CRYPTO_ALG_INTERNAL, 366 367 .min_keysize = AES_MIN_KEY_SIZE, 368 .max_keysize = AES_MAX_KEY_SIZE, 369 .walksize = 8 * AES_BLOCK_SIZE, 370 .ivsize = AES_BLOCK_SIZE, 371 .setkey = aesbs_cbc_setkey, 372 .encrypt = cbc_encrypt, 373 .decrypt = cbc_decrypt, 374 }, { 375 .base.cra_name = "__ctr(aes)", 376 .base.cra_driver_name = "__ctr-aes-neonbs", 377 .base.cra_priority = 250, 378 .base.cra_blocksize = 1, 379 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 380 .base.cra_module = THIS_MODULE, 381 .base.cra_flags = CRYPTO_ALG_INTERNAL, 382 383 .min_keysize = AES_MIN_KEY_SIZE, 384 .max_keysize = AES_MAX_KEY_SIZE, 385 .chunksize = AES_BLOCK_SIZE, 386 .walksize = 8 * AES_BLOCK_SIZE, 387 .ivsize = AES_BLOCK_SIZE, 388 .setkey = aesbs_setkey, 389 .encrypt = ctr_encrypt, 390 .decrypt = ctr_encrypt, 391 }, { 392 .base.cra_name = "ctr(aes)", 393 .base.cra_driver_name = "ctr-aes-neonbs", 394 .base.cra_priority = 250 - 1, 395 .base.cra_blocksize = 1, 396 .base.cra_ctxsize = sizeof(struct aesbs_ctr_ctx), 397 .base.cra_module = THIS_MODULE, 398 399 .min_keysize = AES_MIN_KEY_SIZE, 400 .max_keysize = AES_MAX_KEY_SIZE, 401 .chunksize = AES_BLOCK_SIZE, 402 .walksize = 8 * AES_BLOCK_SIZE, 403 .ivsize = AES_BLOCK_SIZE, 404 .setkey = aesbs_ctr_setkey_sync, 405 .encrypt = ctr_encrypt_sync, 406 .decrypt = ctr_encrypt_sync, 407 }, { 408 .base.cra_name = "__xts(aes)", 409 .base.cra_driver_name = "__xts-aes-neonbs", 410 .base.cra_priority = 250, 411 .base.cra_blocksize = AES_BLOCK_SIZE, 412 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 413 .base.cra_module = THIS_MODULE, 414 .base.cra_flags = CRYPTO_ALG_INTERNAL, 415 416 .min_keysize = 2 * AES_MIN_KEY_SIZE, 417 .max_keysize = 2 * AES_MAX_KEY_SIZE, 418 .walksize = 8 * AES_BLOCK_SIZE, 419 .ivsize = AES_BLOCK_SIZE, 420 .setkey = aesbs_xts_setkey, 421 .encrypt = xts_encrypt, 422 .decrypt = xts_decrypt, 423 } }; 424 425 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)]; 426 427 static void aes_exit(void) 428 { 429 int i; 430 431 for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++) 432 if (aes_simd_algs[i]) 433 simd_skcipher_free(aes_simd_algs[i]); 434 435 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 436 } 437 438 static int __init aes_init(void) 439 { 440 struct simd_skcipher_alg *simd; 441 const char *basename; 442 const char *algname; 443 const char *drvname; 444 int err; 445 int i; 446 447 if (!(elf_hwcap & HWCAP_ASIMD)) 448 return -ENODEV; 449 450 err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 451 if (err) 452 return err; 453 454 for (i = 0; i < ARRAY_SIZE(aes_algs); i++) { 455 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL)) 456 continue; 457 458 algname = aes_algs[i].base.cra_name + 2; 459 drvname = aes_algs[i].base.cra_driver_name + 2; 460 basename = aes_algs[i].base.cra_driver_name; 461 simd = simd_skcipher_create_compat(algname, drvname, basename); 462 err = PTR_ERR(simd); 463 if (IS_ERR(simd)) 464 goto unregister_simds; 465 466 aes_simd_algs[i] = simd; 467 } 468 return 0; 469 470 unregister_simds: 471 aes_exit(); 472 return err; 473 } 474 475 module_init(aes_init); 476 module_exit(aes_exit); 477