#!/usr/bin/env python2
from __future__ import print_function

import os
import re
import shlex
import signal
import subprocess
import sys
import time

import pexpect

STUDENT_SHELL = "./42sh"

# Some global state - set by one (or more) test and used later to subtract
# points
valgrind_failed = None
valgrind_output = None
compiler_warnings = None

# Set per-testcase and used by the function actually running the test XXX nasty
run_with_valgrind = True

# Set by every run_cmd so the signal handler knows where it failed.
last_command = ""

# C files added by student - we need these during compilation
additional_sources = ""


class TestError(Exception):
    pass


# Helper functions for performing common tests
def eq(a, b, name="Objects"):
    if a != b:
        if isinstance(a, str) and a.count('\n') > 3:
            raise TestError("%s not equal: \"%s\" and \"%s\"" % (name, a, b))
        else:
            raise TestError("%s not equal: %s and %s" % (name, repr(a),
                                                         repr(b)))


# Test case definition
class Test():
    def __init__(self, name, func, valgrind=False):
        self.name, self.func, self.valgrind = name, func, valgrind


# Collection of testcases worth n points (i.e. one item in the grading scheme)
class TestGroup():
    def __init__(self, name, points, *tests, **kwargs):
        self.name = name
        self.points = float(points)
        self.tests = tests
        self.stop_if_fail = kwargs.get("stop_if_fail", False)


def test_groups(groups):
    points = 0.0
    for group in groups:
        print('= ' + group.name)
        succeeded = 0
        for test in group.tests:
            global run_with_valgrind
            run_with_valgrind = test.valgrind
            print('\t' + test.name, end=': ')
            try:
                test.func()
            except TestError as e:
                print("FAIL")
                print(e.args[0])
                if group.stop_if_fail:
                    break
            else:
                print("OK")
                succeeded += 1
        perc = (1. * succeeded / len(group.tests))
        if group.points < 0:
            perc = 1 - perc
        grouppoints = group.points * perc
        if group.points > 0:
            print(" Passed %d/%d tests, %.2f/%.2f points" %
                  (succeeded, len(group.tests), grouppoints, group.points))
        else:
            if perc > 0:
                print(" Failed, subtracting %.2f points" % abs(grouppoints))
            else:
                print(" Passed")
        points += grouppoints
        if group.stop_if_fail and succeeded != len(group.tests):
            break
    return points


def run():
    basic_tests = [
        TestGroup(
            "Compilation",
            0.5,
            Test("Clean",
                 check_cmd("make moreclean ADDITIONAL_SOURCES=\"%s\"" %
                           additional_sources)),
            Test("Make", check_cmd_clangwarnings),
            Test("Has valid binary", check_cmd("./42sh -h")),
            stop_if_fail=True),
        TestGroup(
            "Simple commands",
            1,
            Test("Simple", bash_cmp("pwd"), valgrind=True),
            Test("Arguments", bash_cmp("ls -alh ..")),
            Test("Wait for 1 proc", test_wait("sleep 0.2", 0.2)),
            stop_if_fail=True),
        TestGroup(
            "exit builtin",
            1,
            Test("exit", test_exit), ),
        TestGroup(
            "cd builtin",
            1,
            Test("cd", test_cd), ),
        TestGroup(
            "Sequences",
            1,
            Test("Simple sequence", bash_cmp("pwd; uname"), valgrind=True),
            Test("Nested sequences", bash_cmp("pwd; cd /; pwd")),
            Test("Wait for sequence", test_wait("sleep 0.1; sleep 0.2", 0.3)),
        ),
        TestGroup(
            "Pipes",
            1,
            Test("One pipe", bash_cmp("ls -alh / | grep lib"), valgrind=True),
            Test("cd in pipe", bash_cmp("cd / | pwd")),
            Test("exit in pipe", bash_cmp("exit 1 | pwd")),
            Test("Wait for pipes 1", test_wait("sleep 0.2 | sleep 0.1", 0.2)),
            Test("Wait for pipes 2", test_wait("sleep 0.1 | sleep 0.2", 0.2)),
        ),
        TestGroup(
            "Valgrind memcheck (basic)",
            -1,
            Test("valgrind", check_valgrind), ),
        TestGroup(
            "Compiler warnings",
            -1,
            Test("clang", check_warnings), ),
        TestGroup(
            "Errors",
            -1.0,
            Test("Binary not in path", test_errors), )
    ]

    advanced_tests = [
        TestGroup(
            "Pipes >2p with seqs or pipes",
            1.0,
            Test("Multiple pipes",
                 bash_cmp("ls -alh / | grep lib | grep -v 32 | tac")),
            Test("Seq in pipe", bash_cmp("ls | { grep c ; ls .. ; } | tac")),
            Test("Seq wait 1",
                 test_wait("{ sleep 0.1 | sleep 0.2; }; exit 1", 0.2)),
            Test("Seq wait 2",
                 test_wait("{ sleep 0.2 | sleep 0.1; }; exit 1", 0.2)),
            Test("Seq wait 3",
                 test_wait("{ sleep 0.2 | sleep 0.3 | sleep 0.1; }; exit 1",
                           0.3)),
            # TODO? (also check with ps?)
        ),
        TestGroup(
            "Redirections",
            1.0,
            Test("To/from file", bash_cmp(">a ls /bin; <a wc -l")),
            Test("Overwrite", bash_cmp(">a ls /bin; >a ls; cat a")),
            Test("Append", bash_cmp(">a ls; >>a pwd; cat a")),
            Test("Errors", bash_cmp(">a 2>&1 du /etc/.", check_rv=False)),
            Test("Errors to out",
                 bash_cmp("2>&1 >/dev/null find /etc/.", check_rv=False)), ),
        TestGroup(
            "Detached commands",
            0.5,
            Test("sleep", test_detach),
            Test(
                "Sequence",
                bash_cmp("{ sleep 0.1; echo hello; }& echo world; sleep 0.3")),
        ),
        TestGroup(
            "Subshells",
            0.5,
            Test("exit", bash_cmp("(pwd; exit 2); exit 1")),
            Test("cd", bash_cmp("cd /bin; pwd; (cd /; pwd); pwd")), ),
        TestGroup(
            "Environment variables",
            1.0,
            Test("Simple",
                 manual_cmp(
                     "set hello=world; env | grep hello",
                     out="hello=world\n",
                     err="")),
            Test("Subshell",
                 manual_cmp(
                     "set hello=world; (set hello=bye); env | grep hello",
                     out="hello=world\n",
                     err="")),
            # To see it's pwd erroring it can't be found, not unset, we have a
            # random command inbetween which creates a marker in stdout.
            Test("Unset",
                 manual_cmp(
                     "set hoi=daar; env | grep ^hoi; " +
                     "unset hoi; env | grep ^hoi",
                     out="hoi=daar\n",
                     err="",
                     rv=0)), ),
        TestGroup(
            "Job control",
            2.0,
            Test("Ctrl-z", test_ctrl_z),
            Test("Detach + fg", test_detach_fg), ),
    ]

    points = test_groups(basic_tests)
    totalpoints = sum([g.points for g in basic_tests if g.points > 0])

    if points >= 5.0:
        print("Passed basic tests with enough points, doing advanced tests")
        points += test_groups(advanced_tests)
        totalpoints += sum([g.points for g in advanced_tests if g.points > 0])
    else:
        print("Didnt get enough points for the basic tests, aborting.")
        print("Got %.2f points, need at least 5.0" % points)

    print()
    print("Executed all tests, got %.2f/%.2f points in total" % (points,
                                                                 totalpoints))


def run_cmd(cmd, shell=None, prefix=None):
    global last_command
    last_command = cmd
    shell = shell or STUDENT_SHELL
    prefix = prefix or []
    return subprocess.Popen(
        prefix + [shell, "-c", cmd],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        stdin=subprocess.PIPE)


def run_42sh(cmd):
    global valgrind_failed, valgrind_output
    if not run_with_valgrind or valgrind_failed is not None:
        proc = run_cmd(cmd)
        out, err = proc.communicate()
        return proc.returncode, out, err
    else:
        proc = run_cmd(cmd, prefix=["valgrind", "--log-file=_valgrind.out"])
        out, err = proc.communicate()

        with open("_valgrind.out") as f:
            o = f.read()
        for line in o.split("\n"):
            if 'ERROR SUMMARY' in line:
                nerrors = int(line.split()[3])
                if nerrors:
                    valgrind_failed = cmd
                    valgrind_output = o
                break

        return proc.returncode, out, err


def check_valgrind():
    if valgrind_failed is None:
        global run_with_valgrind
        run_with_valgrind = True
        run_42sh("pwd | (ls -alh /) | (pwd|pwd) | {pwd;pwd} | " +
                 "{{pwd|pwd};pwd} | 2>&1 >/dev/null ls . fdsfsa | >a pwd | " +
                 ">>b pwd | (pwd &) | cd / | exit 31")
    if valgrind_failed is not None:
        raise TestError("Valgrind failed for command %s:\n%s" %
                        (valgrind_failed, valgrind_output))


def check_warnings():
    if compiler_warnings is not None:
        raise TestError("Got compiler warnings:\n%s" % compiler_warnings)


def check_cmd_clangwarnings():
    out, err = check_cmd(
        "make ADDITIONAL_SOURCES=\"%s\"" % additional_sources)()
    err = '\n'.join([l for l in err.split("\n") if not l.startswith("make:")])
    if "warning" in err:
        global compiler_warnings
        compiler_warnings = err


def check_authors():
    if not os.path.isfile('AUTHORS'):
        raise TestError("AUTHORS file not found")
    with open("AUTHORS") as f:
        authors = f.read()
    names, nums = [], []
    for line in authors.split('\n'):
        chunks = re.split(r"\s*\b(?:and|en)\b\s*", line)
        nw = r"[^&:+,\s\d]+"
        for c in chunks:
            names += re.findall(r"(%s\s+(?:%s\s+)*%s)" % (nw, nw, nw), c, re.U)
            nums += re.findall(r"\b\d+\b", c, re.U)
    print()
    print("   Names: %s" % (' + '.join(names)))
    print("   Nums: %s" % (' + '.join(nums)))
    if not len(names):
        raise TestError("No names found")
    if not len(nums):
        raise TestError("No student numbers found")


def check_cmd(cmd):
    def check_cmd_inner():
        args = shlex.split(cmd)
        p = subprocess.Popen(
            args,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            stdin=subprocess.PIPE)
        stdout, stderr = p.communicate()

        if p.returncode:
            raise TestError(
                "Command returned non-zero value.\n" +
                "Command: %s\nReturn code: %d\nstdout: %s\nstderr: %s" %
                (cmd, p.returncode, stdout, stderr))
        return stdout, stderr

    return check_cmd_inner


def bash_cmp(cmd, check_rv=True):
    def bash_cmp_inner():
        rv, stdout1, stderr1 = run_42sh(cmd)
        p2 = run_cmd(cmd, shell="bash")
        stdout2, stderr2 = p2.communicate()

        try:
            eq(stdout1, stdout2, "stdout")
            eq(stderr1, stderr2, "stderr")
            if check_rv:
                eq(rv, p2.returncode, "return value")
        except TestError as e:
            raise TestError("Error while comparing your shell output to " +
                            "bash.\nCommand: %s\n%s" % (cmd, e.args[0]))

    return bash_cmp_inner


def manual_cmp(cmd, out=None, err=None, rv=None):
    def manual_cmp_inner():
        rv1, stdout1, stderr1 = run_42sh(cmd)

        try:
            if out is not None:
                eq(stdout1, out, "stdout")
            if err is not None:
                eq(stderr1, err, "stderr")
            if rv is not None:
                eq(rv1, rv, "return value")
        except TestError as e:
            raise TestError("Error while comparing your shell output to " +
                            "expected output.\nCommand: %s\n%s" % (cmd,
                                                                   e.args[0]))

    return manual_cmp_inner


def test_wait(cmd, timeout, out='', err=''):
    def wait():
        start_time = time.time()
        stdout, stderr = run_cmd(cmd).communicate()
        end_time = time.time()
        eq(stdout, out, "stdout")
        eq(stderr, err, "stderr")

        if not (timeout - 0.03 < end_time - start_time < timeout + 0.03):
            raise TestError("Command did not finish in expected time.\n" +
                            "Command: %s\nExpected time: %f\nTime taken: %f" %
                            (cmd, timeout, end_time - start_time))

    return wait


def test_exit():
    global last_command
    last_command = "exit 42 ... make sure your shell prompt contains a '$'"

    p = pexpect.spawn(STUDENT_SHELL)
    p.expect('\$')
    p.sendline("exit 42")
    try:
        p.expect('\$')
    except pexpect.EOF:
        try:
            p.wait()
        except pexpect.ExceptionPexpect:
            pass
        eq(p.exitstatus, 42, "exit status")
        return

    raise TestError("Shell did not exit on 'exit' command.")


def test_detach():
    global last_command
    last_command = "sleep 1 &"

    p = pexpect.spawn(STUDENT_SHELL)
    p.expect('\$')
    p.sendline("sleep 1 &")
    stime = time.time()
    p.expect('\$')
    if time.time() - stime > 0.1:
        raise TestError("Detach did not return immediately.\n" +
                        "Command: sleep 1 &")
    p.sendline("ps")
    stime = time.time()
    p.expect('\$')
    if time.time() - stime > 0.1:
        raise TestError("Command after detach did not return immediately.\n" +
                        "Command: sleep 1 & \\n ps")
    if "sleep" not in p.before:
        raise TestError("Detached command not in ps output.\n" +
                        "Command: sleep 1 &")


def test_cd():
    global last_command
    last_command = "cd /tmp; pwd"

    p = subprocess.Popen(
        [STUDENT_SHELL],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        stdin=subprocess.PIPE)
    stdout, stderr = p.communicate("cd /tmp\npwd\n")
    eq(stdout, "cd /tmp\npwd\n/tmp\n", "stdout")  # readline echo's input back
    eq(stderr, "", "stderr")


def sleep_status(p):
    p.sendline("ps t")
    p.expect('\$')

    for line in p.before.split('\n'):
        if 'sleep' in line:
            return line.split()[2]

    raise TestError("Sleep not found in background.")


def test_ctrl_z():
    global last_command
    last_command = "sleep 2"

    p = pexpect.spawn(STUDENT_SHELL)
    p.expect('\$')

    try:
        p.sendline("sleep 2")
        stime = time.time()
        p.sendcontrol('z')
        p.expect('\$')
        if time.time() - stime > 0.1:
            raise TestError("Sleep was not stopped by SIGTSTP in time.")

        if sleep_status(p) != 'T':
            raise TestError("Sleep found in background, but not stopped.")
    except pexpect.EOF:
        raise TestError("Shell exited due to SIGTSTP (ctrl-z)")
    p.close()


def test_detach_fg():
    global last_command
    last_command = "sleep 0.5 &"

    p = pexpect.spawn(STUDENT_SHELL)
    p.expect('\$')

    p.sendline("sleep 0.5 &")
    stime = time.time()
    p.expect('\$')
    if sleep_status(p) != 'S':
        raise TestError("Sleep not running in background.")

    p.sendline("fg")
    p.expect('\$')
    if time.time() - stime < 0.45:
        raise TestError("Sleep did not finish properly with a wait after fg.")

    p.close()

def test_errors():
    rv, stdout, stderr = run_42sh("blablabla")
    eq(stdout, "", "stdout")
    if "No such file or directory" not in stderr:
        raise TestError("String \"No such file or directory\" not found in " +
                        "stderr: use perror if execvp fails.")


def do_additional_params(lst, name, suffix=''):
    for f in lst:
        if not f.endswith(suffix):
            raise TestError("File does not end with %s in %s: '%s'" %
                            (suffix, name, f))
        if '"' in f:
            raise TestError("No quotes allowed in %s: '%s'" % (name, f))
        if '/' in f:
            raise TestError("No slashes allowed in %s: '%s'" % (name, f))
        if '$' in f:
            raise TestError("No $ allowed in %s: '%s'" % (name, f))
        if f.startswith('-'):
            raise TestError("No flags allowed in %s: '%s'" % (name, f))
        if os.path.isfile(f):
            raise TestError("Trying to overwrite framework file in %s: '%s'" %
                            (name, f))
        check_cmd("tar xvf archive.tar.gz %s" % f)()


def handle_sigterm(signum, frame):
    raise Exception("SIGTERM while executing command:\n\"%s\"" % last_command)


if __name__ == '__main__':
    signal.signal(signal.SIGTERM, handle_sigterm)
    try:
        run()
    except Exception as e:
        print("\n\nTester got exception: %s" % e.args[0])
