1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Bit sliced AES using NEON instructions 4 * 5 * Copyright (C) 2016 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 6 */ 7 8 #include <asm/neon.h> 9 #include <asm/simd.h> 10 #include <crypto/aes.h> 11 #include <crypto/ctr.h> 12 #include <crypto/internal/simd.h> 13 #include <crypto/internal/skcipher.h> 14 #include <crypto/scatterwalk.h> 15 #include <crypto/xts.h> 16 #include <linux/module.h> 17 18 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 19 MODULE_DESCRIPTION("Bit sliced AES using NEON instructions"); 20 MODULE_IMPORT_NS("CRYPTO_INTERNAL"); 21 MODULE_LICENSE("GPL v2"); 22 23 MODULE_ALIAS_CRYPTO("ecb(aes)"); 24 MODULE_ALIAS_CRYPTO("cbc(aes)"); 25 MODULE_ALIAS_CRYPTO("ctr(aes)"); 26 MODULE_ALIAS_CRYPTO("xts(aes)"); 27 28 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 29 30 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 31 int rounds, int blocks); 32 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 33 int rounds, int blocks); 34 35 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 36 int rounds, int blocks, u8 iv[]); 37 38 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 39 int rounds, int blocks, u8 iv[]); 40 41 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 42 int rounds, int blocks, u8 iv[]); 43 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 44 int rounds, int blocks, u8 iv[]); 45 46 struct aesbs_ctx { 47 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32]; 48 int rounds; 49 } __aligned(AES_BLOCK_SIZE); 50 51 struct aesbs_cbc_ctr_ctx { 52 struct aesbs_ctx key; 53 u32 enc[AES_MAX_KEYLENGTH_U32]; 54 }; 55 56 struct aesbs_xts_ctx { 57 struct aesbs_ctx key; 58 u32 twkey[AES_MAX_KEYLENGTH_U32]; 59 struct crypto_aes_ctx cts; 60 }; 61 62 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 63 unsigned int key_len) 64 { 65 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 66 struct crypto_aes_ctx *rk; 67 int err; 68 69 rk = kmalloc(sizeof(*rk), GFP_KERNEL); 70 if (!rk) 71 return -ENOMEM; 72 73 err = aes_expandkey(rk, in_key, key_len); 74 if (err) 75 goto out; 76 77 ctx->rounds = 6 + key_len / 4; 78 79 scoped_ksimd() 80 aesbs_convert_key(ctx->rk, rk->key_enc, ctx->rounds); 81 out: 82 kfree_sensitive(rk); 83 return err; 84 } 85 86 static int __ecb_crypt(struct skcipher_request *req, 87 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 88 int rounds, int blocks)) 89 { 90 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 91 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 92 struct skcipher_walk walk; 93 int err; 94 95 err = skcipher_walk_virt(&walk, req, false); 96 97 while (walk.nbytes >= AES_BLOCK_SIZE) { 98 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 99 100 if (walk.nbytes < walk.total) 101 blocks = round_down(blocks, 102 walk.stride / AES_BLOCK_SIZE); 103 104 scoped_ksimd() 105 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 106 ctx->rounds, blocks); 107 err = skcipher_walk_done(&walk, 108 walk.nbytes - blocks * AES_BLOCK_SIZE); 109 } 110 111 return err; 112 } 113 114 static int ecb_encrypt(struct skcipher_request *req) 115 { 116 return __ecb_crypt(req, aesbs_ecb_encrypt); 117 } 118 119 static int ecb_decrypt(struct skcipher_request *req) 120 { 121 return __ecb_crypt(req, aesbs_ecb_decrypt); 122 } 123 124 static int aesbs_cbc_ctr_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 125 unsigned int key_len) 126 { 127 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 128 struct crypto_aes_ctx *rk; 129 int err; 130 131 rk = kmalloc(sizeof(*rk), GFP_KERNEL); 132 if (!rk) 133 return -ENOMEM; 134 135 err = aes_expandkey(rk, in_key, key_len); 136 if (err) 137 goto out; 138 139 ctx->key.rounds = 6 + key_len / 4; 140 141 memcpy(ctx->enc, rk->key_enc, sizeof(ctx->enc)); 142 143 scoped_ksimd() 144 aesbs_convert_key(ctx->key.rk, rk->key_enc, ctx->key.rounds); 145 out: 146 kfree_sensitive(rk); 147 return err; 148 } 149 150 static int cbc_encrypt(struct skcipher_request *req) 151 { 152 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 153 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 154 struct skcipher_walk walk; 155 int err; 156 157 err = skcipher_walk_virt(&walk, req, false); 158 159 while (walk.nbytes >= AES_BLOCK_SIZE) { 160 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 161 162 /* fall back to the non-bitsliced NEON implementation */ 163 scoped_ksimd() 164 neon_aes_cbc_encrypt(walk.dst.virt.addr, 165 walk.src.virt.addr, 166 ctx->enc, ctx->key.rounds, blocks, 167 walk.iv); 168 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE); 169 } 170 return err; 171 } 172 173 static int cbc_decrypt(struct skcipher_request *req) 174 { 175 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 176 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 177 struct skcipher_walk walk; 178 int err; 179 180 err = skcipher_walk_virt(&walk, req, false); 181 182 while (walk.nbytes >= AES_BLOCK_SIZE) { 183 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 184 185 if (walk.nbytes < walk.total) 186 blocks = round_down(blocks, 187 walk.stride / AES_BLOCK_SIZE); 188 189 scoped_ksimd() 190 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 191 ctx->key.rk, ctx->key.rounds, blocks, 192 walk.iv); 193 err = skcipher_walk_done(&walk, 194 walk.nbytes - blocks * AES_BLOCK_SIZE); 195 } 196 197 return err; 198 } 199 200 static int ctr_encrypt(struct skcipher_request *req) 201 { 202 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 203 struct aesbs_cbc_ctr_ctx *ctx = crypto_skcipher_ctx(tfm); 204 struct skcipher_walk walk; 205 int err; 206 207 err = skcipher_walk_virt(&walk, req, false); 208 209 while (walk.nbytes > 0) { 210 int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; 211 int nbytes = walk.nbytes % (8 * AES_BLOCK_SIZE); 212 const u8 *src = walk.src.virt.addr; 213 u8 *dst = walk.dst.virt.addr; 214 215 scoped_ksimd() { 216 if (blocks >= 8) { 217 aesbs_ctr_encrypt(dst, src, ctx->key.rk, 218 ctx->key.rounds, blocks, 219 walk.iv); 220 dst += blocks * AES_BLOCK_SIZE; 221 src += blocks * AES_BLOCK_SIZE; 222 } 223 if (nbytes && walk.nbytes == walk.total) { 224 u8 buf[AES_BLOCK_SIZE]; 225 u8 *d = dst; 226 227 if (unlikely(nbytes < AES_BLOCK_SIZE)) 228 src = dst = memcpy(buf + sizeof(buf) - 229 nbytes, src, nbytes); 230 231 neon_aes_ctr_encrypt(dst, src, ctx->enc, 232 ctx->key.rounds, nbytes, 233 walk.iv); 234 235 if (unlikely(nbytes < AES_BLOCK_SIZE)) 236 memcpy(d, dst, nbytes); 237 238 nbytes = 0; 239 } 240 } 241 err = skcipher_walk_done(&walk, nbytes); 242 } 243 return err; 244 } 245 246 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 247 unsigned int key_len) 248 { 249 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 250 struct crypto_aes_ctx rk; 251 int err; 252 253 err = xts_verify_key(tfm, in_key, key_len); 254 if (err) 255 return err; 256 257 key_len /= 2; 258 err = aes_expandkey(&ctx->cts, in_key, key_len); 259 if (err) 260 return err; 261 262 err = aes_expandkey(&rk, in_key + key_len, key_len); 263 if (err) 264 return err; 265 266 memcpy(ctx->twkey, rk.key_enc, sizeof(ctx->twkey)); 267 268 return aesbs_setkey(tfm, in_key, key_len); 269 } 270 271 static int __xts_crypt(struct skcipher_request *req, bool encrypt, 272 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 273 int rounds, int blocks, u8 iv[])) 274 { 275 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 276 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 277 int tail = req->cryptlen % (8 * AES_BLOCK_SIZE); 278 struct scatterlist sg_src[2], sg_dst[2]; 279 struct skcipher_request subreq; 280 struct scatterlist *src, *dst; 281 struct skcipher_walk walk; 282 int nbytes, err; 283 int first = 1; 284 const u8 *in; 285 u8 *out; 286 287 if (req->cryptlen < AES_BLOCK_SIZE) 288 return -EINVAL; 289 290 /* ensure that the cts tail is covered by a single step */ 291 if (unlikely(tail > 0 && tail < AES_BLOCK_SIZE)) { 292 int xts_blocks = DIV_ROUND_UP(req->cryptlen, 293 AES_BLOCK_SIZE) - 2; 294 295 skcipher_request_set_tfm(&subreq, tfm); 296 skcipher_request_set_callback(&subreq, 297 skcipher_request_flags(req), 298 NULL, NULL); 299 skcipher_request_set_crypt(&subreq, req->src, req->dst, 300 xts_blocks * AES_BLOCK_SIZE, 301 req->iv); 302 req = &subreq; 303 } else { 304 tail = 0; 305 } 306 307 err = skcipher_walk_virt(&walk, req, false); 308 if (err) 309 return err; 310 311 scoped_ksimd() { 312 while (walk.nbytes >= AES_BLOCK_SIZE) { 313 int blocks = (walk.nbytes / AES_BLOCK_SIZE) & ~7; 314 out = walk.dst.virt.addr; 315 in = walk.src.virt.addr; 316 nbytes = walk.nbytes; 317 318 if (blocks >= 8) { 319 if (first == 1) 320 neon_aes_ecb_encrypt(walk.iv, walk.iv, 321 ctx->twkey, 322 ctx->key.rounds, 1); 323 first = 2; 324 325 fn(out, in, ctx->key.rk, ctx->key.rounds, blocks, 326 walk.iv); 327 328 out += blocks * AES_BLOCK_SIZE; 329 in += blocks * AES_BLOCK_SIZE; 330 nbytes -= blocks * AES_BLOCK_SIZE; 331 } 332 if (walk.nbytes == walk.total && nbytes > 0) { 333 if (encrypt) 334 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, 335 ctx->key.rounds, nbytes, 336 ctx->twkey, walk.iv, first); 337 else 338 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, 339 ctx->key.rounds, nbytes, 340 ctx->twkey, walk.iv, first); 341 nbytes = first = 0; 342 } 343 err = skcipher_walk_done(&walk, nbytes); 344 } 345 346 if (err || likely(!tail)) 347 return err; 348 349 /* handle ciphertext stealing */ 350 dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen); 351 if (req->dst != req->src) 352 dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen); 353 354 skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail, 355 req->iv); 356 357 err = skcipher_walk_virt(&walk, req, false); 358 if (err) 359 return err; 360 361 out = walk.dst.virt.addr; 362 in = walk.src.virt.addr; 363 nbytes = walk.nbytes; 364 365 if (encrypt) 366 neon_aes_xts_encrypt(out, in, ctx->cts.key_enc, 367 ctx->key.rounds, nbytes, ctx->twkey, 368 walk.iv, first); 369 else 370 neon_aes_xts_decrypt(out, in, ctx->cts.key_dec, 371 ctx->key.rounds, nbytes, ctx->twkey, 372 walk.iv, first); 373 } 374 375 return skcipher_walk_done(&walk, 0); 376 } 377 378 static int xts_encrypt(struct skcipher_request *req) 379 { 380 return __xts_crypt(req, true, aesbs_xts_encrypt); 381 } 382 383 static int xts_decrypt(struct skcipher_request *req) 384 { 385 return __xts_crypt(req, false, aesbs_xts_decrypt); 386 } 387 388 static struct skcipher_alg aes_algs[] = { { 389 .base.cra_name = "ecb(aes)", 390 .base.cra_driver_name = "ecb-aes-neonbs", 391 .base.cra_priority = 250, 392 .base.cra_blocksize = AES_BLOCK_SIZE, 393 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 394 .base.cra_module = THIS_MODULE, 395 396 .min_keysize = AES_MIN_KEY_SIZE, 397 .max_keysize = AES_MAX_KEY_SIZE, 398 .walksize = 8 * AES_BLOCK_SIZE, 399 .setkey = aesbs_setkey, 400 .encrypt = ecb_encrypt, 401 .decrypt = ecb_decrypt, 402 }, { 403 .base.cra_name = "cbc(aes)", 404 .base.cra_driver_name = "cbc-aes-neonbs", 405 .base.cra_priority = 250, 406 .base.cra_blocksize = AES_BLOCK_SIZE, 407 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), 408 .base.cra_module = THIS_MODULE, 409 410 .min_keysize = AES_MIN_KEY_SIZE, 411 .max_keysize = AES_MAX_KEY_SIZE, 412 .walksize = 8 * AES_BLOCK_SIZE, 413 .ivsize = AES_BLOCK_SIZE, 414 .setkey = aesbs_cbc_ctr_setkey, 415 .encrypt = cbc_encrypt, 416 .decrypt = cbc_decrypt, 417 }, { 418 .base.cra_name = "ctr(aes)", 419 .base.cra_driver_name = "ctr-aes-neonbs", 420 .base.cra_priority = 250, 421 .base.cra_blocksize = 1, 422 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctr_ctx), 423 .base.cra_module = THIS_MODULE, 424 425 .min_keysize = AES_MIN_KEY_SIZE, 426 .max_keysize = AES_MAX_KEY_SIZE, 427 .chunksize = AES_BLOCK_SIZE, 428 .walksize = 8 * AES_BLOCK_SIZE, 429 .ivsize = AES_BLOCK_SIZE, 430 .setkey = aesbs_cbc_ctr_setkey, 431 .encrypt = ctr_encrypt, 432 .decrypt = ctr_encrypt, 433 }, { 434 .base.cra_name = "xts(aes)", 435 .base.cra_driver_name = "xts-aes-neonbs", 436 .base.cra_priority = 250, 437 .base.cra_blocksize = AES_BLOCK_SIZE, 438 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 439 .base.cra_module = THIS_MODULE, 440 441 .min_keysize = 2 * AES_MIN_KEY_SIZE, 442 .max_keysize = 2 * AES_MAX_KEY_SIZE, 443 .walksize = 8 * AES_BLOCK_SIZE, 444 .ivsize = AES_BLOCK_SIZE, 445 .setkey = aesbs_xts_setkey, 446 .encrypt = xts_encrypt, 447 .decrypt = xts_decrypt, 448 } }; 449 450 static void aes_exit(void) 451 { 452 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 453 } 454 455 static int __init aes_init(void) 456 { 457 if (!cpu_have_named_feature(ASIMD)) 458 return -ENODEV; 459 460 return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 461 } 462 463 module_init(aes_init); 464 module_exit(aes_exit); 465