#!/usr/bin/env python

#
# Copyright (C) 2012, 2013, 2014, 2015, 2016, 2017, 2019, 2020, 2021, 2023
#           Smithsonian Astrophysical Observatory
#
#
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

r"""
Reproject observations to a common tangent point and (optionally)
merge the event files.

So,

  merge_obs @in.lis out/

can be broken down into

  reproject_obs @in.lis out/
  fluxobs out/\*_reproj_evt.fits out/

"""

import os
import sys
import shutil
import tempfile

import paramio

import ciao_contrib.logger_wrapper as lw

from ciao_contrib.param_wrapper import open_param_file
from ciao_contrib.runtool import add_tool_history

import ciao_contrib._tools.fileio as fileio
import ciao_contrib._tools.merging as merging
import ciao_contrib._tools.run as run
import ciao_contrib._tools.utils as utils

from ciao_contrib._tools.taskrunner import TaskRunner

toolname = 'reproject_obs'
__revision__ = '20 October 2023'

lgr = lw.initialize_logger(toolname)
v1 = lgr.verbose1
v3 = lgr.verbose3
v4 = lgr.verbose4


def get_par(argv):
    "Get parameters from parameter file."

    pfile = open_param_file(argv, toolname=toolname)["fp"]

    # Common parameters:
    params = {}
    pars = {}

    # set verbosity level (intentionally breaking the order of the
    # parameters since we need the verbose level fairly early, and
    # any user who has decided to change verbose to an automatic parameter
    # can handle this break in UI behavior).
    #
    pars["verbose"] = paramio.pgeti(pfile, "verbose")
    params["verbose"] = pars["verbose"]

    lw.set_verbosity(pars["verbose"])
    utils.print_version(toolname, __revision__)

    # files with wcs_information to reproject
    pars['infiles'] = paramio.pgetstr(pfile, 'infiles')
    if pars['infiles'].strip() == '':
        raise ValueError("The infiles parameter is empty.")

    obsinfos = merging.validate_obsinfo(pars['infiles'], colcheck=False)
    params['obsinfos'] = obsinfos
    params['instrument'] = obsinfos[0].instrument

    # Any other warning messages
    params['warn_msgs'] = merging.obsinfo_checks(obsinfos)

    # output root
    pars['outroot'] = paramio.pgetstr(pfile, 'outroot')
    if pars["outroot"] == "":
        raise ValueError("The outroot parameter can not be empty.")

    (params["outdir"], params["outhead"]) = utils.split_outroot(pars["outroot"])
    if params["outdir"] != "":
        fileio.validate_outdir(params["outdir"])

    # aspect solution files
    pars['asolfiles'] = paramio.pgetstr(pfile, 'asolfiles')
    if pars['asolfiles'].lower() == "none":
        raise ValueError("The asolfiles argument can not be set to " +
                         f"{pars['asolfiles']}")

    merging.setup_obsid_asol(obsinfos, pars['asolfiles'])

    pars['merge'] = paramio.pgetb(pfile, 'merge') == 1
    params['evtmerge'] = pars['merge']

    pars['refcoord'] = paramio.pgetstr(pfile, 'refcoord')
    params['refcoord'] = pars['refcoord']

    # parallelize?
    pars['parallel'] = paramio.pgetb(pfile, 'parallel') == 1
    params['parallel'] = pars['parallel']
    if pars['parallel']:
        pars['nproc'] = paramio.pgetstr(pfile, 'nproc')
        if pars['nproc'] == 'INDEF':
            params['nproc'] = None
        else:
            params['nproc'] = paramio.pgeti(pfile, 'nproc')
    else:
        pars['nproc'] = 'INDEF'
        params['nproc'] = 1

    pars['linkfiles'] = paramio.pgetb(pfile, 'linkfiles') == 1
    params['linkfiles'] = pars['linkfiles']

    pars['tmpdir'] = paramio.pgetstr(pfile, 'tmpdir')
    params["tmpdir"] = utils.process_tmpdir(pars['tmpdir'])

    pars['clobber'] = paramio.pgetstr(pfile, 'clobber')
    params['clobber'] = paramio.pgetb(pfile, 'clobber') == 1

    paramio.paramclose(pfile)

    return (pars, params)


def link_or_copy_file(infile, outfile, linkfile=True):
    """Soft link infile into outdir if it does not
    already exist. If a soft link fails then copy the
    file. If linkfile is False then just copy.

    Both infile and outfile are assumed to not end in .gz;
    the code checks to see if infile is actually a .gz file and, if
    so, then '.gz' will be appended to both infile and outfile.
    """

    if _link_or_copy_file(infile, outfile, linkfile=linkfile):
        return

    if _link_or_copy_file(infile + '.gz', outfile + '.gz', linkfile=linkfile):
        return

    raise IOError(f"Input file {infile} not found (also tried .gz)")


def _link_or_copy_file(infile, outfile, linkfile=True):
    "link or copy infile to outfile. See link_or_copy_file"

    if not os.path.exists(infile):
        return False

    if os.path.exists(outfile):
        v3(f"File already exists at {outfile}")
        return True

    cwd = os.getcwd()
    ifile = os.path.join(cwd, infile)
    ofile = os.path.join(cwd, outfile)

    if linkfile:
        v3(f"Soft linking {outfile} to {infile}")
        try:
            os.symlink(ifile, ofile)
        except Exception:
            v3(f"Soft link failed; copying file to {ofile}")
            shutil.copyfile(ifile, ofile)

    else:
        v3(f"Copying {infile} to {outfile}")
        shutil.copyfile(ifile, ofile)

    return True


def dmhedit_line(key, value):
    """Return a string that can be used as line in a
    file sent to dmhedit to change key to value.
    """

    out = f"{key} = "
    if isinstance(value, bool):
        if bool:
            out += "T"
        else:
            out += "F"

    elif isinstance(value, str):
        if '/' in value:
            out += f"'{value}'"
        else:
            out += value

    else:
        out += str(value)

    return out + "\n"


# QUS: do we not have this routine elsewhere?
def update_ancillary_header(infile, adict, tmpdir="/tmp/"):
    """Update the ancillary header keywords in infile
    with the contents of adict, which is a dictionary
    containing keys of asol, bpix, mask, and dtf
    (at least one of these should exist).
    """

    if len(adict) == 0:
        return

    v3(f"Updating ancillary file keywords of {infile}")
    tfile = tempfile.NamedTemporaryFile(dir=tmpdir, mode='w+',
                                        suffix=".dmhedit")
    try:
        tfile.write("#add\n")
        for (k, v) in adict.items():
            tfile.write(dmhedit_line(f"{k}FILE", v))

        tfile.flush()
        run.dmhedit_file(infile, tfile.name)
    finally:
        tfile.close()


def link_ancillary_files(taskrunner,
                         labelconv,
                         preconditions,
                         instrume,
                         obsinfos,
                         reprofiles,
                         outdir,
                         outhead,
                         linkfiles=True,
                         tmpdir="/tmp/"):
    """Create a soft link to (or copy of) the ancillary files
    to <outroot><obsid>.<type>.

    Since interleaved-mode observations has different ancillary
    files for some, we include the cycle name in the output
    for BPIX and MASK files.

    Multi-obi observations have different ASOL, BPIX, and MASK files,
    when multiple OBIs are being used for the ObsId.

    Multiple asol files for an obsid are distinguished by a,b,...
    appended to the obsid. Type values are asol, bpix, mask, and dtf.

    The headers for the reprojected files are updated to point to
    these files.

    The return value is the name of the barrier indicating these tasks have
    finished.

    If an ancillary file can not be found then it is not copied/linked to,
    nor is the header updated. A warning is displayed to the user.

    Support for PBK files was removed in CIAO 4.6 and for multi-obi
    observations was added during CIAO 4.7. Creating of FOV files
    was added in CIAO 4.14.
    """

    outroot = outdir + outhead

    stask = labelconv("link-ancillary-start")
    taskrunner.add_barrier(stask, preconditions)

    def get_key_label(obsid):
        """Useful when we care about ObsId and maybe OBI_NUM
        but not CYCLE.
        """

        # I would lile to just set
        #    label = str(obsid)
        # but although that works for multi-obi cases, which
        # is what we care about here, it does not for interleaved
        # data, as label would end in "e1" or "e2". So for now
        # we recreate some - but not all - of the __str__ method
        # of ciao_contrib._tools.utils.Obsid.
        #
        key = obsid.obsid
        if obsid.is_multi_obi:
            label = f"{obsid.obsid}_{obsid.obi:03d}"
        else:
            label = f"{key}"

        return (key, label)

    # Aspect solution files; complicated by support for interleaved mode
    # and multi-obi (which behave differently; interleaved mode
    # uses the same asol file for different observations, but OBI
    # uses different asol files; however, only need to worry about
    # naming the asol files differently if multiple OBIs for the same
    # ObsId are present).
    #
    # Should interleaved mode just have separate copies of the asol files,
    # even though they are the same?
    #
    tasks = []
    store = {}
    asolcache = {}
    asoltasks = {}
    for obs in obsinfos:

        # The key we use ignores the cycle but may include the OBI
        obsid = obs.obsid
        (key, label) = get_key_label(obsid)

        try:
            newname = asolcache[label]

        except KeyError:
            afiles = obs.get_asol()
            nfiles = len(afiles)
            if nfiles == 0:
                raise IOError(f"*Internal error* ObsId {obsid} has no aspect solution.")

            elif nfiles == 1:
                newname = f"{outroot}{label}.asol"
                taskname = f"link-asol-{label}"
                taskrunner.add_task(taskname, [stask],
                                    link_or_copy_file,
                                    afiles[0],
                                    newname,
                                    linkfile=linkfiles)
                tasks.append(taskname)
                asoltasks[label] = [taskname]
                asolcache[label] = newname
                newname = os.path.basename(newname)

            elif nfiles > 26:
                # SHOULD NOT HAPPEN so I have not been bothered about
                # what the code should do here since continue is a bit
                # dangerous. Really it should error out
                #
                v1("WARNING: too many aspect solutions " +
                   f"({nfiles}) for obsid {obsid}")
                continue

            else:
                newnames = []
                asoltasks[label] = []
                for (idx, afile) in enumerate(afiles):
                    idval = chr(97 + idx)
                    newname = f"{outroot}{label}{idval}.asol"
                    newnames.append(newname)
                    taskname = f"link-asol-{label}-{idval}"
                    taskrunner.add_task(taskname, [stask],
                                        link_or_copy_file,
                                        afile,
                                        newname,
                                        linkfile=linkfiles)
                    tasks.append(taskname)
                    asoltasks[label].append(taskname)

                newname = ','.join(newnames)
                asolcache[label] = newname
                newname = ','.join([os.path.basename(n) for n in newnames])

        # this stores by the tuple (OBS_ID, Cycle, OBI_NUM)
        store[obsid] = {'asol': newname}

    if instrume.lower() == 'hrc':
        itype = 'dtf'
    else:
        itype = None

    # BPIX and MASK depend on the CYCLE and OBI (if in use);
    # HRC/DTF do not depend on CYCLE (it's an ACIS-only mode)
    # but can on OBI.
    #
    instcache = {}
    missing = {}

    def link_file_inst(obsid, filetype, filename):
        "This is now only used for HRC/DTF files"

        if filename is None:
            try:
                missing[filetype].append(obsid)
            except KeyError:
                missing[filetype] = [obsid]
            return

        if filetype in store[obsid]:
            raise IOError("*Internal error* multiple " +
                          f"filetype={filetype} for obsid={obsid}")

        (key, label) = get_key_label(obsid)
        try:
            newname = instcache[label]

        except KeyError:
            newname = f"{outroot}{label}.{filetype}"
            taskname = labelconv(f"link-{filetype}-{label}")
            taskrunner.add_task(taskname, [stask],
                                link_or_copy_file,
                                filename,
                                newname,
                                linkfile=linkfiles)
            tasks.append(taskname)
            newname = os.path.basename(newname)
            instcache[label] = newname

        store[obsid][filetype] = os.path.basename(newname)

    def link_file(obsid, filetype, filename):
        if filename is None:
            try:
                missing[filetype].append(obsid)
            except KeyError:
                missing[filetype] = [obsid]
            return

        if filetype in store[obsid]:
            raise IOError("*Internal error* multiple " +
                          f"filetype={filetype} for obsid={obsid}")

        newname = f"{outroot}{obsid}.{filetype}"
        taskname = labelconv(f"link-{filetype}-{obsid}")
        taskrunner.add_task(taskname, [stask],
                            link_or_copy_file,
                            filename,
                            newname,
                            linkfile=linkfiles)
        tasks.append(taskname)
        store[obsid][filetype] = os.path.basename(newname)
        return (taskname, newname)

    fovtasks = []
    fovfiles = []
    for obs in obsinfos:
        obsid = obs.obsid
        link_file(obsid, 'bpix', obs.get_ancillary_('bpix'))
        try:
            (mtask, mfile) = link_file(obsid, 'mask', obs.get_ancillary_('mask'))
        except TypeError:
            mfile = None
        if itype is not None:
            link_file_inst(obsid, itype, obs.get_ancillary_(itype))

        # make the FOV file
        #
        (key, label) = get_key_label(obsid)
        taskname = labelconv(f"fov-{obsid}")

        pretasks = asoltasks[label]
        if mfile is not None:
            pretasks += [mtask]

        fovname = f"{outroot}{obsid}.fov"
        evtfile = f"{outroot}{obsid}_reproj_evt.fits"  # assume this is correct
        taskrunner.add_task(taskname, pretasks, run.make_fov,
                            evtfile, asolcache[label], mfile, fovname)
        tasks.append(taskname)
        fovtasks.append(taskname)
        fovfiles.append(fovname)

    # Create the combined FOV file
    #
    taskname = labelconv("combined-fov")
    outfov = f"{outroot}merged.fov"
    taskrunner.add_task(taskname, fovtasks, run.combined_fovs,
                        fovfiles, outfov)
    tasks.append(taskname)

    etask = labelconv("link-ancillary-end")
    taskrunner.add_barrier(etask, tasks)

    # Let the user know about missing files
    nmiss = 0
    for (filetype, mobsids) in missing.items():
        nmiss += len(mobsids)
        if len(mobsids) == 1:
            v1(f"Unable to find {filetype} file for ObsId " +
               f"{mobsids[0]}.")
        else:
            v1(f"Unable to find {filetype} file for ObsIds " +
               "{}.".format(" ".join([str(mobsid) for mobsid in mobsids])))

    if nmiss > 0:
        if nmiss == 1:
            mlabel = "file"
        else:
            mlabel = "files"

        v1("")
        v1(f"NOTE: The missing ancillary {mlabel} may affect the " +
           "reliability of further")
        v1("      analysis, such as the output of flux_obs.\n")

    # Now to update the headers of the reprojected event files
    tasks = []
    for (obs, reprofile) in zip(obsinfos, reprofiles):
        obsid = obs.obsid
        try:
            adict = store[obsid]
        except KeyError:
            # At a minimum expect to update the asol header
            raise IOError("*Internal error* no header updates for " +
                          f"ObsId {obsid} ({reprofile})")

        taskname = labelconv(f"update-header-{obsid}")
        taskrunner.add_task(taskname, [etask],
                            update_ancillary_header, reprofile, adict,
                            tmpdir=tmpdir)
        tasks.append(taskname)

    etask = labelconv("update-header-end")
    taskrunner.add_barrier(etask, tasks)
    return etask


def get_ref(taskrunner, labelconv,
            instrume,
            obsinfos,
            refcoord,
            outdir, outhead,
            linkfiles=True,
            tmpdir="/tmp/",
            verbose=0,
            parallel=True):
    """This routine creates the reference WCSs, recalculating the
    coordinates for all the files with respect to a common tangent
    point.

    The return is a a tuple:
      array of the output (reprojected or copied) event files,
      name of boundary task that indicates all reprojections have
        finished
      warnings array from merging.list_observations
    """

    # Note: we no longer use match but ra/dec directly
    (match, ra, dec) = merging.process_reference_position(refcoord, obsinfos)
    warnings = merging.list_observations(instrume, ra, dec, obsinfos)

    infiles = []
    outfiles = []
    asolfiles = []
    for obs in obsinfos:
        infiles.append(obs.get_evtfile())
        outfile = merging.obsid_reproj_evt_name(outdir, outhead, obs.obsid)
        outfiles.append(outfile)
        asolfiles.append(obs.get_asol())

    etask = merging.reproject_event_files(taskrunner,
                                          labelconv,
                                          [],
                                          infiles, outfiles, ra, dec,
                                          clobber=True,
                                          verbose=verbose,
                                          tmpdir=tmpdir,
                                          parallel=parallel)

    ltask = link_ancillary_files(taskrunner, labelconv, [etask],
                                 instrume, obsinfos, outfiles,
                                 outdir, outhead,
                                 linkfiles=linkfiles,
                                 tmpdir=tmpdir)

    return (outfiles, ltask, warnings)


def file_checks(obsinfos,
                outdir, outhead,
                clobber=False):
    "check existence or status of output files."

    for obs in obsinfos:
        outfile = merging.obsid_reproj_evt_name(outdir, outhead, obs.obsid)
        fileio.outfile_clobber_checks(clobber, outfile)


def merge_evt_files(taskrunner, preconditions,
                    infiles,
                    outfile,
                    obsinfos,
                    tmpdir="/tmp/",
                    clobber=False,
                    verbose=0):
    """Merge the reprojected event files.
    """

    ltab = run.get_lookup_table('obsidmerge', pathfrom=__file__)
    task = outfile
    taskrunner.add_task(task, preconditions,
                        merging.merge_event_files,
                        infiles,
                        outfile,
                        obsinfos=obsinfos,
                        colfilter=True,
                        lookupTab=ltab,
                        tmpdir=tmpdir,
                        clobber=clobber,
                        verbose=verbose)
    return task


@lw.handle_ciao_errors(toolname, __revision__)
def reproject_obsids():
    "run the tool"

    args = lw.preprocess_arglist(sys.argv)
    (pars, params) = get_par(args)

    # set parameters in CIAO tools to use a level less that set
    if params['verbose'] > 0 and params['verbose'] < 5:
        params['verbose'] -= 1

    obsinfos = params['obsinfos']
    outdir = params['outdir']
    outhead = params['outhead']

    file_checks(obsinfos, outdir, outhead,
                clobber=params['clobber'])

    taskrunner = TaskRunner()
    (ref_files, rtask, warnings) = get_ref(taskrunner, lambda s: s,
                                           params['instrument'],
                                           obsinfos,
                                           params['refcoord'],
                                           outdir,
                                           outhead,
                                           linkfiles=params['linkfiles'],
                                           tmpdir=params['tmpdir'],
                                           verbose=params['verbose'],
                                           parallel=params['parallel'])

    if params['evtmerge']:
        mevtfile = merging.merged_evt_name(outdir, outhead)
        merge_evt_files(taskrunner,
                        [rtask],
                        ref_files,
                        mevtfile,
                        obsinfos,
                        tmpdir=params['tmpdir'],
                        clobber=params['clobber'],
                        verbose=params['verbose'])

    taskrunner.run_tasks(processes=params['nproc'])

    for ref_file in ref_files:
        add_tool_history(ref_file, toolname, pars, toolversion=__revision__)

    if params['evtmerge']:
        add_tool_history(mevtfile, toolname, pars, toolversion=__revision__)

    # This should be picked up automatically
    if len(obsinfos) > 1:
        add_tool_history(f"{outdir}{outhead}merged.fov", toolname, pars,
                         toolversion=__revision__)

    spacer = '     '
    v1("\nThe following files were created:\n")

    if len(obsinfos) == 1:
       v1("The reprojected FOV file:")
    else:
        v1("The reprojected FOV files:")

    for obs in obsinfos:
        v1(f"{spacer}{outdir}{outhead}{obs.obsid}.fov")

    v1("")
    if len(obsinfos) > 1:
        v1("The combined FOV file:")
        v1(f"{spacer}{outdir}{outhead}merged.fov")
        v1("")

    if len(ref_files) == 1:
        v1("The reprojected event file:")
    else:
        v1("The reprojected event files:")

    v1("{}{}\n".format(spacer, ('\n' + spacer).join(ref_files)))

    if params['evtmerge']:
        v1("The merged event file:")
        v1(f"{spacer}{mevtfile}\n")
        merging.display_merging_warnings(warnings,
                                         mevtfile,
                                         params['obsinfos'])

    if params['warn_msgs'] != []:
        for wm in params['warn_msgs']:
            v1(wm)


if __name__ == "__main__":
    reproject_obsids()
