#!/usr/bin/env python
#
# Copyright (C) 2014, 2016, 2019, 2020, 2025
# Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'Convert type:II pha into type:I files and associate response files'

import os
import sys
import numpy as np

from ciao_contrib.runtool import make_tool
from pycrates import read_file, write_file
import stk

import ciao_contrib.logger_wrapper as lw


__toolname__ = "tgsplit"
__revision__ = "06 February 2025"

lw.initialize_logger(__toolname__)
verb0 = lw.get_logger(__toolname__).verbose0
verb1 = lw.get_logger(__toolname__).verbose1
verb2 = lw.get_logger(__toolname__).verbose2

tgpart = {0: 'ZO', 1: 'HEG', 2: 'MEG', 3: 'LEG'}


def get_unique_id(infile):
    """
    Get unique ID for each spectrum & arf
    """
    from collections import namedtuple

    # err, tg_srcid isn't correct in response files so can't use it.
    keys = ['obs_id', 'obi_num', 'cycle', 'tg_m', 'tg_part']

    tab = read_file(infile, "r")

    retval = [tab.get_key_value(k) for k in keys]

    # Bug in tools products sets CYCLE='P' for CC mode, should be None
    if tab.get_key_value("READMODE") == 'CONTINUOUS':
        retval[keys.index('cycle')] = None

    Config = namedtuple('Config', keys)
    return Config(*retval)


def update_ontime(pha):
    """
    The ONTIME/EXPOSURE/LIVETIME keyword in the PHA file is from the
    pha2 file which will come from the aim CCD.

    In some rare cases, this value may be very different than the ONTIME/etc
    for each of the other CCDs. The gratings team has decided that the
    averge ONTIME *for the chips the grating arm covers* should be used.

    This only applies to ACIS.

    This assumes that the aimpoint is on ACIS-7.

    """

    plus = [4, 5, 6, 7]
    minus = [7, 8, 9]

    tab = read_file(pha, mode="rw")

    if "ACIS" != tab.get_key_value("INSTRUME"):
        return

    order = tab.get_key_value("tg_m")
    dtcor = tab.get_key_value("dtcor")

    if order > 0:
        keys = [f"ONTIME{x}" for x in plus]
    elif order < 0:
        keys = [f"ONTIME{x}" for x in minus]
    else:
        raise RuntimeError("The order should not be 0")

    ontimes = [tab.get_key_value(x) for x in keys]
    valid_vals = [x for x in ontimes if x]

    if len(valid_vals) == 0:
        tab.get_dataset().write()
        return

    avg_ontime = np.average(valid_vals)

    tab.get_key("ontime") .value = avg_ontime
    tab.get_key("livetime").value = avg_ontime * dtcor
    tab.get_key("exposure").value = avg_ontime * dtcor

    tab.get_dataset().write()

    return


def copy_infile(infile, tgm, tgp, root, clobber):
    """
    Copy type:I pha files
    """
    dmcopy = make_tool("dmcopy")
    sgn = 'p' if tgm > 0 else 'm'
    arm = tgpart[tgp].lower()
    dmcopy.infile = infile
    dmcopy.outfile = f"{root}{arm}_{sgn}{abs(tgm)}.pha[SPECTRUM]"
    dmcopy.clobber = clobber
    v = dmcopy()
    if v:
        verb2(v)

    verb1(f"Created source spectrum: {dmcopy.outfile}")
    return dmcopy.outfile.replace("[SPECTRUM]", "")


def split_pha2(tgm, tgp, infile, root, clobber):
    """
    Run dmtype2split on the pha file to create the src spectrum.
    """
    dmtype2split = make_tool("dmtype2split")

    sgn = 'p' if tgm > 0 else 'm'
    arm = tgpart[tgp].lower()
    dmtype2split.infile = f"{infile}[tg_m={tgm}, tg_part={tgp}]"
    dmtype2split.outfile = f"{root}{arm}_{sgn}{abs(tgm)}.pha[SPECTRUM]"
    dmtype2split.clobber = clobber
    v = dmtype2split()
    if v:
        verb2(v)

    retval = dmtype2split.outfile.replace("[SPECTRUM]", "")
    dmappend = make_tool("dmappend")
    dmappend.infile = f"{infile}[REGION][subspace -time][tg_m={tgm}, tg_part={tgp}]"
    v = dmappend(outfile=retval)
    if v:
        verb0(v)

    verb1(f"Created source spectrum: {dmtype2split.outfile}")
    return retval


def split_bkg_tge(infile, clobber):
    """
    Split the background from tgextract output.

    In DPH's 'tg_bkg' he deletes the stat_err column from the background
    and sets the POISSERR = True, so we do the same here.
    """

    update_ontime(infile)

    tab = read_file(infile, "r")

    # sum counts, sum backscale values
    bgcts = tab.get_column("background_up").values + tab.get_column("background_down").values
    bgscl = tab.get_key_value('BACKSCUP')+tab.get_key_value('BACKSCDN')

    # save values
    tab.get_column("counts").values = bgcts
    tab.get_key("backscal").value = bgscl
    tab.get_key("POISSERR").value = True

    # delete bkg colums/keys
    tab.delete_column("background_up")
    tab.delete_column("background_down")
    tab.delete_column("stat_err")
    tab.delete_key("backscup")
    tab.delete_key("backscdn")

    # write
    outfile = infile.replace(".pha", "_bkg.pha")
    write_file(tab, outfile, clobber=clobber)

    verb1(f"Created background spectrum: {outfile}")

    # Hack to delete the background columns from the source file.
    tab = read_file(infile, "r")
    tab.delete_column("background_up")
    tab.delete_column("background_down")
    tab.delete_key("backscup")
    tab.delete_key("backscdn")
    tab.get_dataset().write(infile+".tmp", clobber=True)

    os.rename(infile+".tmp", infile)
    return outfile


def split_bkg_tge2(infile, clobber):
    """
    Split the background from tgextract2 output.

    In DPH's 'tg_bkg' he deletes the stat_err column from the background
    and sets the POISSERR = True, so we do the same here.
    """
    update_ontime(infile)

    tab = read_file(infile, "r")

    # sum counts,  sum backscale values
    bgcts = tab.get_column("bg_counts").values
    bgscl = tab.get_column("bg_area").values

    # save values
    tab.get_column("counts").values = bgcts
    tab.get_column("backscal").values = bgscl
    tab.get_key("POISSERR").value = True

    # delete bkg colums/keys
    for cc in ["bg_counts", "bg_area", "bg_err"]:
        if tab.column_exists(cc):
            tab.delete_column(cc)

    # write
    outfile = infile.replace(".pha", "_bkg.pha")
    write_file(tab, outfile, clobber=clobber)

    verb1(f"Created background spectrum: {outfile}")

    # Hack to delete columns from the source spectrum.
    tab = read_file(infile, "r")
    for cc in ["bg_counts", "bg_area", "bg_err"]:
        if tab.column_exists(cc):
            tab.delete_column(cc)

    # use dataset write to copy region block
    tab.get_dataset().write(infile+".tmp", clobber=True)
    os.rename(infile+".tmp", infile)

    return outfile


def get_respfile_from_arf(infile):
    """
    Try to locate the RMF using the information in the ARF.

    The grating ARFs have a 'RESPFILE' keyword that has
    the grid used to construct the arf.

    It looks like

        grid(/maybe/some/path/awesome.rmf[ebounds][cols bin_lo, bin_hi])

    so need to strip out just the file name part of it.

    Assume to be in same dir as the ARF.
    """
    tab = read_file(infile, "r")
    grid = tab.get_key_value("respfile")
    isgrid = grid.split("(")
    if 2 != len(isgrid) or isgrid[0].lower() != 'grid':
        raise ValueError(f"Unexpected RESPFILE found in {infile}")
    # strip trailing ")", remove DM filter,
    ff = isgrid[1].split(")")[0].split("[")[0]

    return os.path.join(os.path.dirname(infile), os.path.basename(ff))


def update_src_headers(src, bkg, arf, rmf):
    """
    Update the spectrum files with the background and
    response file names.
    """

    def withpath(infile):
        retval = os.path.basename(infile) if os.path.dirname(infile) == os.path.dirname(src) else infile
        return retval

    def update_if_not_blank(key, value):
        if len(value.strip()):
            tab.get_key(key).value = value

    tab = read_file(src, "rw")

    update_if_not_blank("backfile", withpath(bkg))
    update_if_not_blank("ancrfile", withpath(arf))
    update_if_not_blank("respfile", withpath(rmf))

    tab.get_dataset().write()


def check_infile(tab, infile):
    """
    Run some checks on the infile to make sure type: II pha (ish)

    """

    cols = [x.lower() for x in tab.get_colnames()]
    keys = [x.lower() for x in tab.get_keynames()]
    musthave = ["tg_m", "tg_part", "counts", "channel"]

    for m in musthave:
        if m not in cols and m not in keys:
            raise IOError(f"Missing '{m}' column in '{infile}'")

    cts = tab.get_column("counts").values
    if len(cts.shape) == 1:
        verb1(f"The file '{infile}' does not appear to be a TYPE: II spectrum.  The 'COUNTS' column is not an array column.")
        return "TYPE: I"

    verb2("Valid TYPE: II input")
    return "TYPE: II"


def split_pha2_and_background(infile, root, clobber):
    """
    Split the pha2 file into src and background PHA files.
    """

    # Get the list of order & part to iterate over
    tab = read_file(infile, "r")
    ftype = check_infile(tab, infile)

    if "TYPE: II" == ftype:
        tgm = tab.get_column("tg_m").values
        tgp = tab.get_column("tg_part").values
        # split into TYPE: I files
        src = [split_pha2(t[0], t[1], infile, root, clobber) for t in zip(tgm, tgp)]
    else:
        tgm = [tab.get_key("tg_m").value]
        tgp = [tab.get_key("tg_part").value]
        src = [copy_infile(infile, tgm[0], tgp[0], root, clobber), ]

    srcmap = [get_unique_id(x) for x in src]
    srcdict = dict(zip(srcmap, src))
    # In a single pha2 file, the order, part should be unique
    if len(srcmap) != len(set(srcmap)):
        raise IOError("ERROR: Duplicate row found in infile!")

    # separate background.  Need to determine if tgextract or tgextract2
    # format.
    cols = [x.lower() for x in tab.get_colnames()]
    if 'background_up' in cols and 'background_down' in cols:
        bkg = [split_bkg_tge(s, clobber) for s in src]
    elif 'bg_counts' in cols and 'bg_area' in cols:
        bkg = [split_bkg_tge2(s, clobber) for s in src]
    else:
        verb0("WARNING: Cannot determine background from PHA file. Background will not be included.")
        return srcdict, {}

    bkgmap = [get_unique_id(b) for b in bkg]
    bkgdict = dict(zip(bkgmap, bkg))

    return srcdict, bkgdict


def identify_response_files(arffile, rmffile):
    """
    Try to identify response files.

    ARF files have the necessary meta data (obsid, obi, cycle, tg_part, tg_m).
    but the RMFs do not.

    The ARF file has a 'RESPFILE' keyword that should point to the
    RMF filename, that needs to be parsed.

    """

    # First deal with ARFs
    try:
        arf = stk.build(arffile)
        arfmap = [get_unique_id(a) for a in arf]
    except Exception as e:
        print(e)
        return {}, {}
    arfdict = dict(zip(arfmap, arf))

    # Now try to match RMFs
    if is_none(rmffile):
        return arfdict, {}

    try:
        _rmf = [get_respfile_from_arf(a) for a in arf]
        rmf = stk.build(rmffile)
    except Exception as e:
        print(e)
        return arfdict, {}

    # Check that the file name in ARF file matches one of those
    # in the input rmffile stack.
    validrmf = [x in rmf or x+".gz" in rmf for x in _rmf]
    if not all(validrmf):
        miss = [x for x in zip(validrmf, _rmf) if not x[0]]
        miss = ", ".join([x[1] for x in miss])
        raise IOError(f"Cannot locate these RMF files: {miss}")
    rmfdict = dict(zip(arfmap, _rmf))  # yes arfmap, rmf missing hdr info

    return arfdict, rmfdict


def is_none(infile):
    """
    Check for None, "none", and ""
    """
    if not infile:
        return True
    if infile.lower() in ["none", ""]:
        return True
    return False


def match_config(config, files, flavor, infile):
    """
    Find configuration in list of files.
    """

    if len(infile.strip()) == 0:
        return ""

    if config in files:
        return files[config]

    if files:
        tgm = config.tg_m
        tgp = config.tg_part
        sgn = '+' if tgm > 0 else '-'
        arm = tgpart[tgp]
        verb1(f"Cannot find {flavor} file for {arm} {sgn}{abs(tgm)}")
    return "NONE"


@lw.handle_ciao_errors(__toolname__, __revision__)
def main():
    """
    Main routine.
    """
    from ciao_contrib.param_soaker import get_params
    from ciao_contrib.runtool import add_tool_history

    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb2})

    infile = pars["infile"]
    arffile = pars["arffile"]
    rmffile = pars["rmffile"]
    root = pars["outroot"]
    clobber = pars["clobber"]

    if not os.path.isdir(root):
        root = root+"_"

    srcdict, bkgdict = split_pha2_and_background(infile, root, clobber)
    for ss in srcdict:
        add_tool_history(srcdict[ss], __toolname__, pars, toolversion=__revision__)
    for bb in bkgdict:
        add_tool_history(bkgdict[bb], __toolname__, pars, toolversion=__revision__)

    if is_none(arffile) and is_none(rmffile):
        arfdict = {}
        rmfdict = {}
    elif is_none(arffile) and not is_none(rmffile):
        raise ValueError("arffile must be set if rmffile is set")
    else:
        arfdict, rmfdict = identify_response_files(arffile, rmffile)

    #
    # Save arf/rmf info in src file.
    #
    for ss in srcdict:
        aa = match_config(ss, arfdict, "ARF", arffile)
        rr = match_config(ss, rmfdict, "RMF", rmffile)
        bb = match_config(ss, bkgdict, "background", "does_not_apply")
        update_src_headers(srcdict[ss], bb, aa, rr)

        # background file has no responses nor background
        if not is_none(bb):
            update_src_headers(bb, "NONE", "NONE", "NONE")


if __name__ == "__main__":
    main()
