#!/usr/bin/env python
# 
# Copyright (C) 2013,2014,2016,2018,2019,2022,2024 
# Smithsonian Astrophysical Observatory
# 
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# 

toolname = "combine_grating_spectra"
__revision__ = "17 April 2024"

from pycrates import read_file
import stk
import os
import sys

import ciao_contrib.logger_wrapper as lw
lw.initialize_logger(toolname)
lgr = lw.get_logger(toolname)
verb0 = lgr.verbose0
verb1 = lgr.verbose1
verb2 = lgr.verbose2


from collections import namedtuple

tgpart = { 0 : 'ZO', 1 : 'HEG', 2 : 'MEG', 3 : 'LEG' }
SpectrumFiles = namedtuple( "SpectrumFiles", [ 'src', 'bkg', 'arf', 'rmf'] )


def gorm( infile ):
    """
    Wrapper around the os.remove command -- setting SAVE_ALL to anything 
    will keep all temp products for debugging
    """
    if 'SAVE_ALL' not in os.environ:
        if os.path.exists( infile ):
            os.remove(infile)


class delme( str ):
    """
    The temp files created by this script are based on the output
    root name.  Under normal conditions they are removed automatically
    however, if there is an error then since the threads
    are independent we can't be sure which files were created and
    what needs to be cleaned up.  
    
    Using this delme() class, we id the files that should always be
    removed -- either will be via the normal code paths, or if
    not when the string gets deleted/goes out of scope, the
    __del__ is called to make sure the files they reference
    are removed.
    """
    def __del__( self ):
        infile = self.__str__().split("[")[0]  # may have a DM filter
        gorm(infile)


def determine_pha_type( infile ):
    """
    ID the pha file type.  The HDUCLAS keyword is a bit fragile, so 
    instead we just look at the channel column.  If it's an array
    then it's TYPE:II, if it's a scalar column then TYPE:I  
    """
    tab = read_file(infile)
    cts = tab.get_column("channel").values
    if len(cts.shape) == 1:
        return "TYPE:I"
    if len(cts.shape) == 2:
        return "TYPE:II"
    else:
        raise IOError( "Bad input file '{}'.  Cannot determine if TYPE:I or TYPE:II".format(infile))



def get_unique_id( infile, use_abs=False ):
    """
    Get unique ID for each spectrum & arf

    If use_abs == True, then we use the absolute value of the
    keywords.  The tg_part doesn't matter, but for the order it
    can be +/-; this gives us a way to combine the postive and
    negative orders using same basic routines.    
    """
    # err, tg_srcid isn't correct in response files so can't use it.
    keys = [ 'tg_m', 'tg_part']

    tab = read_file( str(infile) )

    retval = []
    for k in keys:
        if use_abs:
            retval.append(abs(tab.get_key_value(k)))
        else:
            retval.append(tab.get_key_value(k))

    Config = namedtuple( 'Config', keys )    
    return Config( *retval )
 
 
def check_infile( infile ):
    """
    Run some checks on the infile to make sure type:II pha (ish)
    
    """
    tab = read_file(infile)
    cols = [x.lower() for x in tab.get_colnames()]
    keys = [x.lower() for x in tab.get_keynames()]
    musthave = ["tg_m", "tg_part", "counts", "channel"]
    for m in musthave:
        if m not in cols and m not in keys:
            raise IOError("Missing '{}' column or keyword in '{}'".format(m,infile))


def is_none( infile ):
    """
    Check for None, "none", and ""
    """
    if not infile:
        return True
    if infile.lower() in ["none", ""]:
        return True
    return False


def split_II_to_I( instk, arfs, rmfs, outroot, verbose, clobber ):
    """
    Combine spectra only accepts type I files, ie values
    are stored in rows.  Where as type II files have values
    stored as arrays in each column.
    
    We need to split typeII files into typeI files, but 
    we also need to associate the ARF and RMF files with
    each row.
    """    
    import subprocess as sp
    import glob as glob

    verb1("Splitting TYPE:II pha files into TYPE:I")

    import ciao_contrib.runtool as rt

    tgsplit = rt.make_tool("tgsplit")
    

    retsrc = []
    retbkg = []
    for ii,ss in enumerate(instk):
        tmproot = outroot+"tmp{:02d}".format(ii)

        tgsplit.infile =ss
        tgsplit.arffile = arfs
        tgsplit.rmffile = rmfs
        tgsplit.outroot = tmproot
        tgsplit.clobber = clobber
        tgsplit.verbose = verbose
        v = tgsplit()
        if v : verb2(v)

        #
        # The output files names are just going to be temporary
        # files.  We wrap them in a 'delme' object so that they are
        # automatically removed whenever the variable goes out of scope.
        #
        outs = [delme(x) for x in glob.glob( tmproot+"_?eg_??.pha")]
        bkgs = [delme(x) for x in glob.glob( tmproot+"_?eg_??_bkg.pha")]

        outs.sort()
        retsrc.extend(outs)

        if ( 0 == len(bkgs) ):
            retbkg.extend([None]*len(outs))
        elif len(bkgs) == len(outs):
            bkgs.sort()
            retbkg.extend(bkgs)
        else:  # should never happen but just in case
            raise IOError("Mismatch in number of source and background PHA files")

    for sss,bbb in zip(retsrc,retbkg):
        if bbb:
            sroot = sss.rstrip(".pha")
            broot = bbb.rstrip("_bkg.pha")
            assert sroot == broot, "Something is messed up in sort/bkg matching"

    return retsrc,retbkg

        
def filter_spectra( combos, arm, order, add_pm ):
    """
    Filter the pha spectrum based on the desired grating arm
    and order.
    """
        
    if arm.lower() == "all":
        retvals = combos
    else:
        tgp = [x for x in tgpart if tgpart[x] == arm.upper() ] # filter( lambda x : tgpart[x] == arm.upper(), tgpart )
        retvals = [x for x in combos if x.tg_part == tgp[0] ] # filter( lambda x : x.tg_part == tgp[0], combos )

    if order.lower() == "all":
        return retvals

    #
    # If adding plus and minus, then be sure to return both +Nth and -Nth
    # order (ie abs(order))
    #
    if add_pm:
        return [x for x in retvals if abs(x.tg_m) == abs(int(order)) ]#filter( lambda x: abs(x.tg_m) == abs(int(order)), retvals)
    else:
        return [x for x in retvals if x.tg_m == int(order) ] # filter( lambda x: x.tg_m == int(order), retvals)


def add_to_combo( combo, key, val ):    
    """
    Add an item to the list if it already exists, otherwise create it.
    """
    if key not in combo:
        combo[key] = [ val ]
    else:
        combo[key].append(val)


def add_typeII(instk, bkg, arfs, rmfs, outroot, verbose, clobber, useabs=False):
    """
    Loop through the type:II files and find the 
    different grating configurations.
    
    """
    if not is_none( bkg ):
        verb0("WARNING: Background files are ignored for TYPE:II pha files")
        bkg=""
    
    instk, bkgstk = split_II_to_I( instk, arfs, rmfs, outroot, verbose, clobber )

    src_order_part = [get_unique_id(ii, useabs) for ii in instk ] #map( get_unique_id, instk, [useabs]*len(instk) )

    tocombo = {}
    for ii in range(len(instk)):
        #
        # When typeII files are split (above), the ARF and RMF
        # files are already ID'd and set in the temporary
        # typeI file header's ANCRFILE and RESPFILE keywords.  So
        # they do not need to be specified here again.
        # 
        val =SpectrumFiles( instk[ii], bkgstk[ii], None, None )
        key = src_order_part[ii]
        add_to_combo( tocombo, key, val )

    return tocombo
    

def try_nonstd_files( instk, arfs, rmfs, outroot ):
    """
    tgcat saves typeI files in a non standard way
    (background columns rather than in a separate file)
    
    Trying to work around them
    """
    import glob as glob

    def chk_bkg( infile ):
        tab = read_file( str(infile) )
        cols = [x.lower() for x in tab.get_colnames()]
        if "bg_counts" in cols or "background_up" in cols:
            return True
        return False
        
    chk = [chk_bkg(ii) for ii in instk] # map(chk_bkg, instk )
    if not all( chk ):
        #Not tgcat files
        return None, None

    import ciao_contrib.runtool as rt
    tgsplit = rt.make_tool("tgsplit")
    tmproots = [ outroot+"_tmp{:02d}".format(ii) for ii in range( len(instk)) ]    

    try:
        for ii,tt in zip(instk, tmproots):
            tgsplit( ii, arfs, rmfs, tt )
        src = [ delme(i) for tt in tmproots for i in glob.glob(tt+"_?eg_??.pha") ]
        bkg = [ delme(i) for tt in tmproots for i in glob.glob(tt+"_?eg_??_bkg.pha") ]
        return src,bkg
    except Exception as e:
        verb1(str(e))
        verb1("Problem running tgsplit on type:I files, no background will be used.")
        return None, None



def add_typeI(instk, bkg, arfs, rmfs, useabs=False, outroot="./"):
    """
    Loop through the type:I pha files.
    
    If the background, arf, and/or rmf are supplied, they
    must be supplied in the same order as the src pha files.
    
    If they are not supplied, the combine_spectra tool will
    try to locate and use them.
    
    """
    def stkmatch( param, flavor ):
        if is_none(param):
            bkgstk = [param] * len(instk)
        else:
            bkgstk = stk.build(param)
            if len(bkgstk) != len(instk):
                raise RuntimeError("ERROR: If {} are supplied, they must be supplied for all files".format(flavor))
        return bkgstk
        
    tmpin = None
    if useabs == False and len(bkg.strip()) == 0:
        verb1("Background stack is blank, will see if files are from tgcat")
        tmpin, tmpbkg = try_nonstd_files( instk, arfs, rmfs, outroot )
        
    if tmpin:
        instk = tmpin
        bkgstk = tmpbkg
    else:
        bkgstk = stkmatch( bkg, "backgrounds")


    arfstk = stkmatch( arfs, "ARFs" )
    rmfstk = stkmatch( rmfs, "RMFs") 

    src_order_part = [get_unique_id(ii,useabs) for ii in instk] # inmap( get_unique_id, instk, [useabs]*len(instk) )

    # If user has provided filenames, we require they have provided in correct
    # order

    tocombo = {}
    for ii in range(len(instk)):
        val = SpectrumFiles( instk[ii], bkgstk[ii], arfstk[ii], rmfstk[ii] )
        key = src_order_part[ii]
        add_to_combo( tocombo, key, val )
        
    return tocombo


def get_from_list( files, flavor ):
    """
    Returns the src|bkg|arf|rmf list from the SpectrumFiles tuple
    """
    retvals = [ getattr( f, flavor) for f in files ]
    tt = [is_none(ff) for ff in retvals ]
    if all( tt ):
        retvals = retvals[0]
    return retvals

    
#arm, order, add_pm, outroot, clobber, exppref, method=None, sign=None ):
def combine( tocombo, pars, add_pm, method=None, sign=None):
    """
    Combine the spectra.
    
    When adding +/-, the method will be avg.  When summing observations
    the method="sum"

    """
    from ciao_contrib.runtool import make_tool
    combine_spectra = make_tool("combine_spectra")

    outstk = []
    tt = filter_spectra( tocombo, pars["garm"], pars["order"], add_pm )
    for key in tt:
        combine_spectra.punlearn()
        combine_spectra.src_spectra = get_from_list(tocombo[key], "src")
        combine_spectra.bkg_spectra = get_from_list(tocombo[key], "bkg")
        combine_spectra.src_arfs = get_from_list(tocombo[key], "arf")
        combine_spectra.src_rmfs = get_from_list(tocombo[key], "rmf")
        combine_spectra.method = method
        combine_spectra.exp_origin = "arf"  # pars["exp_origin"]
        combine_spectra.bscale_method = pars["bscale_method"]
        combine_spectra.verbose=pars["verbose"]
        combine_spectra.clobber=pars["clobber"]
        
        if sign is None:
            sgn = 'p' if key.tg_m > 0 else 'm'
        else:
            sgn = sign
        
        arm = tgpart[key.tg_part].lower()
        combine_spectra.outroot = "{}combo_{}_{}{}".format( pars["outroot"], arm, sgn, abs(key.tg_m))

        vv = combine_spectra()
        if vv: verb0( vv )
        
        outstk.append( combine_spectra.outroot+".pha")
    
    return outstk
    

def update_metadata( outfile, pars ):
    """
    Update the keywords, only OBJECT right now
    """

    tab = read_file( outfile, mode="rw") 
    
    if tab.key_exists("OBJECT") and not is_none( pars["object"] ):
        tab.get_key("OBJECT").value = pars["object"]
    
    # These keywords are merged and should be deleted
    for kw in ["TG_M", "SPEC_NUM"]:
        if tab.key_exists(kw):
            tab.delete_key(kw)
    
    
    tab.write()


@lw.handle_ciao_errors( toolname, __revision__)
def main():
    """
    """
    from ciao_contrib.param_soaker import get_params
    from ciao_contrib.runtool import add_tool_history

    pars = get_params(toolname, "rw", sys.argv, 
        verbose={"set":lw.set_verbosity, "cmd":verb2} )

    if pars["exposure_mode"] == "avg":
        verb0("\nEXPOSURE and LIVETIME values will be averaged.  If this is not desired behavior, set exposure_mode=sum\n")


    instk = stk.build(pars["infile"])
    for ii in instk:  
        check_infile(ii)
    flavor = [determine_pha_type(ii) for ii in instk ] # map(determine_pha_type, instk)

    # Check for type:II vs type:I -ness. Should be 1 unique value.
    if len(set(flavor)) != 1:  
        raise IOError("Input appears to a mix of TYPE:I and TYPE:II pha files. This is not supported.")

    root_is_dir = os.path.isdir(pars["outroot"]) 
    strip_outroot = True
    if not root_is_dir:
        pars["outroot"] = pars["outroot"]+"_"
    elif not pars["outroot"].endswith("/"):
        pars["outroot"] = pars["outroot"]+"/"
    else:
        strip_outroot = False


    if "TYPE:II" in flavor:
        verb1("Input TYPE:II pha file(s).  The output will be one or more TYPE:I pha files")
        tocombo = add_typeII( instk, pars["bkg_pha"], pars["arf"], pars["rmf"], 
            pars["outroot"], pars["verbose"], pars["clobber"] )
    else:
        tocombo = add_typeI( instk, pars["bkg_pha"], pars["arf"], pars["rmf"], outroot=pars["outroot"] )

    add_pm = True if pars["add_plusminus"] == "yes" else False
    verb1("Combining spectra across observations")
    outstk = combine( tocombo, pars, add_pm, method=pars["exposure_mode"] )

    if add_pm:
        # Note: tmp type:I files now deleted, the +/- files will be deleted
        # when tool exits.
        outstk = [ delme(o) for o in outstk ]
        bkgstk = [ delme(o.replace('.pha','_bkg.pha')) for o in outstk]
        arfstk = [ delme(o.replace('.pha','.arf')) for o in outstk]
        rmfstk = [ delme(o.replace('.pha','.rmf')) for o in outstk]
        verb1("Combining plus and minus orders")

        tocombo = add_typeI( outstk, "", "", "", useabs=True, outroot=pars["outroot"] )
        outstk = combine( tocombo, pars, add_pm, method='avg', sign="abs")


    if strip_outroot:
        pars["outroot"] = pars["outroot"][:-1]

    for oo in outstk:
        add_tool_history( oo, toolname, pars, toolversion=__revision__)
        update_metadata( oo, pars )
        verb1("Created output source spectrum '{}'".format(oo))
        if os.path.exists( oo.replace(".pha", "_bkg.pha")):
            update_metadata( oo.replace(".pha", "_bkg.pha"), pars )
            verb1("    with background spectrum '{}'".format(oo.replace('.pha', '_bkg.pha')))
        if os.path.exists( oo.replace(".pha", ".arf")):
            update_metadata( oo.replace(".pha", ".arf"), pars)
            verb1("    with source ARF '{}'".format(oo.replace('.pha', '.arf')))
        if os.path.exists( oo.replace(".pha", ".rmf")):
            #update_metadata( oo.replace(".pha", ".rmf"), pars )
            # Crates crashes on rmf files
            verb1("    with source RMF '{}'".format(oo.replace('.pha', '.rmf')))
        verb1("")
            
            

            
if __name__ == "__main__":
    try:
        main()
    except Exception as E:
        print("\n# "+toolname+" ("+__revision__+"): ERROR "+str(E)+"\n", file=sys.stderr)
        sys.exit(1)
    sys.exit(0)

