#!/usr/bin/env python

# Copyright (C) 2024,2025 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'Compute and apply fine astrometric correction to stack of observations'

import os
import sys
from tempfile import TemporaryDirectory
from dataclasses import dataclass
import shutil

import ciao_contrib.logger_wrapper as lw
from ciao_contrib.runtool import make_tool
from pycrates import read_file


__toolname__ = "fine_astro"
__revision__ = "02 June 2025"

WARN_LARGE_SHIFT = 3.0
WARN_LARGE_OFFSET = 4.0

lw.initialize_logger(__toolname__)

# TODO: CC-mode
# TODO: tg_zo
# TODO: clobber


def log_wrapper(func):
    'wrapper around logger to check for None'
    def wrapped(msg):
        if msg:
            func(msg)
    return wrapped


verb0 = log_wrapper(lw.get_logger(__toolname__).verbose0)
verb1 = log_wrapper(lw.get_logger(__toolname__).verbose1)
verb2 = log_wrapper(lw.get_logger(__toolname__).verbose2)


@dataclass
class OutputDirs():
    'All the dirs that will be used'
    pre_detect: str = "pre"
    detect: str = "detect"
    xmatch: str = "xmatch"
    final: str = "post"
    fine_astro: str = ""


def run_merge_obs(evtlist, out_root, pars, binsize=1):
    'run merge obs'

    merge_obs = make_tool("merge_obs")
    merge_obs.infile = evtlist
    merge_obs.outroot = out_root
    merge_obs.bin = binsize
    merge_obs.psfecf = 0.5
    merge_obs.band = 'default'
    merge_obs.parallel = pars["parallel"]
    merge_obs.nproc = pars["nproc"]
    merge_obs.clobber = pars["clobber"]
    merge_obs.cleanup = pars["cleanup"]
    merge_obs.tmpdir = pars["tmpdir"]
    vv = merge_obs()
    verb2(vv)
    return merge_obs.outroot


def run_wavdetect(pars):
    'run wavdetect to create src list'

    inroot, outroot, obsid, scales, src_filter, tmpdir = pars

    def check_file(infile):
        'Does file exist'
        if os.path.exists(infile) is False:
            raise ValueError(f"Error: cannot locate file '{infile}'")
        return infile

    obi = str(obsid.obsid)
    verb1(f"Running wavdetect on {obi}")

    if obsid.instrument == "ACIS":
        img = check_file(f"{inroot}_{obi}_broad_thresh.img")
        exp = check_file(f"{inroot}_{obi}_broad_thresh.expmap")
        psf = check_file(f"{inroot}_{obi}_broad_thresh.psfmap")
    elif obsid.instrument == "HRC":
        img = check_file(f"{inroot}_{obi}_wide_thresh.img")
        exp = check_file(f"{inroot}_{obi}_wide_thresh.expmap")
        psf = check_file(f"{inroot}_{obi}_wide_thresh.psfmap")
    else:
        raise ValueError("Bad instrument")

    import ciao_contrib.runtool as rt
    with rt.new_pfiles_environment():

        wavdetect = make_tool("wavdetect")
        wavdetect.infile = img
        wavdetect.expfile = exp
        wavdetect.psffile = psf
        wavdetect.outfile = f"{outroot}_{obi}.src"
        wavdetect.scales = scales
        wavdetect.clobber = True

        with TemporaryDirectory(suffix="_wav", dir=tmpdir) as workdir:
            wavdetect.interdir = workdir
            wavdetect.scellfile = f"{workdir}/{obi}.cell"
            wavdetect.defnbkg = f"{workdir}/{obi}.bkg"
            wavdetect.image = f"{workdir}/{obi}.recon"
            vv = wavdetect()
            verb2(vv)

    if src_filter is None or len(src_filter) == 0:
        check_srclist(wavdetect.outfile, obsid.obsid.obsid)
        return wavdetect.outfile

    dmcopy = make_tool("dmcopy")
    dmcopy(f"{wavdetect.outfile}{src_filter}", f"{outroot}_{obi}_filtered.src", clobber=True)

    check_srclist(dmcopy.outfile, obi)

    return dmcopy.outfile


def check_srclist(srclist, obsid=None):
    'Look for common problems with srclist'
    if srclist is None or len(srclist) == 0:
        raise ValueError("No source list found")

    tab = read_file(srclist)  # Checks file exists and is readable

    try:
        tab.get_column("ra")
        tab.get_column("dec")
    except ValueError as nocol:
        raise ValueError(f"Source list {srclist} is missing columns RA or Dec") from nocol

    if obsid is None:
        return

    src_obi = tab.get_key_value("OBS_ID")
    if src_obi is None:
        return

    if src_obi != obsid:
        raise ValueError(f"OBS_ID in {srclist}, {src_obi}, does not match OBS_ID in event file: {obsid}")


def get_ref_srclist(obis, reflist, srcs):
    'figure out reference src list'

    if reflist.lower() in ["indef", "none"]:
        reflist = ""

    # If reflist is set, check it and use it.
    if reflist is not None and len(reflist) > 0:
        check_srclist(reflist)
        return reflist

    # if reflist is None or blank, use longest obsid
    if reflist is None or len(reflist) == 0:

        maxtime = 0
        retval = None
        for obi, src in zip(obis, srcs):
            if obi.get_keyword("EXPOSURE") > maxtime:
                maxtime = obi.get_keyword("EXPOSURE")
                retval = src
        verb1(f"Using {retval} for the reference source list")
        return retval

    for obi, src in zip(obis, srcs):
        if obi.obsid.obsid == reflist or str(obi.obsid) == reflist:
            return src

    raise ValueError("Cannot identify reference source list")


def _warning(msg):
    'print warning'
    if sys.stdout.isatty() and os.getenv('NO_COLOR') is None:
        msg = f"\033[1;33mWarning: {msg}\033[0;0m"
    verb1(msg)


def summarize_xmatch(verbose, infile, outfile):
    'Summarize cross match results'

    def parse_verbose():
        'Parse verbose to get number of sources used'
        srcline = [x for x in verbose.split("\n") if 'sources remain' in x]
        assert len(srcline) > 0, "Problem parsing verbose output"
        splitline = srcline[-1].split(" ")
        assert len(splitline) == 7, "Did verbose output format change?"
        numsrcs = int(splitline[4])
        return numsrcs

    def get_shifts():
        'Get offsets from xform file'
        tab = read_file(outfile)
        xx = tab.get_column("t1").values[0]
        yy = tab.get_column("t2").values[0]
        return (xx, yy)

    def warn_large_offsets():
        'Show a warning for large offsets'
        from math import hypot, fabs
        if fabs(xx) > WARN_LARGE_SHIFT:
            _warning(f"X shift = {xx:6.3f} exceeds warning limit of {WARN_LARGE_SHIFT} pixels")

        if fabs(yy) > WARN_LARGE_SHIFT:
            _warning(f"Y shift = {yy:6.3f} exceeds warning limit of {WARN_LARGE_SHIFT} pixels")

        dd = hypot(xx, yy)
        if dd > WARN_LARGE_OFFSET:
            _warning(f"Total offset = {dd:6.3f} exceeds warning limit of {WARN_LARGE_OFFSET} pixels")

    fname = os.path.basename(infile)

    num = parse_verbose()
    xx, yy = get_shifts()
    fmt = f"{fname:<35s} nsrc={num}\txoff={xx:6.3f}[pix]\tyoff={yy:6.3f}[pix]"
    verb1(fmt)
    warn_large_offsets()


def _make_unit_xform(outfile, refevt):
    'Create a fake unit transform file'

    from pycrates import get_transform

    wcs = get_transform(read_file(refevt), "eqpos")  # Chandra only, OK
    crpix = wcs.get_parameter_value("CRPIX")
    crval = wcs.get_parameter_value("CRVAL")
    cdelt = wcs.get_parameter_value("CDELT")

    with open(outfile, "w", encoding="ascii") as fp:
        fp.write(f"""#TEXT/SIMPLE
# a11 a12 a21 a22 t1 t2 ra_ref dec_ref roll_ref xpix_ref ypix_ref x_scale y_scale
1.0 0.0 0.0 1.0 0.0 0.0 {crval[0]} {crval[1]} 0.0 {crpix[0]} {crpix[1]} {cdelt[0]} {cdelt[1]}
""")


def cross_match(ref_srclist, refevt, srclists, obis, outroot):
    'Do cross match'

    xmatches = []

    for oo, src in zip(obis, srclists):

        wcs_match = make_tool("wcs_match")

        obi = str(oo.obsid)

        wcs_match.infile = src
        wcs_match.refsrcfile = ref_srclist
        wcs_match.wcsfile = refevt
        wcs_match.outfile = f"{outroot}_{obi}.xmatch"
        wcs_match.radius = 4
        wcs_match.method = "trans"
        wcs_match.verbose = 1
        try:
            vv = wcs_match(clobber=True)
        except IOError as failed:
            ss = str(failed)
            if 'Cannot find at least 1 source match' in ss:
                _warning(f"Failed to identify matching sources for OBS_ID {obi}. Proceeding with no astrometric corrections applied.")
                _make_unit_xform(wcs_match.outfile, refevt)
                vv = "After deleting poor matches, 0 sources remain\n"
            else:
                raise failed

        summarize_xmatch(vv, src, wcs_match.outfile)
        xmatches.append(wcs_match.outfile)

    return xmatches


def copy_auxfiles(obsid, out_dirs):
    'copy bpix, mask, dtf to xmatch dir'

    verb2("Copying ancillary files.")
    aux_list = ["bpix", "mask"]
    if obsid.detector.startswith("HRC"):
        aux_list.append("dtf")

    for aux in aux_list:
        aux_file = obsid.get_ancillary(aux)
        if not os.path.exists(aux_file):
            aux_file += ".gz"
            if not os.path.exists(aux_file):
                aux_file.removesuffix(".gz")
                raise IOError(f"Cannot locate auxiliary file: {aux_file}")
        fname = os.path.basename(aux_file)
        shutil.copyfile(aux_file,
                        os.path.join(out_dirs.root_dir,
                                     out_dirs.fine_astro, fname))


def apply_fine_astro(obis, xmatches, refevt, out_dirs, pars):
    'Apply fine astro corrections.  Copies asol, bpix, msk, and dtf files'
    from ciao_contrib.runtool import add_tool_history

    def _update_evt():
        'Update the event file'

        verb2(f"Applying fine astro correction to {in_evt}")
        dmcopy = make_tool("dmcopy")
        vv = dmcopy(in_evt, out_evt, clobber=True)
        verb2(vv)

        wcs_update.infile = out_evt
        wcs_update.outfile = ""
        wcs_update.transformfile = xform
        wcs_update.wcsfile = refevt
        vv = wcs_update()
        verb2(vv)
        add_tool_history(out_evt, __toolname__,
                         pars, toolversion=__revision__)

    def _update_asol():
        'Update aspect solution files'
        asols = oo.get_asol()
        new_asol = []
        for asol in asols:
            verb2(f"Applying fine astro correction to {asol}")
            outfile = os.path.basename(asol).replace("asol1", "fa_asol")
            wcs_update.infile = asol
            wcs_update.outfile = os.path.join(out_dirs.root_dir,
                                              out_dirs.fine_astro, outfile)
            vv = wcs_update(clobber=True)
            verb2(vv)
            add_tool_history(wcs_update.outfile, __toolname__,
                             pars, toolversion=__revision__)

            new_asol.append(os.path.basename(wcs_update.outfile))

        new_asol = ",".join(new_asol)

        verb2(f"Updating ASOLFILE keyword in {out_evt}")
        dmhedit = make_tool("dmhedit")
        dmhedit.infile = out_evt
        dmhedit.filelist = ""
        dmhedit.operation = "add"
        dmhedit.key = "ASOLFILE"
        dmhedit.value = new_asol
        vv = dmhedit()
        verb2(vv)
        return new_asol

    out_root = os.path.join(out_dirs.root_dir, out_dirs.fine_astro,
                            out_dirs.root_base)

    updated_evts = []
    updated_asol = []
    for oo, xform in zip(obis, xmatches):
        obi = str(oo.obsid)

        # I want the original event file, might have used smaller
        # event file for detect
        in_evt = oo.get_evtfile()
        out_evt = f"{out_root}_{obi}_fa_evt.fits"
        wcs_update = make_tool("wcs_update")

        _update_evt()
        new_asol = _update_asol()
        updated_evts.append(out_evt)
        updated_asol.append(new_asol)

    return updated_evts, updated_asol


def detect_sources(obis, pars, out_dirs):
    'Run wavdetect to detect sources'

    def set_nproc():
        'Set number of processors'

        if "no" == pars["parallel"]:
            return 1

        if pars["nproc"] == "INDEF":
            # parallel_map doesn't limit to number of CPU's like
            # taskRunner does so we have to set the actual number
            import multiprocessing
            return multiprocessing.cpu_count()

        return int(pars["nproc"])

    if all(oo.instrument == 'ACIS' for oo in obis):
        binsize = 1
    elif all(oo.instrument == 'HRC' for oo in obis):
        binsize = 2
    else:
        raise ValueError("Cannot merge HRC and ACIS observations")

    if pars["src_lists"].lower() in ["", "none"]:

        verb1("Running pre-detect merge_obs")
        evtlist = [f"{oo.get_evtfile()}{pars['det_filter']}" for oo in obis]

        merge_dir = os.path.join(out_dirs.root_dir, out_dirs.pre_detect)
        if merge_dir:
            os.makedirs(merge_dir, exist_ok=True)

        merge_root = os.path.join(merge_dir, out_dirs.root_base)
        pre_merge_root = run_merge_obs(evtlist, merge_root, pars, binsize=binsize)

        verb1("Running wavdetect")
        det_out = os.path.join(out_dirs.root_dir, out_dirs.detect)
        if det_out:
            os.makedirs(det_out, exist_ok=True)

        det_pars = []
        for oo in obis:
            runpars = (pre_merge_root,
                       os.path.join(det_out, out_dirs.root_base),
                       oo, pars["det_scales"], pars["src_filter"],
                       pars["tmpdir"])
            det_pars.append(runpars)

        from sherpa.utils import parallel_map
        src_lists = parallel_map(run_wavdetect, det_pars,
                                 numcores=set_nproc())

    else:   # pars["src_list"] is not blank/none
        import stk
        src_lists = stk.build(pars["src_lists"])

        if len(src_lists) != len(obis):
            raise ValueError("Mismatched number of observations and source lists.")

        # We have sorted the infile by obsid so that means we need to also
        # source the input src list or else thet may not match any more.
        from ciao_contrib._tools.fileio import get_obsid_object
        src_list_obi = [(str(get_obsid_object(src).obsid),src) for src in src_lists]
        src_lists = [x[1] for x in sorted(src_list_obi)]

        for src, oo in zip(src_lists, obis):
            check_srclist(src, oo.obsid.obsid)

    return src_lists


def summarize_astro(obis, updated_evts, updated_asol):
    'Print summary of fine astro corrections'

    verb1("\nThe following event files were created:")

    for evt in updated_evts:
        verb1(f"{evt}")

    verb1("\nThe following aspect solution files were created:")
    for asps in updated_asol:
        for asol in asps.split(","):
            verb1(asol)

    verb1("\n")


def sort_obis(obis):
    'Sort obsid by obinum'

    vals = [(str(oo.obsid), oo) for oo in obis]
    retval = sorted(vals)
    retval = [oo[1] for oo in retval]
    return retval


def bracketize(filt_str):
    'Make sure filter is enclosed in []'
    if filt_str == "":
        return filt_str
    if not filt_str.startswith('['):
        filt_str = '[' + filt_str
    if not filt_str.endswith(']'):
        filt_str = filt_str + ']'
    return filt_str


def check_pars(pars):
    'Checks on parameter values'

    for filt_str in ['det_filter', 'src_filter']:
        pars[filt_str] = bracketize(pars[filt_str])

    if pars["ref_src_list"].lower() not in ["", "none", "indef"]:
        check_srclist(pars["ref_src_list"])

    if pars["src_lists"].lower() not in ["", "none", "indef"]:
        if pars["ref_src_list"].lower() in ["", "none", "indef"]:
            raise ValueError("User must specify ref_src_list if supplying src_lists")


@lw.handle_ciao_errors(__toolname__, __revision__)
def main():
    'Main routine'

    # Load parameters
    from ciao_contrib.param_soaker import get_params
    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1},
                      revision=__revision__)

    check_pars(pars)

    # ~ from ciao_contrib._tools.fileio import outfile_clobber_checks
    # ~ outfile_clobber_checks(pars["clobber"], pars["outfile"])

    if pars["outroot"] == "":
        raise IOError("ERROR: outroot is blank")

    out_dirs = OutputDirs()
    out_dirs.root_dir = os.path.dirname(pars["outroot"])
    out_dirs.root_base = os.path.basename(pars["outroot"])
    if out_dirs.root_base == "":
        raise IOError("ERROR: outroot must have an output root file name with an optional path name.")

    from ciao_contrib._tools import merging
    obis = merging.validate_obsinfo(pars["infile"])
    obis = sort_obis(obis)

    if obis is None or len(obis) == 0:
        raise ValueError("No valid event files")

    # ------------ Run wavdetect ----------------------

    verb1("Gathering source lists")
    src_lists = detect_sources(obis, pars, out_dirs)


    # ------------ Cross match -------------
    verb1("Running cross matches using wcs_match")

    refevt = obis[0].get_evtfile()  # Arbitrary, doesn't matter
    ref_srclist = get_ref_srclist(obis, pars["ref_src_list"], src_lists)

    xmatch_out = os.path.join(out_dirs.root_dir, out_dirs.xmatch)
    if xmatch_out:
        os.makedirs(xmatch_out, exist_ok=True)

    xmatches = cross_match(ref_srclist, refevt, src_lists, obis,
                           os.path.join(xmatch_out, out_dirs.root_base))

    # ------------ Update ----------
    verb1("Updating astrometry")
    updated_out = os.path.join(out_dirs.root_dir, out_dirs.fine_astro)
    if updated_out:
        os.makedirs(updated_out, exist_ok=True)

    updated_evts, updated_asol = apply_fine_astro(obis, xmatches, refevt, out_dirs, pars)
    summarize_astro(obis, updated_evts, updated_asol)

    if pars["stop"] == "fineastro":
        return

    if pars["stop"] == "mergeobs":
        # ------------ Reproject -----------
        verb1("Running final merge_obs")
        verb2("Copying auxiliary files so they can be used by merge_obs")
        for oo in obis:
            copy_auxfiles(oo, out_dirs)

        merge_root = os.path.join(out_dirs.root_dir, out_dirs.final,
                                  out_dirs.root_base)
        merge_out = os.path.dirname(merge_root)
        if merge_out:
            os.makedirs(os.path.dirname(merge_root), exist_ok=True)
        final_out = run_merge_obs(updated_evts, merge_root, pars)

    elif pars["stop"] == "reproject":
        raise NotImplementedError("Rerproject TBD")
    else:
        raise ValueError("Bad value for stop parameter")


if __name__ == "__main__":
    main()
