1#/* 2# * Copyright (C) 2017 - This file is part of libecc project 3# * 4# * Authors: 5# * Ryad BENADJILA <ryadbenadjila@gmail.com> 6# * Arnaud EBALARD <arnaud.ebalard@ssi.gouv.fr> 7# * Jean-Pierre FLORI <jean-pierre.flori@ssi.gouv.fr> 8# * 9# * Contributors: 10# * Nicolas VIVET <nicolas.vivet@ssi.gouv.fr> 11# * Karim KHALFALLAH <karim.khalfallah@ssi.gouv.fr> 12# * 13# * This software is licensed under a dual BSD and GPL v2 license. 14# * See LICENSE file at the root folder of the project. 15# */ 16#! /usr/bin/env python 17 18import random, sys, re, math, os, getopt, glob, copy, hashlib, binascii, string, signal, base64 19 20# External dependecy for SHA-3 21# It is an independent module, since hashlib has no support 22# for SHA-3 functions for now 23import sha3 24 25# Handle Python 2/3 issues 26def is_python_2(): 27 if sys.version_info[0] < 3: 28 return True 29 else: 30 return False 31 32### Ctrl-C handler 33def handler(signal, frame): 34 print("\nSIGINT caught: exiting ...") 35 exit(0) 36 37# Helper to ask the user for something 38def get_user_input(prompt): 39 # Handle the Python 2/3 issue 40 if is_python_2() == False: 41 return input(prompt) 42 else: 43 return raw_input(prompt) 44 45########################################################## 46#### Math helpers 47def egcd(b, n): 48 x0, x1, y0, y1 = 1, 0, 0, 1 49 while n != 0: 50 q, b, n = b // n, n, b % n 51 x0, x1 = x1, x0 - q * x1 52 y0, y1 = y1, y0 - q * y1 53 return b, x0, y0 54 55def modinv(a, m): 56 g, x, y = egcd(a, m) 57 if g != 1: 58 raise Exception("Error: modular inverse does not exist") 59 else: 60 return x % m 61 62def compute_monty_coef(prime, pbitlen, wlen): 63 """ 64 Compute montgomery coeff r, r^2 and mpinv. pbitlen is the size 65 of p in bits. It is expected to be a multiple of word 66 bit size. 67 """ 68 r = (1 << int(pbitlen)) % prime 69 r_square = (1 << (2 * int(pbitlen))) % prime 70 mpinv = 2**wlen - (modinv(prime, 2**wlen)) 71 return r, r_square, mpinv 72 73def compute_div_coef(prime, pbitlen, wlen): 74 """ 75 Compute division coeffs p_normalized, p_shift and p_reciprocal. 76 """ 77 tmp = prime 78 cnt = 0 79 while tmp != 0: 80 tmp = tmp >> 1 81 cnt += 1 82 pshift = int(pbitlen - cnt) 83 primenorm = prime << pshift 84 B = 2**wlen 85 prec = B**3 // ((primenorm >> int(pbitlen - 2*wlen)) + 1) - B 86 return pshift, primenorm, prec 87 88def is_probprime(n): 89 # ensure n is odd 90 if n % 2 == 0: 91 return False 92 # write n-1 as 2**s * d 93 # repeatedly try to divide n-1 by 2 94 s = 0 95 d = n-1 96 while True: 97 quotient, remainder = divmod(d, 2) 98 if remainder == 1: 99 break 100 s += 1 101 d = quotient 102 assert(2**s * d == n-1) 103 # test the base a to see whether it is a witness for the compositeness of n 104 def try_composite(a): 105 if pow(a, d, n) == 1: 106 return False 107 for i in range(s): 108 if pow(a, 2**i * d, n) == n-1: 109 return False 110 return True # n is definitely composite 111 for i in range(5): 112 a = random.randrange(2, n) 113 if try_composite(a): 114 return False 115 return True # no base tested showed n as composite 116 117def legendre_symbol(a, p): 118 ls = pow(a, (p - 1) // 2, p) 119 return -1 if ls == p - 1 else ls 120 121# Tonelli-Shanks algorithm to find square roots 122# over prime fields 123def mod_sqrt(a, p): 124 # Square root of 0 is 0 125 if a == 0: 126 return 0 127 # Simple cases 128 if legendre_symbol(a, p) != 1: 129 # No square residue 130 return None 131 elif p == 2: 132 return a 133 elif p % 4 == 3: 134 return pow(a, (p + 1) // 4, p) 135 s = p - 1 136 e = 0 137 while s % 2 == 0: 138 s = s // 2 139 e += 1 140 n = 2 141 while legendre_symbol(n, p) != -1: 142 n += 1 143 x = pow(a, (s + 1) // 2, p) 144 b = pow(a, s, p) 145 g = pow(n, s, p) 146 r = e 147 while True: 148 t = b 149 m = 0 150 if is_python_2(): 151 for m in xrange(r): 152 if t == 1: 153 break 154 t = pow(t, 2, p) 155 else: 156 for m in range(r): 157 if t == 1: 158 break 159 t = pow(t, 2, p) 160 if m == 0: 161 return x 162 gs = pow(g, 2 ** (r - m - 1), p) 163 g = (gs * gs) % p 164 x = (x * gs) % p 165 b = (b * g) % p 166 r = m 167 168########################################################## 169### Math elliptic curves basic blocks 170 171# WARNING: these blocks are only here for testing purpose and 172# are not intended to be used in a security oriented library! 173# This explains the usage of naive affine coordinates fomulas 174class Curve(object): 175 def __init__(self, a, b, prime, order, cofactor, gx, gy, npoints, name, oid): 176 self.a = a 177 self.b = b 178 self.p = prime 179 self.q = order 180 self.c = cofactor 181 self.gx = gx 182 self.gy = gy 183 self.n = npoints 184 self.name = name 185 self.oid = oid 186 # Equality testing 187 def __eq__(self, other): 188 return self.__dict__ == other.__dict__ 189 # Deep copy is implemented using the ~X operator 190 def __invert__(self): 191 return copy.deepcopy(self) 192 193 194class Point(object): 195 # Affine coordinates (x, y), infinity point is (None, None) 196 def __init__(self, curve, x, y): 197 self.curve = curve 198 if x != None: 199 self.x = (x % curve.p) 200 else: 201 self.x = None 202 if y != None: 203 self.y = (y % curve.p) 204 else: 205 self.y = None 206 # Check that the point is indeed on the curve 207 if (x != None): 208 if (pow(y, 2, curve.p) != ((pow(x, 3, curve.p) + (curve.a * x) + curve.b ) % curve.p)): 209 raise Exception("Error: point is not on curve!") 210 # Addition 211 def __add__(self, Q): 212 x1 = self.x 213 y1 = self.y 214 x2 = Q.x 215 y2 = Q.y 216 curve = self.curve 217 # Check that we are on the same curve 218 if Q.curve != curve: 219 raise Exception("Point add error: two point don't have the same curve") 220 # If Q is infinity point, return ourself 221 if Q.x == None: 222 return Point(self.curve, self.x, self.y) 223 # If we are the infinity point return Q 224 if self.x == None: 225 return Q 226 # Infinity point or Doubling 227 if (x1 == x2): 228 if (((y1 + y2) % curve.p) == 0): 229 # Return infinity point 230 return Point(self.curve, None, None) 231 else: 232 # Doubling 233 L = ((3*pow(x1, 2, curve.p) + curve.a) * modinv(2*y1, curve.p)) % curve.p 234 # Addition 235 else: 236 L = ((y2 - y1) * modinv((x2 - x1) % curve.p, curve.p)) % curve.p 237 resx = (pow(L, 2, curve.p) - x1 - x2) % curve.p 238 resy = ((L * (x1 - resx)) - y1) % curve.p 239 # Return the point 240 return Point(self.curve, resx, resy) 241 # Negation 242 def __neg__(self): 243 if (self.x == None): 244 return Point(self.curve, None, None) 245 else: 246 return Point(self.curve, self.x, -self.y) 247 # Subtraction 248 def __sub__(self, other): 249 return self + (-other) 250 # Scalar mul 251 def __rmul__(self, scalar): 252 # Implement simple double and add algorithm 253 P = self 254 Q = Point(P.curve, None, None) 255 for i in range(getbitlen(scalar), 0, -1): 256 Q = Q + Q 257 if (scalar >> (i-1)) & 0x1 == 0x1: 258 Q = Q + P 259 return Q 260 # Equality testing 261 def __eq__(self, other): 262 return self.__dict__ == other.__dict__ 263 # Deep copy is implemented using the ~X operator 264 def __invert__(self): 265 return copy.deepcopy(self) 266 def __str__(self): 267 if self.x == None: 268 return "Inf" 269 else: 270 return ("(x = %s, y = %s)" % (hex(self.x), hex(self.y))) 271 272########################################################## 273### Private and public keys structures 274class PrivKey(object): 275 def __init__(self, curve, x): 276 self.curve = curve 277 self.x = x 278 279class PubKey(object): 280 def __init__(self, curve, Y): 281 # Sanity check 282 if Y.curve != curve: 283 raise Exception("Error: curve and point curve differ in public key!") 284 self.curve = curve 285 self.Y = Y 286 287class KeyPair(object): 288 def __init__(self, pubkey, privkey): 289 self.pubkey = pubkey 290 self.privkey = privkey 291 292 293def fromprivkey(privkey, is_eckcdsa=False): 294 curve = privkey.curve 295 q = curve.q 296 gx = curve.gx 297 gy = curve.gy 298 G = Point(curve, gx, gy) 299 if is_eckcdsa == False: 300 return PubKey(curve, privkey.x * G) 301 else: 302 return PubKey(curve, modinv(privkey.x, q) * G) 303 304def genKeyPair(curve, is_eckcdsa=False): 305 p = curve.p 306 q = curve.q 307 gx = curve.gx 308 gy = curve.gy 309 G = Point(curve, gx, gy) 310 OK = False 311 while OK == False: 312 x = getrandomint(q) 313 if x == 0: 314 continue 315 OK = True 316 privkey = PrivKey(curve, x) 317 pubkey = fromprivkey(privkey, is_eckcdsa) 318 return KeyPair(pubkey, privkey) 319 320########################################################## 321### Signature algorithms helpers 322def getrandomint(modulo): 323 return random.randrange(0, modulo+1) 324 325def getbitlen(bint): 326 """ 327 Returns the number of bits encoding an integer 328 """ 329 if bint == None: 330 return 0 331 if bint == 0: 332 # Zero is encoded on one bit 333 return 1 334 else: 335 return int(bint).bit_length() 336 337def getbytelen(bint): 338 """ 339 Returns the number of bytes encoding an integer 340 """ 341 bitsize = getbitlen(bint) 342 bytesize = int(bitsize // 8) 343 if bitsize % 8 != 0: 344 bytesize += 1 345 return bytesize 346 347def stringtoint(bitstring): 348 acc = 0 349 size = len(bitstring) 350 for i in range(0, size): 351 acc = acc + (ord(bitstring[i]) * (2**(8*(size - 1 - i)))) 352 return acc 353 354def inttostring(a): 355 size = int(getbytelen(a)) 356 outstr = "" 357 for i in range(0, size): 358 outstr = outstr + chr((a >> (8*(size - 1 - i))) & 0xFF) 359 return outstr 360 361def expand(bitstring, bitlen, direction): 362 bytelen = int(math.ceil(bitlen / 8.)) 363 if len(bitstring) >= bytelen: 364 return bitstring 365 else: 366 if direction == "LEFT": 367 return ((bytelen-len(bitstring))*"\x00") + bitstring 368 elif direction == "RIGHT": 369 return bitstring + ((bytelen-len(bitstring))*"\x00") 370 else: 371 raise Exception("Error: unknown direction "+direction+" in expand") 372 373def truncate(bitstring, bitlen, keep): 374 """ 375 Takes a bit string and truncates it to keep the left 376 most or the right most bits 377 """ 378 strbitlen = 8*len(bitstring) 379 # Check if truncation is needed 380 if strbitlen > bitlen: 381 if keep == "LEFT": 382 return expand(inttostring(stringtoint(bitstring) >> int(strbitlen - bitlen)), bitlen, "LEFT") 383 elif keep == "RIGHT": 384 mask = (2**bitlen)-1 385 return expand(inttostring(stringtoint(bitstring) & mask), bitlen, "LEFT") 386 else: 387 raise Exception("Error: unknown direction "+keep+" in truncate") 388 else: 389 # No need to truncate! 390 return bitstring 391 392########################################################## 393### Hash algorithms 394def sha224(message): 395 ctx = hashlib.sha224() 396 if(is_python_2() == True): 397 ctx.update(message) 398 digest = ctx.digest() 399 else: 400 ctx.update(message.encode('latin-1')) 401 digest = ctx.digest().decode('latin-1') 402 return (digest, ctx.digest_size, ctx.block_size) 403 404def sha256(message): 405 ctx = hashlib.sha256() 406 if(is_python_2() == True): 407 ctx.update(message) 408 digest = ctx.digest() 409 else: 410 ctx.update(message.encode('latin-1')) 411 digest = ctx.digest().decode('latin-1') 412 return (digest, ctx.digest_size, ctx.block_size) 413 414def sha384(message): 415 ctx = hashlib.sha384() 416 if(is_python_2() == True): 417 ctx.update(message) 418 digest = ctx.digest() 419 else: 420 ctx.update(message.encode('latin-1')) 421 digest = ctx.digest().decode('latin-1') 422 return (digest, ctx.digest_size, ctx.block_size) 423 424def sha512(message): 425 ctx = hashlib.sha512() 426 if(is_python_2() == True): 427 ctx.update(message) 428 digest = ctx.digest() 429 else: 430 ctx.update(message.encode('latin-1')) 431 digest = ctx.digest().decode('latin-1') 432 return (digest, ctx.digest_size, ctx.block_size) 433 434def sha3_224(message): 435 ctx = sha3.Sha3_ctx(224) 436 if(is_python_2() == True): 437 ctx.update(message) 438 digest = ctx.digest() 439 else: 440 ctx.update(message.encode('latin-1')) 441 digest = ctx.digest().decode('latin-1') 442 return (digest, ctx.digest_size, ctx.block_size) 443 444def sha3_256(message): 445 ctx = sha3.Sha3_ctx(256) 446 if(is_python_2() == True): 447 ctx.update(message) 448 digest = ctx.digest() 449 else: 450 ctx.update(message.encode('latin-1')) 451 digest = ctx.digest().decode('latin-1') 452 return (digest, ctx.digest_size, ctx.block_size) 453 454def sha3_384(message): 455 ctx = sha3.Sha3_ctx(384) 456 if(is_python_2() == True): 457 ctx.update(message) 458 digest = ctx.digest() 459 else: 460 ctx.update(message.encode('latin-1')) 461 digest = ctx.digest().decode('latin-1') 462 return (digest, ctx.digest_size, ctx.block_size) 463 464def sha3_512(message): 465 ctx = sha3.Sha3_ctx(512) 466 if(is_python_2() == True): 467 ctx.update(message) 468 digest = ctx.digest() 469 else: 470 ctx.update(message.encode('latin-1')) 471 digest = ctx.digest().decode('latin-1') 472 return (digest, ctx.digest_size, ctx.block_size) 473 474########################################################## 475### Signature algorithms 476 477# *| IUF - ECDSA signature 478# *| 479# *| UF 1. Compute h = H(m) 480# *| F 2. If |h| > bitlen(q), set h to bitlen(q) 481# *| leftmost (most significant) bits of h 482# *| F 3. e = OS2I(h) mod q 483# *| F 4. Get a random value k in ]0,q[ 484# *| F 5. Compute W = (W_x,W_y) = kG 485# *| F 6. Compute r = W_x mod q 486# *| F 7. If r is 0, restart the process at step 4. 487# *| F 8. If e == rx, restart the process at step 4. 488# *| F 9. Compute s = k^-1 * (xr + e) mod q 489# *| F 10. If s is 0, restart the process at step 4. 490# *| F 11. Return (r,s) 491def ecdsa_sign(hashfunc, keypair, message, k=None): 492 privkey = keypair.privkey 493 # Get important parameters from the curve 494 p = privkey.curve.p 495 q = privkey.curve.q 496 gx = privkey.curve.gx 497 gy = privkey.curve.gy 498 G = Point(privkey.curve, gx, gy) 499 q_limit_len = getbitlen(q) 500 # Compute the hash 501 (h, _, _) = hashfunc(message) 502 # Truncate hash value 503 h = truncate(h, q_limit_len, "LEFT") 504 # Convert the hash value to an int 505 e = stringtoint(h) % q 506 OK = False 507 while OK == False: 508 if k == None: 509 k = getrandomint(q) 510 if k == 0: 511 continue 512 W = k * G 513 r = W.x % q 514 if r == 0: 515 continue 516 if e == r * privkey.x: 517 continue 518 s = (modinv(k, q) * ((privkey.x * r) + e)) % q 519 if s == 0: 520 continue 521 OK = True 522 return ((expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT")), k) 523 524# *| IUF - ECDSA verification 525# *| 526# *| I 1. Reject the signature if r or s is 0. 527# *| UF 2. Compute h = H(m) 528# *| F 3. If |h| > bitlen(q), set h to bitlen(q) 529# *| leftmost (most significant) bits of h 530# *| F 4. Compute e = OS2I(h) mod q 531# *| F 5. Compute u = (s^-1)e mod q 532# *| F 6. Compute v = (s^-1)r mod q 533# *| F 7. Compute W' = uG + vY 534# *| F 8. If W' is the point at infinity, reject the signature. 535# *| F 9. Compute r' = W'_x mod q 536# *| F 10. Accept the signature if and only if r equals r' 537def ecdsa_verify(hashfunc, keypair, message, sig): 538 pubkey = keypair.pubkey 539 # Get important parameters from the curve 540 p = pubkey.curve.p 541 q = pubkey.curve.q 542 gx = pubkey.curve.gx 543 gy = pubkey.curve.gy 544 q_limit_len = getbitlen(q) 545 G = Point(pubkey.curve, gx, gy) 546 # Extract r and s 547 if len(sig) != 2*getbytelen(q): 548 raise Exception("ECDSA verify: bad signature length!") 549 r = stringtoint(sig[0:int(len(sig)/2)]) 550 s = stringtoint(sig[int(len(sig)/2):]) 551 if r == 0 or s == 0: 552 return False 553 # Compute the hash 554 (h, _, _) = hashfunc(message) 555 # Truncate hash value 556 h = truncate(h, q_limit_len, "LEFT") 557 # Convert the hash value to an int 558 e = stringtoint(h) % q 559 u = (modinv(s, q) * e) % q 560 v = (modinv(s, q) * r) % q 561 W_ = (u * G) + (v * pubkey.Y) 562 if W_.x == None: 563 return False 564 r_ = W_.x % q 565 if r == r_: 566 return True 567 else: 568 return False 569 570def eckcdsa_genKeyPair(curve): 571 return genKeyPair(curve, True) 572 573# *| IUF - ECKCDSA signature 574# *| 575# *| IUF 1. Compute h = H(z||m) 576# *| F 2. If hsize > bitlen(q), set h to bitlen(q) 577# *| rightmost (less significant) bits of h. 578# *| F 3. Get a random value k in ]0,q[ 579# *| F 4. Compute W = (W_x,W_y) = kG 580# *| F 5. Compute r = h(FE2OS(W_x)). 581# *| F 6. If hsize > bitlen(q), set r to bitlen(q) 582# *| rightmost (less significant) bits of r. 583# *| F 7. Compute e = OS2I(r XOR h) mod q 584# *| F 8. Compute s = x(k - e) mod q 585# *| F 9. if s == 0, restart at step 3. 586# *| F 10. return (r,s) 587def eckcdsa_sign(hashfunc, keypair, message, k=None): 588 privkey = keypair.privkey 589 # Get important parameters from the curve 590 p = privkey.curve.p 591 q = privkey.curve.q 592 gx = privkey.curve.gx 593 gy = privkey.curve.gy 594 G = Point(privkey.curve, gx, gy) 595 q_limit_len = getbitlen(q) 596 # Compute the certificate data 597 (_, _, hblocksize) = hashfunc("") 598 z = expand(inttostring(keypair.pubkey.Y.x), 8*getbytelen(p), "LEFT") 599 z = z + expand(inttostring(keypair.pubkey.Y.y), 8*getbytelen(p), "LEFT") 600 if len(z) > hblocksize: 601 # Truncate 602 z = truncate(z, 8*hblocksize, "LEFT") 603 else: 604 # Expand 605 z = expand(z, 8*hblocksize, "RIGHT") 606 # Compute the hash 607 (h, _, _) = hashfunc(z + message) 608 # Truncate hash value 609 h = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT") 610 OK = False 611 while OK == False: 612 if k == None: 613 k = getrandomint(q) 614 if k == 0: 615 continue 616 W = k * G 617 (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT")) 618 r = truncate(r, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT") 619 e = (stringtoint(r) ^ stringtoint(h)) % q 620 s = (privkey.x * (k - e)) % q 621 if s == 0: 622 continue 623 OK = True 624 return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k) 625 626# *| IUF - ECKCDSA verification 627# *| 628# *| I 1. Check the length of r: 629# *| - if hsize > bitlen(q), r must be of 630# *| length bitlen(q) 631# *| - if hsize <= bitlen(q), r must be of 632# *| length hsize 633# *| I 2. Check that s is in ]0,q[ 634# *| IUF 3. Compute h = H(z||m) 635# *| F 4. If hsize > bitlen(q), set h to bitlen(q) 636# *| rightmost (less significant) bits of h. 637# *| F 5. Compute e = OS2I(r XOR h) mod q 638# *| F 6. Compute W' = sY + eG, where Y is the public key 639# *| F 7. Compute r' = h(FE2OS(W'x)) 640# *| F 8. If hsize > bitlen(q), set r' to bitlen(q) 641# *| rightmost (less significant) bits of r'. 642# *| F 9. Check if r == r' 643def eckcdsa_verify(hashfunc, keypair, message, sig): 644 pubkey = keypair.pubkey 645 # Get important parameters from the curve 646 p = pubkey.curve.p 647 q = pubkey.curve.q 648 gx = pubkey.curve.gx 649 gy = pubkey.curve.gy 650 G = Point(pubkey.curve, gx, gy) 651 q_limit_len = getbitlen(q) 652 (_, hsize, hblocksize) = hashfunc("") 653 # Extract r and s 654 if (8*hsize) > q_limit_len: 655 r_len = int(math.ceil(q_limit_len / 8.)) 656 else: 657 r_len = hsize 658 r = stringtoint(sig[0:int(r_len)]) 659 s = stringtoint(sig[int(r_len):]) 660 if (s >= q) or (s < 0): 661 return False 662 # Compute the certificate data 663 z = expand(inttostring(keypair.pubkey.Y.x), 8*getbytelen(p), "LEFT") 664 z = z + expand(inttostring(keypair.pubkey.Y.y), 8*getbytelen(p), "LEFT") 665 if len(z) > hblocksize: 666 # Truncate 667 z = truncate(z, 8*hblocksize, "LEFT") 668 else: 669 # Expand 670 z = expand(z, 8*hblocksize, "RIGHT") 671 # Compute the hash 672 (h, _, _) = hashfunc(z + message) 673 # Truncate hash value 674 h = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT") 675 e = (r ^ stringtoint(h)) % q 676 W_ = (s * pubkey.Y) + (e * G) 677 (h, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT")) 678 r_ = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT") 679 if stringtoint(r_) == r: 680 return True 681 else: 682 return False 683 684# *| IUF - ECFSDSA signature 685# *| 686# *| I 1. Get a random value k in ]0,q[ 687# *| I 2. Compute W = (W_x,W_y) = kG 688# *| I 3. Compute r = FE2OS(W_x)||FE2OS(W_y) 689# *| I 4. If r is an all zero string, restart the process at step 1. 690# *| IUF 5. Compute h = H(r||m) 691# *| F 6. Compute e = OS2I(h) mod q 692# *| F 7. Compute s = (k + ex) mod q 693# *| F 8. If s is 0, restart the process at step 1 (see c. below) 694# *| F 9. Return (r,s) 695def ecfsdsa_sign(hashfunc, keypair, message, k=None): 696 privkey = keypair.privkey 697 # Get important parameters from the curve 698 p = privkey.curve.p 699 q = privkey.curve.q 700 gx = privkey.curve.gx 701 gy = privkey.curve.gy 702 G = Point(privkey.curve, gx, gy) 703 OK = False 704 while OK == False: 705 if k == None: 706 k = getrandomint(q) 707 if k == 0: 708 continue 709 W = k * G 710 r = expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W.y), 8*getbytelen(p), "LEFT") 711 if stringtoint(r) == 0: 712 continue 713 (h, _, _) = hashfunc(r + message) 714 e = stringtoint(h) % q 715 s = (k + e * privkey.x) % q 716 if s == 0: 717 continue 718 OK = True 719 return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k) 720 721 722# *| IUF - ECFSDSA verification 723# *| 724# *| I 1. Reject the signature if r is not a valid point on the curve. 725# *| I 2. Reject the signature if s is not in ]0,q[ 726# *| IUF 3. Compute h = H(r||m) 727# *| F 4. Convert h to an integer and then compute e = -h mod q 728# *| F 5. compute W' = sG + eY, where Y is the public key 729# *| F 6. Compute r' = FE2OS(W'_x)||FE2OS(W'_y) 730# *| F 7. Accept the signature if and only if r equals r' 731def ecfsdsa_verify(hashfunc, keypair, message, sig): 732 pubkey = keypair.pubkey 733 # Get important parameters from the curve 734 p = pubkey.curve.p 735 q = pubkey.curve.q 736 gx = pubkey.curve.gx 737 gy = pubkey.curve.gy 738 G = Point(pubkey.curve, gx, gy) 739 # Extract coordinates from r and s from signature 740 if len(sig) != (2*getbytelen(p)) + getbytelen(q): 741 raise Exception("ECFSDSA verify: bad signature length!") 742 wx = sig[:int(getbytelen(p))] 743 wy = sig[int(getbytelen(p)):int(2*getbytelen(p))] 744 r = wx + wy 745 s = stringtoint(sig[int(2*getbytelen(p)):int((2*getbytelen(p))+getbytelen(q))]) 746 # Check r is on the curve 747 W = Point(pubkey.curve, stringtoint(wx), stringtoint(wy)) 748 # Check s is in ]0,q[ 749 if s == 0 or s > q: 750 raise Exception("ECFSDSA verify: s not in ]0,q[") 751 (h, _, _) = hashfunc(r + message) 752 e = (-stringtoint(h)) % q 753 W_ = s * G + e * pubkey.Y 754 r_ = expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W_.y), 8*getbytelen(p), "LEFT") 755 if r == r_: 756 return True 757 else: 758 return False 759 760 761# NOTE: ISO/IEC 14888-3 standard seems to diverge from the existing implementations 762# of ECRDSA when treating the message hash, and from the examples of certificates provided 763# in RFC 7091 and draft-deremin-rfc4491-bis. While in ISO/IEC 14888-3 it is explicitely asked 764# to proceed with the hash of the message as big endian, the RFCs derived from the Russian 765# standard expect the hash value to be treated as little endian when importing it as an integer 766# (this discrepancy is exhibited and confirmed by test vectors present in ISO/IEC 14888-3, and 767# by X.509 certificates present in the RFCs). This seems (to be confirmed) to be a discrepancy of 768# ISO/IEC 14888-3 algorithm description that must be fixed there. 769# 770# In order to be conservative, libecc uses the Russian standard behavior as expected to be in line with 771# other implemetations, but keeps the ISO/IEC 14888-3 behavior if forced/asked by the user using 772# the USE_ISO14888_3_ECRDSA toggle. This allows to keep backward compatibility with previous versions of the 773# library if needed. 774 775# *| IUF - ECRDSA signature 776# *| 777# *| UF 1. Compute h = H(m) 778# *| F 2. Get a random value k in ]0,q[ 779# *| F 3. Compute W = (W_x,W_y) = kG 780# *| F 4. Compute r = W_x mod q 781# *| F 5. If r is 0, restart the process at step 2. 782# *| F 6. Compute e = OS2I(h) mod q. If e is 0, set e to 1. 783# *| NOTE: here, ISO/IEC 14888-3 and RFCs differ in the way e treated. 784# *| e = OS2I(h) for ISO/IEC 14888-3, or e = OS2I(reversed(h)) when endianness of h 785# *| is reversed for RFCs. 786# *| F 7. Compute s = (rx + ke) mod q 787# *| F 8. If s is 0, restart the process at step 2. 788# *| F 11. Return (r,s) 789def ecrdsa_sign(hashfunc, keypair, message, k=None, use_iso14888_divergence=False): 790 privkey = keypair.privkey 791 # Get important parameters from the curve 792 p = privkey.curve.p 793 q = privkey.curve.q 794 gx = privkey.curve.gx 795 gy = privkey.curve.gy 796 G = Point(privkey.curve, gx, gy) 797 (h, _, _) = hashfunc(message) 798 if use_iso14888_divergence == False: 799 # Reverse the endianness for Russian standard RFC ECRDSA (contrary to ISO/IEC 14888-3 case) 800 h = h[::-1] 801 OK = False 802 while OK == False: 803 if k == None: 804 k = getrandomint(q) 805 if k == 0: 806 continue 807 W = k * G 808 r = W.x % q 809 if r == 0: 810 continue 811 e = stringtoint(h) % q 812 if e == 0: 813 e = 1 814 s = ((r * privkey.x) + (k * e)) % q 815 if s == 0: 816 continue 817 OK = True 818 return (expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k) 819 820# *| IUF - ECRDSA verification 821# *| 822# *| UF 1. Check that r and s are both in ]0,q[ 823# *| F 2. Compute h = H(m) 824# *| F 3. Compute e = OS2I(h)^-1 mod q 825# *| NOTE: here, ISO/IEC 14888-3 and RFCs differ in the way e treated. 826# *| e = OS2I(h) for ISO/IEC 14888-3, or e = OS2I(reversed(h)) when endianness of h 827# *| is reversed for RFCs. 828# *| F 4. Compute u = es mod q 829# *| F 4. Compute v = -er mod q 830# *| F 5. Compute W' = uG + vY = (W'_x, W'_y) 831# *| F 6. Let's now compute r' = W'_x mod q 832# *| F 7. Check r and r' are the same 833def ecrdsa_verify(hashfunc, keypair, message, sig, use_iso14888_divergence=False): 834 pubkey = keypair.pubkey 835 # Get important parameters from the curve 836 p = pubkey.curve.p 837 q = pubkey.curve.q 838 gx = pubkey.curve.gx 839 gy = pubkey.curve.gy 840 G = Point(pubkey.curve, gx, gy) 841 # Extract coordinates from r and s from signature 842 if len(sig) != 2*getbytelen(q): 843 raise Exception("ECRDSA verify: bad signature length!") 844 r = stringtoint(sig[:int(getbytelen(q))]) 845 s = stringtoint(sig[int(getbytelen(q)):int(2*getbytelen(q))]) 846 if r == 0 or r > q: 847 raise Exception("ECRDSA verify: r not in ]0,q[") 848 if s == 0 or s > q: 849 raise Exception("ECRDSA verify: s not in ]0,q[") 850 (h, _, _) = hashfunc(message) 851 if use_iso14888_divergence == False: 852 # Reverse the endianness for Russian standard RFC ECRDSA (contrary to ISO/IEC 14888-3 case) 853 h = h[::-1] 854 e = modinv(stringtoint(h) % q, q) 855 u = (e * s) % q 856 v = (-e * r) % q 857 W_ = u * G + v * pubkey.Y 858 r_ = W_.x % q 859 if r == r_: 860 return True 861 else: 862 return False 863 864 865# *| IUF - ECGDSA signature 866# *| 867# *| UF 1. Compute h = H(m). If |h| > bitlen(q), set h to bitlen(q) 868# *| leftmost (most significant) bits of h 869# *| F 2. Convert e = - OS2I(h) mod q 870# *| F 3. Get a random value k in ]0,q[ 871# *| F 4. Compute W = (W_x,W_y) = kG 872# *| F 5. Compute r = W_x mod q 873# *| F 6. If r is 0, restart the process at step 4. 874# *| F 7. Compute s = x(kr + e) mod q 875# *| F 8. If s is 0, restart the process at step 4. 876# *| F 9. Return (r,s) 877def ecgdsa_sign(hashfunc, keypair, message, k=None): 878 privkey = keypair.privkey 879 # Get important parameters from the curve 880 p = privkey.curve.p 881 q = privkey.curve.q 882 gx = privkey.curve.gx 883 gy = privkey.curve.gy 884 G = Point(privkey.curve, gx, gy) 885 (h, _, _) = hashfunc(message) 886 q_limit_len = getbitlen(q) 887 # Truncate hash value 888 h = truncate(h, q_limit_len, "LEFT") 889 e = (-stringtoint(h)) % q 890 OK = False 891 while OK == False: 892 if k == None: 893 k = getrandomint(q) 894 if k == 0: 895 continue 896 W = k * G 897 r = W.x % q 898 if r == 0: 899 continue 900 s = (privkey.x * ((k * r) + e)) % q 901 if s == 0: 902 continue 903 OK = True 904 return (expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k) 905 906# *| IUF - ECGDSA verification 907# *| 908# *| I 1. Reject the signature if r or s is 0. 909# *| UF 2. Compute h = H(m). If |h| > bitlen(q), set h to bitlen(q) 910# *| leftmost (most significant) bits of h 911# *| F 3. Compute e = OS2I(h) mod q 912# *| F 4. Compute u = ((r^-1)e mod q) 913# *| F 5. Compute v = ((r^-1)s mod q) 914# *| F 6. Compute W' = uG + vY 915# *| F 7. Compute r' = W'_x mod q 916# *| F 8. Accept the signature if and only if r equals r' 917def ecgdsa_verify(hashfunc, keypair, message, sig): 918 pubkey = keypair.pubkey 919 # Get important parameters from the curve 920 p = pubkey.curve.p 921 q = pubkey.curve.q 922 gx = pubkey.curve.gx 923 gy = pubkey.curve.gy 924 G = Point(pubkey.curve, gx, gy) 925 # Extract coordinates from r and s from signature 926 if len(sig) != 2*getbytelen(q): 927 raise Exception("ECGDSA verify: bad signature length!") 928 r = stringtoint(sig[:int(getbytelen(q))]) 929 s = stringtoint(sig[int(getbytelen(q)):int(2*getbytelen(q))]) 930 if r == 0 or r > q: 931 raise Exception("ECGDSA verify: r not in ]0,q[") 932 if s == 0 or s > q: 933 raise Exception("ECGDSA verify: s not in ]0,q[") 934 (h, _, _) = hashfunc(message) 935 q_limit_len = getbitlen(q) 936 # Truncate hash value 937 h = truncate(h, q_limit_len, "LEFT") 938 e = stringtoint(h) % q 939 r_inv = modinv(r, q) 940 u = (r_inv * e) % q 941 v = (r_inv * s) % q 942 W_ = u * G + v * pubkey.Y 943 r_ = W_.x % q 944 if r == r_: 945 return True 946 else: 947 return False 948 949# *| IUF - ECSDSA/ECOSDSA signature 950# *| 951# *| I 1. Get a random value k in ]0, q[ 952# *| I 2. Compute W = kG = (Wx, Wy) 953# *| IUF 3. Compute r = H(Wx [|| Wy] || m) 954# *| - In the normal version (ECSDSA), r = h(Wx || Wy || m). 955# *| - In the optimized version (ECOSDSA), r = h(Wx || m). 956# *| F 4. Compute e = OS2I(r) mod q 957# *| F 5. if e == 0, restart at step 1. 958# *| F 6. Compute s = (k + ex) mod q. 959# *| F 7. if s == 0, restart at step 1. 960# *| F 8. Return (r, s) 961def ecsdsa_common_sign(hashfunc, keypair, message, optimized, k=None): 962 privkey = keypair.privkey 963 # Get important parameters from the curve 964 p = privkey.curve.p 965 q = privkey.curve.q 966 gx = privkey.curve.gx 967 gy = privkey.curve.gy 968 G = Point(privkey.curve, gx, gy) 969 OK = False 970 while OK == False: 971 if k == None: 972 k = getrandomint(q) 973 if k == 0: 974 continue 975 W = k * G 976 if optimized == False: 977 (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W.y), 8*getbytelen(p), "LEFT") + message) 978 else: 979 (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + message) 980 e = stringtoint(r) % q 981 if e == 0: 982 continue 983 s = (k + (e * privkey.x)) % q 984 if s == 0: 985 continue 986 OK = True 987 return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k) 988 989def ecsdsa_sign(hashfunc, keypair, message, k=None): 990 return ecsdsa_common_sign(hashfunc, keypair, message, False, k) 991 992def ecosdsa_sign(hashfunc, keypair, message, k=None): 993 return ecsdsa_common_sign(hashfunc, keypair, message, True, k) 994 995# *| IUF - ECSDSA/ECOSDSA verification 996# *| 997# *| I 1. if s is not in ]0,q[, reject the signature.x 998# *| I 2. Compute e = -r mod q 999# *| I 3. If e == 0, reject the signature. 1000# *| I 4. Compute W' = sG + eY 1001# *| IUF 5. Compute r' = H(W'x [|| W'y] || m) 1002# *| - In the normal version (ECSDSA), r = h(W'x || W'y || m). 1003# *| - In the optimized version (ECOSDSA), r = h(W'x || m). 1004# *| F 6. Accept the signature if and only if r and r' are the same 1005def ecsdsa_common_verify(hashfunc, keypair, message, sig, optimized): 1006 pubkey = keypair.pubkey 1007 # Get important parameters from the curve 1008 p = pubkey.curve.p 1009 q = pubkey.curve.q 1010 gx = pubkey.curve.gx 1011 gy = pubkey.curve.gy 1012 G = Point(pubkey.curve, gx, gy) 1013 (_, hlen, _) = hashfunc("") 1014 # Extract coordinates from r and s from signature 1015 if len(sig) != hlen + getbytelen(q): 1016 raise Exception("EC[O]SDSA verify: bad signature length!") 1017 r = stringtoint(sig[:int(hlen)]) 1018 s = stringtoint(sig[int(hlen):int(hlen+getbytelen(q))]) 1019 if s == 0 or s > q: 1020 raise Exception("EC[O]DSA verify: s not in ]0,q[") 1021 e = (-r) % q 1022 if e == 0: 1023 raise Exception("EC[O]DSA verify: e is null") 1024 W_ = s * G + e * pubkey.Y 1025 if optimized == False: 1026 (r_, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W_.y), 8*getbytelen(p), "LEFT") + message) 1027 else: 1028 (r_, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + message) 1029 if sig[:int(hlen)] == r_: 1030 return True 1031 else: 1032 return False 1033 1034def ecsdsa_verify(hashfunc, keypair, message, sig): 1035 return ecsdsa_common_verify(hashfunc, keypair, message, sig, False) 1036 1037def ecosdsa_verify(hashfunc, keypair, message, sig): 1038 return ecsdsa_common_verify(hashfunc, keypair, message, sig, True) 1039 1040 1041########################################################## 1042### Generate self-tests for all the algorithms 1043 1044all_hash_funcs = [ (sha224, "SHA224"), (sha256, "SHA256"), (sha384, "SHA384"), (sha512, "SHA512"), (sha3_224, "SHA3_224"), (sha3_256, "SHA3_256"), (sha3_384, "SHA3_384"), (sha3_512, "SHA3_512") ] 1045 1046all_sig_algs = [ (ecdsa_sign, ecdsa_verify, genKeyPair, "ECDSA"), 1047 (eckcdsa_sign, eckcdsa_verify, eckcdsa_genKeyPair, "ECKCDSA"), 1048 (ecfsdsa_sign, ecfsdsa_verify, genKeyPair, "ECFSDSA"), 1049 (ecrdsa_sign, ecrdsa_verify, genKeyPair, "ECRDSA"), 1050 (ecgdsa_sign, ecgdsa_verify, eckcdsa_genKeyPair, "ECGDSA"), 1051 (ecsdsa_sign, ecsdsa_verify, genKeyPair, "ECSDSA"), 1052 (ecosdsa_sign, ecosdsa_verify, genKeyPair, "ECOSDSA"), ] 1053 1054 1055curr_test = 0 1056def pretty_print_curr_test(num_test, total_gen_tests): 1057 num_decimal = int(math.log10(total_gen_tests))+1 1058 format_buf = "%0"+str(num_decimal)+"d/%0"+str(num_decimal)+"d" 1059 sys.stdout.write('\b'*((2*num_decimal)+1)) 1060 sys.stdout.flush() 1061 sys.stdout.write(format_buf % (num_test, total_gen_tests)) 1062 if num_test == total_gen_tests: 1063 print("") 1064 return 1065 1066def gen_self_test(curve, hashfunc, sig_alg_sign, sig_alg_verify, sig_alg_genkeypair, num, hashfunc_name, sig_alg_name, total_gen_tests): 1067 global curr_test 1068 curr_test = curr_test + 1 1069 if num != 0: 1070 pretty_print_curr_test(curr_test, total_gen_tests) 1071 output_list = [] 1072 for test_num in range(0, num): 1073 out_vectors = "" 1074 # Generate a random key pair 1075 keypair = sig_alg_genkeypair(curve) 1076 # Generate a random message with a random size 1077 size = getrandomint(256) 1078 if is_python_2(): 1079 message = ''.join([random.choice(string.ascii_letters + string.digits) for n in xrange(size)]) 1080 else: 1081 message = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(size)]) 1082 test_name = sig_alg_name + "_" + hashfunc_name + "_" + curve.name.upper() + "_" + str(test_num) 1083 # Sign the message 1084 (sig, k) = sig_alg_sign(hashfunc, keypair, message) 1085 # Check that everything is OK with a verify 1086 if sig_alg_verify(hashfunc, keypair, message, sig) != True: 1087 raise Exception("Error during self test generation: sig verify failed! "+test_name+ " / msg="+message+" / sig="+binascii.hexlify(sig)+" / k="+hex(k)+" / privkey.x="+hex(keypair.privkey.x)) 1088 if sig_alg_name == "ECRDSA": 1089 out_vectors += "#ifndef USE_ISO14888_3_ECRDSA\n" 1090 # Now generate the test vector 1091 out_vectors += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"\n" 1092 out_vectors += "#ifdef WITH_CURVE_"+curve.name.upper()+"\n" 1093 out_vectors += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"\n" 1094 out_vectors += "/* "+test_name+" known test vectors */\n" 1095 out_vectors += "static int "+test_name+"_test_vectors_get_random(nn_t out, nn_src_t q)\n{\n" 1096 # k_buf MUST be exported padded to the length of q 1097 out_vectors += "\tconst u8 k_buf[] = "+bigint_to_C_array(k, getbytelen(curve.q)) 1098 out_vectors += "\tint ret, cmp;\n\tret = nn_init_from_buf(out, k_buf, sizeof(k_buf)); EG(ret, err);\n\tret = nn_cmp(out, q, &cmp); EG(ret, err);\n\tret = (cmp >= 0) ? -1 : 0;\nerr:\n\treturn ret;\n}\n" 1099 out_vectors += "static const u8 "+test_name+"_test_vectors_priv_key[] = \n"+bigint_to_C_array(keypair.privkey.x, getbytelen(keypair.privkey.x)) 1100 out_vectors += "static const u8 "+test_name+"_test_vectors_expected_sig[] = \n"+bigint_to_C_array(stringtoint(sig), len(sig)) 1101 out_vectors += "static const ec_test_case "+test_name+"_test_case = {\n" 1102 out_vectors += "\t.name = \""+test_name+"\",\n" 1103 out_vectors += "\t.ec_str_p = &"+curve.name+"_str_params,\n" 1104 out_vectors += "\t.priv_key = "+test_name+"_test_vectors_priv_key,\n" 1105 out_vectors += "\t.priv_key_len = sizeof("+test_name+"_test_vectors_priv_key),\n" 1106 out_vectors += "\t.nn_random = "+test_name+"_test_vectors_get_random,\n" 1107 out_vectors += "\t.hash_type = "+hashfunc_name+",\n" 1108 out_vectors += "\t.msg = \""+message+"\",\n" 1109 out_vectors += "\t.msglen = "+str(len(message))+",\n" 1110 out_vectors += "\t.sig_type = "+sig_alg_name+",\n" 1111 out_vectors += "\t.exp_sig = "+test_name+"_test_vectors_expected_sig,\n" 1112 out_vectors += "\t.exp_siglen = sizeof("+test_name+"_test_vectors_expected_sig),\n};\n" 1113 out_vectors += "#endif /* WITH_HASH_"+hashfunc_name+" */\n" 1114 out_vectors += "#endif /* WITH_CURVE_"+curve.name+" */\n" 1115 out_vectors += "#endif /* WITH_SIG_"+sig_alg_name+" */\n" 1116 if sig_alg_name == "ECRDSA": 1117 out_vectors += "#endif /* !USE_ISO14888_3_ECRDSA */\n" 1118 out_name = "" 1119 if sig_alg_name == "ECRDSA": 1120 out_name += "#ifndef USE_ISO14888_3_ECRDSA"+"/* For "+test_name+" */\n" 1121 out_name += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"/* For "+test_name+" */\n" 1122 out_name += "#ifdef WITH_CURVE_"+curve.name.upper()+"/* For "+test_name+" */\n" 1123 out_name += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"/* For "+test_name+" */\n" 1124 out_name += "\t&"+test_name+"_test_case,\n" 1125 out_name += "#endif /* WITH_HASH_"+hashfunc_name+" for "+test_name+" */\n" 1126 out_name += "#endif /* WITH_CURVE_"+curve.name+" for "+test_name+" */\n" 1127 out_name += "#endif /* WITH_SIG_"+sig_alg_name+" for "+test_name+" */" 1128 if sig_alg_name == "ECRDSA": 1129 out_name += "\n#endif /* !USE_ISO14888_3_ECRDSA */"+"/* For "+test_name+" */" 1130 output_list.append((out_name, out_vectors)) 1131 # In the specific case of ECRDSA, we also generate an ISO/IEC compatible test vector 1132 if sig_alg_name == "ECRDSA": 1133 out_vectors = "" 1134 (sig, k) = sig_alg_sign(hashfunc, keypair, message, use_iso14888_divergence=True) 1135 # Check that everything is OK with a verify 1136 if sig_alg_verify(hashfunc, keypair, message, sig, use_iso14888_divergence=True) != True: 1137 raise Exception("Error during self test generation: sig verify failed! "+test_name+ " / msg="+message+" / sig="+binascii.hexlify(sig)+" / k="+hex(k)+" / privkey.x="+hex(keypair.privkey.x)) 1138 out_vectors += "#ifdef USE_ISO14888_3_ECRDSA\n" 1139 # Now generate the test vector 1140 out_vectors += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"\n" 1141 out_vectors += "#ifdef WITH_CURVE_"+curve.name.upper()+"\n" 1142 out_vectors += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"\n" 1143 out_vectors += "/* "+test_name+" known test vectors */\n" 1144 out_vectors += "static int "+test_name+"_test_vectors_get_random(nn_t out, nn_src_t q)\n{\n" 1145 # k_buf MUST be exported padded to the length of q 1146 out_vectors += "\tconst u8 k_buf[] = "+bigint_to_C_array(k, getbytelen(curve.q)) 1147 out_vectors += "\tint ret, cmp;\n\tret = nn_init_from_buf(out, k_buf, sizeof(k_buf)); EG(ret, err);\n\tret = nn_cmp(out, q, &cmp); EG(ret, err);\n\tret = (cmp >= 0) ? -1 : 0;\nerr:\n\treturn ret;\n}\n" 1148 out_vectors += "static const u8 "+test_name+"_test_vectors_priv_key[] = \n"+bigint_to_C_array(keypair.privkey.x, getbytelen(keypair.privkey.x)) 1149 out_vectors += "static const u8 "+test_name+"_test_vectors_expected_sig[] = \n"+bigint_to_C_array(stringtoint(sig), len(sig)) 1150 out_vectors += "static const ec_test_case "+test_name+"_test_case = {\n" 1151 out_vectors += "\t.name = \""+test_name+"\",\n" 1152 out_vectors += "\t.ec_str_p = &"+curve.name+"_str_params,\n" 1153 out_vectors += "\t.priv_key = "+test_name+"_test_vectors_priv_key,\n" 1154 out_vectors += "\t.priv_key_len = sizeof("+test_name+"_test_vectors_priv_key),\n" 1155 out_vectors += "\t.nn_random = "+test_name+"_test_vectors_get_random,\n" 1156 out_vectors += "\t.hash_type = "+hashfunc_name+",\n" 1157 out_vectors += "\t.msg = \""+message+"\",\n" 1158 out_vectors += "\t.msglen = "+str(len(message))+",\n" 1159 out_vectors += "\t.sig_type = "+sig_alg_name+",\n" 1160 out_vectors += "\t.exp_sig = "+test_name+"_test_vectors_expected_sig,\n" 1161 out_vectors += "\t.exp_siglen = sizeof("+test_name+"_test_vectors_expected_sig),\n};\n" 1162 out_vectors += "#endif /* WITH_HASH_"+hashfunc_name+" */\n" 1163 out_vectors += "#endif /* WITH_CURVE_"+curve.name+" */\n" 1164 out_vectors += "#endif /* WITH_SIG_"+sig_alg_name+" */\n" 1165 out_vectors += "#endif /* USE_ISO14888_3_ECRDSA */\n" 1166 out_name = "" 1167 out_name += "#ifdef USE_ISO14888_3_ECRDSA"+"/* For "+test_name+" */\n" 1168 out_name += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"/* For "+test_name+" */\n" 1169 out_name += "#ifdef WITH_CURVE_"+curve.name.upper()+"/* For "+test_name+" */\n" 1170 out_name += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"/* For "+test_name+" */\n" 1171 out_name += "\t&"+test_name+"_test_case,\n" 1172 out_name += "#endif /* WITH_HASH_"+hashfunc_name+" for "+test_name+" */\n" 1173 out_name += "#endif /* WITH_CURVE_"+curve.name+" for "+test_name+" */\n" 1174 out_name += "#endif /* WITH_SIG_"+sig_alg_name+" for "+test_name+" */\n" 1175 out_name += "#endif /* USE_ISO14888_3_ECRDSA */"+"/* For "+test_name+" */" 1176 output_list.append((out_name, out_vectors)) 1177 1178 return output_list 1179 1180def gen_self_tests(curve, num): 1181 global curr_test 1182 curr_test = 0 1183 total_gen_tests = len(all_hash_funcs) * len(all_sig_algs) 1184 vectors = [[ gen_self_test(curve, hashf, sign, verify, genkp, num, hash_name, sig_alg_name, total_gen_tests) 1185 for (hashf, hash_name) in all_hash_funcs ] for (sign, verify, genkp, sig_alg_name) in all_sig_algs ] 1186 return vectors 1187 1188########################################################## 1189### ASN.1 stuff 1190def parse_DER_extract_size(derbuf): 1191 # Extract the size 1192 if ord(derbuf[0]) & 0x80 != 0: 1193 encoding_len_bytes = ord(derbuf[0]) & ~0x80 1194 # Skip 1195 base = 1 1196 else: 1197 encoding_len_bytes = 1 1198 base = 0 1199 if len(derbuf) < encoding_len_bytes+1: 1200 return (False, 0, 0) 1201 else: 1202 length = stringtoint(derbuf[base:base+encoding_len_bytes]) 1203 if len(derbuf) < length+encoding_len_bytes: 1204 return (False, 0, 0) 1205 else: 1206 return (True, encoding_len_bytes+base, length) 1207 1208def extract_DER_object(derbuf, object_tag): 1209 # Check type 1210 if ord(derbuf[0]) != object_tag: 1211 # Not the type we expect ... 1212 return (False, 0, "") 1213 else: 1214 derbuf = derbuf[1:] 1215 # Extract the size 1216 (check, encoding_len, size) = parse_DER_extract_size(derbuf) 1217 if check == False: 1218 return (False, 0, "") 1219 else: 1220 if len(derbuf) < encoding_len + size: 1221 return (False, 0, "") 1222 else: 1223 return (True, size+encoding_len+1, derbuf[encoding_len:encoding_len+size]) 1224 1225def extract_DER_sequence(derbuf): 1226 return extract_DER_object(derbuf, 0x30) 1227 1228def extract_DER_integer(derbuf): 1229 return extract_DER_object(derbuf, 0x02) 1230 1231def extract_DER_octetstring(derbuf): 1232 return extract_DER_object(derbuf, 0x04) 1233 1234def extract_DER_bitstring(derbuf): 1235 return extract_DER_object(derbuf, 0x03) 1236 1237def extract_DER_oid(derbuf): 1238 return extract_DER_object(derbuf, 0x06) 1239 1240# See ECParameters sequence in RFC 3279 1241def parse_DER_ECParameters(derbuf): 1242 # XXX: this is a very ugly way of extracting the information 1243 # regarding an EC curve, but since the ASN.1 structure is quite 1244 # "static", this might be sufficient without embedding a full 1245 # ASN.1 parser ... 1246 # Default return (a, b, prime, order, cofactor, gx, gy) 1247 default_ret = (0, 0, 0, 0, 0, 0, 0) 1248 # Get ECParameters wrapping sequence 1249 (check, size_ECParameters, ECParameters) = extract_DER_sequence(derbuf) 1250 if check == False: 1251 return (False, default_ret) 1252 # Get integer 1253 (check, size_ECPVer, ECPVer) = extract_DER_integer(ECParameters) 1254 if check == False: 1255 return (False, default_ret) 1256 # Get sequence 1257 (check, size_FieldID, FieldID) = extract_DER_sequence(ECParameters[size_ECPVer:]) 1258 if check == False: 1259 return (False, default_ret) 1260 # Get OID 1261 (check, size_Oid, Oid) = extract_DER_oid(FieldID) 1262 if check == False: 1263 return (False, default_ret) 1264 # Does the OID correspond to a prime field? 1265 if(Oid != "\x2A\x86\x48\xCE\x3D\x01\x01"): 1266 print("DER parse error: only prime fields are supported ...") 1267 return (False, default_ret) 1268 # Get prime p of prime field 1269 (check, size_P, P) = extract_DER_integer(FieldID[size_Oid:]) 1270 if check == False: 1271 return (False, default_ret) 1272 # Get curve (sequence) 1273 (check, size_Curve, Curve) = extract_DER_sequence(ECParameters[size_ECPVer+size_FieldID:]) 1274 if check == False: 1275 return (False, default_ret) 1276 # Get A in curve 1277 (check, size_A, A) = extract_DER_octetstring(Curve) 1278 if check == False: 1279 return (False, default_ret) 1280 # Get B in curve 1281 (check, size_B, B) = extract_DER_octetstring(Curve[size_A:]) 1282 if check == False: 1283 return (False, default_ret) 1284 # Get ECPoint 1285 (check, size_ECPoint, ECPoint) = extract_DER_octetstring(ECParameters[size_ECPVer+size_FieldID+size_Curve:]) 1286 if check == False: 1287 return (False, default_ret) 1288 # Get Order 1289 (check, size_Order, Order) = extract_DER_integer(ECParameters[size_ECPVer+size_FieldID+size_Curve+size_ECPoint:]) 1290 if check == False: 1291 return (False, default_ret) 1292 # Get Cofactor 1293 (check, size_Cofactor, Cofactor) = extract_DER_integer(ECParameters[size_ECPVer+size_FieldID+size_Curve+size_ECPoint+size_Order:]) 1294 if check == False: 1295 return (False, default_ret) 1296 # If we end up here, everything is OK, we can extract all our elements 1297 prime = stringtoint(P) 1298 a = stringtoint(A) 1299 b = stringtoint(B) 1300 order = stringtoint(Order) 1301 cofactor = stringtoint(Cofactor) 1302 # Extract Gx and Gy, see X9.62-1998 1303 if len(ECPoint) < 1: 1304 return (False, default_ret) 1305 ECPoint_type = ord(ECPoint[0]) 1306 if (ECPoint_type == 0x04) or (ECPoint_type == 0x06) or (ECPoint_type == 0x07): 1307 # Uncompressed and hybrid points 1308 if len(ECPoint[1:]) % 2 != 0: 1309 return (False, default_ret) 1310 ECPoint = ECPoint[1:] 1311 gx = stringtoint(ECPoint[:int(len(ECPoint)/2)]) 1312 gy = stringtoint(ECPoint[int(len(ECPoint)/2):]) 1313 elif (ECPoint_type == 0x02) or (ECPoint_type == 0x03): 1314 # Compressed point: uncompress it, see X9.62-1998 section 4.2.1 1315 ECPoint = ECPoint[1:] 1316 gx = stringtoint(ECPoint) 1317 alpha = (pow(gx, 3, prime) + (a * gx) + b) % prime 1318 beta = mod_sqrt(alpha, prime) 1319 if (beta == None) or ((beta == 0) and (alpha != 0)): 1320 return (False, 0) 1321 if (beta & 0x1) == (ECPoint_type & 0x1): 1322 gy = beta 1323 else: 1324 gy = prime - beta 1325 else: 1326 print("DER parse error: hybrid points are unsupported!") 1327 return (False, default_ret) 1328 return (True, (a, b, prime, order, cofactor, gx, gy)) 1329 1330########################################################## 1331### Text and format helpers 1332def bigint_to_C_array(bint, size): 1333 """ 1334 Format a python big int to a C hex array 1335 """ 1336 hexstr = format(int(bint), 'x') 1337 # Left pad to the size! 1338 hexstr = ("0"*int((2*size)-len(hexstr)))+hexstr 1339 hexstr = ("0"*(len(hexstr) % 2))+hexstr 1340 out_str = "{\n" 1341 for i in range(0, len(hexstr) - 1, 2): 1342 if (i%16 == 0): 1343 if(i!=0): 1344 out_str += "\n" 1345 out_str += "\t" 1346 out_str += "0x"+hexstr[i:i+2]+", " 1347 out_str += "\n};\n" 1348 return out_str 1349 1350def check_in_file(fname, pat): 1351 # See if the pattern is in the file. 1352 with open(fname) as f: 1353 if not any(re.search(pat, line) for line in f): 1354 return False # pattern does not occur in file so we are done. 1355 else: 1356 return True 1357 1358def num_patterns_in_file(fname, pat): 1359 num_pat = 0 1360 with open(fname) as f: 1361 for line in f: 1362 if re.search(pat, line): 1363 num_pat = num_pat+1 1364 return num_pat 1365 1366def file_replace_pattern(fname, pat, s_after): 1367 # first, see if the pattern is even in the file. 1368 with open(fname) as f: 1369 if not any(re.search(pat, line) for line in f): 1370 return # pattern does not occur in file so we are done. 1371 1372 # pattern is in the file, so perform replace operation. 1373 with open(fname) as f: 1374 out_fname = fname + ".tmp" 1375 out = open(out_fname, "w") 1376 for line in f: 1377 out.write(re.sub(pat, s_after, line)) 1378 out.close() 1379 os.rename(out_fname, fname) 1380 1381def file_remove_pattern(fname, pat): 1382 # first, see if the pattern is even in the file. 1383 with open(fname) as f: 1384 if not any(re.search(pat, line) for line in f): 1385 return # pattern does not occur in file so we are done. 1386 1387 # pattern is in the file, so perform remove operation. 1388 with open(fname) as f: 1389 out_fname = fname + ".tmp" 1390 out = open(out_fname, "w") 1391 for line in f: 1392 if not re.search(pat, line): 1393 out.write(line) 1394 out.close() 1395 1396 if os.path.exists(fname): 1397 remove_file(fname) 1398 os.rename(out_fname, fname) 1399 1400def remove_file(fname): 1401 # Remove file 1402 os.remove(fname) 1403 1404def remove_files_pattern(fpattern): 1405 [remove_file(x) for x in glob.glob(fpattern)] 1406 1407def buffer_remove_pattern(buff, pat): 1408 if is_python_2() == False: 1409 buff = buff.decode('latin-1') 1410 if re.search(pat, buff) == None: 1411 return (False, buff) # pattern does not occur in file so we are done. 1412 # Remove the pattern 1413 buff = re.sub(pat, "", buff) 1414 return (True, buff) 1415 1416def is_base64(s): 1417 s = ''.join([s.strip() for s in s.split("\n")]) 1418 try: 1419 enc = base64.b64encode(base64.b64decode(s)).strip() 1420 if type(enc) is bytes: 1421 return enc == s.encode('latin-1') 1422 else: 1423 return enc == s 1424 except TypeError: 1425 return False 1426 1427### Curve helpers 1428def export_curve_int(curvename, intname, bigint, size): 1429 if bigint == None: 1430 out = "static const u8 "+curvename+"_"+intname+"[] = {\n\t0x00,\n};\n" 1431 out += "TO_EC_STR_PARAM_FIXED_SIZE("+curvename+"_"+intname+", 0);\n\n" 1432 else: 1433 out = "static const u8 "+curvename+"_"+intname+"[] = "+bigint_to_C_array(bigint, size)+"\n" 1434 out += "TO_EC_STR_PARAM("+curvename+"_"+intname+");\n\n" 1435 return out 1436 1437def export_curve_string(curvename, stringname, stringvalue): 1438 out = "static const u8 "+curvename+"_"+stringname+"[] = \""+stringvalue+"\";\n" 1439 out += "TO_EC_STR_PARAM("+curvename+"_"+stringname+");\n\n" 1440 return out 1441 1442def export_curve_struct(curvename, paramname, paramnamestr): 1443 return "\t."+paramname+" = &"+curvename+"_"+paramnamestr+"_str_param, \n" 1444 1445def curve_params(name, prime, pbitlen, a, b, gx, gy, order, cofactor, oid, alpha_montgomery, gamma_montgomery, alpha_edwards): 1446 """ 1447 Take as input some elliptic curve parameters and generate the 1448 C parameters in a string 1449 """ 1450 bytesize = int(pbitlen / 8) 1451 if pbitlen % 8 != 0: 1452 bytesize += 1 1453 # Compute the rounded word size for each word size 1454 if bytesize % 8 != 0: 1455 wordsbitsize64 = 8*((int(bytesize/8)+1)*8) 1456 else: 1457 wordsbitsize64 = 8*bytesize 1458 if bytesize % 4 != 0: 1459 wordsbitsize32 = 8*((int(bytesize/4)+1)*4) 1460 else: 1461 wordsbitsize32 = 8*bytesize 1462 if bytesize % 2 != 0: 1463 wordsbitsize16 = 8*((int(bytesize/2)+1)*2) 1464 else: 1465 wordsbitsize16 = 8*bytesize 1466 # Compute some parameters 1467 (r64, r_square64, mpinv64) = compute_monty_coef(prime, wordsbitsize64, 64) 1468 (r32, r_square32, mpinv32) = compute_monty_coef(prime, wordsbitsize32, 32) 1469 (r16, r_square16, mpinv16) = compute_monty_coef(prime, wordsbitsize16, 16) 1470 # Compute p_reciprocal for each word size 1471 (pshift64, primenorm64, p_reciprocal64) = compute_div_coef(prime, wordsbitsize64, 64) 1472 (pshift32, primenorm32, p_reciprocal32) = compute_div_coef(prime, wordsbitsize32, 32) 1473 (pshift16, primenorm16, p_reciprocal16) = compute_div_coef(prime, wordsbitsize16, 16) 1474 # Compute the number of points on the curve 1475 npoints = order * cofactor 1476 1477 # Now output the parameters 1478 ec_params_string = "#include <libecc/lib_ecc_config.h>\n" 1479 ec_params_string += "#ifdef WITH_CURVE_"+name.upper()+"\n\n" 1480 ec_params_string += "#ifndef __EC_PARAMS_"+name.upper()+"_H__\n" 1481 ec_params_string += "#define __EC_PARAMS_"+name.upper()+"_H__\n" 1482 ec_params_string += "#include <libecc/curves/known/ec_params_external.h>\n" 1483 ec_params_string += export_curve_int(name, "p", prime, bytesize) 1484 1485 ec_params_string += "#define CURVE_"+name.upper()+"_P_BITLEN "+str(pbitlen)+"\n" 1486 ec_params_string += export_curve_int(name, "p_bitlen", pbitlen, getbytelen(pbitlen)) 1487 1488 ec_params_string += "#if (WORD_BYTES == 8) /* 64-bit words */\n" 1489 ec_params_string += export_curve_int(name, "r", r64, getbytelen(r64)) 1490 ec_params_string += export_curve_int(name, "r_square", r_square64, getbytelen(r_square64)) 1491 ec_params_string += export_curve_int(name, "mpinv", mpinv64, getbytelen(mpinv64)) 1492 ec_params_string += export_curve_int(name, "p_shift", pshift64, getbytelen(pshift64)) 1493 ec_params_string += export_curve_int(name, "p_normalized", primenorm64, getbytelen(primenorm64)) 1494 ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal64, getbytelen(p_reciprocal64)) 1495 ec_params_string += "#elif (WORD_BYTES == 4) /* 32-bit words */\n" 1496 ec_params_string += export_curve_int(name, "r", r32, getbytelen(r32)) 1497 ec_params_string += export_curve_int(name, "r_square", r_square32, getbytelen(r_square32)) 1498 ec_params_string += export_curve_int(name, "mpinv", mpinv32, getbytelen(mpinv32)) 1499 ec_params_string += export_curve_int(name, "p_shift", pshift32, getbytelen(pshift32)) 1500 ec_params_string += export_curve_int(name, "p_normalized", primenorm32, getbytelen(primenorm32)) 1501 ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal32, getbytelen(p_reciprocal32)) 1502 ec_params_string += "#elif (WORD_BYTES == 2) /* 16-bit words */\n" 1503 ec_params_string += export_curve_int(name, "r", r16, getbytelen(r16)) 1504 ec_params_string += export_curve_int(name, "r_square", r_square16, getbytelen(r_square16)) 1505 ec_params_string += export_curve_int(name, "mpinv", mpinv16, getbytelen(mpinv16)) 1506 ec_params_string += export_curve_int(name, "p_shift", pshift16, getbytelen(pshift16)) 1507 ec_params_string += export_curve_int(name, "p_normalized", primenorm16, getbytelen(primenorm16)) 1508 ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal16, getbytelen(p_reciprocal16)) 1509 ec_params_string += "#else /* unknown word size */\n" 1510 ec_params_string += "#error \"Unsupported word size\"\n" 1511 ec_params_string += "#endif\n\n" 1512 1513 ec_params_string += export_curve_int(name, "a", a, bytesize) 1514 ec_params_string += export_curve_int(name, "b", b, bytesize) 1515 1516 curve_order_bitlen = getbitlen(npoints) 1517 ec_params_string += "#define CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN "+str(curve_order_bitlen)+"\n" 1518 ec_params_string += export_curve_int(name, "curve_order", npoints, getbytelen(npoints)) 1519 1520 ec_params_string += export_curve_int(name, "gx", gx, bytesize) 1521 ec_params_string += export_curve_int(name, "gy", gy, bytesize) 1522 ec_params_string += export_curve_int(name, "gz", 0x01, bytesize) 1523 1524 qbitlen = getbitlen(order) 1525 1526 ec_params_string += export_curve_int(name, "gen_order", order, getbytelen(order)) 1527 ec_params_string += "#define CURVE_"+name.upper()+"_Q_BITLEN "+str(qbitlen)+"\n" 1528 ec_params_string += export_curve_int(name, "gen_order_bitlen", qbitlen, getbytelen(qbitlen)) 1529 1530 ec_params_string += export_curve_int(name, "cofactor", cofactor, getbytelen(cofactor)) 1531 1532 ec_params_string += export_curve_int(name, "alpha_montgomery", alpha_montgomery, getbytelen(alpha_montgomery)) 1533 ec_params_string += export_curve_int(name, "gamma_montgomery", gamma_montgomery, getbytelen(gamma_montgomery)) 1534 ec_params_string += export_curve_int(name, "alpha_edwards", alpha_edwards, getbytelen(alpha_edwards)) 1535 1536 ec_params_string += export_curve_string(name, "name", name.upper()); 1537 1538 if oid == None: 1539 oid = "" 1540 ec_params_string += export_curve_string(name, "oid", oid); 1541 1542 ec_params_string += "static const ec_str_params "+name+"_str_params = {\n"+\ 1543 export_curve_struct(name, "p", "p") +\ 1544 export_curve_struct(name, "p_bitlen", "p_bitlen") +\ 1545 export_curve_struct(name, "r", "r") +\ 1546 export_curve_struct(name, "r_square", "r_square") +\ 1547 export_curve_struct(name, "mpinv", "mpinv") +\ 1548 export_curve_struct(name, "p_shift", "p_shift") +\ 1549 export_curve_struct(name, "p_normalized", "p_normalized") +\ 1550 export_curve_struct(name, "p_reciprocal", "p_reciprocal") +\ 1551 export_curve_struct(name, "a", "a") +\ 1552 export_curve_struct(name, "b", "b") +\ 1553 export_curve_struct(name, "curve_order", "curve_order") +\ 1554 export_curve_struct(name, "gx", "gx") +\ 1555 export_curve_struct(name, "gy", "gy") +\ 1556 export_curve_struct(name, "gz", "gz") +\ 1557 export_curve_struct(name, "gen_order", "gen_order") +\ 1558 export_curve_struct(name, "gen_order_bitlen", "gen_order_bitlen") +\ 1559 export_curve_struct(name, "cofactor", "cofactor") +\ 1560 export_curve_struct(name, "alpha_montgomery", "alpha_montgomery") +\ 1561 export_curve_struct(name, "gamma_montgomery", "gamma_montgomery") +\ 1562 export_curve_struct(name, "alpha_edwards", "alpha_edwards") +\ 1563 export_curve_struct(name, "oid", "oid") +\ 1564 export_curve_struct(name, "name", "name") 1565 ec_params_string += "};\n\n" 1566 1567 ec_params_string += "/*\n"+\ 1568 " * Compute max bit length of all curves for p and q\n"+\ 1569 " */\n"+\ 1570 "#ifndef CURVES_MAX_P_BIT_LEN\n"+\ 1571 "#define CURVES_MAX_P_BIT_LEN 0\n"+\ 1572 "#endif\n"+\ 1573 "#if (CURVES_MAX_P_BIT_LEN < CURVE_"+name.upper()+"_P_BITLEN)\n"+\ 1574 "#undef CURVES_MAX_P_BIT_LEN\n"+\ 1575 "#define CURVES_MAX_P_BIT_LEN CURVE_"+name.upper()+"_P_BITLEN\n"+\ 1576 "#endif\n"+\ 1577 "#ifndef CURVES_MAX_Q_BIT_LEN\n"+\ 1578 "#define CURVES_MAX_Q_BIT_LEN 0\n"+\ 1579 "#endif\n"+\ 1580 "#if (CURVES_MAX_Q_BIT_LEN < CURVE_"+name.upper()+"_Q_BITLEN)\n"+\ 1581 "#undef CURVES_MAX_Q_BIT_LEN\n"+\ 1582 "#define CURVES_MAX_Q_BIT_LEN CURVE_"+name.upper()+"_Q_BITLEN\n"+\ 1583 "#endif\n"+\ 1584 "#ifndef CURVES_MAX_CURVE_ORDER_BIT_LEN\n"+\ 1585 "#define CURVES_MAX_CURVE_ORDER_BIT_LEN 0\n"+\ 1586 "#endif\n"+\ 1587 "#if (CURVES_MAX_CURVE_ORDER_BIT_LEN < CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN)\n"+\ 1588 "#undef CURVES_MAX_CURVE_ORDER_BIT_LEN\n"+\ 1589 "#define CURVES_MAX_CURVE_ORDER_BIT_LEN CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN\n"+\ 1590 "#endif\n\n" 1591 1592 ec_params_string += "/*\n"+\ 1593 " * Compute and adapt max name and oid length\n"+\ 1594 " */\n"+\ 1595 "#ifndef MAX_CURVE_OID_LEN\n"+\ 1596 "#define MAX_CURVE_OID_LEN 0\n"+\ 1597 "#endif\n"+\ 1598 "#ifndef MAX_CURVE_NAME_LEN\n"+\ 1599 "#define MAX_CURVE_NAME_LEN 0\n"+\ 1600 "#endif\n"+\ 1601 "#if (MAX_CURVE_OID_LEN < "+str(len(oid)+1)+")\n"+\ 1602 "#undef MAX_CURVE_OID_LEN\n"+\ 1603 "#define MAX_CURVE_OID_LEN "+str(len(oid)+1)+"\n"+\ 1604 "#endif\n"+\ 1605 "#if (MAX_CURVE_NAME_LEN < "+str(len(name.upper())+1)+")\n"+\ 1606 "#undef MAX_CURVE_NAME_LEN\n"+\ 1607 "#define MAX_CURVE_NAME_LEN "+str(len(name.upper())+1)+"\n"+\ 1608 "#endif\n\n" 1609 1610 ec_params_string += "#endif /* __EC_PARAMS_"+name.upper()+"_H__ */\n\n"+"#endif /* WITH_CURVE_"+name.upper()+" */\n" 1611 1612 return ec_params_string 1613 1614def usage(): 1615 print("This script is intented to *statically* expand the ECC library with user defined curves.") 1616 print("By statically we mean that the source code of libecc is expanded with new curves parameters through") 1617 print("automatic code generation filling place holders in the existing code base of the library. Though the") 1618 print("choice of static code generation versus dynamic curves import (such as what OpenSSL does) might be") 1619 print("argued, this choice has been driven by simplicity and security design decisions: we want libecc to have") 1620 print("all its parameters (such as memory consumption) set at compile time and statically adapted to the curves.") 1621 print("Since libecc only supports curves over prime fields, the script can only add this kind of curves.") 1622 print("This script implements elliptic curves and ISO signature algorithms from scratch over Python's multi-precision") 1623 print("big numbers library. Addition and doubling over curves use naive formulas. Please DO NOT use the functions of this") 1624 print("script for production code: they are not securely implemented and are very inefficient. Their only purpose is to expand") 1625 print("libecc and produce test vectors.") 1626 print("") 1627 print("In order to add a curve, there are two ways:") 1628 print("Adding a user defined curve with explicit parameters:") 1629 print("-----------------------------------------------------") 1630 print(sys.argv[0]+" --name=\"YOURCURVENAME\" --prime=... --order=... --a=... --b=... --gx=... --gy=... --cofactor=... --oid=THEOID") 1631 print("\t> name: name of the curve in the form of a string") 1632 print("\t> prime: prime number representing the curve prime field") 1633 print("\t> order: prime number representing the generator order") 1634 print("\t> cofactor: cofactor of the curve") 1635 print("\t> a: 'a' coefficient of the short Weierstrass equation of the curve") 1636 print("\t> b: 'b' coefficient of the short Weierstrass equation of the curve") 1637 print("\t> gx: x coordinate of the generator G") 1638 print("\t> gy: y coordinate of the generator G") 1639 print("\t> oid: optional OID of the curve") 1640 print(" Notes:") 1641 print(" ******") 1642 print("\t1) These elements are verified to indeed satisfy the curve equation.") 1643 print("\t2) All the numbers can be given either in decimal or hexadecimal format with a prepending '0x'.") 1644 print("\t3) The script automatically generates all the necessary files for the curve to be included in the library." ) 1645 print("\tYou will find the new curve definition in the usual 'lib_ecc_config.h' file (one can activate it or not at compile time).") 1646 print("") 1647 print("Adding a user defined curve through RFC3279 ASN.1 parameters:") 1648 print("-------------------------------------------------------------") 1649 print(sys.argv[0]+" --name=\"YOURCURVENAME\" --ECfile=... --oid=THEOID") 1650 print("\t> ECfile: the DER or PEM encoded file containing the curve parameters (see RFC3279)") 1651 print(" Notes:") 1652 print("\tCurve parameters encoded in DER or PEM format can be generated with tools like OpenSSL (among others). As an illustrative example,") 1653 print("\tone can list all the supported curves under OpenSSL with:") 1654 print("\t $ openssl ecparam -list_curves") 1655 print("\tOnly the listed so called \"prime\" curves are supported. Then, one can extract an explicit curve representation in ASN.1") 1656 print("\tas defined in RFC3279, for example for BRAINPOOLP320R1:") 1657 print("\t $ openssl ecparam -param_enc explicit -outform DER -name brainpoolP320r1 -out brainpoolP320r1.der") 1658 print("") 1659 print("Removing user defined curves:") 1660 print("-----------------------------") 1661 print("\t*All the user defined curves can be removed with the --remove-all toggle.") 1662 print("\t*A specific named user define curve can be removed with the --remove toggle: in this case the --name option is used to ") 1663 print("\tlocate which named curve must be deleted.") 1664 print("") 1665 print("Test vectors:") 1666 print("-------------") 1667 print("\tTest vectors can be automatically generated and added to the library self tests when providing the --add-test-vectors=X toggle.") 1668 print("\tIn this case, X test vectors will be generated for *each* (curve, sign algorithm, hash algorithm) 3-uplet (beware of combinatorial") 1669 print("\tissues when X is big!). These tests are transparently added and compiled with the self tests.") 1670 return 1671 1672def get_int(instring): 1673 if len(instring) == 0: 1674 return 0 1675 if len(instring) >= 2: 1676 if instring[:2] == "0x": 1677 return int(instring, 16) 1678 return int(instring) 1679 1680def parse_cmd_line(args): 1681 """ 1682 Get elliptic curve parameters from command line 1683 """ 1684 name = oid = prime = a = b = gx = gy = g = order = cofactor = ECfile = remove = remove_all = add_test_vectors = None 1685 alpha_montgomery = gamma_montgomery = alpha_edwards = None 1686 try: 1687 opts, args = getopt.getopt(sys.argv[1:], ":h", ["help", "remove", "remove-all", "name=", "prime=", "a=", "b=", "generator=", "gx=", "gy=", "order=", "cofactor=", "alpha_montgomery=","gamma_montgomery=", "alpha_edwards=", "ECfile=", "oid=", "add-test-vectors="]) 1688 except getopt.GetoptError as err: 1689 # print help information and exit: 1690 print(err) # will print something like "option -a not recognized" 1691 usage() 1692 return False 1693 for o, arg in opts: 1694 if o in ("-h", "--help"): 1695 usage() 1696 return True 1697 elif o in ("--name"): 1698 name = arg 1699 # Prepend the custom string before name to avoid any collision 1700 name = "user_defined_"+name 1701 # Replace any unwanted name char 1702 name = re.sub("\-", "_", name) 1703 elif o in ("--oid="): 1704 oid = arg 1705 elif o in ("--prime"): 1706 prime = get_int(arg.replace(' ', '')) 1707 elif o in ("--a"): 1708 a = get_int(arg.replace(' ', '')) 1709 elif o in ("--b"): 1710 b = get_int(arg.replace(' ', '')) 1711 elif o in ("--gx"): 1712 gx = get_int(arg.replace(' ', '')) 1713 elif o in ("--gy"): 1714 gy = get_int(arg.replace(' ', '')) 1715 elif o in ("--generator"): 1716 g = arg.replace(' ', '') 1717 elif o in ("--order"): 1718 order = get_int(arg.replace(' ', '')) 1719 elif o in ("--cofactor"): 1720 cofactor = get_int(arg.replace(' ', '')) 1721 elif o in ("--alpha_montgomery"): 1722 alpha_montgomery = get_int(arg.replace(' ', '')) 1723 elif o in ("--gamma_montgomery"): 1724 gamma_montgomery = get_int(arg.replace(' ', '')) 1725 elif o in ("--alpha_edwards"): 1726 alpha_edwards = get_int(arg.replace(' ', '')) 1727 elif o in ("--remove"): 1728 remove = True 1729 elif o in ("--remove-all"): 1730 remove_all = True 1731 elif o in ("--add-test-vectors"): 1732 add_test_vectors = get_int(arg.replace(' ', '')) 1733 elif o in ("--ECfile"): 1734 ECfile = arg 1735 else: 1736 print("unhandled option") 1737 usage() 1738 return False 1739 1740 # File paths 1741 script_path = os.path.abspath(os.path.dirname(sys.argv[0])) + "/" 1742 ec_params_path = script_path + "../include/libecc/curves/user_defined/" 1743 curves_list_path = script_path + "../include/libecc/curves/" 1744 lib_ecc_types_path = script_path + "../include/libecc/" 1745 lib_ecc_config_path = script_path + "../include/libecc/" 1746 ec_self_tests_path = script_path + "../src/tests/" 1747 meson_options_path = script_path + "../" 1748 1749 # If remove is True, we have been asked to remove already existing user defined curves 1750 if remove == True: 1751 if name == None: 1752 print("--remove option expects a curve name provided with --name") 1753 return False 1754 asked = "" 1755 while asked != "y" and asked != "n": 1756 asked = get_user_input("You asked to remove everything related to user defined "+name.replace("user_defined_", "")+" curve. Enter y to confirm, n to cancel [y/n]. ") 1757 if asked == "n": 1758 print("NOT removing curve "+name.replace("user_defined_", "")+" (cancelled).") 1759 return True 1760 # Remove any user defined stuff with given name 1761 print("Removing user defined curve "+name.replace("user_defined_", "")+" ...") 1762 if name == None: 1763 print("Error: you must provide a curve name with --remove") 1764 return False 1765 file_remove_pattern(curves_list_path + "curves_list.h", ".*"+name+".*") 1766 file_remove_pattern(curves_list_path + "curves_list.h", ".*"+name.upper()+".*") 1767 file_remove_pattern(lib_ecc_types_path + "lib_ecc_types.h", ".*"+name.upper()+".*") 1768 file_remove_pattern(lib_ecc_config_path + "lib_ecc_config.h", ".*"+name.upper()+".*") 1769 file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*"+name+".*") 1770 file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*"+name.upper()+".*") 1771 file_remove_pattern(meson_options_path + "meson.options", ".*"+name.lower()+".*") 1772 try: 1773 remove_file(ec_params_path + "ec_params_"+name+".h") 1774 except: 1775 print("Error: curve name "+name+" does not seem to be present in the sources!") 1776 return False 1777 try: 1778 remove_file(ec_self_tests_path + "ec_self_tests_core_"+name+".h") 1779 except: 1780 print("Warning: curve name "+name+" self tests do not seem to be present ...") 1781 return True 1782 return True 1783 if remove_all == True: 1784 asked = "" 1785 while asked != "y" and asked != "n": 1786 asked = get_user_input("You asked to remove everything related to ALL user defined curves. Enter y to confirm, n to cancel [y/n]. ") 1787 if asked == "n": 1788 print("NOT removing user defined curves (cancelled).") 1789 return True 1790 # Remove any user defined stuff with given name 1791 print("Removing ALL user defined curves ...") 1792 # Remove any user defined stuff (whatever name) 1793 file_remove_pattern(curves_list_path + "curves_list.h", ".*user_defined.*") 1794 file_remove_pattern(curves_list_path + "curves_list.h", ".*USER_DEFINED.*") 1795 file_remove_pattern(lib_ecc_types_path + "lib_ecc_types.h", ".*USER_DEFINED.*") 1796 file_remove_pattern(lib_ecc_config_path + "lib_ecc_config.h", ".*USER_DEFINED.*") 1797 file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*USER_DEFINED.*") 1798 file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*user_defined.*") 1799 file_remove_pattern(meson_options_path + "meson.options", ".*user_defined.*") 1800 remove_files_pattern(ec_params_path + "ec_params_user_defined_*.h") 1801 remove_files_pattern(ec_self_tests_path + "ec_self_tests_core_user_defined_*.h") 1802 return True 1803 1804 # If a g is provided, split it in two gx and gy 1805 if g != None: 1806 if (len(g)/2)%2 == 0: 1807 gx = get_int(g[:len(g)/2]) 1808 gy = get_int(g[len(g)/2:]) 1809 else: 1810 # This is probably a generator encapsulated in a bit string 1811 if g[0:2] != "04": 1812 print("Error: provided generator g is not conforming!") 1813 return False 1814 else: 1815 g = g[2:] 1816 gx = get_int(g[:len(g)/2]) 1817 gy = get_int(g[len(g)/2:]) 1818 if ECfile != None: 1819 # ASN.1 DER input incompatible with other options 1820 if (prime != None) or (a != None) or (b != None) or (gx != None) or (gy != None) or (order != None) or (cofactor != None): 1821 print("Error: option ECfile incompatible with explicit (prime, a, b, gx, gy, order, cofactor) options!") 1822 return False 1823 # We need at least a name 1824 if (name == None): 1825 print("Error: option ECfile needs a curve name!") 1826 return False 1827 # Open the file 1828 try: 1829 buf = open(ECfile, 'rb').read() 1830 except: 1831 print("Error: cannot open ECfile file "+ECfile) 1832 return False 1833 # Check if we have a PEM or a DER file 1834 (check, derbuf) = buffer_remove_pattern(buf, "-----.*-----") 1835 if (check == True): 1836 # This a PEM file, proceed with base64 decoding 1837 if(is_base64(derbuf) == False): 1838 print("Error: error when decoding ECfile file "+ECfile+" (seems to be PEM, but failed to decode)") 1839 return False 1840 derbuf = base64.b64decode(derbuf) 1841 (check, (a, b, prime, order, cofactor, gx, gy)) = parse_DER_ECParameters(derbuf) 1842 if (check == False): 1843 print("Error: error when parsing ECfile file "+ECfile+" (malformed or unsupported ASN.1)") 1844 return False 1845 1846 else: 1847 if (prime == None) or (a == None) or (b == None) or (gx == None) or (gy == None) or (order == None) or (cofactor == None) or (name == None): 1848 err_string = (prime == None)*"prime "+(a == None)*"a "+(b == None)*"b "+(gx == None)*"gx "+(gy == None)*"gy "+(order == None)*"order "+(cofactor == None)*"cofactor "+(name == None)*"name " 1849 print("Error: missing "+err_string+" in explicit curve definition (name, prime, a, b, gx, gy, order, cofactor)!") 1850 print("See the help with -h or --help") 1851 return False 1852 1853 # Some sanity checks here 1854 # Check that prime is indeed a prime 1855 if is_probprime(prime) == False: 1856 print("Error: given prime is *NOT* prime!") 1857 return False 1858 if is_probprime(order) == False: 1859 print("Error: given order is *NOT* prime!") 1860 return False 1861 if (a > prime) or (b > prime) or (gx > prime) or (gy > prime): 1862 err_string = (a > prime)*"a "+(b > prime)*"b "+(gx > prime)*"gx "+(gy > prime)*"gy " 1863 print("Error: "+err_string+"is > prime") 1864 return False 1865 # Check that the provided generator is on the curve 1866 if pow(gy, 2, prime) != ((pow(gx, 3, prime) + (a*gx) + b) % prime): 1867 print("Error: the given parameters (prime, a, b, gx, gy) do not verify the elliptic curve equation!") 1868 return False 1869 1870 # Check Montgomery and Edwards transfer coefficients 1871 if ((alpha_montgomery != None) and (gamma_montgomery == None)) or ((alpha_montgomery == None) and (gamma_montgomery != None)): 1872 print("Error: alpha_montgomery and gamma_montgomery must be both defined if used!") 1873 return False 1874 if (alpha_edwards != None): 1875 if (alpha_montgomery == None) or (gamma_montgomery == None): 1876 print("Error: alpha_edwards needs alpha_montgomery and gamma_montgomery to be both defined if used!") 1877 return False 1878 1879 # Now that we have our parameters, call the function to get bitlen 1880 pbitlen = getbitlen(prime) 1881 ec_params = curve_params(name, prime, pbitlen, a, b, gx, gy, order, cofactor, oid, alpha_montgomery, gamma_montgomery, alpha_edwards) 1882 # Check if there is a name collision somewhere 1883 if os.path.exists(ec_params_path + "ec_params_"+name+".h") == True : 1884 print("Error: file %s already exists!" % (ec_params_path + "ec_params_"+name+".h")) 1885 return False 1886 if (check_in_file(curves_list_path + "curves_list.h", "ec_params_"+name+"_str_params") == True) or (check_in_file(curves_list_path + "curves_list.h", "WITH_CURVE_"+name.upper()+"\n") == True) or (check_in_file(lib_ecc_types_path + "lib_ecc_types.h", "WITH_CURVE_"+name.upper()+"\n") == True): 1887 print("Error: name %s already exists in files" % ("ec_params_"+name)) 1888 return False 1889 # Create a new file with the parameters 1890 if not os.path.exists(ec_params_path): 1891 # Create the "user_defined" folder if it does not exist 1892 os.mkdir(ec_params_path) 1893 f = open(ec_params_path + "ec_params_"+name+".h", 'w') 1894 f.write(ec_params) 1895 f.close() 1896 # Include the file in curves_list.h 1897 magic = "ADD curves header here" 1898 magic_re = "\/\* "+magic+" \*\/" 1899 magic_back = "/* "+magic+" */" 1900 file_replace_pattern(curves_list_path + "curves_list.h", magic_re, "#include <libecc/curves/user_defined/ec_params_"+name+".h>\n"+magic_back) 1901 # Add the curve mapping 1902 magic = "ADD curves mapping here" 1903 magic_re = "\/\* "+magic+" \*\/" 1904 magic_back = "/* "+magic+" */" 1905 file_replace_pattern(curves_list_path + "curves_list.h", magic_re, "#ifdef WITH_CURVE_"+name.upper()+"\n\t{ .type = "+name.upper()+", .params = &"+name+"_str_params },\n#endif /* WITH_CURVE_"+name.upper()+" */\n"+magic_back) 1906 # Add the new curve type in the enum 1907 # First we get the number of already defined curves so that we increment the enum counter 1908 num_with_curve = num_patterns_in_file(lib_ecc_types_path + "lib_ecc_types.h", "#ifdef WITH_CURVE_") 1909 magic = "ADD curves type here" 1910 magic_re = "\/\* "+magic+" \*\/" 1911 magic_back = "/* "+magic+" */" 1912 file_replace_pattern(lib_ecc_types_path + "lib_ecc_types.h", magic_re, "#ifdef WITH_CURVE_"+name.upper()+"\n\t"+name.upper()+" = "+str(num_with_curve+1)+",\n#endif /* WITH_CURVE_"+name.upper()+" */\n"+magic_back) 1913 # Add the new curve define in the config 1914 magic = "ADD curves define here" 1915 magic_re = "\/\* "+magic+" \*\/" 1916 magic_back = "/* "+magic+" */" 1917 file_replace_pattern(lib_ecc_config_path + "lib_ecc_config.h", magic_re, "#define WITH_CURVE_"+name.upper()+"\n"+magic_back) 1918 # Add the new curve meson option in the meson.options file 1919 magic = "ADD curves meson option here" 1920 magic_re = "# " + magic 1921 magic_back = "# " + magic 1922 file_replace_pattern(meson_options_path + "meson.options", magic_re, "\t'"+name.lower()+"',\n"+magic_back) 1923 1924 # Do we need to add some test vectors? 1925 if add_test_vectors != None: 1926 print("Test vectors generation asked: this can take some time! Please wait ...") 1927 # Create curve 1928 c = Curve(a, b, prime, order, cofactor, gx, gy, cofactor * order, name, oid) 1929 # Generate key pair for the algorithm 1930 vectors = gen_self_tests(c, add_test_vectors) 1931 # Iterate through all the tests 1932 f = open(ec_self_tests_path + "ec_self_tests_core_"+name+".h", 'w') 1933 for l in vectors: 1934 for v in l: 1935 for case in v: 1936 (case_name, case_vector) = case 1937 # Add the new test case 1938 magic = "ADD curve test case here" 1939 magic_re = "\/\* "+magic+" \*\/" 1940 magic_back = "/* "+magic+" */" 1941 file_replace_pattern(ec_self_tests_path + "ec_self_tests_core.h", magic_re, case_name+"\n"+magic_back) 1942 # Create/Increment the header file 1943 f.write(case_vector) 1944 f.close() 1945 # Add the new test cases header 1946 magic = "ADD curve test vectors header here" 1947 magic_re = "\/\* "+magic+" \*\/" 1948 magic_back = "/* "+magic+" */" 1949 file_replace_pattern(ec_self_tests_path + "ec_self_tests_core.h", magic_re, "#include \"ec_self_tests_core_"+name+".h\"\n"+magic_back) 1950 return True 1951 1952 1953#### Main 1954if __name__ == "__main__": 1955 signal.signal(signal.SIGINT, handler) 1956 parse_cmd_line(sys.argv[1:]) 1957