1 /* 2 * Copyright (c) 2013, 2014 Kenneth MacKay. All rights reserved. 3 * Copyright (c) 2019 Vitaly Chikunov <vt@altlinux.org> 4 * 5 * Redistribution and use in source and binary forms, with or without 6 * modification, are permitted provided that the following conditions are 7 * met: 8 * * Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * * Redistributions in binary form must reproduce the above copyright 11 * notice, this list of conditions and the following disclaimer in the 12 * documentation and/or other materials provided with the distribution. 13 * 14 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 15 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 16 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 17 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 18 * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 19 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 20 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 21 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 22 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 23 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 */ 26 27 #include <crypto/ecc_curve.h> 28 #include <linux/module.h> 29 #include <linux/random.h> 30 #include <linux/slab.h> 31 #include <linux/swab.h> 32 #include <linux/fips.h> 33 #include <crypto/ecdh.h> 34 #include <crypto/rng.h> 35 #include <crypto/internal/ecc.h> 36 #include <asm/unaligned.h> 37 #include <linux/ratelimit.h> 38 39 #include "ecc_curve_defs.h" 40 41 typedef struct { 42 u64 m_low; 43 u64 m_high; 44 } uint128_t; 45 46 /* Returns curv25519 curve param */ 47 const struct ecc_curve *ecc_get_curve25519(void) 48 { 49 return &ecc_25519; 50 } 51 EXPORT_SYMBOL(ecc_get_curve25519); 52 53 const struct ecc_curve *ecc_get_curve(unsigned int curve_id) 54 { 55 switch (curve_id) { 56 /* In FIPS mode only allow P256 and higher */ 57 case ECC_CURVE_NIST_P192: 58 return fips_enabled ? NULL : &nist_p192; 59 case ECC_CURVE_NIST_P256: 60 return &nist_p256; 61 case ECC_CURVE_NIST_P384: 62 return &nist_p384; 63 case ECC_CURVE_NIST_P521: 64 return &nist_p521; 65 default: 66 return NULL; 67 } 68 } 69 EXPORT_SYMBOL(ecc_get_curve); 70 71 static u64 *ecc_alloc_digits_space(unsigned int ndigits) 72 { 73 size_t len = ndigits * sizeof(u64); 74 75 if (!len) 76 return NULL; 77 78 return kmalloc(len, GFP_KERNEL); 79 } 80 81 static void ecc_free_digits_space(u64 *space) 82 { 83 kfree_sensitive(space); 84 } 85 86 struct ecc_point *ecc_alloc_point(unsigned int ndigits) 87 { 88 struct ecc_point *p = kmalloc(sizeof(*p), GFP_KERNEL); 89 90 if (!p) 91 return NULL; 92 93 p->x = ecc_alloc_digits_space(ndigits); 94 if (!p->x) 95 goto err_alloc_x; 96 97 p->y = ecc_alloc_digits_space(ndigits); 98 if (!p->y) 99 goto err_alloc_y; 100 101 p->ndigits = ndigits; 102 103 return p; 104 105 err_alloc_y: 106 ecc_free_digits_space(p->x); 107 err_alloc_x: 108 kfree(p); 109 return NULL; 110 } 111 EXPORT_SYMBOL(ecc_alloc_point); 112 113 void ecc_free_point(struct ecc_point *p) 114 { 115 if (!p) 116 return; 117 118 kfree_sensitive(p->x); 119 kfree_sensitive(p->y); 120 kfree_sensitive(p); 121 } 122 EXPORT_SYMBOL(ecc_free_point); 123 124 static void vli_clear(u64 *vli, unsigned int ndigits) 125 { 126 int i; 127 128 for (i = 0; i < ndigits; i++) 129 vli[i] = 0; 130 } 131 132 /* Returns true if vli == 0, false otherwise. */ 133 bool vli_is_zero(const u64 *vli, unsigned int ndigits) 134 { 135 int i; 136 137 for (i = 0; i < ndigits; i++) { 138 if (vli[i]) 139 return false; 140 } 141 142 return true; 143 } 144 EXPORT_SYMBOL(vli_is_zero); 145 146 /* Returns nonzero if bit of vli is set. */ 147 static u64 vli_test_bit(const u64 *vli, unsigned int bit) 148 { 149 return (vli[bit / 64] & ((u64)1 << (bit % 64))); 150 } 151 152 static bool vli_is_negative(const u64 *vli, unsigned int ndigits) 153 { 154 return vli_test_bit(vli, ndigits * 64 - 1); 155 } 156 157 /* Counts the number of 64-bit "digits" in vli. */ 158 static unsigned int vli_num_digits(const u64 *vli, unsigned int ndigits) 159 { 160 int i; 161 162 /* Search from the end until we find a non-zero digit. 163 * We do it in reverse because we expect that most digits will 164 * be nonzero. 165 */ 166 for (i = ndigits - 1; i >= 0 && vli[i] == 0; i--); 167 168 return (i + 1); 169 } 170 171 /* Counts the number of bits required for vli. */ 172 unsigned int vli_num_bits(const u64 *vli, unsigned int ndigits) 173 { 174 unsigned int i, num_digits; 175 u64 digit; 176 177 num_digits = vli_num_digits(vli, ndigits); 178 if (num_digits == 0) 179 return 0; 180 181 digit = vli[num_digits - 1]; 182 for (i = 0; digit; i++) 183 digit >>= 1; 184 185 return ((num_digits - 1) * 64 + i); 186 } 187 EXPORT_SYMBOL(vli_num_bits); 188 189 /* Set dest from unaligned bit string src. */ 190 void vli_from_be64(u64 *dest, const void *src, unsigned int ndigits) 191 { 192 int i; 193 const u64 *from = src; 194 195 for (i = 0; i < ndigits; i++) 196 dest[i] = get_unaligned_be64(&from[ndigits - 1 - i]); 197 } 198 EXPORT_SYMBOL(vli_from_be64); 199 200 void vli_from_le64(u64 *dest, const void *src, unsigned int ndigits) 201 { 202 int i; 203 const u64 *from = src; 204 205 for (i = 0; i < ndigits; i++) 206 dest[i] = get_unaligned_le64(&from[i]); 207 } 208 EXPORT_SYMBOL(vli_from_le64); 209 210 /* Sets dest = src. */ 211 static void vli_set(u64 *dest, const u64 *src, unsigned int ndigits) 212 { 213 int i; 214 215 for (i = 0; i < ndigits; i++) 216 dest[i] = src[i]; 217 } 218 219 /* Returns sign of left - right. */ 220 int vli_cmp(const u64 *left, const u64 *right, unsigned int ndigits) 221 { 222 int i; 223 224 for (i = ndigits - 1; i >= 0; i--) { 225 if (left[i] > right[i]) 226 return 1; 227 else if (left[i] < right[i]) 228 return -1; 229 } 230 231 return 0; 232 } 233 EXPORT_SYMBOL(vli_cmp); 234 235 /* Computes result = in << c, returning carry. Can modify in place 236 * (if result == in). 0 < shift < 64. 237 */ 238 static u64 vli_lshift(u64 *result, const u64 *in, unsigned int shift, 239 unsigned int ndigits) 240 { 241 u64 carry = 0; 242 int i; 243 244 for (i = 0; i < ndigits; i++) { 245 u64 temp = in[i]; 246 247 result[i] = (temp << shift) | carry; 248 carry = temp >> (64 - shift); 249 } 250 251 return carry; 252 } 253 254 /* Computes vli = vli >> 1. */ 255 static void vli_rshift1(u64 *vli, unsigned int ndigits) 256 { 257 u64 *end = vli; 258 u64 carry = 0; 259 260 vli += ndigits; 261 262 while (vli-- > end) { 263 u64 temp = *vli; 264 *vli = (temp >> 1) | carry; 265 carry = temp << 63; 266 } 267 } 268 269 /* Computes result = left + right, returning carry. Can modify in place. */ 270 static u64 vli_add(u64 *result, const u64 *left, const u64 *right, 271 unsigned int ndigits) 272 { 273 u64 carry = 0; 274 int i; 275 276 for (i = 0; i < ndigits; i++) { 277 u64 sum; 278 279 sum = left[i] + right[i] + carry; 280 if (sum != left[i]) 281 carry = (sum < left[i]); 282 283 result[i] = sum; 284 } 285 286 return carry; 287 } 288 289 /* Computes result = left + right, returning carry. Can modify in place. */ 290 static u64 vli_uadd(u64 *result, const u64 *left, u64 right, 291 unsigned int ndigits) 292 { 293 u64 carry = right; 294 int i; 295 296 for (i = 0; i < ndigits; i++) { 297 u64 sum; 298 299 sum = left[i] + carry; 300 if (sum != left[i]) 301 carry = (sum < left[i]); 302 else 303 carry = !!carry; 304 305 result[i] = sum; 306 } 307 308 return carry; 309 } 310 311 /* Computes result = left - right, returning borrow. Can modify in place. */ 312 u64 vli_sub(u64 *result, const u64 *left, const u64 *right, 313 unsigned int ndigits) 314 { 315 u64 borrow = 0; 316 int i; 317 318 for (i = 0; i < ndigits; i++) { 319 u64 diff; 320 321 diff = left[i] - right[i] - borrow; 322 if (diff != left[i]) 323 borrow = (diff > left[i]); 324 325 result[i] = diff; 326 } 327 328 return borrow; 329 } 330 EXPORT_SYMBOL(vli_sub); 331 332 /* Computes result = left - right, returning borrow. Can modify in place. */ 333 static u64 vli_usub(u64 *result, const u64 *left, u64 right, 334 unsigned int ndigits) 335 { 336 u64 borrow = right; 337 int i; 338 339 for (i = 0; i < ndigits; i++) { 340 u64 diff; 341 342 diff = left[i] - borrow; 343 if (diff != left[i]) 344 borrow = (diff > left[i]); 345 346 result[i] = diff; 347 } 348 349 return borrow; 350 } 351 352 static uint128_t mul_64_64(u64 left, u64 right) 353 { 354 uint128_t result; 355 #if defined(CONFIG_ARCH_SUPPORTS_INT128) 356 unsigned __int128 m = (unsigned __int128)left * right; 357 358 result.m_low = m; 359 result.m_high = m >> 64; 360 #else 361 u64 a0 = left & 0xffffffffull; 362 u64 a1 = left >> 32; 363 u64 b0 = right & 0xffffffffull; 364 u64 b1 = right >> 32; 365 u64 m0 = a0 * b0; 366 u64 m1 = a0 * b1; 367 u64 m2 = a1 * b0; 368 u64 m3 = a1 * b1; 369 370 m2 += (m0 >> 32); 371 m2 += m1; 372 373 /* Overflow */ 374 if (m2 < m1) 375 m3 += 0x100000000ull; 376 377 result.m_low = (m0 & 0xffffffffull) | (m2 << 32); 378 result.m_high = m3 + (m2 >> 32); 379 #endif 380 return result; 381 } 382 383 static uint128_t add_128_128(uint128_t a, uint128_t b) 384 { 385 uint128_t result; 386 387 result.m_low = a.m_low + b.m_low; 388 result.m_high = a.m_high + b.m_high + (result.m_low < a.m_low); 389 390 return result; 391 } 392 393 static void vli_mult(u64 *result, const u64 *left, const u64 *right, 394 unsigned int ndigits) 395 { 396 uint128_t r01 = { 0, 0 }; 397 u64 r2 = 0; 398 unsigned int i, k; 399 400 /* Compute each digit of result in sequence, maintaining the 401 * carries. 402 */ 403 for (k = 0; k < ndigits * 2 - 1; k++) { 404 unsigned int min; 405 406 if (k < ndigits) 407 min = 0; 408 else 409 min = (k + 1) - ndigits; 410 411 for (i = min; i <= k && i < ndigits; i++) { 412 uint128_t product; 413 414 product = mul_64_64(left[i], right[k - i]); 415 416 r01 = add_128_128(r01, product); 417 r2 += (r01.m_high < product.m_high); 418 } 419 420 result[k] = r01.m_low; 421 r01.m_low = r01.m_high; 422 r01.m_high = r2; 423 r2 = 0; 424 } 425 426 result[ndigits * 2 - 1] = r01.m_low; 427 } 428 429 /* Compute product = left * right, for a small right value. */ 430 static void vli_umult(u64 *result, const u64 *left, u32 right, 431 unsigned int ndigits) 432 { 433 uint128_t r01 = { 0 }; 434 unsigned int k; 435 436 for (k = 0; k < ndigits; k++) { 437 uint128_t product; 438 439 product = mul_64_64(left[k], right); 440 r01 = add_128_128(r01, product); 441 /* no carry */ 442 result[k] = r01.m_low; 443 r01.m_low = r01.m_high; 444 r01.m_high = 0; 445 } 446 result[k] = r01.m_low; 447 for (++k; k < ndigits * 2; k++) 448 result[k] = 0; 449 } 450 451 static void vli_square(u64 *result, const u64 *left, unsigned int ndigits) 452 { 453 uint128_t r01 = { 0, 0 }; 454 u64 r2 = 0; 455 int i, k; 456 457 for (k = 0; k < ndigits * 2 - 1; k++) { 458 unsigned int min; 459 460 if (k < ndigits) 461 min = 0; 462 else 463 min = (k + 1) - ndigits; 464 465 for (i = min; i <= k && i <= k - i; i++) { 466 uint128_t product; 467 468 product = mul_64_64(left[i], left[k - i]); 469 470 if (i < k - i) { 471 r2 += product.m_high >> 63; 472 product.m_high = (product.m_high << 1) | 473 (product.m_low >> 63); 474 product.m_low <<= 1; 475 } 476 477 r01 = add_128_128(r01, product); 478 r2 += (r01.m_high < product.m_high); 479 } 480 481 result[k] = r01.m_low; 482 r01.m_low = r01.m_high; 483 r01.m_high = r2; 484 r2 = 0; 485 } 486 487 result[ndigits * 2 - 1] = r01.m_low; 488 } 489 490 /* Computes result = (left + right) % mod. 491 * Assumes that left < mod and right < mod, result != mod. 492 */ 493 static void vli_mod_add(u64 *result, const u64 *left, const u64 *right, 494 const u64 *mod, unsigned int ndigits) 495 { 496 u64 carry; 497 498 carry = vli_add(result, left, right, ndigits); 499 500 /* result > mod (result = mod + remainder), so subtract mod to 501 * get remainder. 502 */ 503 if (carry || vli_cmp(result, mod, ndigits) >= 0) 504 vli_sub(result, result, mod, ndigits); 505 } 506 507 /* Computes result = (left - right) % mod. 508 * Assumes that left < mod and right < mod, result != mod. 509 */ 510 static void vli_mod_sub(u64 *result, const u64 *left, const u64 *right, 511 const u64 *mod, unsigned int ndigits) 512 { 513 u64 borrow = vli_sub(result, left, right, ndigits); 514 515 /* In this case, p_result == -diff == (max int) - diff. 516 * Since -x % d == d - x, we can get the correct result from 517 * result + mod (with overflow). 518 */ 519 if (borrow) 520 vli_add(result, result, mod, ndigits); 521 } 522 523 /* 524 * Computes result = product % mod 525 * for special form moduli: p = 2^k-c, for small c (note the minus sign) 526 * 527 * References: 528 * R. Crandall, C. Pomerance. Prime Numbers: A Computational Perspective. 529 * 9 Fast Algorithms for Large-Integer Arithmetic. 9.2.3 Moduli of special form 530 * Algorithm 9.2.13 (Fast mod operation for special-form moduli). 531 */ 532 static void vli_mmod_special(u64 *result, const u64 *product, 533 const u64 *mod, unsigned int ndigits) 534 { 535 u64 c = -mod[0]; 536 u64 t[ECC_MAX_DIGITS * 2]; 537 u64 r[ECC_MAX_DIGITS * 2]; 538 539 vli_set(r, product, ndigits * 2); 540 while (!vli_is_zero(r + ndigits, ndigits)) { 541 vli_umult(t, r + ndigits, c, ndigits); 542 vli_clear(r + ndigits, ndigits); 543 vli_add(r, r, t, ndigits * 2); 544 } 545 vli_set(t, mod, ndigits); 546 vli_clear(t + ndigits, ndigits); 547 while (vli_cmp(r, t, ndigits * 2) >= 0) 548 vli_sub(r, r, t, ndigits * 2); 549 vli_set(result, r, ndigits); 550 } 551 552 /* 553 * Computes result = product % mod 554 * for special form moduli: p = 2^{k-1}+c, for small c (note the plus sign) 555 * where k-1 does not fit into qword boundary by -1 bit (such as 255). 556 557 * References (loosely based on): 558 * A. Menezes, P. van Oorschot, S. Vanstone. Handbook of Applied Cryptography. 559 * 14.3.4 Reduction methods for moduli of special form. Algorithm 14.47. 560 * URL: http://cacr.uwaterloo.ca/hac/about/chap14.pdf 561 * 562 * H. Cohen, G. Frey, R. Avanzi, C. Doche, T. Lange, K. Nguyen, F. Vercauteren. 563 * Handbook of Elliptic and Hyperelliptic Curve Cryptography. 564 * Algorithm 10.25 Fast reduction for special form moduli 565 */ 566 static void vli_mmod_special2(u64 *result, const u64 *product, 567 const u64 *mod, unsigned int ndigits) 568 { 569 u64 c2 = mod[0] * 2; 570 u64 q[ECC_MAX_DIGITS]; 571 u64 r[ECC_MAX_DIGITS * 2]; 572 u64 m[ECC_MAX_DIGITS * 2]; /* expanded mod */ 573 int carry; /* last bit that doesn't fit into q */ 574 int i; 575 576 vli_set(m, mod, ndigits); 577 vli_clear(m + ndigits, ndigits); 578 579 vli_set(r, product, ndigits); 580 /* q and carry are top bits */ 581 vli_set(q, product + ndigits, ndigits); 582 vli_clear(r + ndigits, ndigits); 583 carry = vli_is_negative(r, ndigits); 584 if (carry) 585 r[ndigits - 1] &= (1ull << 63) - 1; 586 for (i = 1; carry || !vli_is_zero(q, ndigits); i++) { 587 u64 qc[ECC_MAX_DIGITS * 2]; 588 589 vli_umult(qc, q, c2, ndigits); 590 if (carry) 591 vli_uadd(qc, qc, mod[0], ndigits * 2); 592 vli_set(q, qc + ndigits, ndigits); 593 vli_clear(qc + ndigits, ndigits); 594 carry = vli_is_negative(qc, ndigits); 595 if (carry) 596 qc[ndigits - 1] &= (1ull << 63) - 1; 597 if (i & 1) 598 vli_sub(r, r, qc, ndigits * 2); 599 else 600 vli_add(r, r, qc, ndigits * 2); 601 } 602 while (vli_is_negative(r, ndigits * 2)) 603 vli_add(r, r, m, ndigits * 2); 604 while (vli_cmp(r, m, ndigits * 2) >= 0) 605 vli_sub(r, r, m, ndigits * 2); 606 607 vli_set(result, r, ndigits); 608 } 609 610 /* 611 * Computes result = product % mod, where product is 2N words long. 612 * Reference: Ken MacKay's micro-ecc. 613 * Currently only designed to work for curve_p or curve_n. 614 */ 615 static void vli_mmod_slow(u64 *result, u64 *product, const u64 *mod, 616 unsigned int ndigits) 617 { 618 u64 mod_m[2 * ECC_MAX_DIGITS]; 619 u64 tmp[2 * ECC_MAX_DIGITS]; 620 u64 *v[2] = { tmp, product }; 621 u64 carry = 0; 622 unsigned int i; 623 /* Shift mod so its highest set bit is at the maximum position. */ 624 int shift = (ndigits * 2 * 64) - vli_num_bits(mod, ndigits); 625 int word_shift = shift / 64; 626 int bit_shift = shift % 64; 627 628 vli_clear(mod_m, word_shift); 629 if (bit_shift > 0) { 630 for (i = 0; i < ndigits; ++i) { 631 mod_m[word_shift + i] = (mod[i] << bit_shift) | carry; 632 carry = mod[i] >> (64 - bit_shift); 633 } 634 } else 635 vli_set(mod_m + word_shift, mod, ndigits); 636 637 for (i = 1; shift >= 0; --shift) { 638 u64 borrow = 0; 639 unsigned int j; 640 641 for (j = 0; j < ndigits * 2; ++j) { 642 u64 diff = v[i][j] - mod_m[j] - borrow; 643 644 if (diff != v[i][j]) 645 borrow = (diff > v[i][j]); 646 v[1 - i][j] = diff; 647 } 648 i = !(i ^ borrow); /* Swap the index if there was no borrow */ 649 vli_rshift1(mod_m, ndigits); 650 mod_m[ndigits - 1] |= mod_m[ndigits] << (64 - 1); 651 vli_rshift1(mod_m + ndigits, ndigits); 652 } 653 vli_set(result, v[i], ndigits); 654 } 655 656 /* Computes result = product % mod using Barrett's reduction with precomputed 657 * value mu appended to the mod after ndigits, mu = (2^{2w} / mod) and have 658 * length ndigits + 1, where mu * (2^w - 1) should not overflow ndigits 659 * boundary. 660 * 661 * Reference: 662 * R. Brent, P. Zimmermann. Modern Computer Arithmetic. 2010. 663 * 2.4.1 Barrett's algorithm. Algorithm 2.5. 664 */ 665 static void vli_mmod_barrett(u64 *result, u64 *product, const u64 *mod, 666 unsigned int ndigits) 667 { 668 u64 q[ECC_MAX_DIGITS * 2]; 669 u64 r[ECC_MAX_DIGITS * 2]; 670 const u64 *mu = mod + ndigits; 671 672 vli_mult(q, product + ndigits, mu, ndigits); 673 if (mu[ndigits]) 674 vli_add(q + ndigits, q + ndigits, product + ndigits, ndigits); 675 vli_mult(r, mod, q + ndigits, ndigits); 676 vli_sub(r, product, r, ndigits * 2); 677 while (!vli_is_zero(r + ndigits, ndigits) || 678 vli_cmp(r, mod, ndigits) != -1) { 679 u64 carry; 680 681 carry = vli_sub(r, r, mod, ndigits); 682 vli_usub(r + ndigits, r + ndigits, carry, ndigits); 683 } 684 vli_set(result, r, ndigits); 685 } 686 687 /* Computes p_result = p_product % curve_p. 688 * See algorithm 5 and 6 from 689 * http://www.isys.uni-klu.ac.at/PDF/2001-0126-MT.pdf 690 */ 691 static void vli_mmod_fast_192(u64 *result, const u64 *product, 692 const u64 *curve_prime, u64 *tmp) 693 { 694 const unsigned int ndigits = ECC_CURVE_NIST_P192_DIGITS; 695 int carry; 696 697 vli_set(result, product, ndigits); 698 699 vli_set(tmp, &product[3], ndigits); 700 carry = vli_add(result, result, tmp, ndigits); 701 702 tmp[0] = 0; 703 tmp[1] = product[3]; 704 tmp[2] = product[4]; 705 carry += vli_add(result, result, tmp, ndigits); 706 707 tmp[0] = tmp[1] = product[5]; 708 tmp[2] = 0; 709 carry += vli_add(result, result, tmp, ndigits); 710 711 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 712 carry -= vli_sub(result, result, curve_prime, ndigits); 713 } 714 715 /* Computes result = product % curve_prime 716 * from http://www.nsa.gov/ia/_files/nist-routines.pdf 717 */ 718 static void vli_mmod_fast_256(u64 *result, const u64 *product, 719 const u64 *curve_prime, u64 *tmp) 720 { 721 int carry; 722 const unsigned int ndigits = ECC_CURVE_NIST_P256_DIGITS; 723 724 /* t */ 725 vli_set(result, product, ndigits); 726 727 /* s1 */ 728 tmp[0] = 0; 729 tmp[1] = product[5] & 0xffffffff00000000ull; 730 tmp[2] = product[6]; 731 tmp[3] = product[7]; 732 carry = vli_lshift(tmp, tmp, 1, ndigits); 733 carry += vli_add(result, result, tmp, ndigits); 734 735 /* s2 */ 736 tmp[1] = product[6] << 32; 737 tmp[2] = (product[6] >> 32) | (product[7] << 32); 738 tmp[3] = product[7] >> 32; 739 carry += vli_lshift(tmp, tmp, 1, ndigits); 740 carry += vli_add(result, result, tmp, ndigits); 741 742 /* s3 */ 743 tmp[0] = product[4]; 744 tmp[1] = product[5] & 0xffffffff; 745 tmp[2] = 0; 746 tmp[3] = product[7]; 747 carry += vli_add(result, result, tmp, ndigits); 748 749 /* s4 */ 750 tmp[0] = (product[4] >> 32) | (product[5] << 32); 751 tmp[1] = (product[5] >> 32) | (product[6] & 0xffffffff00000000ull); 752 tmp[2] = product[7]; 753 tmp[3] = (product[6] >> 32) | (product[4] << 32); 754 carry += vli_add(result, result, tmp, ndigits); 755 756 /* d1 */ 757 tmp[0] = (product[5] >> 32) | (product[6] << 32); 758 tmp[1] = (product[6] >> 32); 759 tmp[2] = 0; 760 tmp[3] = (product[4] & 0xffffffff) | (product[5] << 32); 761 carry -= vli_sub(result, result, tmp, ndigits); 762 763 /* d2 */ 764 tmp[0] = product[6]; 765 tmp[1] = product[7]; 766 tmp[2] = 0; 767 tmp[3] = (product[4] >> 32) | (product[5] & 0xffffffff00000000ull); 768 carry -= vli_sub(result, result, tmp, ndigits); 769 770 /* d3 */ 771 tmp[0] = (product[6] >> 32) | (product[7] << 32); 772 tmp[1] = (product[7] >> 32) | (product[4] << 32); 773 tmp[2] = (product[4] >> 32) | (product[5] << 32); 774 tmp[3] = (product[6] << 32); 775 carry -= vli_sub(result, result, tmp, ndigits); 776 777 /* d4 */ 778 tmp[0] = product[7]; 779 tmp[1] = product[4] & 0xffffffff00000000ull; 780 tmp[2] = product[5]; 781 tmp[3] = product[6] & 0xffffffff00000000ull; 782 carry -= vli_sub(result, result, tmp, ndigits); 783 784 if (carry < 0) { 785 do { 786 carry += vli_add(result, result, curve_prime, ndigits); 787 } while (carry < 0); 788 } else { 789 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 790 carry -= vli_sub(result, result, curve_prime, ndigits); 791 } 792 } 793 794 #define SL32OR32(x32, y32) (((u64)x32 << 32) | y32) 795 #define AND64H(x64) (x64 & 0xffFFffFF00000000ull) 796 #define AND64L(x64) (x64 & 0x00000000ffFFffFFull) 797 798 /* Computes result = product % curve_prime 799 * from "Mathematical routines for the NIST prime elliptic curves" 800 */ 801 static void vli_mmod_fast_384(u64 *result, const u64 *product, 802 const u64 *curve_prime, u64 *tmp) 803 { 804 int carry; 805 const unsigned int ndigits = ECC_CURVE_NIST_P384_DIGITS; 806 807 /* t */ 808 vli_set(result, product, ndigits); 809 810 /* s1 */ 811 tmp[0] = 0; // 0 || 0 812 tmp[1] = 0; // 0 || 0 813 tmp[2] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 814 tmp[3] = product[11]>>32; // 0 ||a23 815 tmp[4] = 0; // 0 || 0 816 tmp[5] = 0; // 0 || 0 817 carry = vli_lshift(tmp, tmp, 1, ndigits); 818 carry += vli_add(result, result, tmp, ndigits); 819 820 /* s2 */ 821 tmp[0] = product[6]; //a13||a12 822 tmp[1] = product[7]; //a15||a14 823 tmp[2] = product[8]; //a17||a16 824 tmp[3] = product[9]; //a19||a18 825 tmp[4] = product[10]; //a21||a20 826 tmp[5] = product[11]; //a23||a22 827 carry += vli_add(result, result, tmp, ndigits); 828 829 /* s3 */ 830 tmp[0] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 831 tmp[1] = SL32OR32(product[6], (product[11]>>32)); //a12||a23 832 tmp[2] = SL32OR32(product[7], (product[6])>>32); //a14||a13 833 tmp[3] = SL32OR32(product[8], (product[7]>>32)); //a16||a15 834 tmp[4] = SL32OR32(product[9], (product[8]>>32)); //a18||a17 835 tmp[5] = SL32OR32(product[10], (product[9]>>32)); //a20||a19 836 carry += vli_add(result, result, tmp, ndigits); 837 838 /* s4 */ 839 tmp[0] = AND64H(product[11]); //a23|| 0 840 tmp[1] = (product[10]<<32); //a20|| 0 841 tmp[2] = product[6]; //a13||a12 842 tmp[3] = product[7]; //a15||a14 843 tmp[4] = product[8]; //a17||a16 844 tmp[5] = product[9]; //a19||a18 845 carry += vli_add(result, result, tmp, ndigits); 846 847 /* s5 */ 848 tmp[0] = 0; // 0|| 0 849 tmp[1] = 0; // 0|| 0 850 tmp[2] = product[10]; //a21||a20 851 tmp[3] = product[11]; //a23||a22 852 tmp[4] = 0; // 0|| 0 853 tmp[5] = 0; // 0|| 0 854 carry += vli_add(result, result, tmp, ndigits); 855 856 /* s6 */ 857 tmp[0] = AND64L(product[10]); // 0 ||a20 858 tmp[1] = AND64H(product[10]); //a21|| 0 859 tmp[2] = product[11]; //a23||a22 860 tmp[3] = 0; // 0 || 0 861 tmp[4] = 0; // 0 || 0 862 tmp[5] = 0; // 0 || 0 863 carry += vli_add(result, result, tmp, ndigits); 864 865 /* d1 */ 866 tmp[0] = SL32OR32(product[6], (product[11]>>32)); //a12||a23 867 tmp[1] = SL32OR32(product[7], (product[6]>>32)); //a14||a13 868 tmp[2] = SL32OR32(product[8], (product[7]>>32)); //a16||a15 869 tmp[3] = SL32OR32(product[9], (product[8]>>32)); //a18||a17 870 tmp[4] = SL32OR32(product[10], (product[9]>>32)); //a20||a19 871 tmp[5] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 872 carry -= vli_sub(result, result, tmp, ndigits); 873 874 /* d2 */ 875 tmp[0] = (product[10]<<32); //a20|| 0 876 tmp[1] = SL32OR32(product[11], (product[10]>>32)); //a22||a21 877 tmp[2] = (product[11]>>32); // 0 ||a23 878 tmp[3] = 0; // 0 || 0 879 tmp[4] = 0; // 0 || 0 880 tmp[5] = 0; // 0 || 0 881 carry -= vli_sub(result, result, tmp, ndigits); 882 883 /* d3 */ 884 tmp[0] = 0; // 0 || 0 885 tmp[1] = AND64H(product[11]); //a23|| 0 886 tmp[2] = product[11]>>32; // 0 ||a23 887 tmp[3] = 0; // 0 || 0 888 tmp[4] = 0; // 0 || 0 889 tmp[5] = 0; // 0 || 0 890 carry -= vli_sub(result, result, tmp, ndigits); 891 892 if (carry < 0) { 893 do { 894 carry += vli_add(result, result, curve_prime, ndigits); 895 } while (carry < 0); 896 } else { 897 while (carry || vli_cmp(curve_prime, result, ndigits) != 1) 898 carry -= vli_sub(result, result, curve_prime, ndigits); 899 } 900 901 } 902 903 #undef SL32OR32 904 #undef AND64H 905 #undef AND64L 906 907 /* 908 * Computes result = product % curve_prime 909 * from "Recommendations for Discrete Logarithm-Based Cryptography: 910 * Elliptic Curve Domain Parameters" section G.1.4 911 */ 912 static void vli_mmod_fast_521(u64 *result, const u64 *product, 913 const u64 *curve_prime, u64 *tmp) 914 { 915 const unsigned int ndigits = ECC_CURVE_NIST_P521_DIGITS; 916 size_t i; 917 918 /* Initialize result with lowest 521 bits from product */ 919 vli_set(result, product, ndigits); 920 result[8] &= 0x1ff; 921 922 for (i = 0; i < ndigits; i++) 923 tmp[i] = (product[8 + i] >> 9) | (product[9 + i] << 55); 924 tmp[8] &= 0x1ff; 925 926 vli_mod_add(result, result, tmp, curve_prime, ndigits); 927 } 928 929 /* Computes result = product % curve_prime for different curve_primes. 930 * 931 * Note that curve_primes are distinguished just by heuristic check and 932 * not by complete conformance check. 933 */ 934 static bool vli_mmod_fast(u64 *result, u64 *product, 935 const struct ecc_curve *curve) 936 { 937 u64 tmp[2 * ECC_MAX_DIGITS]; 938 const u64 *curve_prime = curve->p; 939 const unsigned int ndigits = curve->g.ndigits; 940 941 /* All NIST curves have name prefix 'nist_' */ 942 if (strncmp(curve->name, "nist_", 5) != 0) { 943 /* Try to handle Pseudo-Marsenne primes. */ 944 if (curve_prime[ndigits - 1] == -1ull) { 945 vli_mmod_special(result, product, curve_prime, 946 ndigits); 947 return true; 948 } else if (curve_prime[ndigits - 1] == 1ull << 63 && 949 curve_prime[ndigits - 2] == 0) { 950 vli_mmod_special2(result, product, curve_prime, 951 ndigits); 952 return true; 953 } 954 vli_mmod_barrett(result, product, curve_prime, ndigits); 955 return true; 956 } 957 958 switch (ndigits) { 959 case ECC_CURVE_NIST_P192_DIGITS: 960 vli_mmod_fast_192(result, product, curve_prime, tmp); 961 break; 962 case ECC_CURVE_NIST_P256_DIGITS: 963 vli_mmod_fast_256(result, product, curve_prime, tmp); 964 break; 965 case ECC_CURVE_NIST_P384_DIGITS: 966 vli_mmod_fast_384(result, product, curve_prime, tmp); 967 break; 968 case ECC_CURVE_NIST_P521_DIGITS: 969 vli_mmod_fast_521(result, product, curve_prime, tmp); 970 break; 971 default: 972 pr_err_ratelimited("ecc: unsupported digits size!\n"); 973 return false; 974 } 975 976 return true; 977 } 978 979 /* Computes result = (left * right) % mod. 980 * Assumes that mod is big enough curve order. 981 */ 982 void vli_mod_mult_slow(u64 *result, const u64 *left, const u64 *right, 983 const u64 *mod, unsigned int ndigits) 984 { 985 u64 product[ECC_MAX_DIGITS * 2]; 986 987 vli_mult(product, left, right, ndigits); 988 vli_mmod_slow(result, product, mod, ndigits); 989 } 990 EXPORT_SYMBOL(vli_mod_mult_slow); 991 992 /* Computes result = (left * right) % curve_prime. */ 993 static void vli_mod_mult_fast(u64 *result, const u64 *left, const u64 *right, 994 const struct ecc_curve *curve) 995 { 996 u64 product[2 * ECC_MAX_DIGITS]; 997 998 vli_mult(product, left, right, curve->g.ndigits); 999 vli_mmod_fast(result, product, curve); 1000 } 1001 1002 /* Computes result = left^2 % curve_prime. */ 1003 static void vli_mod_square_fast(u64 *result, const u64 *left, 1004 const struct ecc_curve *curve) 1005 { 1006 u64 product[2 * ECC_MAX_DIGITS]; 1007 1008 vli_square(product, left, curve->g.ndigits); 1009 vli_mmod_fast(result, product, curve); 1010 } 1011 1012 #define EVEN(vli) (!(vli[0] & 1)) 1013 /* Computes result = (1 / p_input) % mod. All VLIs are the same size. 1014 * See "From Euclid's GCD to Montgomery Multiplication to the Great Divide" 1015 * https://labs.oracle.com/techrep/2001/smli_tr-2001-95.pdf 1016 */ 1017 void vli_mod_inv(u64 *result, const u64 *input, const u64 *mod, 1018 unsigned int ndigits) 1019 { 1020 u64 a[ECC_MAX_DIGITS], b[ECC_MAX_DIGITS]; 1021 u64 u[ECC_MAX_DIGITS], v[ECC_MAX_DIGITS]; 1022 u64 carry; 1023 int cmp_result; 1024 1025 if (vli_is_zero(input, ndigits)) { 1026 vli_clear(result, ndigits); 1027 return; 1028 } 1029 1030 vli_set(a, input, ndigits); 1031 vli_set(b, mod, ndigits); 1032 vli_clear(u, ndigits); 1033 u[0] = 1; 1034 vli_clear(v, ndigits); 1035 1036 while ((cmp_result = vli_cmp(a, b, ndigits)) != 0) { 1037 carry = 0; 1038 1039 if (EVEN(a)) { 1040 vli_rshift1(a, ndigits); 1041 1042 if (!EVEN(u)) 1043 carry = vli_add(u, u, mod, ndigits); 1044 1045 vli_rshift1(u, ndigits); 1046 if (carry) 1047 u[ndigits - 1] |= 0x8000000000000000ull; 1048 } else if (EVEN(b)) { 1049 vli_rshift1(b, ndigits); 1050 1051 if (!EVEN(v)) 1052 carry = vli_add(v, v, mod, ndigits); 1053 1054 vli_rshift1(v, ndigits); 1055 if (carry) 1056 v[ndigits - 1] |= 0x8000000000000000ull; 1057 } else if (cmp_result > 0) { 1058 vli_sub(a, a, b, ndigits); 1059 vli_rshift1(a, ndigits); 1060 1061 if (vli_cmp(u, v, ndigits) < 0) 1062 vli_add(u, u, mod, ndigits); 1063 1064 vli_sub(u, u, v, ndigits); 1065 if (!EVEN(u)) 1066 carry = vli_add(u, u, mod, ndigits); 1067 1068 vli_rshift1(u, ndigits); 1069 if (carry) 1070 u[ndigits - 1] |= 0x8000000000000000ull; 1071 } else { 1072 vli_sub(b, b, a, ndigits); 1073 vli_rshift1(b, ndigits); 1074 1075 if (vli_cmp(v, u, ndigits) < 0) 1076 vli_add(v, v, mod, ndigits); 1077 1078 vli_sub(v, v, u, ndigits); 1079 if (!EVEN(v)) 1080 carry = vli_add(v, v, mod, ndigits); 1081 1082 vli_rshift1(v, ndigits); 1083 if (carry) 1084 v[ndigits - 1] |= 0x8000000000000000ull; 1085 } 1086 } 1087 1088 vli_set(result, u, ndigits); 1089 } 1090 EXPORT_SYMBOL(vli_mod_inv); 1091 1092 /* ------ Point operations ------ */ 1093 1094 /* Returns true if p_point is the point at infinity, false otherwise. */ 1095 bool ecc_point_is_zero(const struct ecc_point *point) 1096 { 1097 return (vli_is_zero(point->x, point->ndigits) && 1098 vli_is_zero(point->y, point->ndigits)); 1099 } 1100 EXPORT_SYMBOL(ecc_point_is_zero); 1101 1102 /* Point multiplication algorithm using Montgomery's ladder with co-Z 1103 * coordinates. From https://eprint.iacr.org/2011/338.pdf 1104 */ 1105 1106 /* Double in place */ 1107 static void ecc_point_double_jacobian(u64 *x1, u64 *y1, u64 *z1, 1108 const struct ecc_curve *curve) 1109 { 1110 /* t1 = x, t2 = y, t3 = z */ 1111 u64 t4[ECC_MAX_DIGITS]; 1112 u64 t5[ECC_MAX_DIGITS]; 1113 const u64 *curve_prime = curve->p; 1114 const unsigned int ndigits = curve->g.ndigits; 1115 1116 if (vli_is_zero(z1, ndigits)) 1117 return; 1118 1119 /* t4 = y1^2 */ 1120 vli_mod_square_fast(t4, y1, curve); 1121 /* t5 = x1*y1^2 = A */ 1122 vli_mod_mult_fast(t5, x1, t4, curve); 1123 /* t4 = y1^4 */ 1124 vli_mod_square_fast(t4, t4, curve); 1125 /* t2 = y1*z1 = z3 */ 1126 vli_mod_mult_fast(y1, y1, z1, curve); 1127 /* t3 = z1^2 */ 1128 vli_mod_square_fast(z1, z1, curve); 1129 1130 /* t1 = x1 + z1^2 */ 1131 vli_mod_add(x1, x1, z1, curve_prime, ndigits); 1132 /* t3 = 2*z1^2 */ 1133 vli_mod_add(z1, z1, z1, curve_prime, ndigits); 1134 /* t3 = x1 - z1^2 */ 1135 vli_mod_sub(z1, x1, z1, curve_prime, ndigits); 1136 /* t1 = x1^2 - z1^4 */ 1137 vli_mod_mult_fast(x1, x1, z1, curve); 1138 1139 /* t3 = 2*(x1^2 - z1^4) */ 1140 vli_mod_add(z1, x1, x1, curve_prime, ndigits); 1141 /* t1 = 3*(x1^2 - z1^4) */ 1142 vli_mod_add(x1, x1, z1, curve_prime, ndigits); 1143 if (vli_test_bit(x1, 0)) { 1144 u64 carry = vli_add(x1, x1, curve_prime, ndigits); 1145 1146 vli_rshift1(x1, ndigits); 1147 x1[ndigits - 1] |= carry << 63; 1148 } else { 1149 vli_rshift1(x1, ndigits); 1150 } 1151 /* t1 = 3/2*(x1^2 - z1^4) = B */ 1152 1153 /* t3 = B^2 */ 1154 vli_mod_square_fast(z1, x1, curve); 1155 /* t3 = B^2 - A */ 1156 vli_mod_sub(z1, z1, t5, curve_prime, ndigits); 1157 /* t3 = B^2 - 2A = x3 */ 1158 vli_mod_sub(z1, z1, t5, curve_prime, ndigits); 1159 /* t5 = A - x3 */ 1160 vli_mod_sub(t5, t5, z1, curve_prime, ndigits); 1161 /* t1 = B * (A - x3) */ 1162 vli_mod_mult_fast(x1, x1, t5, curve); 1163 /* t4 = B * (A - x3) - y1^4 = y3 */ 1164 vli_mod_sub(t4, x1, t4, curve_prime, ndigits); 1165 1166 vli_set(x1, z1, ndigits); 1167 vli_set(z1, y1, ndigits); 1168 vli_set(y1, t4, ndigits); 1169 } 1170 1171 /* Modify (x1, y1) => (x1 * z^2, y1 * z^3) */ 1172 static void apply_z(u64 *x1, u64 *y1, u64 *z, const struct ecc_curve *curve) 1173 { 1174 u64 t1[ECC_MAX_DIGITS]; 1175 1176 vli_mod_square_fast(t1, z, curve); /* z^2 */ 1177 vli_mod_mult_fast(x1, x1, t1, curve); /* x1 * z^2 */ 1178 vli_mod_mult_fast(t1, t1, z, curve); /* z^3 */ 1179 vli_mod_mult_fast(y1, y1, t1, curve); /* y1 * z^3 */ 1180 } 1181 1182 /* P = (x1, y1) => 2P, (x2, y2) => P' */ 1183 static void xycz_initial_double(u64 *x1, u64 *y1, u64 *x2, u64 *y2, 1184 u64 *p_initial_z, const struct ecc_curve *curve) 1185 { 1186 u64 z[ECC_MAX_DIGITS]; 1187 const unsigned int ndigits = curve->g.ndigits; 1188 1189 vli_set(x2, x1, ndigits); 1190 vli_set(y2, y1, ndigits); 1191 1192 vli_clear(z, ndigits); 1193 z[0] = 1; 1194 1195 if (p_initial_z) 1196 vli_set(z, p_initial_z, ndigits); 1197 1198 apply_z(x1, y1, z, curve); 1199 1200 ecc_point_double_jacobian(x1, y1, z, curve); 1201 1202 apply_z(x2, y2, z, curve); 1203 } 1204 1205 /* Input P = (x1, y1, Z), Q = (x2, y2, Z) 1206 * Output P' = (x1', y1', Z3), P + Q = (x3, y3, Z3) 1207 * or P => P', Q => P + Q 1208 */ 1209 static void xycz_add(u64 *x1, u64 *y1, u64 *x2, u64 *y2, 1210 const struct ecc_curve *curve) 1211 { 1212 /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */ 1213 u64 t5[ECC_MAX_DIGITS]; 1214 const u64 *curve_prime = curve->p; 1215 const unsigned int ndigits = curve->g.ndigits; 1216 1217 /* t5 = x2 - x1 */ 1218 vli_mod_sub(t5, x2, x1, curve_prime, ndigits); 1219 /* t5 = (x2 - x1)^2 = A */ 1220 vli_mod_square_fast(t5, t5, curve); 1221 /* t1 = x1*A = B */ 1222 vli_mod_mult_fast(x1, x1, t5, curve); 1223 /* t3 = x2*A = C */ 1224 vli_mod_mult_fast(x2, x2, t5, curve); 1225 /* t4 = y2 - y1 */ 1226 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1227 /* t5 = (y2 - y1)^2 = D */ 1228 vli_mod_square_fast(t5, y2, curve); 1229 1230 /* t5 = D - B */ 1231 vli_mod_sub(t5, t5, x1, curve_prime, ndigits); 1232 /* t5 = D - B - C = x3 */ 1233 vli_mod_sub(t5, t5, x2, curve_prime, ndigits); 1234 /* t3 = C - B */ 1235 vli_mod_sub(x2, x2, x1, curve_prime, ndigits); 1236 /* t2 = y1*(C - B) */ 1237 vli_mod_mult_fast(y1, y1, x2, curve); 1238 /* t3 = B - x3 */ 1239 vli_mod_sub(x2, x1, t5, curve_prime, ndigits); 1240 /* t4 = (y2 - y1)*(B - x3) */ 1241 vli_mod_mult_fast(y2, y2, x2, curve); 1242 /* t4 = y3 */ 1243 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1244 1245 vli_set(x2, t5, ndigits); 1246 } 1247 1248 /* Input P = (x1, y1, Z), Q = (x2, y2, Z) 1249 * Output P + Q = (x3, y3, Z3), P - Q = (x3', y3', Z3) 1250 * or P => P - Q, Q => P + Q 1251 */ 1252 static void xycz_add_c(u64 *x1, u64 *y1, u64 *x2, u64 *y2, 1253 const struct ecc_curve *curve) 1254 { 1255 /* t1 = X1, t2 = Y1, t3 = X2, t4 = Y2 */ 1256 u64 t5[ECC_MAX_DIGITS]; 1257 u64 t6[ECC_MAX_DIGITS]; 1258 u64 t7[ECC_MAX_DIGITS]; 1259 const u64 *curve_prime = curve->p; 1260 const unsigned int ndigits = curve->g.ndigits; 1261 1262 /* t5 = x2 - x1 */ 1263 vli_mod_sub(t5, x2, x1, curve_prime, ndigits); 1264 /* t5 = (x2 - x1)^2 = A */ 1265 vli_mod_square_fast(t5, t5, curve); 1266 /* t1 = x1*A = B */ 1267 vli_mod_mult_fast(x1, x1, t5, curve); 1268 /* t3 = x2*A = C */ 1269 vli_mod_mult_fast(x2, x2, t5, curve); 1270 /* t4 = y2 + y1 */ 1271 vli_mod_add(t5, y2, y1, curve_prime, ndigits); 1272 /* t4 = y2 - y1 */ 1273 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1274 1275 /* t6 = C - B */ 1276 vli_mod_sub(t6, x2, x1, curve_prime, ndigits); 1277 /* t2 = y1 * (C - B) */ 1278 vli_mod_mult_fast(y1, y1, t6, curve); 1279 /* t6 = B + C */ 1280 vli_mod_add(t6, x1, x2, curve_prime, ndigits); 1281 /* t3 = (y2 - y1)^2 */ 1282 vli_mod_square_fast(x2, y2, curve); 1283 /* t3 = x3 */ 1284 vli_mod_sub(x2, x2, t6, curve_prime, ndigits); 1285 1286 /* t7 = B - x3 */ 1287 vli_mod_sub(t7, x1, x2, curve_prime, ndigits); 1288 /* t4 = (y2 - y1)*(B - x3) */ 1289 vli_mod_mult_fast(y2, y2, t7, curve); 1290 /* t4 = y3 */ 1291 vli_mod_sub(y2, y2, y1, curve_prime, ndigits); 1292 1293 /* t7 = (y2 + y1)^2 = F */ 1294 vli_mod_square_fast(t7, t5, curve); 1295 /* t7 = x3' */ 1296 vli_mod_sub(t7, t7, t6, curve_prime, ndigits); 1297 /* t6 = x3' - B */ 1298 vli_mod_sub(t6, t7, x1, curve_prime, ndigits); 1299 /* t6 = (y2 + y1)*(x3' - B) */ 1300 vli_mod_mult_fast(t6, t6, t5, curve); 1301 /* t2 = y3' */ 1302 vli_mod_sub(y1, t6, y1, curve_prime, ndigits); 1303 1304 vli_set(x1, t7, ndigits); 1305 } 1306 1307 static void ecc_point_mult(struct ecc_point *result, 1308 const struct ecc_point *point, const u64 *scalar, 1309 u64 *initial_z, const struct ecc_curve *curve, 1310 unsigned int ndigits) 1311 { 1312 /* R0 and R1 */ 1313 u64 rx[2][ECC_MAX_DIGITS]; 1314 u64 ry[2][ECC_MAX_DIGITS]; 1315 u64 z[ECC_MAX_DIGITS]; 1316 u64 sk[2][ECC_MAX_DIGITS]; 1317 u64 *curve_prime = curve->p; 1318 int i, nb; 1319 int num_bits; 1320 int carry; 1321 1322 carry = vli_add(sk[0], scalar, curve->n, ndigits); 1323 vli_add(sk[1], sk[0], curve->n, ndigits); 1324 scalar = sk[!carry]; 1325 if (curve->nbits == 521) /* NIST P521 */ 1326 num_bits = curve->nbits + 2; 1327 else 1328 num_bits = sizeof(u64) * ndigits * 8 + 1; 1329 1330 vli_set(rx[1], point->x, ndigits); 1331 vli_set(ry[1], point->y, ndigits); 1332 1333 xycz_initial_double(rx[1], ry[1], rx[0], ry[0], initial_z, curve); 1334 1335 for (i = num_bits - 2; i > 0; i--) { 1336 nb = !vli_test_bit(scalar, i); 1337 xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve); 1338 xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve); 1339 } 1340 1341 nb = !vli_test_bit(scalar, 0); 1342 xycz_add_c(rx[1 - nb], ry[1 - nb], rx[nb], ry[nb], curve); 1343 1344 /* Find final 1/Z value. */ 1345 /* X1 - X0 */ 1346 vli_mod_sub(z, rx[1], rx[0], curve_prime, ndigits); 1347 /* Yb * (X1 - X0) */ 1348 vli_mod_mult_fast(z, z, ry[1 - nb], curve); 1349 /* xP * Yb * (X1 - X0) */ 1350 vli_mod_mult_fast(z, z, point->x, curve); 1351 1352 /* 1 / (xP * Yb * (X1 - X0)) */ 1353 vli_mod_inv(z, z, curve_prime, point->ndigits); 1354 1355 /* yP / (xP * Yb * (X1 - X0)) */ 1356 vli_mod_mult_fast(z, z, point->y, curve); 1357 /* Xb * yP / (xP * Yb * (X1 - X0)) */ 1358 vli_mod_mult_fast(z, z, rx[1 - nb], curve); 1359 /* End 1/Z calculation */ 1360 1361 xycz_add(rx[nb], ry[nb], rx[1 - nb], ry[1 - nb], curve); 1362 1363 apply_z(rx[0], ry[0], z, curve); 1364 1365 vli_set(result->x, rx[0], ndigits); 1366 vli_set(result->y, ry[0], ndigits); 1367 } 1368 1369 /* Computes R = P + Q mod p */ 1370 static void ecc_point_add(const struct ecc_point *result, 1371 const struct ecc_point *p, const struct ecc_point *q, 1372 const struct ecc_curve *curve) 1373 { 1374 u64 z[ECC_MAX_DIGITS]; 1375 u64 px[ECC_MAX_DIGITS]; 1376 u64 py[ECC_MAX_DIGITS]; 1377 unsigned int ndigits = curve->g.ndigits; 1378 1379 vli_set(result->x, q->x, ndigits); 1380 vli_set(result->y, q->y, ndigits); 1381 vli_mod_sub(z, result->x, p->x, curve->p, ndigits); 1382 vli_set(px, p->x, ndigits); 1383 vli_set(py, p->y, ndigits); 1384 xycz_add(px, py, result->x, result->y, curve); 1385 vli_mod_inv(z, z, curve->p, ndigits); 1386 apply_z(result->x, result->y, z, curve); 1387 } 1388 1389 /* Computes R = u1P + u2Q mod p using Shamir's trick. 1390 * Based on: Kenneth MacKay's micro-ecc (2014). 1391 */ 1392 void ecc_point_mult_shamir(const struct ecc_point *result, 1393 const u64 *u1, const struct ecc_point *p, 1394 const u64 *u2, const struct ecc_point *q, 1395 const struct ecc_curve *curve) 1396 { 1397 u64 z[ECC_MAX_DIGITS]; 1398 u64 sump[2][ECC_MAX_DIGITS]; 1399 u64 *rx = result->x; 1400 u64 *ry = result->y; 1401 unsigned int ndigits = curve->g.ndigits; 1402 unsigned int num_bits; 1403 struct ecc_point sum = ECC_POINT_INIT(sump[0], sump[1], ndigits); 1404 const struct ecc_point *points[4]; 1405 const struct ecc_point *point; 1406 unsigned int idx; 1407 int i; 1408 1409 ecc_point_add(&sum, p, q, curve); 1410 points[0] = NULL; 1411 points[1] = p; 1412 points[2] = q; 1413 points[3] = ∑ 1414 1415 num_bits = max(vli_num_bits(u1, ndigits), vli_num_bits(u2, ndigits)); 1416 i = num_bits - 1; 1417 idx = !!vli_test_bit(u1, i); 1418 idx |= (!!vli_test_bit(u2, i)) << 1; 1419 point = points[idx]; 1420 1421 vli_set(rx, point->x, ndigits); 1422 vli_set(ry, point->y, ndigits); 1423 vli_clear(z + 1, ndigits - 1); 1424 z[0] = 1; 1425 1426 for (--i; i >= 0; i--) { 1427 ecc_point_double_jacobian(rx, ry, z, curve); 1428 idx = !!vli_test_bit(u1, i); 1429 idx |= (!!vli_test_bit(u2, i)) << 1; 1430 point = points[idx]; 1431 if (point) { 1432 u64 tx[ECC_MAX_DIGITS]; 1433 u64 ty[ECC_MAX_DIGITS]; 1434 u64 tz[ECC_MAX_DIGITS]; 1435 1436 vli_set(tx, point->x, ndigits); 1437 vli_set(ty, point->y, ndigits); 1438 apply_z(tx, ty, z, curve); 1439 vli_mod_sub(tz, rx, tx, curve->p, ndigits); 1440 xycz_add(tx, ty, rx, ry, curve); 1441 vli_mod_mult_fast(z, z, tz, curve); 1442 } 1443 } 1444 vli_mod_inv(z, z, curve->p, ndigits); 1445 apply_z(rx, ry, z, curve); 1446 } 1447 EXPORT_SYMBOL(ecc_point_mult_shamir); 1448 1449 /* 1450 * This function performs checks equivalent to Appendix A.4.2 of FIPS 186-5. 1451 * Whereas A.4.2 results in an integer in the interval [1, n-1], this function 1452 * ensures that the integer is in the range of [2, n-3]. We are slightly 1453 * stricter because of the currently used scalar multiplication algorithm. 1454 */ 1455 static int __ecc_is_key_valid(const struct ecc_curve *curve, 1456 const u64 *private_key, unsigned int ndigits) 1457 { 1458 u64 one[ECC_MAX_DIGITS] = { 1, }; 1459 u64 res[ECC_MAX_DIGITS]; 1460 1461 if (!private_key) 1462 return -EINVAL; 1463 1464 if (curve->g.ndigits != ndigits) 1465 return -EINVAL; 1466 1467 /* Make sure the private key is in the range [2, n-3]. */ 1468 if (vli_cmp(one, private_key, ndigits) != -1) 1469 return -EINVAL; 1470 vli_sub(res, curve->n, one, ndigits); 1471 vli_sub(res, res, one, ndigits); 1472 if (vli_cmp(res, private_key, ndigits) != 1) 1473 return -EINVAL; 1474 1475 return 0; 1476 } 1477 1478 int ecc_is_key_valid(unsigned int curve_id, unsigned int ndigits, 1479 const u64 *private_key, unsigned int private_key_len) 1480 { 1481 int nbytes; 1482 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1483 1484 nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 1485 1486 if (private_key_len != nbytes) 1487 return -EINVAL; 1488 1489 return __ecc_is_key_valid(curve, private_key, ndigits); 1490 } 1491 EXPORT_SYMBOL(ecc_is_key_valid); 1492 1493 /* 1494 * ECC private keys are generated using the method of rejection sampling, 1495 * equivalent to that described in FIPS 186-5, Appendix A.2.2. 1496 * 1497 * This method generates a private key uniformly distributed in the range 1498 * [2, n-3]. 1499 */ 1500 int ecc_gen_privkey(unsigned int curve_id, unsigned int ndigits, 1501 u64 *private_key) 1502 { 1503 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1504 unsigned int nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 1505 unsigned int nbits = vli_num_bits(curve->n, ndigits); 1506 int err; 1507 1508 /* 1509 * Step 1 & 2: check that N is included in Table 1 of FIPS 186-5, 1510 * section 6.1.1. 1511 */ 1512 if (nbits < 224) 1513 return -EINVAL; 1514 1515 /* 1516 * FIPS 186-5 recommends that the private key should be obtained from a 1517 * RBG with a security strength equal to or greater than the security 1518 * strength associated with N. 1519 * 1520 * The maximum security strength identified by NIST SP800-57pt1r4 for 1521 * ECC is 256 (N >= 512). 1522 * 1523 * This condition is met by the default RNG because it selects a favored 1524 * DRBG with a security strength of 256. 1525 */ 1526 if (crypto_get_default_rng()) 1527 return -EFAULT; 1528 1529 /* Step 3: obtain N returned_bits from the DRBG. */ 1530 err = crypto_rng_get_bytes(crypto_default_rng, 1531 (u8 *)private_key, nbytes); 1532 crypto_put_default_rng(); 1533 if (err) 1534 return err; 1535 1536 /* Step 4: make sure the private key is in the valid range. */ 1537 if (__ecc_is_key_valid(curve, private_key, ndigits)) 1538 return -EINVAL; 1539 1540 return 0; 1541 } 1542 EXPORT_SYMBOL(ecc_gen_privkey); 1543 1544 int ecc_make_pub_key(unsigned int curve_id, unsigned int ndigits, 1545 const u64 *private_key, u64 *public_key) 1546 { 1547 int ret = 0; 1548 struct ecc_point *pk; 1549 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1550 1551 if (!private_key) { 1552 ret = -EINVAL; 1553 goto out; 1554 } 1555 1556 pk = ecc_alloc_point(ndigits); 1557 if (!pk) { 1558 ret = -ENOMEM; 1559 goto out; 1560 } 1561 1562 ecc_point_mult(pk, &curve->g, private_key, NULL, curve, ndigits); 1563 1564 /* SP800-56A rev 3 5.6.2.1.3 key check */ 1565 if (ecc_is_pubkey_valid_full(curve, pk)) { 1566 ret = -EAGAIN; 1567 goto err_free_point; 1568 } 1569 1570 ecc_swap_digits(pk->x, public_key, ndigits); 1571 ecc_swap_digits(pk->y, &public_key[ndigits], ndigits); 1572 1573 err_free_point: 1574 ecc_free_point(pk); 1575 out: 1576 return ret; 1577 } 1578 EXPORT_SYMBOL(ecc_make_pub_key); 1579 1580 /* SP800-56A section 5.6.2.3.4 partial verification: ephemeral keys only */ 1581 int ecc_is_pubkey_valid_partial(const struct ecc_curve *curve, 1582 struct ecc_point *pk) 1583 { 1584 u64 yy[ECC_MAX_DIGITS], xxx[ECC_MAX_DIGITS], w[ECC_MAX_DIGITS]; 1585 1586 if (WARN_ON(pk->ndigits != curve->g.ndigits)) 1587 return -EINVAL; 1588 1589 /* Check 1: Verify key is not the zero point. */ 1590 if (ecc_point_is_zero(pk)) 1591 return -EINVAL; 1592 1593 /* Check 2: Verify key is in the range [1, p-1]. */ 1594 if (vli_cmp(curve->p, pk->x, pk->ndigits) != 1) 1595 return -EINVAL; 1596 if (vli_cmp(curve->p, pk->y, pk->ndigits) != 1) 1597 return -EINVAL; 1598 1599 /* Check 3: Verify that y^2 == (x^3 + a·x + b) mod p */ 1600 vli_mod_square_fast(yy, pk->y, curve); /* y^2 */ 1601 vli_mod_square_fast(xxx, pk->x, curve); /* x^2 */ 1602 vli_mod_mult_fast(xxx, xxx, pk->x, curve); /* x^3 */ 1603 vli_mod_mult_fast(w, curve->a, pk->x, curve); /* a·x */ 1604 vli_mod_add(w, w, curve->b, curve->p, pk->ndigits); /* a·x + b */ 1605 vli_mod_add(w, w, xxx, curve->p, pk->ndigits); /* x^3 + a·x + b */ 1606 if (vli_cmp(yy, w, pk->ndigits) != 0) /* Equation */ 1607 return -EINVAL; 1608 1609 return 0; 1610 } 1611 EXPORT_SYMBOL(ecc_is_pubkey_valid_partial); 1612 1613 /* SP800-56A section 5.6.2.3.3 full verification */ 1614 int ecc_is_pubkey_valid_full(const struct ecc_curve *curve, 1615 struct ecc_point *pk) 1616 { 1617 struct ecc_point *nQ; 1618 1619 /* Checks 1 through 3 */ 1620 int ret = ecc_is_pubkey_valid_partial(curve, pk); 1621 1622 if (ret) 1623 return ret; 1624 1625 /* Check 4: Verify that nQ is the zero point. */ 1626 nQ = ecc_alloc_point(pk->ndigits); 1627 if (!nQ) 1628 return -ENOMEM; 1629 1630 ecc_point_mult(nQ, pk, curve->n, NULL, curve, pk->ndigits); 1631 if (!ecc_point_is_zero(nQ)) 1632 ret = -EINVAL; 1633 1634 ecc_free_point(nQ); 1635 1636 return ret; 1637 } 1638 EXPORT_SYMBOL(ecc_is_pubkey_valid_full); 1639 1640 int crypto_ecdh_shared_secret(unsigned int curve_id, unsigned int ndigits, 1641 const u64 *private_key, const u64 *public_key, 1642 u64 *secret) 1643 { 1644 int ret = 0; 1645 struct ecc_point *product, *pk; 1646 u64 rand_z[ECC_MAX_DIGITS]; 1647 unsigned int nbytes; 1648 const struct ecc_curve *curve = ecc_get_curve(curve_id); 1649 1650 if (!private_key || !public_key || ndigits > ARRAY_SIZE(rand_z)) { 1651 ret = -EINVAL; 1652 goto out; 1653 } 1654 1655 nbytes = ndigits << ECC_DIGITS_TO_BYTES_SHIFT; 1656 1657 get_random_bytes(rand_z, nbytes); 1658 1659 pk = ecc_alloc_point(ndigits); 1660 if (!pk) { 1661 ret = -ENOMEM; 1662 goto out; 1663 } 1664 1665 ecc_swap_digits(public_key, pk->x, ndigits); 1666 ecc_swap_digits(&public_key[ndigits], pk->y, ndigits); 1667 ret = ecc_is_pubkey_valid_partial(curve, pk); 1668 if (ret) 1669 goto err_alloc_product; 1670 1671 product = ecc_alloc_point(ndigits); 1672 if (!product) { 1673 ret = -ENOMEM; 1674 goto err_alloc_product; 1675 } 1676 1677 ecc_point_mult(product, pk, private_key, rand_z, curve, ndigits); 1678 1679 if (ecc_point_is_zero(product)) { 1680 ret = -EFAULT; 1681 goto err_validity; 1682 } 1683 1684 ecc_swap_digits(product->x, secret, ndigits); 1685 1686 err_validity: 1687 memzero_explicit(rand_z, sizeof(rand_z)); 1688 ecc_free_point(product); 1689 err_alloc_product: 1690 ecc_free_point(pk); 1691 out: 1692 return ret; 1693 } 1694 EXPORT_SYMBOL(crypto_ecdh_shared_secret); 1695 1696 MODULE_LICENSE("Dual BSD/GPL"); 1697