1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Bit sliced AES using NEON instructions 4 * 5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 6 */ 7 8 #include <asm/neon.h> 9 #include <asm/simd.h> 10 #include <crypto/aes.h> 11 #include <crypto/ctr.h> 12 #include <crypto/internal/simd.h> 13 #include <crypto/internal/skcipher.h> 14 #include <crypto/scatterwalk.h> 15 #include <crypto/xts.h> 16 #include <linux/module.h> 17 18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 19 MODULE_DESCRIPTION("Bit sliced AES using NEON instructions"); 20 MODULE_LICENSE("GPL v2"); 21 22 MODULE_ALIAS_CRYPTO("ecb(aes)"); 23 MODULE_ALIAS_CRYPTO("cbc(aes)"); 24 MODULE_ALIAS_CRYPTO("ctr(aes)"); 25 MODULE_ALIAS_CRYPTO("xts(aes)"); 26 27 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 28 29 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 30 int rounds, int blocks); 31 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 32 int rounds, int blocks); 33 34 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 35 int rounds, int blocks, u8 iv[]); 36 37 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 38 int rounds, int blocks, u8 iv[]); 39 40 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 41 int rounds, int blocks, u8 iv[]); 42 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 43 int rounds, int blocks, u8 iv[]); 44 45 /* borrowed from aes-neon-blk.ko */ 46 asmlinkage void neon_aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[], 47 int rounds, int blocks); 48 asmlinkage void neon_aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[], 49 int rounds, int blocks, u8 iv[]); 50 asmlinkage void neon_aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[], 51 int rounds, int bytes, u8 ctr[]); 52 asmlinkage void neon_aes_xts_encrypt(u8 out[], u8 const in[], 53 u32 const rk1[], int rounds, int bytes, 54 u32 const rk2[], u8 iv[], int first); 55 asmlinkage void neon_aes_xts_decrypt(u8 out[], u8 const in[], 56 u32 const rk1[], int rounds, int bytes, 57 u32 const rk2[], u8 iv[], int first); 58 59 struct aesbs_ctx { 60 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; 61 int rounds; 62 } __aligned(AES_BLOCK_SIZE); 63 64 struct aesbs_cbc_ctr_ctx { 65 struct aesbs_ctx key; 66 u32 enc[AES_MAX_KEYLENGTH_U32]; 67 }; 68 69 struct aesbs_xts_ctx { 70 struct aesbs_ctx key; 71 u32 twkey[AES_MAX_KEYLENGTH_U32]; 72 struct crypto_aes_ctx cts; 73 }; 74 75 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 76 unsigned int key_len) 77 { 78 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 79 struct crypto_aes_ctx rk; 80 int err; 81 82 err = aes_expandkey(&rk, in_key, key_len); 83 if (err) 84 return err; 85 86 ctx->rounds = 6 + key_len / 4; 87 88 scoped_ksimd() 89 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 90 91 return 0; 92 } 93 94 static int __ecb_crypt(struct skcipher_request *req, 95 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 96 int rounds, int blocks)) 97 { 98 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 99 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 100 struct skcipher_walk walk; 101 int err; 102 103 err = skcipher_walk_virt(&walk, req, false); 104 105 while (walk.nbytes >= AES_BLOCK_SIZE) { 106 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 107 108 if (walk.nbytes < walk.total) 109 blocks = round_down(blocks, 110 walk.stride / AES_BLOCK_SIZE); 111 112 scoped_ksimd() 113 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 114 ctx->rounds, blocks); 115 err = skcipher_walk_done(&walk, 116 walk.nbytes - blocks * AES_BLOCK_SIZE); 117 } 118 119 return err; 120 } 121 122 static int ecb_encrypt(struct skcipher_request *req) 123 { 124 return __ecb_crypt(req, aesbs_ecb_encrypt); 125 } 126 127 static int ecb_decrypt(struct skcipher_request *req) 128 { 129 return __ecb_crypt(req, aesbs_ecb_decrypt); 130 } 131 132 static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 133 unsigned int key_len) 134 { 135 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 136 struct crypto_aes_ctx rk; 137 int err; 138 139 err = aes_expandkey(&rk, in_key, key_len); 140 if (err) 141 return err; 142 143 ctx->key.rounds = 6 + key_len / 4; 144 145 memcpy(ctx->enc, rk.key_enc, sizeof(ctx->enc)); 146 147 scoped_ksimd() 148 aesbs_convert_key(ctx->key.rk, rk.key_enc, ctx->key.rounds); 149 memzero_explicit(&rk, sizeof(rk)); 150 151 return 0; 152 } 153 154 static int cbc_encrypt(struct skcipher_request *req) 155 { 156 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 157 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 158 struct skcipher_walk walk; 159 int err; 160 161 err = skcipher_walk_virt(&walk, req, false); 162 163 while (walk.nbytes >= AES_BLOCK_SIZE) { 164 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 165 166 /* fall back to the non-bitsliced NEON implementation */ 167 scoped_ksimd() 168 neon_aes_cbc_encrypt(walk.dst.virt.addr, 169 walk.src.virt.addr, 170 ctx->enc, ctx->key.rounds, blocks, 171 walk.iv); 172 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); 173 } 174 return err; 175 } 176 177 static int cbc_decrypt(struct skcipher_request *req) 178 { 179 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 180 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 181 struct skcipher_walk walk; 182 int err; 183 184 err = skcipher_walk_virt(&walk, req, false); 185 186 while (walk.nbytes >= AES_BLOCK_SIZE) { 187 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 188 189 if (walk.nbytes < walk.total) 190 blocks = round_down(blocks, 191 walk.stride / AES_BLOCK_SIZE); 192 193 scoped_ksimd() 194 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 195 ctx->key.rk, ctx->key.rounds, blocks, 196 walk.iv); 197 err = skcipher_walk_done(&walk, 198 walk.nbytes - blocks * AES_BLOCK_SIZE); 199 } 200 201 return err; 202 } 203 204 static int ctr_encrypt(struct skcipher_request *req) 205 { 206 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 207 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 208 struct skcipher_walk walk; 209 int err; 210 211 err = skcipher_walk_virt(&walk, req, false); 212 213 while (walk.nbytes > 0) { 214 int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; 215 int nbytes = walk.nbytes % (8 * AES_BLOCK_SIZE); 216 const u8 *src = walk.src.virt.addr; 217 u8 *dst = walk.dst.virt.addr; 218 219 scoped_ksimd() { 220 if (blocks >= 8) { 221 aesbs_ctr_encrypt(dst, src, ctx->key.rk, 222 ctx->key.rounds, blocks, 223 walk.iv); 224 dst += blocks * AES_BLOCK_SIZE; 225 src += blocks * AES_BLOCK_SIZE; 226 } 227 if (nbytes && walk.nbytes == walk.total) { 228 u8 buf[AES_BLOCK_SIZE]; 229 u8 *d = dst; 230 231 if (unlikely(nbytes < AES_BLOCK_SIZE)) 232 src = dst = memcpy(buf + sizeof(buf) - 233 nbytes, src, nbytes); 234 235 neon_aes_ctr_encrypt(dst, src, ctx->enc, 236 ctx->key.rounds, nbytes, 237 walk.iv); 238 239 if (unlikely(nbytes < AES_BLOCK_SIZE)) 240 memcpy(d, dst, nbytes); 241 242 nbytes = 0; 243 } 244 } 245 err = skcipher_walk_done(&walk, nbytes); 246 } 247 return err; 248 } 249 250 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 251 unsigned int key_len) 252 { 253 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 254 struct crypto_aes_ctx rk; 255 int err; 256 257 err = xts_verify_key(tfm, in_key, key_len); 258 if (err) 259 return err; 260 261 key_len /= 2; 262 err = aes_expandkey(&ctx->cts, in_key, key_len); 263 if (err) 264 return err; 265 266 err = aes_expandkey(&rk, in_key + key_len, key_len); 267 if (err) 268 return err; 269 270 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); 271 272 return aesbs_setkey(tfm, in_key, key_len); 273 } 274 275 static int __xts_crypt(struct skcipher_request *req, bool encrypt, 276 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 277 int rounds, int blocks, u8 iv[])) 278 { 279 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 280 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 281 int tail = req->cryptlen % (8 * AES_BLOCK_SIZE); 282 struct scatterlist sg_src[2], sg_dst[2]; 283 struct skcipher_request subreq; 284 struct scatterlist *src, *dst; 285 struct skcipher_walk walk; 286 int nbytes, err; 287 int first = 1; 288 const u8 *in; 289 u8 *out; 290 291 if (req->cryptlen < AES_BLOCK_SIZE) 292 return -EINVAL; 293 294 /* ensure that the cts tail is covered by a single step */ 295 if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) { 296 int xts_blocks = DIV_ROUND_UP(req->cryptlen, 297 AES_BLOCK_SIZE) - 2; 298 299 skcipher_request_set_tfm(&subreq, tfm); 300 skcipher_request_set_callback(&subreq, 301 skcipher_request_flags(req), 302 NULL, NULL); 303 skcipher_request_set_crypt(&subreq, req->src, req->dst, 304 xts_blocks * AES_BLOCK_SIZE, 305 req->iv); 306 req = &subreq; 307 } else { 308 tail = 0; 309 } 310 311 err = skcipher_walk_virt(&walk, req, false); 312 if (err) 313 return err; 314 315 while (walk.nbytes >= AES_BLOCK_SIZE) { 316 int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; 317 out = walk.dst.virt.addr; 318 in = walk.src.virt.addr; 319 nbytes = walk.nbytes; 320 321 scoped_ksimd() { 322 if (blocks >= 8) { 323 if (first == 1) 324 neon_aes_ecb_encrypt(walk.iv, walk.iv, 325 ctx->twkey, 326 ctx->key.rounds, 1); 327 first = 2; 328 329 fn(out, in, ctx->key.rk, ctx->key.rounds, blocks, 330 walk.iv); 331 332 out += blocks * AES_BLOCK_SIZE; 333 in += blocks * AES_BLOCK_SIZE; 334 nbytes -= blocks * AES_BLOCK_SIZE; 335 } 336 if (walk.nbytes == walk.total && nbytes > 0) { 337 if (encrypt) 338 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, 339 ctx->key.rounds, nbytes, 340 ctx->twkey, walk.iv, first); 341 else 342 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, 343 ctx->key.rounds, nbytes, 344 ctx->twkey, walk.iv, first); 345 nbytes = first = 0; 346 } 347 } 348 err = skcipher_walk_done(&walk, nbytes); 349 } 350 351 if (err || likely(!tail)) 352 return err; 353 354 /* handle ciphertext stealing */ 355 dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen); 356 if (req->dst != req->src) 357 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen); 358 359 skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail, 360 req->iv); 361 362 err = skcipher_walk_virt(&walk, req, false); 363 if (err) 364 return err; 365 366 out = walk.dst.virt.addr; 367 in = walk.src.virt.addr; 368 nbytes = walk.nbytes; 369 370 scoped_ksimd() { 371 if (encrypt) 372 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, 373 ctx->key.rounds, nbytes, ctx->twkey, 374 walk.iv, first); 375 else 376 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, 377 ctx->key.rounds, nbytes, ctx->twkey, 378 walk.iv, first); 379 } 380 381 return skcipher_walk_done(&walk, 0); 382 } 383 384 static int xts_encrypt(struct skcipher_request *req) 385 { 386 return __xts_crypt(req, true, aesbs_xts_encrypt); 387 } 388 389 static int xts_decrypt(struct skcipher_request *req) 390 { 391 return __xts_crypt(req, false, aesbs_xts_decrypt); 392 } 393 394 static struct skcipher_alg aes_algs[] = { { 395 .base.cra_name = "ecb(aes)", 396 .base.cra_driver_name = "ecb-aes-neonbs", 397 .base.cra_priority = 250, 398 .base.cra_blocksize = AES_BLOCK_SIZE, 399 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 400 .base.cra_module = THIS_MODULE, 401 402 .min_keysize = AES_MIN_KEY_SIZE, 403 .max_keysize = AES_MAX_KEY_SIZE, 404 .walksize = 8 * AES_BLOCK_SIZE, 405 .setkey = aesbs_setkey, 406 .encrypt = ecb_encrypt, 407 .decrypt = ecb_decrypt, 408 }, { 409 .base.cra_name = "cbc(aes)", 410 .base.cra_driver_name = "cbc-aes-neonbs", 411 .base.cra_priority = 250, 412 .base.cra_blocksize = AES_BLOCK_SIZE, 413 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), 414 .base.cra_module = THIS_MODULE, 415 416 .min_keysize = AES_MIN_KEY_SIZE, 417 .max_keysize = AES_MAX_KEY_SIZE, 418 .walksize = 8 * AES_BLOCK_SIZE, 419 .ivsize = AES_BLOCK_SIZE, 420 .setkey = aesbs_cbc_ctr_setkey, 421 .encrypt = cbc_encrypt, 422 .decrypt = cbc_decrypt, 423 }, { 424 .base.cra_name = "ctr(aes)", 425 .base.cra_driver_name = "ctr-aes-neonbs", 426 .base.cra_priority = 250, 427 .base.cra_blocksize = 1, 428 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), 429 .base.cra_module = THIS_MODULE, 430 431 .min_keysize = AES_MIN_KEY_SIZE, 432 .max_keysize = AES_MAX_KEY_SIZE, 433 .chunksize = AES_BLOCK_SIZE, 434 .walksize = 8 * AES_BLOCK_SIZE, 435 .ivsize = AES_BLOCK_SIZE, 436 .setkey = aesbs_cbc_ctr_setkey, 437 .encrypt = ctr_encrypt, 438 .decrypt = ctr_encrypt, 439 }, { 440 .base.cra_name = "xts(aes)", 441 .base.cra_driver_name = "xts-aes-neonbs", 442 .base.cra_priority = 250, 443 .base.cra_blocksize = AES_BLOCK_SIZE, 444 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 445 .base.cra_module = THIS_MODULE, 446 447 .min_keysize = 2 * AES_MIN_KEY_SIZE, 448 .max_keysize = 2 * AES_MAX_KEY_SIZE, 449 .walksize = 8 * AES_BLOCK_SIZE, 450 .ivsize = AES_BLOCK_SIZE, 451 .setkey = aesbs_xts_setkey, 452 .encrypt = xts_encrypt, 453 .decrypt = xts_decrypt, 454 } }; 455 456 static void aes_exit(void) 457 { 458 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 459 } 460 461 static int __init aes_init(void) 462 { 463 if (!cpu_have_named_feature(ASIMD)) 464 return -ENODEV; 465 466 return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 467 } 468 469 module_init(aes_init); 470 module_exit(aes_exit); 471