import lensfunpy
import cv2
import sys
from exif import Image
import os

if len(sys.argv) < 2:
    print("Usage: main.py <image>")
    sys.exit()

filename = sys.argv[1]

path, name = os.path.split(filename)
new_filename = os.path.join(path, f"lenscor_{name}")
# new_filename = 'test.jpg'

with open(filename, 'rb') as image_file:
    img = Image(image_file)
    if 'lens_model' not in img.list_all():
        assert 'IMG_' not in filename
        cam_make = 'NIKON CORPORATION'
        cam_model = 'NIKON D5200'
        lens_make = 'Nikon'
        lens_model = 'Nikon AF-S DX Nikkor 18-55mm f/3.5-5.6G VR II'
    else:
        cam_make = img.make
        cam_model = img.model
        lens_make = img.make # true for our images
        lens_model = img.lens_model

db = lensfunpy.Database()
cams = db.find_cameras(cam_make, cam_model)
assert(len(cams) == 1)
cam = cams[0]

lenses = db.find_lenses(cam, lens_make, lens_model)
assert(len(lenses) == 1)
lens = lenses[0]

print(cam, lens)

im = cv2.imread(filename)
height, width, *_ = im.shape

distance = 0.5 # approximate distance to subject (meters)

mod = lensfunpy.Modifier(lens, cam.crop_factor, width, height)
mod.initialize(img.focal_length, img.f_number, distance)

did_apply = mod.apply_color_modification(im)
if not did_apply:
    print("unable to correct vignetting")

undist_coords = mod.apply_geometry_distortion()
im_undistorted = cv2.remap(im, undist_coords, None, cv2.INTER_LANCZOS4)
cv2.imwrite(new_filename, im_undistorted)
