#!/usr/bin/env python

#
# Copyright (C) 2014-2015, 2020, 2023
# Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'Create a map from a stack of regions'


import sys

import numpy as np
from region import CXCRegion

import ciao_contrib.logger_wrapper as lw


__toolname__ = "mkregmap"
__revision__ = "24 August 2023"

__lgr__ = lw.initialize_logger(__toolname__)
verb0 = __lgr__.verbose0
verb1 = __lgr__.verbose1
verb2 = __lgr__.verbose2
verb3 = __lgr__.verbose3
verb5 = __lgr__.verbose5


def parse_region(region_string):
    """
    Parse stack of regions
    """
    # Region not pickle-able, so have to go serial
    import stk
    mystk = stk.build(region_string)  # pylint: disable=no-member
    regions = [CXCRegion(s) for s in mystk]
    return regions


class Image():
    'Class to load/save image'

    def __init__(self, infile, coord="sky"):
        'Read the image file'
        from crates_contrib.masked_image_crate import MaskedIMAGECrate
        img = MaskedIMAGECrate(infile)

        # Load the pixel values -- really just need the image size
        img_data = img.get_image().values

        # Get lengths of each axis
        xlen = len(img_data[0])
        ylen = len(img_data)

        # Get mapping to physical X, Y
        if coord.lower() not in [x.lower() for x in img.get_axisnames()]:
            raise ValueError(f"Coordinate '{coord}' is not in file")

        self.sky = img.get_transform(coord)
        self.infile = infile
        self.xlen = xlen
        self.ylen = ylen
        self.img = img

    def save_map(self, outvals, outfile):
        'Save the output file'

        self.img.get_image().values = outvals
        self.img.name = "REGMAP"
        self.img.write(outfile, clobber=True)

    def map_regions(self, regions):
        """
        Determine which region each pixel belongs to.
        If there are multiple regions covering the same
        pixel, then the last one will win.
        """

        od = np.zeros([self.ylen, self.xlen])

        reg_no = 0
        for rr in regions:
            reg_no = reg_no+1

            # Compute bounds of region
            bnds = rr.extent()
            xy = np.array([(bnds['x0'], bnds['y0']),
                           (bnds['x1'], bnds['y1'])])
            # Get bounds in image coordinates
            ij = self.sky.invert(xy)

            # Clip bounds 1:axis-length
            i0 = np.floor(np.clip(ij[0][0], 1, self.xlen)).astype('i4')
            j0 = np.floor(np.clip(ij[0][1], 1, self.ylen)).astype('i4')
            i1 = np.ceil(np.clip(ij[1][0], 1, self.xlen)).astype('i4')
            j1 = np.ceil(np.clip(ij[1][1], 1, self.ylen)).astype('i4')

            # Number of pixels in x, y
            nx = i1-i0+1
            ny = j1-j0+1

            # Setup arrays to do conversion from image to sky coords
            ii = [float(x) for x in list(range(i0, i1+1)) * ny]
            jj = [float(x) for x in np.repeat(list(range(j0, j1+1)), nx)]

            # Identify valid pixels
            rirj = [(i, j) for i, j in zip(ii, jj)
                    if self.img.valid(int(i-1), int(j-1))]
            if len(rirj) == 0:
                # no valid pixels, move on
                continue

            # Compute sky coords
            rxry = self.sky.apply(np.array(rirj))

            # Now check pixels in bounding box around region
            for kk in range(len(rxry)):   # pylint: disable=consider-using-enumerate
                _i, _j = [int(q) for q in rirj[kk]]

                # If pixel already assigned, skip it
                if od[_j-1, _i-1] != 0:
                    continue

                # If pixel is inside, tag it with region number.
                _x, _y = rxry[kk]
                if rr.is_inside(_x, _y):
                    od[_j-1, _i-1] = reg_no

        return od


@lw.handle_ciao_errors(__toolname__, __revision__)
def main():
    'Main routine'

    # get parameters
    from ciao_contrib.param_soaker import get_params
    # Load parameters
    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    from ciao_contrib._tools.fileio import outfile_clobber_checks
    outfile_clobber_checks(pars["clobber"], pars["outfile"])
    outfile_clobber_checks(pars["clobber"], pars["binimg"])

    img = Image(pars["infile"], coord=pars["coord"])
    reg = parse_region(pars["regions"])
    out = img.map_regions(reg)

    # Save output
    img.save_map(out, pars["outfile"])
    from ciao_contrib.runtool import add_tool_history
    add_tool_history(pars["outfile"], __toolname__, pars,
                     toolversion=__revision__)

    if len(pars["binimg"]) > 0 and "none" != pars["binimg"].lower():
        from ciao_contrib.runtool import dmmaskbin  # pylint: disable=no-name-in-module
        dmmaskbin(pars["infile"], pars["outfile"]+"[opt type=i4]",
                  pars["binimg"], clobber=True)
        add_tool_history(pars["binimg"], __toolname__, pars,
                         toolversion=__revision__)


if __name__ == "__main__":
    main()
