#!/usr/bin/env python
#
# Copyright (C) 2019, 2024 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#


'Convert a map file into a region file'


import os
import sys
from tempfile import NamedTemporaryFile

import numpy as np

from ciao_contrib.runtool import make_tool
import ciao_contrib.logger_wrapper as lw
from region import CXCRegion


__toolname__ = "map2reg"
__revision__ = "28 June 2024"

__lgr__ = lw.initialize_logger(__toolname__)
verb0 = __lgr__.verbose0
verb1 = __lgr__.verbose1
verb2 = __lgr__.verbose2
verb3 = __lgr__.verbose3
verb5 = __lgr__.verbose5

# Okay, I don't like globals, but better than passing around
IMG = None
INFILE = None


def set_nproc(pars):
    'Set number of processors'

    if "no" == pars["parallel"]:
        pars["nproc"] = 1
        return

    if pars["nproc"] == "INDEF":
        # parallel_map doesn't limit to number of CPU's like
        # taskRunner does so we have to set the actual number
        import multiprocessing
        pars["nproc"] = multiprocessing.cpu_count()
    else:
        pars["nproc"] = int(pars["nproc"])


def get_region(n_at):
    """
    Apply threshold at the map value (n_at)

    Find location of max pixel (will be some arbitrary pixel in map)

    Lasso to create region  -- this only works for contigious regions

    """
    verb2(n_at)

    pixels = np.argwhere(IMG.get_image().values == n_at)

    if len(pixels) == 0:
        return None

    ypos, xpos = pixels[0] + 1  # Image coords are +1 from array index

    with NamedTemporaryFile(dir=os.environ["ASCDS_WORK_PATH"],
                            suffix="lasso.reg", delete=True) as tmpfile:

        lasso = make_tool("dmimglasso")
        lasso.infile = INFILE
        lasso.outfile = tmpfile.name
        lasso.clobber = True
        lasso.low_value = n_at - 0.5
        lasso.high_value = n_at + 0.5
        lasso.value = "absolute"
        lasso.coord = "logical"
        lasso.xpos = xpos
        lasso.ypos = ypos
        lasso()

        lasso_reg = CXCRegion(tmpfile.name)

    if lasso_reg is None or len(lasso_reg) == 0:
        verb0(f"Warning: {n_at} map is empty")

    # CXCRegion object's can't be pickled so we have to pass back the
    # string representation.

    retval = str(lasso_reg)
    return retval


def write_region_output(regions, outfile, stdhdr=None, clobber=True):
    """Unfortunately between memory usage and strlen, resort to home
    grown output routine"""

    def _unpack_regions(col):
        "unpack values from region object"
        retval = []
        for reg in regions:
            if reg is None or len(reg) == 0:
                continue
            for shape in CXCRegion(reg).shapes:
                vals = getattr(shape, col)
                if vals is None:
                    vals = np.nan
                retval.append(vals)
        return retval

    def _make_output_shapes():
        "Set the output shape name based on inclusion flag"
        shapes = _unpack_regions("name")
        incl = _unpack_regions("include")
        retval = []
        for shape, inc in zip(shapes, incl):
            val = f"{inc.str}{shape}"
            retval.append(val)
        return retval

    def _make_output_component():
        "Create component values"
        logic = _unpack_regions("logic")
        retval = []
        val = 0
        for lgc in logic:
            val += lgc.val
            retval.append(val)
        return retval

    def _pad_to_max(col):
        "nan pad arrays to max length"
        maxlen = max(len(xx) for xx in col)
        retval = []
        for val in col:
            tmp = [np.nan]*maxlen
            tmp[0:len(val)] = val
            retval.append(tmp)
        return retval

    def _update_columns():
        'Add units and comments to columns'
        coldefs = {'X': {'unit': 'pixel', 'desc': 'Position'},
                   'Y': {'unit': 'pixel', 'desc': 'Position'},
                   'R': {'unit': 'pixel', 'desc': 'Radius'},
                   'ROTANG': {'unit': 'deg', 'desc': 'Angle'},
                   'SHAPE': {'unit': '', 'desc': 'Region shape type'},
                   'COMPONENT': {'unit': '', 'desc': 'Component number'},
                   }
        for cols in coldefs:
            col = tab.get_column(cols)
            col.unit = coldefs[cols]['unit']
            col.desc = coldefs[cols]['desc']

    def _update_header():
        'Add standard headers'

        def _get_version():
            'Get version number'
            from ciao_version import get_ciao_version
            number = get_ciao_version()
            retval = f"CIAO {number}"
            return retval

        if stdhdr:
            tab.put_stdhdr(stdhdr)

        keywords = (("MTYPE1", "POS"), ("MFORM1", "X,Y"),
                    ("CONTENT", "REGION"),
                    ("HDUDOC", "ASC-FITS-REGION-1.2: Rots, McDowell: FITS REGION Binary Table Design"),
                    ("HDUVERS", "1.2.0"), ("HDUCLASS", "ASC"),
                    ("HDUCLAS1", "REGION"), ("HDUCLAS2", "STANDARD"),
                    ("CREATORE", __toolname__), ("ASCDSVER", _get_version()),
                    )

        from pycrates import set_key
        for kwrd, val in keywords:
            set_key(tab, kwrd, val)
        tab.name = "REGION"

    from crates_contrib.utils import make_table_crate

    xpos = _pad_to_max(_unpack_regions("xpoints"))
    ypos = _pad_to_max(_unpack_regions("ypoints"))
    out_shapes = _make_output_shapes()
    component = _make_output_component()
    radii = _unpack_regions('radii')
    angle = _unpack_regions('angles')

    col_names = ['X', 'Y', 'SHAPE', 'R', 'ROTANG', 'COMPONENT']
    tab = make_table_crate(np.array(xpos), np.array(ypos), out_shapes,
                           np.array(radii), np.array(angle),
                           np.array(component), colnames=col_names)
    _update_columns()
    _update_header()
    tab.write(outfile, clobber=clobber)


@lw.handle_ciao_errors(__toolname__, __revision__)
def main():
    """
    Main routine
    """

    # Load parameters
    from ciao_contrib.param_soaker import get_params
    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    from ciao_contrib._tools.fileio import outfile_clobber_checks
    outfile_clobber_checks(pars["clobber"], pars["outfile"])

    set_nproc(pars)

    # Load map file
    from crates_contrib.masked_image_crate import MaskedIMAGECrate
    global IMG               # pylint: disable=global-statement
    global INFILE            # pylint: disable=global-statement
    INFILE = pars["infile"]
    IMG = MaskedIMAGECrate(INFILE)

    # Get list values to iterate over
    vals = IMG.get_image().values
    ylen, xlen = vals.shape

    for yy in range(ylen):
        for xx in range(xlen):
            if not IMG.valid(xx, yy):
                vals[yy][xx] = 0
    IMG.get_image().values = vals

    uniq_vals = np.unique(vals)
    uniq_vals = uniq_vals[uniq_vals > 0]

    # Create regions
    from sherpa.utils import parallel_map
    outreg = parallel_map(get_region, uniq_vals,
                          numcores=int(pars["nproc"]))

    stdhdr = IMG.get_stdhdr("BASIC", add_missing=False)
    write_region_output(outreg, pars["outfile"], stdhdr=stdhdr,
                        clobber=pars["clobber"])

    from ciao_contrib.runtool import add_tool_history
    add_tool_history(pars["outfile"], __toolname__, pars,
                     toolversion=__revision__)


if __name__ == "__main__":
    main()
