diff options
Diffstat (limited to 'scripts/expand_libecc.py')
| -rw-r--r-- | scripts/expand_libecc.py | 1956 |
1 files changed, 1956 insertions, 0 deletions
diff --git a/scripts/expand_libecc.py b/scripts/expand_libecc.py new file mode 100644 index 000000000000..2c4d5b5f3d9d --- /dev/null +++ b/scripts/expand_libecc.py @@ -0,0 +1,1956 @@ +#/* +# * Copyright (C) 2017 - This file is part of libecc project +# * +# * Authors: +# * Ryad BENADJILA <ryadbenadjila@gmail.com> +# * Arnaud EBALARD <arnaud.ebalard@ssi.gouv.fr> +# * Jean-Pierre FLORI <jean-pierre.flori@ssi.gouv.fr> +# * +# * Contributors: +# * Nicolas VIVET <nicolas.vivet@ssi.gouv.fr> +# * Karim KHALFALLAH <karim.khalfallah@ssi.gouv.fr> +# * +# * This software is licensed under a dual BSD and GPL v2 license. +# * See LICENSE file at the root folder of the project. +# */ +#! /usr/bin/env python + +import random, sys, re, math, os, getopt, glob, copy, hashlib, binascii, string, signal, base64 + +# External dependecy for SHA-3 +# It is an independent module, since hashlib has no support +# for SHA-3 functions for now +import sha3 + +# Handle Python 2/3 issues +def is_python_2(): + if sys.version_info[0] < 3: + return True + else: + return False + +### Ctrl-C handler +def handler(signal, frame): + print("\nSIGINT caught: exiting ...") + exit(0) + +# Helper to ask the user for something +def get_user_input(prompt): + # Handle the Python 2/3 issue + if is_python_2() == False: + return input(prompt) + else: + return raw_input(prompt) + +########################################################## +#### Math helpers +def egcd(b, n): + x0, x1, y0, y1 = 1, 0, 0, 1 + while n != 0: + q, b, n = b // n, n, b % n + x0, x1 = x1, x0 - q * x1 + y0, y1 = y1, y0 - q * y1 + return b, x0, y0 + +def modinv(a, m): + g, x, y = egcd(a, m) + if g != 1: + raise Exception("Error: modular inverse does not exist") + else: + return x % m + +def compute_monty_coef(prime, pbitlen, wlen): + """ + Compute montgomery coeff r, r^2 and mpinv. pbitlen is the size + of p in bits. It is expected to be a multiple of word + bit size. + """ + r = (1 << int(pbitlen)) % prime + r_square = (1 << (2 * int(pbitlen))) % prime + mpinv = 2**wlen - (modinv(prime, 2**wlen)) + return r, r_square, mpinv + +def compute_div_coef(prime, pbitlen, wlen): + """ + Compute division coeffs p_normalized, p_shift and p_reciprocal. + """ + tmp = prime + cnt = 0 + while tmp != 0: + tmp = tmp >> 1 + cnt += 1 + pshift = int(pbitlen - cnt) + primenorm = prime << pshift + B = 2**wlen + prec = B**3 // ((primenorm >> int(pbitlen - 2*wlen)) + 1) - B + return pshift, primenorm, prec + +def is_probprime(n): + # ensure n is odd + if n % 2 == 0: + return False + # write n-1 as 2**s * d + # repeatedly try to divide n-1 by 2 + s = 0 + d = n-1 + while True: + quotient, remainder = divmod(d, 2) + if remainder == 1: + break + s += 1 + d = quotient + assert(2**s * d == n-1) + # test the base a to see whether it is a witness for the compositeness of n + def try_composite(a): + if pow(a, d, n) == 1: + return False + for i in range(s): + if pow(a, 2**i * d, n) == n-1: + return False + return True # n is definitely composite + for i in range(5): + a = random.randrange(2, n) + if try_composite(a): + return False + return True # no base tested showed n as composite + +def legendre_symbol(a, p): + ls = pow(a, (p - 1) // 2, p) + return -1 if ls == p - 1 else ls + +# Tonelli-Shanks algorithm to find square roots +# over prime fields +def mod_sqrt(a, p): + # Square root of 0 is 0 + if a == 0: + return 0 + # Simple cases + if legendre_symbol(a, p) != 1: + # No square residue + return None + elif p == 2: + return a + elif p % 4 == 3: + return pow(a, (p + 1) // 4, p) + s = p - 1 + e = 0 + while s % 2 == 0: + s = s // 2 + e += 1 + n = 2 + while legendre_symbol(n, p) != -1: + n += 1 + x = pow(a, (s + 1) // 2, p) + b = pow(a, s, p) + g = pow(n, s, p) + r = e + while True: + t = b + m = 0 + if is_python_2(): + for m in xrange(r): + if t == 1: + break + t = pow(t, 2, p) + else: + for m in range(r): + if t == 1: + break + t = pow(t, 2, p) + if m == 0: + return x + gs = pow(g, 2 ** (r - m - 1), p) + g = (gs * gs) % p + x = (x * gs) % p + b = (b * g) % p + r = m + +########################################################## +### Math elliptic curves basic blocks + +# WARNING: these blocks are only here for testing purpose and +# are not intended to be used in a security oriented library! +# This explains the usage of naive affine coordinates fomulas +class Curve(object): + def __init__(self, a, b, prime, order, cofactor, gx, gy, npoints, name, oid): + self.a = a + self.b = b + self.p = prime + self.q = order + self.c = cofactor + self.gx = gx + self.gy = gy + self.n = npoints + self.name = name + self.oid = oid + # Equality testing + def __eq__(self, other): + return self.__dict__ == other.__dict__ + # Deep copy is implemented using the ~X operator + def __invert__(self): + return copy.deepcopy(self) + + +class Point(object): + # Affine coordinates (x, y), infinity point is (None, None) + def __init__(self, curve, x, y): + self.curve = curve + if x != None: + self.x = (x % curve.p) + else: + self.x = None + if y != None: + self.y = (y % curve.p) + else: + self.y = None + # Check that the point is indeed on the curve + if (x != None): + if (pow(y, 2, curve.p) != ((pow(x, 3, curve.p) + (curve.a * x) + curve.b ) % curve.p)): + raise Exception("Error: point is not on curve!") + # Addition + def __add__(self, Q): + x1 = self.x + y1 = self.y + x2 = Q.x + y2 = Q.y + curve = self.curve + # Check that we are on the same curve + if Q.curve != curve: + raise Exception("Point add error: two point don't have the same curve") + # If Q is infinity point, return ourself + if Q.x == None: + return Point(self.curve, self.x, self.y) + # If we are the infinity point return Q + if self.x == None: + return Q + # Infinity point or Doubling + if (x1 == x2): + if (((y1 + y2) % curve.p) == 0): + # Return infinity point + return Point(self.curve, None, None) + else: + # Doubling + L = ((3*pow(x1, 2, curve.p) + curve.a) * modinv(2*y1, curve.p)) % curve.p + # Addition + else: + L = ((y2 - y1) * modinv((x2 - x1) % curve.p, curve.p)) % curve.p + resx = (pow(L, 2, curve.p) - x1 - x2) % curve.p + resy = ((L * (x1 - resx)) - y1) % curve.p + # Return the point + return Point(self.curve, resx, resy) + # Negation + def __neg__(self): + if (self.x == None): + return Point(self.curve, None, None) + else: + return Point(self.curve, self.x, -self.y) + # Subtraction + def __sub__(self, other): + return self + (-other) + # Scalar mul + def __rmul__(self, scalar): + # Implement simple double and add algorithm + P = self + Q = Point(P.curve, None, None) + for i in range(getbitlen(scalar), 0, -1): + Q = Q + Q + if (scalar >> (i-1)) & 0x1 == 0x1: + Q = Q + P + return Q + # Equality testing + def __eq__(self, other): + return self.__dict__ == other.__dict__ + # Deep copy is implemented using the ~X operator + def __invert__(self): + return copy.deepcopy(self) + def __str__(self): + if self.x == None: + return "Inf" + else: + return ("(x = %s, y = %s)" % (hex(self.x), hex(self.y))) + +########################################################## +### Private and public keys structures +class PrivKey(object): + def __init__(self, curve, x): + self.curve = curve + self.x = x + +class PubKey(object): + def __init__(self, curve, Y): + # Sanity check + if Y.curve != curve: + raise Exception("Error: curve and point curve differ in public key!") + self.curve = curve + self.Y = Y + +class KeyPair(object): + def __init__(self, pubkey, privkey): + self.pubkey = pubkey + self.privkey = privkey + + +def fromprivkey(privkey, is_eckcdsa=False): + curve = privkey.curve + q = curve.q + gx = curve.gx + gy = curve.gy + G = Point(curve, gx, gy) + if is_eckcdsa == False: + return PubKey(curve, privkey.x * G) + else: + return PubKey(curve, modinv(privkey.x, q) * G) + +def genKeyPair(curve, is_eckcdsa=False): + p = curve.p + q = curve.q + gx = curve.gx + gy = curve.gy + G = Point(curve, gx, gy) + OK = False + while OK == False: + x = getrandomint(q) + if x == 0: + continue + OK = True + privkey = PrivKey(curve, x) + pubkey = fromprivkey(privkey, is_eckcdsa) + return KeyPair(pubkey, privkey) + +########################################################## +### Signature algorithms helpers +def getrandomint(modulo): + return random.randrange(0, modulo+1) + +def getbitlen(bint): + """ + Returns the number of bits encoding an integer + """ + if bint == None: + return 0 + if bint == 0: + # Zero is encoded on one bit + return 1 + else: + return int(bint).bit_length() + +def getbytelen(bint): + """ + Returns the number of bytes encoding an integer + """ + bitsize = getbitlen(bint) + bytesize = int(bitsize // 8) + if bitsize % 8 != 0: + bytesize += 1 + return bytesize + +def stringtoint(bitstring): + acc = 0 + size = len(bitstring) + for i in range(0, size): + acc = acc + (ord(bitstring[i]) * (2**(8*(size - 1 - i)))) + return acc + +def inttostring(a): + size = int(getbytelen(a)) + outstr = "" + for i in range(0, size): + outstr = outstr + chr((a >> (8*(size - 1 - i))) & 0xFF) + return outstr + +def expand(bitstring, bitlen, direction): + bytelen = int(math.ceil(bitlen / 8.)) + if len(bitstring) >= bytelen: + return bitstring + else: + if direction == "LEFT": + return ((bytelen-len(bitstring))*"\x00") + bitstring + elif direction == "RIGHT": + return bitstring + ((bytelen-len(bitstring))*"\x00") + else: + raise Exception("Error: unknown direction "+direction+" in expand") + +def truncate(bitstring, bitlen, keep): + """ + Takes a bit string and truncates it to keep the left + most or the right most bits + """ + strbitlen = 8*len(bitstring) + # Check if truncation is needed + if strbitlen > bitlen: + if keep == "LEFT": + return expand(inttostring(stringtoint(bitstring) >> int(strbitlen - bitlen)), bitlen, "LEFT") + elif keep == "RIGHT": + mask = (2**bitlen)-1 + return expand(inttostring(stringtoint(bitstring) & mask), bitlen, "LEFT") + else: + raise Exception("Error: unknown direction "+keep+" in truncate") + else: + # No need to truncate! + return bitstring + +########################################################## +### Hash algorithms +def sha224(message): + ctx = hashlib.sha224() + if(is_python_2() == True): + ctx.update(message) + digest = ctx.digest() + else: + ctx.update(message.encode('latin-1')) + digest = ctx.digest().decode('latin-1') + return (digest, ctx.digest_size, ctx.block_size) + +def sha256(message): + ctx = hashlib.sha256() + if(is_python_2() == True): + ctx.update(message) + digest = ctx.digest() + else: + ctx.update(message.encode('latin-1')) + digest = ctx.digest().decode('latin-1') + return (digest, ctx.digest_size, ctx.block_size) + +def sha384(message): + ctx = hashlib.sha384() + if(is_python_2() == True): + ctx.update(message) + digest = ctx.digest() + else: + ctx.update(message.encode('latin-1')) + digest = ctx.digest().decode('latin-1') + return (digest, ctx.digest_size, ctx.block_size) + +def sha512(message): + ctx = hashlib.sha512() + if(is_python_2() == True): + ctx.update(message) + digest = ctx.digest() + else: + ctx.update(message.encode('latin-1')) + digest = ctx.digest().decode('latin-1') + return (digest, ctx.digest_size, ctx.block_size) + +def sha3_224(message): + ctx = sha3.Sha3_ctx(224) + if(is_python_2() == True): + ctx.update(message) + digest = ctx.digest() + else: + ctx.update(message.encode('latin-1')) + digest = ctx.digest().decode('latin-1') + return (digest, ctx.digest_size, ctx.block_size) + +def sha3_256(message): + ctx = sha3.Sha3_ctx(256) + if(is_python_2() == True): + ctx.update(message) + digest = ctx.digest() + else: + ctx.update(message.encode('latin-1')) + digest = ctx.digest().decode('latin-1') + return (digest, ctx.digest_size, ctx.block_size) + +def sha3_384(message): + ctx = sha3.Sha3_ctx(384) + if(is_python_2() == True): + ctx.update(message) + digest = ctx.digest() + else: + ctx.update(message.encode('latin-1')) + digest = ctx.digest().decode('latin-1') + return (digest, ctx.digest_size, ctx.block_size) + +def sha3_512(message): + ctx = sha3.Sha3_ctx(512) + if(is_python_2() == True): + ctx.update(message) + digest = ctx.digest() + else: + ctx.update(message.encode('latin-1')) + digest = ctx.digest().decode('latin-1') + return (digest, ctx.digest_size, ctx.block_size) + +########################################################## +### Signature algorithms + +# *| IUF - ECDSA signature +# *| +# *| UF 1. Compute h = H(m) +# *| F 2. If |h| > bitlen(q), set h to bitlen(q) +# *| leftmost (most significant) bits of h +# *| F 3. e = OS2I(h) mod q +# *| F 4. Get a random value k in ]0,q[ +# *| F 5. Compute W = (W_x,W_y) = kG +# *| F 6. Compute r = W_x mod q +# *| F 7. If r is 0, restart the process at step 4. +# *| F 8. If e == rx, restart the process at step 4. +# *| F 9. Compute s = k^-1 * (xr + e) mod q +# *| F 10. If s is 0, restart the process at step 4. +# *| F 11. Return (r,s) +def ecdsa_sign(hashfunc, keypair, message, k=None): + privkey = keypair.privkey + # Get important parameters from the curve + p = privkey.curve.p + q = privkey.curve.q + gx = privkey.curve.gx + gy = privkey.curve.gy + G = Point(privkey.curve, gx, gy) + q_limit_len = getbitlen(q) + # Compute the hash + (h, _, _) = hashfunc(message) + # Truncate hash value + h = truncate(h, q_limit_len, "LEFT") + # Convert the hash value to an int + e = stringtoint(h) % q + OK = False + while OK == False: + if k == None: + k = getrandomint(q) + if k == 0: + continue + W = k * G + r = W.x % q + if r == 0: + continue + if e == r * privkey.x: + continue + s = (modinv(k, q) * ((privkey.x * r) + e)) % q + if s == 0: + continue + OK = True + return ((expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT")), k) + +# *| IUF - ECDSA verification +# *| +# *| I 1. Reject the signature if r or s is 0. +# *| UF 2. Compute h = H(m) +# *| F 3. If |h| > bitlen(q), set h to bitlen(q) +# *| leftmost (most significant) bits of h +# *| F 4. Compute e = OS2I(h) mod q +# *| F 5. Compute u = (s^-1)e mod q +# *| F 6. Compute v = (s^-1)r mod q +# *| F 7. Compute W' = uG + vY +# *| F 8. If W' is the point at infinity, reject the signature. +# *| F 9. Compute r' = W'_x mod q +# *| F 10. Accept the signature if and only if r equals r' +def ecdsa_verify(hashfunc, keypair, message, sig): + pubkey = keypair.pubkey + # Get important parameters from the curve + p = pubkey.curve.p + q = pubkey.curve.q + gx = pubkey.curve.gx + gy = pubkey.curve.gy + q_limit_len = getbitlen(q) + G = Point(pubkey.curve, gx, gy) + # Extract r and s + if len(sig) != 2*getbytelen(q): + raise Exception("ECDSA verify: bad signature length!") + r = stringtoint(sig[0:int(len(sig)/2)]) + s = stringtoint(sig[int(len(sig)/2):]) + if r == 0 or s == 0: + return False + # Compute the hash + (h, _, _) = hashfunc(message) + # Truncate hash value + h = truncate(h, q_limit_len, "LEFT") + # Convert the hash value to an int + e = stringtoint(h) % q + u = (modinv(s, q) * e) % q + v = (modinv(s, q) * r) % q + W_ = (u * G) + (v * pubkey.Y) + if W_.x == None: + return False + r_ = W_.x % q + if r == r_: + return True + else: + return False + +def eckcdsa_genKeyPair(curve): + return genKeyPair(curve, True) + +# *| IUF - ECKCDSA signature +# *| +# *| IUF 1. Compute h = H(z||m) +# *| F 2. If hsize > bitlen(q), set h to bitlen(q) +# *| rightmost (less significant) bits of h. +# *| F 3. Get a random value k in ]0,q[ +# *| F 4. Compute W = (W_x,W_y) = kG +# *| F 5. Compute r = h(FE2OS(W_x)). +# *| F 6. If hsize > bitlen(q), set r to bitlen(q) +# *| rightmost (less significant) bits of r. +# *| F 7. Compute e = OS2I(r XOR h) mod q +# *| F 8. Compute s = x(k - e) mod q +# *| F 9. if s == 0, restart at step 3. +# *| F 10. return (r,s) +def eckcdsa_sign(hashfunc, keypair, message, k=None): + privkey = keypair.privkey + # Get important parameters from the curve + p = privkey.curve.p + q = privkey.curve.q + gx = privkey.curve.gx + gy = privkey.curve.gy + G = Point(privkey.curve, gx, gy) + q_limit_len = getbitlen(q) + # Compute the certificate data + (_, _, hblocksize) = hashfunc("") + z = expand(inttostring(keypair.pubkey.Y.x), 8*getbytelen(p), "LEFT") + z = z + expand(inttostring(keypair.pubkey.Y.y), 8*getbytelen(p), "LEFT") + if len(z) > hblocksize: + # Truncate + z = truncate(z, 8*hblocksize, "LEFT") + else: + # Expand + z = expand(z, 8*hblocksize, "RIGHT") + # Compute the hash + (h, _, _) = hashfunc(z + message) + # Truncate hash value + h = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT") + OK = False + while OK == False: + if k == None: + k = getrandomint(q) + if k == 0: + continue + W = k * G + (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT")) + r = truncate(r, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT") + e = (stringtoint(r) ^ stringtoint(h)) % q + s = (privkey.x * (k - e)) % q + if s == 0: + continue + OK = True + return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k) + +# *| IUF - ECKCDSA verification +# *| +# *| I 1. Check the length of r: +# *| - if hsize > bitlen(q), r must be of +# *| length bitlen(q) +# *| - if hsize <= bitlen(q), r must be of +# *| length hsize +# *| I 2. Check that s is in ]0,q[ +# *| IUF 3. Compute h = H(z||m) +# *| F 4. If hsize > bitlen(q), set h to bitlen(q) +# *| rightmost (less significant) bits of h. +# *| F 5. Compute e = OS2I(r XOR h) mod q +# *| F 6. Compute W' = sY + eG, where Y is the public key +# *| F 7. Compute r' = h(FE2OS(W'x)) +# *| F 8. If hsize > bitlen(q), set r' to bitlen(q) +# *| rightmost (less significant) bits of r'. +# *| F 9. Check if r == r' +def eckcdsa_verify(hashfunc, keypair, message, sig): + pubkey = keypair.pubkey + # Get important parameters from the curve + p = pubkey.curve.p + q = pubkey.curve.q + gx = pubkey.curve.gx + gy = pubkey.curve.gy + G = Point(pubkey.curve, gx, gy) + q_limit_len = getbitlen(q) + (_, hsize, hblocksize) = hashfunc("") + # Extract r and s + if (8*hsize) > q_limit_len: + r_len = int(math.ceil(q_limit_len / 8.)) + else: + r_len = hsize + r = stringtoint(sig[0:int(r_len)]) + s = stringtoint(sig[int(r_len):]) + if (s >= q) or (s < 0): + return False + # Compute the certificate data + z = expand(inttostring(keypair.pubkey.Y.x), 8*getbytelen(p), "LEFT") + z = z + expand(inttostring(keypair.pubkey.Y.y), 8*getbytelen(p), "LEFT") + if len(z) > hblocksize: + # Truncate + z = truncate(z, 8*hblocksize, "LEFT") + else: + # Expand + z = expand(z, 8*hblocksize, "RIGHT") + # Compute the hash + (h, _, _) = hashfunc(z + message) + # Truncate hash value + h = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT") + e = (r ^ stringtoint(h)) % q + W_ = (s * pubkey.Y) + (e * G) + (h, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT")) + r_ = truncate(h, 8 * int(math.ceil(q_limit_len / 8)), "RIGHT") + if stringtoint(r_) == r: + return True + else: + return False + +# *| IUF - ECFSDSA signature +# *| +# *| I 1. Get a random value k in ]0,q[ +# *| I 2. Compute W = (W_x,W_y) = kG +# *| I 3. Compute r = FE2OS(W_x)||FE2OS(W_y) +# *| I 4. If r is an all zero string, restart the process at step 1. +# *| IUF 5. Compute h = H(r||m) +# *| F 6. Compute e = OS2I(h) mod q +# *| F 7. Compute s = (k + ex) mod q +# *| F 8. If s is 0, restart the process at step 1 (see c. below) +# *| F 9. Return (r,s) +def ecfsdsa_sign(hashfunc, keypair, message, k=None): + privkey = keypair.privkey + # Get important parameters from the curve + p = privkey.curve.p + q = privkey.curve.q + gx = privkey.curve.gx + gy = privkey.curve.gy + G = Point(privkey.curve, gx, gy) + OK = False + while OK == False: + if k == None: + k = getrandomint(q) + if k == 0: + continue + W = k * G + r = expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W.y), 8*getbytelen(p), "LEFT") + if stringtoint(r) == 0: + continue + (h, _, _) = hashfunc(r + message) + e = stringtoint(h) % q + s = (k + e * privkey.x) % q + if s == 0: + continue + OK = True + return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k) + + +# *| IUF - ECFSDSA verification +# *| +# *| I 1. Reject the signature if r is not a valid point on the curve. +# *| I 2. Reject the signature if s is not in ]0,q[ +# *| IUF 3. Compute h = H(r||m) +# *| F 4. Convert h to an integer and then compute e = -h mod q +# *| F 5. compute W' = sG + eY, where Y is the public key +# *| F 6. Compute r' = FE2OS(W'_x)||FE2OS(W'_y) +# *| F 7. Accept the signature if and only if r equals r' +def ecfsdsa_verify(hashfunc, keypair, message, sig): + pubkey = keypair.pubkey + # Get important parameters from the curve + p = pubkey.curve.p + q = pubkey.curve.q + gx = pubkey.curve.gx + gy = pubkey.curve.gy + G = Point(pubkey.curve, gx, gy) + # Extract coordinates from r and s from signature + if len(sig) != (2*getbytelen(p)) + getbytelen(q): + raise Exception("ECFSDSA verify: bad signature length!") + wx = sig[:int(getbytelen(p))] + wy = sig[int(getbytelen(p)):int(2*getbytelen(p))] + r = wx + wy + s = stringtoint(sig[int(2*getbytelen(p)):int((2*getbytelen(p))+getbytelen(q))]) + # Check r is on the curve + W = Point(pubkey.curve, stringtoint(wx), stringtoint(wy)) + # Check s is in ]0,q[ + if s == 0 or s > q: + raise Exception("ECFSDSA verify: s not in ]0,q[") + (h, _, _) = hashfunc(r + message) + e = (-stringtoint(h)) % q + W_ = s * G + e * pubkey.Y + r_ = expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W_.y), 8*getbytelen(p), "LEFT") + if r == r_: + return True + else: + return False + + +# NOTE: ISO/IEC 14888-3 standard seems to diverge from the existing implementations +# of ECRDSA when treating the message hash, and from the examples of certificates provided +# in RFC 7091 and draft-deremin-rfc4491-bis. While in ISO/IEC 14888-3 it is explicitely asked +# to proceed with the hash of the message as big endian, the RFCs derived from the Russian +# standard expect the hash value to be treated as little endian when importing it as an integer +# (this discrepancy is exhibited and confirmed by test vectors present in ISO/IEC 14888-3, and +# by X.509 certificates present in the RFCs). This seems (to be confirmed) to be a discrepancy of +# ISO/IEC 14888-3 algorithm description that must be fixed there. +# +# In order to be conservative, libecc uses the Russian standard behavior as expected to be in line with +# other implemetations, but keeps the ISO/IEC 14888-3 behavior if forced/asked by the user using +# the USE_ISO14888_3_ECRDSA toggle. This allows to keep backward compatibility with previous versions of the +# library if needed. + +# *| IUF - ECRDSA signature +# *| +# *| UF 1. Compute h = H(m) +# *| F 2. Get a random value k in ]0,q[ +# *| F 3. Compute W = (W_x,W_y) = kG +# *| F 4. Compute r = W_x mod q +# *| F 5. If r is 0, restart the process at step 2. +# *| F 6. Compute e = OS2I(h) mod q. If e is 0, set e to 1. +# *| NOTE: here, ISO/IEC 14888-3 and RFCs differ in the way e treated. +# *| e = OS2I(h) for ISO/IEC 14888-3, or e = OS2I(reversed(h)) when endianness of h +# *| is reversed for RFCs. +# *| F 7. Compute s = (rx + ke) mod q +# *| F 8. If s is 0, restart the process at step 2. +# *| F 11. Return (r,s) +def ecrdsa_sign(hashfunc, keypair, message, k=None, use_iso14888_divergence=False): + privkey = keypair.privkey + # Get important parameters from the curve + p = privkey.curve.p + q = privkey.curve.q + gx = privkey.curve.gx + gy = privkey.curve.gy + G = Point(privkey.curve, gx, gy) + (h, _, _) = hashfunc(message) + if use_iso14888_divergence == False: + # Reverse the endianness for Russian standard RFC ECRDSA (contrary to ISO/IEC 14888-3 case) + h = h[::-1] + OK = False + while OK == False: + if k == None: + k = getrandomint(q) + if k == 0: + continue + W = k * G + r = W.x % q + if r == 0: + continue + e = stringtoint(h) % q + if e == 0: + e = 1 + s = ((r * privkey.x) + (k * e)) % q + if s == 0: + continue + OK = True + return (expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k) + +# *| IUF - ECRDSA verification +# *| +# *| UF 1. Check that r and s are both in ]0,q[ +# *| F 2. Compute h = H(m) +# *| F 3. Compute e = OS2I(h)^-1 mod q +# *| NOTE: here, ISO/IEC 14888-3 and RFCs differ in the way e treated. +# *| e = OS2I(h) for ISO/IEC 14888-3, or e = OS2I(reversed(h)) when endianness of h +# *| is reversed for RFCs. +# *| F 4. Compute u = es mod q +# *| F 4. Compute v = -er mod q +# *| F 5. Compute W' = uG + vY = (W'_x, W'_y) +# *| F 6. Let's now compute r' = W'_x mod q +# *| F 7. Check r and r' are the same +def ecrdsa_verify(hashfunc, keypair, message, sig, use_iso14888_divergence=False): + pubkey = keypair.pubkey + # Get important parameters from the curve + p = pubkey.curve.p + q = pubkey.curve.q + gx = pubkey.curve.gx + gy = pubkey.curve.gy + G = Point(pubkey.curve, gx, gy) + # Extract coordinates from r and s from signature + if len(sig) != 2*getbytelen(q): + raise Exception("ECRDSA verify: bad signature length!") + r = stringtoint(sig[:int(getbytelen(q))]) + s = stringtoint(sig[int(getbytelen(q)):int(2*getbytelen(q))]) + if r == 0 or r > q: + raise Exception("ECRDSA verify: r not in ]0,q[") + if s == 0 or s > q: + raise Exception("ECRDSA verify: s not in ]0,q[") + (h, _, _) = hashfunc(message) + if use_iso14888_divergence == False: + # Reverse the endianness for Russian standard RFC ECRDSA (contrary to ISO/IEC 14888-3 case) + h = h[::-1] + e = modinv(stringtoint(h) % q, q) + u = (e * s) % q + v = (-e * r) % q + W_ = u * G + v * pubkey.Y + r_ = W_.x % q + if r == r_: + return True + else: + return False + + +# *| IUF - ECGDSA signature +# *| +# *| UF 1. Compute h = H(m). If |h| > bitlen(q), set h to bitlen(q) +# *| leftmost (most significant) bits of h +# *| F 2. Convert e = - OS2I(h) mod q +# *| F 3. Get a random value k in ]0,q[ +# *| F 4. Compute W = (W_x,W_y) = kG +# *| F 5. Compute r = W_x mod q +# *| F 6. If r is 0, restart the process at step 4. +# *| F 7. Compute s = x(kr + e) mod q +# *| F 8. If s is 0, restart the process at step 4. +# *| F 9. Return (r,s) +def ecgdsa_sign(hashfunc, keypair, message, k=None): + privkey = keypair.privkey + # Get important parameters from the curve + p = privkey.curve.p + q = privkey.curve.q + gx = privkey.curve.gx + gy = privkey.curve.gy + G = Point(privkey.curve, gx, gy) + (h, _, _) = hashfunc(message) + q_limit_len = getbitlen(q) + # Truncate hash value + h = truncate(h, q_limit_len, "LEFT") + e = (-stringtoint(h)) % q + OK = False + while OK == False: + if k == None: + k = getrandomint(q) + if k == 0: + continue + W = k * G + r = W.x % q + if r == 0: + continue + s = (privkey.x * ((k * r) + e)) % q + if s == 0: + continue + OK = True + return (expand(inttostring(r), 8*getbytelen(q), "LEFT") + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k) + +# *| IUF - ECGDSA verification +# *| +# *| I 1. Reject the signature if r or s is 0. +# *| UF 2. Compute h = H(m). If |h| > bitlen(q), set h to bitlen(q) +# *| leftmost (most significant) bits of h +# *| F 3. Compute e = OS2I(h) mod q +# *| F 4. Compute u = ((r^-1)e mod q) +# *| F 5. Compute v = ((r^-1)s mod q) +# *| F 6. Compute W' = uG + vY +# *| F 7. Compute r' = W'_x mod q +# *| F 8. Accept the signature if and only if r equals r' +def ecgdsa_verify(hashfunc, keypair, message, sig): + pubkey = keypair.pubkey + # Get important parameters from the curve + p = pubkey.curve.p + q = pubkey.curve.q + gx = pubkey.curve.gx + gy = pubkey.curve.gy + G = Point(pubkey.curve, gx, gy) + # Extract coordinates from r and s from signature + if len(sig) != 2*getbytelen(q): + raise Exception("ECGDSA verify: bad signature length!") + r = stringtoint(sig[:int(getbytelen(q))]) + s = stringtoint(sig[int(getbytelen(q)):int(2*getbytelen(q))]) + if r == 0 or r > q: + raise Exception("ECGDSA verify: r not in ]0,q[") + if s == 0 or s > q: + raise Exception("ECGDSA verify: s not in ]0,q[") + (h, _, _) = hashfunc(message) + q_limit_len = getbitlen(q) + # Truncate hash value + h = truncate(h, q_limit_len, "LEFT") + e = stringtoint(h) % q + r_inv = modinv(r, q) + u = (r_inv * e) % q + v = (r_inv * s) % q + W_ = u * G + v * pubkey.Y + r_ = W_.x % q + if r == r_: + return True + else: + return False + +# *| IUF - ECSDSA/ECOSDSA signature +# *| +# *| I 1. Get a random value k in ]0, q[ +# *| I 2. Compute W = kG = (Wx, Wy) +# *| IUF 3. Compute r = H(Wx [|| Wy] || m) +# *| - In the normal version (ECSDSA), r = h(Wx || Wy || m). +# *| - In the optimized version (ECOSDSA), r = h(Wx || m). +# *| F 4. Compute e = OS2I(r) mod q +# *| F 5. if e == 0, restart at step 1. +# *| F 6. Compute s = (k + ex) mod q. +# *| F 7. if s == 0, restart at step 1. +# *| F 8. Return (r, s) +def ecsdsa_common_sign(hashfunc, keypair, message, optimized, k=None): + privkey = keypair.privkey + # Get important parameters from the curve + p = privkey.curve.p + q = privkey.curve.q + gx = privkey.curve.gx + gy = privkey.curve.gy + G = Point(privkey.curve, gx, gy) + OK = False + while OK == False: + if k == None: + k = getrandomint(q) + if k == 0: + continue + W = k * G + if optimized == False: + (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W.y), 8*getbytelen(p), "LEFT") + message) + else: + (r, _, _) = hashfunc(expand(inttostring(W.x), 8*getbytelen(p), "LEFT") + message) + e = stringtoint(r) % q + if e == 0: + continue + s = (k + (e * privkey.x)) % q + if s == 0: + continue + OK = True + return (r + expand(inttostring(s), 8*getbytelen(q), "LEFT"), k) + +def ecsdsa_sign(hashfunc, keypair, message, k=None): + return ecsdsa_common_sign(hashfunc, keypair, message, False, k) + +def ecosdsa_sign(hashfunc, keypair, message, k=None): + return ecsdsa_common_sign(hashfunc, keypair, message, True, k) + +# *| IUF - ECSDSA/ECOSDSA verification +# *| +# *| I 1. if s is not in ]0,q[, reject the signature.x +# *| I 2. Compute e = -r mod q +# *| I 3. If e == 0, reject the signature. +# *| I 4. Compute W' = sG + eY +# *| IUF 5. Compute r' = H(W'x [|| W'y] || m) +# *| - In the normal version (ECSDSA), r = h(W'x || W'y || m). +# *| - In the optimized version (ECOSDSA), r = h(W'x || m). +# *| F 6. Accept the signature if and only if r and r' are the same +def ecsdsa_common_verify(hashfunc, keypair, message, sig, optimized): + pubkey = keypair.pubkey + # Get important parameters from the curve + p = pubkey.curve.p + q = pubkey.curve.q + gx = pubkey.curve.gx + gy = pubkey.curve.gy + G = Point(pubkey.curve, gx, gy) + (_, hlen, _) = hashfunc("") + # Extract coordinates from r and s from signature + if len(sig) != hlen + getbytelen(q): + raise Exception("EC[O]SDSA verify: bad signature length!") + r = stringtoint(sig[:int(hlen)]) + s = stringtoint(sig[int(hlen):int(hlen+getbytelen(q))]) + if s == 0 or s > q: + raise Exception("EC[O]DSA verify: s not in ]0,q[") + e = (-r) % q + if e == 0: + raise Exception("EC[O]DSA verify: e is null") + W_ = s * G + e * pubkey.Y + if optimized == False: + (r_, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + expand(inttostring(W_.y), 8*getbytelen(p), "LEFT") + message) + else: + (r_, _, _) = hashfunc(expand(inttostring(W_.x), 8*getbytelen(p), "LEFT") + message) + if sig[:int(hlen)] == r_: + return True + else: + return False + +def ecsdsa_verify(hashfunc, keypair, message, sig): + return ecsdsa_common_verify(hashfunc, keypair, message, sig, False) + +def ecosdsa_verify(hashfunc, keypair, message, sig): + return ecsdsa_common_verify(hashfunc, keypair, message, sig, True) + + +########################################################## +### Generate self-tests for all the algorithms + +all_hash_funcs = [ (sha224, "SHA224"), (sha256, "SHA256"), (sha384, "SHA384"), (sha512, "SHA512"), (sha3_224, "SHA3_224"), (sha3_256, "SHA3_256"), (sha3_384, "SHA3_384"), (sha3_512, "SHA3_512") ] + +all_sig_algs = [ (ecdsa_sign, ecdsa_verify, genKeyPair, "ECDSA"), + (eckcdsa_sign, eckcdsa_verify, eckcdsa_genKeyPair, "ECKCDSA"), + (ecfsdsa_sign, ecfsdsa_verify, genKeyPair, "ECFSDSA"), + (ecrdsa_sign, ecrdsa_verify, genKeyPair, "ECRDSA"), + (ecgdsa_sign, ecgdsa_verify, eckcdsa_genKeyPair, "ECGDSA"), + (ecsdsa_sign, ecsdsa_verify, genKeyPair, "ECSDSA"), + (ecosdsa_sign, ecosdsa_verify, genKeyPair, "ECOSDSA"), ] + + +curr_test = 0 +def pretty_print_curr_test(num_test, total_gen_tests): + num_decimal = int(math.log10(total_gen_tests))+1 + format_buf = "%0"+str(num_decimal)+"d/%0"+str(num_decimal)+"d" + sys.stdout.write('\b'*((2*num_decimal)+1)) + sys.stdout.flush() + sys.stdout.write(format_buf % (num_test, total_gen_tests)) + if num_test == total_gen_tests: + print("") + return + +def gen_self_test(curve, hashfunc, sig_alg_sign, sig_alg_verify, sig_alg_genkeypair, num, hashfunc_name, sig_alg_name, total_gen_tests): + global curr_test + curr_test = curr_test + 1 + if num != 0: + pretty_print_curr_test(curr_test, total_gen_tests) + output_list = [] + for test_num in range(0, num): + out_vectors = "" + # Generate a random key pair + keypair = sig_alg_genkeypair(curve) + # Generate a random message with a random size + size = getrandomint(256) + if is_python_2(): + message = ''.join([random.choice(string.ascii_letters + string.digits) for n in xrange(size)]) + else: + message = ''.join([random.choice(string.ascii_letters + string.digits) for n in range(size)]) + test_name = sig_alg_name + "_" + hashfunc_name + "_" + curve.name.upper() + "_" + str(test_num) + # Sign the message + (sig, k) = sig_alg_sign(hashfunc, keypair, message) + # Check that everything is OK with a verify + if sig_alg_verify(hashfunc, keypair, message, sig) != True: + raise Exception("Error during self test generation: sig verify failed! "+test_name+ " / msg="+message+" / sig="+binascii.hexlify(sig)+" / k="+hex(k)+" / privkey.x="+hex(keypair.privkey.x)) + if sig_alg_name == "ECRDSA": + out_vectors += "#ifndef USE_ISO14888_3_ECRDSA\n" + # Now generate the test vector + out_vectors += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"\n" + out_vectors += "#ifdef WITH_CURVE_"+curve.name.upper()+"\n" + out_vectors += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"\n" + out_vectors += "/* "+test_name+" known test vectors */\n" + out_vectors += "static int "+test_name+"_test_vectors_get_random(nn_t out, nn_src_t q)\n{\n" + # k_buf MUST be exported padded to the length of q + out_vectors += "\tconst u8 k_buf[] = "+bigint_to_C_array(k, getbytelen(curve.q)) + out_vectors += "\tint ret, cmp;\n\tret = nn_init_from_buf(out, k_buf, sizeof(k_buf)); EG(ret, err);\n\tret = nn_cmp(out, q, &cmp); EG(ret, err);\n\tret = (cmp >= 0) ? -1 : 0;\nerr:\n\treturn ret;\n}\n" + out_vectors += "static const u8 "+test_name+"_test_vectors_priv_key[] = \n"+bigint_to_C_array(keypair.privkey.x, getbytelen(keypair.privkey.x)) + out_vectors += "static const u8 "+test_name+"_test_vectors_expected_sig[] = \n"+bigint_to_C_array(stringtoint(sig), len(sig)) + out_vectors += "static const ec_test_case "+test_name+"_test_case = {\n" + out_vectors += "\t.name = \""+test_name+"\",\n" + out_vectors += "\t.ec_str_p = &"+curve.name+"_str_params,\n" + out_vectors += "\t.priv_key = "+test_name+"_test_vectors_priv_key,\n" + out_vectors += "\t.priv_key_len = sizeof("+test_name+"_test_vectors_priv_key),\n" + out_vectors += "\t.nn_random = "+test_name+"_test_vectors_get_random,\n" + out_vectors += "\t.hash_type = "+hashfunc_name+",\n" + out_vectors += "\t.msg = \""+message+"\",\n" + out_vectors += "\t.msglen = "+str(len(message))+",\n" + out_vectors += "\t.sig_type = "+sig_alg_name+",\n" + out_vectors += "\t.exp_sig = "+test_name+"_test_vectors_expected_sig,\n" + out_vectors += "\t.exp_siglen = sizeof("+test_name+"_test_vectors_expected_sig),\n};\n" + out_vectors += "#endif /* WITH_HASH_"+hashfunc_name+" */\n" + out_vectors += "#endif /* WITH_CURVE_"+curve.name+" */\n" + out_vectors += "#endif /* WITH_SIG_"+sig_alg_name+" */\n" + if sig_alg_name == "ECRDSA": + out_vectors += "#endif /* !USE_ISO14888_3_ECRDSA */\n" + out_name = "" + if sig_alg_name == "ECRDSA": + out_name += "#ifndef USE_ISO14888_3_ECRDSA"+"/* For "+test_name+" */\n" + out_name += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"/* For "+test_name+" */\n" + out_name += "#ifdef WITH_CURVE_"+curve.name.upper()+"/* For "+test_name+" */\n" + out_name += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"/* For "+test_name+" */\n" + out_name += "\t&"+test_name+"_test_case,\n" + out_name += "#endif /* WITH_HASH_"+hashfunc_name+" for "+test_name+" */\n" + out_name += "#endif /* WITH_CURVE_"+curve.name+" for "+test_name+" */\n" + out_name += "#endif /* WITH_SIG_"+sig_alg_name+" for "+test_name+" */" + if sig_alg_name == "ECRDSA": + out_name += "\n#endif /* !USE_ISO14888_3_ECRDSA */"+"/* For "+test_name+" */" + output_list.append((out_name, out_vectors)) + # In the specific case of ECRDSA, we also generate an ISO/IEC compatible test vector + if sig_alg_name == "ECRDSA": + out_vectors = "" + (sig, k) = sig_alg_sign(hashfunc, keypair, message, use_iso14888_divergence=True) + # Check that everything is OK with a verify + if sig_alg_verify(hashfunc, keypair, message, sig, use_iso14888_divergence=True) != True: + raise Exception("Error during self test generation: sig verify failed! "+test_name+ " / msg="+message+" / sig="+binascii.hexlify(sig)+" / k="+hex(k)+" / privkey.x="+hex(keypair.privkey.x)) + out_vectors += "#ifdef USE_ISO14888_3_ECRDSA\n" + # Now generate the test vector + out_vectors += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"\n" + out_vectors += "#ifdef WITH_CURVE_"+curve.name.upper()+"\n" + out_vectors += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"\n" + out_vectors += "/* "+test_name+" known test vectors */\n" + out_vectors += "static int "+test_name+"_test_vectors_get_random(nn_t out, nn_src_t q)\n{\n" + # k_buf MUST be exported padded to the length of q + out_vectors += "\tconst u8 k_buf[] = "+bigint_to_C_array(k, getbytelen(curve.q)) + out_vectors += "\tint ret, cmp;\n\tret = nn_init_from_buf(out, k_buf, sizeof(k_buf)); EG(ret, err);\n\tret = nn_cmp(out, q, &cmp); EG(ret, err);\n\tret = (cmp >= 0) ? -1 : 0;\nerr:\n\treturn ret;\n}\n" + out_vectors += "static const u8 "+test_name+"_test_vectors_priv_key[] = \n"+bigint_to_C_array(keypair.privkey.x, getbytelen(keypair.privkey.x)) + out_vectors += "static const u8 "+test_name+"_test_vectors_expected_sig[] = \n"+bigint_to_C_array(stringtoint(sig), len(sig)) + out_vectors += "static const ec_test_case "+test_name+"_test_case = {\n" + out_vectors += "\t.name = \""+test_name+"\",\n" + out_vectors += "\t.ec_str_p = &"+curve.name+"_str_params,\n" + out_vectors += "\t.priv_key = "+test_name+"_test_vectors_priv_key,\n" + out_vectors += "\t.priv_key_len = sizeof("+test_name+"_test_vectors_priv_key),\n" + out_vectors += "\t.nn_random = "+test_name+"_test_vectors_get_random,\n" + out_vectors += "\t.hash_type = "+hashfunc_name+",\n" + out_vectors += "\t.msg = \""+message+"\",\n" + out_vectors += "\t.msglen = "+str(len(message))+",\n" + out_vectors += "\t.sig_type = "+sig_alg_name+",\n" + out_vectors += "\t.exp_sig = "+test_name+"_test_vectors_expected_sig,\n" + out_vectors += "\t.exp_siglen = sizeof("+test_name+"_test_vectors_expected_sig),\n};\n" + out_vectors += "#endif /* WITH_HASH_"+hashfunc_name+" */\n" + out_vectors += "#endif /* WITH_CURVE_"+curve.name+" */\n" + out_vectors += "#endif /* WITH_SIG_"+sig_alg_name+" */\n" + out_vectors += "#endif /* USE_ISO14888_3_ECRDSA */\n" + out_name = "" + out_name += "#ifdef USE_ISO14888_3_ECRDSA"+"/* For "+test_name+" */\n" + out_name += "#ifdef WITH_HASH_"+hashfunc_name.upper()+"/* For "+test_name+" */\n" + out_name += "#ifdef WITH_CURVE_"+curve.name.upper()+"/* For "+test_name+" */\n" + out_name += "#ifdef WITH_SIG_"+sig_alg_name.upper()+"/* For "+test_name+" */\n" + out_name += "\t&"+test_name+"_test_case,\n" + out_name += "#endif /* WITH_HASH_"+hashfunc_name+" for "+test_name+" */\n" + out_name += "#endif /* WITH_CURVE_"+curve.name+" for "+test_name+" */\n" + out_name += "#endif /* WITH_SIG_"+sig_alg_name+" for "+test_name+" */\n" + out_name += "#endif /* USE_ISO14888_3_ECRDSA */"+"/* For "+test_name+" */" + output_list.append((out_name, out_vectors)) + + return output_list + +def gen_self_tests(curve, num): + global curr_test + curr_test = 0 + total_gen_tests = len(all_hash_funcs) * len(all_sig_algs) + vectors = [[ gen_self_test(curve, hashf, sign, verify, genkp, num, hash_name, sig_alg_name, total_gen_tests) + for (hashf, hash_name) in all_hash_funcs ] for (sign, verify, genkp, sig_alg_name) in all_sig_algs ] + return vectors + +########################################################## +### ASN.1 stuff +def parse_DER_extract_size(derbuf): + # Extract the size + if ord(derbuf[0]) & 0x80 != 0: + encoding_len_bytes = ord(derbuf[0]) & ~0x80 + # Skip + base = 1 + else: + encoding_len_bytes = 1 + base = 0 + if len(derbuf) < encoding_len_bytes+1: + return (False, 0, 0) + else: + length = stringtoint(derbuf[base:base+encoding_len_bytes]) + if len(derbuf) < length+encoding_len_bytes: + return (False, 0, 0) + else: + return (True, encoding_len_bytes+base, length) + +def extract_DER_object(derbuf, object_tag): + # Check type + if ord(derbuf[0]) != object_tag: + # Not the type we expect ... + return (False, 0, "") + else: + derbuf = derbuf[1:] + # Extract the size + (check, encoding_len, size) = parse_DER_extract_size(derbuf) + if check == False: + return (False, 0, "") + else: + if len(derbuf) < encoding_len + size: + return (False, 0, "") + else: + return (True, size+encoding_len+1, derbuf[encoding_len:encoding_len+size]) + +def extract_DER_sequence(derbuf): + return extract_DER_object(derbuf, 0x30) + +def extract_DER_integer(derbuf): + return extract_DER_object(derbuf, 0x02) + +def extract_DER_octetstring(derbuf): + return extract_DER_object(derbuf, 0x04) + +def extract_DER_bitstring(derbuf): + return extract_DER_object(derbuf, 0x03) + +def extract_DER_oid(derbuf): + return extract_DER_object(derbuf, 0x06) + +# See ECParameters sequence in RFC 3279 +def parse_DER_ECParameters(derbuf): + # XXX: this is a very ugly way of extracting the information + # regarding an EC curve, but since the ASN.1 structure is quite + # "static", this might be sufficient without embedding a full + # ASN.1 parser ... + # Default return (a, b, prime, order, cofactor, gx, gy) + default_ret = (0, 0, 0, 0, 0, 0, 0) + # Get ECParameters wrapping sequence + (check, size_ECParameters, ECParameters) = extract_DER_sequence(derbuf) + if check == False: + return (False, default_ret) + # Get integer + (check, size_ECPVer, ECPVer) = extract_DER_integer(ECParameters) + if check == False: + return (False, default_ret) + # Get sequence + (check, size_FieldID, FieldID) = extract_DER_sequence(ECParameters[size_ECPVer:]) + if check == False: + return (False, default_ret) + # Get OID + (check, size_Oid, Oid) = extract_DER_oid(FieldID) + if check == False: + return (False, default_ret) + # Does the OID correspond to a prime field? + if(Oid != "\x2A\x86\x48\xCE\x3D\x01\x01"): + print("DER parse error: only prime fields are supported ...") + return (False, default_ret) + # Get prime p of prime field + (check, size_P, P) = extract_DER_integer(FieldID[size_Oid:]) + if check == False: + return (False, default_ret) + # Get curve (sequence) + (check, size_Curve, Curve) = extract_DER_sequence(ECParameters[size_ECPVer+size_FieldID:]) + if check == False: + return (False, default_ret) + # Get A in curve + (check, size_A, A) = extract_DER_octetstring(Curve) + if check == False: + return (False, default_ret) + # Get B in curve + (check, size_B, B) = extract_DER_octetstring(Curve[size_A:]) + if check == False: + return (False, default_ret) + # Get ECPoint + (check, size_ECPoint, ECPoint) = extract_DER_octetstring(ECParameters[size_ECPVer+size_FieldID+size_Curve:]) + if check == False: + return (False, default_ret) + # Get Order + (check, size_Order, Order) = extract_DER_integer(ECParameters[size_ECPVer+size_FieldID+size_Curve+size_ECPoint:]) + if check == False: + return (False, default_ret) + # Get Cofactor + (check, size_Cofactor, Cofactor) = extract_DER_integer(ECParameters[size_ECPVer+size_FieldID+size_Curve+size_ECPoint+size_Order:]) + if check == False: + return (False, default_ret) + # If we end up here, everything is OK, we can extract all our elements + prime = stringtoint(P) + a = stringtoint(A) + b = stringtoint(B) + order = stringtoint(Order) + cofactor = stringtoint(Cofactor) + # Extract Gx and Gy, see X9.62-1998 + if len(ECPoint) < 1: + return (False, default_ret) + ECPoint_type = ord(ECPoint[0]) + if (ECPoint_type == 0x04) or (ECPoint_type == 0x06) or (ECPoint_type == 0x07): + # Uncompressed and hybrid points + if len(ECPoint[1:]) % 2 != 0: + return (False, default_ret) + ECPoint = ECPoint[1:] + gx = stringtoint(ECPoint[:int(len(ECPoint)/2)]) + gy = stringtoint(ECPoint[int(len(ECPoint)/2):]) + elif (ECPoint_type == 0x02) or (ECPoint_type == 0x03): + # Compressed point: uncompress it, see X9.62-1998 section 4.2.1 + ECPoint = ECPoint[1:] + gx = stringtoint(ECPoint) + alpha = (pow(gx, 3, prime) + (a * gx) + b) % prime + beta = mod_sqrt(alpha, prime) + if (beta == None) or ((beta == 0) and (alpha != 0)): + return (False, 0) + if (beta & 0x1) == (ECPoint_type & 0x1): + gy = beta + else: + gy = prime - beta + else: + print("DER parse error: hybrid points are unsupported!") + return (False, default_ret) + return (True, (a, b, prime, order, cofactor, gx, gy)) + +########################################################## +### Text and format helpers +def bigint_to_C_array(bint, size): + """ + Format a python big int to a C hex array + """ + hexstr = format(int(bint), 'x') + # Left pad to the size! + hexstr = ("0"*int((2*size)-len(hexstr)))+hexstr + hexstr = ("0"*(len(hexstr) % 2))+hexstr + out_str = "{\n" + for i in range(0, len(hexstr) - 1, 2): + if (i%16 == 0): + if(i!=0): + out_str += "\n" + out_str += "\t" + out_str += "0x"+hexstr[i:i+2]+", " + out_str += "\n};\n" + return out_str + +def check_in_file(fname, pat): + # See if the pattern is in the file. + with open(fname) as f: + if not any(re.search(pat, line) for line in f): + return False # pattern does not occur in file so we are done. + else: + return True + +def num_patterns_in_file(fname, pat): + num_pat = 0 + with open(fname) as f: + for line in f: + if re.search(pat, line): + num_pat = num_pat+1 + return num_pat + +def file_replace_pattern(fname, pat, s_after): + # first, see if the pattern is even in the file. + with open(fname) as f: + if not any(re.search(pat, line) for line in f): + return # pattern does not occur in file so we are done. + + # pattern is in the file, so perform replace operation. + with open(fname) as f: + out_fname = fname + ".tmp" + out = open(out_fname, "w") + for line in f: + out.write(re.sub(pat, s_after, line)) + out.close() + os.rename(out_fname, fname) + +def file_remove_pattern(fname, pat): + # first, see if the pattern is even in the file. + with open(fname) as f: + if not any(re.search(pat, line) for line in f): + return # pattern does not occur in file so we are done. + + # pattern is in the file, so perform remove operation. + with open(fname) as f: + out_fname = fname + ".tmp" + out = open(out_fname, "w") + for line in f: + if not re.search(pat, line): + out.write(line) + out.close() + + if os.path.exists(fname): + remove_file(fname) + os.rename(out_fname, fname) + +def remove_file(fname): + # Remove file + os.remove(fname) + +def remove_files_pattern(fpattern): + [remove_file(x) for x in glob.glob(fpattern)] + +def buffer_remove_pattern(buff, pat): + if is_python_2() == False: + buff = buff.decode('latin-1') + if re.search(pat, buff) == None: + return (False, buff) # pattern does not occur in file so we are done. + # Remove the pattern + buff = re.sub(pat, "", buff) + return (True, buff) + +def is_base64(s): + s = ''.join([s.strip() for s in s.split("\n")]) + try: + enc = base64.b64encode(base64.b64decode(s)).strip() + if type(enc) is bytes: + return enc == s.encode('latin-1') + else: + return enc == s + except TypeError: + return False + +### Curve helpers +def export_curve_int(curvename, intname, bigint, size): + if bigint == None: + out = "static const u8 "+curvename+"_"+intname+"[] = {\n\t0x00,\n};\n" + out += "TO_EC_STR_PARAM_FIXED_SIZE("+curvename+"_"+intname+", 0);\n\n" + else: + out = "static const u8 "+curvename+"_"+intname+"[] = "+bigint_to_C_array(bigint, size)+"\n" + out += "TO_EC_STR_PARAM("+curvename+"_"+intname+");\n\n" + return out + +def export_curve_string(curvename, stringname, stringvalue): + out = "static const u8 "+curvename+"_"+stringname+"[] = \""+stringvalue+"\";\n" + out += "TO_EC_STR_PARAM("+curvename+"_"+stringname+");\n\n" + return out + +def export_curve_struct(curvename, paramname, paramnamestr): + return "\t."+paramname+" = &"+curvename+"_"+paramnamestr+"_str_param, \n" + +def curve_params(name, prime, pbitlen, a, b, gx, gy, order, cofactor, oid, alpha_montgomery, gamma_montgomery, alpha_edwards): + """ + Take as input some elliptic curve parameters and generate the + C parameters in a string + """ + bytesize = int(pbitlen / 8) + if pbitlen % 8 != 0: + bytesize += 1 + # Compute the rounded word size for each word size + if bytesize % 8 != 0: + wordsbitsize64 = 8*((int(bytesize/8)+1)*8) + else: + wordsbitsize64 = 8*bytesize + if bytesize % 4 != 0: + wordsbitsize32 = 8*((int(bytesize/4)+1)*4) + else: + wordsbitsize32 = 8*bytesize + if bytesize % 2 != 0: + wordsbitsize16 = 8*((int(bytesize/2)+1)*2) + else: + wordsbitsize16 = 8*bytesize + # Compute some parameters + (r64, r_square64, mpinv64) = compute_monty_coef(prime, wordsbitsize64, 64) + (r32, r_square32, mpinv32) = compute_monty_coef(prime, wordsbitsize32, 32) + (r16, r_square16, mpinv16) = compute_monty_coef(prime, wordsbitsize16, 16) + # Compute p_reciprocal for each word size + (pshift64, primenorm64, p_reciprocal64) = compute_div_coef(prime, wordsbitsize64, 64) + (pshift32, primenorm32, p_reciprocal32) = compute_div_coef(prime, wordsbitsize32, 32) + (pshift16, primenorm16, p_reciprocal16) = compute_div_coef(prime, wordsbitsize16, 16) + # Compute the number of points on the curve + npoints = order * cofactor + + # Now output the parameters + ec_params_string = "#include <libecc/lib_ecc_config.h>\n" + ec_params_string += "#ifdef WITH_CURVE_"+name.upper()+"\n\n" + ec_params_string += "#ifndef __EC_PARAMS_"+name.upper()+"_H__\n" + ec_params_string += "#define __EC_PARAMS_"+name.upper()+"_H__\n" + ec_params_string += "#include <libecc/curves/known/ec_params_external.h>\n" + ec_params_string += export_curve_int(name, "p", prime, bytesize) + + ec_params_string += "#define CURVE_"+name.upper()+"_P_BITLEN "+str(pbitlen)+"\n" + ec_params_string += export_curve_int(name, "p_bitlen", pbitlen, getbytelen(pbitlen)) + + ec_params_string += "#if (WORD_BYTES == 8) /* 64-bit words */\n" + ec_params_string += export_curve_int(name, "r", r64, getbytelen(r64)) + ec_params_string += export_curve_int(name, "r_square", r_square64, getbytelen(r_square64)) + ec_params_string += export_curve_int(name, "mpinv", mpinv64, getbytelen(mpinv64)) + ec_params_string += export_curve_int(name, "p_shift", pshift64, getbytelen(pshift64)) + ec_params_string += export_curve_int(name, "p_normalized", primenorm64, getbytelen(primenorm64)) + ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal64, getbytelen(p_reciprocal64)) + ec_params_string += "#elif (WORD_BYTES == 4) /* 32-bit words */\n" + ec_params_string += export_curve_int(name, "r", r32, getbytelen(r32)) + ec_params_string += export_curve_int(name, "r_square", r_square32, getbytelen(r_square32)) + ec_params_string += export_curve_int(name, "mpinv", mpinv32, getbytelen(mpinv32)) + ec_params_string += export_curve_int(name, "p_shift", pshift32, getbytelen(pshift32)) + ec_params_string += export_curve_int(name, "p_normalized", primenorm32, getbytelen(primenorm32)) + ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal32, getbytelen(p_reciprocal32)) + ec_params_string += "#elif (WORD_BYTES == 2) /* 16-bit words */\n" + ec_params_string += export_curve_int(name, "r", r16, getbytelen(r16)) + ec_params_string += export_curve_int(name, "r_square", r_square16, getbytelen(r_square16)) + ec_params_string += export_curve_int(name, "mpinv", mpinv16, getbytelen(mpinv16)) + ec_params_string += export_curve_int(name, "p_shift", pshift16, getbytelen(pshift16)) + ec_params_string += export_curve_int(name, "p_normalized", primenorm16, getbytelen(primenorm16)) + ec_params_string += export_curve_int(name, "p_reciprocal", p_reciprocal16, getbytelen(p_reciprocal16)) + ec_params_string += "#else /* unknown word size */\n" + ec_params_string += "#error \"Unsupported word size\"\n" + ec_params_string += "#endif\n\n" + + ec_params_string += export_curve_int(name, "a", a, bytesize) + ec_params_string += export_curve_int(name, "b", b, bytesize) + + curve_order_bitlen = getbitlen(npoints) + ec_params_string += "#define CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN "+str(curve_order_bitlen)+"\n" + ec_params_string += export_curve_int(name, "curve_order", npoints, getbytelen(npoints)) + + ec_params_string += export_curve_int(name, "gx", gx, bytesize) + ec_params_string += export_curve_int(name, "gy", gy, bytesize) + ec_params_string += export_curve_int(name, "gz", 0x01, bytesize) + + qbitlen = getbitlen(order) + + ec_params_string += export_curve_int(name, "gen_order", order, getbytelen(order)) + ec_params_string += "#define CURVE_"+name.upper()+"_Q_BITLEN "+str(qbitlen)+"\n" + ec_params_string += export_curve_int(name, "gen_order_bitlen", qbitlen, getbytelen(qbitlen)) + + ec_params_string += export_curve_int(name, "cofactor", cofactor, getbytelen(cofactor)) + + ec_params_string += export_curve_int(name, "alpha_montgomery", alpha_montgomery, getbytelen(alpha_montgomery)) + ec_params_string += export_curve_int(name, "gamma_montgomery", gamma_montgomery, getbytelen(gamma_montgomery)) + ec_params_string += export_curve_int(name, "alpha_edwards", alpha_edwards, getbytelen(alpha_edwards)) + + ec_params_string += export_curve_string(name, "name", name.upper()); + + if oid == None: + oid = "" + ec_params_string += export_curve_string(name, "oid", oid); + + ec_params_string += "static const ec_str_params "+name+"_str_params = {\n"+\ + export_curve_struct(name, "p", "p") +\ + export_curve_struct(name, "p_bitlen", "p_bitlen") +\ + export_curve_struct(name, "r", "r") +\ + export_curve_struct(name, "r_square", "r_square") +\ + export_curve_struct(name, "mpinv", "mpinv") +\ + export_curve_struct(name, "p_shift", "p_shift") +\ + export_curve_struct(name, "p_normalized", "p_normalized") +\ + export_curve_struct(name, "p_reciprocal", "p_reciprocal") +\ + export_curve_struct(name, "a", "a") +\ + export_curve_struct(name, "b", "b") +\ + export_curve_struct(name, "curve_order", "curve_order") +\ + export_curve_struct(name, "gx", "gx") +\ + export_curve_struct(name, "gy", "gy") +\ + export_curve_struct(name, "gz", "gz") +\ + export_curve_struct(name, "gen_order", "gen_order") +\ + export_curve_struct(name, "gen_order_bitlen", "gen_order_bitlen") +\ + export_curve_struct(name, "cofactor", "cofactor") +\ + export_curve_struct(name, "alpha_montgomery", "alpha_montgomery") +\ + export_curve_struct(name, "gamma_montgomery", "gamma_montgomery") +\ + export_curve_struct(name, "alpha_edwards", "alpha_edwards") +\ + export_curve_struct(name, "oid", "oid") +\ + export_curve_struct(name, "name", "name") + ec_params_string += "};\n\n" + + ec_params_string += "/*\n"+\ + " * Compute max bit length of all curves for p and q\n"+\ + " */\n"+\ + "#ifndef CURVES_MAX_P_BIT_LEN\n"+\ + "#define CURVES_MAX_P_BIT_LEN 0\n"+\ + "#endif\n"+\ + "#if (CURVES_MAX_P_BIT_LEN < CURVE_"+name.upper()+"_P_BITLEN)\n"+\ + "#undef CURVES_MAX_P_BIT_LEN\n"+\ + "#define CURVES_MAX_P_BIT_LEN CURVE_"+name.upper()+"_P_BITLEN\n"+\ + "#endif\n"+\ + "#ifndef CURVES_MAX_Q_BIT_LEN\n"+\ + "#define CURVES_MAX_Q_BIT_LEN 0\n"+\ + "#endif\n"+\ + "#if (CURVES_MAX_Q_BIT_LEN < CURVE_"+name.upper()+"_Q_BITLEN)\n"+\ + "#undef CURVES_MAX_Q_BIT_LEN\n"+\ + "#define CURVES_MAX_Q_BIT_LEN CURVE_"+name.upper()+"_Q_BITLEN\n"+\ + "#endif\n"+\ + "#ifndef CURVES_MAX_CURVE_ORDER_BIT_LEN\n"+\ + "#define CURVES_MAX_CURVE_ORDER_BIT_LEN 0\n"+\ + "#endif\n"+\ + "#if (CURVES_MAX_CURVE_ORDER_BIT_LEN < CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN)\n"+\ + "#undef CURVES_MAX_CURVE_ORDER_BIT_LEN\n"+\ + "#define CURVES_MAX_CURVE_ORDER_BIT_LEN CURVE_"+name.upper()+"_CURVE_ORDER_BITLEN\n"+\ + "#endif\n\n" + + ec_params_string += "/*\n"+\ + " * Compute and adapt max name and oid length\n"+\ + " */\n"+\ + "#ifndef MAX_CURVE_OID_LEN\n"+\ + "#define MAX_CURVE_OID_LEN 0\n"+\ + "#endif\n"+\ + "#ifndef MAX_CURVE_NAME_LEN\n"+\ + "#define MAX_CURVE_NAME_LEN 0\n"+\ + "#endif\n"+\ + "#if (MAX_CURVE_OID_LEN < "+str(len(oid)+1)+")\n"+\ + "#undef MAX_CURVE_OID_LEN\n"+\ + "#define MAX_CURVE_OID_LEN "+str(len(oid)+1)+"\n"+\ + "#endif\n"+\ + "#if (MAX_CURVE_NAME_LEN < "+str(len(name.upper())+1)+")\n"+\ + "#undef MAX_CURVE_NAME_LEN\n"+\ + "#define MAX_CURVE_NAME_LEN "+str(len(name.upper())+1)+"\n"+\ + "#endif\n\n" + + ec_params_string += "#endif /* __EC_PARAMS_"+name.upper()+"_H__ */\n\n"+"#endif /* WITH_CURVE_"+name.upper()+" */\n" + + return ec_params_string + +def usage(): + print("This script is intented to *statically* expand the ECC library with user defined curves.") + print("By statically we mean that the source code of libecc is expanded with new curves parameters through") + print("automatic code generation filling place holders in the existing code base of the library. Though the") + print("choice of static code generation versus dynamic curves import (such as what OpenSSL does) might be") + print("argued, this choice has been driven by simplicity and security design decisions: we want libecc to have") + print("all its parameters (such as memory consumption) set at compile time and statically adapted to the curves.") + print("Since libecc only supports curves over prime fields, the script can only add this kind of curves.") + print("This script implements elliptic curves and ISO signature algorithms from scratch over Python's multi-precision") + print("big numbers library. Addition and doubling over curves use naive formulas. Please DO NOT use the functions of this") + print("script for production code: they are not securely implemented and are very inefficient. Their only purpose is to expand") + print("libecc and produce test vectors.") + print("") + print("In order to add a curve, there are two ways:") + print("Adding a user defined curve with explicit parameters:") + print("-----------------------------------------------------") + print(sys.argv[0]+" --name=\"YOURCURVENAME\" --prime=... --order=... --a=... --b=... --gx=... --gy=... --cofactor=... --oid=THEOID") + print("\t> name: name of the curve in the form of a string") + print("\t> prime: prime number representing the curve prime field") + print("\t> order: prime number representing the generator order") + print("\t> cofactor: cofactor of the curve") + print("\t> a: 'a' coefficient of the short Weierstrass equation of the curve") + print("\t> b: 'b' coefficient of the short Weierstrass equation of the curve") + print("\t> gx: x coordinate of the generator G") + print("\t> gy: y coordinate of the generator G") + print("\t> oid: optional OID of the curve") + print(" Notes:") + print(" ******") + print("\t1) These elements are verified to indeed satisfy the curve equation.") + print("\t2) All the numbers can be given either in decimal or hexadecimal format with a prepending '0x'.") + print("\t3) The script automatically generates all the necessary files for the curve to be included in the library." ) + print("\tYou will find the new curve definition in the usual 'lib_ecc_config.h' file (one can activate it or not at compile time).") + print("") + print("Adding a user defined curve through RFC3279 ASN.1 parameters:") + print("-------------------------------------------------------------") + print(sys.argv[0]+" --name=\"YOURCURVENAME\" --ECfile=... --oid=THEOID") + print("\t> ECfile: the DER or PEM encoded file containing the curve parameters (see RFC3279)") + print(" Notes:") + print("\tCurve parameters encoded in DER or PEM format can be generated with tools like OpenSSL (among others). As an illustrative example,") + print("\tone can list all the supported curves under OpenSSL with:") + print("\t $ openssl ecparam -list_curves") + print("\tOnly the listed so called \"prime\" curves are supported. Then, one can extract an explicit curve representation in ASN.1") + print("\tas defined in RFC3279, for example for BRAINPOOLP320R1:") + print("\t $ openssl ecparam -param_enc explicit -outform DER -name brainpoolP320r1 -out brainpoolP320r1.der") + print("") + print("Removing user defined curves:") + print("-----------------------------") + print("\t*All the user defined curves can be removed with the --remove-all toggle.") + print("\t*A specific named user define curve can be removed with the --remove toggle: in this case the --name option is used to ") + print("\tlocate which named curve must be deleted.") + print("") + print("Test vectors:") + print("-------------") + print("\tTest vectors can be automatically generated and added to the library self tests when providing the --add-test-vectors=X toggle.") + print("\tIn this case, X test vectors will be generated for *each* (curve, sign algorithm, hash algorithm) 3-uplet (beware of combinatorial") + print("\tissues when X is big!). These tests are transparently added and compiled with the self tests.") + return + +def get_int(instring): + if len(instring) == 0: + return 0 + if len(instring) >= 2: + if instring[:2] == "0x": + return int(instring, 16) + return int(instring) + +def parse_cmd_line(args): + """ + Get elliptic curve parameters from command line + """ + name = oid = prime = a = b = gx = gy = g = order = cofactor = ECfile = remove = remove_all = add_test_vectors = None + alpha_montgomery = gamma_montgomery = alpha_edwards = None + try: + opts, args = getopt.getopt(sys.argv[1:], ":h", ["help", "remove", "remove-all", "name=", "prime=", "a=", "b=", "generator=", "gx=", "gy=", "order=", "cofactor=", "alpha_montgomery=","gamma_montgomery=", "alpha_edwards=", "ECfile=", "oid=", "add-test-vectors="]) + except getopt.GetoptError as err: + # print help information and exit: + print(err) # will print something like "option -a not recognized" + usage() + return False + for o, arg in opts: + if o in ("-h", "--help"): + usage() + return True + elif o in ("--name"): + name = arg + # Prepend the custom string before name to avoid any collision + name = "user_defined_"+name + # Replace any unwanted name char + name = re.sub("\-", "_", name) + elif o in ("--oid="): + oid = arg + elif o in ("--prime"): + prime = get_int(arg.replace(' ', '')) + elif o in ("--a"): + a = get_int(arg.replace(' ', '')) + elif o in ("--b"): + b = get_int(arg.replace(' ', '')) + elif o in ("--gx"): + gx = get_int(arg.replace(' ', '')) + elif o in ("--gy"): + gy = get_int(arg.replace(' ', '')) + elif o in ("--generator"): + g = arg.replace(' ', '') + elif o in ("--order"): + order = get_int(arg.replace(' ', '')) + elif o in ("--cofactor"): + cofactor = get_int(arg.replace(' ', '')) + elif o in ("--alpha_montgomery"): + alpha_montgomery = get_int(arg.replace(' ', '')) + elif o in ("--gamma_montgomery"): + gamma_montgomery = get_int(arg.replace(' ', '')) + elif o in ("--alpha_edwards"): + alpha_edwards = get_int(arg.replace(' ', '')) + elif o in ("--remove"): + remove = True + elif o in ("--remove-all"): + remove_all = True + elif o in ("--add-test-vectors"): + add_test_vectors = get_int(arg.replace(' ', '')) + elif o in ("--ECfile"): + ECfile = arg + else: + print("unhandled option") + usage() + return False + + # File paths + script_path = os.path.abspath(os.path.dirname(sys.argv[0])) + "/" + ec_params_path = script_path + "../include/libecc/curves/user_defined/" + curves_list_path = script_path + "../include/libecc/curves/" + lib_ecc_types_path = script_path + "../include/libecc/" + lib_ecc_config_path = script_path + "../include/libecc/" + ec_self_tests_path = script_path + "../src/tests/" + meson_options_path = script_path + "../" + + # If remove is True, we have been asked to remove already existing user defined curves + if remove == True: + if name == None: + print("--remove option expects a curve name provided with --name") + return False + asked = "" + while asked != "y" and asked != "n": + asked = get_user_input("You asked to remove everything related to user defined "+name.replace("user_defined_", "")+" curve. Enter y to confirm, n to cancel [y/n]. ") + if asked == "n": + print("NOT removing curve "+name.replace("user_defined_", "")+" (cancelled).") + return True + # Remove any user defined stuff with given name + print("Removing user defined curve "+name.replace("user_defined_", "")+" ...") + if name == None: + print("Error: you must provide a curve name with --remove") + return False + file_remove_pattern(curves_list_path + "curves_list.h", ".*"+name+".*") + file_remove_pattern(curves_list_path + "curves_list.h", ".*"+name.upper()+".*") + file_remove_pattern(lib_ecc_types_path + "lib_ecc_types.h", ".*"+name.upper()+".*") + file_remove_pattern(lib_ecc_config_path + "lib_ecc_config.h", ".*"+name.upper()+".*") + file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*"+name+".*") + file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*"+name.upper()+".*") + file_remove_pattern(meson_options_path + "meson.options", ".*"+name.lower()+".*") + try: + remove_file(ec_params_path + "ec_params_"+name+".h") + except: + print("Error: curve name "+name+" does not seem to be present in the sources!") + return False + try: + remove_file(ec_self_tests_path + "ec_self_tests_core_"+name+".h") + except: + print("Warning: curve name "+name+" self tests do not seem to be present ...") + return True + return True + if remove_all == True: + asked = "" + while asked != "y" and asked != "n": + asked = get_user_input("You asked to remove everything related to ALL user defined curves. Enter y to confirm, n to cancel [y/n]. ") + if asked == "n": + print("NOT removing user defined curves (cancelled).") + return True + # Remove any user defined stuff with given name + print("Removing ALL user defined curves ...") + # Remove any user defined stuff (whatever name) + file_remove_pattern(curves_list_path + "curves_list.h", ".*user_defined.*") + file_remove_pattern(curves_list_path + "curves_list.h", ".*USER_DEFINED.*") + file_remove_pattern(lib_ecc_types_path + "lib_ecc_types.h", ".*USER_DEFINED.*") + file_remove_pattern(lib_ecc_config_path + "lib_ecc_config.h", ".*USER_DEFINED.*") + file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*USER_DEFINED.*") + file_remove_pattern(ec_self_tests_path + "ec_self_tests_core.h", ".*user_defined.*") + file_remove_pattern(meson_options_path + "meson.options", ".*user_defined.*") + remove_files_pattern(ec_params_path + "ec_params_user_defined_*.h") + remove_files_pattern(ec_self_tests_path + "ec_self_tests_core_user_defined_*.h") + return True + + # If a g is provided, split it in two gx and gy + if g != None: + if (len(g)/2)%2 == 0: + gx = get_int(g[:len(g)/2]) + gy = get_int(g[len(g)/2:]) + else: + # This is probably a generator encapsulated in a bit string + if g[0:2] != "04": + print("Error: provided generator g is not conforming!") + return False + else: + g = g[2:] + gx = get_int(g[:len(g)/2]) + gy = get_int(g[len(g)/2:]) + if ECfile != None: + # ASN.1 DER input incompatible with other options + if (prime != None) or (a != None) or (b != None) or (gx != None) or (gy != None) or (order != None) or (cofactor != None): + print("Error: option ECfile incompatible with explicit (prime, a, b, gx, gy, order, cofactor) options!") + return False + # We need at least a name + if (name == None): + print("Error: option ECfile needs a curve name!") + return False + # Open the file + try: + buf = open(ECfile, 'rb').read() + except: + print("Error: cannot open ECfile file "+ECfile) + return False + # Check if we have a PEM or a DER file + (check, derbuf) = buffer_remove_pattern(buf, "-----.*-----") + if (check == True): + # This a PEM file, proceed with base64 decoding + if(is_base64(derbuf) == False): + print("Error: error when decoding ECfile file "+ECfile+" (seems to be PEM, but failed to decode)") + return False + derbuf = base64.b64decode(derbuf) + (check, (a, b, prime, order, cofactor, gx, gy)) = parse_DER_ECParameters(derbuf) + if (check == False): + print("Error: error when parsing ECfile file "+ECfile+" (malformed or unsupported ASN.1)") + return False + + else: + if (prime == None) or (a == None) or (b == None) or (gx == None) or (gy == None) or (order == None) or (cofactor == None) or (name == None): + err_string = (prime == None)*"prime "+(a == None)*"a "+(b == None)*"b "+(gx == None)*"gx "+(gy == None)*"gy "+(order == None)*"order "+(cofactor == None)*"cofactor "+(name == None)*"name " + print("Error: missing "+err_string+" in explicit curve definition (name, prime, a, b, gx, gy, order, cofactor)!") + print("See the help with -h or --help") + return False + + # Some sanity checks here + # Check that prime is indeed a prime + if is_probprime(prime) == False: + print("Error: given prime is *NOT* prime!") + return False + if is_probprime(order) == False: + print("Error: given order is *NOT* prime!") + return False + if (a > prime) or (b > prime) or (gx > prime) or (gy > prime): + err_string = (a > prime)*"a "+(b > prime)*"b "+(gx > prime)*"gx "+(gy > prime)*"gy " + print("Error: "+err_string+"is > prime") + return False + # Check that the provided generator is on the curve + if pow(gy, 2, prime) != ((pow(gx, 3, prime) + (a*gx) + b) % prime): + print("Error: the given parameters (prime, a, b, gx, gy) do not verify the elliptic curve equation!") + return False + + # Check Montgomery and Edwards transfer coefficients + if ((alpha_montgomery != None) and (gamma_montgomery == None)) or ((alpha_montgomery == None) and (gamma_montgomery != None)): + print("Error: alpha_montgomery and gamma_montgomery must be both defined if used!") + return False + if (alpha_edwards != None): + if (alpha_montgomery == None) or (gamma_montgomery == None): + print("Error: alpha_edwards needs alpha_montgomery and gamma_montgomery to be both defined if used!") + return False + + # Now that we have our parameters, call the function to get bitlen + pbitlen = getbitlen(prime) + ec_params = curve_params(name, prime, pbitlen, a, b, gx, gy, order, cofactor, oid, alpha_montgomery, gamma_montgomery, alpha_edwards) + # Check if there is a name collision somewhere + if os.path.exists(ec_params_path + "ec_params_"+name+".h") == True : + print("Error: file %s already exists!" % (ec_params_path + "ec_params_"+name+".h")) + return False + if (check_in_file(curves_list_path + "curves_list.h", "ec_params_"+name+"_str_params") == True) or (check_in_file(curves_list_path + "curves_list.h", "WITH_CURVE_"+name.upper()+"\n") == True) or (check_in_file(lib_ecc_types_path + "lib_ecc_types.h", "WITH_CURVE_"+name.upper()+"\n") == True): + print("Error: name %s already exists in files" % ("ec_params_"+name)) + return False + # Create a new file with the parameters + if not os.path.exists(ec_params_path): + # Create the "user_defined" folder if it does not exist + os.mkdir(ec_params_path) + f = open(ec_params_path + "ec_params_"+name+".h", 'w') + f.write(ec_params) + f.close() + # Include the file in curves_list.h + magic = "ADD curves header here" + magic_re = "\/\* "+magic+" \*\/" + magic_back = "/* "+magic+" */" + file_replace_pattern(curves_list_path + "curves_list.h", magic_re, "#include <libecc/curves/user_defined/ec_params_"+name+".h>\n"+magic_back) + # Add the curve mapping + magic = "ADD curves mapping here" + magic_re = "\/\* "+magic+" \*\/" + magic_back = "/* "+magic+" */" + file_replace_pattern(curves_list_path + "curves_list.h", magic_re, "#ifdef WITH_CURVE_"+name.upper()+"\n\t{ .type = "+name.upper()+", .params = &"+name+"_str_params },\n#endif /* WITH_CURVE_"+name.upper()+" */\n"+magic_back) + # Add the new curve type in the enum + # First we get the number of already defined curves so that we increment the enum counter + num_with_curve = num_patterns_in_file(lib_ecc_types_path + "lib_ecc_types.h", "#ifdef WITH_CURVE_") + magic = "ADD curves type here" + magic_re = "\/\* "+magic+" \*\/" + magic_back = "/* "+magic+" */" + file_replace_pattern(lib_ecc_types_path + "lib_ecc_types.h", magic_re, "#ifdef WITH_CURVE_"+name.upper()+"\n\t"+name.upper()+" = "+str(num_with_curve+1)+",\n#endif /* WITH_CURVE_"+name.upper()+" */\n"+magic_back) + # Add the new curve define in the config + magic = "ADD curves define here" + magic_re = "\/\* "+magic+" \*\/" + magic_back = "/* "+magic+" */" + file_replace_pattern(lib_ecc_config_path + "lib_ecc_config.h", magic_re, "#define WITH_CURVE_"+name.upper()+"\n"+magic_back) + # Add the new curve meson option in the meson.options file + magic = "ADD curves meson option here" + magic_re = "# " + magic + magic_back = "# " + magic + file_replace_pattern(meson_options_path + "meson.options", magic_re, "\t'"+name.lower()+"',\n"+magic_back) + + # Do we need to add some test vectors? + if add_test_vectors != None: + print("Test vectors generation asked: this can take some time! Please wait ...") + # Create curve + c = Curve(a, b, prime, order, cofactor, gx, gy, cofactor * order, name, oid) + # Generate key pair for the algorithm + vectors = gen_self_tests(c, add_test_vectors) + # Iterate through all the tests + f = open(ec_self_tests_path + "ec_self_tests_core_"+name+".h", 'w') + for l in vectors: + for v in l: + for case in v: + (case_name, case_vector) = case + # Add the new test case + magic = "ADD curve test case here" + magic_re = "\/\* "+magic+" \*\/" + magic_back = "/* "+magic+" */" + file_replace_pattern(ec_self_tests_path + "ec_self_tests_core.h", magic_re, case_name+"\n"+magic_back) + # Create/Increment the header file + f.write(case_vector) + f.close() + # Add the new test cases header + magic = "ADD curve test vectors header here" + magic_re = "\/\* "+magic+" \*\/" + magic_back = "/* "+magic+" */" + file_replace_pattern(ec_self_tests_path + "ec_self_tests_core.h", magic_re, "#include \"ec_self_tests_core_"+name+".h\"\n"+magic_back) + return True + + +#### Main +if __name__ == "__main__": + signal.signal(signal.SIGINT, handler) + parse_cmd_line(sys.argv[1:]) |
