1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * Bit sliced AES using NEON instructions 4 * 5 * Copyright (C) 2017 Linaro Ltd <ard.biesheuvel@linaro.org> 6 */ 7 8 #include <asm/neon.h> 9 #include <asm/simd.h> 10 #include <crypto/aes.h> 11 #include <crypto/internal/skcipher.h> 12 #include <crypto/scatterwalk.h> 13 #include <crypto/xts.h> 14 #include <linux/module.h> 15 #include "aes-cipher.h" 16 17 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>"); 18 MODULE_DESCRIPTION("Bit sliced AES using NEON instructions"); 19 MODULE_LICENSE("GPL v2"); 20 21 MODULE_ALIAS_CRYPTO("ecb(aes)"); 22 MODULE_ALIAS_CRYPTO("cbc(aes)"); 23 MODULE_ALIAS_CRYPTO("ctr(aes)"); 24 MODULE_ALIAS_CRYPTO("xts(aes)"); 25 26 asmlinkage void aesbs_convert_key(u8 out[], u32 const rk[], int rounds); 27 28 asmlinkage void aesbs_ecb_encrypt(u8 out[], u8 const in[], u8 const rk[], 29 int rounds, int blocks); 30 asmlinkage void aesbs_ecb_decrypt(u8 out[], u8 const in[], u8 const rk[], 31 int rounds, int blocks); 32 33 asmlinkage void aesbs_cbc_decrypt(u8 out[], u8 const in[], u8 const rk[], 34 int rounds, int blocks, u8 iv[]); 35 36 asmlinkage void aesbs_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], 37 int rounds, int blocks, u8 ctr[]); 38 39 asmlinkage void aesbs_xts_encrypt(u8 out[], u8 const in[], u8 const rk[], 40 int rounds, int blocks, u8 iv[], int); 41 asmlinkage void aesbs_xts_decrypt(u8 out[], u8 const in[], u8 const rk[], 42 int rounds, int blocks, u8 iv[], int); 43 44 struct aesbs_ctx { 45 int rounds; 46 u8 rk[13 * (8 * AES_BLOCK_SIZE) + 32] __aligned(AES_BLOCK_SIZE); 47 }; 48 49 struct aesbs_cbc_ctx { 50 struct aesbs_ctx key; 51 struct crypto_aes_ctx fallback; 52 }; 53 54 struct aesbs_xts_ctx { 55 struct aesbs_ctx key; 56 struct crypto_aes_ctx fallback; 57 struct crypto_aes_ctx tweak_key; 58 }; 59 60 static int aesbs_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 61 unsigned int key_len) 62 { 63 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 64 struct crypto_aes_ctx rk; 65 int err; 66 67 err = aes_expandkey(&rk, in_key, key_len); 68 if (err) 69 return err; 70 71 ctx->rounds = 6 + key_len / 4; 72 73 kernel_neon_begin(); 74 aesbs_convert_key(ctx->rk, rk.key_enc, ctx->rounds); 75 kernel_neon_end(); 76 77 return 0; 78 } 79 80 static int __ecb_crypt(struct skcipher_request *req, 81 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 82 int rounds, int blocks)) 83 { 84 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 85 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 86 struct skcipher_walk walk; 87 int err; 88 89 err = skcipher_walk_virt(&walk, req, false); 90 91 while (walk.nbytes >= AES_BLOCK_SIZE) { 92 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 93 94 if (walk.nbytes < walk.total) 95 blocks = round_down(blocks, 96 walk.stride / AES_BLOCK_SIZE); 97 98 kernel_neon_begin(); 99 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->rk, 100 ctx->rounds, blocks); 101 kernel_neon_end(); 102 err = skcipher_walk_done(&walk, 103 walk.nbytes - blocks * AES_BLOCK_SIZE); 104 } 105 106 return err; 107 } 108 109 static int ecb_encrypt(struct skcipher_request *req) 110 { 111 return __ecb_crypt(req, aesbs_ecb_encrypt); 112 } 113 114 static int ecb_decrypt(struct skcipher_request *req) 115 { 116 return __ecb_crypt(req, aesbs_ecb_decrypt); 117 } 118 119 static int aesbs_cbc_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 120 unsigned int key_len) 121 { 122 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 123 int err; 124 125 err = aes_expandkey(&ctx->fallback, in_key, key_len); 126 if (err) 127 return err; 128 129 ctx->key.rounds = 6 + key_len / 4; 130 131 kernel_neon_begin(); 132 aesbs_convert_key(ctx->key.rk, ctx->fallback.key_enc, ctx->key.rounds); 133 kernel_neon_end(); 134 135 return 0; 136 } 137 138 static int cbc_encrypt(struct skcipher_request *req) 139 { 140 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 141 const struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 142 struct skcipher_walk walk; 143 unsigned int nbytes; 144 int err; 145 146 err = skcipher_walk_virt(&walk, req, false); 147 148 while ((nbytes = walk.nbytes) >= AES_BLOCK_SIZE) { 149 const u8 *src = walk.src.virt.addr; 150 u8 *dst = walk.dst.virt.addr; 151 u8 *prev = walk.iv; 152 153 do { 154 crypto_xor_cpy(dst, src, prev, AES_BLOCK_SIZE); 155 __aes_arm_encrypt(ctx->fallback.key_enc, 156 ctx->key.rounds, dst, dst); 157 prev = dst; 158 src += AES_BLOCK_SIZE; 159 dst += AES_BLOCK_SIZE; 160 nbytes -= AES_BLOCK_SIZE; 161 } while (nbytes >= AES_BLOCK_SIZE); 162 memcpy(walk.iv, prev, AES_BLOCK_SIZE); 163 err = skcipher_walk_done(&walk, nbytes); 164 } 165 return err; 166 } 167 168 static int cbc_decrypt(struct skcipher_request *req) 169 { 170 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 171 struct aesbs_cbc_ctx *ctx = crypto_skcipher_ctx(tfm); 172 struct skcipher_walk walk; 173 int err; 174 175 err = skcipher_walk_virt(&walk, req, false); 176 177 while (walk.nbytes >= AES_BLOCK_SIZE) { 178 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 179 180 if (walk.nbytes < walk.total) 181 blocks = round_down(blocks, 182 walk.stride / AES_BLOCK_SIZE); 183 184 kernel_neon_begin(); 185 aesbs_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr, 186 ctx->key.rk, ctx->key.rounds, blocks, 187 walk.iv); 188 kernel_neon_end(); 189 err = skcipher_walk_done(&walk, 190 walk.nbytes - blocks * AES_BLOCK_SIZE); 191 } 192 193 return err; 194 } 195 196 static int ctr_encrypt(struct skcipher_request *req) 197 { 198 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 199 struct aesbs_ctx *ctx = crypto_skcipher_ctx(tfm); 200 struct skcipher_walk walk; 201 u8 buf[AES_BLOCK_SIZE]; 202 int err; 203 204 err = skcipher_walk_virt(&walk, req, false); 205 206 while (walk.nbytes > 0) { 207 const u8 *src = walk.src.virt.addr; 208 u8 *dst = walk.dst.virt.addr; 209 unsigned int bytes = walk.nbytes; 210 211 if (unlikely(bytes < AES_BLOCK_SIZE)) 212 src = dst = memcpy(buf + sizeof(buf) - bytes, 213 src, bytes); 214 else if (walk.nbytes < walk.total) 215 bytes &= ~(8 * AES_BLOCK_SIZE - 1); 216 217 kernel_neon_begin(); 218 aesbs_ctr_encrypt(dst, src, ctx->rk, ctx->rounds, bytes, walk.iv); 219 kernel_neon_end(); 220 221 if (unlikely(bytes < AES_BLOCK_SIZE)) 222 memcpy(walk.dst.virt.addr, 223 buf + sizeof(buf) - bytes, bytes); 224 225 err = skcipher_walk_done(&walk, walk.nbytes - bytes); 226 } 227 228 return err; 229 } 230 231 static int aesbs_xts_setkey(struct crypto_skcipher *tfm, const u8 *in_key, 232 unsigned int key_len) 233 { 234 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 235 int err; 236 237 err = xts_verify_key(tfm, in_key, key_len); 238 if (err) 239 return err; 240 241 key_len /= 2; 242 err = aes_expandkey(&ctx->fallback, in_key, key_len); 243 if (err) 244 return err; 245 err = aes_expandkey(&ctx->tweak_key, in_key + key_len, key_len); 246 if (err) 247 return err; 248 249 return aesbs_setkey(tfm, in_key, key_len); 250 } 251 252 static int __xts_crypt(struct skcipher_request *req, bool encrypt, 253 void (*fn)(u8 out[], u8 const in[], u8 const rk[], 254 int rounds, int blocks, u8 iv[], int)) 255 { 256 struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req); 257 struct aesbs_xts_ctx *ctx = crypto_skcipher_ctx(tfm); 258 const int rounds = ctx->key.rounds; 259 int tail = req->cryptlen % AES_BLOCK_SIZE; 260 struct skcipher_request subreq; 261 u8 buf[2 * AES_BLOCK_SIZE]; 262 struct skcipher_walk walk; 263 int err; 264 265 if (req->cryptlen < AES_BLOCK_SIZE) 266 return -EINVAL; 267 268 if (unlikely(tail)) { 269 skcipher_request_set_tfm(&subreq, tfm); 270 skcipher_request_set_callback(&subreq, 271 skcipher_request_flags(req), 272 NULL, NULL); 273 skcipher_request_set_crypt(&subreq, req->src, req->dst, 274 req->cryptlen - tail, req->iv); 275 req = &subreq; 276 } 277 278 err = skcipher_walk_virt(&walk, req, true); 279 if (err) 280 return err; 281 282 __aes_arm_encrypt(ctx->tweak_key.key_enc, rounds, walk.iv, walk.iv); 283 284 while (walk.nbytes >= AES_BLOCK_SIZE) { 285 unsigned int blocks = walk.nbytes / AES_BLOCK_SIZE; 286 int reorder_last_tweak = !encrypt && tail > 0; 287 288 if (walk.nbytes < walk.total) { 289 blocks = round_down(blocks, 290 walk.stride / AES_BLOCK_SIZE); 291 reorder_last_tweak = 0; 292 } 293 294 kernel_neon_begin(); 295 fn(walk.dst.virt.addr, walk.src.virt.addr, ctx->key.rk, 296 rounds, blocks, walk.iv, reorder_last_tweak); 297 kernel_neon_end(); 298 err = skcipher_walk_done(&walk, 299 walk.nbytes - blocks * AES_BLOCK_SIZE); 300 } 301 302 if (err || likely(!tail)) 303 return err; 304 305 /* handle ciphertext stealing */ 306 scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE, 307 AES_BLOCK_SIZE, 0); 308 memcpy(buf + AES_BLOCK_SIZE, buf, tail); 309 scatterwalk_map_and_copy(buf, req->src, req->cryptlen, tail, 0); 310 311 crypto_xor(buf, req->iv, AES_BLOCK_SIZE); 312 313 if (encrypt) 314 __aes_arm_encrypt(ctx->fallback.key_enc, rounds, buf, buf); 315 else 316 __aes_arm_decrypt(ctx->fallback.key_dec, rounds, buf, buf); 317 318 crypto_xor(buf, req->iv, AES_BLOCK_SIZE); 319 320 scatterwalk_map_and_copy(buf, req->dst, req->cryptlen - AES_BLOCK_SIZE, 321 AES_BLOCK_SIZE + tail, 1); 322 return 0; 323 } 324 325 static int xts_encrypt(struct skcipher_request *req) 326 { 327 return __xts_crypt(req, true, aesbs_xts_encrypt); 328 } 329 330 static int xts_decrypt(struct skcipher_request *req) 331 { 332 return __xts_crypt(req, false, aesbs_xts_decrypt); 333 } 334 335 static struct skcipher_alg aes_algs[] = { { 336 .base.cra_name = "ecb(aes)", 337 .base.cra_driver_name = "ecb-aes-neonbs", 338 .base.cra_priority = 250, 339 .base.cra_blocksize = AES_BLOCK_SIZE, 340 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 341 .base.cra_module = THIS_MODULE, 342 343 .min_keysize = AES_MIN_KEY_SIZE, 344 .max_keysize = AES_MAX_KEY_SIZE, 345 .walksize = 8 * AES_BLOCK_SIZE, 346 .setkey = aesbs_setkey, 347 .encrypt = ecb_encrypt, 348 .decrypt = ecb_decrypt, 349 }, { 350 .base.cra_name = "cbc(aes)", 351 .base.cra_driver_name = "cbc-aes-neonbs", 352 .base.cra_priority = 250, 353 .base.cra_blocksize = AES_BLOCK_SIZE, 354 .base.cra_ctxsize = sizeof(struct aesbs_cbc_ctx), 355 .base.cra_module = THIS_MODULE, 356 357 .min_keysize = AES_MIN_KEY_SIZE, 358 .max_keysize = AES_MAX_KEY_SIZE, 359 .walksize = 8 * AES_BLOCK_SIZE, 360 .ivsize = AES_BLOCK_SIZE, 361 .setkey = aesbs_cbc_setkey, 362 .encrypt = cbc_encrypt, 363 .decrypt = cbc_decrypt, 364 }, { 365 .base.cra_name = "ctr(aes)", 366 .base.cra_driver_name = "ctr-aes-neonbs", 367 .base.cra_priority = 250, 368 .base.cra_blocksize = 1, 369 .base.cra_ctxsize = sizeof(struct aesbs_ctx), 370 .base.cra_module = THIS_MODULE, 371 372 .min_keysize = AES_MIN_KEY_SIZE, 373 .max_keysize = AES_MAX_KEY_SIZE, 374 .chunksize = AES_BLOCK_SIZE, 375 .walksize = 8 * AES_BLOCK_SIZE, 376 .ivsize = AES_BLOCK_SIZE, 377 .setkey = aesbs_setkey, 378 .encrypt = ctr_encrypt, 379 .decrypt = ctr_encrypt, 380 }, { 381 .base.cra_name = "xts(aes)", 382 .base.cra_driver_name = "xts-aes-neonbs", 383 .base.cra_priority = 250, 384 .base.cra_blocksize = AES_BLOCK_SIZE, 385 .base.cra_ctxsize = sizeof(struct aesbs_xts_ctx), 386 .base.cra_module = THIS_MODULE, 387 388 .min_keysize = 2 * AES_MIN_KEY_SIZE, 389 .max_keysize = 2 * AES_MAX_KEY_SIZE, 390 .walksize = 8 * AES_BLOCK_SIZE, 391 .ivsize = AES_BLOCK_SIZE, 392 .setkey = aesbs_xts_setkey, 393 .encrypt = xts_encrypt, 394 .decrypt = xts_decrypt, 395 } }; 396 397 static void aes_exit(void) 398 { 399 crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 400 } 401 402 static int __init aes_init(void) 403 { 404 if (!(elf_hwcap & HWCAP_NEON)) 405 return -ENODEV; 406 407 return crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs)); 408 } 409 410 module_init(aes_init); 411 module_exit(aes_exit); 412