#!/usr/bin/env python
#
# Copyright (C) 2022-2023 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

"Assign overlapping region area based on selected metric"

import os
import sys

from region import CXCRegion
from pycrates import read_file, IMAGECrate
from ciao_contrib.runtool import make_tool
import ciao_contrib.logger_wrapper as lw


toolname = "rank_roi"
__revision__ = "09 March 2023"

lw.initialize_logger(toolname)
lgr = lw.get_logger(toolname)
verb0 = lgr.verbose0
verb1 = lgr.verbose1
verb2 = lgr.verbose2

"""
The way the roi tool works, specifically when using

    group=exclude

is that all overlaping area is ignored.  So if  file 1 has :
circleA - EllipseB, the there will be another file with
EllipseB - circleA.  The overlap area between A and B is omitted.

The goal of this script is to assign that overlap area to the
"best" source.  Here we define "best" as the source with the largest
number of counts.

    Note: "best" could also be defined as the faintest, biggest,
    smallest, roundest, ... limitless options

We know that the 1st shape in the SRCREG extension is "the" shape
from which any overlapping shapes are excluded and FOV's intersected.

Thankfully, the way roi works, this 1st shape is identically
replicated in other roi files where it is excluded.  So if
we know that circleA in both files is exactly the same (no
string truncation nonsense).
"""


class ROIFile():
    """Hold the contents of a single ROI output file"""

    def __init__(self, filename, out_template):
        """init class from a single roi file

        The 1st shape, `self.region[0]`, is the shape-of-interest
        for this ROIFile object.
        """
        self.filename = filename
        self.region = CXCRegion(filename)
        self.roi = self.region[0]
        self.metric = None
        self.del_other = []
        _cr = read_file(filename)

        self.regionid = int(_cr.get_key_value("REGIONID").lstrip("0"))
        self.outfile = out_template.format(self.regionid)

    def run_dmstat(self, propfile):
        """Run dmstat

        This is common to most of the metrics"""
        self.dmstat = make_tool("dmstat")
        self.dmstat.infile = "{}[sky={}]".format(propfile, self.roi)
        self.dmstat(centroid=False)

    def compute_metric(self, propfile):
        """Compute some metric for shape of interest"""
        raise NotImplementedError("Implement in the subclasses")

    def remove_from(self, other):
        """Determine if self should get overlap area or
        release it to another shape

        The easiest way to do this is to identify which indices
        the other shape is located at and translate that into
        a DM #row filter.
        """
        # skip self
        if other.roi == self.roi:
            return

        # skip if other is bigger
        if other.metric > self.metric:
            return

        # if metric is equal then only keep one (either)
        if other.metric == self.metric:
            myidx = other.region.index(~self.roi)
            if any([x in other.del_other for x in myidx]):
                return

        # remove all instances of other from self.  With fov files
        # the roi file may have multiple instances of other.roi, be
        # sure to remove them all.
        idx = self.region.index(~other.roi)
        self.del_other.extend(idx)

    def write(self, pars):
        """Write out a new roi file by copying selected rows from
        the input.

        This turns out to be a lot easier than trying to copy
        header, history, wcs's from input to output.
        """
        dmcopy = make_tool("dmcopy")

        if len(self.del_other) > 0:
            # +1: array index -> dm row number
            rr = ", ".join([str(x+1) for x in self.del_other])
            ff = "{}[exclude #row={}]"
            ff = ff.format(self.filename, rr)
        else:
            ff = self.filename

        # Not worried about dmcopy bug here since not copying
        # event files (no GTIs to get messed up).
        dmcopy(ff, self.outfile, clobber=pars["clobber"], opt="all")

        from ciao_contrib.runtool import add_tool_history
        add_tool_history(self.outfile, toolname, pars,
                         toolversion=__revision__)


class MostFlux(ROIFile):
    """
    The metric we're using the sum of the pixel values in region.
    This can be counts (or could be say photon flux).
    """
    def compute_metric(self, propfile):
        self.run_dmstat(propfile)
        self.metric = float(self.dmstat.out_sum)


class LeastFlux(ROIFile):
    """
    The metric we're using the sum of the pixel values in region.
    This can be counts (or could be say photon flux).
    """
    def compute_metric(self, propfile):
        self.run_dmstat(propfile)
        self.metric = -1.0*float(self.dmstat.out_sum)


class BrightestPixel(ROIFile):
    """Pick region with brighest pixel
    """
    def compute_metric(self, propfile):
        self.run_dmstat(propfile)
        self.metric = float(self.dmstat.out_max)


class FaintestPixel(ROIFile):
    """Pick region with most faint max pixel
    """
    def compute_metric(self, propfile):
        self.run_dmstat(propfile)
        self.metric = -1.0*float(self.dmstat.out_max)


class LargestArea(ROIFile):
    """Pick region with most area
    """
    def compute_metric(self, propfile):
        self.metric = self.roi.area()


class SmallestArea(ROIFile):
    """Pick region with least area
    """
    def compute_metric(self, propfile):
        self.metric = -1.0*self.roi.area()


def map_metric(func):
    """Map string to object type"""
    retval = {'min': LeastFlux, 'max': MostFlux,
              'bright': BrightestPixel, 'faint': FaintestPixel,
              'big': LargestArea, 'small': SmallestArea, }
    return retval[func]


def load_rois(pars):
    'Load all roi files'
    verb1("Loading ROI files")
    import stk
    rois = stk.build(pars["roifiles"])

    metric = map_metric(pars["method"])
    rois_reg = [metric(r, pars["outfile"]) for r in rois]
    verb1("{} ROI files parsed".format(len(rois_reg)))
    return rois_reg


def clobber_outfiles(pars, rois_reg):
    'Remove outfiles'
    verb1("Checking clobber")
    from ciao_contrib._tools.fileio import outfile_clobber_checks
    pars["clobber"] = ("yes" == pars["clobber"])
    for r in rois_reg:
        outfile_clobber_checks(pars["clobber"], r.outfile)


def get_binning(rois_reg):
    'Get bounding box around all regions and create binning command'

    extents = [x.roi.extent() for x in rois_reg]

    x0 = int(min([x['x0'] for x in extents]))-1.5
    y0 = int(min([x['y0'] for x in extents]))-1.5
    x1 = int(max([x['x1'] for x in extents]))+1.5
    y1 = int(max([x['y1'] for x in extents]))+1.5

    return f'[bin x={x0}:{x1}:1,y={y0}:{y1}:1]'


def compute_metric(pars, rois_reg):
    'Compute roi metric'
    verb1("Checking infile (image v. table)")
    infile = pars["infile"]
    oo = read_file(infile)
    if isinstance(oo, IMAGECrate) is not True:
        verb1("Infile is a table, will bin into an image")
        sky_bin = get_binning(rois_reg)
        infile = infile+sky_bin
    verb1("Computing metric for all ROIs")
    for r in rois_reg:
        r.compute_metric(infile)
        verb2("region {}\tmetric {}".format(r.regionid, r.metric))


def pick_winner(rois_reg):
    'Determine which roi is best'
    verb1("Determining which ROIs should get overlap area")
    for r in rois_reg:
        for k in rois_reg:
            k.remove_from(r)


def write_outputs(pars, rois_reg):
    'Write output'
    verb1("Writing out new roi files")
    for r in rois_reg:
        r.write(pars)


@lw.handle_ciao_errors(toolname, __revision__)
def main():
    'main program'

    from ciao_contrib.param_soaker import get_params
    pars = get_params(toolname, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    rois_reg = load_rois(pars)
    clobber_outfiles(pars, rois_reg)
    compute_metric(pars, rois_reg)
    pick_winner(rois_reg)
    write_outputs(pars, rois_reg)


if __name__ == "__main__":
    try:
        main()
    except Exception as E:
        print("\n# "+toolname+" ("+__revision__+"): ERROR "+str(E)+"\n",
              file=sys.stderr)
        sys.exit(1)
    sys.exit(0)
