#!/usr/bin/env python

#
# Copyright (C) 2015-2016, 2018, 2019
#               Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

import os
import sys

# This is only needed for development.
try:
    if not __file__.startswith(os.environ['ASCDS_INSTALL']):
        _thisdir = os.path.dirname(__file__)
        _libname = "python{}.{}".format(sys.version_info.major,
                                        sys.version_info.minor)
        _pathdir = os.path.normpath(os.path.join(_thisdir, '../lib', _libname, 'site-packages'))
        if os.path.isdir(_pathdir):
            os.sys.path.insert(1, _pathdir)
        else:
            print("*** WARNING: no {}".format(_pathdir))

        del _libname
        del _pathdir
        del _thisdir

except KeyError:
    raise IOError('Unable to find ASCDS_INSTALL environment variable.\nHas CIAO been started?')

import ciao_contrib.logger_wrapper as lw


toolname = "gti_align"
__revision__ = "11 June 2019"

lw.initialize_logger(toolname)
lgr = lw.get_logger(toolname)
verb0 = lgr.verbose0
verb1 = lgr.verbose1
verb2 = lgr.verbose2
verb3 = lgr.verbose3
verb5 = lgr.verbose5


import numpy as np
from pycrates import read_file


class ExposureStatsFile():

    tlo = {}
    thi = {}

    per_chip_output = {}

    def __init__(self, infile):
        self.infile= infile

        tab = read_file(infile)
        if 'CONTINUOUS' == tab.get_key_value("READMODE"):
            raise RuntimeError("This script cannot be used with continuous clocking mode data.")

        self._times = tab.get_column("time").values*1.0
        self._expno = tab.get_column("expno").values*1
        self._ccds = tab.get_column("ccd_id").values*1

        self.ccds = list(set(self._ccds))
        self.ccds.sort()

        # We know TIMEPIXR = 0.5 , ie times are middle of TIMEDEL
        # length time bin
        self.timedel = tab.get_key_value("timedel")
        self.timedel_d2 = self.timedel/2.0

        self.flushtime = tab.get_key_value("flshtime")

    def get_per_chip_exposure_time_boundaries( self, ccd_id ):
        """
        The stat1 file has the mid-time for each exposure number that is sent.
        For dropped exposures there is a gap.

        We create start/stop times for each exposure.
        At the start of a gap we use time-timedel/2.  At the end
        of a gap we use time+timedel/2.

        The exposure are not identically timedel seconds apart so
        when we have 2 consecutive exposures we use the average/middle
        time between them.
        """

        tt = [ x[1] for x in zip( self._ccds, self._times) if x[0] == ccd_id ]
        ee = [ x[1] for x in zip( self._ccds, self._expno) if x[0] == ccd_id ]

        # read file, use DM filter to select just this ccd_id

        if len(ee) ==  0:
            raise RuntimeError("CCD_ID={} is not in exposure stats file {}".format(ccd_id, infile))

        # setup output arrays
        self.tlo[ccd_id] = np.zeros_like(tt)
        self.thi[ccd_id] = np.zeros_like(tt)

        # First deal with the 1st exposure and last exposure
        self.tlo[ccd_id][0] = tt[0] - self.timedel_d2 - self.flushtime
        self.thi[ccd_id][-1] = tt[-1] + self.timedel_d2

        # Now all the other rows in the file
        for xx in range( 1, len(tt)):
            if ee[xx] == ( ee[xx-1]+1 ):
                self.thi[ccd_id][xx-1] = (tt[xx]+tt[xx-1])/2.0
                self.tlo[ccd_id][xx] = self.thi[ccd_id][xx-1]
            else: # Dropped exposure
                self.thi[ccd_id][xx-1] = tt[xx-1]+self.timedel_d2
                self.tlo[ccd_id][xx]   = tt[xx]-self.timedel_d2 - self.flushtime

        for xx in range(len(tt)):
            verb5( "{}\t{}\t{}\t{}".format(ee[xx], self.tlo[ccd_id][xx], tt[xx], self.thi[ccd_id][xx]))

    def get_exposure_time_boundaries( self ):
        """
        Wrapper around above for each chip
        """
        for ccd_id in self.ccds:
            self.get_per_chip_exposure_time_boundaries(ccd_id)

    def _align_boundary( self, ss, tt, ccd_id ):
        """
        Find the exposure times that bound the start (ss) and stop (tt)
        times.

        If values are outside of range, then start is set to 1st record
        and stop is set to the last record.
        """

        # Due to round offs and other approx's, a time right at the
        # boundary between two exposures may get shifted into the wrong
        # bin.  We artificially shrink the input GTI record so that if
        # will end up in the correct frame even if the FPE noise
        # would otherwise shift into the wrong one.
        delta = 0.001

        ll = [ x for x in range(len(self.thi[ccd_id])) if ss+delta <= self.thi[ccd_id][x] ]
        hh = [ x for x in range(len(self.thi[ccd_id])) if tt-delta >= self.tlo[ccd_id][x] ]
        #ll = filter( lambda x: ss+delta <= self.thi[ccd_id][x], range(len(self.thi[ccd_id] )))
        #hh = filter( lambda x: tt-delta >= self.tlo[ccd_id][x], range(len(self.thi[ccd_id] )))

        if ll and hh:
            ll = self.tlo[ccd_id][ll[0]]
            hh = self.thi[ccd_id][hh[-1]]
            return float(ll),float(hh)
        else:
            return None, None

    def align_boundary( self, gti ):
        """
        Loop over ccd_id's found in the stat1 file and match to
        gti subspace componts.  If no ccd_id subspace, use same
        for all ccd_id's.
        """
        for ccd_id in self.ccds:
            if ccd_id in gti.cpts:
                starts = gti.cpts[ccd_id][0]
                stops = gti.cpts[ccd_id][1]
            elif gti.sentinal in gti.cpts:
                starts = gti.cpts[gti.sentinal][0]
                stops = gti.cpts[gti.sentinal][1]
            else:
                verb0("WARNING: ccd_id {} in stat1 file is not in gti file".format(ccd_id))
                continue

            #new_limits = map( self._align_boundary, starts, stops, [ccd_id]*len(starts) )
            new_limits = []
            for ii in range( len(starts )):
                new_limits.append( self._align_boundary( starts[ii], stops[ii], ccd_id ))

            ns = np.array([x[0] for x in new_limits if x[0]])
            nt = np.array([x[1] for x in new_limits if x[1]])
            self.per_chip_output[ccd_id] = ( ns, nt )

        # Preserve the order of GTIs in infile (if any)
        if gti.ccd_order:
            self.ccds = [ x for x in gti.ccd_order if x in self.ccds ]

    def write_output( self, outfile ):
        """Write the output file using cxcdm routines"""

        from cxcdm import dmTableCreate, dmColumnCreate, dmSubspaceColOpen
        from cxcdm import dmSubspaceColSet, dmTableClose, dmBlockCreateSubspaceCpt, dmSubspaceColSetTableName

        # clobber checking already done
        if os.path.exists( outfile ):
            os.unlink( outfile )

        tab = dmTableCreate( outfile+"[FILTER]" )

        # create the necessary cols, no data/rows
        dmColumnCreate( tab, ["TIME", "CCD_ID"], ["d", "i2"] )

        verbstr = "Filter: "

        # create/fill subspaces
        for ccd_id in self.ccds:
            tdss = dmSubspaceColOpen( tab, "TIME")
            cdss = dmSubspaceColOpen( tab, "CCD_ID")

            dmSubspaceColSet( tdss, self.per_chip_output[ccd_id][0], self.per_chip_output[ccd_id][1] )
            dmSubspaceColSet( cdss, np.array([ccd_id]), np.array([ccd_id]) )
            dmSubspaceColSetTableName( tdss, "GTI{}".format(ccd_id))

            if (1 == len(self.per_chip_output[ccd_id][0])):
                if verbstr:  verbstr += "(ccd_id={},time={}:{})".format(ccd_id,self.per_chip_output[ccd_id][0][0], self.per_chip_output[ccd_id][1][0])
            else:
                verbstr = None

            # if not last one, create a new subspace cpt
            if ccd_id != self.ccds[-1]:
                dmBlockCreateSubspaceCpt(tab)
                if verbstr:  verbstr += "||"

        if verbstr:
            verb1(verbstr)
        dmTableClose(tab)


class GTI():
    """

    """
    sentinal = "no_ccd_dss"  # if input gti has no ccd_id subspace

    def __init__(self):
        pass


class GTIString(GTI):
    """
    """
    def __init__( self, instr ):
        self.cpts = {}
        self.ccd_order = []

        pairs = instr.split(",")
        lo_vals = []
        hi_vals = []
        for p in pairs:
            lo_hi = p.split(":")

            if len(lo_hi) != 2:
                raise ValueError("Invalid time filter string '{}'".format(instr))
            lo = lo_hi[0]
            hi = lo_hi[1]

            if len(lo) == 0:
                lo_vals.append(0) # time had best not be negative!
            else:
                lo_vals.append(float(lo))

            if len(hi) == 0:
                hi_vals.append( 1.0e308 )  # chandra roars into the 3rd millennium
            else:
                hi_vals.append(float(hi))

        # could add checks that
        # hi[i] > lo[i]
        # lo[i+1] > lo[i]
        # hi[i+1] > hi[i]

        self.cpts[self.sentinal] = ( lo_vals, hi_vals )


class GTIFile(GTI):
    """
    Hold multi component gti.  It only keys off the ccd_id being
    different -- and ccd_id can only be in 1 component.  Other things
    like diff regions or expno subspaces are not dealt with here.
    """

    def __init__( self, infile ):
        self.cpts = {}
        self.ccd_order = []
        self.sentinal = "no_ccd_dss"  # if input gti has no ccd_id subspace

        tab = read_file(infile)

        for cno in range(1,100):
            try:
                time_ss = tab.get_subspace_data( cno, "time")
            except IndexError:
                # No interface to know number of subspace components :(
                # keep going until we run out
                break

            start = time_ss.range_min*1.0
            stop = time_ss.range_max*1.0

            try:
                ccd_id = tab.get_subspace_data( cno, "ccd_id")
                for ccd_range in zip( ccd_id.range_min, ccd_id.range_max):
                    for ccd in range( ccd_range[0],ccd_range[1]+1 ):
                        self.cpts[ccd] = ( start, stop )
                        self.ccd_order.append(ccd)

            except Exception:
                self.cpts[self.sentinal] = ( start, stop )


def load_gti( infile ):
    """
    wrapper around object creation
    """

    try:
        gti = GTIFile( infile )
    except Exception:
        gti = GTIString( infile )

    return gti


#
# Main Routine
#
@lw.handle_ciao_errors( toolname, __revision__)
def main():
    # get parameters
    from ciao_contrib.param_soaker import get_params
    from ciao_contrib.runtool import add_tool_history
    from ciao_contrib._tools.fileio import outfile_clobber_checks

    pars = get_params(toolname, "rw", sys.argv,
         verbose={"set":lw.set_verbosity, "cmd":verb1} )

    infile   = pars["times"]
    statfile = pars["statfile"]
    evtfile = pars["evtfile"]
    outfile  = pars["outfile"]
    clobber  = (pars["clobber"] == "yes")

    outfile_clobber_checks( clobber, outfile )

    try:
        gti = load_gti( infile )
    except Exception:
        raise ValueError("ERROR: could not open times value as either a file nor parse as a string.")

    if evtfile and len(evtfile) and "none" != evtfile.lower():
        evt_gti = GTIFile( evtfile )
        if gti.sentinal in gti.cpts:
            gti.ccd_order = evt_gti.ccd_order
        else:
            gti.ccd_order = [ c for c in evt_gti.ccd_order if c in gti.cpts]

    stats = ExposureStatsFile( statfile )
    stats.get_exposure_time_boundaries()
    stats.align_boundary( gti )
    stats.write_output( outfile )

    add_tool_history( outfile, toolname, pars, toolversion=__revision__)


if __name__ == "__main__":
    try:
        main()
    except Exception as E:
        print("\n# "+toolname+" ("+__revision__+"): ERROR "+str(E)+"\n", file=sys.stderr)
        sys.exit(1)
    sys.exit(0)
