1 /* 2 * Copyright (c) 2013, 2014 Kenneth MacKay. All rights reserved. 3 * Copyright (c) 2019 Vitaly Chikunov <vt@altlinux.org> 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are 7 * met: 8 * * Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * * Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * 14 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 18 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 19 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 20 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 */ 26 27 #include <crypto/ecc_curve.h> 28 #include <linux/module.h> 29 #include <linux/random.h> 30 #include <linux/slab.h> 31 #include <linux/swab.h> 32 #include <linux/fips.h> 33 #include <crypto/ecdh.h> 34 #include <crypto/rng.h> 35 #include <crypto/internal/ecc.h> 36 #include <linux/unaligned.h> 37 #include <linux/ratelimit.h> 38 39 #include "ecc_curve_defs.h" 40 41 typedef struct { 42 u64 m_low; 43 u64 m_high; 44 } uint128_t; 45 46 /* Returns curv25519 curve param */ 47 const struct ecc_curve *ecc_get_curve25519(void) 48 { 49 return &ecc_25519; 50 } 51 EXPORT_SYMBOL(ecc_get_curve25519); 52 53 const struct ecc_curve *ecc_get_curve(unsigned int curve_id) 54 { 55 switch (curve_id) { 56 /* In FIPS mode only allow P256 and higher */ 57 case ECC_CURVE_NIST_P192: 58 return fips_enabled ? NULL : &nist_p192; 59 case ECC_CURVE_NIST_P256: 60 return &nist_p256; 61 case ECC_CURVE_NIST_P384: 62 return &nist_p384; 63 case ECC_CURVE_NIST_P521: 64 return &nist_p521; 65 default: 66 return NULL; 67 } 68 } 69 EXPORT_SYMBOL(ecc_get_curve); 70 71 void ecc_digits_from_bytes(const u8 *in, unsigned int nbytes, 72 u64 *out, unsigned int ndigits) 73 { 74 int diff = ndigits - DIV_ROUND_UP_POW2(nbytes, sizeof(u64)); 75 unsigned int o = nbytes & 7; 76 __be64 msd = 0; 77 78 /* diff > 0: not enough input bytes: set most significant digits to 0 */ 79 if (diff > 0) { 80 ndigits -= diff; 81 memset(&out[ndigits], 0, diff * sizeof(u64)); 82 } 83 84 if (o) { 85 memcpy((u8 *)&msd + sizeof(msd) - o, in, o); 86 out[--ndigits] = be64_to_cpu(msd); 87 in += o; 88 } 89 ecc_swap_digits(in, out, ndigits); 90 } 91 EXPORT_SYMBOL(ecc_digits_from_bytes); 92 93 struct ecc_point *ecc_alloc_point(unsigned int ndigits) 94 { 95 struct ecc_point *p; 96 size_t ndigits_sz; 97 98 if (!ndigits) 99 return NULL; 100 101 p = kmalloc(sizeof(*p), GFP_KERNEL); 102 if (!p) 103 return NULL; 104 105 ndigits_sz = ndigits * sizeof(u64); 106 p->x = kmalloc(ndigits_sz, GFP_KERNEL); 107 if (!p->x) 108 goto err_alloc_x; 109 110 p->y = kmalloc(ndigits_sz, GFP_KERNEL); 111 if (!p->y) 112 goto err_alloc_y; 113 114 p->ndigits = ndigits; 115 116 return p; 117 118 err_alloc_y: 119 kfree(p->x); 120 err_alloc_x: 121 kfree(p); 122 return NULL; 123 } 124 EXPORT_SYMBOL(ecc_alloc_point); 125 126 void ecc_free_point(struct ecc_point *p) 127 { 128 if (!p) 129 return; 130 131 kfree_sensitive(p->x); 132 kfree_sensitive(p->y); 133 kfree_sensitive(p); 134 } 135 EXPORT_SYMBOL(ecc_free_point); 136 137 static void vli_clear(u64 *vli, unsigned int ndigits) 138 { 139 int i; 140 141 for (i = 0; i < ndigits; i++) 142 vli[i] = 0; 143 } 144 145 /* Returns true if vli == 0, false otherwise. */ 146 bool vli_is_zero(const u64 *vli, unsigned int ndigits) 147 { 148 int i; 149 150 for (i = 0; i < ndigits; i++) { 151 if (vli[i]) 152 return false; 153 } 154 155 return true; 156 } 157 EXPORT_SYMBOL(vli_is_zero); 158 159 /* Returns nonzero if bit of vli is set. */ 160 static u64 vli_test_bit(const u64 *vli, unsigned int bit) 161 { 162 return (vli[bit / 64] & ((u64)1 << (bit % 64))); 163 } 164 165 static bool vli_is_negative(const u64 *vli, unsigned int ndigits) 166 { 167 return vli_test_bit(vli, ndigits * 64 - 1); 168 } 169 170 /* Counts the number of 64-bit "digits" in vli. */ 171 static unsigned int vli_num_digits(const u64 *vli, unsigned int ndigits) 172 { 173 int i; 174 175 /* Search from the end until we find a non-zero digit. 176 * We do it in reverse because we expect that most digits will 177 * be nonzero. 178 */ 179 for (i = ndigits - 1; i >= 0 && vli[i] == 0; i--); 180 181 return (i + 1); 182 } 183 184 /* Counts the number of bits required for vli. */ 185 unsigned int vli_num_bits(const u64 *vli, unsigned int ndigits) 186 { 187 unsigned int i, num_digits; 188 u64 digit; 189 190 num_digits = vli_num_digits(vli, ndigits); 191 if (num_digits == 0) 192 return 0; 193 194 digit = vli[num_digits - 1]; 195 for (i = 0; digit; i++) 196 digit >>= 1; 197 198 return ((num_digits - 1) * 64 + i); 199 } 200 EXPORT_SYMBOL(vli_num_bits); 201 202 /* Set dest from unaligned bit string src. */ 203 void vli_from_be64(u64 *dest, const void *src, unsigned int ndigits) 204 { 205 int i; 206 const u64 *from = src; 207 208 for (i = 0; i < ndigits; i++) 209 dest[i] = get_unaligned_be64(&from[ndigits - 1 - i]); 210 } 211 EXPORT_SYMBOL(vli_from_be64); 212 213 void vli_from_le64(u64 *dest, const void *src, unsigned int ndigits) 214 { 215 int i; 216 const u64 *from = src; 217 218 for (i = 0; i < ndigits; i++) 219 dest[i] = get_unaligned_le64(&from[i]); 220 } 221 EXPORT_SYMBOL(vli_from_le64); 222 223 /* Sets dest = src. */ 224 static void vli_set(u64 *dest, const u64 *src, unsigned int ndigits) 225 { 226 int i; 227 228 for (i = 0; i < ndigits; i++) 229 dest[i] = src[i]; 230 } 231 232 /* Returns sign of left - right. */ 233 int vli_cmp(const u64 *left, const u64 *right, unsigned int ndigits) 234 { 235 int i; 236 237 for (i = ndigits - 1; i >= 0; i--) { 238 if (left[i] > right[i]) 239 return 1; 240 else if (left[i] < right[i]) 241 return -1; 242 } 243 244 return 0; 245 } 246 EXPORT_SYMBOL(vli_cmp); 247 248 /* Computes result = in << c, returning carry. Can modify in place 249 * (if result == in). 0 < shift < 64. 250 */ 251 static u64 vli_lshift(u64 *result, const u64 *in, unsigned int shift, 252 unsigned int ndigits) 253 { 254 u64 carry = 0; 255 int i; 256 257 for (i = 0; i < ndigits; i++) { 258 u64 temp = in[i]; 259 260 result[i] = (temp << shift) | carry; 261 carry = temp >> (64 - shift); 262 } 263 264 return carry; 265 } 266 267 /* Computes vli = vli >> 1. */ 268 static void vli_rshift1(u64 *vli, unsigned int ndigits) 269 { 270 u64 *end = vli; 271 u64 carry = 0; 272 273 vli += ndigits; 274 275 while (vli-- > end) { 276 u64 temp = *vli; 277 *vli = (temp >> 1) | carry; 278 carry = temp << 63; 279 } 280 } 281 282 /* Computes result = left + right, returning carry. Can modify in place. */ 283 static u64 vli_add(u64 *result, const u64 *left, const u64 *right, 284 unsigned int ndigits) 285 { 286 u64 carry = 0; 287 int i; 288 289 for (i = 0; i < ndigits; i++) { 290 u64 sum; 291 292 sum = left[i] + right[i] + carry; 293 if (sum != left[i]) 294 carry = (sum < left[i]); 295 296 result[i] = sum; 297 } 298 299 return carry; 300 } 301 302 /* Computes result = left + right, returning carry. Can modify in place. */ 303 static u64 vli_uadd(u64 *result, const u64 *left, u64 right, 304 unsigned int ndigits) 305 { 306 u64 carry = right; 307 int i; 308 309 for (i = 0; i < ndigits; i++) { 310 u64 sum; 311 312 sum = left[i] + carry; 313 if (sum != left[i]) 314 carry = (sum < left[i]); 315 else 316 carry = !!carry; 317 318 result[i] = sum; 319 } 320 321 return carry; 322 } 323 324 /* Computes result = left - right, returning borrow. Can modify in place. */ 325 u64 vli_sub(u64 *result, const u64 *left, const u64 *right, 326 unsigned int ndigits) 327 { 328 u64 borrow = 0; 329 int i; 330 331 for (i = 0; i < ndigits; i++) { 332 u64 diff; 333 334 diff = left[i] - right[i] - borrow; 335 if (diff != left[i]) 336 borrow = (diff > left[i]); 337 338 result[i] = diff; 339 } 340 341 return borrow; 342 } 343 EXPORT_SYMBOL(vli_sub); 344 345 /* Computes result = left - right, returning borrow. Can modify in place. */ 346 static u64 vli_usub(u64 *result, const u64 *left, u64 right, 347 unsigned int ndigits) 348 { 349 u64 borrow = right; 350 int i; 351 352 for (i = 0; i < ndigits; i++) { 353 u64 diff; 354 355 diff = left[i] - borrow; 356 if (diff != left[i]) 357 borrow = (diff > left[i]); 358 359 result[i] = diff; 360 } 361 362 return borrow; 363 } 364 365 static uint128_t mul_64_64(u64 left, u64 right) 366 { 367 uint128_t result; 368 #if defined(CONFIG_ARCH_SUPPORTS_INT128) 369 unsigned __int128 m = (unsigned __int128)left * right; 370 371 result.m_low = m; 372 result.m_high = m >> 64; 373 #else 374 u64 a0 = left & 0xffffffffull; 375 u64 a1 = left >> 32; 376 u64 b0 = right & 0xffffffffull; 377 u64 b1 = right >> 32; 378 u64 m0 = a0 * b0; 379 u64 m1 = a0 * b1; 380 u64 m2 = a1 * b0; 381 u64 m3 = a1 * b1; 382 383 m2 += (m0 >> 32); 384 m2 += m1; 385 386 /* Overflow */ 387 if (m2 < m1) 388 m3 += 0x100000000ull; 389 390 result.m_low = (m0 & 0xffffffffull) | (m2 << 32); 391 result.m_high = m3 + (m2 >> 32); 392 #endif 393 return result; 394 } 395 396 static uint128_t add_128_128(uint128_t a, uint128_t b) 397 { 398 uint128_t result; 399 400 result.m_low = a.m_low + b.m_low; 401 result.m_high = a.m_high + b.m_high + (result.m_low < a.m_low); 402 403 return result; 404 } 405 406 static void vli_mult(u64 *result, const u64 *left, const u64 *right, 407 unsigned int ndigits) 408 { 409 uint128_t r01 = { 0, 0 }; 410 u64 r2 = 0; 411 unsigned int i, k; 412 413 /* Compute each digit of result in sequence, maintaining the 414 * carries. 415 */ 416 for (k = 0; k < ndigits * 2 - 1; k++) { 417 unsigned int min; 418 419 if (k < ndigits) 420 min = 0; 421 else 422 min = (k + 1) - ndigits; 423 424 for (i = min; i <= k && i < ndigits; i++) { 425 uint128_t product; 426 427 product = mul_64_64(left[i], right[k - i]); 428 429 r01 = add_128_128(r01, product); 430 r2 += (r01.m_high < product.m_high); 431 } 432 433 result[k] = r01.m_low; 434 r01.m_low = r01.m_high; 435 r01.m_high = r2; 436 r2 = 0; 437 } 438 439 result[ndigits * 2 - 1] = r01.m_low; 440 } 441 442 /* Compute product = left * right, for a small right value. */ 443 static void vli_umult(u64 *result, const u64 *left, u32 right, 444 unsigned int ndigits) 445 { 446 uint128_t r01 = { 0 }; 447 unsigned int k; 448 449 for (k = 0; k < ndigits; k++) { 450 uint128_t product; 451 452 product = mul_64_64(left[k], right); 453 r01 = add_128_128(r01, product); 454 /* no carry */ 455 result[k] = r01.m_low; 456 r01.m_low = r01.m_high; 457 r01.m_high = 0; 458 } 459 result[k] = r01.m_low; 460 for (++k; k < ndigits * 2; k++) 461 result[k] = 0; 462 } 463 464 static void vli_square(u64 *result, const u64 *left, unsigned int ndigits) 465 { 466 uint128_t r01 = { 0, 0 }; 467 u64 r2 = 0; 468 int i, k; 469 470 for (k = 0; k < ndigits * 2 - 1; k++) { 471 unsigned int min; 472 473 if (k < ndigits) 474 min = 0; 475 else 476 min = (k + 1) - ndigits; 477 478 for (i = min; i <= k && i <= k - i; i++) { 479 uint128_t product; 480 481 product = mul_64_64(left[i], left[k - i]); 482 483 if (i < k - i) { 484 r2 += product.m_high >> 63; 485 product.m_high = (product.m_high << 1) | 486 (product.m_low >> 63); 487 product.m_low <<= 1; 488 } 489 490 r01 = add_128_128(r01, product); 491 r2 += (r01.m_high < product.m_high); 492 } 493 494 result[k] = r01.m_low; 495 r01.m_low = r01.m_high; 496 r01.m_high = r2; 497 r2 = 0; 498 } 499 500 result[ndigits * 2 - 1] = r01.m_low; 501 } 502 503 /* Computes result = (left + right) % mod. 504 * Assumes that left < mod and right < mod, result != mod. 505 */ 506 static void vli_mod_add(u64 *result, const u64 *left, const u64 *right, 507 const u64 *mod, unsigned int ndigits) 508 { 509 u64 carry; 510 511 carry = vli_add(result, left, right, ndigits); 512 513 /* result > mod (result = mod + remainder), so subtract mod to 514 * get remainder. 515 */ 516 if (carry || vli_cmp(result, mod, ndigits) >= 0) 517 vli_sub(result, result, mod, ndigits); 518 } 519 520 /* Computes result = (left - right) % mod. 521 * Assumes that left < mod and right < mod, result != mod. 522 */ 523 static void vli_mod_sub(u64 *result, const u64 *left, const u64 *right, 524 const u64 *mod, unsigned int ndigits) 525 { 526 u64 borrow = vli_sub(result, left, right, ndigits); 527 528 /* In this case, p_result == -diff == (max int) - diff. 529 * Since -x % d == d - x, we can get the correct result from 530 * result + mod (with overflow). 531 */ 532 if (borrow) 533 vli_add(result, result, mod, ndigits); 534 } 535 536 /* 537 * Computes result = product % mod 538 * for special form moduli: p = 2^k-c, for small c (note the minus sign) 539 * 540 * References: 541 * R. Crandall, C. Pomerance. Prime Numbers: A Computational Perspective. 542 * 9 Fast Algorithms for Large-Integer Arithmetic. 9.2.3 Moduli of special form 543 * Algorithm 9.2.13 (Fast mod operation for special-form moduli). 544 */ 545 static void vli_mmod_special(u64 *result, const u64 *product, 546 const u64 *mod, unsigned int ndigits) 547 { 548 u64 c = -mod[0]; 549 u64 t[ECC_MAX_DIGITS * 2]; 550 u64 r[ECC_MAX_DIGITS * 2]; 551 552 vli_set(r, product, ndigits * 2); 553 while (!vli_is_zero(r + ndigits, ndigits)) { 554 vli_umult(t, r + ndigits, c, ndigits); 555 vli_clear(r + ndigits, ndigits); 556 vli_add(r, r, t, ndigits * 2); 557 } 558 vli_set(t, mod, ndigits); 559 vli_clear(t + ndigits, ndigits); 560 while (vli_cmp(r, t, ndigits * 2) >= 0) 561 vli_sub(r, r, t, ndigits * 2); 562 vli_set(result, r, ndigits); 563 } 564 565 /* 566 * Computes result = product % mod 567 * for special form moduli: p = 2^{k-1}+c, for small c (note the plus sign) 568 * where k-1 does not fit into qword boundary by -1 bit (such as 255). 569 570 * References (loosely based on): 571 * A. Menezes, P. van Oorschot, S. Vanstone. Handbook of Applied Cryptography. 572 * 14.3.4 Reduction methods for moduli of special form. Algorithm 14.47. 573 * URL: http://cacr.uwaterloo.ca/hac/about/chap14.pdf 574 * 575 * H. Cohen, G. Frey, R. Avanzi, C. Doche, T. Lange, K. Nguyen, F. Vercauteren. 576 * Handbook of Elliptic and Hyperelliptic Curve Cryptography. 577 * Algorithm 10.25 Fast reduction for special form moduli 578 */ 579 static void vli_mmod_special2(u64 *result, const u64 *product, 580 const u64 *mod, unsigned int ndigits) 581 { 582 u64 c2 = mod[0] * 2; 583 u64 q[ECC_MAX_DIGITS]; 584 u64 r[ECC_MAX_DIGITS * 2]; 585 u64 m[ECC_MAX_DIGITS * 2]; /* expanded mod */ 586 int carry; /* last bit that doesn't fit into q */ 587 int i; 588 589 vli_set(m, mod, ndigits); 590 vli_clear(m + ndigits, ndigits); 591 592 vli_set(r, product, ndigits); 593 /* q and carry are top bits */ 594 vli_set(q, product + ndigits, ndigits); 595 vli_clear(r + ndigits, ndigits); 596 carry = vli_is_negative(r, ndigits); 597 if (carry) 598 r[ndigits - 1] &= (1ull << 63) - 1; 599 for (i = 1; carry || !vli_is_zero(q, ndigits); i++) { 600 u64 qc[ECC_MAX_DIGITS * 2]; 601 602 vli_umult(qc, q, c2, ndigits); 603 if (carry) 604 vli_uadd(qc, qc, mod[0], ndigits * 2); 605 vli_set(q, qc + ndigits, ndigits); 606 vli_clear(qc + ndigits, ndigits); 607 carry = vli_is_negative(qc, ndigits); 608 if (carry) 609 qc[ndigits - 1] &= (1ull << 63) - 1; 610 if (i & 1) 611 vli_sub(r, r, qc, ndigits * 2); 612 else 613 vli_add(r, r, qc, ndigits * 2); 614 } 615 while (vli_is_negative(r, ndigits * 2)) 616 vli_add(r, r, m, ndigits * 2); 617 while (vli_cmp(r, m, ndigits * 2) >= 0) 618 vli_sub(r, r, m, ndigits * 2); 619 620 vli_set(result, r, ndigits); 621 } 622 623 /* 624 * Computes result = product % mod, where product is 2N words long. 625 * Reference: Ken MacKay's micro-ecc. 626 * Currently only designed to work for curve_p or curve_n. 627 */ 628 static void vli_mmod_slow(u64 *result, u64 *product, const u64 *mod, 629 unsigned int ndigits) 630 { 631 u64 mod_m[2 * ECC_MAX_DIGITS]; 632 u64 tmp[2 * ECC_MAX_DIGITS]; 633 u64 *v[2] = { tmp, product }; 634 u64 carry = 0; 635 unsigned int i; 636 /* Shift mod so its highest set bit is at the maximum position. */ 637 int shift = (ndigits * 2 * 64) - vli_num_bits(mod, ndigits); 638 int word_shift = shift / 64; 639 int bit_shift = shift % 64; 640 641 vli_clear(mod_m, word_shift); 642 if (bit_shift > 0) { 643 for (i = 0; i < ndigits; ++i) { 644 mod_m[word_shift + i] = (mod[i] << bit_shift) | carry; 645 carry = mod[i] >> (64 - bit_shift); 646 } 647 } else 648 vli_set(mod_m + word_shift, mod, ndigits); 649 650 for (i = 1; shift >= 0; --shift) { 651 u64 borrow = 0; 652 unsigned int j; 653 654 for (j = 0; j < ndigits * 2; ++j) { 655 u64 diff = v[i][j] - mod_m[j] - borrow; 656 657 if (diff != v[i][j]) 658 borrow = (diff > v[i][j]); 659 v[1 - i][j] = diff; 660 } 661 i = !(i ^ borrow); /* Swap the index if there was no borrow */ 662 vli_rshift1(mod_m, ndigits); 663 mod_m[ndigits - 1] |= mod_m[ndigits] << (64 - 1); 664 vli_rshift1(mod_m + ndigits, ndigits); 665 } 666 vli_set(result, v[i], ndigits); 667 } 668 669 /* Computes result = product % mod using Barrett's reduction with precomputed 670 * value mu appended to the mod after ndigits, mu = (2^{2w} / mod) and have 671 * length ndigits + 1, where mu * (2^w - 1) should not overflow ndigits 672 * boundary. 673 * 674 * Reference: 675 * R. Brent, P. Zimmermann. Modern Computer Arithmetic. 2010. 676 * 2.4.1 Barrett's algorithm. Algorithm 2.5. 677 */ 678 static void vli_mmod_barrett(u64 *result, u64 *product, const u64 *mod, 679 unsigned int ndigits) 680 { 681 u64 q[ECC_MAX_DIGITS * 2]; 682 u64 r[ECC_MAX_DIGITS * 2]; 683 const u64 *mu = mod + ndigits; 684 685 vli_mult(q, product + ndigits, mu, ndigits); 686 if (mu[ndigits]) 687 vli_add(q + ndigits, q + ndigits, product + ndigits, ndigits); 688 vli_mult(r, mod, q + ndigits, ndigits); 689 vli_sub(r, product, r, ndigits * 2); 690 while (!vli_is_zero(r + ndigits, ndigits) || 691 vli_cmp(r, mod, ndigits) != -1) { 692 u64 carry; 693 694 carry = vli_sub(r, r, mod, ndigits); 695 vli_usub(r + ndigits, r + ndigits, carry, ndigits); 696 } 697 vli_set(result, r, ndigits); 698 } 699 700 /* Computes p_result = p_product % curve_p. 701 * See algorithm 5 and 6 from 702 * http://www.isys.uni-klu.ac.at/PDF/2001-0126-MT.pdf 703 */ 704 static void vli_mmod_fast_192(u64 *result, const u64 *product, 705 const u64 *curve_prime, u64 *tmp) 706 { 707 const unsigned int ndigits = ECC_CURVE_NIST_P192_DIGITS; 708 int carry; 709 710 vli_set(result, product, ndigits); 711 712 vli_set(tmp, &product[3], ndigits); 713 carry = vli_add(result, result, tmp, ndigits); 714 715 tmp[0] = 0; 716 tmp[1] = product[3]; 717 tmp[2] = product[4]; 718 carry += vli_add(result, result, tmp, ndigits); 719 720 tmp[0] = tmp[1] = product[5]; 721 tmp[2] = 0; 722 carry += vli_add(result, result, tmp, ndigits); 723 724 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 725 carry -= vli_sub(result, result, curve_prime, ndigits); 726 } 727 728 /* Computes result = product % curve_prime 729 * from http://www.nsa.gov/ia/_files/nist-routines.pdf 730 */ 731 static void vli_mmod_fast_256(u64 *result, const u64 *product, 732 const u64 *curve_prime, u64 *tmp) 733 { 734 int carry; 735 const unsigned int ndigits = ECC_CURVE_NIST_P256_DIGITS; 736 737 /* t */ 738 vli_set(result, product, ndigits); 739 740 /* s1 */ 741 tmp[0] = 0; 742 tmp[1] = product[5] & 0xffffffff00000000ull; 743 tmp[2] = product[6]; 744 tmp[3] = product[7]; 745 carry = vli_lshift(tmp, tmp, 1, ndigits); 746 carry += vli_add(result, result, tmp, ndigits); 747 748 /* s2 */ 749 tmp[1] = product[6] << 32; 750 tmp[2] = (product[6] >> 32) | (product[7] << 32); 751 tmp[3] = product[7] >> 32; 752 carry += vli_lshift(tmp, tmp, 1, ndigits); 753 carry += vli_add(result, result, tmp, ndigits); 754 755 /* s3 */ 756 tmp[0] = product[4]; 757 tmp[1] = product[5] & 0xffffffff; 758 tmp[2] = 0; 759 tmp[3] = product[7]; 760 carry += vli_add(result, result, tmp, ndigits); 761 762 /* s4 */ 763 tmp[0] = (product[4] >> 32) | (product[5] << 32); 764 tmp[1] = (product[5] >> 32) | (product[6] & 0xffffffff00000000ull); 765 tmp[2] = product[7]; 766 tmp[3] = (product[6] >> 32) | (product[4] << 32); 767 carry += vli_add(result, result, tmp, ndigits); 768 769 /* d1 */ 770 tmp[0] = (product[5] >> 32) | (product[6] << 32); 771 tmp[1] = (product[6] >> 32); 772 tmp[2] = 0; 773 tmp[3] = (product[4] & 0xffffffff) | (product[5] << 32); 774 carry -= vli_sub(result, result, tmp, ndigits); 775 776 /* d2 */ 777 tmp[0] = product[6]; 778 tmp[1] = product[7]; 779 tmp[2] = 0; 780 tmp[3] = (product[4] >> 32) | (product[5] & 0xffffffff00000000ull); 781 carry -= vli_sub(result, result, tmp, ndigits); 782 783 /* d3 */ 784 tmp[0] = (product[6] >> 32) | (product[7] << 32); 785 tmp[1] = (product[7] >> 32) | (product[4] << 32); 786 tmp[2] = (product[4] >> 32) | (product[5] << 32); 787 tmp[3] = (product[6] << 32); 788 carry -= vli_sub(result, result, tmp, ndigits); 789 790 /* d4 */ 791 tmp[0] = product[7]; 792 tmp[1] = product[4] & 0xffffffff00000000ull; 793 tmp[2] = product[5]; 794 tmp[3] = product[6] & 0xffffffff00000000ull; 795 carry -= vli_sub(result, result, tmp, ndigits); 796 797 if (carry < 0) { 798 do { 799 carry += vli_add(result, result, curve_prime, ndigits); 800 } while (carry < 0); 801 } else { 802 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 803 carry -= vli_sub(result, result, curve_prime, ndigits); 804 } 805 } 806 807 #define SL32OR32(x32, y32) (((u64)x32 << 32) | y32) 808 #define AND64H(x64) (x64 & 0xffFFffFF00000000ull) 809 #define AND64L(x64) (x64 & 0x00000000ffFFffFFull) 810 811 /* Computes result = product % curve_prime 812 * from "Mathematical routines for the NIST prime elliptic curves" 813 */ 814 static void vli_mmod_fast_384(u64 *result, const u64 *product, 815 const u64 *curve_prime, u64 *tmp) 816 { 817 int carry; 818 const unsigned int ndigits = ECC_CURVE_NIST_P384_DIGITS; 819 820 /* t */ 821 vli_set(result, product, ndigits); 822 823 /* s1 */ 824 tmp[0] = 0; // 0 || 0 825 tmp[1] = 0; // 0 || 0 826 tmp[2] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 827 tmp[3] = product[11]>>32; // 0 ||a23 828 tmp[4] = 0; // 0 || 0 829 tmp[5] = 0; // 0 || 0 830 carry = vli_lshift(tmp, tmp, 1, ndigits); 831 carry += vli_add(result, result, tmp, ndigits); 832 833 /* s2 */ 834 tmp[0] = product[6]; //a13||a12 835 tmp[1] = product[7]; //a15||a14 836 tmp[2] = product[8]; //a17||a16 837 tmp[3] = product[9]; //a19||a18 838 tmp[4] = product[10]; //a21||a20 839 tmp[5] = product[11]; //a23||a22 840 carry += vli_add(result, result, tmp, ndigits); 841 842 /* s3 */ 843 tmp[0] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 844 tmp[1] = SL32OR32(product[6], (product[11]>>32)); //a12||a23 845 tmp[2] = SL32OR32(product[7], (product[6])>>32); //a14||a13 846 tmp[3] = SL32OR32(product[8], (product[7]>>32)); //a16||a15 847 tmp[4] = SL32OR32(product[9], (product[8]>>32)); //a18||a17 848 tmp[5] = SL32OR32(product[10], (product[9]>>32)); //a20||a19 849 carry += vli_add(result, result, tmp, ndigits); 850 851 /* s4 */ 852 tmp[0] = AND64H(product[11]); //a23|| 0 853 tmp[1] = (product[10]<<32); //a20|| 0 854 tmp[2] = product[6]; //a13||a12 855 tmp[3] = product[7]; //a15||a14 856 tmp[4] = product[8]; //a17||a16 857 tmp[5] = product[9]; //a19||a18 858 carry += vli_add(result, result, tmp, ndigits); 859 860 /* s5 */ 861 tmp[0] = 0; // 0|| 0 862 tmp[1] = 0; // 0|| 0 863 tmp[2] = product[10]; //a21||a20 864 tmp[3] = product[11]; //a23||a22 865 tmp[4] = 0; // 0|| 0 866 tmp[5] = 0; // 0|| 0 867 carry += vli_add(result, result, tmp, ndigits); 868 869 /* s6 */ 870 tmp[0] = AND64L(product[10]); // 0 ||a20 871 tmp[1] = AND64H(product[10]); //a21|| 0 872 tmp[2] = product[11]; //a23||a22 873 tmp[3] = 0; // 0 || 0 874 tmp[4] = 0; // 0 || 0 875 tmp[5] = 0; // 0 || 0 876 carry += vli_add(result, result, tmp, ndigits); 877 878 /* d1 */ 879 tmp[0] = SL32OR32(product[6], (product[11]>>32)); //a12||a23 880 tmp[1] = SL32OR32(product[7], (product[6]>>32)); //a14||a13 881 tmp[2] = SL32OR32(product[8], (product[7]>>32)); //a16||a15 882 tmp[3] = SL32OR32(product[9], (product[8]>>32)); //a18||a17 883 tmp[4] = SL32OR32(product[10], (product[9]>>32)); //a20||a19 884 tmp[5] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 885 carry -= vli_sub(result, result, tmp, ndigits); 886 887 /* d2 */ 888 tmp[0] = (product[10]<<32); //a20|| 0 889 tmp[1] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 890 tmp[2] = (product[11]>>32); // 0 ||a23 891 tmp[3] = 0; // 0 || 0 892 tmp[4] = 0; // 0 || 0 893 tmp[5] = 0; // 0 || 0 894 carry -= vli_sub(result, result, tmp, ndigits); 895 896 /* d3 */ 897 tmp[0] = 0; // 0 || 0 898 tmp[1] = AND64H(product[11]); //a23|| 0 899 tmp[2] = product[11]>>32; // 0 ||a23 900 tmp[3] = 0; // 0 || 0 901 tmp[4] = 0; // 0 || 0 902 tmp[5] = 0; // 0 || 0 903 carry -= vli_sub(result, result, tmp, ndigits); 904 905 if (carry < 0) { 906 do { 907 carry += vli_add(result, result, curve_prime, ndigits); 908 } while (carry < 0); 909 } else { 910 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 911 carry -= vli_sub(result, result, curve_prime, ndigits); 912 } 913 914 } 915 916 #undef SL32OR32 917 #undef AND64H 918 #undef AND64L 919 920 /* 921 * Computes result = product % curve_prime 922 * from "Recommendations for Discrete Logarithm-Based Cryptography: 923 * Elliptic Curve Domain Parameters" section G.1.4 924 */ 925 static void vli_mmod_fast_521(u64 *result, const u64 *product, 926 const u64 *curve_prime, u64 *tmp) 927 { 928 const unsigned int ndigits = ECC_CURVE_NIST_P521_DIGITS; 929 size_t i; 930 931 /* Initialize result with lowest 521 bits from product */ 932 vli_set(result, product, ndigits); 933 result[8] &= 0x1ff; 934 935 for (i = 0; i < ndigits; i++) 936 tmp[i] = (product[8 + i] >> 9) | (product[9 + i] << 55); 937 tmp[8] &= 0x1ff; 938 939 vli_mod_add(result, result, tmp, curve_prime, ndigits); 940 } 941 942 /* Computes result = product % curve_prime for different curve_primes. 943 * 944 * Note that curve_primes are distinguished just by heuristic check and 945 * not by complete conformance check. 946 */ 947 static bool vli_mmod_fast(u64 *result, u64 *product, 948 const struct ecc_curve *curve) 949 { 950 u64 tmp[2 * ECC_MAX_DIGITS]; 951 const u64 *curve_prime = curve->p; 952 const unsigned int ndigits = curve->g.ndigits; 953 954 /* All NIST curves have name prefix 'nist_' */ 955 if (strncmp(curve->name, "nist_", 5) != 0) { 956 /* Try to handle Pseudo-Marsenne primes. */ 957 if (curve_prime[ndigits - 1] == -1ull) { 958 vli_mmod_special(result, product, curve_prime, 959 ndigits); 960 return true; 961 } else if (curve_prime[ndigits - 1] == 1ull << 63 && 962 curve_prime[ndigits - 2] == 0) { 963 vli_mmod_special2(result, product, curve_prime, 964 ndigits); 965 return true; 966 } 967 vli_mmod_barrett(result, product, curve_prime, ndigits); 968 return true; 969 } 970 971 switch (ndigits) { 972 case ECC_CURVE_NIST_P192_DIGITS: 973 vli_mmod_fast_192(result, product, curve_prime, tmp); 974 break; 975 case ECC_CURVE_NIST_P256_DIGITS: 976 vli_mmod_fast_256(result, product, curve_prime, tmp); 977 break; 978 case ECC_CURVE_NIST_P384_DIGITS: 979 vli_mmod_fast_384(result, product, curve_prime, tmp); 980 break; 981 case ECC_CURVE_NIST_P521_DIGITS: 982 vli_mmod_fast_521(result, product, curve_prime, tmp); 983 break; 984 default: 985 pr_err_ratelimited("ecc: unsupported digits size!\n"); 986 return false; 987 } 988 989 return true; 990 } 991 992 /* Computes result = (left * right) % mod. 993 * Assumes that mod is big enough curve order. 994 */ 995 void vli_mod_mult_slow(u64 *result, const u64 *left, const u64 *right, 996 const u64 *mod, unsigned int ndigits) 997 { 998 u64 product[ECC_MAX_DIGITS * 2]; 999 1000 vli_mult(product, left, right, ndigits); 1001 vli_mmod_slow(result, product, mod, ndigits); 1002 } 1003 EXPORT_SYMBOL(vli_mod_mult_slow); 1004 1005 /* Computes result = (left * right) % curve_prime. */ 1006 static void vli_mod_mult_fast(u64 *result, const u64 *left, const u64 *right, 1007 const struct ecc_curve *curve) 1008 { 1009 u64 product[2 * ECC_MAX_DIGITS]; 1010 1011 vli_mult(product, left, right, curve->g.ndigits); 1012 vli_mmod_fast(result, product, curve); 1013 } 1014 1015 /* Computes result = left^2 % curve_prime. */ 1016 static void vli_mod_square_fast(u64 *result, const u64 *left, 1017 const struct ecc_curve *curve) 1018 { 1019 u64 product[2 * ECC_MAX_DIGITS]; 1020 1021 vli_square(product, left, curve->g.ndigits); 1022 vli_mmod_fast(result, product, curve); 1023 } 1024 1025 #define EVEN(vli) (!(vli[0] & 1)) 1026 /* Computes result = (1 / p_input) % mod. All VLIs are the same size. 1027 * See "From Euclid's GCD to Montgomery Multiplication to the Great Divide" 1028 * https://labs.oracle.com/techrep/2001/smli_tr-2001-95.pdf 1029 */ 1030 void vli_mod_inv(u64 *result, const u64 *input, const u64 *mod, 1031 unsigned int ndigits) 1032 { 1033 u64 a[ECC_MAX_DIGITS], b[ECC_MAX_DIGITS]; 1034 u64 u[ECC_MAX_DIGITS], v[ECC_MAX_DIGITS]; 1035 u64 carry; 1036 int cmp_result; 1037 1038 if (vli_is_zero(input, ndigits)) { 1039 vli_clear(result, ndigits); 1040 return; 1041 } 1042 1043 vli_set(a, input, ndigits); 1044 vli_set(b, mod, ndigits); 1045 vli_clear(u, ndigits); 1046 u[0] = 1; 1047 vli_clear(v, ndigits); 1048 1049 while ((cmp_result = vli_cmp(a, b, ndigits)) != 0) { 1050 carry = 0; 1051 1052 if (EVEN(a)) { 1053 vli_rshift1(a, ndigits); 1054 1055 if (!EVEN(u)) 1056 carry = vli_add(u, u, mod, ndigits); 1057 1058 vli_rshift1(u, ndigits); 1059 if (carry) 1060 u[ndigits - 1] |= 0x8000000000000000ull; 1061 } else if (EVEN(b)) { 1062 vli_rshift1(b, ndigits); 1063 1064 if (!EVEN(v)) 1065 carry = vli_add(v, v, mod, ndigits); 1066 1067 vli_rshift1(v, ndigits); 1068 if (carry) 1069 v[ndigits - 1] |= 0x8000000000000000ull; 1070 } else if (cmp_result > 0) { 1071 vli_sub(a, a, b, ndigits); 1072 vli_rshift1(a, ndigits); 1073 1074 if (vli_cmp(u, v, ndigits) < 0) 1075 vli_add(u, u, mod, ndigits); 1076 1077 vli_sub(u, u, v, ndigits); 1078 if (!EVEN(u)) 1079 carry = vli_add(u, u, mod, ndigits); 1080 1081 vli_rshift1(u, ndigits); 1082 if (carry) 1083 u[ndigits - 1] |= 0x8000000000000000ull; 1084 } else { 1085 vli_sub(b, b, a, ndigits); 1086 vli_rshift1(b, ndigits); 1087 1088 if (vli_cmp(v, u, ndigits) < 0) 1089 vli_add(v, v, mod, ndigits); 1090 1091 vli_sub(v, v, u, ndigits); 1092 if (!EVEN(v)) 1093 carry = vli_add(v, v, mod, ndigits); 1094 1095 vli_rshift1(v, ndigits); 1096 if (carry) 1097 v[ndigits - 1] |= 0x8000000000000000ull; 1098 } 1099 } 1100 1101 vli_set(result, u, ndigits); 1102 } 1103 EXPORT_SYMBOL(vli_mod_inv); 1104 1105 /* ------ Point operations ------ */ 1106 1107 /* Returns true if p_point is the point at infinity, false otherwise. */ 1108 bool ecc_point_is_zero(const struct ecc_point *point) 1109 { 1110 return (vli_is_zero(point->x, point->ndigits) && 1111 vli_is_zero(point->y, point->ndigits)); 1112 } 1113 EXPORT_SYMBOL(ecc_point_is_zero); 1114 1115 /* Point multiplication algorithm using Montgomery's ladder with co-Z 1116 * coordinates. From https://eprint.iacr.org/2011/338.pdf 1117 */ 1118 1119 /* Double in place */ 1120 static void ecc_point_double_jacobian(u64 *x1, u64 *y1, u64 *z1, 1121 const struct ecc_curve *curve) 1122 { 1123 /* t1 = x, t2 = y, t3 = z */ 1124 u64 t4[ECC_MAX_DIGITS]; 1125 u64 t5[ECC_MAX_DIGITS]; 1126 const u64 *curve_prime = curve->p; 1127 const unsigned int ndigits = curve->g.ndigits; 1128 1129 if (vli_is_zero(z1, ndigits)) 1130 return; 1131 1132 /* t4 = y1^2 */ 1133 vli_mod_square_fast(t4, y1, curve); 1134 /* t5 = x1*y1^2 = A */ 1135 vli_mod_mult_fast(t5, x1, t4, curve); 1136 /* t4 = y1^4 */ 1137 vli_mod_square_fast(t4, t4, curve); 1138 /* t2 = y1*z1 = z3 */ 1139 vli_mod_mult_fast(y1, y1, z1, curve); 1140 /* t3 = z1^2 */ 1141 vli_mod_square_fast(z1, z1, curve); 1142 1143 /* t1 = x1 + z1^2 */ 1144 vli_mod_add(x1, x1, z1, curve_prime, ndigits); 1145 /* t3 = 2*z1^2 */ 1146 vli_mod_add(z1, z1, z1, curve_prime, ndigits); 1147 /* t3 = x1 - z1^2 */ 1148 vli_mod_sub(z1, x1, z1, curve_prime, ndigits); 1149 /* t1 = x1^2 - z1^4 */ 1150 vli_mod_mult_fast(x1, x1, z1, curve); 1151 1152 /* t3 = 2*(x1^2 - z1^4) */ 1153 vli_mod_add(z1, x1, x1, curve_prime, ndigits); 1154 /* t1 = 3*(x1^2 - z1^4) */ 1155 vli_mod_add(x1, x1, z1, curve_prime, ndigits); 1156 if (vli_test_bit(x1, 0)) { 1157 u64 carry = vli_add(x1, x1, curve_prime, ndigits); 1158 1159 vli_rshift1(x1, ndigits); 1160 x1[ndigits - 1] |= carry << 63; 1161 } else { 1162 vli_rshift1(x1, ndigits); 1163 } 1164 /* t1 = 3/2*(x1^2 - z1^4) = B */ 1165 1166 /* t3 = B^2 */ 1167 vli_mod_square_fast(z1, x1, curve); 1168 /* t3 = B^2 - A */ 1169 vli_mod_sub(z1, z1, t5, curve_prime, ndigits); 1170 /* t3 = B^2 - 2A = x3 */ 1171 vli_mod_sub(z1, z1, t5, curve_prime, ndigits); 1172 /* t5 = A - x3 */ 1173 vli_mod_sub(t5, t5, z1, curve_prime, ndigits); 1174 /* t1 = B * (A - x3) */ 1175 vli_mod_mult_fast(x1, x1, t5, curve); 1176 /* t4 = B * (A - x3) - y1^4 = y3 */ 1177 vli_mod_sub(t4, x1, t4, curve_prime, ndigits); 1178 1179 vli_set(x1, z1, ndigits); 1180 vli_set(z1, y1, ndigits); 1181 vli_set(y1, t4, ndigits); 1182 } 1183 1184 /* Modify (x1, y1) => (x1 * z^2, y1 * z^3) */ 1185 static void apply_z(u64 *x1, u64 *y1, u64 *z, const struct ecc_curve *curve) 1186 { 1187 u64 t1[ECC_MAX_DIGITS]; 1188 1189 vli_mod_square_fast(t1, z, curve); /* z^2 */ 1190 vli_mod_mult_fast(x1, x1, t1, curve); /* x1 * z^2 */ 1191 vli_mod_mult_fast(t1, t1, z, curve); /* z^3 */ 1192 vli_mod_mult_fast(y1, y1, t1, curve); /* y1 * z^3 */ 1193 } 1194 1195 /* P = (x1, y1) => 2P, (x2, y2) => P' */ 1196 static void xycz_initial_double(u64 *x1, u64 *y1, u64 *x2, u64 *y2, 1197 u64 *p_initial_z, const struct ecc_curve *curve) 1198 { 1199 u64 z[ECC_MAX_DIGITS]; 1200 const unsigned int ndigits = curve->g.ndigits; 1201 1202 vli_set(x2, x1, ndigits); 1203 vli_set(y2, y1, ndigits); 1204 1205 vli_clear(z, ndigits); 1206 z[0] = 1; 1207 1208 if (p_initial_z) 1209 vli_set(z, p_initial_z, ndigits); 1210 1211 apply_z(x1, y1, z, curve); 1212 1213 ecc_point_double_jacobian(x1, y1, z, curve); 1214 1215 apply_z(x2, y2, z, curve); 1216 } 1217 1218 /* Input P = (x1, y1, Z), Q = (x2, y2, Z) 1219 * Output P' = (x1', y1', Z3), P + Q = (x3, y3, Z3) 1220 * or P => P', Q => P + Q 1221 */ 1222 static void xycz_add(u64 *x1, u64 *y1, u64 *x2, u64 *y2, 1223 const struct ecc_curve *curve) 1224 { 1225 /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */ 1226 u64 t5[ECC_MAX_DIGITS]; 1227 const u64 *curve_prime = curve->p; 1228 const unsigned int ndigits = curve->g.ndigits; 1229 1230 /* t5 = x2 - x1 */ 1231 vli_mod_sub(t5, x2, x1, curve_prime, ndigits); 1232 /* t5 = (x2 - x1)^2 = A */ 1233 vli_mod_square_fast(t5, t5, curve); 1234 /* t1 = x1*A = B */ 1235 vli_mod_mult_fast(x1, x1, t5, curve); 1236 /* t3 = x2*A = C */ 1237 vli_mod_mult_fast(x2, x2, t5, curve); 1238 /* t4 = y2 - y1 */ 1239 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1240 /* t5 = (y2 - y1)^2 = D */ 1241 vli_mod_square_fast(t5, y2, curve); 1242 1243 /* t5 = D - B */ 1244 vli_mod_sub(t5, t5, x1, curve_prime, ndigits); 1245 /* t5 = D - B - C = x3 */ 1246 vli_mod_sub(t5, t5, x2, curve_prime, ndigits); 1247 /* t3 = C - B */ 1248 vli_mod_sub(x2, x2, x1, curve_prime, ndigits); 1249 /* t2 = y1*(C - B) */ 1250 vli_mod_mult_fast(y1, y1, x2, curve); 1251 /* t3 = B - x3 */ 1252 vli_mod_sub(x2, x1, t5, curve_prime, ndigits); 1253 /* t4 = (y2 - y1)*(B - x3) */ 1254 vli_mod_mult_fast(y2, y2, x2, curve); 1255 /* t4 = y3 */ 1256 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1257 1258 vli_set(x2, t5, ndigits); 1259 } 1260 1261 /* Input P = (x1, y1, Z), Q = (x2, y2, Z) 1262 * Output P + Q = (x3, y3, Z3), P - Q = (x3', y3', Z3) 1263 * or P => P - Q, Q => P + Q 1264 */ 1265 static void xycz_add_c(u64 *x1, u64 *y1, u64 *x2, u64 *y2, 1266 const struct ecc_curve *curve) 1267 { 1268 /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */ 1269 u64 t5[ECC_MAX_DIGITS]; 1270 u64 t6[ECC_MAX_DIGITS]; 1271 u64 t7[ECC_MAX_DIGITS]; 1272 const u64 *curve_prime = curve->p; 1273 const unsigned int ndigits = curve->g.ndigits; 1274 1275 /* t5 = x2 - x1 */ 1276 vli_mod_sub(t5, x2, x1, curve_prime, ndigits); 1277 /* t5 = (x2 - x1)^2 = A */ 1278 vli_mod_square_fast(t5, t5, curve); 1279 /* t1 = x1*A = B */ 1280 vli_mod_mult_fast(x1, x1, t5, curve); 1281 /* t3 = x2*A = C */ 1282 vli_mod_mult_fast(x2, x2, t5, curve); 1283 /* t4 = y2 + y1 */ 1284 vli_mod_add(t5, y2, y1, curve_prime, ndigits); 1285 /* t4 = y2 - y1 */ 1286 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1287 1288 /* t6 = C - B */ 1289 vli_mod_sub(t6, x2, x1, curve_prime, ndigits); 1290 /* t2 = y1 * (C - B) */ 1291 vli_mod_mult_fast(y1, y1, t6, curve); 1292 /* t6 = B + C */ 1293 vli_mod_add(t6, x1, x2, curve_prime, ndigits); 1294 /* t3 = (y2 - y1)^2 */ 1295 vli_mod_square_fast(x2, y2, curve); 1296 /* t3 = x3 */ 1297 vli_mod_sub(x2, x2, t6, curve_prime, ndigits); 1298 1299 /* t7 = B - x3 */ 1300 vli_mod_sub(t7, x1, x2, curve_prime, ndigits); 1301 /* t4 = (y2 - y1)*(B - x3) */ 1302 vli_mod_mult_fast(y2, y2, t7, curve); 1303 /* t4 = y3 */ 1304 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1305 1306 /* t7 = (y2 + y1)^2 = F */ 1307 vli_mod_square_fast(t7, t5, curve); 1308 /* t7 = x3' */ 1309 vli_mod_sub(t7, t7, t6, curve_prime, ndigits); 1310 /* t6 = x3' - B */ 1311 vli_mod_sub(t6, t7, x1, curve_prime, ndigits); 1312 /* t6 = (y2 + y1)*(x3' - B) */ 1313 vli_mod_mult_fast(t6, t6, t5, curve); 1314 /* t2 = y3' */ 1315 vli_mod_sub(y1, t6, y1, curve_prime, ndigits); 1316 1317 vli_set(x1, t7, ndigits); 1318 } 1319 1320 static void ecc_point_mult(struct ecc_point *result, 1321 const struct ecc_point *point, const u64 *scalar, 1322 u64 *initial_z, const struct ecc_curve *curve, 1323 unsigned int ndigits) 1324 { 1325 /* R0 and R1 */ 1326 u64 rx[2][ECC_MAX_DIGITS]; 1327 u64 ry[2][ECC_MAX_DIGITS]; 1328 u64 z[ECC_MAX_DIGITS]; 1329 u64 sk[2][ECC_MAX_DIGITS]; 1330 u64 *curve_prime = curve->p; 1331 int i, nb; 1332 int num_bits; 1333 int carry; 1334 1335 carry = vli_add(sk[0], scalar, curve->n, ndigits); 1336 vli_add(sk[1], sk[0], curve->n, ndigits); 1337 scalar = sk[!carry]; 1338 if (curve->nbits == 521) /* NIST P521 */ 1339 num_bits = curve->nbits + 2; 1340 else 1341 num_bits = sizeof(u64) * ndigits * 8 + 1; 1342 1343 vli_set(rx[1], point->x, ndigits); 1344 vli_set(ry[1], point->y, ndigits); 1345 1346 xycz_initial_double(rx[1], ry[1], rx[0], ry[0], initial_z, curve); 1347 1348 for (i = num_bits - 2; i > 0; i--) { 1349 nb = !vli_test_bit(scalar, i); 1350 xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve); 1351 xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve); 1352 } 1353 1354 nb = !vli_test_bit(scalar, 0); 1355 xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve); 1356 1357 /* Find final 1/Z value. */ 1358 /* X1 - X0 */ 1359 vli_mod_sub(z, rx[1], rx[0], curve_prime, ndigits); 1360 /* Yb * (X1 - X0) */ 1361 vli_mod_mult_fast(z, z, ry[1 - nb], curve); 1362 /* xP * Yb * (X1 - X0) */ 1363 vli_mod_mult_fast(z, z, point->x, curve); 1364 1365 /* 1 / (xP * Yb * (X1 - X0)) */ 1366 vli_mod_inv(z, z, curve_prime, point->ndigits); 1367 1368 /* yP / (xP * Yb * (X1 - X0)) */ 1369 vli_mod_mult_fast(z, z, point->y, curve); 1370 /* Xb * yP / (xP * Yb * (X1 - X0)) */ 1371 vli_mod_mult_fast(z, z, rx[1 - nb], curve); 1372 /* End 1/Z calculation */ 1373 1374 xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve); 1375 1376 apply_z(rx[0], ry[0], z, curve); 1377 1378 vli_set(result->x, rx[0], ndigits); 1379 vli_set(result->y, ry[0], ndigits); 1380 } 1381 1382 /* Computes R = P + Q mod p */ 1383 static void ecc_point_add(const struct ecc_point *result, 1384 const struct ecc_point *p, const struct ecc_point *q, 1385 const struct ecc_curve *curve) 1386 { 1387 u64 z[ECC_MAX_DIGITS]; 1388 u64 px[ECC_MAX_DIGITS]; 1389 u64 py[ECC_MAX_DIGITS]; 1390 unsigned int ndigits = curve->g.ndigits; 1391 1392 vli_set(result->x, q->x, ndigits); 1393 vli_set(result->y, q->y, ndigits); 1394 vli_mod_sub(z, result->x, p->x, curve->p, ndigits); 1395 vli_set(px, p->x, ndigits); 1396 vli_set(py, p->y, ndigits); 1397 xycz_add(px, py, result->x, result->y, curve); 1398 vli_mod_inv(z, z, curve->p, ndigits); 1399 apply_z(result->x, result->y, z, curve); 1400 } 1401 1402 /* Computes R = u1P + u2Q mod p using Shamir's trick. 1403 * Based on: Kenneth MacKay's micro-ecc (2014). 1404 */ 1405 void ecc_point_mult_shamir(const struct ecc_point *result, 1406 const u64 *u1, const struct ecc_point *p, 1407 const u64 *u2, const struct ecc_point *q, 1408 const struct ecc_curve *curve) 1409 { 1410 u64 z[ECC_MAX_DIGITS]; 1411 u64 sump[2][ECC_MAX_DIGITS]; 1412 u64 *rx = result->x; 1413 u64 *ry = result->y; 1414 unsigned int ndigits = curve->g.ndigits; 1415 unsigned int num_bits; 1416 struct ecc_point sum = ECC_POINT_INIT(sump[0], sump[1], ndigits); 1417 const struct ecc_point *points[4]; 1418 const struct ecc_point *point; 1419 unsigned int idx; 1420 int i; 1421 1422 ecc_point_add(&sum, p, q, curve); 1423 points[0] = NULL; 1424 points[1] = p; 1425 points[2] = q; 1426 points[3] = ∑ 1427 1428 num_bits = max(vli_num_bits(u1, ndigits), vli_num_bits(u2, ndigits)); 1429 i = num_bits - 1; 1430 idx = !!vli_test_bit(u1, i); 1431 idx |= (!!vli_test_bit(u2, i)) << 1; 1432 point = points[idx]; 1433 1434 vli_set(rx, point->x, ndigits); 1435 vli_set(ry, point->y, ndigits); 1436 vli_clear(z + 1, ndigits - 1); 1437 z[0] = 1; 1438 1439 for (--i; i >= 0; i--) { 1440 ecc_point_double_jacobian(rx, ry, z, curve); 1441 idx = !!vli_test_bit(u1, i); 1442 idx |= (!!vli_test_bit(u2, i)) << 1; 1443 point = points[idx]; 1444 if (point) { 1445 u64 tx[ECC_MAX_DIGITS]; 1446 u64 ty[ECC_MAX_DIGITS]; 1447 u64 tz[ECC_MAX_DIGITS]; 1448 1449 vli_set(tx, point->x, ndigits); 1450 vli_set(ty, point->y, ndigits); 1451 apply_z(tx, ty, z, curve); 1452 vli_mod_sub(tz, rx, tx, curve->p, ndigits); 1453 xycz_add(tx, ty, rx, ry, curve); 1454 vli_mod_mult_fast(z, z, tz, curve); 1455 } 1456 } 1457 vli_mod_inv(z, z, curve->p, ndigits); 1458 apply_z(rx, ry, z, curve); 1459 } 1460 EXPORT_SYMBOL(ecc_point_mult_shamir); 1461 1462 /* 1463 * This function performs checks equivalent to Appendix A.4.2 of FIPS 186-5. 1464 * Whereas A.4.2 results in an integer in the interval [1, n-1], this function 1465 * ensures that the integer is in the range of [2, n-3]. We are slightly 1466 * stricter because of the currently used scalar multiplication algorithm. 1467 */ 1468 static int __ecc_is_key_valid(const struct ecc_curve *curve, 1469 const u64 *private_key, unsigned int ndigits) 1470 { 1471 u64 one[ECC_MAX_DIGITS] = { 1, }; 1472 u64 res[ECC_MAX_DIGITS]; 1473 1474 if (!private_key) 1475 return -EINVAL; 1476 1477 if (curve->g.ndigits != ndigits) 1478 return -EINVAL; 1479 1480 /* Make sure the private key is in the range [2, n-3]. */ 1481 if (vli_cmp(one, private_key, ndigits) != -1) 1482 return -EINVAL; 1483 vli_sub(res, curve->n, one, ndigits); 1484 vli_sub(res, res, one, ndigits); 1485 if (vli_cmp(res, private_key, ndigits) != 1) 1486 return -EINVAL; 1487 1488 return 0; 1489 } 1490 1491 int ecc_is_key_valid(unsigned int curve_id, unsigned int ndigits, 1492 const u64 *private_key, unsigned int private_key_len) 1493 { 1494 int nbytes; 1495 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1496 1497 nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 1498 1499 if (private_key_len != nbytes) 1500 return -EINVAL; 1501 1502 return __ecc_is_key_valid(curve, private_key, ndigits); 1503 } 1504 EXPORT_SYMBOL(ecc_is_key_valid); 1505 1506 /* 1507 * ECC private keys are generated using the method of rejection sampling, 1508 * equivalent to that described in FIPS 186-5, Appendix A.2.2. 1509 * 1510 * This method generates a private key uniformly distributed in the range 1511 * [2, n-3]. 1512 */ 1513 int ecc_gen_privkey(unsigned int curve_id, unsigned int ndigits, 1514 u64 *private_key) 1515 { 1516 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1517 unsigned int nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 1518 unsigned int nbits = vli_num_bits(curve->n, ndigits); 1519 int err; 1520 1521 /* 1522 * Step 1 & 2: check that N is included in Table 1 of FIPS 186-5, 1523 * section 6.1.1. 1524 */ 1525 if (nbits < 224) 1526 return -EINVAL; 1527 1528 /* 1529 * FIPS 186-5 recommends that the private key should be obtained from a 1530 * RBG with a security strength equal to or greater than the security 1531 * strength associated with N. 1532 * 1533 * The maximum security strength identified by NIST SP800-57pt1r4 for 1534 * ECC is 256 (N >= 512). 1535 * 1536 * This condition is met by the default RNG because it selects a favored 1537 * DRBG with a security strength of 256. 1538 */ 1539 if (crypto_get_default_rng()) 1540 return -EFAULT; 1541 1542 /* Step 3: obtain N returned_bits from the DRBG. */ 1543 err = crypto_rng_get_bytes(crypto_default_rng, 1544 (u8 *)private_key, nbytes); 1545 crypto_put_default_rng(); 1546 if (err) 1547 return err; 1548 1549 /* Step 4: make sure the private key is in the valid range. */ 1550 if (__ecc_is_key_valid(curve, private_key, ndigits)) 1551 return -EINVAL; 1552 1553 return 0; 1554 } 1555 EXPORT_SYMBOL(ecc_gen_privkey); 1556 1557 int ecc_make_pub_key(unsigned int curve_id, unsigned int ndigits, 1558 const u64 *private_key, u64 *public_key) 1559 { 1560 int ret = 0; 1561 struct ecc_point *pk; 1562 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1563 1564 if (!private_key) { 1565 ret = -EINVAL; 1566 goto out; 1567 } 1568 1569 pk = ecc_alloc_point(ndigits); 1570 if (!pk) { 1571 ret = -ENOMEM; 1572 goto out; 1573 } 1574 1575 ecc_point_mult(pk, &curve->g, private_key, NULL, curve, ndigits); 1576 1577 /* SP800-56A rev 3 5.6.2.1.3 key check */ 1578 if (ecc_is_pubkey_valid_full(curve, pk)) { 1579 ret = -EAGAIN; 1580 goto err_free_point; 1581 } 1582 1583 ecc_swap_digits(pk->x, public_key, ndigits); 1584 ecc_swap_digits(pk->y, &public_key[ndigits], ndigits); 1585 1586 err_free_point: 1587 ecc_free_point(pk); 1588 out: 1589 return ret; 1590 } 1591 EXPORT_SYMBOL(ecc_make_pub_key); 1592 1593 /* SP800-56A section 5.6.2.3.4 partial verification: ephemeral keys only */ 1594 int ecc_is_pubkey_valid_partial(const struct ecc_curve *curve, 1595 struct ecc_point *pk) 1596 { 1597 u64 yy[ECC_MAX_DIGITS], xxx[ECC_MAX_DIGITS], w[ECC_MAX_DIGITS]; 1598 1599 if (WARN_ON(pk->ndigits != curve->g.ndigits)) 1600 return -EINVAL; 1601 1602 /* Check 1: Verify key is not the zero point. */ 1603 if (ecc_point_is_zero(pk)) 1604 return -EINVAL; 1605 1606 /* Check 2: Verify key is in the range [1, p-1]. */ 1607 if (vli_cmp(curve->p, pk->x, pk->ndigits) != 1) 1608 return -EINVAL; 1609 if (vli_cmp(curve->p, pk->y, pk->ndigits) != 1) 1610 return -EINVAL; 1611 1612 /* Check 3: Verify that y^2 == (x^3 + a·x + b) mod p */ 1613 vli_mod_square_fast(yy, pk->y, curve); /* y^2 */ 1614 vli_mod_square_fast(xxx, pk->x, curve); /* x^2 */ 1615 vli_mod_mult_fast(xxx, xxx, pk->x, curve); /* x^3 */ 1616 vli_mod_mult_fast(w, curve->a, pk->x, curve); /* a·x */ 1617 vli_mod_add(w, w, curve->b, curve->p, pk->ndigits); /* a·x + b */ 1618 vli_mod_add(w, w, xxx, curve->p, pk->ndigits); /* x^3 + a·x + b */ 1619 if (vli_cmp(yy, w, pk->ndigits) != 0) /* Equation */ 1620 return -EINVAL; 1621 1622 return 0; 1623 } 1624 EXPORT_SYMBOL(ecc_is_pubkey_valid_partial); 1625 1626 /* SP800-56A section 5.6.2.3.3 full verification */ 1627 int ecc_is_pubkey_valid_full(const struct ecc_curve *curve, 1628 struct ecc_point *pk) 1629 { 1630 struct ecc_point *nQ; 1631 1632 /* Checks 1 through 3 */ 1633 int ret = ecc_is_pubkey_valid_partial(curve, pk); 1634 1635 if (ret) 1636 return ret; 1637 1638 /* Check 4: Verify that nQ is the zero point. */ 1639 nQ = ecc_alloc_point(pk->ndigits); 1640 if (!nQ) 1641 return -ENOMEM; 1642 1643 ecc_point_mult(nQ, pk, curve->n, NULL, curve, pk->ndigits); 1644 if (!ecc_point_is_zero(nQ)) 1645 ret = -EINVAL; 1646 1647 ecc_free_point(nQ); 1648 1649 return ret; 1650 } 1651 EXPORT_SYMBOL(ecc_is_pubkey_valid_full); 1652 1653 int crypto_ecdh_shared_secret(unsigned int curve_id, unsigned int ndigits, 1654 const u64 *private_key, const u64 *public_key, 1655 u64 *secret) 1656 { 1657 int ret = 0; 1658 struct ecc_point *product, *pk; 1659 u64 rand_z[ECC_MAX_DIGITS]; 1660 unsigned int nbytes; 1661 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1662 1663 if (!private_key || !public_key || ndigits > ARRAY_SIZE(rand_z)) { 1664 ret = -EINVAL; 1665 goto out; 1666 } 1667 1668 nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 1669 1670 get_random_bytes(rand_z, nbytes); 1671 1672 pk = ecc_alloc_point(ndigits); 1673 if (!pk) { 1674 ret = -ENOMEM; 1675 goto out; 1676 } 1677 1678 ecc_swap_digits(public_key, pk->x, ndigits); 1679 ecc_swap_digits(&public_key[ndigits], pk->y, ndigits); 1680 ret = ecc_is_pubkey_valid_partial(curve, pk); 1681 if (ret) 1682 goto err_alloc_product; 1683 1684 product = ecc_alloc_point(ndigits); 1685 if (!product) { 1686 ret = -ENOMEM; 1687 goto err_alloc_product; 1688 } 1689 1690 ecc_point_mult(product, pk, private_key, rand_z, curve, ndigits); 1691 1692 if (ecc_point_is_zero(product)) { 1693 ret = -EFAULT; 1694 goto err_validity; 1695 } 1696 1697 ecc_swap_digits(product->x, secret, ndigits); 1698 1699 err_validity: 1700 memzero_explicit(rand_z, sizeof(rand_z)); 1701 ecc_free_point(product); 1702 err_alloc_product: 1703 ecc_free_point(pk); 1704 out: 1705 return ret; 1706 } 1707 EXPORT_SYMBOL(crypto_ecdh_shared_secret); 1708 1709 MODULE_DESCRIPTION("core elliptic curve module"); 1710 MODULE_LICENSE("Dual BSD/GPL"); 1711