#!/usr/bin/env python
# 
# Copyright (C) 2013,2016-2018, 2020, 2021, 2022 
# Smithsonian Astrophysical Observatory
# 
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
# 
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# 
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
# 

toolname = "mktgresp"
__revision__ = "15 August 2022"


import os

from ciao_contrib.runtool import make_tool
from ciao_contrib.runtool import add_tool_history

import ciao_contrib._tools.fileio as fileio
import sys
import os

import ciao_contrib.logger_wrapper as lw
lw.initialize_logger(toolname)
lgr = lw.get_logger(toolname)
verb0 = lgr.verbose0
verb1 = lgr.verbose1
verb2 = lgr.verbose2
verb3 = lgr.verbose3
verb5 = lgr.verbose5

from ciao_contrib.param_soaker import *

from collections import namedtuple
names = ("part", "order", "tgid", "xx", "yy" )
Spectrum =  namedtuple( 'Spectrum', names )

class ObservationFiles():
    """
    Some place to store all the file names
    """

    def __init__(self, evt, asp=None, bpx=None, msk=None, dtf=None ):
        import ciao_contrib._tools.obsinfo as o
        
        self.evt = evt
        self.asp = asp
        self.bpx = bpx
        self.msk = msk
        self.dtf = dtf
        
        self.keywords = fileio.get_keys_from_file(self.evt)
        obs = o.ObsInfo( evt )
        
        
        if self.asp is None or '' == self.asp:
            if "ASOLFILE" not in self.keywords:
                raise ValueError("ERROR: Missing ASOL file information")
            self.asp = obs.get_asol() 

        if self.bpx is None or '' == self.bpx:
            if "BPIXFILE" not in self.keywords:
                raise ValueError("ERROR: Missing Badpixel file information")
            self.bpx = obs.get_ancillary("bpix")

        if self.msk is None or '' == self.msk:
            if "MASKFILE" not in self.keywords:
                raise ValueError("ERROR: Missing Mask file information")
            self.msk = obs.get_ancillary("mask")
 
        # HRC Only
        if self.dtf is None or '' == self.dtf:
            if "DTFFILE" not in self.keywords:
                self.dtf = ""
            else:
                self.dtf = obs.get_ancillary("dtf")
         


def map_part_to_name( part ):
    """
    Convert integer TG_PART value into its string name
    """
    parts_map = { 0 : "ZERO", 1 : "HEG", 2 : "MEG", 3 : "LEG" }
    
    if part not in parts_map:
        raise ValueError("Invalid TG_PART={}".format(part))
    
    return parts_map[part]
    

def parse_orderlist( orderlist, parts, ids, xx, yy ):
    """
    
    """
    import stk as stk
    if orderlist is None or 0 == len(orderlist): 
        return None

    try:
        ss = stk.build( orderlist )
    except:
        verb0("ERROR building error list stack from '{}'.  Using orders in input pha file".format(orderlist))
        return None

    if len(ss) == 1 and ss[0] in ["INDEF", "default"]:
        return None

    try:
        ol = list(map(int, ss))
    except:
        verb0("Cannot convert orderlist values to integers.  Using orders in input pha file")
        return None

    upart = set( parts )  # unique parts
    uids = set(ids)       # unique src ids
    uxy = set( zip(xx,yy)) # unique src coords (need to test as a tuple/pair)

    zz = [ (p,o,i,xy[0],xy[1]) for p in upart for i in uids for o in ol for xy in uxy]

    return zz



def get_order_and_arm( phafile, orderlist ):
    """
    Load the order, part, information from pha file
    """
    import pycrates as pyc
    
    myf = pyc.read_file(phafile)

    try:
        parts = myf.get_column("tg_part").values.copy()
        orders = myf.get_column("tg_m").values.copy()
        ids = myf.get_column("tg_srcid").values.copy()
        xx = myf.get_column("x").values.copy()
        yy = myf.get_column("y").values.copy()
    except:
        parts = [myf.get_key_value("tg_part")]
        orders = [myf.get_key_value("tg_m")]
        ids = [myf.get_key_value("tg_srcid")]
        xx = [myf.get_key_value("x")]
        yy = [myf.get_key_value("y")]

    zz = parse_orderlist( orderlist, parts, ids, xx, yy )

    if zz is None:
        zz = zip( parts, orders, ids, xx, yy )

    # Save the data in a named tuple    
    
    retval = [  Spectrum(*z) for z in zz ]
    
    return retval


def determine_src_chip( x, y, obsinfo ):
    """
    Determine which chip 0th order falls on. Need by mkgrmf
    """    

    dmcoords  = make_tool("dmcoords")
    inst = obsinfo.keywords["INSTRUME"]
    if inst not in ["ACIS", "HRC"]:
        raise ValueError("Invalid INSTRUME keyword in {}".format(obsinfo.evt))
    
    dmcoords.x        = x
    dmcoords.y        = y
    dmcoords.opt      = "sky"
    dmcoords.asolfile = obsinfo.asp
    dmcoords.infile   = obsinfo.evt
    dmcoords()
    
    chip = dmcoords.chip_id    
    if "HRC" == inst and chip > 1:
        return "HRC-S2"
    elif "HRC" == inst and chip == 0:
        return "HRC-I"
    else:
        return "ACIS-{}".format(chip)


def get_outfile_name( outroot, spectrum, filetype ):
    """
    Construct the output file
    """

    outfile = outroot
    outfile += "{}".format( map_part_to_name( spectrum.part ).lower())

    pm = "p"     # plus, +
    if spectrum.order < 0:
        pm = "m" # minus, -
    outfile += "_{}{}.{}".format( pm, abs(spectrum.order), filetype)

    return outfile
    
    
def make_tg_rmf( obsinfo, pha, spectrum, outroot, wvgrid_arf, wvgrid_chan):
    """
    Run mkgrmf, this happens 1st so the arf can use it's grid
    """
    verb1("  making RMF")

    mkgrmf = make_tool("mkgrmf")
    det = determine_src_chip( spectrum.xx, spectrum.yy, obsinfo )
    part = map_part_to_name( spectrum.part )
    if det == "HRC-I":
        if part.lower() in ["meg", "heg"]:
            # TODO: If/when HRC-I + HETG lsfparams are released then 
            # we should merge this with the else LEG section.
            verb0("WARNING: There are no grating RMF calibration products for HRC-I+HEG/MEG.  Will create a diagonal RMF")
            mkgrmf.diagonalrmf = True

        else:    # leg
            # Need to check if lsfparm file is available in CALDB 
            from caldb4 import Caldb
            cc = Caldb("chandra", "hrc", "lsfparm")
            cc.detnam = det
            cc.grating = obsinfo.keywords["GRATING"]
            cc.grattype = part
            cc.order = spectrum.order
            n_files = len(cc.search)
            
            if 0 == n_files:
                # CALDB not upgraded
                mkgrmf.diagonalrmf = True
                msg="""WARNING: Please consider upgrading CALDB versions. 
                There are no grating RMF calibration products for HRC-I.  
                Will create a diagonal RMF"""
            else:
                msg="""WARNING: There are no grating RMF calibration 
                products for HRC-I. An RMF will be created using 
                LETG+HRC-S calibration products. This should be 
                accurate for wavelengths < 60 A. For longer LETG+HRC-I 
                wavelengths, the observed line widths can be increasingly 
                broad compared with the RMF generated here owing to the 
                flat HRC-I detector not following the geometry of the 
                LETG Rowland torus."""                

            # remove white spaces
            msg = " ".join([x.strip() for x in msg.split("\n")])
            verb0(msg)

    mkgrmf.outfile     = get_outfile_name( outroot, spectrum, "rmf" )
    mkgrmf.obsfile     = obsinfo.evt
    mkgrmf.regionfile  = pha    
    mkgrmf.wvgrid_arf  = wvgrid_arf
    mkgrmf.wvgrid_chan = wvgrid_chan
    mkgrmf.order       = spectrum.order
    mkgrmf.srcid       = spectrum.tgid
    mkgrmf.detsubsys   = det
    mkgrmf.grating_arm = part
    mkgrmf.clobber     = True
    oo = mkgrmf()
    if oo:
        verb2(oo)

    return mkgrmf.outfile


def make_asphist( obsinfo, outroot ):
    """
    Run asphist once per chips
    """
    asphist = make_tool("asphist")
    asphist.infile  = obsinfo.asp
    retval = []

    if obsinfo.keywords["INSTRUME"].startswith("ACIS"):
        ccds = obsinfo.keywords["DETNAM"].split("-")[1]
        for ccd in ccds:
            asphist.evtfile = obsinfo.evt+"[ccd_id={}]".format(ccd)
            asphist.dtffile = ""
            asphist.outfile = outroot+"{}.asphist".format(ccd)            
            asphist.clobber = True
            oo = asphist()
            if oo:
                verb0(oo)
            retval.append( asphist.outfile )
    else:  # HRC
        asphist.evtfile = obsinfo.evt
        asphist.dtffile = obsinfo.dtf
        asphist.outfile = outroot+"asphist"
        oo = asphist()
        if oo:
            verb0(oo)
        retval.append( asphist.outfile )
    
    return retval


def run_fullgarf( obsinfo, spectrum, outroot, dafile, osipfile ):
    """
    Run mkgarf and then dmarfadd to combine 
    """
    
    mkgarf = make_tool("mkgarf")
    dmarfadd = make_tool("dmarfadd")

    verb1("  making ARF")

    def filter_mkgarf_verbose( msg ):
        if not msg:
            return None
        mm = msg.split("\n")
        nn = filter( lambda x : 'Warning: CYCLE not found or invalid in' not in x, mm )
        return "\n".join(nn)


    rmffile = get_outfile_name( outroot, spectrum, "rmf" )

    mkgarf.order = spectrum.order
    mkgarf.sourcepixelx = spectrum.xx
    mkgarf.sourcepixely = spectrum.yy
    mkgarf.engrid = "grid({}[cols ENERG_LO,ENERG_HI])".format(rmffile)    
    mkgarf.obsfile = obsinfo.evt
    mkgarf.osipfile = osipfile
    mkgarf.maskfile = obsinfo.msk
    mkgarf.grating_arm = map_part_to_name( spectrum.part )
    mkgarf.dafile = dafile
    mkgarf.clobber = True


    if obsinfo.keywords["INSTRUME"].startswith("ACIS"):
        # postive and negative orders only hit certain chips so we
        # only make ARF for chips that are on (in detnam) and 
        # based on sign of the order
        ccds = [x for x in obsinfo.keywords["DETNAM"].split("-")[1]]
        
        # Due to offsets, need to extend +/- around acis-7
        if spectrum.order < 0:
            ccds = [x for x in ccds if int(x) in [0,1,2,3,4,5,6,7,8]] 
        elif spectrum.order > 0:
            ccds = [x for x in ccds if int(x) in [0,1,2,3,6,7,8,9]] 
        else:
            raise ValueError("Zero order should not be here!")
            
        #
        # ACIS-I does not have any OSIP calibrations.  The upshot for
        # users is that the ARF will be 0.
        #
        acis_i = [ str(x) in ccds for x in [0,1,2,3] ]
        if any( acis_i ):
            mkgarf.osipfile="NONE"
            verb1("Grating ARFS for ACIS-I CCDs are not calibrated")

        toadd = []
        for ccd in ccds:
            verb2("    ccd={}".format(ccd))
            mkgarf.asphistfile = outroot+"{}.asphist".format(ccd)
            mkgarf.outfile = get_outfile_name( outroot, spectrum, "{}.arf".format(ccd))
            mkgarf.detsubsys = "ACIS-{}".format(ccd)

            oo = mkgarf()
            oo = filter_mkgarf_verbose(oo)
            if oo:
                verb1(oo)
            toadd.append( mkgarf.outfile)

        dmarfadd( toadd, get_outfile_name( outroot, spectrum, "arf" ), clobber=True)
        list(map( os.remove, toadd ))

    elif obsinfo.keywords["DETNAM"] == "HRC-I":
        mkgarf.asphistfile = outroot+"asphist"
        mkgarf.outfile = get_outfile_name( outroot, spectrum, "arf" )
        mkgarf.detsubsys = "HRC-I"
        oo = mkgarf()
        oo = filter_mkgarf_verbose(oo)
        if oo: 
            verb1(oo)
        dmarfadd.outfile = mkgarf.outfile  # just to make the return easier below

    elif obsinfo.keywords["DETNAM"] == "HRC-S":
        mkgarf.asphistfile = outroot+"asphist"

        if spectrum.order < 0:
            ccds = [ "S2", "S3" ]
        elif spectrum.order > 0:
            ccds = [ "S2", "S1" ]
        else:
            raise ValueError("No zeroth order please")

        toadd = []
        for ccd in ccds:
            mkgarf.outfile = get_outfile_name( outroot, spectrum, "{}.arf".format(ccd))
            mkgarf.detsubsys = "HRC-{}".format(ccd)
            oo = mkgarf()
            oo = filter_mkgarf_verbose(oo)
            if oo: 
                verb1(oo)
            toadd.append( mkgarf.outfile)

        dmarfadd( toadd, get_outfile_name( outroot, spectrum, "arf" ), clobber=True)
        list(map( os.remove, toadd ))
    else:
        raise ValueError("Unknown combination of insturment={} and detname={}".format( obsinfo.keywords["INSTRUME"], obsinfo.keywords["DETNAM"]))

    return dmarfadd.outfile
        

def set_nproc( pars ):
    if "no" == pars["parallel"]:
        pars["nproc"] = 1
    else:
        if pars["nproc"] != "INDEF":
            pars["nproc"] = int(pars["nproc"])
        else:
            pars["nproc"] = 999 #Hack, else history gets stuck with a pset


def setup_ardlib_bpix( obsinfo ):
    """
    Setup ardlib.par with badpixel files
    """
    import paramio as pio
    if "INSTRUME" not in obsinfo.keywords:
        raise ValueError("Missing 'INSTRUME' keyword in {}".format(obsinfo.evt))
    if "HRC" == obsinfo.keywords["INSTRUME"]:
        if "DETNAM" not in obsinfo.keywords:
            raise ValueError("Missing 'DETNAM' keyword in {}".format(obsinfo.evt))
        elif "HRC-I" == obsinfo.keywords["DETNAM"]:
            pio.pset( "ardlib", "AXAF_HRC-I_BADPIX_FILE", obsinfo.bpx )
        elif "HRC-S" == obsinfo.keywords["DETNAM"]:
            pio.pset( "ardlib", "AXAF_HRC-S_BADPIX_FILE", obsinfo.bpx )
        else:
            raise ValueError("Unknown DETNAM value '{}' in {}".format( obsinfo.keywords["DETNAM"], obsinfo.evt))
    else:
        acis_set_ardlib = make_tool("acis_set_ardlib")
        acis_set_ardlib.badpixfile = obsinfo.bpx
        acis_set_ardlib()

def main_task( obsinfo, pars, dd ):
    """
    
    """

    import os
    from ciao_contrib.runtool import new_pfiles_environment as newpf

    tmpdir = os.environ["ASCDS_WORK_PATH"]

    with newpf(tmpdir=tmpdir, copyuser=False, ardlib=False) as foo:
        verb1("Making responses for {} {} order".format( map_part_to_name(dd.part), dd.order))

        # Setup ardlib
        setup_ardlib_bpix( obsinfo )
        rmf = make_tg_rmf( obsinfo, pars["infile"], dd, pars["outroot"], pars["wvgrid_arf"], pars["wvgrid_chan"] )
        arf = run_fullgarf( obsinfo, dd, pars["outroot"], pars["dafile"], pars["osipfile"] )




#
# Main Routine
#
@lw.handle_ciao_errors( toolname, __revision__)
def main():
    """
    """
    from ciao_contrib._tools.taskrunner import TaskRunner

    # Load the parameters
    pars = get_params(toolname, "rw", sys.argv, 
        verbose={"set":lw.set_verbosity, "cmd":verb1} )

    root_is_dir = os.path.isdir( pars["outroot"])
    strip_outroot=True
    if not root_is_dir:
        pars["outroot"] = pars["outroot"]+"_"
    elif not pars["outroot"].endswith("/"):
        pars["outroot"] = pars["outroot"]+"/"
    else:
        strip_outroot = False


    # Locate aux files
    obsinfo = ObservationFiles( pars["evtfile"], asp=pars["asolfile"], 
        msk=pars["mskfile"], dtf=pars["dtffile"], bpx=pars["bpixfile"] )

    # Determine order, and grating arm pairs
    doit = get_order_and_arm( pars["infile"],pars["orders"] )

    # Remove outfiles 1st, can be slow, esp for many orders so
    # do clobber checks/removes now.
    for dd in doit:
        for ff in ["rmf", "arf"]:
            myfname = get_outfile_name( pars["outroot"], dd, ff )
            fileio.outfile_clobber_checks(pars["clobber"], myfname )
            if os.path.exists( myfname): os.remove( myfname )

    # Create aspect histograms for each chip
    asphist_list = make_asphist( obsinfo, pars["outroot"] )

    try:
        # Create responses
        taskrunner = TaskRunner()
        for dd in doit:
            task_name = "part={} order={}".format(map_part_to_name(dd.part), dd.order)
            taskrunner.add_task( task_name, "", main_task, obsinfo, pars, dd )
        set_nproc(pars)
        taskrunner.run_tasks(processes=pars["nproc"]) 

        ## For some reason, the add_tool_history doesn't work well when
        ## run in parallel (problem with punlearn().  Since this is quick
        ## we just add a loop at the end here.

        _outroot = pars["outroot"]
        if strip_outroot:
            pars["outroot"] = pars["outroot"][:-1]
            
        for dd in doit:
            ff = get_outfile_name( _outroot, dd, 'rmf' )
            add_tool_history( ff, toolname, pars, toolversion=__revision__)
            ff = get_outfile_name( _outroot, dd, 'arf' )
            add_tool_history( ff, toolname, pars, toolversion=__revision__)

    finally:
        # Always clean up asphsit files
        for f in asphist_list:
            os.remove(f)
    

if __name__ == "__main__":
    try:
        main()
    except Exception as E:
        print("\n# "+toolname+" ("+__revision__+"): ERROR "+str(E)+"\n", file=sys.stderr)
        sys.exit(1)
    sys.exit(0)

    
    

