#!/usr/bin/env python


'Patch "Secondary Science Corruption" (SSC) in HRC data.'


#  SSC for HRC is described in
#    POG \S 7.12;  <https://cxc.harvard.edu/proposer/POG/html/chap7.html#tth_sEc7.12>
#     or  <https://cxc.harvard.edu/proposer/POG/pdf/MPOG.pdf> page 205
#
#   Technical notes:
#     <https://cxc.harvard.edu/contrib/juda/memos/anomaly/sec_sci/index.html>
#     <https://cxc.harvard.edu/contrib/juda/memos/anomaly/sec_sci/byte_shift.html>
#
#   Detailed info on HRC Secondary Science invalid data:
#     <https://sot.github.io/cheta/pseudo_msids.html>
#
#
#  SSC results from a byte-shift anomaly which occasionally causes a
#  portion of the housekeeping data to be corrupted.  The symptom is
#  dropouts in the dead-time-factor (dtf) which can be seen in a plot
#  of the dtf1 file's dtf valuse vs time, where the dtf is
#  significantly (>10%) below the median value (~1.0).
#
#  In the dtf1 file are several rate columns.  The one that seems to
#  reliably flag SSC is the TOTAL_EVT_COUNT, when it is anomalously
#  high, with values > 4000 (2000 count/s; the default binning is 2s,
#  the dtf file values are integrated over time bins)
#
#  Event data are good during the SSC times.  Standard data
#  processing, however, will create multiple GTI intervals around the
#  low dtf times, rejecting some events, and lowering the
#  dead-time-correction factor (DTDOR).
#
#  Note: telemetry saturation from a bright source can also cause
#  lower the dtf.  That tends to be a more continuous and sustained
#  lowering of the value.  Hence, inspection of the dtf and
#  total_evt_count vs time is recommended.
#
#  This is a prototype script with a simple positional parameter interface
# (pending conversion to a CIAO version with a standard paramio interface).

# Example:
#
#  patch_hrc_ssc.sh obs_22802/dtf1 obs_22802/mtl1 obs_22802/evt1 evt1_new flt1_new 4000
#
#  dtf1 is the archival level 1 DTF file (input)
#  mtl1 is the archival level 1 MTL file (input)
#  evt1 is the archival Level 1 event file (input)
#
#  evt1_new is the output evt1 (with updated header keywords)
#  flt1_new is the updated FLT file (a GTI table) (output)
#
#  4000  is the threshold for the total_evt_rate column, above which
#        we assume means SSC.  If the value is <1.0, then we assume it
#        is a dtf threshold, below which the dtf is considered bad
#        due to SSC  (input)
#
# The output evt and flt files can then be used in standard
# processing (hrc_process_events on)
#

import os
import sys
from tempfile import NamedTemporaryFile
import numpy as np

import ciao_contrib.logger_wrapper as lw
from ciao_contrib.runtool import make_tool
from pycrates import read_file

__toolname__ = "patch_hrc_ssc"
__revision__ = "10 October 2024"

lw.initialize_logger(__toolname__)


def log_wrapper(func):
    'wrapper around logger to check for None'
    def wrapped(msg):
        if msg:
            func(msg)
    return wrapped


verb0 = log_wrapper(lw.get_logger(__toolname__).verbose0)
verb1 = log_wrapper(lw.get_logger(__toolname__).verbose1)
verb2 = log_wrapper(lw.get_logger(__toolname__).verbose2)


def check_for_bad_values(dtf_in, threshold):
    '''# First, see if we have any total_evt_count values greater than the
    # threshold:'''

    infile = f"{dtf_in}[total_evt_count>{threshold}]"
    tab = read_file(infile)

    if tab.get_key_value("SSCFIX") is not None:
        raise IOError(f"This file, {dtf_in}, has already been patched and cannot be patched again.")

    if tab.get_nrows() == 0:
        return False
    return True


def get_median_values(dtf_in, threshold):
    '''# Determine the median DTF and DTF_ERR where values are good:'''

    infile = f"{dtf_in}[total_evt_count<={threshold}]"
    tab = read_file(infile)

    dtf_median = np.median(tab.get_column("DTF").values)
    dtf_err_median = np.median(tab.get_column("DTF_ERR").values)

    verb1(f"DTF_MEDIAN = {dtf_median} ({dtf_err_median})")
    return dtf_median, dtf_err_median


def join_dtf_cols_to_mtl(mtl_in, dtf_in, tmpdir):
    'add the DTF info to the MTL file: (NOTE: column DTF is already present in the MTL)'

    outfile = NamedTemporaryFile(suffix="_mtl.fits",
                                 dir=tmpdir, delete=False)

    dmjoin = make_tool("dmjoin")

    dmjoin.infile = mtl_in
    dmjoin.joinfile = f"{dtf_in}[cols time,TOTAL_EVT_COUNT,DTF_ERR]"
    dmjoin.outfile = outfile.name
    dmjoin.interpolate = "first"
    dmjoin.join = "time"
    vv = dmjoin(clobber=True)
    verb2(vv)

    return dmjoin.outfile


def patch_dtf(infile, outfile, threshold, smooth_count, median_dtf, median_dtf_err):
    'Replace bad DTF w/ median values'

    verb1(f"Patching DTF values in {infile}")

    tab = read_file(infile)

    total_evt_count = tab.get_column("total_evt_count").values
    dtf = tab.get_column("dtf").values
    dtf_err = tab.get_column("dtf_err").values

    kern = np.ones(int(smooth_count))/float(smooth_count)
    tec_smooth = np.convolve(total_evt_count, kern, mode="same")

    idx, = np.where(tec_smooth > float(threshold))

    dtf[idx] = median_dtf
    dtf_err[idx] = median_dtf_err

    tab.write(outfile, clobber=True)

    # Keep here for reference
    # ~ dmtcalc = make_tool("dmtcalc")
    # ~ dmtcalc.infile = infile
    # ~ dmtcalc.outfile = outfile
    # ~ dmtcalc.expression = f"if(total_evt_count:{smooth_count}>{threshold})then(dtf={median_dtf};dtf_err={median_dtf_err})else(dtf=dtf;dtf_err=dtf_err)"
    # ~ vv = dmtcalc(clobber=True)
    # ~ verb2(vv)


def create_new_limits(mtl_in, tmpdir):
    'Create new "limits" table excluding the row containing the HRC "*LV" filters'

    verb1("Creating new GTI limits")

    outfile = NamedTemporaryFile(suffix="_limits.fits",
                                 dir=tmpdir, delete=False)

    tab = read_file(f"{mtl_in}[LIMITS]")

    rows = [f"{i+1}" for i, x in enumerate(tab.get_column(0).values) if 'LV' in x]

    rows_str = ",".join(rows)
    exclude = f"[exclude #row={rows_str}]"

    dmcopy = make_tool("dmcopy")
    dmcopy.infile = f"{mtl_in}[LIMITS]{exclude}"
    dmcopy.outfile = outfile.name
    vv = dmcopy(clobber=True)
    verb2(vv)

    return outfile.name


def make_new_gti(mtl_file, gti_out, mod_limits):
    '''# make  new output file w/ gti table, fflt_out
    # using the edited "limits" table:'''

    verb1("Make new flt file...")

    dmgti = make_tool("dmgti")
    dmgti.infile = mtl_file
    dmgti.outfile = gti_out
    dmgti.lkupfile = mod_limits
    vv = dmgti(clobber=True)
    verb2(vv)


def compute_dtcor(dtf_file, gti_file, tmpdir):
    'Re run hrc_dtfstats to compute new DTCOR'

    verb1("Recomputing DTF stats")

    outfile = NamedTemporaryFile(suffix="_dtfstats.fits",
                                 dir=tmpdir, delete=False)
    dtfstats = make_tool("hrc_dtfstats")
    dtfstats.infile = dtf_file
    dtfstats.outfile = outfile.name
    dtfstats.gtifile = gti_file
    vv = dtfstats(clobber=True)
    verb2(vv)

    tab = read_file(dtfstats.outfile)
    dtcor = tab.get_column("dtcor").values
    verb1(f"New DTCOR={dtcor[-1]}")
    os.unlink(dtfstats.outfile)

    return dtcor[-1]


def update_events_dtcor(evt_in, evt_out, dtcor, tmpdir, fltfile, dtffile):
    'Need to make a temp copy to update DTCOR in header'

    verb1("Updating DTCOR in event file and recomputing EXPOSURE time")

    outfile = NamedTemporaryFile(suffix="_evt.fits", dir=tmpdir,
                                 delete=False)
    dmcopy = make_tool("dmcopy")
    vv = dmcopy(evt_in, outfile.name, clobber=True)
    verb2(vv)
    verb2("Copy 1 done")

    dmhedit = make_tool("dmhedit")
    dmhedit(outfile.name, file="", op="add", key="DTCOR", value=dtcor)

    basefile = os.path.basename(fltfile)
    dmhedit(outfile.name, file="", op="add", key="FLTFILE", value=basefile)

    basefile = os.path.basename(dtffile)
    dmhedit(outfile.name, file="", op="add", key="DTFFILE", value=basefile)
    verb2("Update done")

    vv = dmcopy(f"{outfile.name}[time=:]", evt_out, clobber=True)
    verb2(vv)
    verb2("Copy 2 done")

    os.unlink(outfile.name)


def add_metadata(outfile, pars):
    """# - add new keywords to flt:  SSC=[T/F] SSCFIX=[T/F]
    #
    # TBR: could be, e.g., SSC=0 (not present),  SSC=1 (present), SSC=2 (present, patched)
    #  instead of using two keywords.
    #
    """
    from ciao_contrib.runtool import add_tool_history

    dmhedit = make_tool("dmhedit")
    dmhedit(outfile, file="", op="add", key="SSC", value="T",
            datatype="boolean")
    dmhedit(outfile, file="", op="add", key="SSCFIX", value="T",
            datatype="boolean")

    add_tool_history(outfile, __toolname__, pars,
                     toolversion=__revision__)


def get_parameters():
    'Load parameters from parameter file'
    from ciao_contrib.param_soaker import get_params

    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1},
                      revision=__revision__)

    from ciao_contrib._tools.fileio import outfile_clobber_checks
    for outfile in ("evt_outfile", "gti_outfile", "dtf_outfile"):
        outfile_clobber_checks(pars["clobber"], pars[outfile])

    return pars


@lw.handle_ciao_errors(__toolname__, __revision__)
def main():
    'Main routine'

    pars = get_parameters()

    if check_for_bad_values(pars["dtf_infile"], pars["threshold"]) is False:
        verb0("SSC not detected, no action required.")
        return

    verb1("SSC detected; patching")

    median_dtf, median_dtf_err = get_median_values(pars["dtf_infile"],
                                                   pars["threshold"])

    patch_dtf(pars["dtf_infile"], pars["dtf_outfile"], pars["threshold"],
              pars["smooth_count"], median_dtf, median_dtf_err)

    mod_mtl = join_dtf_cols_to_mtl(pars["mtl_infile"], pars["dtf_infile"],
                                   pars["tmpdir"])

    outfile = NamedTemporaryFile(suffix="_mtl.fits",
                                 dir=pars["tmpdir"],
                                 delete=False)
    patch_dtf(mod_mtl, outfile.name, pars["threshold"],
              pars["smooth_count"], median_dtf, median_dtf_err)
    os.unlink(mod_mtl)

    mod_limits = create_new_limits(pars["mtl_infile"], pars["tmpdir"])

    make_new_gti(outfile.name, pars["gti_outfile"], mod_limits)
    os.unlink(mod_limits)

    dtcor = compute_dtcor(pars["dtf_outfile"], pars["gti_outfile"],
                          pars["tmpdir"])

    add_metadata(pars["dtf_outfile"], pars)

    os.unlink(outfile.name)

    update_events_dtcor(pars["evt_infile"], pars["evt_outfile"], dtcor,
                        pars["tmpdir"], pars["gti_outfile"],
                        pars["dtf_outfile"])

    add_metadata(pars["evt_outfile"], pars)


if __name__ == '__main__':
    main()
