from textwrap import wrap as split_chunks
from pprint import pprint
from collections import Counter, defaultdict
from operator import itemgetter
import itertools
import statistics
import string
import sys

# A map of bytes to frequency, for English text.
# The frequencies in _DATA are a-z + space
ETEXT_FREQUENCY_DATA = map(float, "0.0651738 0.0124248 0.0217339 0.0349835 0.1041442 0.0197881 0.0158610 0.0492888 0.0558094 0.0009033 0.0050529 0.0331490 0.0202124 0.0564513 0.0596302 0.0137645 0.0008606 0.0497563 0.0515760 0.0729357 0.0225134 0.0082903 0.0171272 0.0013692 0.0145984 0.0007836 0.1918182".split())
ETEXT_FREQUENCY_MAP = defaultdict(float, zip(
        map(ord, string.ascii_lowercase + ' '),
        ETEXT_FREQUENCY_DATA
    ))

ALS_BYTES = [ord(x) for x in string.ascii_lowercase + ' ']

def main():
    if len(sys.argv) < 2:
        print(f"Usage: {sys.argv[0]} <ciphertext filename>")
        sys.exit(1)

    # Read ciphertext
    fname = sys.argv[1]
    with open(fname, 'r') as f:
        data = f.readlines()[0].strip()
    
    # Decode hex to list of bytes
    decoded_data = bytes(int(x, base=16) for x in split_chunks(data, 2))
    
    # Encrypt.py indicates that key length varies between 1 and 13.
    # Plan of attack:
    #  1. Generate streams for key lengths between 1 and 13,
    #  2. Do frequency analysis (compute sum(q_i^2) for i = 0..255)
    #     If key length is wrong, that should be around 1/256,
    #     if it's right, much larger.
    #  3. Take largest sum as key length
    #  4. Generate possible decryptions for each stream
    #  5. Evaluate those on criteria like valid characters, frequency
    #  6. Determine key based on evaluation
    #  7. Decrypt!
    
    avg_freqs = []
    
    for length in range(1, 14):
        strs = streams(decoded_data, length)
        freqs = map(analyze_freq, strs)
        mean = statistics.mean(freqs)
        avg_freqs.append(mean)
        
    for i, a in enumerate(avg_freqs):
        print(f"Avg square frequency for key length {i+1}: {a}")
    
    key_length = max(enumerate(avg_freqs), key=itemgetter(1))[0] + 1
    
    print(f"Likely key length: {key_length}")
    
    key = []
    
    for n, stream in enumerate(streams(decoded_data, key_length)):
        possible_decryptions = possible_stream_decryptions(stream)
        possible_decryptions = filter(valid_characters, possible_decryptions)
        l = [frequency_check(pd) for pd in possible_decryptions]
        probable_key_byte = max(l, key=itemgetter(1))[0]
        print(probable_key_byte)
        key.append(probable_key_byte)
    
    print(decrypt(decoded_data, key))

def streams(ciphertext, key_length):
    """
    Returns sequences for each of the n streams in a ciphertext encrypted
    with key length n. I.e., if we have the ciphertext string 'kaassaus',
    encrypted with a key of length 2, this function will return
    
    ["kasu", "asas"].
    """
    return [ciphertext[idx::key_length] for idx in range(key_length)]

def analyze_freq(stream):
    """
    Counts the amount of times each byte is in a stream using a collections.Counter,
    then divides by the amount of bytes to get its frequency. Returns the sum of
    the squares of the frequencies.
    """
    c = Counter(stream)
    total_chars = len(stream)
    freq_map = {k : (v / total_chars) for k, v in c.items()}
    return sum(v**2 for v in freq_map.values())

def valid_characters(candidate):
    """
    Plaintexts with non-printable characters can be discarded,
    as well as plaintexts with characters that don't usually
    appear in English text. Their presence is not impossible,
    though, so if decryption fails you might want to tweak them.
    """
    within_32_127 = all(32 <= x <= 127 for x in candidate[1])
    
    # There's other letters which are very unlikely to be in English text
    # If the decryption looks like garbage you can try removing or changing
    # some of these.
    bad_letters = ['=', '}', '{', ']', '[', '>', '<', '-', '#']
    for bad_letter in bad_letters:
        if ord(bad_letter) in candidate[1]:
            return False
    return within_32_127

def frequency_check(candidate):
    """
    Compute sum(q_i * p_i), where p are the known frequencies
    of English plaintext, and q is computed from the candidate.
    """
    c = Counter(candidate[1])
    total_chars = len([x for x in candidate[1]
                       if x in ALS_BYTES])
                       
    freq_map = defaultdict(float, {k : (v / total_chars) for k, v in c.items()
                if k in ALS_BYTES})
    
    return (candidate[0], sum(freq_map[k] * ETEXT_FREQUENCY_MAP[k]
               for k in ALS_BYTES))


def possible_stream_decryptions(stream):
    """
    Get all possible plaintexts for a given ciphertext stream. 
    Returns a tuple of (key byte, plaintext stream)
    """
    for i in range(256):
        yield (i, [x^i for x in stream])

def decrypt(ciphertext, key):
    """ Decrypt a given ciphertext using the provided key. """
    plaintext = []
    for byte, key_byte in zip(ciphertext, itertools.cycle(key)):
        plaintext.append(byte ^ key_byte)
    return ''.join(map(chr, plaintext))


if __name__ == '__main__':
    main()