#!/usr/bin/env python
#
# Copyright (C) 2016-2024
#           Smithsonian Astrophysical Observatory
#
#
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
# 02110-1301 USA.
#

__toolname__ = "blanksky_image"
__revision__  = "18 March 2024"


import os
import sys
import tempfile
import numpy

import paramio
import pycrates as pcr

from ciao_contrib.logger_wrapper import initialize_logger, make_verbose_level, set_verbosity, handle_ciao_errors
from ciao_contrib.param_wrapper import open_param_file

from ciao_contrib._tools import utils, fileio
from ciao_contrib.runtool import make_tool, add_tool_history, new_pfiles_environment

from ciao_contrib.parallel_wrapper import parallel_pool_futures
#from sherpa.utils import parallel_map

#############################################################################
#############################################################################

# Set up the logging/verbose code
initialize_logger(__toolname__)

# Use v<n> to display messages at the given verbose level.
v0 = make_verbose_level(__toolname__, 0)
v1 = make_verbose_level(__toolname__, 1)
v2 = make_verbose_level(__toolname__, 2)
v3 = make_verbose_level(__toolname__, 3)
v4 = make_verbose_level(__toolname__, 4)
v5 = make_verbose_level(__toolname__, 5)


class ScriptError(RuntimeError):
    """Error found during running the script. This class is introduced
    in case there is a need to catch such an error and deal with it
    appropriately (e.g. recognize it as distinct from an error raised
    by the code).
    """
    pass


###############################################


def combine_chips(args):
    """
    stack the scaled and energy filtered chip images
    """

    bkgfile,chip,component_info,bkg_kw,bkg,dettype,dmfilter,xylimits,crtype,verbose = args

    ind = component_info[1].index(chip)
    elo = component_info[2][ind]
    ehi = component_info[3][ind]

    dmimgcalc = make_tool("dmimgcalc")

    if dettype == "ccd_id":
        scale = bkg_kw[f"BKGSCAL{chip}"]
        energy = f"energy={elo}:{ehi}"
    else:
        scale = bkg_kw["BKGSCALE"]
        energy = f"pi={elo}:{ehi}"

    with new_pfiles_environment(ardlib=False,copyuser=False):

        dmimgcalc.punlearn()
        dmimgcalc.infile = f"{bkg}[{dettype}={chip},{energy}]{dmfilter}[bin {xylimits}][opt type={crtype}]"
        dmimgcalc.outfile = bkgfile
        dmimgcalc.operation = f"imgout={scale}*img1"
        dmimgcalc.clobber = True
        dmimgcalc.verbose = verbose

        dmimgcalc()



def bkg_img(imgfile,bkg,outfile,dmfilter,dettype,chips,tmpdir,verbose):
    """
    return XY-grid; produce scaled image of the background file
    """

    dmimgcalc = make_tool("dmimgcalc")
    gsl = make_tool("get_sky_limits")

    gsl.punlearn()
    gsl.image = imgfile
    gsl.verbose = "0"
    gsl()

    xylimits = gsl.dmfilter

    bkg_kw = fileio.get_keys_from_file(bkg)
    bkgmethod = bkg_kw["BKGMETH"].lower()

    cr = pcr.read_file(imgfile)

    # match chips available in blanksky events file and the reference image file
    chips_img = []

    for i,_ in enumerate(chips,start=1):
        try:
            chips_img.extend(cr.get_subspace_data(i,dettype).range_min.tolist())
        except IndexError:
            continue

    chips = set(chips_img).intersection(chips)

    # get energy filter that matches reference image file by using the file's subspace
    component_info = []

    if dettype == "ccd_id":
        subspace_arg = "energy"
    else:
        subspace_arg = "pi"

    for i,_ in enumerate(chips,start=1):
        try:
            component_info.append([i,
                                   cr.get_subspace_data(i,dettype).range_min,
                                   cr.get_subspace_data(i,subspace_arg).range_min[0],
                                   cr.get_subspace_data(i,subspace_arg).range_max[0]])

        except IndexError:
            break

    # restructure list to remove lists; necessary for subspace components with
    # more than one chip listed
    component_info = [[n[0],chip,n[2],n[3]] for n in component_info for chip in n[1]]
    component_info = list(zip(*component_info)) # the * in the zip takes the transpose

    # get image data type to set appropriate image type
    crtype = cr.get_image()._values.dtype

    if crtype == numpy.dtype('int16'):
        crtype = "i2"
    elif crtype == numpy.dtype('int64'):
        crtype = "i4"
    elif crtype == numpy.dtype('float32'):
        crtype = "r4"
    elif crtype == numpy.dtype('float64'):
        crtype = "r8"
    else:
        crtype = "i4"

    cr = None

    # scale each chip before mosaicing them using the BKGSCALn keywords
    # in the background event file
    bkg_scaled = []

    dmimgcalc_op = "+".join([f"img{n}" for n,_ in enumerate(chips,start=1)])

    outname = f"{outfile}{bkgmethod}_bgnd.img"

    try:
        bkg_scaled = [(tempfile.NamedTemporaryFile(suffix=f".bkg{chip}",dir=tmpdir),
                       chip)
                      for chip in chips]

        args = [(bg.name,
                 chip,
                 component_info,
                 bkg_kw,
                 bkg,
                 dettype,
                 dmfilter,
                 xylimits,
                 crtype,
                 verbose) for bg,chip in bkg_scaled]

        #parallel_map(combine_chips,args)
        parallel_pool_futures(combine_chips,args)

        dmimgcalc.punlearn()
        dmimgcalc.infile = [bg.name for bg,_ in bkg_scaled]
        dmimgcalc.outfile = outname
        dmimgcalc.operation = f"imgout={dmimgcalc_op}"
        dmimgcalc.clobber = True
        dmimgcalc.verbose = verbose

        dmimgcalc()

    finally:
        for fn,_ in bkg_scaled:
            fn.close()

    return xylimits,outname,bkgmethod



def get_par(argv):
    """ 
    Get data_products parameters from parameter file 
    """

    pfile = open_param_file(argv,toolname=__toolname__)["fp"]

    # Common parameters:
    params: dict = {}
    pars: dict = {}

    # input blanksky background event file
    pars["bkgfile"] = params["bkgfile"] = paramio.pgetstr(pfile,"bkgfile")
    if params["bkgfile"] == "":
        raise ScriptError("Input background event file must be specified.")
    if params["bkgfile"].startswith("@"):
        raise ScriptError("Input event stacks not supported.")

    params["infile_filter"] = fileio.get_filter(params["bkgfile"])
    params["bkgfile"] = fileio.get_file(params["bkgfile"])

    # output file name
    pars["outroot"] = params["outroot"] = paramio.pgetstr(pfile,"outroot")
    if params["outroot"] == "":
        raise ScriptError("Please specify an output file name.")

    params["outdir"],outfile = utils.split_outroot(params["outroot"])

    if outfile.endswith("_"):
        params["outroot"] = outfile
    else:
        params["outroot"] = outfile.rstrip("_")

    # check if output directory is writable
    fileio.validate_outdir(params["outdir"])

    # reference image file to get xygrid and binning
    pars["imgfile"] = params["imgfile"] = paramio.pgetstr(pfile,"imgfile")
    if params["imgfile"] == "":
        raise ScriptError("Reference image file must be specified.")
    if params["imgfile"].startswith("@"):
        raise ScriptError("Input image file stacks not supported.")

    params["imgfile"] = fileio.get_file(params["imgfile"])

    # set clobber of files
    pars["tmpdir"] = params["tmpdir"] = paramio.pgetstr(pfile,"tmpdir")
    pars["clobber"] = params["clobber"] = paramio.pgetstr(pfile, "clobber")
    pars["verbose"] = params["verbose"] = paramio.pgeti(pfile,"verbose")
    pars["mode"] = params["mode"] = paramio.pgetstr(pfile, "mode")

    ## error out if there are spaces in absolute paths of the various parameters
    checklist_spaces = [(pars["bkgfile"], "background"),
                        (pars["outroot"], "output"),
                        (params["imgfile"], "output image")]

    for fn,filetype in checklist_spaces:
        _check_no_spaces(fn,filetype)

    # close parameters block after validation
    paramio.paramclose(pfile)

    return params,pars



def _check_no_spaces(fn,filetype):
    abspath = os.path.abspath(fn)

    if " " in abspath:
        raise IOError(f"The absolute path for the {filetype} file, '{abspath}', cannot contain any spaces")



@handle_ciao_errors(__toolname__,__revision__)
def doit():

    params,pars = get_par(sys.argv)

    set_verbosity(params["verbose"])
    utils.print_version(__toolname__, __revision__)

    bkgfile = params["bkgfile"]
    filters = params["infile_filter"]
    outdir = params["outdir"]
    outroot = params["outroot"]
    imgfile = pars["imgfile"]
    tmpdir = params["tmpdir"]
    clobber = params["clobber"]
    verbose = params["verbose"]

    bkg_kw = fileio.get_keys_from_file(bkgfile)
    instrument = bkg_kw["INSTRUME"]

    # determine chips to be used for the image
    if instrument == "ACIS":
        chips = fileio.get_ccds(bkgfile+filters)
        det = "ccd_id"

    else:
        chips = fileio.get_chips(bkgfile+filters)
        det = "chip_id"

    # get XY-grid; produce scaled image of the background file
    _,bkgimg,bkgmethod = bkg_img(imgfile,bkgfile,outdir+outroot,
                                 filters,det,chips,tmpdir,verbose)

    # create ACIS background subtracted image
    outname = f"{outdir}{outroot}{bkgmethod}_bkgsub.img"

    dmimgcalc = make_tool("dmimgcalc")

    dmimgcalc.punlearn()

    dmimgcalc.infile = imgfile
    dmimgcalc.infile2 = bkgimg
    dmimgcalc.outfile = outname
    dmimgcalc.operation = "sub"
    dmimgcalc.weight = "1"
    dmimgcalc.weight2 = "1"
    dmimgcalc.clobber = clobber
    dmimgcalc.verbose = verbose

    dmimgcalc()

    add_tool_history(bkgimg,__toolname__,pars)
    add_tool_history(outname,__toolname__,pars)



if __name__ == "__main__":
    doit()
