#!/usr/bin/env python

# Copyright (C) 2024 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'Compute fraction of PSF in neighboring regions'

import os
import sys
from tempfile import NamedTemporaryFile, TemporaryDirectory

import numpy as np

import ciao_contrib.logger_wrapper as lw
from ciao_contrib.runtool import make_tool
from pycrates import IMAGECrate, CrateData
import paramio as pio


__toolname__ = "mkrprm"
__revision__ = "28 August 2024"

lw.initialize_logger(__toolname__)
VERB0 = lw.get_logger(__toolname__).verbose0
VERB1 = lw.get_logger(__toolname__).verbose1
VERB2 = lw.get_logger(__toolname__).verbose2


class Convolution():
    'Base class for two types of convolution'

    def __init__(self, pars):
        'Constructor'
        self.pars = pars

    def setup(self):
        'Setup routines'
        raise NotImplementedError("Implement in derived class")

    def __call__(self, mapcrate):
        'Perform the convolution'
        raise NotImplementedError("Implement in derived class")

    def __del__(self):
        'Cleanup when done'
        raise NotImplementedError("Implement in derived class")


class PSFMapConvolution(Convolution):
    'Convolve with dmimgadapt using PSFMap for function size'

    def __init__(self, pars):
        'Same as base class'
        super().__init__(pars)
        self.psfmap = None
        self.setup()

    def setup(self):
        '''Create the PSFMap

        Create the PSF map. Must be in units of logical pixel size
        '''
        VERB2("Started make_psfmap")

        tmpfile = NamedTemporaryFile(dir=self.pars["tmpdir"],
                                     suffix="_psf.map", delete=False)

        mkpsfmap = make_tool("mkpsfmap")

        mkpsfmap.infile = self.pars["infile"]
        mkpsfmap.outfile = tmpfile.name
        mkpsfmap.energy = self.pars["energy"]
        mkpsfmap.spectrum = ""
        mkpsfmap.ecf = self.pars["ecf"]
        mkpsfmap.units = "logical"
        verb = mkpsfmap(clobber=True)
        if verb:
            VERB2(verb)

        psffile = IMAGECrate(mkpsfmap.outfile)

        VERB2("Done mkpsfmap")

        # Round pixel values, remove NaN's.

        vals = psffile.get_image().values

        vals[~np.isfinite(vals)] = 0
        vals = vals * 20.0
        vals = np.int32(vals)/20.0
        vals[vals < 0] = 0
        psffile.get_image().values = vals

        psffile.write(psffile.get_filename(), clobber=True)

        VERB2("Done make_psfmap")

        self.psfmap = psffile

    def __call__(self, mapcrate):
        '''Convolve the mapID crate w/ PSF map

        Smooth map file with PSF map using dmimgadapt
        '''

        tmpout = NamedTemporaryFile(dir=self.pars["tmpdir"],
                                    suffix="_out.map", delete=False)

        # Normalize to sum=1.0
        vals = mapcrate.get_image().values
        total = np.nansum(vals)
        vals = vals / total
        mapcrate.get_image().values = vals
        mapcrate.write(mapcrate.get_filename(), clobber=True)

        abin = make_tool("dmimgadapt")
        abin.infile = mapcrate.get_filename()
        abin.outfile = tmpout.name
        abin.function = self.pars["function"]
        abin.inradfile = self.psfmap.get_filename()
        abin.counts = 1   # Counts don't matter with inradfile specified
        vv = abin(clobber=True)

        if vv:
            VERB2(vv)

        retvals = IMAGECrate(tmpout.name)

        return retvals

    def __del__(self):
        delete_crate(self.psfmap)


class MarxConvolution(Convolution):
    'Convolve with marx; ie simulate w/ marx'

    def __init__(self, pars):
        'Same as base class'
        super().__init__(pars)
        self.marxpar = None
        self.setup()

    def setup(self):
        '''Generate marx.par file

        Run marx with a point source model, low flux just go get marx.par
        '''
        from glob import glob
        import shutil as shu
        tmpfile = NamedTemporaryFile(dir=self.pars["tmpdir"],
                                     suffix="_marx.par", delete=False)

        simulate_psf = make_tool("simulate_psf")
        simulate_psf.infile = self.pars["infile"]

        img = IMAGECrate(self.pars["infile"], "r")
        simulate_psf.ra = img.get_key_value("RA_PNT")
        simulate_psf.dec = img.get_key_value("DEC_PNT")
        simulate_psf.spectrumfile = ""
        simulate_psf.monoenergy = self.pars["energy"]
        simulate_psf.flux = 1.0e-6  # any small value will do
        simulate_psf.keep = True
        simulate_psf.marx_root = self.pars["marx_root"]
        simulate_psf.random_seed = self.pars["random_seed"]

        with TemporaryDirectory(suffix="_marx",
                                dir=self.pars["tmpdir"]) as marxdir:
            simulate_psf.outroot = os.path.join(marxdir, "sim")
            simulate_psf()
            outfile = glob(os.path.join(marxdir, "*_marx.par"))
            assert len(outfile) == 1, "There should only be one marx parameter file"
            os.unlink(tmpfile.name)
            shu.copy(outfile[0], tmpfile.name)

        self.marxpar = tmpfile.name

    def _pad_input_to_center_image(self, mapcrate, img_grid):
        '''MARX requires the center of the image to be the aimpoint,
        need to pad image to make that happen'''

        dmcoords = make_tool("dmcoords")
        dmcoords.infile = mapcrate.get_filename()
        dmcoords.option = "cel"
        dmcoords.celfmt = "deg"

        # These values are pset in the setup to be equal to PNT values
        dmcoords.ra = pio.pget(self.marxpar, "SourceRA")
        dmcoords.dec = pio.pget(self.marxpar, "SourceDEC")
        dmcoords()

        xcenter = float(dmcoords.x)
        ycenter = float(dmcoords.y)

        def split_grid(center, axis):
            'Split the grid into parts'
            # x=lo:hi:#bin,y=lo:hi:#bin
            vals = img_grid.split(",")[axis].split("=")[1].split(":")

            lo = float(vals[0])
            hi = float(vals[1])
            no = int(vals[2].lstrip("#"))
            axlen = max(abs(center-lo), abs(center-hi))
            axbin = (hi - lo) / no
            ngrid = f"{center-axlen}:{center+axlen}:{axbin}"
            return ngrid

        xgrid = split_grid(xcenter, 0)
        ygrid = split_grid(ycenter, 1)

        newgrid = f"x={xgrid},y={ygrid}"

        tmpimg = NamedTemporaryFile(dir=self.pars["tmpdir"],
                                    suffix="_padded.img", delete=False)
        dmcopy = make_tool("dmcopy")
        dmcopy(f"{mapcrate.get_filename()}[bin {newgrid}]", tmpimg.name,
               clobber=True)
        retval = IMAGECrate(tmpimg.name)
        return retval

    def __call__(self, mapcrate):
        '''Simulate, aka convolve

        Smooth map file with PSF map using dmimgadapt
        '''

        import subprocess as sp

        tmpevt = NamedTemporaryFile(dir=self.pars["tmpdir"],
                                    suffix="_marx.fits", delete=True)
        tmpimg = NamedTemporaryFile(dir=self.pars["tmpdir"],
                                    suffix="_marx.img", delete=False)

        gsl = make_tool("get_sky_limits")
        gsl(mapcrate.get_filename())
        grid = gsl.dmfilter

        padded_image = self._pad_input_to_center_image(mapcrate, grid)

        pio.pset(self.marxpar, "SourceType", "IMAGE")
        pio.pset(self.marxpar, "S-ImageFile", padded_image.get_filename())
        pio.pset(self.marxpar, "SourceFlux", "1.0e-2")

        with TemporaryDirectory(suffix="_marx", dir=self.pars["tmpdir"]) as marxdir:
            pio.pset(self.marxpar, "OutputDir", marxdir)

            marx = os.path.join(self.pars["marx_root"], "bin", "marx")
            cmd = [marx, f"@@{self.marxpar}"]
            sp.check_output(cmd)

            m2f = os.path.join(self.pars["marx_root"], "bin", "marx2fits")
            cmd = [m2f, marxdir, tmpevt.name]
            sp.check_output(cmd)

        delete_crate(padded_image)

        dmcopy = make_tool("dmcopy")
        dmcopy(f"{tmpevt.name}[bin {grid}][opt type=r4]", tmpimg.name, clobber=True)

        # Now need to normalize to 1

        retval = IMAGECrate(tmpimg.name, "rw")
        vals = retval.get_image().values*1.0
        vals = vals / np.nansum(vals)
        retval.get_image().values = vals
        retval.write()  # Update so vals are normalized

        return retval

    def __del__(self):
        'Cleanup'
        if os.path.exists(self.marxpar):
            os.unlink(self.marxpar)


def make_id_map(pars):
    'Convert input stack of regions to map file'

    VERB2("Started make_id_map")

    tmpfile = NamedTemporaryFile(dir=pars["tmpdir"], suffix="_reg.map",
                                 delete=False)

    mkregmap = make_tool("mkregmap")
    mkregmap.infile = pars["infile"]
    mkregmap.regions = pars["regions"]
    mkregmap.outfile = tmpfile.name
    verb = mkregmap(clobber=True)

    if verb:
        VERB2(verb)

    mapfile = IMAGECrate(mkregmap.outfile)

    VERB2("Done make_id_map")
    return mapfile


def get_crate_from_map_with_id(mapfile, mapidx, pars):
    '''Create a crate with image w/ single mapid value'''

    tmpout = NamedTemporaryFile(dir=pars["tmpdir"], suffix="_id.map",
                                delete=False)

    vals = mapfile.get_image().values

    # copy values
    vals = vals * 1

    # Values not equal to ID are set to 0
    vals[vals != mapidx] = 0

    # Values equal to ID are set to 1
    vals[vals > 0] = 1

    # Save to new file
    crate = IMAGECrate(mapfile.get_filename(), "r")
    crate.get_image().values = vals
    crate.write(tmpout.name, clobber=True)

    crate = IMAGECrate(tmpout.name)

    return crate   # mapidx_file


def delete_crate(crate):
    '''Helper to delete temp crates'''
    if os.path.exists(crate.get_filename()):
        os.unlink(crate.get_filename())


def run_inner_loop(args):
    "Inner loop run for each ID"

    _mapfile, _outer_psf, inner, pars = args

    inner_crate = get_crate_from_map_with_id(_mapfile, inner, pars)
    product = _outer_psf.get_image().values * inner_crate.get_image().values
    frac = np.nansum(product)
    delete_crate(inner_crate)
    return frac


def build_matrix(mapfile, response, pars):
    '''Main routine to compute PSF overlap in neighboring regions

    The basic algorithm is simple:

      Get a map with a single ID value
      Smooth it with PSF map
      For each mapID, compute the overlap, which is literally just
         the smoothed map multiplied by the ID map
      Save values
    '''

    # Get list of ID's in mapfile
    idx_vals = np.sort(np.unique(mapfile.get_image().values))
    idx_vals = [int(x) for x in idx_vals if x > 0]

    # Setup output matrix
    max_id = max(idx_vals)
    matrix = np.zeros([max_id, max_id])

    for outer in idx_vals:

        # Smooth ID with PSF
        outer_crate = get_crate_from_map_with_id(mapfile, outer, pars)
        outer_psf = response(outer_crate)

        # ~ # Parallel -- does not work, some race condition yields bad results
        # ~ from sherpa.utils import parallel_map
        # ~ inner = [x for x in idx_vals if x>0]
        # ~ vals = [(mapfile, outer_psf, v) for v in inner]
        # ~ frac = parallel_map(run_inner_loop, vals)

        # Compute frac in neighbors in serial
        frac = []
        for inner in idx_vals:
            frac.append(run_inner_loop((mapfile, outer_psf, inner, pars)))

        if int(pars["verbose"]) > 0:
            loc = idx_vals.index(outer)
            ll = len(idx_vals)
            percent = (100.0*(loc+1.0))/ll
            sys.stdout.write(f"\rPercent Complete: {percent:5.1f}%")

        matrix[outer-1, :] = frac

        delete_crate(outer_crate)
        delete_crate(outer_psf)

    if int(pars["verbose"]) > 0:
        sys.stdout.write("\n")

    return matrix


def save_matrix(matrix, pars):
    'Save to outfile'

    from ciao_contrib.runtool import add_tool_history

    cdata = CrateData()
    cdata.name = "MATRIX"
    cdata.values = matrix

    crate = IMAGECrate()
    crate.add_image(cdata)

    # clobber is checked at the beginning of script
    crate.write(pars["outfile"], clobber=pars["clobber"])

    add_tool_history(pars["outfile"], __toolname__,
                     pars, toolversion=__revision__)


@lw.handle_ciao_errors(__toolname__, __revision__)
def main():
    'Main routine'

    # Load parameters
    from ciao_contrib.param_soaker import get_params
    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": VERB1},
                      revision=__revision__)

    from ciao_contrib._tools.fileio import outfile_clobber_checks
    outfile_clobber_checks(pars["clobber"], pars["outfile"])

    mapfile = make_id_map(pars)

    if pars["psfmethod"] == "map":
        response = PSFMapConvolution(pars)
    elif pars["psfmethod"] == "marx":
        response = MarxConvolution(pars)
    else:
        raise ValueError("Unknown value for 'psfmethod' parameter")

    matrix = build_matrix(mapfile, response, pars)
    save_matrix(matrix, pars)

    delete_crate(mapfile)


if __name__ == "__main__":
    main()
