xref: /freebsd/crypto/libecc/scripts/expand_libecc.py (revision f0865ec9906d5a18fa2a3b61381f22ce16e606ad)
1#/*
2# *  Copyright (C) 2017 - This file is part of libecc project
3# *
4# *  Authors:
5# *      Ryad BENADJILA <ryadbenadjila@gmail.com>
6# *      Arnaud EBALARD <arnaud.ebalard@ssi.gouv.fr>
7# *      Jean-Pierre FLORI <jean-pierre.flori@ssi.gouv.fr>
8# *
9# *  Contributors:
10# *      Nicolas VIVET <nicolas.vivet@ssi.gouv.fr>
11# *      Karim KHALFALLAH <karim.khalfallah@ssi.gouv.fr>
12# *
13# *  This software is licensed under a dual BSD and GPL v2 license.
14# *  See LICENSE file at the root folder of the project.
15# */
16#! /usr/bin/env python
17
18import random, sys, re, math, os, getopt, glob, copy, hashlib, binascii, string, signal, base64
19
20# External dependecy for SHA-3
21# It is an independent module, since hashlib has no support
22# for SHA-3 functions for now
23import sha3
24
25# Handle Python 2/3 issues
26def is_python_2():
27    if sys.version_info[0] < 3:
28        return True
29    else:
30        return False
31
32### Ctrl-C handler
33def handler(signal, frame):
34    print("\nSIGINT caught: exiting ...")
35    exit(0)
36
37# Helper to ask the user for something
38def get_user_input(prompt):
39    # Handle the Python 2/3 issue
40    if is_python_2() == False:
41        return input(prompt)
42    else:
43        return raw_input(prompt)
44
45##########################################################
46#### Math helpers
47def egcd(b, n):
48    x0, x1, y0, y1 = 1, 0, 0, 1
49    while n != 0:
50        q, b, n = b // n, n, b % n
51        x0, x1 = x1, x0 - q * x1
52        y0, y1 = y1, y0 - q * y1
53    return  b, x0, y0
54
55def modinv(a, m):
56    g, x, y = egcd(a, m)
57    if g != 1:
58        raise Exception("Error: modular inverse does not exist")
59    else:
60        return x % m
61
62def compute_monty_coef(prime, pbitlen, wlen):
63    """
64    Compute montgomery coeff r, r^2 and mpinv. pbitlen is the size
65    of p in bits. It is expected to be a multiple of word
66    bit size.
67    """
68    r = (1 << int(pbitlen)) % prime
69    r_square = (1 << (2 * int(pbitlen))) % prime
70    mpinv = 2**wlen - (modinv(prime, 2**wlen))
71    return r, r_square, mpinv
72
73def compute_div_coef(prime, pbitlen, wlen):
74    """
75    Compute division coeffs p_normalized, p_shift and p_reciprocal.
76    """
77    tmp = prime
78    cnt = 0
79    while tmp != 0:
80        tmp = tmp >> 1
81        cnt += 1
82    pshift = int(pbitlen - cnt)
83    primenorm = prime << pshift
84    B = 2**wlen
85    prec = B**3 // ((primenorm >> int(pbitlen - 2*wlen)) + 1) - B
86    return pshift, primenorm, prec
87
88def is_probprime(n):
89    # ensure n is odd
90    if n % 2 == 0:
91        return False
92    # write n-1 as 2**s * d
93    # repeatedly try to divide n-1 by 2
94    s = 0
95    d = n-1
96    while True:
97        quotient, remainder = divmod(d, 2)
98        if remainder == 1:
99            break
100        s += 1
101        d = quotient
102    assert(2**s * d == n-1)
103    # test the base a to see whether it is a witness for the compositeness of n
104    def try_composite(a):
105        if pow(a, d, n) == 1:
106            return False
107        for i in range(s):
108            if pow(a, 2**i * d, n) == n-1:
109                return False
110        return True # n is definitely composite
111    for i in range(5):
112        a = random.randrange(2, n)
113        if try_composite(a):
114            return False
115    return True # no base tested showed n as composite
116
117def legendre_symbol(a, p):
118    ls = pow(a, (p - 1) // 2, p)
119    return -1 if ls == p - 1 else ls
120
121# Tonelli-Shanks algorithm to find square roots
122# over prime fields
123def mod_sqrt(a, p):
124    # Square root of 0 is 0
125    if a == 0:
126        return 0
127    # Simple cases
128    if legendre_symbol(a, p) != 1:
129        # No square residue
130        return None
131    elif p == 2:
132        return a
133    elif p % 4 == 3:
134        return pow(a, (p + 1) // 4, p)
135    s = p - 1
136    e = 0
137    while s % 2 == 0:
138        s = s // 2
139        e += 1
140    n = 2
141    while legendre_symbol(n, p) != -1:
142        n += 1
143    x = pow(a, (s + 1) // 2, p)
144    b = pow(a, s, p)
145    g = pow(n, s, p)
146    r = e
147    while True:
148        t = b
149        m = 0
150        if is_python_2():
151            for m in xrange(r):
152                if t == 1:
153                    break
154                t = pow(t, 2, p)
155        else:
156            for m in range(r):
157                if t == 1:
158                    break
159                t = pow(t, 2, p)
160        if m == 0:
161            return x
162        gs = pow(g, 2 ** (r - m - 1), p)
163        g = (gs * gs) % p
164        x = (x * gs) % p
165        b = (b * g) % p
166        r = m
167
168##########################################################
169### Math elliptic curves basic blocks
170
171# WARNING: these blocks are only here for testing purpose and
172# are not intended to be used in a security oriented library!
173# This explains the usage of naive affine coordinates fomulas
174class Curve(object):
175    def __init__(self, a, b, prime, order, cofactor, gx, gy, npoints, name, oid):
176        self.a = a
177        self.b = b
178        self.p = prime
179        self.q = order
180        self.c = cofactor
181        self.gx = gx
182        self.gy = gy
183        self.n = npoints
184        self.name = name
185        self.oid = oid
186    # Equality testing
187    def __eq__(self, other):
188        return self.__dict__ == other.__dict__
189    # Deep copy is implemented using the ~X operator
190    def __invert__(self):
191        return copy.deepcopy(self)
192
193
194class Point(object):
195    # Affine coordinates (x, y), infinity point is (None, None)
196    def __init__(self, curve, x, y):
197        self.curve = curve
198        if x != None:
199            self.x = (x % curve.p)
200        else:
201            self.x = None
202        if y != None:
203            self.y = (y % curve.p)
204        else:
205            self.y = None
206        # Check that the point is indeed on the curve
207        if (x != None):
208            if (pow(y, 2, curve.p) != ((pow(x, 3, curve.p) + (curve.a * x) + curve.b ) % curve.p)):
209                raise Exception("Error: point is not on curve!")
210    # Addition
211    def __add__(self, Q):
212        x1 = self.x
213        y1 = self.y
214        x2 = Q.x
215        y2 = Q.y
216        curve = self.curve
217        # Check that we are on the same curve
218        if Q.curve != curve:
219            raise Exception("Point add error: two point don't have the same curve")
220        # If Q is infinity point, return ourself
221        if Q.x == None:
222            return Point(self.curve, self.x, self.y)
223        # If we are the infinity point return Q
224        if self.x == None:
225            return Q
226        # Infinity point or Doubling
227        if (x1 == x2):
228            if (((y1 + y2) % curve.p) == 0):
229                # Return infinity point
230                return Point(self.curve, None, None)
231            else:
232                # Doubling
233                L = ((3*pow(x1, 2, curve.p) + curve.a) * modinv(2*y1, curve.p)) % curve.p
234        # Addition
235        else:
236            L = ((y2 - y1) * modinv((x2 - x1) % curve.p, curve.p)) % curve.p
237        resx = (pow(L, 2, curve.p) - x1 - x2) % curve.p
238        resy = ((L * (x1 - resx)) - y1) % curve.p
239        # Return the point
240        return Point(self.curve, resx, resy)
241    # Negation
242    def __neg__(self):
243        if (self.x == None):
244            return Point(self.curve, None, None)
245        else:
246            return Point(self.curve, self.x, -self.y)
247    # Subtraction
248    def __sub__(self, other):
249        return self + (-other)
250    # Scalar mul
251    def __rmul__(self, scalar):
252        # Implement simple double and add algorithm
253        P = self
254        Q = Point(P.curve, None, None)
255        for i in range(getbitlen(scalar), 0, -1):
256            Q = Q + Q
257            if (scalar >> (i-1)) & 0x1 == 0x1:
258                Q = Q + P
259        return Q
260    # Equality testing
261    def __eq__(self, other):
262        return self.__dict__ == other.__dict__
263    # Deep copy is implemented using the ~X operator
264    def __invert__(self):
265        return copy.deepcopy(self)
266    def __str__(self):
267        if self.x == None:
268            return "Inf"
269        else:
270            return ("(x = %s, y = %s)" % (hex(self.x), hex(self.y)))
271
272##########################################################
273### Private and public keys structures
274class PrivKey(object):
275    def __init__(self, curve, x):
276        self.curve = curve
277        self.x = x
278
279class PubKey(object):
280    def __init__(self, curve, Y):
281        # Sanity check
282        if Y.curve != curve:
283            raise Exception("Error: curve and point curve differ in public key!")
284        self.curve = curve
285        self.Y = Y
286
287class KeyPair(object):
288    def __init__(self, pubkey, privkey):
289        self.pubkey = pubkey
290        self.privkey = privkey
291
292
293def fromprivkey(privkey, is_eckcdsa=False):
294    curve = privkey.curve
295    q = curve.q
296    gx = curve.gx
297    gy = curve.gy
298    G = Point(curve, gx, gy)
299    if is_eckcdsa == False:
300        return PubKey(curve, privkey.x * G)
301    else:
302        return PubKey(curve, modinv(privkey.x, q) * G)
303
304def genKeyPair(curve, is_eckcdsa=False):
305    p = curve.p
306    q = curve.q
307    gx = curve.gx
308    gy = curve.gy
309    G = Point(curve, gx, gy)
310    OK = False
311    while OK == False:
312        x = getrandomint(q)
313        if x == 0:
314            continue
315        OK = True
316    privkey = PrivKey(curve, x)
317    pubkey = fromprivkey(privkey, is_eckcdsa)
318    return KeyPair(pubkey, privkey)
319
320##########################################################
321### Signature algorithms helpers
322def getrandomint(modulo):
323    return random.randrange(0, modulo+1)
324
325def getbitlen(bint):
326    """
327    Returns the number of bits encoding an integer
328    """
329    if bint == None:
330        return 0
331    if bint == 0:
332        # Zero is encoded on one bit
333        return 1
334    else:
335        return int(bint).bit_length()
336
337def getbytelen(bint):
338    """
339    Returns the number of bytes encoding an integer
340    """
341    bitsize = getbitlen(bint)
342    bytesize = int(bitsize // 8)
343    if bitsize % 8 != 0:
344        bytesize += 1
345    return bytesize
346
347def stringtoint(bitstring):
348    acc = 0
349    size = len(bitstring)
350    for i in range(0, size):
351        acc = acc + (ord(bitstring[i]) * (2**(8*(size - 1 - i))))
352    return acc
353
354def inttostring(a):
355    size = int(getbytelen(a))
356    outstr = ""
357    for i in range(0, size):
358        outstr = outstr + chr((a >> (8*(size - 1 - i))) & 0xFF)
359    return outstr
360
361def expand(bitstring, bitlen, direction):
362    bytelen = int(math.ceil(bitlen / 8.))
363    if len(bitstring) >= bytelen:
364        return bitstring
365    else:
366        if direction == "LEFT":
367            return ((bytelen-len(bitstring))*"\x00") + bitstring
368        elif direction == "RIGHT":
369            return bitstring + ((bytelen-len(bitstring))*"\x00")
370        else:
371            raise Exception("Error: unknown direction "+direction+" in expand")
372
373def truncate(bitstring, bitlen, keep):
374    """
375    Takes a bit string and truncates it to keep the left
376    most or the right most bits
377    """
378    strbitlen = 8*len(bitstring)
379    # Check if truncation is needed
380    if strbitlen > bitlen:
381        if keep == "LEFT":
382            return expand(inttostring(stringtoint(bitstring) >> int(strbitlen - bitlen)), bitlen, "LEFT")
383        elif keep == "RIGHT":
384            mask = (2**bitlen)-1
385            return expand(inttostring(stringtoint(bitstring) & mask), bitlen, "LEFT")
386        else:
387            raise Exception("Error: unknown direction "+keep+" in truncate")
388    else:
389        # No need to truncate!
390        return bitstring
391
392##########################################################
393### Hash algorithms
394def sha224(message):
395    ctx = hashlib.sha224()
396    if(is_python_2() == True):
397        ctx.update(message)
398        digest = ctx.digest()
399    else:
400        ctx.update(message.encode('latin-1'))
401        digest = ctx.digest().decode('latin-1')
402    return (digest, ctx.digest_size, ctx.block_size)
403
404def sha256(message):
405    ctx = hashlib.sha256()
406    if(is_python_2() == True):
407        ctx.update(message)
408        digest = ctx.digest()
409    else:
410        ctx.update(message.encode('latin-1'))
411        digest = ctx.digest().decode('latin-1')
412    return (digest, ctx.digest_size, ctx.block_size)
413
414def sha384(message):
415    ctx = hashlib.sha384()
416    if(is_python_2() == True):
417        ctx.update(message)
418        digest = ctx.digest()
419    else:
420        ctx.update(message.encode('latin-1'))
421        digest = ctx.digest().decode('latin-1')
422    return (digest, ctx.digest_size, ctx.block_size)
423
424def sha512(message):
425    ctx = hashlib.sha512()
426    if(is_python_2() == True):
427        ctx.update(message)
428        digest = ctx.digest()
429    else:
430        ctx.update(message.encode('latin-1'))
431        digest = ctx.digest().decode('latin-1')
432    return (digest, ctx.digest_size, ctx.block_size)
433
434def sha3_224(message):
435    ctx = sha3.Sha3_ctx(224)
436    if(is_python_2() == True):
437        ctx.update(message)
438        digest = ctx.digest()
439    else:
440        ctx.update(message.encode('latin-1'))
441        digest = ctx.digest().decode('latin-1')
442    return (digest, ctx.digest_size, ctx.block_size)
443
444def sha3_256(message):
445    ctx = sha3.Sha3_ctx(256)
446    if(is_python_2() == True):
447        ctx.update(message)
448        digest = ctx.digest()
449    else:
450        ctx.update(message.encode('latin-1'))
451        digest = ctx.digest().decode('latin-1')
452    return (digest, ctx.digest_size, ctx.block_size)
453
454def sha3_384(message):
455    ctx = sha3.Sha3_ctx(384)
456    if(is_python_2() == True):
457        ctx.update(message)
458        digest = ctx.digest()
459    else:
460        ctx.update(message.encode('latin-1'))
461        digest = ctx.digest().decode('latin-1')
462    return (digest, ctx.digest_size, ctx.block_size)
463
464def sha3_512(message):
465    ctx = sha3.Sha3_ctx(512)
466    if(is_python_2() == True):
467        ctx.update(message)
468        digest = ctx.digest()
469    else:
470        ctx.update(message.encode('latin-1'))
471        digest = ctx.digest().decode('latin-1')
472    return (digest, ctx.digest_size, ctx.block_size)
473
474##########################################################
475### Signature algorithms
476
477# *| IUF  - ECDSA signature
478# *|
479# *|  UF  1. Compute h = H(m)
480# *|   F  2. If |h| > bitlen(q), set h to bitlen(q)
481# *|         leftmost (most significant) bits of h
482# *|   F  3. e = OS2I(h) mod q
483# *|   F  4. Get a random value k in ]0,q[
484# *|   F  5. Compute W = (W_x,W_y) = kG
485# *|   F  6. Compute r = W_x mod q
486# *|   F  7. If r is 0, restart the process at step 4.
487# *|   F  8. If e == rx, restart the process at step 4.
488# *|   F  9. Compute s = k^-1 * (xr + e) mod q
489# *|   F 10. If s is 0, restart the process at step 4.
490# *|   F 11. Return (r,s)
491def ecdsa_sign(hashfunc, keypair, message, k=None):
492    privkey = keypair.privkey
493    # Get important parameters from the curve
494    p = privkey.curve.p
495    q = privkey.curve.q
496    gx = privkey.curve.gx
497    gy = privkey.curve.gy
498    G = Point(privkey.curve, gx, gy)
499    q_limit_len = getbitlen(q)
500    # Compute the hash
501    (h, _, _) = hashfunc(message)
502    # Truncate hash value
503    h = truncate(h, q_limit_len, "LEFT")
504    # Convert the hash value to an int
505    e = stringtoint(h) % q
506    OK = False
507    while OK == False:
508        if k == None:
509            k = getrandomint(q)
510        if k == 0:
511            continue
512        W = k * G
513        r = W.x % q
514        if r == 0:
515            continue
516        if e == r * privkey.x:
517            continue
518        s = (modinv(k, q) * ((privkey.x * r) + e)) % q
519        if s == 0:
520            continue
521        OK = True
522    return ((expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT")), k)
523
524# *| IUF  - ECDSA verification
525# *|
526# *| I    1. Reject the signature if r or s is 0.
527# *|  UF  2. Compute h = H(m)
528# *|   F  3. If |h| > bitlen(q), set h to bitlen(q)
529# *|         leftmost (most significant) bits of h
530# *|   F  4. Compute e = OS2I(h) mod q
531# *|   F  5. Compute u = (s^-1)e mod q
532# *|   F  6. Compute v = (s^-1)r mod q
533# *|   F  7. Compute W' = uG + vY
534# *|   F  8. If W' is the point at infinity, reject the signature.
535# *|   F  9. Compute r' = W'_x mod q
536# *|   F 10. Accept the signature if and only if r equals r'
537def ecdsa_verify(hashfunc, keypair, message, sig):
538    pubkey = keypair.pubkey
539    # Get important parameters from the curve
540    p = pubkey.curve.p
541    q = pubkey.curve.q
542    gx = pubkey.curve.gx
543    gy = pubkey.curve.gy
544    q_limit_len = getbitlen(q)
545    G = Point(pubkey.curve, gx, gy)
546    # Extract r and s
547    if len(sig) != 2*getbytelen(q):
548        raise Exception("ECDSA verify: bad signature length!")
549    r = stringtoint(sig[0:int(len(sig)/2)])
550    s = stringtoint(sig[int(len(sig)/2):])
551    if r == 0 or s == 0:
552        return False
553    # Compute the hash
554    (h, _, _) = hashfunc(message)
555    # Truncate hash value
556    h = truncate(h, q_limit_len, "LEFT")
557    # Convert the hash value to an int
558    e = stringtoint(h) % q
559    u = (modinv(s, q) * e) % q
560    v = (modinv(s, q) * r) % q
561    W_ = (u * G) + (v * pubkey.Y)
562    if W_.x == None:
563        return False
564    r_ = W_.x % q
565    if r == r_:
566        return True
567    else:
568        return False
569
570def eckcdsa_genKeyPair(curve):
571    return genKeyPair(curve, True)
572
573# *| IUF  - ECKCDSA signature
574# *|
575# *| IUF  1. Compute h = H(z||m)
576# *|   F  2. If hsize > bitlen(q), set h to bitlen(q)
577# *|         rightmost (less significant) bits of h.
578# *|   F  3. Get a random value k in ]0,q[
579# *|   F  4. Compute W = (W_x,W_y) = kG
580# *|   F  5. Compute r = h(FE2OS(W_x)).
581# *|   F  6. If hsize > bitlen(q), set r to bitlen(q)
582# *|         rightmost (less significant) bits of r.
583# *|   F  7. Compute e = OS2I(r XOR h) mod q
584# *|   F  8. Compute s = x(k - e) mod q
585# *|   F  9. if s == 0, restart at step 3.
586# *|   F 10. return (r,s)
587def eckcdsa_sign(hashfunc, keypair, message, k=None):
588    privkey = keypair.privkey
589    # Get important parameters from the curve
590    p = privkey.curve.p
591    q = privkey.curve.q
592    gx = privkey.curve.gx
593    gy = privkey.curve.gy
594    G = Point(privkey.curve, gx, gy)
595    q_limit_len = getbitlen(q)
596    # Compute the certificate data
597    (_, _, hblocksize) = hashfunc("")
598    z = expand(inttostring(keypair.pubkey.Y.x), 8*getbytelen(p), "LEFT")
599    z = z + expand(inttostring(keypair.pubkey.Y.y), 8*getbytelen(p), "LEFT")
600    if len(z) > hblocksize:
601        # Truncate
602        z = truncate(z, 8*hblocksize, "LEFT")
603    else:
604        # Expand
605        z = expand(z, 8*hblocksize, "RIGHT")
606    # Compute the hash
607    (h, _, _) = hashfunc(z + message)
608    # Truncate hash value
609    h = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
610    OK = False
611    while OK == False:
612        if k == None:
613            k = getrandomint(q)
614        if k == 0:
615            continue
616        W = k * G
617        (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT"))
618        r = truncate(r, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
619        e = (stringtoint(r) ^ stringtoint(h)) % q
620        s = (privkey.x * (k - e)) % q
621        if s == 0:
622            continue
623        OK = True
624    return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
625
626# *| IUF - ECKCDSA verification
627# *|
628# *| I   1. Check the length of r:
629# *|         - if hsize > bitlen(q), r must be of
630# *|           length bitlen(q)
631# *|         - if hsize <= bitlen(q), r must be of
632# *|           length hsize
633# *| I   2. Check that s is in ]0,q[
634# *| IUF 3. Compute h = H(z||m)
635# *|   F 4. If hsize > bitlen(q), set h to bitlen(q)
636# *|        rightmost (less significant) bits of h.
637# *|   F 5. Compute e = OS2I(r XOR h) mod q
638# *|   F 6. Compute W' = sY + eG, where Y is the public key
639# *|   F 7. Compute r' = h(FE2OS(W'x))
640# *|   F 8. If hsize > bitlen(q), set r' to bitlen(q)
641# *|        rightmost (less significant) bits of r'.
642# *|   F 9. Check if r == r'
643def eckcdsa_verify(hashfunc, keypair, message, sig):
644    pubkey = keypair.pubkey
645    # Get important parameters from the curve
646    p = pubkey.curve.p
647    q = pubkey.curve.q
648    gx = pubkey.curve.gx
649    gy = pubkey.curve.gy
650    G = Point(pubkey.curve, gx, gy)
651    q_limit_len = getbitlen(q)
652    (_, hsize, hblocksize) = hashfunc("")
653    # Extract r and s
654    if (8*hsize) > q_limit_len:
655        r_len = int(math.ceil(q_limit_len / 8.))
656    else:
657        r_len = hsize
658    r = stringtoint(sig[0:int(r_len)])
659    s = stringtoint(sig[int(r_len):])
660    if (s >= q) or (s < 0):
661        return False
662    # Compute the certificate data
663    z = expand(inttostring(keypair.pubkey.Y.x), 8*getbytelen(p), "LEFT")
664    z = z + expand(inttostring(keypair.pubkey.Y.y), 8*getbytelen(p), "LEFT")
665    if len(z) > hblocksize:
666        # Truncate
667        z = truncate(z, 8*hblocksize, "LEFT")
668    else:
669        # Expand
670        z = expand(z, 8*hblocksize, "RIGHT")
671    # Compute the hash
672    (h, _, _) = hashfunc(z + message)
673    # Truncate hash value
674    h = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
675    e = (r ^ stringtoint(h)) % q
676    W_ = (s * pubkey.Y) + (e * G)
677    (h, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT"))
678    r_ = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT")
679    if stringtoint(r_) == r:
680        return True
681    else:
682        return False
683
684# *| IUF - ECFSDSA signature
685# *|
686# *| I   1. Get a random value k in ]0,q[
687# *| I   2. Compute W = (W_x,W_y) = kG
688# *| I   3. Compute r = FE2OS(W_x)||FE2OS(W_y)
689# *| I   4. If r is an all zero string, restart the process at step 1.
690# *| IUF 5. Compute h = H(r||m)
691# *|   F 6. Compute e = OS2I(h) mod q
692# *|   F 7. Compute s = (k + ex) mod q
693# *|   F 8. If s is 0, restart the process at step 1 (see c. below)
694# *|   F 9. Return (r,s)
695def ecfsdsa_sign(hashfunc, keypair, message, k=None):
696    privkey = keypair.privkey
697    # Get important parameters from the curve
698    p = privkey.curve.p
699    q = privkey.curve.q
700    gx = privkey.curve.gx
701    gy = privkey.curve.gy
702    G = Point(privkey.curve, gx, gy)
703    OK = False
704    while OK == False:
705        if k == None:
706            k = getrandomint(q)
707        if k == 0:
708            continue
709        W = k * G
710        r = expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W.y), 8*getbytelen(p), "LEFT")
711        if stringtoint(r) == 0:
712            continue
713        (h, _, _) = hashfunc(r + message)
714        e = stringtoint(h) % q
715        s = (k + e * privkey.x) % q
716        if s == 0:
717            continue
718        OK = True
719    return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
720
721
722# *| IUF - ECFSDSA verification
723# *|
724# *| I   1. Reject the signature if r is not a valid point on the curve.
725# *| I   2. Reject the signature if s is not in ]0,q[
726# *| IUF 3. Compute h = H(r||m)
727# *|   F 4. Convert h to an integer and then compute e = -h mod q
728# *|   F 5. compute W' = sG + eY, where Y is the public key
729# *|   F 6. Compute r' = FE2OS(W'_x)||FE2OS(W'_y)
730# *|   F 7. Accept the signature if and only if r equals r'
731def ecfsdsa_verify(hashfunc, keypair, message, sig):
732    pubkey = keypair.pubkey
733    # Get important parameters from the curve
734    p = pubkey.curve.p
735    q = pubkey.curve.q
736    gx = pubkey.curve.gx
737    gy = pubkey.curve.gy
738    G = Point(pubkey.curve, gx, gy)
739    # Extract coordinates from r and s from signature
740    if len(sig) != (2*getbytelen(p)) + getbytelen(q):
741        raise Exception("ECFSDSA verify: bad signature length!")
742    wx = sig[:int(getbytelen(p))]
743    wy = sig[int(getbytelen(p)):int(2*getbytelen(p))]
744    r = wx + wy
745    s = stringtoint(sig[int(2*getbytelen(p)):int((2*getbytelen(p))+getbytelen(q))])
746    # Check r is on the curve
747    W = Point(pubkey.curve, stringtoint(wx), stringtoint(wy))
748    # Check s is in ]0,q[
749    if s == 0 or s > q:
750        raise Exception("ECFSDSA verify: s not in ]0,q[")
751    (h, _, _) = hashfunc(r + message)
752    e = (-stringtoint(h)) % q
753    W_ = s * G + e * pubkey.Y
754    r_ = expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W_.y), 8*getbytelen(p), "LEFT")
755    if r == r_:
756        return True
757    else:
758        return False
759
760
761# NOTE: ISO/IEC 14888-3 standard seems to diverge from the existing implementations
762# of ECRDSA when treating the message hash, and from the examples of certificates provided
763# in RFC 7091 and draft-deremin-rfc4491-bis. While in ISO/IEC 14888-3 it is explicitely asked
764# to proceed with the hash of the message as big endian, the RFCs derived from the Russian
765# standard expect the hash value to be treated as little endian when importing it as an integer
766# (this discrepancy is exhibited and confirmed by test vectors present in ISO/IEC 14888-3, and
767# by X.509 certificates present in the RFCs). This seems (to be confirmed) to be a discrepancy of
768# ISO/IEC 14888-3 algorithm description that must be fixed there.
769#
770# In order to be conservative, libecc uses the Russian standard behavior as expected to be in line with
771# other implemetations, but keeps the ISO/IEC 14888-3 behavior if forced/asked by the user using
772# the USE_ISO14888_3_ECRDSA toggle. This allows to keep backward compatibility with previous versions of the
773# library if needed.
774
775# *| IUF - ECRDSA signature
776# *|
777# *|  UF  1. Compute h = H(m)
778# *|   F  2. Get a random value k in ]0,q[
779# *|   F  3. Compute W = (W_x,W_y) = kG
780# *|   F  4. Compute r = W_x mod q
781# *|   F  5. If r is 0, restart the process at step 2.
782# *|   F  6. Compute e = OS2I(h) mod q. If e is 0, set e to 1.
783# *|         NOTE: here, ISO/IEC 14888-3 and RFCs differ in the way e treated.
784# *|         e = OS2I(h) for ISO/IEC 14888-3, or e = OS2I(reversed(h)) when endianness of h
785# *|         is reversed for RFCs.
786# *|   F  7. Compute s = (rx + ke) mod q
787# *|   F  8. If s is 0, restart the process at step 2.
788# *|   F 11. Return (r,s)
789def ecrdsa_sign(hashfunc, keypair, message, k=None, use_iso14888_divergence=False):
790    privkey = keypair.privkey
791    # Get important parameters from the curve
792    p = privkey.curve.p
793    q = privkey.curve.q
794    gx = privkey.curve.gx
795    gy = privkey.curve.gy
796    G = Point(privkey.curve, gx, gy)
797    (h, _, _) = hashfunc(message)
798    if use_iso14888_divergence == False:
799        # Reverse the endianness for Russian standard RFC ECRDSA (contrary to ISO/IEC 14888-3 case)
800        h = h[::-1]
801    OK = False
802    while OK == False:
803        if k == None:
804            k = getrandomint(q)
805        if k == 0:
806            continue
807        W = k * G
808        r = W.x % q
809        if r == 0:
810            continue
811        e = stringtoint(h) % q
812        if e == 0:
813            e = 1
814        s = ((r * privkey.x) + (k * e)) % q
815        if s == 0:
816            continue
817        OK = True
818    return (expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
819
820# *| IUF - ECRDSA verification
821# *|
822# *|  UF 1. Check that r and s are both in ]0,q[
823# *|   F 2. Compute h = H(m)
824# *|   F 3. Compute e = OS2I(h)^-1 mod q
825# *|         NOTE: here, ISO/IEC 14888-3 and RFCs differ in the way e treated.
826# *|         e = OS2I(h) for ISO/IEC 14888-3, or e = OS2I(reversed(h)) when endianness of h
827# *|         is reversed for RFCs.
828# *|   F 4. Compute u = es mod q
829# *|   F 4. Compute v = -er mod q
830# *|   F 5. Compute W' = uG + vY = (W'_x, W'_y)
831# *|   F 6. Let's now compute r' = W'_x mod q
832# *|   F 7. Check r and r' are the same
833def ecrdsa_verify(hashfunc, keypair, message, sig, use_iso14888_divergence=False):
834    pubkey = keypair.pubkey
835    # Get important parameters from the curve
836    p = pubkey.curve.p
837    q = pubkey.curve.q
838    gx = pubkey.curve.gx
839    gy = pubkey.curve.gy
840    G = Point(pubkey.curve, gx, gy)
841    # Extract coordinates from r and s from signature
842    if len(sig) != 2*getbytelen(q):
843        raise Exception("ECRDSA verify: bad signature length!")
844    r = stringtoint(sig[:int(getbytelen(q))])
845    s = stringtoint(sig[int(getbytelen(q)):int(2*getbytelen(q))])
846    if r == 0 or r > q:
847        raise Exception("ECRDSA verify: r not in ]0,q[")
848    if s == 0 or s > q:
849        raise Exception("ECRDSA verify: s not in ]0,q[")
850    (h, _, _) = hashfunc(message)
851    if use_iso14888_divergence == False:
852        # Reverse the endianness for Russian standard RFC ECRDSA (contrary to ISO/IEC 14888-3 case)
853        h = h[::-1]
854    e = modinv(stringtoint(h) % q, q)
855    u = (e * s) % q
856    v = (-e * r) % q
857    W_ = u * G + v * pubkey.Y
858    r_ = W_.x % q
859    if r == r_:
860        return True
861    else:
862        return False
863
864
865# *| IUF - ECGDSA signature
866# *|
867# *|  UF 1. Compute h = H(m). If |h| > bitlen(q), set h to bitlen(q)
868# *|         leftmost (most significant) bits of h
869# *|   F 2. Convert e = - OS2I(h) mod q
870# *|   F 3. Get a random value k in ]0,q[
871# *|   F 4. Compute W = (W_x,W_y) = kG
872# *|   F 5. Compute r = W_x mod q
873# *|   F 6. If r is 0, restart the process at step 4.
874# *|   F 7. Compute s = x(kr + e) mod q
875# *|   F 8. If s is 0, restart the process at step 4.
876# *|   F 9. Return (r,s)
877def ecgdsa_sign(hashfunc, keypair, message, k=None):
878    privkey = keypair.privkey
879    # Get important parameters from the curve
880    p = privkey.curve.p
881    q = privkey.curve.q
882    gx = privkey.curve.gx
883    gy = privkey.curve.gy
884    G = Point(privkey.curve, gx, gy)
885    (h, _, _) = hashfunc(message)
886    q_limit_len = getbitlen(q)
887    # Truncate hash value
888    h = truncate(h, q_limit_len, "LEFT")
889    e = (-stringtoint(h)) % q
890    OK = False
891    while OK == False:
892        if k == None:
893            k = getrandomint(q)
894        if k == 0:
895            continue
896        W = k * G
897        r = W.x % q
898        if r == 0:
899            continue
900        s = (privkey.x * ((k * r) + e)) % q
901        if s == 0:
902            continue
903        OK = True
904    return (expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
905
906# *| IUF - ECGDSA verification
907# *|
908# *| I   1. Reject the signature if r or s is 0.
909# *|  UF 2. Compute h = H(m). If |h| > bitlen(q), set h to bitlen(q)
910# *|         leftmost (most significant) bits of h
911# *|   F 3. Compute e = OS2I(h) mod q
912# *|   F 4. Compute u = ((r^-1)e mod q)
913# *|   F 5. Compute v = ((r^-1)s mod q)
914# *|   F 6. Compute W' = uG + vY
915# *|   F 7. Compute r' = W'_x mod q
916# *|   F 8. Accept the signature if and only if r equals r'
917def ecgdsa_verify(hashfunc, keypair, message, sig):
918    pubkey = keypair.pubkey
919    # Get important parameters from the curve
920    p = pubkey.curve.p
921    q = pubkey.curve.q
922    gx = pubkey.curve.gx
923    gy = pubkey.curve.gy
924    G = Point(pubkey.curve, gx, gy)
925    # Extract coordinates from r and s from signature
926    if len(sig) != 2*getbytelen(q):
927        raise Exception("ECGDSA verify: bad signature length!")
928    r = stringtoint(sig[:int(getbytelen(q))])
929    s = stringtoint(sig[int(getbytelen(q)):int(2*getbytelen(q))])
930    if r == 0 or r > q:
931        raise Exception("ECGDSA verify: r not in ]0,q[")
932    if s == 0 or s > q:
933        raise Exception("ECGDSA verify: s not in ]0,q[")
934    (h, _, _) = hashfunc(message)
935    q_limit_len = getbitlen(q)
936    # Truncate hash value
937    h = truncate(h, q_limit_len, "LEFT")
938    e = stringtoint(h) % q
939    r_inv = modinv(r, q)
940    u = (r_inv * e) % q
941    v = (r_inv * s) % q
942    W_ = u * G + v * pubkey.Y
943    r_ = W_.x % q
944    if r == r_:
945        return True
946    else:
947        return False
948
949# *| IUF - ECSDSA/ECOSDSA signature
950# *|
951# *| I   1. Get a random value k in ]0, q[
952# *| I   2. Compute W = kG = (Wx, Wy)
953# *| IUF 3. Compute r = H(Wx [|| Wy] || m)
954# *|        - In the normal version (ECSDSA), r = h(Wx || Wy || m).
955# *|        - In the optimized version (ECOSDSA), r = h(Wx || m).
956# *|   F 4. Compute e = OS2I(r) mod q
957# *|   F 5. if e == 0, restart at step 1.
958# *|   F 6. Compute s = (k + ex) mod q.
959# *|   F 7. if s == 0, restart at step 1.
960# *|   F 8. Return (r, s)
961def ecsdsa_common_sign(hashfunc, keypair, message, optimized, k=None):
962    privkey = keypair.privkey
963    # Get important parameters from the curve
964    p = privkey.curve.p
965    q = privkey.curve.q
966    gx = privkey.curve.gx
967    gy = privkey.curve.gy
968    G = Point(privkey.curve, gx, gy)
969    OK = False
970    while OK == False:
971        if k == None:
972            k = getrandomint(q)
973        if k == 0:
974            continue
975        W = k * G
976        if optimized == False:
977            (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W.y), 8*getbytelen(p), "LEFT") + message)
978        else:
979            (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + message)
980        e = stringtoint(r) % q
981        if e == 0:
982            continue
983        s = (k + (e * privkey.x)) % q
984        if s == 0:
985            continue
986        OK = True
987    return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k)
988
989def ecsdsa_sign(hashfunc, keypair, message, k=None):
990    return ecsdsa_common_sign(hashfunc, keypair, message, False, k)
991
992def ecosdsa_sign(hashfunc, keypair, message, k=None):
993    return ecsdsa_common_sign(hashfunc, keypair, message, True, k)
994
995# *| IUF - ECSDSA/ECOSDSA verification
996# *|
997# *| I   1. if s is not in ]0,q[, reject the signature.x
998# *| I   2. Compute e = -r mod q
999# *| I   3. If e == 0, reject the signature.
1000# *| I   4. Compute W' = sG + eY
1001# *| IUF 5. Compute r' = H(W'x [|| W'y] || m)
1002# *|        - In the normal version (ECSDSA), r = h(W'x || W'y || m).
1003# *|        - In the optimized version (ECOSDSA), r = h(W'x || m).
1004# *|   F 6. Accept the signature if and only if r and r' are the same
1005def ecsdsa_common_verify(hashfunc, keypair, message, sig, optimized):
1006    pubkey = keypair.pubkey
1007    # Get important parameters from the curve
1008    p = pubkey.curve.p
1009    q = pubkey.curve.q
1010    gx = pubkey.curve.gx
1011    gy = pubkey.curve.gy
1012    G = Point(pubkey.curve, gx, gy)
1013    (_, hlen, _) = hashfunc("")
1014    # Extract coordinates from r and s from signature
1015    if len(sig) != hlen + getbytelen(q):
1016        raise Exception("EC[O]SDSA verify: bad signature length!")
1017    r = stringtoint(sig[:int(hlen)])
1018    s = stringtoint(sig[int(hlen):int(hlen+getbytelen(q))])
1019    if s == 0 or s > q:
1020        raise Exception("EC[O]DSA verify: s not in ]0,q[")
1021    e = (-r) % q
1022    if e == 0:
1023        raise Exception("EC[O]DSA verify: e is null")
1024    W_ = s * G + e * pubkey.Y
1025    if optimized == False:
1026        (r_, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W_.y), 8*getbytelen(p), "LEFT") + message)
1027    else:
1028        (r_, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + message)
1029    if sig[:int(hlen)] == r_:
1030        return True
1031    else:
1032        return False
1033
1034def ecsdsa_verify(hashfunc, keypair, message, sig):
1035    return ecsdsa_common_verify(hashfunc, keypair, message, sig, False)
1036
1037def ecosdsa_verify(hashfunc, keypair, message, sig):
1038    return ecsdsa_common_verify(hashfunc, keypair, message, sig, True)
1039
1040
1041##########################################################
1042### Generate self-tests for all the algorithms
1043
1044all_hash_funcs = [ (sha224, "SHA224"), (sha256, "SHA256"), (sha384, "SHA384"), (sha512, "SHA512"), (sha3_224, "SHA3_224"), (sha3_256, "SHA3_256"), (sha3_384, "SHA3_384"), (sha3_512, "SHA3_512") ]
1045
1046all_sig_algs = [ (ecdsa_sign, ecdsa_verify, genKeyPair, "ECDSA"),
1047         (eckcdsa_sign, eckcdsa_verify, eckcdsa_genKeyPair, "ECKCDSA"),
1048         (ecfsdsa_sign, ecfsdsa_verify, genKeyPair, "ECFSDSA"),
1049         (ecrdsa_sign, ecrdsa_verify, genKeyPair, "ECRDSA"),
1050         (ecgdsa_sign, ecgdsa_verify, eckcdsa_genKeyPair, "ECGDSA"),
1051         (ecsdsa_sign, ecsdsa_verify, genKeyPair, "ECSDSA"),
1052         (ecosdsa_sign, ecosdsa_verify, genKeyPair, "ECOSDSA"), ]
1053
1054
1055curr_test = 0
1056def pretty_print_curr_test(num_test, total_gen_tests):
1057    num_decimal = int(math.log10(total_gen_tests))+1
1058    format_buf = "%0"+str(num_decimal)+"d/%0"+str(num_decimal)+"d"
1059    sys.stdout.write('\b'*((2*num_decimal)+1))
1060    sys.stdout.flush()
1061    sys.stdout.write(format_buf % (num_test, total_gen_tests))
1062    if num_test == total_gen_tests:
1063        print("")
1064    return
1065
1066def gen_self_test(curve, hashfunc, sig_alg_sign, sig_alg_verify, sig_alg_genkeypair, num, hashfunc_name, sig_alg_name, total_gen_tests):
1067    global curr_test
1068    curr_test = curr_test + 1
1069    if num != 0:
1070        pretty_print_curr_test(curr_test, total_gen_tests)
1071    output_list = []
1072    for test_num in range(0, num):
1073        out_vectors = ""
1074        # Generate a random key pair
1075        keypair = sig_alg_genkeypair(curve)
1076        # Generate a random message with a random size
1077        size = getrandomint(256)
1078        if is_python_2():
1079            message = ''.join([random.choice(string.ascii_letters + string.digits) for n in xrange(size)])
1080        else:
1081            message = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(size)])
1082        test_name = sig_alg_name + "_" + hashfunc_name + "_" + curve.name.upper() + "_" + str(test_num)
1083        # Sign the message
1084        (sig, k) = sig_alg_sign(hashfunc, keypair, message)
1085        # Check that everything is OK with a verify
1086        if sig_alg_verify(hashfunc, keypair, message, sig) != True:
1087            raise Exception("Error during self test generation: sig verify failed! "+test_name+ "   /  msg="+message+"   /   sig="+binascii.hexlify(sig)+"    /    k="+hex(k)+"   /   privkey.x="+hex(keypair.privkey.x))
1088        if sig_alg_name == "ECRDSA":
1089            out_vectors += "#ifndef USE_ISO14888_3_ECRDSA\n"
1090        # Now generate the test vector
1091        out_vectors += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"\n"
1092        out_vectors += "#ifdef WITH_CURVE_"+curve.name.upper()+"\n"
1093        out_vectors += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"\n"
1094        out_vectors += "/* "+test_name+" known test vectors */\n"
1095        out_vectors += "static int "+test_name+"_test_vectors_get_random(nn_t out, nn_src_t q)\n{\n"
1096        # k_buf MUST be exported padded to the length of q
1097        out_vectors += "\tconst u8 k_buf[] = "+bigint_to_C_array(k, getbytelen(curve.q))
1098        out_vectors += "\tint ret, cmp;\n\tret = nn_init_from_buf(out, k_buf, sizeof(k_buf)); EG(ret, err);\n\tret = nn_cmp(out, q, &cmp); EG(ret, err);\n\tret = (cmp >= 0) ? -1 : 0;\nerr:\n\treturn ret;\n}\n"
1099        out_vectors += "static const u8 "+test_name+"_test_vectors_priv_key[] = \n"+bigint_to_C_array(keypair.privkey.x, getbytelen(keypair.privkey.x))
1100        out_vectors += "static const u8 "+test_name+"_test_vectors_expected_sig[] = \n"+bigint_to_C_array(stringtoint(sig), len(sig))
1101        out_vectors += "static const ec_test_case "+test_name+"_test_case = {\n"
1102        out_vectors += "\t.name = \""+test_name+"\",\n"
1103        out_vectors += "\t.ec_str_p = &"+curve.name+"_str_params,\n"
1104        out_vectors += "\t.priv_key = "+test_name+"_test_vectors_priv_key,\n"
1105        out_vectors += "\t.priv_key_len = sizeof("+test_name+"_test_vectors_priv_key),\n"
1106        out_vectors += "\t.nn_random = "+test_name+"_test_vectors_get_random,\n"
1107        out_vectors += "\t.hash_type = "+hashfunc_name+",\n"
1108        out_vectors += "\t.msg = \""+message+"\",\n"
1109        out_vectors += "\t.msglen = "+str(len(message))+",\n"
1110        out_vectors += "\t.sig_type = "+sig_alg_name+",\n"
1111        out_vectors += "\t.exp_sig = "+test_name+"_test_vectors_expected_sig,\n"
1112        out_vectors += "\t.exp_siglen = sizeof("+test_name+"_test_vectors_expected_sig),\n};\n"
1113        out_vectors += "#endif /* WITH_HASH_"+hashfunc_name+" */\n"
1114        out_vectors += "#endif /* WITH_CURVE_"+curve.name+" */\n"
1115        out_vectors += "#endif /* WITH_SIG_"+sig_alg_name+" */\n"
1116        if sig_alg_name == "ECRDSA":
1117            out_vectors += "#endif /* !USE_ISO14888_3_ECRDSA */\n"
1118        out_name = ""
1119        if sig_alg_name == "ECRDSA":
1120            out_name += "#ifndef USE_ISO14888_3_ECRDSA"+"/* For "+test_name+" */\n"
1121        out_name += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"/* For "+test_name+" */\n"
1122        out_name += "#ifdef WITH_CURVE_"+curve.name.upper()+"/* For "+test_name+" */\n"
1123        out_name += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"/* For "+test_name+" */\n"
1124        out_name += "\t&"+test_name+"_test_case,\n"
1125        out_name += "#endif /* WITH_HASH_"+hashfunc_name+" for "+test_name+" */\n"
1126        out_name += "#endif /* WITH_CURVE_"+curve.name+" for "+test_name+" */\n"
1127        out_name += "#endif /* WITH_SIG_"+sig_alg_name+" for "+test_name+" */"
1128        if sig_alg_name == "ECRDSA":
1129            out_name += "\n#endif /* !USE_ISO14888_3_ECRDSA */"+"/* For "+test_name+" */"
1130        output_list.append((out_name, out_vectors))
1131        # In the specific case of ECRDSA, we also generate an ISO/IEC compatible test vector
1132        if sig_alg_name == "ECRDSA":
1133            out_vectors = ""
1134            (sig, k) = sig_alg_sign(hashfunc, keypair, message, use_iso14888_divergence=True)
1135            # Check that everything is OK with a verify
1136            if sig_alg_verify(hashfunc, keypair, message, sig, use_iso14888_divergence=True) != True:
1137                raise Exception("Error during self test generation: sig verify failed! "+test_name+ "   /  msg="+message+"   /   sig="+binascii.hexlify(sig)+"    /    k="+hex(k)+"   /   privkey.x="+hex(keypair.privkey.x))
1138            out_vectors += "#ifdef USE_ISO14888_3_ECRDSA\n"
1139            # Now generate the test vector
1140            out_vectors += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"\n"
1141            out_vectors += "#ifdef WITH_CURVE_"+curve.name.upper()+"\n"
1142            out_vectors += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"\n"
1143            out_vectors += "/* "+test_name+" known test vectors */\n"
1144            out_vectors += "static int "+test_name+"_test_vectors_get_random(nn_t out, nn_src_t q)\n{\n"
1145            # k_buf MUST be exported padded to the length of q
1146            out_vectors += "\tconst u8 k_buf[] = "+bigint_to_C_array(k, getbytelen(curve.q))
1147            out_vectors += "\tint ret, cmp;\n\tret = nn_init_from_buf(out, k_buf, sizeof(k_buf)); EG(ret, err);\n\tret = nn_cmp(out, q, &cmp); EG(ret, err);\n\tret = (cmp >= 0) ? -1 : 0;\nerr:\n\treturn ret;\n}\n"
1148            out_vectors += "static const u8 "+test_name+"_test_vectors_priv_key[] = \n"+bigint_to_C_array(keypair.privkey.x, getbytelen(keypair.privkey.x))
1149            out_vectors += "static const u8 "+test_name+"_test_vectors_expected_sig[] = \n"+bigint_to_C_array(stringtoint(sig), len(sig))
1150            out_vectors += "static const ec_test_case "+test_name+"_test_case = {\n"
1151            out_vectors += "\t.name = \""+test_name+"\",\n"
1152            out_vectors += "\t.ec_str_p = &"+curve.name+"_str_params,\n"
1153            out_vectors += "\t.priv_key = "+test_name+"_test_vectors_priv_key,\n"
1154            out_vectors += "\t.priv_key_len = sizeof("+test_name+"_test_vectors_priv_key),\n"
1155            out_vectors += "\t.nn_random = "+test_name+"_test_vectors_get_random,\n"
1156            out_vectors += "\t.hash_type = "+hashfunc_name+",\n"
1157            out_vectors += "\t.msg = \""+message+"\",\n"
1158            out_vectors += "\t.msglen = "+str(len(message))+",\n"
1159            out_vectors += "\t.sig_type = "+sig_alg_name+",\n"
1160            out_vectors += "\t.exp_sig = "+test_name+"_test_vectors_expected_sig,\n"
1161            out_vectors += "\t.exp_siglen = sizeof("+test_name+"_test_vectors_expected_sig),\n};\n"
1162            out_vectors += "#endif /* WITH_HASH_"+hashfunc_name+" */\n"
1163            out_vectors += "#endif /* WITH_CURVE_"+curve.name+" */\n"
1164            out_vectors += "#endif /* WITH_SIG_"+sig_alg_name+" */\n"
1165            out_vectors += "#endif /* USE_ISO14888_3_ECRDSA */\n"
1166            out_name = ""
1167            out_name += "#ifdef USE_ISO14888_3_ECRDSA"+"/* For "+test_name+" */\n"
1168            out_name += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"/* For "+test_name+" */\n"
1169            out_name += "#ifdef WITH_CURVE_"+curve.name.upper()+"/* For "+test_name+" */\n"
1170            out_name += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"/* For "+test_name+" */\n"
1171            out_name += "\t&"+test_name+"_test_case,\n"
1172            out_name += "#endif /* WITH_HASH_"+hashfunc_name+" for "+test_name+" */\n"
1173            out_name += "#endif /* WITH_CURVE_"+curve.name+" for "+test_name+" */\n"
1174            out_name += "#endif /* WITH_SIG_"+sig_alg_name+" for "+test_name+" */\n"
1175            out_name += "#endif /* USE_ISO14888_3_ECRDSA */"+"/* For "+test_name+" */"
1176            output_list.append((out_name, out_vectors))
1177
1178    return output_list
1179
1180def gen_self_tests(curve, num):
1181    global curr_test
1182    curr_test = 0
1183    total_gen_tests = len(all_hash_funcs) * len(all_sig_algs)
1184    vectors = [[ gen_self_test(curve, hashf, sign, verify, genkp, num, hash_name, sig_alg_name, total_gen_tests)
1185               for (hashf, hash_name) in all_hash_funcs ] for (sign, verify, genkp, sig_alg_name) in all_sig_algs ]
1186    return vectors
1187
1188##########################################################
1189### ASN.1 stuff
1190def parse_DER_extract_size(derbuf):
1191    # Extract the size
1192    if ord(derbuf[0]) & 0x80 != 0:
1193        encoding_len_bytes = ord(derbuf[0]) & ~0x80
1194        # Skip
1195        base = 1
1196    else:
1197        encoding_len_bytes = 1
1198        base = 0
1199    if len(derbuf) < encoding_len_bytes+1:
1200        return (False, 0, 0)
1201    else:
1202        length = stringtoint(derbuf[base:base+encoding_len_bytes])
1203        if len(derbuf) < length+encoding_len_bytes:
1204            return (False, 0, 0)
1205        else:
1206            return (True, encoding_len_bytes+base, length)
1207
1208def extract_DER_object(derbuf, object_tag):
1209    # Check type
1210    if ord(derbuf[0]) != object_tag:
1211        # Not the type we expect ...
1212        return (False, 0, "")
1213    else:
1214        derbuf = derbuf[1:]
1215        # Extract the size
1216        (check, encoding_len, size) = parse_DER_extract_size(derbuf)
1217        if check == False:
1218            return (False, 0, "")
1219        else:
1220            if len(derbuf) < encoding_len + size:
1221                return (False, 0, "")
1222            else:
1223                return (True, size+encoding_len+1, derbuf[encoding_len:encoding_len+size])
1224
1225def extract_DER_sequence(derbuf):
1226    return extract_DER_object(derbuf, 0x30)
1227
1228def extract_DER_integer(derbuf):
1229    return extract_DER_object(derbuf, 0x02)
1230
1231def extract_DER_octetstring(derbuf):
1232    return extract_DER_object(derbuf, 0x04)
1233
1234def extract_DER_bitstring(derbuf):
1235    return extract_DER_object(derbuf, 0x03)
1236
1237def extract_DER_oid(derbuf):
1238    return extract_DER_object(derbuf, 0x06)
1239
1240# See ECParameters sequence in RFC 3279
1241def parse_DER_ECParameters(derbuf):
1242    # XXX: this is a very ugly way of extracting the information
1243    # regarding an EC curve, but since the ASN.1 structure is quite
1244    # "static", this might be sufficient without embedding a full
1245    # ASN.1 parser ...
1246    # Default return (a, b, prime, order, cofactor, gx, gy)
1247    default_ret = (0, 0, 0, 0, 0, 0, 0)
1248    # Get ECParameters wrapping sequence
1249    (check, size_ECParameters, ECParameters) = extract_DER_sequence(derbuf)
1250    if check == False:
1251        return (False, default_ret)
1252    # Get integer
1253    (check, size_ECPVer, ECPVer) = extract_DER_integer(ECParameters)
1254    if check == False:
1255        return (False, default_ret)
1256    # Get sequence
1257    (check, size_FieldID, FieldID) = extract_DER_sequence(ECParameters[size_ECPVer:])
1258    if check == False:
1259        return (False, default_ret)
1260    # Get OID
1261    (check, size_Oid, Oid) = extract_DER_oid(FieldID)
1262    if check == False:
1263        return (False, default_ret)
1264    # Does the OID correspond to a prime field?
1265    if(Oid != "\x2A\x86\x48\xCE\x3D\x01\x01"):
1266        print("DER parse error: only prime fields are supported ...")
1267        return (False, default_ret)
1268    # Get prime p of prime field
1269    (check, size_P, P) = extract_DER_integer(FieldID[size_Oid:])
1270    if check == False:
1271        return (False, default_ret)
1272    # Get curve (sequence)
1273    (check, size_Curve, Curve) = extract_DER_sequence(ECParameters[size_ECPVer+size_FieldID:])
1274    if check == False:
1275        return (False, default_ret)
1276    # Get A in curve
1277    (check, size_A, A) = extract_DER_octetstring(Curve)
1278    if check == False:
1279        return (False, default_ret)
1280    # Get B in curve
1281    (check, size_B, B) = extract_DER_octetstring(Curve[size_A:])
1282    if check == False:
1283        return (False, default_ret)
1284    # Get ECPoint
1285    (check, size_ECPoint, ECPoint) = extract_DER_octetstring(ECParameters[size_ECPVer+size_FieldID+size_Curve:])
1286    if check == False:
1287        return (False, default_ret)
1288    # Get Order
1289    (check, size_Order, Order) = extract_DER_integer(ECParameters[size_ECPVer+size_FieldID+size_Curve+size_ECPoint:])
1290    if check == False:
1291        return (False, default_ret)
1292    # Get Cofactor
1293    (check, size_Cofactor, Cofactor) = extract_DER_integer(ECParameters[size_ECPVer+size_FieldID+size_Curve+size_ECPoint+size_Order:])
1294    if check == False:
1295        return (False, default_ret)
1296    # If we end up here, everything is OK, we can extract all our elements
1297    prime = stringtoint(P)
1298    a = stringtoint(A)
1299    b = stringtoint(B)
1300    order = stringtoint(Order)
1301    cofactor = stringtoint(Cofactor)
1302    # Extract Gx and Gy, see X9.62-1998
1303    if len(ECPoint) < 1:
1304        return (False, default_ret)
1305    ECPoint_type = ord(ECPoint[0])
1306    if (ECPoint_type == 0x04) or (ECPoint_type == 0x06) or (ECPoint_type == 0x07):
1307        # Uncompressed and hybrid points
1308        if len(ECPoint[1:]) % 2 != 0:
1309            return (False, default_ret)
1310        ECPoint = ECPoint[1:]
1311        gx = stringtoint(ECPoint[:int(len(ECPoint)/2)])
1312        gy = stringtoint(ECPoint[int(len(ECPoint)/2):])
1313    elif (ECPoint_type == 0x02) or (ECPoint_type == 0x03):
1314        # Compressed point: uncompress it, see X9.62-1998 section 4.2.1
1315        ECPoint = ECPoint[1:]
1316        gx = stringtoint(ECPoint)
1317        alpha = (pow(gx, 3, prime) + (a * gx) + b) % prime
1318        beta = mod_sqrt(alpha, prime)
1319        if (beta == None) or ((beta == 0) and (alpha != 0)):
1320            return (False, 0)
1321        if (beta & 0x1) == (ECPoint_type & 0x1):
1322            gy = beta
1323        else:
1324            gy = prime - beta
1325    else:
1326        print("DER parse error: hybrid points are unsupported!")
1327        return (False, default_ret)
1328    return (True, (a, b, prime, order, cofactor, gx, gy))
1329
1330##########################################################
1331### Text and format helpers
1332def bigint_to_C_array(bint, size):
1333    """
1334    Format a python big int to a C hex array
1335    """
1336    hexstr = format(int(bint), 'x')
1337    # Left pad to the size!
1338    hexstr = ("0"*int((2*size)-len(hexstr)))+hexstr
1339    hexstr = ("0"*(len(hexstr) % 2))+hexstr
1340    out_str = "{\n"
1341    for i in range(0, len(hexstr) - 1, 2):
1342        if (i%16 == 0):
1343            if(i!=0):
1344                out_str += "\n"
1345            out_str += "\t"
1346        out_str += "0x"+hexstr[i:i+2]+", "
1347    out_str += "\n};\n"
1348    return out_str
1349
1350def check_in_file(fname, pat):
1351    # See if the pattern is in the file.
1352    with open(fname) as f:
1353        if not any(re.search(pat, line) for line in f):
1354            return False # pattern does not occur in file so we are done.
1355        else:
1356            return True
1357
1358def num_patterns_in_file(fname, pat):
1359    num_pat = 0
1360    with open(fname) as f:
1361        for line in f:
1362            if re.search(pat, line):
1363                num_pat = num_pat+1
1364    return num_pat
1365
1366def file_replace_pattern(fname, pat, s_after):
1367    # first, see if the pattern is even in the file.
1368    with open(fname) as f:
1369        if not any(re.search(pat, line) for line in f):
1370            return # pattern does not occur in file so we are done.
1371
1372    # pattern is in the file, so perform replace operation.
1373    with open(fname) as f:
1374        out_fname = fname + ".tmp"
1375        out = open(out_fname, "w")
1376        for line in f:
1377            out.write(re.sub(pat, s_after, line))
1378        out.close()
1379        os.rename(out_fname, fname)
1380
1381def file_remove_pattern(fname, pat):
1382    # first, see if the pattern is even in the file.
1383    with open(fname) as f:
1384        if not any(re.search(pat, line) for line in f):
1385            return # pattern does not occur in file so we are done.
1386
1387    # pattern is in the file, so perform remove operation.
1388    with open(fname) as f:
1389        out_fname = fname + ".tmp"
1390        out = open(out_fname, "w")
1391        for line in f:
1392            if not re.search(pat, line):
1393                out.write(line)
1394        out.close()
1395
1396    if os.path.exists(fname):
1397        remove_file(fname)
1398    os.rename(out_fname, fname)
1399
1400def remove_file(fname):
1401    # Remove file
1402    os.remove(fname)
1403
1404def remove_files_pattern(fpattern):
1405    [remove_file(x) for x in glob.glob(fpattern)]
1406
1407def buffer_remove_pattern(buff, pat):
1408    if is_python_2() == False:
1409        buff = buff.decode('latin-1')
1410    if re.search(pat, buff) == None:
1411        return (False, buff) # pattern does not occur in file so we are done.
1412    # Remove the pattern
1413    buff = re.sub(pat, "", buff)
1414    return (True, buff)
1415
1416def is_base64(s):
1417    s = ''.join([s.strip() for s in s.split("\n")])
1418    try:
1419        enc = base64.b64encode(base64.b64decode(s)).strip()
1420        if type(enc) is bytes:
1421            return enc == s.encode('latin-1')
1422        else:
1423            return enc == s
1424    except TypeError:
1425        return False
1426
1427### Curve helpers
1428def export_curve_int(curvename, intname, bigint, size):
1429    if bigint == None:
1430        out  = "static const u8 "+curvename+"_"+intname+"[] = {\n\t0x00,\n};\n"
1431        out += "TO_EC_STR_PARAM_FIXED_SIZE("+curvename+"_"+intname+", 0);\n\n"
1432    else:
1433        out  = "static const u8 "+curvename+"_"+intname+"[] = "+bigint_to_C_array(bigint, size)+"\n"
1434        out += "TO_EC_STR_PARAM("+curvename+"_"+intname+");\n\n"
1435    return out
1436
1437def export_curve_string(curvename, stringname, stringvalue):
1438    out  = "static const u8 "+curvename+"_"+stringname+"[] = \""+stringvalue+"\";\n"
1439    out += "TO_EC_STR_PARAM("+curvename+"_"+stringname+");\n\n"
1440    return out
1441
1442def export_curve_struct(curvename, paramname, paramnamestr):
1443    return "\t."+paramname+" = &"+curvename+"_"+paramnamestr+"_str_param, \n"
1444
1445def curve_params(name, prime, pbitlen, a, b, gx, gy, order, cofactor, oid, alpha_montgomery, gamma_montgomery, alpha_edwards):
1446    """
1447    Take as input some elliptic curve parameters and generate the
1448    C parameters in a string
1449    """
1450    bytesize = int(pbitlen / 8)
1451    if pbitlen % 8 != 0:
1452        bytesize += 1
1453    # Compute the rounded word size for each word size
1454    if bytesize % 8 != 0:
1455        wordsbitsize64 = 8*((int(bytesize/8)+1)*8)
1456    else:
1457        wordsbitsize64 = 8*bytesize
1458    if bytesize % 4 != 0:
1459        wordsbitsize32 = 8*((int(bytesize/4)+1)*4)
1460    else:
1461        wordsbitsize32 = 8*bytesize
1462    if bytesize % 2 != 0:
1463        wordsbitsize16 = 8*((int(bytesize/2)+1)*2)
1464    else:
1465        wordsbitsize16 = 8*bytesize
1466    # Compute some parameters
1467    (r64, r_square64, mpinv64) = compute_monty_coef(prime, wordsbitsize64, 64)
1468    (r32, r_square32, mpinv32) = compute_monty_coef(prime, wordsbitsize32, 32)
1469    (r16, r_square16, mpinv16) = compute_monty_coef(prime, wordsbitsize16, 16)
1470    # Compute p_reciprocal for each word size
1471    (pshift64, primenorm64, p_reciprocal64) = compute_div_coef(prime, wordsbitsize64, 64)
1472    (pshift32, primenorm32, p_reciprocal32) = compute_div_coef(prime, wordsbitsize32, 32)
1473    (pshift16, primenorm16, p_reciprocal16) = compute_div_coef(prime, wordsbitsize16, 16)
1474    # Compute the number of points on the curve
1475    npoints = order * cofactor
1476
1477    # Now output the parameters
1478    ec_params_string =  "#include <libecc/lib_ecc_config.h>\n"
1479    ec_params_string += "#ifdef WITH_CURVE_"+name.upper()+"\n\n"
1480    ec_params_string += "#ifndef __EC_PARAMS_"+name.upper()+"_H__\n"
1481    ec_params_string += "#define __EC_PARAMS_"+name.upper()+"_H__\n"
1482    ec_params_string += "#include <libecc/curves/known/ec_params_external.h>\n"
1483    ec_params_string += export_curve_int(name, "p", prime, bytesize)
1484
1485    ec_params_string += "#define CURVE_"+name.upper()+"_P_BITLEN "+str(pbitlen)+"\n"
1486    ec_params_string += export_curve_int(name, "p_bitlen", pbitlen, getbytelen(pbitlen))
1487
1488    ec_params_string += "#if (WORD_BYTES == 8)     /* 64-bit words */\n"
1489    ec_params_string += export_curve_int(name, "r", r64, getbytelen(r64))
1490    ec_params_string += export_curve_int(name, "r_square", r_square64, getbytelen(r_square64))
1491    ec_params_string += export_curve_int(name, "mpinv", mpinv64, getbytelen(mpinv64))
1492    ec_params_string += export_curve_int(name, "p_shift", pshift64, getbytelen(pshift64))
1493    ec_params_string += export_curve_int(name, "p_normalized", primenorm64, getbytelen(primenorm64))
1494    ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal64, getbytelen(p_reciprocal64))
1495    ec_params_string += "#elif (WORD_BYTES == 4)   /* 32-bit words */\n"
1496    ec_params_string += export_curve_int(name, "r", r32, getbytelen(r32))
1497    ec_params_string += export_curve_int(name, "r_square", r_square32, getbytelen(r_square32))
1498    ec_params_string += export_curve_int(name, "mpinv", mpinv32, getbytelen(mpinv32))
1499    ec_params_string += export_curve_int(name, "p_shift", pshift32, getbytelen(pshift32))
1500    ec_params_string += export_curve_int(name, "p_normalized", primenorm32, getbytelen(primenorm32))
1501    ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal32, getbytelen(p_reciprocal32))
1502    ec_params_string += "#elif (WORD_BYTES == 2)   /* 16-bit words */\n"
1503    ec_params_string += export_curve_int(name, "r", r16, getbytelen(r16))
1504    ec_params_string += export_curve_int(name, "r_square", r_square16, getbytelen(r_square16))
1505    ec_params_string += export_curve_int(name, "mpinv", mpinv16, getbytelen(mpinv16))
1506    ec_params_string += export_curve_int(name, "p_shift", pshift16, getbytelen(pshift16))
1507    ec_params_string += export_curve_int(name, "p_normalized", primenorm16, getbytelen(primenorm16))
1508    ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal16, getbytelen(p_reciprocal16))
1509    ec_params_string += "#else                     /* unknown word size */\n"
1510    ec_params_string += "#error \"Unsupported word size\"\n"
1511    ec_params_string += "#endif\n\n"
1512
1513    ec_params_string += export_curve_int(name, "a", a, bytesize)
1514    ec_params_string += export_curve_int(name, "b", b, bytesize)
1515
1516    curve_order_bitlen = getbitlen(npoints)
1517    ec_params_string += "#define CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN "+str(curve_order_bitlen)+"\n"
1518    ec_params_string += export_curve_int(name, "curve_order", npoints, getbytelen(npoints))
1519
1520    ec_params_string += export_curve_int(name, "gx", gx, bytesize)
1521    ec_params_string += export_curve_int(name, "gy", gy, bytesize)
1522    ec_params_string += export_curve_int(name, "gz", 0x01, bytesize)
1523
1524    qbitlen = getbitlen(order)
1525
1526    ec_params_string += export_curve_int(name, "gen_order", order, getbytelen(order))
1527    ec_params_string += "#define CURVE_"+name.upper()+"_Q_BITLEN "+str(qbitlen)+"\n"
1528    ec_params_string += export_curve_int(name, "gen_order_bitlen", qbitlen, getbytelen(qbitlen))
1529
1530    ec_params_string += export_curve_int(name, "cofactor", cofactor, getbytelen(cofactor))
1531
1532    ec_params_string += export_curve_int(name, "alpha_montgomery", alpha_montgomery, getbytelen(alpha_montgomery))
1533    ec_params_string += export_curve_int(name, "gamma_montgomery", gamma_montgomery, getbytelen(gamma_montgomery))
1534    ec_params_string += export_curve_int(name, "alpha_edwards", alpha_edwards, getbytelen(alpha_edwards))
1535
1536    ec_params_string += export_curve_string(name, "name", name.upper());
1537
1538    if oid == None:
1539        oid = ""
1540    ec_params_string += export_curve_string(name, "oid", oid);
1541
1542    ec_params_string += "static const ec_str_params "+name+"_str_params = {\n"+\
1543    export_curve_struct(name, "p", "p") +\
1544    export_curve_struct(name, "p_bitlen", "p_bitlen") +\
1545    export_curve_struct(name, "r", "r") +\
1546    export_curve_struct(name, "r_square", "r_square") +\
1547    export_curve_struct(name, "mpinv", "mpinv") +\
1548    export_curve_struct(name, "p_shift", "p_shift") +\
1549    export_curve_struct(name, "p_normalized", "p_normalized") +\
1550    export_curve_struct(name, "p_reciprocal", "p_reciprocal") +\
1551    export_curve_struct(name, "a", "a") +\
1552    export_curve_struct(name, "b", "b") +\
1553    export_curve_struct(name, "curve_order", "curve_order") +\
1554    export_curve_struct(name, "gx", "gx") +\
1555    export_curve_struct(name, "gy", "gy") +\
1556    export_curve_struct(name, "gz", "gz") +\
1557    export_curve_struct(name, "gen_order", "gen_order") +\
1558    export_curve_struct(name, "gen_order_bitlen", "gen_order_bitlen") +\
1559    export_curve_struct(name, "cofactor", "cofactor") +\
1560    export_curve_struct(name, "alpha_montgomery", "alpha_montgomery") +\
1561    export_curve_struct(name, "gamma_montgomery", "gamma_montgomery") +\
1562    export_curve_struct(name, "alpha_edwards", "alpha_edwards") +\
1563    export_curve_struct(name, "oid", "oid") +\
1564    export_curve_struct(name, "name", "name")
1565    ec_params_string += "};\n\n"
1566
1567    ec_params_string += "/*\n"+\
1568    " * Compute max bit length of all curves for p and q\n"+\
1569    " */\n"+\
1570    "#ifndef CURVES_MAX_P_BIT_LEN\n"+\
1571    "#define CURVES_MAX_P_BIT_LEN    0\n"+\
1572    "#endif\n"+\
1573    "#if (CURVES_MAX_P_BIT_LEN < CURVE_"+name.upper()+"_P_BITLEN)\n"+\
1574    "#undef CURVES_MAX_P_BIT_LEN\n"+\
1575    "#define CURVES_MAX_P_BIT_LEN CURVE_"+name.upper()+"_P_BITLEN\n"+\
1576    "#endif\n"+\
1577    "#ifndef CURVES_MAX_Q_BIT_LEN\n"+\
1578    "#define CURVES_MAX_Q_BIT_LEN    0\n"+\
1579    "#endif\n"+\
1580    "#if (CURVES_MAX_Q_BIT_LEN < CURVE_"+name.upper()+"_Q_BITLEN)\n"+\
1581    "#undef CURVES_MAX_Q_BIT_LEN\n"+\
1582    "#define CURVES_MAX_Q_BIT_LEN CURVE_"+name.upper()+"_Q_BITLEN\n"+\
1583    "#endif\n"+\
1584    "#ifndef CURVES_MAX_CURVE_ORDER_BIT_LEN\n"+\
1585    "#define CURVES_MAX_CURVE_ORDER_BIT_LEN    0\n"+\
1586    "#endif\n"+\
1587    "#if (CURVES_MAX_CURVE_ORDER_BIT_LEN < CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN)\n"+\
1588    "#undef CURVES_MAX_CURVE_ORDER_BIT_LEN\n"+\
1589    "#define CURVES_MAX_CURVE_ORDER_BIT_LEN CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN\n"+\
1590    "#endif\n\n"
1591
1592    ec_params_string += "/*\n"+\
1593    " * Compute and adapt max name and oid length\n"+\
1594    " */\n"+\
1595    "#ifndef MAX_CURVE_OID_LEN\n"+\
1596    "#define MAX_CURVE_OID_LEN 0\n"+\
1597    "#endif\n"+\
1598    "#ifndef MAX_CURVE_NAME_LEN\n"+\
1599    "#define MAX_CURVE_NAME_LEN 0\n"+\
1600    "#endif\n"+\
1601    "#if (MAX_CURVE_OID_LEN < "+str(len(oid)+1)+")\n"+\
1602    "#undef MAX_CURVE_OID_LEN\n"+\
1603    "#define MAX_CURVE_OID_LEN "+str(len(oid)+1)+"\n"+\
1604    "#endif\n"+\
1605    "#if (MAX_CURVE_NAME_LEN < "+str(len(name.upper())+1)+")\n"+\
1606    "#undef MAX_CURVE_NAME_LEN\n"+\
1607    "#define MAX_CURVE_NAME_LEN "+str(len(name.upper())+1)+"\n"+\
1608    "#endif\n\n"
1609
1610    ec_params_string += "#endif /* __EC_PARAMS_"+name.upper()+"_H__ */\n\n"+"#endif /* WITH_CURVE_"+name.upper()+" */\n"
1611
1612    return ec_params_string
1613
1614def usage():
1615    print("This script is intented to *statically* expand the ECC library with user defined curves.")
1616    print("By statically we mean that the source code of libecc is expanded with new curves parameters through")
1617    print("automatic code generation filling place holders in the existing code base of the library. Though the")
1618    print("choice of static code generation versus dynamic curves import (such as what OpenSSL does) might be")
1619    print("argued, this choice has been driven by simplicity and security design decisions: we want libecc to have")
1620    print("all its parameters (such as memory consumption) set at compile time and statically adapted to the curves.")
1621    print("Since libecc only supports curves over prime fields, the script can only add this kind of curves.")
1622    print("This script implements elliptic curves and ISO signature algorithms from scratch over Python's multi-precision")
1623    print("big numbers library. Addition and doubling over curves use naive formulas. Please DO NOT use the functions of this")
1624    print("script for production code: they are not securely implemented and are very inefficient. Their only purpose is to expand")
1625    print("libecc and produce test vectors.")
1626    print("")
1627    print("In order to add a curve, there are two ways:")
1628    print("Adding a user defined curve with explicit parameters:")
1629    print("-----------------------------------------------------")
1630    print(sys.argv[0]+" --name=\"YOURCURVENAME\" --prime=... --order=... --a=... --b=... --gx=... --gy=... --cofactor=... --oid=THEOID")
1631    print("\t> name: name of the curve in the form of a string")
1632    print("\t> prime: prime number representing the curve prime field")
1633    print("\t> order: prime number representing the generator order")
1634    print("\t> cofactor: cofactor of the curve")
1635    print("\t> a: 'a' coefficient of the short Weierstrass equation of the curve")
1636    print("\t> b: 'b' coefficient of the short Weierstrass equation of the curve")
1637    print("\t> gx: x coordinate of the generator G")
1638    print("\t> gy: y coordinate of the generator G")
1639    print("\t> oid: optional OID of the curve")
1640    print("  Notes:")
1641    print("  ******")
1642    print("\t1) These elements are verified to indeed satisfy the curve equation.")
1643    print("\t2) All the numbers can be given either in decimal or hexadecimal format with a prepending '0x'.")
1644    print("\t3) The script automatically generates all the necessary files for the curve to be included in the library." )
1645    print("\tYou will find the new curve definition in the usual 'lib_ecc_config.h' file (one can activate it or not at compile time).")
1646    print("")
1647    print("Adding a user defined curve through RFC3279 ASN.1 parameters:")
1648    print("-------------------------------------------------------------")
1649    print(sys.argv[0]+" --name=\"YOURCURVENAME\" --ECfile=... --oid=THEOID")
1650    print("\t> ECfile: the DER or PEM encoded file containing the curve parameters (see RFC3279)")
1651    print("  Notes:")
1652    print("\tCurve parameters encoded in DER or PEM format can be generated with tools like OpenSSL (among others). As an illustrative example,")
1653    print("\tone can list all the supported curves under OpenSSL with:")
1654    print("\t  $ openssl ecparam -list_curves")
1655    print("\tOnly the listed so called \"prime\" curves are supported. Then, one can extract an explicit curve representation in ASN.1")
1656    print("\tas defined in RFC3279, for example for BRAINPOOLP320R1:")
1657    print("\t  $ openssl ecparam -param_enc explicit -outform DER -name brainpoolP320r1 -out brainpoolP320r1.der")
1658    print("")
1659    print("Removing user defined curves:")
1660    print("-----------------------------")
1661    print("\t*All the user defined curves can be removed with the --remove-all toggle.")
1662    print("\t*A specific named user define curve can be removed with the --remove toggle: in this case the --name option is used to ")
1663    print("\tlocate which named curve must be deleted.")
1664    print("")
1665    print("Test vectors:")
1666    print("-------------")
1667    print("\tTest vectors can be automatically generated and added to the library self tests when providing the --add-test-vectors=X toggle.")
1668    print("\tIn this case, X test vectors will be generated for *each* (curve, sign algorithm, hash algorithm) 3-uplet (beware of combinatorial")
1669    print("\tissues when X is big!). These tests are transparently added and compiled with the self tests.")
1670    return
1671
1672def get_int(instring):
1673    if len(instring) == 0:
1674        return 0
1675    if len(instring) >= 2:
1676        if instring[:2] == "0x":
1677            return int(instring, 16)
1678    return int(instring)
1679
1680def parse_cmd_line(args):
1681    """
1682    Get elliptic curve parameters from command line
1683    """
1684    name = oid = prime = a = b = gx = gy = g = order = cofactor = ECfile = remove = remove_all = add_test_vectors = None
1685    alpha_montgomery = gamma_montgomery = alpha_edwards = None
1686    try:
1687        opts, args = getopt.getopt(sys.argv[1:], ":h", ["help", "remove", "remove-all", "name=", "prime=", "a=", "b=", "generator=", "gx=", "gy=", "order=", "cofactor=", "alpha_montgomery=","gamma_montgomery=", "alpha_edwards=", "ECfile=", "oid=", "add-test-vectors="])
1688    except getopt.GetoptError as err:
1689        # print help information and exit:
1690        print(err) # will print something like "option -a not recognized"
1691        usage()
1692        return False
1693    for o, arg in opts:
1694        if o in ("-h", "--help"):
1695            usage()
1696            return True
1697        elif o in ("--name"):
1698            name = arg
1699            # Prepend the custom string before name to avoid any collision
1700            name = "user_defined_"+name
1701            # Replace any unwanted name char
1702            name = re.sub("\-", "_", name)
1703        elif o in ("--oid="):
1704            oid = arg
1705        elif o in ("--prime"):
1706            prime = get_int(arg.replace(' ', ''))
1707        elif o in ("--a"):
1708            a = get_int(arg.replace(' ', ''))
1709        elif o in ("--b"):
1710            b = get_int(arg.replace(' ', ''))
1711        elif o in ("--gx"):
1712            gx = get_int(arg.replace(' ', ''))
1713        elif o in ("--gy"):
1714            gy = get_int(arg.replace(' ', ''))
1715        elif o in ("--generator"):
1716            g = arg.replace(' ', '')
1717        elif o in ("--order"):
1718            order = get_int(arg.replace(' ', ''))
1719        elif o in ("--cofactor"):
1720            cofactor = get_int(arg.replace(' ', ''))
1721        elif o in ("--alpha_montgomery"):
1722            alpha_montgomery = get_int(arg.replace(' ', ''))
1723        elif o in ("--gamma_montgomery"):
1724            gamma_montgomery = get_int(arg.replace(' ', ''))
1725        elif o in ("--alpha_edwards"):
1726            alpha_edwards = get_int(arg.replace(' ', ''))
1727        elif o in ("--remove"):
1728            remove = True
1729        elif o in ("--remove-all"):
1730            remove_all = True
1731        elif o in ("--add-test-vectors"):
1732            add_test_vectors = get_int(arg.replace(' ', ''))
1733        elif o in ("--ECfile"):
1734            ECfile = arg
1735        else:
1736            print("unhandled option")
1737            usage()
1738            return False
1739
1740    # File paths
1741    script_path = os.path.abspath(os.path.dirname(sys.argv[0])) + "/"
1742    ec_params_path = script_path + "../include/libecc/curves/user_defined/"
1743    curves_list_path = script_path + "../include/libecc/curves/"
1744    lib_ecc_types_path = script_path + "../include/libecc/"
1745    lib_ecc_config_path = script_path + "../include/libecc/"
1746    ec_self_tests_path = script_path + "../src/tests/"
1747    meson_options_path = script_path + "../"
1748
1749    # If remove is True, we have been asked to remove already existing user defined curves
1750    if remove == True:
1751        if name == None:
1752            print("--remove option expects a curve name provided with --name")
1753            return False
1754        asked = ""
1755        while asked != "y" and asked != "n":
1756            asked = get_user_input("You asked to remove everything related to user defined "+name.replace("user_defined_", "")+" curve. Enter y to confirm, n to cancel [y/n]. ")
1757        if asked == "n":
1758            print("NOT removing curve "+name.replace("user_defined_", "")+" (cancelled).")
1759            return True
1760        # Remove any user defined stuff with given name
1761        print("Removing user defined curve "+name.replace("user_defined_", "")+" ...")
1762        if name == None:
1763            print("Error: you must provide a curve name with --remove")
1764            return False
1765        file_remove_pattern(curves_list_path + "curves_list.h", ".*"+name+".*")
1766        file_remove_pattern(curves_list_path + "curves_list.h", ".*"+name.upper()+".*")
1767        file_remove_pattern(lib_ecc_types_path + "lib_ecc_types.h", ".*"+name.upper()+".*")
1768        file_remove_pattern(lib_ecc_config_path + "lib_ecc_config.h", ".*"+name.upper()+".*")
1769        file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*"+name+".*")
1770        file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*"+name.upper()+".*")
1771        file_remove_pattern(meson_options_path + "meson.options", ".*"+name.lower()+".*")
1772        try:
1773            remove_file(ec_params_path + "ec_params_"+name+".h")
1774        except:
1775            print("Error: curve name "+name+" does not seem to be present in the sources!")
1776            return False
1777        try:
1778            remove_file(ec_self_tests_path + "ec_self_tests_core_"+name+".h")
1779        except:
1780            print("Warning: curve name "+name+" self tests do not seem to be present ...")
1781            return True
1782        return True
1783    if remove_all == True:
1784        asked = ""
1785        while asked != "y" and asked != "n":
1786            asked = get_user_input("You asked to remove everything related to ALL user defined curves. Enter y to confirm, n to cancel [y/n]. ")
1787        if asked == "n":
1788            print("NOT removing user defined curves (cancelled).")
1789            return True
1790        # Remove any user defined stuff with given name
1791        print("Removing ALL user defined curves ...")
1792        # Remove any user defined stuff (whatever name)
1793        file_remove_pattern(curves_list_path + "curves_list.h", ".*user_defined.*")
1794        file_remove_pattern(curves_list_path + "curves_list.h", ".*USER_DEFINED.*")
1795        file_remove_pattern(lib_ecc_types_path + "lib_ecc_types.h", ".*USER_DEFINED.*")
1796        file_remove_pattern(lib_ecc_config_path + "lib_ecc_config.h", ".*USER_DEFINED.*")
1797        file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*USER_DEFINED.*")
1798        file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*user_defined.*")
1799        file_remove_pattern(meson_options_path + "meson.options", ".*user_defined.*")
1800        remove_files_pattern(ec_params_path + "ec_params_user_defined_*.h")
1801        remove_files_pattern(ec_self_tests_path + "ec_self_tests_core_user_defined_*.h")
1802        return True
1803
1804    # If a g is provided, split it in two gx and gy
1805    if g != None:
1806        if (len(g)/2)%2 == 0:
1807            gx = get_int(g[:len(g)/2])
1808            gy = get_int(g[len(g)/2:])
1809        else:
1810            # This is probably a generator encapsulated in a bit string
1811            if g[0:2] != "04":
1812                print("Error: provided generator g is not conforming!")
1813                return False
1814            else:
1815                g = g[2:]
1816                gx = get_int(g[:len(g)/2])
1817                gy = get_int(g[len(g)/2:])
1818    if ECfile != None:
1819        # ASN.1 DER input incompatible with other options
1820        if (prime != None) or (a != None) or (b != None) or (gx != None) or (gy != None) or (order != None) or (cofactor != None):
1821            print("Error: option ECfile incompatible with explicit (prime, a, b, gx, gy, order, cofactor) options!")
1822            return False
1823        # We need at least a name
1824        if (name == None):
1825            print("Error: option ECfile needs a curve name!")
1826            return False
1827        # Open the file
1828        try:
1829            buf = open(ECfile, 'rb').read()
1830        except:
1831            print("Error: cannot open ECfile file "+ECfile)
1832            return False
1833        # Check if we have a PEM or a DER file
1834        (check, derbuf) = buffer_remove_pattern(buf, "-----.*-----")
1835        if (check == True):
1836            # This a PEM file, proceed with base64 decoding
1837            if(is_base64(derbuf) == False):
1838                print("Error: error when decoding ECfile file "+ECfile+" (seems to be PEM, but failed to decode)")
1839                return False
1840            derbuf = base64.b64decode(derbuf)
1841        (check, (a, b, prime, order, cofactor, gx, gy)) = parse_DER_ECParameters(derbuf)
1842        if (check == False):
1843            print("Error: error when parsing ECfile file "+ECfile+" (malformed or unsupported ASN.1)")
1844            return False
1845
1846    else:
1847        if (prime == None) or (a == None) or (b == None) or (gx == None) or (gy == None) or (order == None) or (cofactor == None) or (name == None):
1848            err_string = (prime == None)*"prime "+(a == None)*"a "+(b == None)*"b "+(gx == None)*"gx "+(gy == None)*"gy "+(order == None)*"order "+(cofactor == None)*"cofactor "+(name == None)*"name "
1849            print("Error: missing "+err_string+" in explicit curve definition (name, prime, a, b, gx, gy, order, cofactor)!")
1850            print("See the help with -h or --help")
1851            return False
1852
1853    # Some sanity checks here
1854    # Check that prime is indeed a prime
1855    if is_probprime(prime) == False:
1856        print("Error: given prime is *NOT* prime!")
1857        return False
1858    if is_probprime(order) == False:
1859        print("Error: given order is *NOT* prime!")
1860        return False
1861    if (a > prime) or (b > prime) or (gx > prime) or (gy > prime):
1862        err_string = (a > prime)*"a "+(b > prime)*"b "+(gx > prime)*"gx "+(gy > prime)*"gy "
1863        print("Error: "+err_string+"is > prime")
1864        return False
1865    # Check that the provided generator is on the curve
1866    if pow(gy, 2, prime) != ((pow(gx, 3, prime) + (a*gx) + b) % prime):
1867        print("Error: the given parameters (prime, a, b, gx, gy) do not verify the elliptic curve equation!")
1868        return False
1869
1870    # Check Montgomery and Edwards transfer coefficients
1871    if ((alpha_montgomery != None) and (gamma_montgomery == None)) or ((alpha_montgomery == None) and (gamma_montgomery != None)):
1872        print("Error: alpha_montgomery and gamma_montgomery must be both defined if used!")
1873        return False
1874    if (alpha_edwards != None):
1875        if (alpha_montgomery == None) or (gamma_montgomery == None):
1876            print("Error: alpha_edwards needs alpha_montgomery and gamma_montgomery to be both defined if used!")
1877            return False
1878
1879    # Now that we have our parameters, call the function to get bitlen
1880    pbitlen = getbitlen(prime)
1881    ec_params = curve_params(name, prime, pbitlen, a, b, gx, gy, order, cofactor, oid, alpha_montgomery, gamma_montgomery, alpha_edwards)
1882    # Check if there is a name collision somewhere
1883    if os.path.exists(ec_params_path + "ec_params_"+name+".h") == True :
1884        print("Error: file %s already exists!" % (ec_params_path + "ec_params_"+name+".h"))
1885        return False
1886    if (check_in_file(curves_list_path + "curves_list.h", "ec_params_"+name+"_str_params") == True) or (check_in_file(curves_list_path + "curves_list.h", "WITH_CURVE_"+name.upper()+"\n") == True) or (check_in_file(lib_ecc_types_path + "lib_ecc_types.h", "WITH_CURVE_"+name.upper()+"\n") == True):
1887        print("Error: name %s already exists in files" % ("ec_params_"+name))
1888        return False
1889    # Create a new file with the parameters
1890    if not os.path.exists(ec_params_path):
1891        # Create the "user_defined" folder if it does not exist
1892        os.mkdir(ec_params_path)
1893    f = open(ec_params_path + "ec_params_"+name+".h", 'w')
1894    f.write(ec_params)
1895    f.close()
1896    # Include the file in curves_list.h
1897    magic = "ADD curves header here"
1898    magic_re = "\/\* "+magic+" \*\/"
1899    magic_back = "/* "+magic+" */"
1900    file_replace_pattern(curves_list_path + "curves_list.h", magic_re, "#include <libecc/curves/user_defined/ec_params_"+name+".h>\n"+magic_back)
1901    # Add the curve mapping
1902    magic = "ADD curves mapping here"
1903    magic_re = "\/\* "+magic+" \*\/"
1904    magic_back = "/* "+magic+" */"
1905    file_replace_pattern(curves_list_path + "curves_list.h", magic_re, "#ifdef WITH_CURVE_"+name.upper()+"\n\t{ .type = "+name.upper()+", .params = &"+name+"_str_params },\n#endif /* WITH_CURVE_"+name.upper()+" */\n"+magic_back)
1906    # Add the new curve type in the enum
1907    # First we get the number of already defined curves so that we increment the enum counter
1908    num_with_curve = num_patterns_in_file(lib_ecc_types_path + "lib_ecc_types.h", "#ifdef WITH_CURVE_")
1909    magic = "ADD curves type here"
1910    magic_re = "\/\* "+magic+" \*\/"
1911    magic_back = "/* "+magic+" */"
1912    file_replace_pattern(lib_ecc_types_path + "lib_ecc_types.h", magic_re, "#ifdef WITH_CURVE_"+name.upper()+"\n\t"+name.upper()+" = "+str(num_with_curve+1)+",\n#endif /* WITH_CURVE_"+name.upper()+" */\n"+magic_back)
1913    # Add the new curve define in the config
1914    magic = "ADD curves define here"
1915    magic_re = "\/\* "+magic+" \*\/"
1916    magic_back = "/* "+magic+" */"
1917    file_replace_pattern(lib_ecc_config_path + "lib_ecc_config.h", magic_re, "#define WITH_CURVE_"+name.upper()+"\n"+magic_back)
1918    # Add the new curve meson option in the meson.options file
1919    magic = "ADD curves meson option here"
1920    magic_re = "# " + magic
1921    magic_back = "# " + magic
1922    file_replace_pattern(meson_options_path + "meson.options", magic_re, "\t'"+name.lower()+"',\n"+magic_back)
1923
1924    # Do we need to add some test vectors?
1925    if add_test_vectors != None:
1926        print("Test vectors generation asked: this can take some time! Please wait ...")
1927        # Create curve
1928        c = Curve(a, b, prime, order, cofactor, gx, gy, cofactor * order, name, oid)
1929        # Generate key pair for the algorithm
1930        vectors = gen_self_tests(c, add_test_vectors)
1931        # Iterate through all the tests
1932        f = open(ec_self_tests_path + "ec_self_tests_core_"+name+".h", 'w')
1933        for l in vectors:
1934            for v in l:
1935                for case in v:
1936                    (case_name, case_vector) = case
1937                    # Add the new test case
1938                    magic = "ADD curve test case here"
1939                    magic_re = "\/\* "+magic+" \*\/"
1940                    magic_back = "/* "+magic+" */"
1941                    file_replace_pattern(ec_self_tests_path + "ec_self_tests_core.h", magic_re, case_name+"\n"+magic_back)
1942                    # Create/Increment the header file
1943                    f.write(case_vector)
1944        f.close()
1945        # Add the new test cases header
1946        magic = "ADD curve test vectors header here"
1947        magic_re = "\/\* "+magic+" \*\/"
1948        magic_back = "/* "+magic+" */"
1949        file_replace_pattern(ec_self_tests_path + "ec_self_tests_core.h", magic_re, "#include \"ec_self_tests_core_"+name+".h\"\n"+magic_back)
1950    return True
1951
1952
1953#### Main
1954if __name__ == "__main__":
1955    signal.signal(signal.SIGINT, handler)
1956    parse_cmd_line(sys.argv[1:])
1957