#!/usr/bin/env python
#
# Copyright (C) 2014-2023 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'Create a map in the form a hexagonal grid'


import sys
import os
import numpy as np

from region import polygon
from crates_contrib.masked_image_crate import MaskedIMAGECrate
import ciao_contrib.logger_wrapper as lw


__toolname__ = "hexgrid"
__revision__ = "24 August 2023"

verb0 = lw.initialize_logger(__toolname__).verbose0
verb1 = lw.initialize_logger(__toolname__).verbose1
verb2 = lw.initialize_logger(__toolname__).verbose2
verb3 = lw.initialize_logger(__toolname__).verbose3
verb5 = lw.initialize_logger(__toolname__).verbose5


class HexagonGrid():
    """
    Create a map or grid of hexagons

    Creates an image whose pixel values indicate which hexagon
    a pixel belongs to.

    >>> hh = HexagonGrid(10, 1024, 1024)
    >>> outvals = hh.make_map()
    """

    def __init__(self, sidelen, xlen, ylen, x0, y0):
        """
        Setup the hexagon properties
        """

        self.sidelen = sidelen
        self.xlen = xlen
        self.ylen = ylen
        self.x0 = x0
        self.y0 = y0

        #
        # Hexagon are aligned so that the "long" axis is parallel
        # to the X-axis.
        #

        xdelta = 3 * self.sidelen
        ydelta = 2 * self.sidelen * np.sin(np.deg2rad(60))

        self.x0 = np.mod(self.x0, xdelta)-1  # -1 => 0 based indexing
        self.y0 = np.mod(self.y0, ydelta)-1

        #
        # Create arrays for a polygon centered around 0,0.  It is shifted
        # to each hex point later
        #
        s60 = np.sin(np.deg2rad(60.0))
        c60 = np.cos(np.deg2rad(60.0))
        self.px = np.array([-1, -c60, c60, 1, c60, -c60, -1]) * sidelen
        self.py = np.array([0, s60, s60, 0, -s60, -s60, 0]) * sidelen

        # Counter for the number of hexagons that are created
        self.counter = 0

        # The output array
        self.stipple = np.zeros([ylen, xlen])

    def _fill_hexagon(self, xx, yy):
        """
        Note: we do everything in logical coords
        """
        self.counter += 1

        #
        # Shift polygon verticies to current location and make
        # a region
        #
        pxx = self.px+xx
        pyy = self.py+yy

        poly = polygon(pxx, pyy)

        for iy in range(int(yy-self.sidelen), int(yy+self.sidelen+1)):
            if iy < 0 or iy >= self.ylen:
                continue
            for ix in range(int(xx-self.sidelen), int(xx+self.sidelen+1)):
                if ix < 0 or ix >= self.xlen:
                    continue

                if poly.is_inside(ix, iy):
                    self.stipple[iy, ix] = self.counter

    def make_map(self):
        """
        Create the hexagon map

        The hexagon grid is created by looping over pixels in
        a box around each hexagon.  There are two offset rows
        (stried) that need to be checked
        """
        xdelta = 3 * self.sidelen
        ydelta = 2 * self.sidelen * np.sin(np.deg2rad(60))

        for yy in np.arange(self.y0-ydelta, self.ylen+ydelta, ydelta):
            for xx in np.arange(self.x0-xdelta, self.xlen+xdelta, xdelta):
                # 1st stride
                self._fill_hexagon(xx, yy)
                # 2nd stride
                self._fill_hexagon(xx+1.5*self.sidelen, yy+ydelta/2.0)

        return self.stipple


@lw.handle_ciao_errors(__toolname__, __revision__)
def main():
    'Main routine'

    # get parameters
    from ciao_contrib.param_soaker import get_params
    # Load parameters
    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    from ciao_contrib._tools.fileio import outfile_clobber_checks
    outfile_clobber_checks(pars["clobber"], pars["outfile"])
    outfile_clobber_checks(pars["clobber"], pars["binimg"])

    #
    # Load image, don't care about the pixel values, just image size
    #
    inimg = MaskedIMAGECrate(pars["infile"], mode="r")
    xx = inimg.get_image().values
    ylen = len(xx)
    xlen = len(xx[0])

    #
    # Create full map
    #
    hexmap = HexagonGrid(float(pars["sidelen"]), xlen, ylen,
                         float(pars["xref"]), float(pars["yref"]))
    stipple = hexmap.make_map()

    #
    # Check pixels in image are inside subspace
    #
    for _ix in range(xlen):
        for _iy in range(ylen):
            if not inimg.valid(_ix, _iy):
                stipple[_iy][_ix] = 0

    # Write output
    if os.path.exists(pars["outfile"]):
        os.remove(pars["outfile"])

    inimg.get_image().values = stipple
    inimg.name = "HEXMAP"
    inimg.write(pars["outfile"], clobber=True)

    from ciao_contrib.runtool import add_tool_history
    add_tool_history(pars["outfile"], __toolname__, pars,
                     toolversion=__revision__)

    # Apply to input image
    if len(pars["binimg"]) > 0 and "none" != pars["binimg"].lower():
        # pylint: disable=no-name-in-module
        from ciao_contrib.runtool import dmmaskbin
        dmmaskbin(pars["infile"], pars["outfile"]+"[opt type=i4]",
                  pars["binimg"], clobber=True)
        add_tool_history(pars["binimg"], __toolname__, pars,
                         toolversion=__revision__)


if __name__ == "__main__":
    main()
