#!/usr/bin/env python
#
# Copyright (C) 2023
# Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

"Script to combine energy map w/ counts image to create true color image"

__toolname__ = "energy_hue_map"
__revision__ = "02 June 2023"

import sys
from colorsys import hls_to_rgb, hsv_to_rgb

import numpy as np
from pycrates import read_file
import ciao_contrib.logger_wrapper as lw


lw.initialize_logger(__toolname__)
verb0 = lw.get_logger(__toolname__).verbose0
verb1 = lw.get_logger(__toolname__).verbose1
verb2 = lw.get_logger(__toolname__).verbose2

np.seterr(all='ignore')


@np.vectorize
def vec_sys_to_rgb(hue, sat, value, func):
    'Vectorized version of to convert hsv_to_rgb or hls_to_rgb'

    if any(np.isnan([hue, sat, value])):
        return (0, 0, 0)

    red, grn, blu = func(hue, sat, value)

    return (red*255, grn*255, blu*255)


def hisv_to_rgb(hue, saturation, value):
    "Same as hsv_to_rgb, but reverse saturation"
    return hsv_to_rgb(hue, 1.0-saturation, value)


def my_hls_to_rgb(hue, saturation, lightness):
    "Same as hls_to_rgb, but swap l and s"
    
    return hls_to_rgb(hue, lightness, saturation)


def scale_values(map_in, map_min, map_max, map_func):
    'Scale the values using the map_func function'

    min_val = np.nanmin(map_in) if map_min is None else map_min
    max_val = np.nanmax(map_in) if map_max is None else map_max
    map_out = map_func(map_in - min_val) / map_func(max_val - min_val)
    map_out = np.clip(map_out, 0, 1)
    return map_out


def map_scale_func(func_name):
    'Map parameter names to functions'
    valid_funcs = {'linear': lambda x: x, 'asinh': np.arcsinh,
                   'log': lambda x: np.log10(x+1),
                   'sqrt': np.sqrt,
                   'square': lambda x: x*x}
    if func_name not in valid_funcs:
        raise ValueError(f"Unknown scale function: {func_name}")
    return valid_funcs[func_name]


def map_colorsys_func(func_name):
    'Map parameter names to function'
    valid_funcs = {'hsv': hsv_to_rgb,
                   'hls': my_hls_to_rgb,
                   'hisv': hisv_to_rgb}
    if func_name not in valid_funcs:
        raise ValueError(f"Unknown scale function: {func_name}")
    return valid_funcs[func_name]


def show(red_map, grn_map, blu_map):
    'Display at the end'
    import matplotlib.pyplot as plt
    plt.imshow(np.stack([red_map, grn_map, blu_map], axis=2), origin='lower')
    plt.margins(x=0, y=0)
    plt.axis('off')
    fig = plt.gcf()
    fig.subplots_adjust(0, 0, 1, 1)
    plt.show()


def write_jpg(pars):
    'Create an output JPEG true color image'

    from ciao_contrib._tools import fileio
    from ciao_contrib.runtool import dmimg2jpg

    outfile = f"{pars['outroot']}.jpg"

    fileio.outfile_clobber_checks(pars["clobber"], outfile)

    dmimg2jpg.infile = f"{pars['outroot']}.fits[red]"
    dmimg2jpg.greenfile = f"{pars['outroot']}.fits[green]"
    dmimg2jpg.bluefile = f"{pars['outroot']}.fits[blue]"
    dmimg2jpg.outfile = outfile
    dmimg2jpg.clobber = True
    dmimg2jpg.minred = 0
    dmimg2jpg.mingreen = 0
    dmimg2jpg.minblue = 0
    dmimg2jpg.maxred = 255
    dmimg2jpg.maxgreen = 255
    dmimg2jpg.maxblue = 255
    dmimg2jpg.scalefunction = 'linear'
    dmimg2jpg.lutfile = ''
    dmimg2jpg.regioncolor = ''
    dmimg2jpg.gridcolor = ''
    dmimg2jpg()


def process_pars(pars):
    ''''Convert parameter to nice data-types/functions.  We keep the
    original parameters so we can set the HISTORY at the end.
    '''

    def none_or_float(par_name):
        if pars[par_name] == "INDEF":
            return None
        return float(pars[par_name])

    out_pars = pars.copy()

    out_pars["colorsys"] = map_colorsys_func(pars["colorsys"])

    for par_name in ['min_energy', 'max_energy', 'min_counts',
                     'max_counts', 'min_hue', 'max_hue', 'min_sat',
                     'max_sat']:
        out_pars[par_name] = none_or_float(par_name)

    out_pars["energy_scale"] = map_scale_func(pars["energy_scale"])
    out_pars["counts_scale"] = map_scale_func(pars["counts_scale"])
    out_pars["clobber"] = pars["clobber"].lower().startswith('y')
    out_pars["show_plot"] = pars["show_plot"].lower().startswith('y')

    return out_pars


def write_output(pars, crate, red_map, grn_map, blu_map):
    'Write output FITS file w/ 3 extensions for each color channel'

    from ciao_contrib.runtool import add_tool_history
    from ciao_contrib._tools import fileio
    from pycrates import set_piximgvals

    outfile = f"{pars['outroot']}.fits"

    fileio.outfile_clobber_checks(pars["clobber"], outfile)

    crate.name = "RED"
    set_piximgvals(crate, red_map)
    crate.write(outfile, clobber=True)
    add_tool_history(outfile, __toolname__, pars,
                     toolversion=__revision__)

    # Okay, there is no way to copy a crate and trying to copy
    # the individual elements (headers, wcs, history, etc) is a PITA
    # So as a hack, I'm going to add each block to the dataset
    # and write() each one.

    out_ds = read_file(outfile, mode="rw").get_dataset()

    crate.name = "GREEN"
    set_piximgvals(crate, grn_map)
    out_ds.add_crate(crate)
    out_ds.write()

    crate.name = "BLUE"
    set_piximgvals(crate, blu_map)
    out_ds.add_crate(crate)
    out_ds.write()

    verb1("To display the data correctly in ds9 use the following command:")
    verb1("")
    verb1(f"ds9 -rgb -red '{outfile}' -linear -scale limits 0 255 \\")
    verb1(f"    -green '{outfile}[GREEN]' -linear -scale limits 0 255 \\")
    verb1(f"    -blue '{outfile}[BLUE]' -linear -scale limits 0 255")


@lw.handle_ciao_errors(__toolname__, __revision__)
def main():
    'main routine'

    from ciao_contrib.param_soaker import get_params

    # Load paraemters
    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})
    wpars = process_pars(pars)

    # Read energy map image
    color_map = read_file(wpars['energymap']).get_image().values

    # Going to use the counts crate as the 'donor' to create
    # the output files, so keep it separate
    counts_crate = read_file(wpars['infile'])
    counts_map = counts_crate.get_image().values

    if color_map.shape != counts_map.shape:
        raise ValueError("ERROR: Images must have the same axis lengths")

    # Scale the energies to determine the hue

    h_map = scale_values(color_map, wpars['min_energy'],
                         wpars['max_energy'], wpars['energy_scale'])
    h_map = h_map * (wpars['max_hue']-wpars['min_hue']) + wpars['min_hue']

    # Scale the counts to get the saturation and lightness|value
    s_map = scale_values(counts_map, wpars['min_counts'],
                         wpars['max_counts'], wpars['counts_scale'])
    s_map = s_map * (wpars['max_sat']-wpars['min_sat']) + wpars['min_sat']

    # We now apply the contrast and bias values to the "value" value.
    # This lets us adjust the brightness of the pixels while keeping
    # the saturation
    # The contrast and bias are same sense and range as ds9 uses.

    v_map = (s_map - 0.5) / float(wpars["contrast"]) + (1.0 - float(wpars["bias"]))
    v_map = np.clip(v_map, 0, 1)

    # Convert hsv|hsl to rgb
    red_map, grn_map, blu_map = vec_sys_to_rgb(h_map, s_map, v_map,
                                               wpars['colorsys'])

    # Write outputs
    write_output(pars, counts_crate, red_map, grn_map, blu_map)
    write_jpg(pars)

    # Show results
    if wpars["show_plot"]:
        verb0("Close plot to exit ....")
        show(red_map, grn_map, blu_map)


if __name__ == '__main__':
    main()
