1 /* 2 * Bit sliced AES using NEON instructions 3 * 4 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 5 * 6 * This program is free software; you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 2 as 8 * published by the Free Software Foundation. 9 */ 10 11 #include <asm/neon.h> 12 #include <asm/simd.h> 13 #include <crypto/aes.h> 14 #include <crypto/internal/simd.h> 15 #include <crypto/internal/skcipher.h> 16 #include <crypto/xts.h> 17 #include <linux/module.h> 18 19 #include "aes-ctr-fallback.h" 20 21 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 22 MODULE_LICENSE("GPL v2"); 23 24 MODULE_ALIAS_CRYPTO("ecb(aes)"); 25 MODULE_ALIAS_CRYPTO("cbc(aes)"); 26 MODULE_ALIAS_CRYPTO("ctr(aes)"); 27 MODULE_ALIAS_CRYPTO("xts(aes)"); 28 29 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 30 31 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 32 int rounds, int blocks); 33 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 34 int rounds, int blocks); 35 36 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 37 int rounds, int blocks, u8 iv[]); 38 39 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 40 int rounds, int blocks, u8 iv[], u8 final[]); 41 42 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 43 int rounds, int blocks, u8 iv[]); 44 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 45 int rounds, int blocks, u8 iv[]); 46 47 /* borrowed from aes-neon-blk.ko */ 48 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], 49 int rounds, int blocks); 50 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], 51 int rounds, int blocks, u8 iv[]); 52 53 struct aesbs_ctx { 54 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; 55 int rounds; 56 } __aligned(AES_BLOCK_SIZE); 57 58 struct aesbs_cbc_ctx { 59 struct aesbs_ctx key; 60 u32 enc[AES_MAX_KEYLENGTH_U32]; 61 }; 62 63 struct aesbs_ctr_ctx { 64 struct aesbs_ctx key; /* must be first member */ 65 struct crypto_aes_ctx fallback; 66 }; 67 68 struct aesbs_xts_ctx { 69 struct aesbs_ctx key; 70 u32 twkey[AES_MAX_KEYLENGTH_U32]; 71 }; 72 73 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 74 unsigned int key_len) 75 { 76 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 77 struct crypto_aes_ctx rk; 78 int err; 79 80 err = crypto_aes_expand_key(&rk, in_key, key_len); 81 if (err) 82 return err; 83 84 ctx->rounds = 6 + key_len / 4; 85 86 kernel_neon_begin(); 87 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 88 kernel_neon_end(); 89 90 return 0; 91 } 92 93 static int __ecb_crypt(struct skcipher_request *req, 94 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 95 int rounds, int blocks)) 96 { 97 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 98 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 99 struct skcipher_walk walk; 100 int err; 101 102 err = skcipher_walk_virt(&walk, req, false); 103 104 while (walk.nbytes >= AES_BLOCK_SIZE) { 105 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 106 107 if (walk.nbytes < walk.total) 108 blocks = round_down(blocks, 109 walk.stride / AES_BLOCK_SIZE); 110 111 kernel_neon_begin(); 112 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 113 ctx->rounds, blocks); 114 kernel_neon_end(); 115 err = skcipher_walk_done(&walk, 116 walk.nbytes - blocks * AES_BLOCK_SIZE); 117 } 118 119 return err; 120 } 121 122 static int ecb_encrypt(struct skcipher_request *req) 123 { 124 return __ecb_crypt(req, aesbs_ecb_encrypt); 125 } 126 127 static int ecb_decrypt(struct skcipher_request *req) 128 { 129 return __ecb_crypt(req, aesbs_ecb_decrypt); 130 } 131 132 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 133 unsigned int key_len) 134 { 135 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 136 struct crypto_aes_ctx rk; 137 int err; 138 139 err = crypto_aes_expand_key(&rk, in_key, key_len); 140 if (err) 141 return err; 142 143 ctx->key.rounds = 6 + key_len / 4; 144 145 memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); 146 147 kernel_neon_begin(); 148 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds); 149 kernel_neon_end(); 150 151 return 0; 152 } 153 154 static int cbc_encrypt(struct skcipher_request *req) 155 { 156 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 157 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 158 struct skcipher_walk walk; 159 int err; 160 161 err = skcipher_walk_virt(&walk, req, false); 162 163 while (walk.nbytes >= AES_BLOCK_SIZE) { 164 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 165 166 /* fall back to the non-bitsliced NEON implementation */ 167 kernel_neon_begin(); 168 neon_aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 169 ctx->enc, ctx->key.rounds, blocks, 170 walk.iv); 171 kernel_neon_end(); 172 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); 173 } 174 return err; 175 } 176 177 static int cbc_decrypt(struct skcipher_request *req) 178 { 179 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 180 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 181 struct skcipher_walk walk; 182 int err; 183 184 err = skcipher_walk_virt(&walk, req, false); 185 186 while (walk.nbytes >= AES_BLOCK_SIZE) { 187 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 188 189 if (walk.nbytes < walk.total) 190 blocks = round_down(blocks, 191 walk.stride / AES_BLOCK_SIZE); 192 193 kernel_neon_begin(); 194 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 195 ctx->key.rk, ctx->key.rounds, blocks, 196 walk.iv); 197 kernel_neon_end(); 198 err = skcipher_walk_done(&walk, 199 walk.nbytes - blocks * AES_BLOCK_SIZE); 200 } 201 202 return err; 203 } 204 205 static int aesbs_ctr_setkey_sync(struct crypto_skcipher *tfm, const u8 *in_key, 206 unsigned int key_len) 207 { 208 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 209 int err; 210 211 err = crypto_aes_expand_key(&ctx->fallback, in_key, key_len); 212 if (err) 213 return err; 214 215 ctx->key.rounds = 6 + key_len / 4; 216 217 kernel_neon_begin(); 218 aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds); 219 kernel_neon_end(); 220 221 return 0; 222 } 223 224 static int ctr_encrypt(struct skcipher_request *req) 225 { 226 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 227 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 228 struct skcipher_walk walk; 229 u8 buf[AES_BLOCK_SIZE]; 230 int err; 231 232 err = skcipher_walk_virt(&walk, req, false); 233 234 while (walk.nbytes > 0) { 235 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 236 u8 *final = (walk.total % AES_BLOCK_SIZE) ? buf : NULL; 237 238 if (walk.nbytes < walk.total) { 239 blocks = round_down(blocks, 240 walk.stride / AES_BLOCK_SIZE); 241 final = NULL; 242 } 243 244 kernel_neon_begin(); 245 aesbs_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr, 246 ctx->rk, ctx->rounds, blocks, walk.iv, final); 247 kernel_neon_end(); 248 249 if (final) { 250 u8 *dst = walk.dst.virt.addr + blocks * AES_BLOCK_SIZE; 251 u8 *src = walk.src.virt.addr + blocks * AES_BLOCK_SIZE; 252 253 crypto_xor_cpy(dst, src, final, 254 walk.total % AES_BLOCK_SIZE); 255 256 err = skcipher_walk_done(&walk, 0); 257 break; 258 } 259 err = skcipher_walk_done(&walk, 260 walk.nbytes - blocks * AES_BLOCK_SIZE); 261 } 262 return err; 263 } 264 265 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 266 unsigned int key_len) 267 { 268 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 269 struct crypto_aes_ctx rk; 270 int err; 271 272 err = xts_verify_key(tfm, in_key, key_len); 273 if (err) 274 return err; 275 276 key_len /= 2; 277 err = crypto_aes_expand_key(&rk, in_key + key_len, key_len); 278 if (err) 279 return err; 280 281 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); 282 283 return aesbs_setkey(tfm, in_key, key_len); 284 } 285 286 static int ctr_encrypt_sync(struct skcipher_request *req) 287 { 288 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 289 struct aesbs_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 290 291 if (!may_use_simd()) 292 return aes_ctr_encrypt_fallback(&ctx->fallback, req); 293 294 return ctr_encrypt(req); 295 } 296 297 static int __xts_crypt(struct skcipher_request *req, 298 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 299 int rounds, int blocks, u8 iv[])) 300 { 301 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 302 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 303 struct skcipher_walk walk; 304 int err; 305 306 err = skcipher_walk_virt(&walk, req, false); 307 308 kernel_neon_begin(); 309 neon_aes_ecb_encrypt(walk.iv, walk.iv, ctx->twkey, ctx->key.rounds, 1); 310 kernel_neon_end(); 311 312 while (walk.nbytes >= AES_BLOCK_SIZE) { 313 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 314 315 if (walk.nbytes < walk.total) 316 blocks = round_down(blocks, 317 walk.stride / AES_BLOCK_SIZE); 318 319 kernel_neon_begin(); 320 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk, 321 ctx->key.rounds, blocks, walk.iv); 322 kernel_neon_end(); 323 err = skcipher_walk_done(&walk, 324 walk.nbytes - blocks * AES_BLOCK_SIZE); 325 } 326 return err; 327 } 328 329 static int xts_encrypt(struct skcipher_request *req) 330 { 331 return __xts_crypt(req, aesbs_xts_encrypt); 332 } 333 334 static int xts_decrypt(struct skcipher_request *req) 335 { 336 return __xts_crypt(req, aesbs_xts_decrypt); 337 } 338 339 static struct skcipher_alg aes_algs[] = { { 340 .base.cra_name = "__ecb(aes)", 341 .base.cra_driver_name = "__ecb-aes-neonbs", 342 .base.cra_priority = 250, 343 .base.cra_blocksize = AES_BLOCK_SIZE, 344 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 345 .base.cra_module = THIS_MODULE, 346 .base.cra_flags = CRYPTO_ALG_INTERNAL, 347 348 .min_keysize = AES_MIN_KEY_SIZE, 349 .max_keysize = AES_MAX_KEY_SIZE, 350 .walksize = 8 * AES_BLOCK_SIZE, 351 .setkey = aesbs_setkey, 352 .encrypt = ecb_encrypt, 353 .decrypt = ecb_decrypt, 354 }, { 355 .base.cra_name = "__cbc(aes)", 356 .base.cra_driver_name = "__cbc-aes-neonbs", 357 .base.cra_priority = 250, 358 .base.cra_blocksize = AES_BLOCK_SIZE, 359 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx), 360 .base.cra_module = THIS_MODULE, 361 .base.cra_flags = CRYPTO_ALG_INTERNAL, 362 363 .min_keysize = AES_MIN_KEY_SIZE, 364 .max_keysize = AES_MAX_KEY_SIZE, 365 .walksize = 8 * AES_BLOCK_SIZE, 366 .ivsize = AES_BLOCK_SIZE, 367 .setkey = aesbs_cbc_setkey, 368 .encrypt = cbc_encrypt, 369 .decrypt = cbc_decrypt, 370 }, { 371 .base.cra_name = "__ctr(aes)", 372 .base.cra_driver_name = "__ctr-aes-neonbs", 373 .base.cra_priority = 250, 374 .base.cra_blocksize = 1, 375 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 376 .base.cra_module = THIS_MODULE, 377 .base.cra_flags = CRYPTO_ALG_INTERNAL, 378 379 .min_keysize = AES_MIN_KEY_SIZE, 380 .max_keysize = AES_MAX_KEY_SIZE, 381 .chunksize = AES_BLOCK_SIZE, 382 .walksize = 8 * AES_BLOCK_SIZE, 383 .ivsize = AES_BLOCK_SIZE, 384 .setkey = aesbs_setkey, 385 .encrypt = ctr_encrypt, 386 .decrypt = ctr_encrypt, 387 }, { 388 .base.cra_name = "ctr(aes)", 389 .base.cra_driver_name = "ctr-aes-neonbs", 390 .base.cra_priority = 250 - 1, 391 .base.cra_blocksize = 1, 392 .base.cra_ctxsize = sizeof(struct aesbs_ctr_ctx), 393 .base.cra_module = THIS_MODULE, 394 395 .min_keysize = AES_MIN_KEY_SIZE, 396 .max_keysize = AES_MAX_KEY_SIZE, 397 .chunksize = AES_BLOCK_SIZE, 398 .walksize = 8 * AES_BLOCK_SIZE, 399 .ivsize = AES_BLOCK_SIZE, 400 .setkey = aesbs_ctr_setkey_sync, 401 .encrypt = ctr_encrypt_sync, 402 .decrypt = ctr_encrypt_sync, 403 }, { 404 .base.cra_name = "__xts(aes)", 405 .base.cra_driver_name = "__xts-aes-neonbs", 406 .base.cra_priority = 250, 407 .base.cra_blocksize = AES_BLOCK_SIZE, 408 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 409 .base.cra_module = THIS_MODULE, 410 .base.cra_flags = CRYPTO_ALG_INTERNAL, 411 412 .min_keysize = 2 * AES_MIN_KEY_SIZE, 413 .max_keysize = 2 * AES_MAX_KEY_SIZE, 414 .walksize = 8 * AES_BLOCK_SIZE, 415 .ivsize = AES_BLOCK_SIZE, 416 .setkey = aesbs_xts_setkey, 417 .encrypt = xts_encrypt, 418 .decrypt = xts_decrypt, 419 } }; 420 421 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)]; 422 423 static void aes_exit(void) 424 { 425 int i; 426 427 for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++) 428 if (aes_simd_algs[i]) 429 simd_skcipher_free(aes_simd_algs[i]); 430 431 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 432 } 433 434 static int __init aes_init(void) 435 { 436 struct simd_skcipher_alg *simd; 437 const char *basename; 438 const char *algname; 439 const char *drvname; 440 int err; 441 int i; 442 443 if (!(elf_hwcap & HWCAP_ASIMD)) 444 return -ENODEV; 445 446 err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 447 if (err) 448 return err; 449 450 for (i = 0; i < ARRAY_SIZE(aes_algs); i++) { 451 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL)) 452 continue; 453 454 algname = aes_algs[i].base.cra_name + 2; 455 drvname = aes_algs[i].base.cra_driver_name + 2; 456 basename = aes_algs[i].base.cra_driver_name; 457 simd = simd_skcipher_create_compat(algname, drvname, basename); 458 err = PTR_ERR(simd); 459 if (IS_ERR(simd)) 460 goto unregister_simds; 461 462 aes_simd_algs[i] = simd; 463 } 464 return 0; 465 466 unregister_simds: 467 aes_exit(); 468 return err; 469 } 470 471 module_init(aes_init); 472 module_exit(aes_exit); 473