#!/usr/bin/env python
#
# Copyright (C) 2019-2020, 2023
# Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

"""Adaptive bin by iterative centroid mapping"""

import sys
import os
import numpy as np
import ciao_contrib.logger_wrapper as lw
from pycrates import read_file
from crates_contrib.masked_image_crate import MaskedIMAGECrate

__TOOLNAME__ = "centroid_map"
__REVISION__ = "13 January 2024"


LGR = lw.initialize_logger(__TOOLNAME__)
VERB0 = LGR.verbose0
VERB1 = LGR.verbose1
VERB2 = LGR.verbose2
VERB3 = LGR.verbose3
VERB5 = LGR.verbose5


class CIAOTemporaryFile():
    """
    A little class to make sure that tmpfiles are forcefully
    removed at destruction.
    """

    def __init__(self, *args, **kwargs):
        'create temp file'
        from tempfile import NamedTemporaryFile
        self.tmpfile = NamedTemporaryFile(dir=os.environ["ASCDS_WORK_PATH"],
                                          delete=False, *args, **kwargs)
        self.name = self.tmpfile.name

    def __del__(self):
        'when object is deleted, so is temp file'
        self.tmpfile.close()
        if os.path.exists(self.name):
            os.remove(self.name)


class InputImage():
    'Object to hold input image'

    def __init__(self, infile, scale):
        'load image and compute coordinates'
        self.input_image = MaskedIMAGECrate(infile)
        self.imgvals = self.input_image.get_image().values.astype(float)
        self.imgvals = np.abs(self.imgvals)

        for xx in range(self.imgvals.shape[1]):
            for yy in range(self.imgvals.shape[0]):
                if not self.input_image.valid(xx, yy):
                    self.imgvals[yy][xx] = np.nan

        func = self._map_scale_function(scale)
        self.imgvals = func(self.imgvals)

        self.xlen = self.imgvals.shape[1]
        self.ylen = self.imgvals.shape[0]

        # Make matrix w/ constant Y values
        xx = list(np.arange(self.xlen))*self.ylen
        xx = np.array(xx)
        self.xx = xx.reshape(self.imgvals.shape)

        # Make matrix w/ constant X values
        yy = list(np.arange(self.ylen))*self.xlen
        yy = np.array(yy)
        self.yy = yy.reshape(self.imgvals.shape[::-1]).T

        # Weight imagevals by X and Y
        self.wx = self.xx * self.imgvals
        self.wy = self.yy * self.imgvals

    @staticmethod
    def _map_scale_function(scale):
        'Map scaling function to numpy function'

        def square(x):
            return x*x

        if 'linear' == scale:
            func = np.abs
        elif 'sqrt' == scale:
            func = np.sqrt
        elif 'squared' == scale:
            func = square
        elif 'asinh' == scale:
            func = np.arcsinh
        else:
            raise RuntimeError(f"Unsupported scale value: {scale}")
        return func

    def write_new_sites(self, outvals, outfile):
        'Write output with the centroids'
        self.input_image.name = "centroid_map"
        self.input_image.get_image().values = outvals
        self.input_image.write(outfile, clobber="yes")


def centroid_map(mapfile, img, outfile):
    'Main routine, called multiple times'

    mapvals = read_file(mapfile).get_image().values
    outvals = np.zeros_like(mapvals)
    assert mapvals.shape == img.imgvals.shape, "Image sizes must match"

    # Operate over map values
    unq = np.unique(mapvals)

    for uu in unq:
        if 0 == uu:
            continue
        idx = np.where(mapvals == uu)
        w = np.nansum(img.imgvals[idx])
        if 0 == w:
            # If sum is 0, use unweighted value
            cx = np.average(img.xx[idx])
            cy = np.average(img.yy[idx])
        else:
            cx = np.nansum(img.wx[idx])/w
            cy = np.nansum(img.wy[idx])/w
        outvals[int(cy)][int(cx)] = uu

    img.write_new_sites(outvals, outfile)


@lw.handle_ciao_errors(__TOOLNAME__, __REVISION__)
def main():
    'Main routine'
    # Load parameters
    from ciao_contrib.param_soaker import get_params
    pars = get_params(__TOOLNAME__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": VERB1})

    infile = pars["infile"]
    outfile = pars["outfile"]
    sitefile = pars["sitefile"]
    if 0 == len(sitefile) or "none" == sitefile.lower():
        sitefile = None
    numiter = int(pars["numiter"])

    # Clobber output
    from ciao_contrib._tools.fileio import outfile_clobber_checks
    outfile_clobber_checks(pars["clobber"], outfile)

    # Load image
    img = InputImage(infile, scale=pars["scale"])

    # Compute tessellation
    from ciao_contrib.runtool import dmcopy
    from ciao_contrib.runtool import vtbin

    mapfile = CIAOTemporaryFile()
    vtbin(infile=infile, outfile=mapfile.name, site=sitefile, clobber=True)

    # Save original values
    oldvals = read_file(mapfile.name).get_image().values.copy()

    # Loop of iterations
    for niter in range(numiter):
        VERB1(f"Working iteration {niter}")
        # Compute centroid in each voronoi cell
        sitefile = CIAOTemporaryFile()
        centroid_map(mapfile.name, img, sitefile.name)

        # compute tessellation to create new voronoi cells
        mapfile = CIAOTemporaryFile()
        vtbin(infile=infile, outfile=mapfile.name, site=sitefile.name,
              clobber=True)

        # check to see if no change (converged) then exit
        newvals = read_file(mapfile.name).get_image().values.copy()
        diff = np.argwhere(np.not_equal(oldvals, newvals))
        oldvals = newvals
        ndiff = len(diff)
        VERB2(f"Number of pixels different: {ndiff}")
        if 0 == ndiff:
            VERB0(f"Converged at step {niter}. Done.")
            break

        if int(pars["verbose"]) >= 2:
            dmcopy(sitefile.name, outfile+f".i{niter:03d}", clobber=True)

    # Rename last temp file to final output file
    dmcopy(mapfile.name+"[1][CENTROID_MAP]", outfile, clobber=True)

    # Add history
    from ciao_contrib.runtool import add_tool_history
    add_tool_history(outfile, __TOOLNAME__, pars, toolversion=__REVISION__)


if __name__ == "__main__":
    main()
