#!/usr/bin/env python

#
# Copyright (C) 2012, 2013, 2014, 2015, 2016, 2018, 2020, 2021
# Smithsonian Astrophysical Observatory
#
#
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301 USA.
#

"""
Reproject observations to a common tangent point, merge the
files, and produce a fluxed image of the whole data set.
"""

import os
import sys

import paramio
import pycrates

import ciao_contrib.logger_wrapper as lw

from ciao_contrib.param_wrapper import open_param_file

import ciao_contrib._tools.bands as bands
import ciao_contrib._tools.fileio as fileio
import ciao_contrib._tools.fluximage as fi
import ciao_contrib._tools.merging as merging
import ciao_contrib._tools.obsinfo as obsinfo
import ciao_contrib._tools.run as run
import ciao_contrib._tools.utils as utils

from ciao_contrib._tools.aspsol import AspectSolution
from ciao_contrib._tools.taskrunner import TaskRunner

toolname = 'merge_obs'
__revision__ = '05 November 2021'

lw.initialize_logger(toolname)
lgr = lw.get_logger(toolname)
v1 = lgr.verbose1
v3 = lgr.verbose3

_background_override = None


def handle_background_override(argv):
    """Look for --backgroundmap=filename in argv and, if found, remove
    it, setting the internal background flag.

    Returns the remaining argv elements.

    If there are multiple entries then the last one is used.

    The background map file should be an ascii file with two columns;
    the first is the event file name (no path) and the second the
    full path to the HRC-I background file to use.
    """

    global _background_override

    marker = "--backgroundmap="
    out = []
    mapname = None
    for arg in argv:
        if arg == marker:
            raise IOError(f"The {marker} option must be followed by a " +
                          "file name!")

        if arg.startswith(marker):
            mapname = arg[len(marker):]
        else:
            out.append(arg)

    # Only check the "winner"
    if mapname is not None:
        cr = pycrates.read_file(mapname)
        if cr.get_ncols() != 2:
            raise IOError(f"Expected 2 columns in {mapname}, " +
                          f"found {cr.get_ncols()}")

        # convert from np.string_ to string to avoid later oddities
        keys = map(str, cr.get_column(0).values.copy())
        vals = map(str, cr.get_column(1).values.copy())
        cr = None

        _background_override = {}
        for (k, v) in zip(keys, vals):
            v3(f"Storing map from {k} to {v}")
            (dmname, dmfilter) = fileio.get_file_filter(v)
            v = os.path.abspath(dmname) + dmfilter
            try:
                pycrates.read_file(v + "[#row=1]")
                _background_override[k] = v
            except IOError:
                v1(f"WARNING: skipping map from {k} to {v} " +
                   "as the latter is missing")

    return out


def get_par(argv):
    """ Get the parameters from parameter file """

    argv = handle_background_override(argv)
    pfile = open_param_file(argv, toolname=toolname)["fp"]

    def is_set(name):
        return paramio.pgetb(pfile, name) == 1

    # Common parameters:
    params = {}
    pars = {}

    # set verbosity level (intentionally breaking the order of the
    # parameters since we need the verbose level fairly early, and
    # any user who has decided to change verbose to an automatic parameter
    # can handle this break in UI behavior).
    #
    pars["verbose"] = params["verbose"] = paramio.pgeti(pfile, "verbose")

    lw.set_verbosity(pars['verbose'])
    utils.print_version(toolname, __revision__)

    # files with information to reproject
    pars['infiles'] = paramio.pgetstr(pfile, 'infiles')

    # done out of order since need the background setting to filter the
    # input files
    pars['background'] = paramio.pgetstr(pfile, 'background')
    pars['tmpdir'] = paramio.pgetstr(pfile, 'tmpdir')
    params["tmpdir"] = utils.process_tmpdir(pars['tmpdir'])

    obsinfos = merging.validate_obsinfo(pars['infiles'])
    nobs = len(obsinfos)
    instrume = obsinfos[0].instrument

    if pars['background'] != "none" and instrume == 'HRC' and \
       obsinfos[0].detector == 'HRC-I':
        bgfiles = merging.find_hrci_backgrounds(obsinfos,
                                                bgndmap=_background_override,
                                                tmpdir=params['tmpdir'])
        if bgfiles is None:
            pars['background'] = 'none'
            params['bgfiles'] = [None] * nobs

        else:
            chk = [bgfile is None for bgfile in bgfiles]
            if any(chk):
                v3("Stripping out observations with no HRC-I background file")
                obsinfos = [obsinfo for (obsinfo, bgfile) in
                            zip(obsinfos, bgfiles) if bgfile is not None]
                params['bgfiles'] = [bgfile for bgfile in bgfiles
                                     if bgfile is not None]

            else:
                params['bgfiles'] = bgfiles

    else:
        params['bgfiles'] = [None] * nobs

    params['obsinfos'] = obsinfos
    params['background'] = pars['background']

    pars['bkgparams'] = paramio.pgetstr(pfile, 'bkgparams')
    params['bkgparams'] = fi.validate_bkgparams(pars['background'],
                                                pars['bkgparams'])

    # PSF creation
    #
    # Argh: I forget why we have params and pars
    pars['psfecf'] = paramio.pgetstr(pfile, 'psfecf')
    pars['psfmerge'] = paramio.pgetstr(pfile, 'psfmerge')

    if pars['psfecf'] == 'INDEF':
        params['psfecf'] = None
        params['psfmerge'] = None
    else:
        # Validate the result in case the .par file has been adjusted
        # (also because then we can be sure about <= vs < in the checks,
        #  since it's not clear to me what the param library uses)
        #
        # Question: can we let psfecf=1?
        #
        params['psfecf'] = paramio.pgetd(pfile, 'psfecf')
        if params['psfecf'] <= 0 or params['psfecf'] > 1:
            raise ValueError(f"psfecf={params['psfecf']} is not valid, it must be > 0 and <= 1")

        params['psfmerge'] = pars['psfmerge']

    # only used (at present) for background subtraction
    params['random'] = paramio.pgeti(pfile, 'random')

    params['instrume'] = instrume

    # Any other warning messages
    params['warn_msgs'] = merging.obsinfo_checks(obsinfos)

    # root of merged output files
    pars['outroot'] = params["outroot"] = paramio.pgetstr(pfile, "outroot")
    if params["outroot"] == ".":
        params["outroot"] = ""

    (params["outdir"], params["outhead"]) = utils.split_outroot(params["outroot"])
    if params["outdir"] != "":
        fileio.validate_outdir(params["outdir"])

    # energy bands, defined by a comma separated string in eV;
    #   i.e. "300:2000,2000:5000,5000:8000"
    pars['bands'] = params["bands"] = paramio.pgetstr(pfile, "bands")

    # Process the xygrid paramater, decide what other parameters are needed.
    merging.handle_xygrid(pfile, instrume, pars, params)

    pars['asolfiles'] = paramio.pgetstr(pfile, 'asolfiles')
    pars['badpixfiles'] = paramio.pgetstr(pfile, 'badpixfiles')
    pars['maskfiles'] = paramio.pgetstr(pfile, 'maskfiles')

    # As of CIAO 4.6 there is no need for the pbkfiles parameter.
    if instrume == 'ACIS':
        pars['dtffiles'] = None
    elif instrume == 'HRC':
        pars['dtffiles'] = paramio.pgetstr(pfile, 'dtffiles')
    else:
        raise ValueError(f"Internal error: unrecognized INSTRUME keyword {instrume}")

    pars['refcoord'] = params['refcoord'] = paramio.pgetstr(pfile, 'refcoord')

    pars['units'] = params['units'] = paramio.pgetstr(pfile, 'units')

    pars['expmapthresh'] = params['thresh'] = paramio.pgetstr(pfile,
                                                              'expmapthresh')

    pars['parallel'] = paramio.pgetstr(pfile, 'parallel')
    params['parallel'] = is_set('parallel')
    if params['parallel']:
        pars['nproc'] = paramio.pgetstr(pfile, 'nproc')
        if pars['nproc'] == 'INDEF':
            params['nproc'] = None
        else:
            params['nproc'] = paramio.pgeti(pfile, 'nproc')
    else:
        pars['nproc'] = 'INDEF'
        params['nproc'] = 1

    pars['cleanup'] = paramio.pgetstr(pfile, 'cleanup')
    params['cleanup'] = is_set('cleanup')

    pars['clobber'] = paramio.pgetstr(pfile, 'clobber')
    params['clobber'] = is_set('clobber')

    # close parameters block
    paramio.paramclose(pfile)

    return (params, pars)


def align(taskrunner,
          obsinfos,
          ra, dec,
          reprofiles, mergefile,
          tmpdir="/tmp",
          clobber=False,
          verbose=1,
          parallel=False,
          nproc=None):
    "Reproject and merge the events."

    if verbose > 0 and verbose < 5:
        verbose -= 1

    infiles = [obs.get_evtfile() for obs in obsinfos]
    etask = merging.reproject_event_files(taskrunner,
                                          lambda s: s,
                                          [],
                                          infiles, reprofiles,
                                          ra, dec,
                                          clobber=clobber,
                                          verbose=verbose,
                                          tmpdir=tmpdir,
                                          parallel=parallel)

    ltab = run.get_lookup_table('obsidmerge', pathfrom=__file__)
    e2task = "merge-event-files"
    taskrunner.add_task(e2task, [etask],
                        merging.merge_event_files,
                        reprofiles,
                        mergefile,
                        obsinfos=obsinfos,
                        colfilter=True,
                        lookupTab=ltab,
                        tmpdir=tmpdir,
                        clobber=clobber,
                        verbose=verbose)

    return e2task


# CAN THIS BE COMBINED WITH THE VERSION IN flux_obs?
#
def data_products(taskrunner,
                  instrume,
                  process_flags,
                  obsinfos,
                  bkgfiles,
                  chiplists,
                  bands,
                  binval,
                  units,
                  thresh,
                  xygrids,
                  outdir, outhead,
                  ecf=None,
                  background="default",
                  bkgparams="",
                  random=0,
                  cleanup=True,
                  tmpdir="/tmp/",
                  clobber=False,
                  verbose=0,
                  parallel=False):
    "Run fluximage on the individual observations"

    # try running fluximage with one less than the script
    tverbose = verbose - 1
    if tverbose < 0:
        tverbose = 0
    elif tverbose == 4:
        tverbose = 5

    # Strip out observations that are to be ignored
    args = zip(obsinfos, bkgfiles, chiplists, xygrids)
    args = [arg for (flag, arg) in zip(process_flags, args) if flag]
    nargs = len(args)

    if nargs == 1:
        msg = "\nCreating the fluxed image."

    elif parallel:
        msg = f"\nCreating the fluxed images for {nargs} observations in " + \
            "parallel."

    else:
        msg = f"\nCreating the fluxed images for {nargs} observations."

    v1(msg)

    for (obs, bkgfile, chips, xygrid) in args:

        outroot = f"{outdir}{outhead}{obs.obsid}_"

        (xrng, yrng) = xygrid
        grid = f"x={xrng},y={yrng}"
        # ox = xrng.nbins
        # oy = yrng.nbins
        binsize = xrng.size

        if bkgfile is None:
            bkginfo = None
        else:
            bkginfo = (background, bkgparams, bkgfile)

        asolobj = AspectSolution(obs.get_asol(), tmpdir=tmpdir)

        fi.run_fluximage_tasks(taskrunner,
                               lambda s: f"{s}-obsid{obs.obsid}",
                               obs,
                               bkginfo,
                               chips,
                               asolobj,
                               outroot,
                               grid,
                               binsize,
                               bands,
                               units,
                               thresh,
                               ecf=ecf,
                               random=random,
                               tmpdir=tmpdir,
                               verbose=tverbose,
                               clobber=clobber,
                               cleanup=cleanup,
                               parallel=parallel,
                               pathfrom=__file__)


"""
TODO:

clean up the headers of the combined images by removing some, or all, of

ASOLFILE
THRFILE
MJD_OBS
DS_IDENT
TLMVER
REVISION
SUBPIXFL
RAND_SKY ?
RAND_PI  ?
PIX_ADJ  ?
MJDREF
OBS_MODE
TIMEZERO
TIMEUNIT
DATACLAS
TIMEREF
TASSIGN
CLOCKAPP
SIM_X / Y / Z
FOC_LEN
TIERRELA
TIERABSO
GRATING
DETNAM
RA_TARG
DEC_TARG
DEFOCUS
RA_NOM
DEC_NOM
STARTMJF
STARTMNF
STARTOBT
STOPMJF
STOPMNF
TIMEPIXR ?
TIMEDEL  ?
GAINFILE
CTI_CORR ?
CTI_APP  ?
CTIFILE
MTLFILE
TGAINCIRR ?
TGAINFIL
GRD_FILE
CORNERS
GRADESYS
BPIXFILE
READMODE
DATAMODE
RUN_ID
FSW_VERS
STARTBEP
STOPBEP
FEP_ID
CCD_ID
TIMEDELA
TIMEDELB
FIRSTROW
NROWS
FLSHTIMA
FLSHTIMB
CYCLE
OBS_ID
SEQ_NUM
ONTIME0-9   ? if they are correct then possibly useful
LIVTIME0-9  ? ditto
EXPOSUR0-9  ? ditto
DTCOR
ASPTYPE
BIASFIL0-9
AIMPFILE
GEOMFILE
SKYFILE
TDETFILE
SHELLFIL
FLTFILE
MASKFILE
PBKFILE
CALDBVER ?
CLSTBITS

subspace filters?

"""


def setup_output_names(outdir, outhead, obsinfos, enbands,
                       psf=False,
                       threshold=False):
    """Set up entries in params that give the file names
    that the code creates.
    """

    v3("Setting up output names")

    obsids = [obs.obsid for obs in obsinfos]
    out = fi.fluximage_output(outdir, outhead, obsinfos, enbands,
                              psf=psf,
                              threshold=threshold)

    out['reprofiles'] = [merging.obsid_reproj_evt_name(outdir, outhead, obsid)
                         for obsid in obsids]
    out['mergedevtfile'] = merging.merged_evt_name(outdir, outhead)

    if threshold:
        img_label = 'out_thresh_images'
        emap_label = 'out_thresh_expmaps'
        ifunc = merging.name_coadd_thresh_img
        efunc = merging.name_coadd_thresh_expmap
    else:
        img_label = 'out_images'
        emap_label = 'out_expmaps'
        ifunc = merging.name_coadd_img
        efunc = merging.name_coadd_expmap

    out[img_label] = []
    out[emap_label] = []
    out['out_fluxmaps'] = []

    if psf:
        out['out_psfmaps'] = []

    for enband in enbands:
        eband = enband.bandlabel
        out[img_label].append(ifunc(outdir, outhead, eband))
        out[emap_label].append(efunc(outdir, outhead, eband))
        fmap = merging.name_coadd_flux(outdir, outhead, eband)
        out['out_fluxmaps'].append(fmap)

        if psf:
            pmap = merging.name_coadd_psfmap(outdir, outhead, eband,
                                             thresh=threshold)
            out['out_psfmaps'].append(pmap)

    return out


@lw.handle_ciao_errors(toolname, __revision__)
def merge_obs():
    "Run the tool."

    args = lw.preprocess_arglist(sys.argv)
    (params, pars) = get_par(args)

    obsinfos = params['obsinfos']
    instrume = params['instrume']
    merging.validate_obsinfo_params(obsinfos,
                                    pars['asolfiles'],
                                    pars['badpixfiles'],
                                    pars['maskfiles'],
                                    dtffiles=pars['dtffiles'],
                                    tangent=False)
    params['enbands'] = bands.validate_bands(instrume, params['bands'])

    outdir = params['outdir']
    outhead = params['outhead']

    enbands = params['enbands']

    clobber = params['clobber']
    verbose = params['verbose']
    parallel = params['parallel']
    tmpdir = params['tmpdir']

    is_thresh = utils.thresh_is_set(params['thresh'])

    refcoord = params['refcoord']
    xygrid = params['xygrid']
    bin = params['bin']

    psfecf = params['psfecf']
    want_psf = psfecf is not None

    outfiles = setup_output_names(outdir, outhead, obsinfos, enbands,
                                  psf=want_psf,
                                  threshold=is_thresh)
    fi.clobber_checks(outfiles, clobber=clobber)

    # Find out what the new reference location is
    # Note: moving away from using the file name as a reference
    #       if given, so refcoord isn't used anymore.
    #
    (refcoord, ra, dec) = merging.process_reference_position(refcoord,
                                                             obsinfos)
    warnings = merging.list_observations(instrume, ra, dec, obsinfos)

    taskrunner = TaskRunner()
    align(taskrunner,
          obsinfos,
          ra, dec,
          outfiles['reprofiles'],
          outfiles['mergedevtfile'],
          tmpdir=tmpdir,
          clobber=clobber,
          verbose=verbose,
          parallel=parallel)

    taskrunner.run_tasks(processes=params['nproc'])
    v1("")

    # Set up the obsinfo objects for the reprojected files
    robsinfos = []
    for (obs, reprofile) in zip(obsinfos, outfiles['reprofiles']):
        robs = obsinfo.ObsInfo(reprofile)
        robs.set_asol(obs.get_asol())

        # Copy over the "multi obi" flag from the observation object
        # sent in, which will have been set to indicate whether
        # the input/non-reprojected version is a "multi obi" case.
        #
        # This is a bit ugly
        robs.obsid.is_multi_obi = obs.obsid.is_multi_obi

        atypes = ['bpix', 'mask']
        if instrume == 'HRC':
            atypes.append('dtf')

        for atype in atypes:
            robs.set_ancillary(atype, obs.get_ancillary(atype))

        robsinfos.append(robs)

    obsinfos = None
    params['obsinfos'] = None
    if xygrid == "":
        (xygrids, chipslist, process) = merging.matchup_xygrids_auto(robsinfos,
                                                                     bin,
                                                                     params['maxsize'],
                                                                     tmpdir=tmpdir)
    else:
        (xygrids, chipslist, process) = merging.matchup_xygrids_user(xygrid,
                                                                     robsinfos,
                                                                     bin,
                                                                     params['sizes'],
                                                                     params['xrange'],
                                                                     params['yrange'],
                                                                     tmpdir=tmpdir)

    data_products(taskrunner,
                  instrume,
                  process,
                  robsinfos,
                  params['bgfiles'],
                  chipslist,
                  enbands,
                  bin,
                  params['units'],
                  params['thresh'],
                  xygrids,
                  outdir,
                  outhead,
                  ecf=psfecf,
                  background=params['background'],
                  bkgparams=params['bkgparams'],
                  random=params['random'],
                  cleanup=params["cleanup"],
                  tmpdir=tmpdir,
                  clobber=clobber,
                  verbose=verbose,
                  parallel=parallel)
    taskrunner.run_tasks(processes=params['nproc'], label=False)

    merging.merge(process,
                  enbands,
                  pars,
                  outfiles,
                  robsinfos,
                  toolname,
                  __revision__,
                  verbose=verbose,
                  psfmerge=params['psfmerge'],
                  threshold=is_thresh,
                  clobber=clobber,
                  pathfrom=__file__,
                  tmpdir=tmpdir)

    merging.display_merging_warnings(warnings,
                                     outfiles['mergedevtfile'],
                                     robsinfos)

    if params['warn_msgs'] != []:
        for wm in params['warn_msgs']:
            v1(wm)


if __name__ == "__main__":
    merge_obs()
