from itertools import chain, combinations
import requests
from lxml import etree
import re
import json

IPSOS_URL = "http://www.ipsos-nederland.nl/ipsos-politieke-barometer/barometer-van-deze-week/"

def powerset(x):
    """ Get the powerset of a set (the set that contains all subsets of a set) """
    return chain.from_iterable(combinations(x, i) for i in range(len(x)+1))

def fetch_poll(poll_url):
    """ Fetches poll data from the URL. An IPSOS poll is assumed. """
    response = requests.get(poll_url)
    tree = etree.fromstring(response.text, etree.HTMLParser())
    script_tag = tree.xpath("//div[@class='barometer_wrapper']/script")[0]
    parties = json.loads(re.search('categories:(.+?)}', script_tag.text).group(1))
    data = json.loads(re.findall('"data":(.+?\])', script_tag.text)[-1])
    return dict(zip(parties, data))

class Coalities():
    def __init__(self, poll=None):
        self.poll = poll or fetch_poll(IPSOS_URL)
        self.powerset = list(filter(lambda x: sum(dict(x).values()) > 75,
            powerset(list(self.poll.items()))))

    def exclude_combination(self, party_names):
        """ Exclude a combination of parties (because they are incompatible) """
        self.powerset = list(filter(lambda x: any(y not in dict(x).keys() for y in party_names),
            self.powerset))
        return self.powerset

    def exclude_party(self, party_name):
        """ Exclude a particular party (e.g. a right-extremist party that nobody likes) """
        return self.exclude_combination([party_name])

    def max_parties(self, n):
        """ Set the maximum allowed number of parties in a coalition. """
        self.powerset = list(filter(lambda x: len(dict(x).keys()) < n, self.powerset))
        return self.powerset

    def print_combinations(self):
        """ Pretty-print the combinations. """
        self.powerset.sort(key=lambda x: sum(dict(x).values()))
        for item in self.powerset:
            sum_parties = '+'.join(dict(item).keys())
            sum_seats = sum(dict(item).values())
            print("{:<50}{}".format(sum_parties, sum_seats))

    def reset(self):
        """ Reset the filters. """
        self.powerset = list(filter(lambda x: sum(dict(x).values()) > 75,
            powerset(list(self.poll.items()))))
        return self.powerset
