import pygame, sys, os
from pygame.locals import *
from PIL import Image, ExifTags, ImageEnhance
from collections import deque
import numpy

name = sys.argv[1]

print(name)

if 'fixed' in name:
    sys.exit()

pygame.init()

SCREEN_SZ = (2250, 1500)

screen = pygame.display.set_mode(SCREEN_SZ)

def thumbify(img):
    thumb = img.copy()
    thumb.thumbnail(SCREEN_SZ)
    return thumb

def open_img(filename):
    img = Image.open(filename)


    _exif = img._getexif()
    if _exif:
        exif=dict(img._getexif().items())

    ratio = img.size[0] / img.size[1]
    height = min(3000, img.size[1])
    width = int(ratio * height)
    img = img.resize((width, height), Image.Resampling.LANCZOS)

    for orientation in ExifTags.TAGS.keys() :
        if ExifTags.TAGS[orientation]=='Orientation' : break

    if _exif:
        if exif[orientation] == 3:
            img=img.rotate(180, expand=True)
        elif exif[orientation] == 6:
            img=img.rotate(270, expand=True)
        elif exif[orientation] == 8:
            img=img.rotate(90, expand=True)

    e = ImageEnhance.Sharpness(img)
    img = e.enhance(1.15)

    thumb = thumbify(img)

    return (img, thumb)

def pilImageToSurface(pilImage):
    return pygame.image.fromstring(pilImage.tobytes(), pilImage.size, pilImage.mode).convert()

def find_coeffs(pa, pb):
    matrix = []
    for p1, p2 in zip(pa, pb):
        matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0]*p1[0], -p2[0]*p1[1]])
        matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1]*p1[0], -p2[1]*p1[1]])

    A = numpy.matrix(matrix, dtype=numpy.float64)
    B = numpy.array(pb).reshape(8)

    res = numpy.linalg.solve(A, B)
    return numpy.array(res).reshape(8)

def bounding_box(quad_pos):
    min_x = min(p[0] for p in quad_pos)
    min_y = min(p[1] for p in quad_pos)
    max_x = max(p[0] for p in quad_pos)
    max_y = max(p[1] for p in quad_pos)

    return (max_x - min_x, max_y - min_y)

def scale_quad_pos(scale, quad_pos):
    return [(int(scale*x), int(scale*y)) for x, y in quad_pos]

def do_transform(quad_pos):
    global pgthumb, thumb, img
    scale = img.size[0] / thumb.size[0]
    img_quad_pos = scale_quad_pos(scale, quad_pos)

    newSize = bounding_box(quad_pos)
    coeffs = find_coeffs([(0, 0), (newSize[0], 0), newSize, (0, newSize[1])], quad_pos)

    thumb = thumb.transform(newSize, Image.PERSPECTIVE, coeffs, Image.BICUBIC)

    newSize = bounding_box(img_quad_pos)
    coeffs = find_coeffs([(0, 0), (newSize[0], 0), newSize, (0, newSize[1])], img_quad_pos)

    img = img.transform(newSize, Image.PERSPECTIVE, coeffs, Image.BICUBIC)


def do_brightness_contrast(amount):
    global img, thumb, pgthumb
    enhancer = ImageEnhance.Brightness(img)
    enhancer2 = ImageEnhance.Brightness(thumb)

    img = enhancer.enhance(1 + (amount / 10))
    thumb = enhancer2.enhance(1 + (amount / 10))

    enhancer3 = ImageEnhance.Contrast(img)
    enhancer4 = ImageEnhance.Contrast(thumb)

    img = enhancer3.enhance(1 + (amount / 80))
    thumb = enhancer4.enhance(1 + (amount / 80))

img, thumb = open_img(name)

quad_pos = deque()


while True:
    evs = pygame.event.get()

    for event in evs:
        if event.type == pygame.MOUSEBUTTONUP:
            pos = pygame.mouse.get_pos()
            quad_pos.append(pos)
            if len(quad_pos) > 4:
                quad_pos.popleft()

        if event.type == pygame.KEYUP:
            if event.key == pygame.K_RETURN:
                if len(quad_pos) == 4:
                    do_transform(quad_pos)
                    quad_pos.clear()
            if event.key == pygame.K_KP_PLUS:
                do_brightness_contrast(+1)
            if event.key == pygame.K_s:
                split_path = os.path.split(name)
                new_name = os.path.join(split_path[0], f"fixed_{split_path[1]}")
                img.save(new_name, quality=80)
                print("Saved!")
            if event.key == pygame.K_q:
                img=img.rotate(90, expand=True)
                thumb = thumbify(img)
            if event.key == pygame.K_e:
                img=img.rotate(-90, expand=True)
                thumb = thumbify(img)
            if event.key == pygame.K_ESCAPE:
                sys.exit(0)


    pygame.draw.rect(screen, (0, 0, 0), pygame.Rect(0, 0, *SCREEN_SZ))

    screen.blit(pilImageToSurface(thumb), (0, 0))

    for pos in quad_pos:
        pygame.draw.circle(screen, (255, 0, 0), pos, 4, 0)

    match len(quad_pos):
        case 2:
            pygame.draw.line(screen, (255, 255, 255), quad_pos[0], quad_pos[1])
        case 3:
            pygame.draw.line(screen, (255, 255, 255), quad_pos[0], quad_pos[1])
            pygame.draw.line(screen, (255, 255, 255), quad_pos[0], quad_pos[2])
            pygame.draw.line(screen, (255, 255, 255), quad_pos[1], quad_pos[2])
        case 4:
            pygame.draw.line(screen, (255, 255, 255), quad_pos[0], quad_pos[1])
            pygame.draw.line(screen, (255, 255, 255), quad_pos[1], quad_pos[2])
            pygame.draw.line(screen, (255, 255, 255), quad_pos[2], quad_pos[3])
            pygame.draw.line(screen, (255, 255, 255), quad_pos[0], quad_pos[3])

    pygame.display.flip()

pygame.quit()
