import numpy as np

from pyics import Model
import random


def decimal_to_base_k(n, k):
    """Converts a given decimal (i.e. base-10 integer) to a list containing the
    base-k equivalant.

    For example, for n=34 and k=3 this function should return [1, 0, 2, 1]."""
    lrg_pow = 1
    while 1:
        if n >= k**lrg_pow:
            lrg_pow += 1
        else:
            break

    num_list = [0] * (lrg_pow)

    while n:
        sub_num = k**lrg_pow
        if n >= sub_num:
            n -= sub_num
            num_list[-(lrg_pow + 1)] += 1
        else:
            lrg_pow -= 1

    return num_list


def base_k_to_decimal(n, k):
    """Converts a base k list to an int"""
    num = 0
    for i in range(0, len(n)):
        num += n[-(i + 1)] * (k**i)
    return int(num)

def randint_neq(x, y, n):
    """Generates a random int between x and (including) y not equal to n"""
    while 1:
        r = random.randint(x, y)
        if r != n:
            return r


class CASim(Model):
    def __init__(self):
        Model.__init__(self)

        self.t = 0
        self.rule_set = []
        self.config = None

        self.prev_states = {}

        self.make_param('r', 1)
        self.make_param('k', 2)
        self.make_param('width', 50)
        self.make_param('height', 50)

        self.state_q = random.randint(0, self.k - 1)
        self.rule_length = self.k**(2*self.r + 1)

        self.make_param('labda', 0.0, setter=self.set_labda)
        self.make_param('build_random_table', 1)
        #self.make_param('rule', 30, setter=self.setter_rule)

    def set_labda(self, man_l):
        """Setter for the labda parameter, rounding it to the nearest
        valid value if it isn't one already"""
        pos_l = np.asarray([float(x) / self.rule_length for x in range(self.rule_length + 1)])
        index = (np.abs(pos_l - man_l)).argmin()
        l = pos_l[index]
        self.n_state_q = self.rule_length * (1 - l)
        return pos_l[index]

    def setter_rule(self, val):
        """Setter for the rule parameter, clipping its value between 0 and the
        maximum possible rule number."""
        rule_set_size = self.k ** (2 * self.r + 1)
        max_rule_number = self.k ** rule_set_size
        return max(0, min(val, max_rule_number - 1))

    def build_rule_set(self):
        """Sets the rule set for the current labda.
        A rule set is a list with the new state for every old configuration.
        """
        if self.build_random_table:
            self.rule_set = [0] * self.rule_length
            for i in range(self.rule_length):
                r = random.random()
                if r > self.labda and self.rule_set.count(self.state_q) < self.n_state_q:
                    self.rule_set[i] = self.state_q
                else:
                    self.rule_set[i] = randint_neq(0, self.k - 1, self.state_q)
        else:
            self.rule_set = [self.state_q] * self.rule_length
            for i in range(int(self.rule_length - self.n_state_q)):
                while 1:
                    index = random.randint(0, self.rule_length - 1)
                    if self.rule_set[index] == self.state_q:
                        self.rule_set[index] = randint_neq(0, self.k - 1, self.state_q)
                        break

    def check_rule(self, inp):
        """Returns the new state based on the input states.

        The input state will be an array of 2r+1 items between 0 and k, the
        neighbourhood which the state of the new cell depends on."""
        return self.rule_set[-base_k_to_decimal(inp, self.k) - 1]

    def setup_initial_row(self, setup_random=True):
        """Returns an array of length `width' with the initial state for each of
        the cells in the first row. Values should be between 0 and k."""
        if setup_random:
            return [random.randint(0, self.k - 1) for i in range(0, self.width)]

    def reset(self):
        """Initializes the configuration of the cells and converts the entered
        rule number to a rule set."""

        self.t = 0
        self.prev_states = {}
        self.state_q = random.randint(0, self.k - 1)
        self.rule_length = self.k**(2*self.r + 1)
        self.config = np.zeros([self.height, self.width])
        self.config[0, :] = self.setup_initial_row()
        self.build_rule_set()

    def draw(self):
        """Draws the current state of the grid."""
        import matplotlib
        import matplotlib.pyplot as plt

        plt.cla()
        if not plt.gca().yaxis_inverted():
            plt.gca().invert_yaxis()
        plt.imshow(self.config, interpolation='none', vmin=0, vmax=self.k - 1,
                cmap=matplotlib.cm.binary)
        plt.axis('image')
        plt.title('t = %d' % self.t)

    def step(self):
        """Performs a single step of the simulation by advancing time (and thus
        row) and applying the rule to determine the state of the cells."""
        self.t += 1
        if self.t >= self.height:
            return True

        for patch in range(self.width):
            # We want the items r to the left and to the right of this patch,
            # while wrapping around (e.g. index -1 is the last item on the row).
            # Since slices do not support this, we create an array with the
            # indices we want and use that to index our grid.
            indices = [i % self.width
                    for i in range(patch - self.r, patch + self.r + 1)]
            values = self.config[self.t - 1, indices]
            self.config[self.t, patch] = self.check_rule(values)

        return self.check_state()

    def check_state(self):
        """Check if the current state has already occured"""
        state = self.config[self.t - 1]
        state_num = base_k_to_decimal(state, self.k)

        try:
            return self.prev_states[state_num] - 1
        except:
            self.prev_states[state_num] = self.t
            return 0

if __name__ == '__main__':
    sim = CASim()
    from pyics import GUI
    cx = GUI(sim)
    cx.start()
