from . import command, USER_SET, users_map
from abc import ABC, abstractmethod
import sys
import traceback
import re
import logging
from typing import Mapping
from collections import defaultdict
from .tagauth_data import AuthRes


logger = logging.getLogger("auth")

tags: Mapping[str, list] = defaultdict(list)

class Tag(ABC):
    def __new__(cls, *args):
        instance = super().__new__(cls)
        instance._args = args
        return instance

    def __repr__(self):
        if self._args:
            return f"{type(self).__name__.lower()}!{':'.join(self._args)}"
        else:
            return type(self).__name__.lower()

    @abstractmethod
    def _try_auth(self, *args):
        pass

    def try_auth(self, *args):
        try:
            ret = self._try_auth(*args)
            if not ret:
                return AuthRes.CONTINUE
            else:
                return ret
        except:
            traceback.print_exc()
            return AuthRes.FAIL

class All(Tag):
    def _try_auth(self, *args):
        return AuthRes.SUCCESS

class Module(Tag):
    def __init__(self, module):
        self.modulename = module

    def _try_auth(self, bea, function, user):
        if not hasattr(function, "__module__"):
            return AuthRes.CONTINUE
        if function.__module__ == "modules." + self.modulename:
            return AuthRes.SUCCESS

class Command(Tag):
    def __init__(self, command):
        self.commandname = command

    def _try_auth(self, bea, function, user):
        if not hasattr(function, "__module__") or not hasattr(function, '__name__'):
            return AuthRes.CONTINUE
        if function.__name__ == self.commandname:
            return AuthRes.SUCCESS

def try_auth(bea, function, user):
    uid = user.id
    if uid not in tags: return AuthRes.FAIL
    for tag in tags[uid]:
        logger.debug(f"trying tag: {tag}")
        res = tag.try_auth(bea, function, user)
        if (res == AuthRes.SUCCESS or res == AuthRes.FAIL):
            logger.debug(f"explicit {res}")
            return res
    logger.debug("implicit fail")
    return AuthRes.FAIL

TAG_RE = re.compile(r"^([a-z]+)(?:!((?:[a-z.]+)(?::[a-z.]+)*))?$")

@command(r'add tag (\S+) to (' + USER_SET + ')( forever)?', pass_groups=True, auth=True)
def addtag(bea, bot, update, groups):
    """ Add a tag to a given user, either permanently or until the bot restarts. For an overview of tags and their syntax, see the showtags command. """
    self = sys.modules['modules.tagauth']
    tag, at, lifetime = groups
    tag_match = TAG_RE.match(tag)
    if not tag_match:
        return update.message.reply_text('Invalid tag.')

    tag_name, tag_arg = tag_match.groups()

    if not hasattr(self, tag_name.title()):
        return update.message.reply_text('Invalid tag.')

    if tag_arg:
        tag_args = tag_arg.split(':')
    else:
        tag_args = ()

    try:
        tag_instance = getattr(self, tag_name.title())(*tag_args)
    except (TypeError, ValueError) as e:
        return update.message.reply_text(f'Unable to construct tag: {e!r}.')

    def addtag_single(user_id):
        tags[user_id].append(tag_instance)

        if lifetime is not None:
            config_tags = bea.config['tagauth']['tags']
            if user_id in config_tags:
                config_tags[user_id].append(tag)
            else:
                config_tags[user_id] = [tag]

    users_map(bea, at, addtag_single)

    update.message.reply_text('Added tag to user(s).')

@command(r'tags ('+ USER_SET + ')', pass_groups=True)
def showtags(bea, bot, update, groups):
    """
    Show a user's tags.

    Tags are Bea's authorization system, loosely inspired by PAM modules. When authorization is required, Bea will work through a user's tags, trying each one sequentially to see whether it authorizes the requested action.

    There are currently three tag classes: `all`, `module` and `command`, authorizing a user to use all commands, just commands in a specified module, or just a specified command.

    Tags have a textual representation. For instance, the tag that allows access to the addtag command can be represented as `command!addtag`. These are the representations shown in the showtags command and required in the addtag command.
    """
    def single_showtags(uid):
        user_tags = tags[uid]
        update.message.reply_text(f"tags for {uid}: {','.join(repr(tag) for tag in user_tags)}", quote=False)

    users_map(bea, groups[0], single_showtags)

def init(bea, config):
    for uid, taglist in config["tags"].items():
        for tag in taglist:
            name, *args = tag.split('!')
            if (name == "all"):
                tags[uid].append(All())
            if (name == "module"):
                tags[uid].append(Module(args[0]))
            if (name == "command"):
                tags[uid].append(Command(args[0]))
    return [showtags, addtag]
