#!/usr/bin/python3
# -*- coding:utf-8 -*-

import os
import sys
import argparse
import re
import textwrap
import random
import shutil
from itertools import permutations
# import imageio # Replaced by imagesize

# Filtering of tie points computed with Tapioca (.txt format only)
# For each pair of images, the program divides the 1st image into
# nb_sect_x * nb_sect_y sectors, then randomly selects max_pts tie points at
# most in each sector
# Copyright (C) 2021 Arthur Delorme - v1.3

# This program is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <http://www.gnu.org/licenses/>.
#
# (contact: delorme@ipgp.fr)


# imagesize by shibukawa (https://github.com/shibukawa/imagesize_py)
import re
import struct
from xml.etree import ElementTree

_UNIT_KM = -3
_UNIT_100M = -2
_UNIT_10M = -1
_UNIT_1M = 0
_UNIT_10CM = 1
_UNIT_CM = 2
_UNIT_MM = 3
_UNIT_0_1MM = 4
_UNIT_0_01MM = 5
_UNIT_UM = 6
_UNIT_INCH = 6

_TIFF_TYPE_SIZES = {
  1: 1,
  2: 1,
  3: 2,
  4: 4,
  5: 8,
  6: 1,
  7: 1,
  8: 2,
  9: 4,
  10: 8,
  11: 4,
  12: 8,
}


def _convertToDPI(density, unit):
    if unit == _UNIT_KM:
        return int(density * 0.0000254 + 0.5)
    elif unit == _UNIT_100M:
        return int(density * 0.000254 + 0.5)
    elif unit == _UNIT_10M:
        return int(density * 0.00254 + 0.5)
    elif unit == _UNIT_1M:
        return int(density * 0.0254 + 0.5)
    elif unit == _UNIT_10CM:
        return int(density * 0.254 + 0.5)
    elif unit == _UNIT_CM:
        return int(density * 2.54 + 0.5)
    elif unit == _UNIT_MM:
        return int(density * 25.4 + 0.5)
    elif unit == _UNIT_0_1MM:
        return density * 254
    elif unit == _UNIT_0_01MM:
        return density * 2540
    elif unit == _UNIT_UM:
        return density * 25400
    return density


def _convertToPx(value):
    matched = re.match(r"(\d+)(?:\.\d)?([a-z]*)$", value)
    if not matched:
        raise ValueError("unknown length value: %s" % value)
    else:
        length, unit = matched.groups()
        if unit == "":
            return int(length)
        elif unit == "cm":
            return int(length) * 96 / 2.54
        elif unit == "mm":
            return int(length) * 96 / 2.54 / 10
        elif unit == "in":
            return int(length) * 96
        elif unit == "pc":
            return int(length) * 96 / 6
        elif unit == "pt":
            return int(length) * 96 / 6
        elif unit == "px":
            return int(length)
        else:
            raise ValueError("unknown unit type: %s" % unit)


def get(filepath):
    """
    Return (width, height) for a given img file content
    no requirements
    :type filepath: Union[str, pathlib.Path]
    :rtype Tuple[int, int]
    """
    height = -1
    width = -1

    with open(str(filepath), 'rb') as fhandle:
        head = fhandle.read(24)
        size = len(head)
        # handle GIFs
        if size >= 10 and head[:6] in (b'GIF87a', b'GIF89a'):
            # Check to see if content_type is correct
            try:
                width, height = struct.unpack("<hh", head[6:10])
            except struct.error:
                raise ValueError("Invalid GIF file")
        # see png edition spec bytes are below chunk length then and finally the
        elif size >= 24 and head.startswith(b'\211PNG\r\n\032\n') and \
                head[12:16] == b'IHDR':
            try:
                width, height = struct.unpack(">LL", head[16:24])
            except struct.error:
                raise ValueError("Invalid PNG file")
        # Maybe this is for an older PNG version.
        elif size >= 16 and head.startswith(b'\211PNG\r\n\032\n'):
            # Check to see if we have the right content type
            try:
                width, height = struct.unpack(">LL", head[8:16])
            except struct.error:
                raise ValueError("Invalid PNG file")
        # handle JPEGs
        elif size >= 2 and head.startswith(b'\377\330'):
            try:
                fhandle.seek(0)  # Read 0xff next
                size = 2
                ftype = 0
                while not 0xc0 <= ftype <= 0xcf or ftype in [0xc4, 0xc8, 0xcc]:
                    fhandle.seek(size, 1)
                    byte = fhandle.read(1)
                    while ord(byte) == 0xff:
                        byte = fhandle.read(1)
                    ftype = ord(byte)
                    size = struct.unpack('>H', fhandle.read(2))[0] - 2
                # We are at a SOFn block
                fhandle.seek(1, 1)  # Skip `precision' byte.
                height, width = struct.unpack('>HH', fhandle.read(4))
            except struct.error:
                raise ValueError("Invalid JPEG file")
        # handle JPEG2000s
        elif size >= 12 and head.startswith(b'\x00\x00\x00\x0cjP  \r\n\x87\n'):
            fhandle.seek(48)
            try:
                height, width = struct.unpack('>LL', fhandle.read(8))
            except struct.error:
                raise ValueError("Invalid JPEG2000 file")
        # handle big endian TIFF
        elif size >= 8 and head.startswith(b"\x4d\x4d\x00\x2a"):
            offset = struct.unpack('>L', head[4:8])[0]
            fhandle.seek(offset)
            ifdsize = struct.unpack(">H", fhandle.read(2))[0]
            for i in range(ifdsize):
                tag, datatype, count, data = struct.unpack(">HHLL",
                    fhandle.read(12))
                if tag == 256:
                    if datatype == 3:
                        width = int(data / 65536)
                    elif datatype == 4:
                        width = data
                    else:
                        raise ValueError("Invalid TIFF file: width column data "
                            "type should be SHORT/LONG.")
                elif tag == 257:
                    if datatype == 3:
                        height = int(data / 65536)
                    elif datatype == 4:
                        height = data
                    else:
                        raise ValueError("Invalid TIFF file: height column data"
                            " type should be SHORT/LONG.")
                if width != -1 and height != -1:
                    break
            if width == -1 or height == -1:
                raise ValueError("Invalid TIFF file: width and/or height IDS "
                    "entries are missing.")
        elif size >= 8 and head.startswith(b"\x49\x49\x2a\x00"):
            offset = struct.unpack('<L', head[4:8])[0]
            fhandle.seek(offset)
            ifdsize = struct.unpack("<H", fhandle.read(2))[0]
            for i in range(ifdsize):
                tag, datatype, count, data = struct.unpack("<HHLL",
                    fhandle.read(12))
                if tag == 256:
                    width = data
                elif tag == 257:
                    height = data
                if width != -1 and height != -1:
                    break
            if width == -1 or height == -1:
                raise ValueError("Invalid TIFF file: width and/or height IDS "
                    "entries are missing.")
        # handle SVGs
        elif size >= 5 and head.startswith(b'<?xml'):
            try:
                fhandle.seek(0)
                root = ElementTree.parse(fhandle).getroot()
                width = _convertToPx(root.attrib["width"])
                height = _convertToPx(root.attrib["height"])
            except Exception:
                raise ValueError("Invalid SVG file")

    return width, height


def getDPI(filepath):
    """
    Return (x DPI, y DPI) for a given img file content
    no requirements
    :type filepath: Union[str, pathlib.Path]
    :rtype Tuple[int, int]
    """
    xDPI = -1
    yDPI = -1
    with open(str(filepath), 'rb') as fhandle:
        head = fhandle.read(24)
        size = len(head)
        # handle GIFs
        # GIFs doesn't have density
        if size >= 10 and head[:6] in (b'GIF87a', b'GIF89a'):
            pass
        # see png edition spec bytes are below chunk length then and finally the
        elif size >= 24 and head.startswith(b'\211PNG\r\n\032\n'):
            chunkOffset = 8
            chunk = head[8:]
            while True:
                chunkType = chunk[4:8]
                if chunkType == b'pHYs':
                    try:
                        xDensity, yDensity, unit = struct.unpack(">LLB",
                            chunk[8:])
                    except struct.error:
                        raise ValueError("Invalid PNG file")
                    if unit:
                        xDPI = _convertToDPI(xDensity, _UNIT_1M)
                        yDPI = _convertToDPI(yDensity, _UNIT_1M)
                    else:  # no unit
                        xDPI = xDensity
                        yDPI = yDensity
                    break
                elif chunkType == b'IDAT':
                    break
                else:
                    try:
                        dataSize, = struct.unpack(">L", chunk[0:4])
                    except struct.error:
                        raise ValueError("Invalid PNG file")
                    chunkOffset += dataSize + 12
                    fhandle.seek(chunkOffset)
                    chunk = fhandle.read(17)
        # handle JPEGs
        elif size >= 2 and head.startswith(b'\377\330'):
            try:
                fhandle.seek(0)  # Read 0xff next
                size = 2
                ftype = 0
                while not 0xc0 <= ftype <= 0xcf:
                    if ftype == 0xe0:  # APP0 marker
                        fhandle.seek(7, 1)
                        unit, xDensity, yDensity = struct.unpack(">BHH",
                            fhandle.read(5))
                        if unit == 1 or unit == 0:
                            xDPI = xDensity
                            yDPI = yDensity
                        elif unit == 2:
                            xDPI = _convertToDPI(xDensity, _UNIT_CM)
                            yDPI = _convertToDPI(yDensity, _UNIT_CM)
                        break
                    fhandle.seek(size, 1)
                    byte = fhandle.read(1)
                    while ord(byte) == 0xff:
                        byte = fhandle.read(1)
                    ftype = ord(byte)
                    size = struct.unpack('>H', fhandle.read(2))[0] - 2
            except struct.error:
                raise ValueError("Invalid JPEG file")
        # handle JPEG2000s
        elif size >= 12 and head.startswith(b'\x00\x00\x00\x0cjP  \r\n\x87\n'):
            fhandle.seek(32)
            # skip JP2 image header box
            headerSize = struct.unpack('>L', fhandle.read(4))[0] - 8
            fhandle.seek(4, 1)
            foundResBox = False
            try:
                while headerSize > 0:
                    print("headerSize", headerSize)
                    boxHeader = fhandle.read(8)
                    boxType = boxHeader[4:]
                    print(boxType)
                    if boxType == 'res ':  # find resolution super box
                        foundResBox = True
                        headerSize -= 8
                        print("found res super box")
                        break
                    print("@1", boxHeader)
                    boxSize, = struct.unpack('>L', boxHeader[:4])
                    print("boxSize", boxSize)
                    fhandle.seek(boxSize - 8, 1)
                    headerSize -= boxSize
                if foundResBox:
                    while headerSize > 0:
                        boxHeader = fhandle.read(8)
                        boxType = boxHeader[4:]
                        print(boxType)
                        if boxType == 'resd':  # Display resolution box
                            print("@2")
                            yDensity, xDensity, yUnit, xUnit = struct.unpack(
                                ">HHBB", fhandle.read(10))
                            xDPI = _convertToDPI(xDensity, xUnit)
                            yDPI = _convertToDPI(yDensity, yUnit)
                            break
                        boxSize, = struct.unpack('>L', boxHeader[:4])
                        print("boxSize", boxSize)
                        fhandle.seek(boxSize - 8, 1)
                        headerSize -= boxSize
            except struct.error as e:
                print(e)
                raise ValueError("Invalid JPEG2000 file")
    return xDPI, yDPI




class Image:
    """
    Store image information
    """
    
    def __init__(self, path):
        self.path = path
        # self.im = imageio.imread(self.path)
        # self.w = self.im.shape[1]
        # self.h = self.im.shape[0]
        self.w = self.h = None
        self.get_dim()

    def get_dim(self):
        self.w, self.h = get(self.path)
        if self.w == -1 or self.h == -1:
            sys.exit("\nError: {}; wrong image dimensions\nThis could have "
                "something to do with BigTiff.\nUse tiff_info to see if the "
                "image is in BigTiff format.\nIf yes, try to convert it with "
                "ClipIm (full size), for example.".format(self.path))


def filterTiePoints(im1_path, im2_path, in_dir, out_dir, nb_sect_x, max_pts,
        mask_suffix, verbose):
    """
    Filter tie points between image1 and image2
    """
    siftin_path = '{}/Pastis{}/{}.txt'.format(in_dir, im1_path, im2_path)
    siftout_path = '{}/Pastis{}/{}.txt'.format(out_dir, im1_path, im2_path)

    if not os.path.exists(siftin_path):
        print('\n{}-{}:\nNo tie points'.format(im1_path, im2_path))
        return

    im1 = Image(im1_path)
    im2 = Image(im2_path)

    # Nb of columns and lines per sector, in the image
    nb_col_p_sect = nb_lig_p_sect = im1.w / nb_sect_x

    # Nb of sectors in y direction, deduced from the nb of sectors in x
    # direction and from the nb of columns per sector (sectors are square)
    nb_sect_y = int(im1.h // nb_lig_p_sect)

    # pts_list: list to store the filtered points
    # 2D list: [lig[col]]
    # Addition of one sector in each direction to handle what exceeds
    pts_list = [[[] for x in range(nb_sect_x+1)] for y in range(nb_sect_y+1)]

    # Mask
    # The mask should have the same line/column ratio as the original image,
    # but it may be a downsampled (lighter) equivalent
    mask1 = None
    mask2 = None
    if mask_suffix is not None:
        mask1_name = '{}_{}.tif'.format(
                os.path.splitext(im1.path)[0],
                mask_suffix
            )
        mask2_name = '{}_{}.tif'.format(
                os.path.splitext(im2.path)[0],
                mask_suffix
            )

        if os.path.isfile(mask1_name):
            mask1 = imageio.imread(mask1_name)
            ratio1 = float(mask1.shape[1]) / im1.w # If downsampled
        if os.path.isfile(mask2_name):
            mask2 = imageio.imread(mask2_name)
            ratio2 = float(mask2.shape[1]) / im2.w # If downsampled

    with open(siftin_path) as f:
        for i, l in enumerate(f):
            l_split = l.split(' ')
            x1 = float(l_split[0])
            y1 = float(l_split[1])
            x2 = float(l_split[2])
            y2 = float(l_split[3])

            # Assign each point of image1 to a sector
            x_pts_list = int(int(x1) // nb_col_p_sect)
            y_pts_list = int(int(y1) // nb_lig_p_sect)

            # Mask
            keep = True
            if mask1 is not None:
                if mask1[int(y1*ratio1),int(x1*ratio1)] == 1:
                    keep = False
            if keep and mask2 is not None:
                if mask2[int(y2*ratio2),int(x2*ratio2)] == 1:
                    keep = False

            if keep:
                pts_list[y_pts_list][x_pts_list].append(l.strip())
    total = i+1

    with open(siftout_path, 'w') as f:
        total_after_mask = 0
        sampled = 0
        print('\n{}-{}:'.format(im1.path, im2.path))
        if verbose:
            print('( l , c )\ttotal\t-->\tsample')
        for i, l_split in enumerate(pts_list):
            for j, sector in enumerate(l_split):
                tot = len(sector)
                total_after_mask += tot
                nb = 0
                if tot > 0:
                    if tot >= max_pts:
                        # Select points randomly
                        sector = random.sample(sector, max_pts)
                        nb = max_pts
                    else:
                        nb = tot
                    for pt in sector:
                            f.write('{}\n'.format(pt))
                if verbose:
                    print('( {} , {})\t{}\t-->\t{}'.format(i, j, tot, nb))
                sampled += nb

    if mask1 is not None or mask2 is not None:
        print('{} remaining point(s) out of {} after masking'.format(
            total_after_mask, total))
    print('{} point(s) sampled out of {}'.format(sampled, total_after_mask))


if __name__ == "__main__":
    print('{} Copyright (C) 2021 Arthur Delorme\n'.format(
        os.path.basename(sys.argv[0])))

    parser = argparse.ArgumentParser(
            formatter_class=argparse.RawDescriptionHelpFormatter,
            description=('Filtering of tie points computed with Tapioca '
                '(.txt format only)'),
            epilog=textwrap.dedent('''
                    This program comes with ABSOLUTELY NO WARRANTY.
                    This is free software, and you are welcome to redistribute
                    it under certain conditions.
                    See the GNU General Public License for more details.
                ''')
        )
    parser.add_argument('regexp', metavar='images', type=str,
                        help='regular expression for image selection (ex: '
                            '"^DSC0976[0-9]\.JPG$")')
    parser.add_argument('nb_sect_x', type=int,
                        help='nb of sectors in x direction. The sectors will be'
                            ' square.')
    parser.add_argument('max_pts', type=int,
                        help='max nb of points to keep, per sector. The '
                            'selection is random.')
    parser.add_argument('--homol_dir', type=str, metavar='DIR_NAME',
                        default='Homol',
                        help=('name of the tie point directory from Tapioca '
                            "(default: 'Homol'). Tie points must be in txt "
                            'format.'))
    parser.add_argument('--use_masks', metavar='SUFFIX', type=str,
                        dest='mask_suffix',
                        help=('if provided, mask files (TIFF binary image) '
                            'will be used to remove tie points from selected '
                            "areas. For example, the suffix 'Masq' will tell "
                            'the program to use the mask file '
                            'DSC09751_Masq.tif for the image DSC09751.JPG, if '
                            'it exists.'))
    parser.add_argument('-v', '--verbose', dest='verbose', action='store_true',
                        help='verbose mode')
    args = parser.parse_args()

    # Select the images corresponding to the regexp
    file_list = [f for f in os.listdir('.') if os.path.isfile(f)]
    for i, f in enumerate(file_list):
        if not re.match(args.regexp, f):
            file_list[i] = None
    file_list = [f for f in file_list if f is not None]

    if not file_list:
        sys.exit('ERROR: no image selected')
    elif len(file_list) == 1:
        sys.exit('ERROR: only one image selected')

    if not os.path.exists(args.homol_dir):
        sys.exit("ERROR: directory '{}' does not exist".format(args.homol_dir))

    # Manage the output directory
    out_dir = '{}_filtered'.format(args.homol_dir.rstrip('/'))
    if os.path.exists(out_dir):
        overwrite = None
        while overwrite not in ['y', 'n']:
            overwrite = input("Directory '{}' already exists. Overwrite? (y/n) ".format(out_dir))
        if overwrite == 'y':
            shutil.rmtree(out_dir)
        else:
            while True:
                out_dir = input('Enter the name of the output directory '
                    '(a-z A-Z 0-9 _ -): ')
                if re.match(r"^[\w-]*$", out_dir):
                    break
                else:
                    print('Use only a-z A-Z 0-9 - or _')
    for f in file_list: # Make one directory per image
        os.makedirs(os.path.join(out_dir, 'Pastis{}'.format(f)))

    if args.mask_suffix is None:
        mask_suffix = None
    else:
        mask_suffix = args.mask_suffix

    # For each pair of images
    for p in list(permutations(file_list, 2)):
        im1, im2 = p
        filterTiePoints(im1, im2, args.homol_dir, out_dir, args.nb_sect_x,
            args.max_pts, mask_suffix, args.verbose)
