#include <sys/time.h>
#include <sys/types.h>
#include <unistd.h>
#include <errno.h>
#include <signal.h>

#include "pch.h"
#include "cube.h"
#include "iengine.h"
#include "crypto.h"
#include "hash.h"
#include <enet/time.h>

#define INPUT_LIMIT 4096
#define OUTPUT_LIMIT (64*1024)
#define CLIENT_TIME (3*60*1000)
#define AUTH_TIME (60*1000)
#define AUTH_LIMIT 100
#define CLIENT_LIMIT 4096
#define DUP_LIMIT 16

FILE *logfile = NULL;

enum
{
    PRIV_MASTER = 1<<0
};

struct userinfo
{
    char *name;
    ecjacobian pubkey;
};
hashtable<char *, userinfo> users;

void adduser(char *name, char *pubkey)
{
    name = newstring(name);
    userinfo &u = users[name];
    u.name = name;
    u.pubkey.parse(pubkey);
}

COMMAND(adduser, "ss");

void clearusers()
{
    enumerate(users, userinfo, u, delete[] u.name);
    users.clear();
}

COMMAND(clearusers, "");

struct authreq
{
    enet_uint32 reqtime; 
    uint id;
    gfint answer;
};

struct client
{
    ENetAddress address;
    ENetSocket socket;
    char input[INPUT_LIMIT];
    vector<char> output;
    int inputpos, outputpos;
    enet_uint32 connecttime;
    vector<authreq> authreqs;

    client() : inputpos(0), outputpos(0) {}
};  
vector<client *> clients;

ENetSocket serversocket = ENET_SOCKET_NULL;

time_t starttime;
enet_uint32 curtime = 0;

void fatal(const char *fmt, ...)
{
    va_list args;
    va_start(args, fmt);
    vfprintf(logfile, fmt, args);
    va_end(args);
    exit(EXIT_FAILURE);
}

void conoutfv(int type, const char *fmt, va_list args)
{
    vfprintf(logfile, fmt, args);
}

void conoutf(const char *fmt, ...)
{
    va_list args;
    va_start(args, fmt);
    conoutfv(CON_INFO, fmt, args);
    va_end(args);
}

void conoutf(int type, const char *fmt, ...)
{
    va_list args;
    va_start(args, fmt);
    conoutfv(type, fmt, args);
    va_end(args);
}

void setupserver(int port, const char *ip = NULL)
{
    ENetAddress address;
    address.host = ENET_HOST_ANY;
    address.port = port;

    if(ip)
    {
        if(enet_address_set_host(&address, ip)<0)
            fatal("failed to resolve server address: %s\n", ip);
    }
    serversocket = enet_socket_create(ENET_SOCKET_TYPE_STREAM);
    if(serversocket==ENET_SOCKET_NULL || 
       enet_socket_set_option(serversocket, ENET_SOCKOPT_REUSEADDR, 1) < 0 ||
       enet_socket_bind(serversocket, &address) < 0 ||
       enet_socket_listen(serversocket, -1) < 0)
        fatal("failed to create server socket\n");
    if(enet_socket_set_option(serversocket, ENET_SOCKOPT_NONBLOCK, 1)<0)
        fatal("failed to make server socket non-blocking\n");

    enet_time_set(0);
    
    starttime = time(NULL);
    char *ct = ctime(&starttime);
    if(strchr(ct, '\n')) *strchr(ct, '\n') = '\0';
    conoutf("*** Starting authserver on %s %d at %s ***\n", ip ? ip : "localhost", port, ct);
}

void purgeclient(int n)
{
    client &c = *clients[n];
    enet_socket_destroy(c.socket);
    delete clients[n];
    clients.remove(n);
}

void purgeauthreqs(client &c)
{
}

void output(client &c, const char *msg, int len = 0)
{
    if(!len) len = strlen(msg);
    c.output.put(msg, len);
}

void outputf(client &c, const char *fmt, ...)
{
    string msg;
    va_list args;
    va_start(args, fmt);
    formatstring(msg, fmt, args);
    va_end(args);

    output(c, msg);
}

void purgeauths(client &c)
{
    int expired = 0;
    loopv(c.authreqs)
    {
        if(ENET_TIME_DIFFERENCE(curtime, c.authreqs[i].reqtime) >= AUTH_TIME) 
        {
            outputf(c, "failauth %u\n", c.authreqs[i].id);
            expired = i + 1;
        }
        else break;
    }
    if(expired > 0) c.authreqs.remove(0, expired);
}

void reqauth(client &c, uint id, char *name)
{
    purgeauths(c);
    
    time_t t = time(NULL);
    char *ct = ctime(&t);
    if(ct) 
    { 
        char *newline = strchr(ct, '\n');
        if(newline) *newline = '\0'; 
    }
    string ip;
    if(enet_address_get_host_ip(&c.address, ip, sizeof(ip)) < 0) s_strcpy(ip, "-");
    conoutf("%s: attempting \"%s\" as %u from %s\n", ct ? ct : "-", name, id, ip);

    userinfo *u = users.access(name);
    if(!u)
    {
        outputf(c, "failauth %u\n", id);
        return;
    }

    if(c.authreqs.length() >= AUTH_LIMIT)
    {
        outputf(c, "failauth %u\n", c.authreqs[0].id);
        c.authreqs.remove(0);
    }

    authreq &a = c.authreqs.add();
    a.reqtime = curtime;
    a.id = id;
    uint seed[3] = { starttime, curtime, randomMT() };
    tiger::hashval hash;
    tiger::hash((uchar *)seed, sizeof(seed), hash);
    gfint challenge;
    memcpy(challenge.digits, hash.bytes, sizeof(challenge.digits));
    challenge.len = 8*sizeof(hash.bytes)/BI_DIGIT_BITS;
    challenge.shrink();

    ecjacobian answer(u->pubkey);
    answer.mul(challenge);
    answer.normalize();
    a.answer = answer.x;

    //printf("expecting %u for user %s to be ", id, u->name);
    //a.answer.print(stdout);
    //printf(" given secret ");

    ecjacobian secret(ecjacobian::base);
    secret.mul(challenge);
    secret.normalize();

    static vector<char> buf;
    buf.setsizenodelete(0);
    secret.print(buf);
    buf.add('\0');
   
    //printf("%s\n", buf.getbuf());

    outputf(c, "chalauth %u %s\n", id, buf.getbuf());
}

void confauth(client &c, uint id, const char *val)
{
    purgeauths(c);

    loopv(c.authreqs) if(c.authreqs[i].id == id)
    {
        gfint answer(val);
        string ip;
        if(enet_address_get_host_ip(&c.address, ip, sizeof(ip)) < 0) s_strcpy(ip, "-");
        if(answer == c.authreqs[i].answer) 
	{
            outputf(c, "succauth %u\n", id);
	    conoutf("succeeded %u from %s\n", id, ip);
        }	
        else 
        {
            outputf(c, "failauth %u\n", id);
            conoutf("failed %u from %s\n", id, ip);
        }
        c.authreqs.remove(i--);
        return;
    }
    outputf(c, "failauth %u\n", id);
}

bool checkclientinput(client &c)
{
    if(c.inputpos<0) return true;
    char *end = (char *)memchr(c.input, '\n', c.inputpos);
    while(end)
    {
        *end++ = '\0';

        uint id;
        string user, val;
        if(sscanf(c.input, "reqauth %u %100s", &id, user) == 2)
        {
            reqauth(c, id, user);
        }
        else if(sscanf(c.input, "confauth %u %100s", &id, val) == 2)
        {
            confauth(c, id, val);
        }
        c.inputpos = &c.input[c.inputpos] - end;
        memmove(c.input, end, c.inputpos);

        end = (char *)memchr(c.input, '\n', c.inputpos);
    }
    return c.inputpos<(int)sizeof(c.input);
}

fd_set readset, writeset;

void checkclients()
{
    fd_set readset, writeset;
    int nfds = serversocket;
    FD_ZERO(&readset);
    FD_ZERO(&writeset);
    FD_SET(serversocket, &readset);
    loopv(clients)
    {
        client &c = *clients[i];
        if(c.outputpos < c.output.length()) FD_SET(c.socket, &writeset);
        else FD_SET(c.socket, &readset);
        nfds = max(nfds, c.socket);
    }
    timeval tv;
    tv.tv_sec = 1;
    tv.tv_usec = 0;
    if(select(nfds+1, &readset, &writeset, NULL, &tv)<=0) return;

    curtime = enet_time_get();
    if(FD_ISSET(serversocket, &readset))
    {
        ENetAddress address;
        ENetSocket clientsocket = enet_socket_accept(serversocket, &address);
        if(clients.length()>=CLIENT_LIMIT) enet_socket_destroy(clientsocket);
        else if(clientsocket!=ENET_SOCKET_NULL)
        {
            int dups = 0, oldest = -1;
            loopv(clients) if(clients[i]->address.host == address.host) 
	    {
		dups++;
		if(oldest<0 || clients[i]->connecttime < clients[oldest]->connecttime) oldest = i;
	    }
            if(dups >= DUP_LIMIT) purgeclient(oldest);
                
	    client *c = new client;
            c->address = address;
            c->socket = clientsocket;
            c->connecttime = curtime;
            clients.add(c);
        }
    }

    loopv(clients)
    {
        client &c = *clients[i];
        if(c.outputpos < c.output.length() && FD_ISSET(c.socket, &writeset))
        {
            ENetBuffer buf;
            buf.data = (void *)&c.output[c.outputpos];
            buf.dataLength = c.output.length()-c.outputpos;
            int res = enet_socket_send(c.socket, NULL, &buf, 1);
            if(res>=0) 
            {
                c.outputpos += res;
                if(c.outputpos>=c.output.length())
                {
                    c.output.setsizenodelete(0);
                    c.outputpos = 0;
                }
            }
            else { purgeclient(i--); continue; }
        }
        if(FD_ISSET(c.socket, &readset))
        {
            ENetBuffer buf;
            buf.data = &c.input[c.inputpos];
            buf.dataLength = sizeof(c.input) - c.inputpos;
            int res = enet_socket_receive(c.socket, NULL, &buf, 1);
            if(res>0)
            {
                c.inputpos += res;
                c.input[min(c.inputpos, (int)sizeof(c.input)-1)] = '\0';
                if(!checkclientinput(c)) { purgeclient(i--); continue; }
            }
            else { purgeclient(i--); continue; }
        }
        if(c.output.length() > OUTPUT_LIMIT) { purgeclient(i--); continue; }
#if 0
        if(ENET_TIME_DIFFERENCE(curtime, c.connecttime) >= CLIENT_TIME) { purgeclient(i--); continue; }
#endif
    }
}

volatile bool reloadcfg = true;

void reloadsignal(int signum)
{
    reloadcfg = true;
}

#define CONFIG_DIR "/path/to/config/"

int main(int argc, char **argv)
{
    char *ip = NULL;
    int port = 28787;
    if(argc>=2) port = atoi(argv[1]);
    if(argc>=3) ip = argv[2];

    logfile = fopen(CONFIG_DIR "auth.log", "a");
    if(!logfile) logfile = stdout;
    setvbuf(logfile, NULL, _IOLBF, 0);
    signal(SIGUSR1, reloadsignal);
    setupserver(port, ip);
    for(;;) 
    {
        if(reloadcfg)
        {
            conoutf("reloading auth.cfg\n");
            exec(CONFIG_DIR "auth.cfg");
            reloadcfg = false;
        }

        checkclients();
    }

    return EXIT_SUCCESS;
}


