#include "pch.h"
#include "engine.h"

#include "hash.h"

/* Elliptic curve cryptography based on NIST DSS prime curves. */

static int readdigits(ushort *digits, const char *s)
{
    int slen = strlen(s), len = (slen+2*sizeof(ushort)-1)/(2*sizeof(ushort));
    memset(digits, 0, len*sizeof(ushort));
    loopi(slen)
    {
        int c = s[slen-i-1];
        if(isalpha(c)) c = toupper(c) - 'A' + 10;
        else c -= '0';
        digits[i/(2*sizeof(ushort))] |= c<<(4*(i%(2*sizeof(ushort)))); 
    };
    return len;
};

static void writedigits(const ushort *digits, int len, FILE *out)
{
    loopi(len) fprintf(out, "%.4x", digits[len-i-1]);
};

#define BI_DIGIT_BITS 16
#define BI_DIGIT_MASK ((1<<BI_DIGIT_BITS)-1)

template<int BI_DIGITS> struct bigint
{
    typedef ushort digit;
    typedef uint dbldigit;

    int len;
    digit digits[BI_DIGITS];

    bigint() {};
    bigint(digit n) { if(n) { len = 1; digits[0] = n; } else len = 0; };
    bigint(const char *s) { len = readdigits(digits, s); };
    template<int Y_SIZE> bigint(const bigint<Y_SIZE> &y) { *this = y; };

    void write(FILE *out) const { writedigits(digits, len, out); };
 
    template<int Y_SIZE> bigint &operator=(const bigint<Y_SIZE> &y)
    {
        len = y.len;
        memcpy(digits, y.digits, len*sizeof(digit));
        return *this;
    };

    bool iszero() const { return !len; };

    int numbits() const
    {
        if(!len) return 0;
        int bits = len*BI_DIGIT_BITS;
        digit last = digits[len-1], mask = 1<<(BI_DIGIT_BITS-1);
        while(mask)
        {
            if(last&mask) return bits;
            bits--;
            mask >>= 1;
        };
        return 0;
    };

    bool hasbit(int n) const { return n/BI_DIGIT_BITS < len && ((digits[n/BI_DIGIT_BITS]>>(n%BI_DIGIT_BITS))&1); };

    template<int X_DIGITS, int Y_DIGITS> void add(const bigint<X_DIGITS> &x, const bigint<Y_DIGITS> &y)
    {
        dbldigit carry = 0;
        int maxlen = max(x.len, y.len);
        for(int i = 0; i < y.len || carry; i++)
        {
             if(i >= maxlen) maxlen++;
             carry += (i < x.len ? (dbldigit)x.digits[i] : 0) + (i < y.len ? (dbldigit)y.digits[i] : 0);
             digits[i] = (digit)carry;
             carry >>= BI_DIGIT_BITS;
        };
        len = maxlen;
    };
    template<int Y_DIGITS> void add(const bigint<Y_DIGITS> &y) { add(*this, y); };

    template<int X_DIGITS, int Y_DIGITS> void sub(const bigint<X_DIGITS> &x, const bigint<Y_DIGITS> &y)
    {
        ASSERT(x >= y);
        dbldigit borrow = 0;
        for(int i = 0; i < y.len || borrow; i++)
        {
             borrow = (1<<BI_DIGIT_BITS) + (dbldigit)x.digits[i] - (i<y.len ? (dbldigit)y.digits[i] : 0) - borrow;
             digits[i] = (digit)borrow;
             borrow = (borrow>>BI_DIGIT_BITS)^1;
        };
        len = x.len;
        shrink();
    };
    template<int Y_DIGITS> void sub(const bigint<Y_DIGITS> &y) { sub(*this, y); };

    void shrink() { while(len && !digits[len-1]) len--; };

    template<int X_DIGITS, int Y_DIGITS> void mul(const bigint<X_DIGITS> &x, const bigint<Y_DIGITS> &y)
    {
        if(!x.len || !y.len) { len = 0; return; };
        memset(digits, 0, y.len*sizeof(digit));
        loopi(x.len)
        {
            dbldigit carry = 0;
            loopj(y.len)
            {
                carry += (dbldigit)x.digits[i] * (dbldigit)y.digits[j] + (dbldigit)digits[i+j];
                digits[i+j] = (digit)carry;
                carry >>= BI_DIGIT_BITS;
            };
            digits[i+y.len] = carry;
        };
        len = x.len + y.len;
        shrink();
    };

    void rshift(uint n)
    {
        if(!len || !n) return;
        uint dig = (n-1)/BI_DIGIT_BITS;
        n = ((n-1) % BI_DIGIT_BITS)+1;
        digit carry = digits[dig]>>n;
        loopi(len-dig-1)
        {
            digit tmp = digits[i+dig+1];
            digits[i] = (tmp<<(BI_DIGIT_BITS-n)) | carry;
            carry = tmp>>n;
        };
        digits[len-dig-1] = carry;
        len -= dig + (n>>BI_DIGIT_BITS);
        shrink();
    };

    void lshift(uint n)
    {
        if(!len || !n) return;
        uint dig = n/BI_DIGIT_BITS;
        n %= BI_DIGIT_BITS;
        digit carry = 0;
        for(int i = len-1; i >= 0; i--)
        {
            digit tmp = digits[i];
            digits[i+dig] = (tmp<<n) | carry;
            carry = tmp>>(BI_DIGIT_BITS-n);
        };
        len += dig;
        if(carry) digits[len++] = carry;
        if(dig) memset(digits, 0, dig*sizeof(digit));
    };

    template<int Y_DIGITS> bool operator==(const bigint<Y_DIGITS> &y) const
    {
        if(len!=y.len) return false;
        for(int i = len-1; i>=0; i--) if(digits[i]!=y.digits[i]) return false;
        return true;
    };
    template<int Y_DIGITS> bool operator!=(const bigint<Y_DIGITS> &y) const { return !(*this==y); };
    template<int Y_DIGITS> bool operator<(const bigint<Y_DIGITS> &y) const
    {
        if(len<y.len) return true;
        if(len>y.len) return false;
        for(int i = len-1; i>=0; i--)
        {
            if(digits[i]<y.digits[i]) return true;
            if(digits[i]>y.digits[i]) return false;
        };
        return false;
    };
    template<int Y_DIGITS> bool operator>(const bigint<Y_DIGITS> &y) const { return y<*this; };
    template<int Y_DIGITS> bool operator<=(const bigint<Y_DIGITS> &y) const { return !(y<*this); };
    template<int Y_DIGITS> bool operator>=(const bigint<Y_DIGITS> &y) const { return !(*this<y); };
};

#define GF_BITS         192
#define GF_DIGIT_BITS   16
#define GF_DIGIT_MASK   ((1<<GF_DIGIT_BITS)-1)
#define GF_DIGITS       ((GF_BITS+GF_DIGIT_BITS-1)/GF_DIGIT_BITS)

typedef bigint<GF_DIGITS+1> gfint;

/* NIST prime Galois fields.
 * Currently only supports NIST P-192, where P=2^192-2^64-1.
 */
struct gfield : gfint
{
    typedef ushort digit;
    typedef uint dbldigit;

    gfield() {};
    gfield(digit n) : gfint(n) {};
    gfield(const char *s) : gfint(s) {};

    static gfield P;

    void add(const gfield &x, const gfield &y)
    {
        gfint::add(x, y);
        if(*this >= P) gfint::sub(*this, P);
    };
    void add(const gfield &y) { add(*this, y); };

    void sub(const gfield &x, const gfield &y)
    {
        if(x < y)
        {
            gfint::add(x, P);
            gfint::sub(*this, y);
        }
        else gfint::sub(x, y);
    };
    void sub(const gfield &y) { sub(*this, y); };

    void square(const gfield &x) { mul(x, x); };
    void square() { square(*this); };

    void mul(const gfield &x, const gfield &y)
    {
        bigint<2*GF_DIGITS> result;
        result.mul(x, y);
        reduce(result);
    };
    void mul(const gfield &y) { mul(*this, y); };

    template<int RESULT_DIGITS>
    void reduce(const bigint<RESULT_DIGITS> &result)
    {
#if GF_BITS==192
        len = min(result.len, GF_DIGITS);
        memcpy(digits, result.digits, len*sizeof(digit));
        shrink();

        if(result.len > 192/GF_DIGIT_BITS)
        {
            gfield s;
            memcpy(s.digits, &result.digits[192/GF_DIGIT_BITS], min(result.len-192/GF_DIGIT_BITS, 64/GF_DIGIT_BITS)*sizeof(digit));
            if(result.len < 256/GF_DIGIT_BITS) memset(&s.digits[result.len-192/GF_DIGIT_BITS], 0, (256/GF_DIGIT_BITS-result.len)*sizeof(digit));
            memcpy(&s.digits[64/GF_DIGIT_BITS], s.digits, 64/GF_DIGIT_BITS*sizeof(digit));
            s.len = 128/GF_DIGIT_BITS;
            s.shrink();
            add(s);

            if(result.len > 256/GF_DIGIT_BITS)
            {
                memset(s.digits, 0, 64/GF_DIGIT_BITS*sizeof(digit));
                memcpy(&s.digits[64/GF_DIGIT_BITS], &result.digits[256/GF_DIGIT_BITS], min(result.len-256/GF_DIGIT_BITS, 64/GF_DIGIT_BITS)*sizeof(digit));
                if(result.len < 320/GF_DIGIT_BITS) memset(&s.digits[result.len+(64-256)/GF_DIGIT_BITS], 0, (320/GF_DIGIT_BITS-result.len)*sizeof(digit));
                memcpy(&s.digits[128/GF_DIGIT_BITS], &s.digits[64/GF_DIGIT_BITS], 64/GF_DIGIT_BITS*sizeof(digit));
                s.len = GF_DIGITS;
                s.shrink();
                add(s);

                if(result.len > 320/GF_DIGIT_BITS)
                {
                    memcpy(s.digits, &result.digits[320/GF_DIGIT_BITS], min(result.len-320/GF_DIGIT_BITS, 64/GF_DIGIT_BITS)*sizeof(digit));
                    if(result.len < 384/GF_DIGIT_BITS) memset(&s.digits[result.len-320/GF_DIGIT_BITS], 0, (384/GF_DIGIT_BITS-result.len)*sizeof(digit));
                    memcpy(&s.digits[64/GF_DIGIT_BITS], s.digits, 64/GF_DIGIT_BITS*sizeof(digit));
                    memcpy(&s.digits[128/GF_DIGIT_BITS], s.digits, 64/GF_DIGIT_BITS*sizeof(digit));
                    s.len = GF_DIGITS;
                    s.shrink();
                    add(s);
                };
            };
        }
        else if(*this >= P) gfint::sub(*this, P);
#else
#error Unsupported GF
#endif
    };

    bool invert(const gfield &x)
    {
        if(!x.len) return false;
        gfint u(x), v(P), A((gfint::digit)1), C((gfint::digit)0);
        while(!u.iszero())
        {
            int ushift = 0, ashift = 0;
            while(!u.hasbit(ushift))
            {
                ushift++;
                if(A.hasbit(ashift))
                { 
                    if(ashift) { A.rshift(ashift); ashift = 0; }; 
                    A.add(P); 
                };
                ashift++;
            };
            if(ushift) u.rshift(ushift);
            if(ashift) A.rshift(ashift);
            int vshift = 0, cshift = 0;
            while(!v.hasbit(vshift))
            {
                vshift++;
                if(C.hasbit(cshift))
                { 
                    if(cshift) { C.rshift(cshift); cshift = 0; }; 
                    C.add(P); 
                };
                cshift++;
            };
            if(vshift) v.rshift(vshift);
            if(cshift) C.rshift(cshift);
            if(u >= v)
            {
                u.sub(v);
                if(A < C) A.add(P);
                A.sub(C);
            }
            else
            {
                v.sub(v, u);
                if(C < A) C.add(P);
                C.sub(A);
            };    
        };
        if(C >= P) gfint::sub(C, P);
        else { len = C.len; memcpy(digits, C.digits, len*sizeof(digit)); };
        ASSERT(*this < P);
        return true;
    };    
    void invert() { invert(*this); };
};

#if GF_BITS==192
gfield gfield::P("fffffffffffffffffffffffffffffffeffffffffffffffff");
#elif GF_BITS==224
gfield gfield::P("ffffffffffffffffffffffffffffffff000000000000000000000001");
#elif GF_BITS==256
gfield gfield::P("ffffffff00000001000000000000000000000000ffffffffffffffffffffffff");
#elif GF_BITS==384
gfield gfield::P("fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff0000000000000000ffffffff");
#elif GF_BITS==521
gfield gfield::P("1ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff");
#else
#error Unsupported GF
#endif

struct ecpoint
{
    gfield x, y;

    bool operator==(const ecpoint &p) const { return x==p.x && y==p.y; };

    void mul2(const ecpoint &p)
    {

    };

    void add(const ecpoint &p)
    {
        if(*this == p) return mul2(p);
    };
};

void testcrypt(char *s)
{
    bigint<64> p(gfield::P);
    printf("p: "); p.write(stdout); putchar('\n');
    p.rshift(3);
    printf("p/2^3: "); p.write(stdout); putchar('\n');
    p.rshift(32);
    printf("p/2^35: "); p.write(stdout); putchar('\n');
    p.rshift(64);
    printf("p/2^99: "); p.write(stdout); putchar('\n');

    bigint<64> one(1), t32(1), t64(1), t96(1), t128(1), t192(1), t224(1), t256(1), t384(1), t521(1), p192, p224, p256, p384, p521;
    t32.lshift(32); printf("2^%d: ", t32.numbits()-1); t32.write(stdout); putchar('\n');
    t64.lshift(64); printf("2^%d: ", t64.numbits()-1); t64.write(stdout); putchar('\n');
    t96.lshift(96); printf("2^%d: ", t96.numbits()-1); t96.write(stdout); putchar('\n');
    t128.lshift(128); printf("2^%d: ", t128.numbits()-1); t128.write(stdout); putchar('\n');
    t192.lshift(192); printf("2^%d: ", t192.numbits()-1); t192.write(stdout); putchar('\n');
    t224.lshift(224); printf("2^%d: ", t224.numbits()-1); t224.write(stdout); putchar('\n');
    t256.lshift(256); printf("2^%d: ", t256.numbits()-1); t256.write(stdout); putchar('\n');
    t384.lshift(384); printf("2^%d: ", t384.numbits()-1); t384.write(stdout); putchar('\n');
    t521.lshift(521); printf("2^%d: ", t521.numbits()-1); t521.write(stdout); putchar('\n');

    p192.sub(t192, t64); p192.sub(one);
    printf("p192: "); p192.write(stdout); putchar('\n');
    p224.sub(t224, t96); p224.add(one);
    printf("p224: "); p224.write(stdout); putchar('\n');
    p256.sub(t256, t224); p256.add(t192); p256.add(t96); p256.sub(one);
    printf("p256: "); p256.write(stdout); putchar('\n');
    p384.sub(t384, t128); p384.sub(t96); p384.add(t32); p384.sub(one);
    printf("p384: "); p384.write(stdout); putchar('\n');
    p521.sub(t521, one);
    printf("p521: "); p521.write(stdout); putchar('\n');

    gfield x(s);
    printf("x: "); x.write(stdout); putchar('\n');
    x.invert();
    printf("1/x: "); x.write(stdout); putchar('\n');
    x.mul(gfield(s));
    printf("(1/x)*x: "); x.write(stdout); putchar('\n');
    x.sub(gfield(42));
    printf("(1/x)*x-42: "); x.write(stdout); putchar('\n');
};

COMMAND(testcrypt, "s");

    

