#!/usr/bin/env python

#
# Copyright (C) 2010, 2016, 2019, 2020, 2023
# Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

"""

"""

import os
import sys
import logging
import numpy as np

from ciao_contrib.logger_wrapper import initialize_logger, \
    make_verbose_level, set_verbosity, handle_ciao_errors
from ciao_contrib.param_wrapper import open_param_file
from ciao_contrib.runtool import add_tool_history
from ciao_contrib._tools.fileio import outfile_clobber_checks
import pycrates
import paramio as pio

from sherpa.data import Data1D
from sherpa.models.basic import Polynom1D
from sherpa.fit import Fit

from sherpa import plot


TOOLNAME = "correct_periscope_drift"
VERSION = "26 September 2023"

# Set up the logging/verbose code
initialize_logger(TOOLNAME)

# Use v<n> to display messages at the given verbose level.
# You can pick other names than v0,v1, ... v5 if desired.
#
v1 = make_verbose_level(TOOLNAME, 1)
v2 = make_verbose_level(TOOLNAME, 2)
v3 = make_verbose_level(TOOLNAME, 3)
v5 = make_verbose_level(TOOLNAME, 5)



def process_command_line(argv):
    """Handle the parameter input for this script."""

    if argv is None or argv == []:
        raise ValueError("argv argument is None or empty")

    # open_param_file, from ciao_contrib.param_wrapper,
    # opens the parameter file and sorts out the command
    # line. It returns a dictionary with a few fields;
    # fp is the handle to use in paramio calls, as shown
    # below; and parnams is the name to use if you need
    # to re-open the parameter file to set things
    # (which is unlikely, speak to Doug if you think you
    # might need this).
    #
    pinfo = open_param_file(argv, toolname=TOOLNAME)
    fp = pinfo["fp"]

    # Use the parameter library to get the arguments
    # and perform any desired validity checks (e.g.
    # the outfile check/clobber check shown below).
    #
    # Parameters should be queried in the same order
    # as the parameter file
    #
    mypars = {}

    for stringpar in ['infile', 'evtfile',
                      'outfile', 'corr_plot_root']:
        mypars[stringpar] = pio.pgetstr(fp, stringpar)
        if mypars[stringpar].strip() == "":
            raise ValueError("{} parameter is empty".format(stringpar))

    for floatpar in ['x', 'y', 'radius', 'src_min_counts']:
        mypars[floatpar] = pio.pgetd(fp, floatpar)

    mypars['corr_poly_degree'] = pio.pgeti(fp, "corr_poly_degree")

    clobber = pio.pgetstr(fp, "clobber")
    verbose = pio.pgeti(fp, "verbose")

    # We close the parameter file here; if you need to write
    # values to the file you can either leave it open or -
    # possibly better - is to close it and re-open it later.
    #
    pio.paramclose(fp)

    # Set tool and module verbosity
    set_verbosity(verbose)

    # check outfile
    outfile_clobber_checks(clobber, mypars['outfile'])
    # and plotfiles
    for ax in ['yag', 'zag']:
        outfile_clobber_checks(clobber, "{}_fit_{}.png".format(mypars['corr_plot_root'], ax))
        outfile_clobber_checks(clobber, "{}_data_{}.png".format(mypars['corr_plot_root'], ax))

    mypars['clobber'] = clobber
    mypars['verbose'] = verbose

    # Return a dictionary with useful info (you could return a tuple or
    # something else, depending on the parameter file)
    #
    return mypars


# Display parameter info to the user.
# The format and choice of information to display is up to you;
# I've chosen to include some ancilalry information but it
# depends on the tool what should be used.
#
def display_start_info(opts):
    v1("Running: {0}".format(TOOLNAME))
    v2("  version = {0}".format(VERSION))
    v2("with parameters:")
    v1("  infile={0}".format(opts["infile"]))
    v1("  evtfile={0}".format(opts["evtfile"]))
    v1("  outfile={0}".format(opts["outfile"]))
    # probably other values here too
    v2("  verbose={0}".format(opts["verbose"]))
    v2("  and ASCDS_INSTALL is {0}".format(os.environ["ASCDS_INSTALL"]))
    v2("-" * 60)


def equatorial2transform(ra, dec, roll):
    """Construct the transform/rotation matrix from RA,Dec,Roll (given in degrees)
    :returns: transform matrix
    :rtype: Nx3x3 numpy array

    """
    ra = np.radians(ra)
    dec = np.radians(dec)
    roll = np.radians(roll)
    ca = np.cos(ra)
    sa = np.sin(ra)
    cd = np.cos(dec)
    sd = np.sin(dec)
    cr = np.cos(roll)
    sr = np.sin(roll)
    # This is the transpose of the transformation matrix (related to
    # translation of original perl code
    rmat = np.array(
        [[ca * cd,                  sa * cd,                sd     ],
         [-ca * sd * sr - sa * cr, -sa * sd * sr + ca * cr, cd * sr],
         [-ca * sd * cr + sa * sr, -sa * sd * cr - ca * sr, cd * cr]])

    return rmat.transpose()


def radec2eci(ra, dec):
    """
    Convert from RA,Dec to ECI.  The input ``ra`` and ``dec`` values can be 1-d
    arrays of length N in which case the output ``ECI`` will be an array with
    shape (3,N).

    Borrowed from Ska.quatutil

    :param ra: Right Ascension (degrees)
    :param dec: Declination (degrees)
    :returns: numpy array ECI (3-vector or 3xN array)
    """
    r = np.radians(ra)
    d = np.radians(dec)
    return np.array([np.cos(r) * np.cos(d), np.sin(r) * np.cos(d), np.sin(d)])


def extract_events(event_file, src_x, src_y, src_radius):
    """
    Get events from specified source circle

    :param event_file: Chandra event 1 or 2 file
    :param src_x: Sky X coordinate of source region to extract
    :param src_y: Sky Y coordinate of source region to extract
    :param src_radius: Source region/circle radius in pixels
    :returns: CRATE of events
    """
    regstring = "circle({},{},{})".format(src_x, src_y, src_radius)
    events = pycrates.read_file("{}[sky={}]".format(
            event_file, regstring))
    return events


def get_event_yag_zag(evt_ra, evt_dec, ra_pnt, dec_pnt, roll_pnt):
    """
    Convert RA and Dec positions into Y and Z angle

    This takes the RA and Dec positions of the events and takes a reference pointing
    (which maybe taken from the NOM or PNT values provided in the event file header)
    and uses that reference pointing to convert to Y and Z angle
    that should relate to Aspect Camera Y and Z angle for the purposes of
    fitting Y and Z angle periscope drift.

    :param evt_ra: event RA
    :param evt_dec: event Dec
    :param ra_pnt: A single reference value for RA
    :param dec_pnt: A single reference value for Dec
    :param roll_pnt: A single reference value for Roll
    :returns: yag, zag in arcsecs
    """
    if len(evt_ra) != len(evt_dec):
        raise ValueError("len(evt_ra) != len(evt_dec), {} != {}".format(
                len(evt_ra), len(evt_dec)))

    # Transform to Earth Centered Inertial
    eci = radec2eci(evt_ra, evt_dec)
    # transform from 3 x N to N x 3
    eci = eci.transpose()

    att_stack = np.repeat(np.array([ra_pnt, dec_pnt, roll_pnt]),
                          len(evt_ra)).reshape(3, len(evt_ra))
    # Transforms
    Ts = equatorial2transform(att_stack[0], att_stack[1], att_stack[2])

    # The position of the events rotated into the frame of the
    d_aca = np.sum(Ts.transpose(2, 0, 1) * eci, axis=-1).transpose()

    R2A = 206264.81
    yag = np.arctan2(d_aca[:, 1], d_aca[:, 0]) * R2A
    zag = np.arctan2(d_aca[:, 2], d_aca[:, 0]) * R2A
    return yag, zag


def time_bins(times, x, nbins=20):
    """
    Bin 'x' by the times in 'times'.
    This is only used as a plot aid.

    :param times: times used for time bins
    :param x: dataset binned in equal time chunks
    :param n_bins: number of time bins to use
    :returns: bin time centers, bin data mean, bin data std
    """

    h, bins = np.histogram(times, bins=nbins)
    bin_centers = (bins[:-1] + bins[1:]) / 2.0
    inds = np.digitize(times, bins) - 1
    bin_x = []
    bin_std = []
    for idx in range(0, nbins):
        data = x[inds == idx]

        bin_x.append(np.mean(data))
        bin_std.append(np.std(data)/np.sqrt(len(data)))
    return bin_centers, np.array(bin_x), np.array(bin_std)


def _fit_poly(fit_data, evt_times, degree):
    """
    Given event data transformed into Y or Z angle positions, and a degree of the desired
    fit polynomial, fit a polynomial to the data.

    :param fit_data: event y or z angle position data
    :param evt_times: times of event/fit_data
    :param degree: degree of polynomial to use for the fit model

    :returns: (sherpa data set, sherpa model)
    """

    def scale_error(eterm):
        return np.zeros_like(fit_data) + eterm

    dset = Data1D('data',
                  evt_times - evt_times[0],
                  fit_data,
                  staterror=scale_error(1))

    v2("Fitting a line to the data to get reduced stat errors")

    # First just fit a line to get reduced errors on this set. Note that
    # polynom1d models default to c0 being the only thawed parameter.
    #
    line = Polynom1D('line')
    line.c1.thaw()
    fit = Fit(dset, line).fit()
    v3("Initial fit: succeeded={} rstat= {}".format(fit.succeeded, fit.rstat))

    dset.staterror = scale_error(np.sqrt(fit.rstat))

    # Then fit the specified model
    #
    v2("Fitting a polynomial of degree {} to the data".format(degree))
    fitpoly = Polynom1D('fitpoly')
    fitpoly.c0 = 0

    # Thaw the coefficients requested by the degree of the desired polynomial
    for deg in range(1, 1 + degree):
        attr = "c{}".format(deg)
        getattr(fitpoly, attr).thaw()

    v3("Starting fit with:")
    v3(str(fitpoly))
    Fit(dset, fitpoly).fit()

    v3("Ended fit with:")
    v3(str(fitpoly))

    return dset, fitpoly


def write_key(crate, name, value, comment, units=""):
    """
    Wrapper around crate key creation. Borrowed from srcflux
    """
    kw = pycrates.CrateKey()
    kw.name = name
    kw.value = value
    kw.desc = comment
    kw.unit = units
    crate.add_key(kw)


def make_fit_plot(evt_times, fit_data, mp, ax):
    """Create the plot of the binned data and fit."""

    bin_centers, bin_mean, bin_std = time_bins(evt_times, fit_data)
    dx = (bin_centers - evt_times[0]) / 1000.

    # set minimum limit on fit plot in arcsecs and set this explicitly
    # as a symmetric limit
    fit_ymax = max(0.3,
                   np.max(np.abs(bin_mean - bin_std)),
                   np.max(np.abs(bin_mean + bin_std)))

    # Rather than create a data object to pass to .prepare, set the
    # attributes manually.
    #
    dplot = plot.DataPlot()
    dplot.x = dx
    dplot.y = bin_mean
    dplot.yerr = bin_std
    dplot.xlabel = 'Observation elapsed/delta time (ks)'
    dplot.ylabel = 'Position offset from mean, {} (arcsec)'.format(ax)
    dplot.title = 'Fit of {} data (with time-binned event offsets)'.format(ax)

    # Backend-specific customization
    #
    if 'errstyle' in dplot.plot_prefs:
        dplot.plot_prefs['errstyle'] = 'cap'

    if 'capsize' in dplot.plot_prefs:
        dplot.plot_prefs['capsize'] = 5

    mplot = plot.ModelPlot()
    mplot.x = mp.x / 1000
    mplot.y = mp.y

    with plot.backend:
        dplot.plot()
        mplot.plot(overplot=True)
        ylimits(-1 * fit_ymax, fit_ymax)


def make_data_plot(dset, model, fit_data, ax):

    # set minimum limit on data plot in arcsecs and set this explicitly
    # as a symmetric limit
    data_ymax = max(2.0, np.max(np.abs(fit_data)) + .2)

    # Adjust the default look of the plot
    #
    dplot = plot.DataPlot()
    dplot.prepare(dset)

    mplot = plot.ModelPlot()
    mplot.prepare(dset, model)

    fplot = plot.FitPlot()
    fplot.prepare(dplot, mplot)

    dplot = fplot.dataplot
    dplot.plot_prefs['yerrorbars'] = False
    dplot.xlabel = 'Observation elapsed/delta time (s)'
    dplot.ylabel = 'Position offset from mean, {} (arcsec)'.format(ax)
    dplot.title = 'Raw data and fit in {}'.format(ax)

    with plot.backend:
        fplot.plot()
        ylimits(-1 * data_ymax, data_ymax)


def ylimits(ymin, ymax):
    """Force the Y axis to cover the given range.

    Parameters
    ----------
    ymin, ymax : number

    Notes
    -----
    Sherpa doesn't provide a back-end agnostic interface for this
    functionality. As of CIAO 4.12 this is not really needed.
    """

    if plot.backend.name != 'pylab':
        return

    from matplotlib import pyplot as plt
    plt.ylim(ymin, ymax)


def hardcopy(filename):
    """Save the current plot to filename.

    The output file is overwritten if it already exists. It is assumed
    that any "clobber" checks have already been made, which does mean
    there is a potential for a "race condition" (the file to be created
    by an external process between teh check and this call), but live
    with that possibility.

    Parameters
    ----------
    filename : str
        The file to create, including any suffix.

    Notes
    -----
    Sherpa doesn't provide a back-end agnostic interface for this
    functionality. As of CIAO 4.12 this is not really needed.
    """

    if plot.backend.name == 'pylab':
        from matplotlib import pyplot as plt
        plt.savefig(filename)
        return

    # could check if this has already been created, but if this
    # does happen we want the user to know, so a repeated message
    # is not a bad thing.
    #
    print('WARNING: No hardcopy - unrecognized Sherpa setting: '
          f'plot_pkg={plot.backend.name}')


# The '@handle_ciao_errors' decorator will catch any error thrown and
# display it in a format like that of a CIAO tool, then exit with a
# non-zero status. So, if at any point in your code you need to exit just
# raise an error and let handle_ciao_errors bother with the display.
#
# This also means that routines can be used from ChIPS/Sherpa/other scripts
# without having to worry about them calling sys.exit.


def main(opt):

    events = extract_events(opt['evtfile'],
                            opt['x'], opt['y'], opt['radius'])

    evt_ra_pnt = events.get_key('RA_PNT').value
    evt_dec_pnt = events.get_key('DEC_PNT').value
    evt_roll_pnt = events.get_key('ROLL_PNT').value

    asol = pycrates.read_file(opt['infile'])
    asol_times = asol.get_column('time').values

    # Sanity check the two input files
    asol_obsid = asol.get_key('OBS_ID').value
    evt_obsid = events.get_key('OBS_ID').value
    if asol_obsid != evt_obsid:
        v1("Error Aspect solution obsid {} != event file obsid {}".format(asol_obsid, evt_obsid))

    # Extract event RA, Dec, and times from event file
    # Do the WCS transformation directly instead of using the pycrates RA/Dec properties to
    # work around intermittent bug https://icxc.harvard.edu/pipe/ascds_help/2013a/0315.html
    wcs = events.get_transform("eqpos")
    evt_x = events.get_column("x").values
    evt_y = events.get_column("y").values
    rd = wcs.apply(np.column_stack([evt_x, evt_y]))
    evt_ra = rd[:, 0]
    evt_dec = rd[:, 1]
    evt_times = events.get_column('Time').values

    # Limit to only using events contained within the range of the aspect solution
    ok_times = (evt_times > asol_times[0]) & (evt_times < asol_times[-1])
    if not np.any(ok_times):
        raise ValueError("No events in region are contained within time range of aspect solution.")
    # Limit this *in place*
    evt_ra = evt_ra[ok_times]
    evt_dec = evt_dec[ok_times]
    evt_times = evt_times[ok_times]

    if len(evt_times) < opt['src_min_counts']:
        v1("Warning only {} counts in src region.  {} minimum suggested 'src_min_counts'".format(
                len(evt_times), opt['src_min_counts']))

    ax_data = {}
    ax_map = {'yag': 'dy',
              'zag': 'dz'}

    ax_data['yag'], ax_data['zag'] = get_event_yag_zag(evt_ra, evt_dec,
                                                       evt_ra_pnt, evt_dec_pnt, evt_roll_pnt)

    if 'DISPLAY' not in os.environ or os.environ['DISPLAY'] == "":
        if plot.backend.name == 'pylab':
            import matplotlib as mpl
            mpl.use('Agg')

    # Store comments to print in block after all of the sherpa fit output
    fit_comments = []
    plot_list = []

    for ax in ['yag', 'zag']:
        fit_data = ax_data[ax] - np.mean(ax_data[ax])
        dset, model = _fit_poly(fit_data, evt_times, opt['corr_poly_degree'])

        mplot = plot.ModelPlot()
        mplot.prepare(dset, model)

        make_fit_plot(evt_times, fit_data, mplot, ax)

        fit_plot = "{}_fit_{}.png".format(opt['corr_plot_root'], ax)
        hardcopy(fit_plot)
        plot_list.append(fit_plot)

        make_data_plot(dset, model, fit_data, ax)

        data_plot = "{}_data_{}.png".format(opt['corr_plot_root'], ax)
        hardcopy(data_plot)
        plot_list.append(data_plot)

        asol_corr = np.interp(asol_times, mplot.x + evt_times[0], mplot.y)
        asol_col_to_fix = asol.get_column(ax_map[ax])
        fit_comments.append("Events show drift range of {:.2f} arcsec in {} axis".format(
                np.max(asol_corr) - np.min(asol_corr), ax))
        fit_comments.append("Max absolute correction of {:.2f} arcsec for {} axis".format(
                np.max(np.abs(asol_corr)), ax))

        # Convert the correction from arcsecs to mm (divide by 20) and add the correction
        # to the dy and dz columns in the file.
        asol_col_to_fix.values += (asol_corr / 20)

        # Add header keys saving the axis-specific parts of this correction
        write_key(asol, "ADC{}MN".format(ax.upper()), np.mean(ax_data[ax]),
                  "Aspect Drift Corr. Mean of uncorr {} data".format(ax))
        for deg in range(0, 1 + opt['corr_poly_degree']):
            write_key(asol, "ADC{}C{}".format(ax.upper(), deg),
                      getattr(model, 'c{}'.format(deg)).val,
                      "Aspect Drift Corr. {} model c{}".format(ax, deg))

    # Add header keywords about fit
    write_key(asol, "ADCTIME0", evt_times[0],
              "Aspect Drift Corr. reference time")
    write_key(asol, "ADCSRCX", opt['x'],
              "Aspect Drift Corr. input src x")
    write_key(asol, "ADCSRCY", opt['y'],
              "Aspect Drift Corr. input src y")
    write_key(asol, "ADCSRCR", opt['radius'],
              "Aspect Drift Corr. input src radius", units='pix')
    write_key(asol, "ADCORDR", opt['corr_poly_degree'],
              "Aspect Drift Corr. model poly degree")
    write_key(asol, "ADCVER", VERSION,
              "Aspect Drift Corr. tool version")

    v2("-" * 60)
    v2("Fit results")
    for c in fit_comments:
        v2("\t{}".format(c))
    v2("-" * 60)
    v2("Writing out corrected aspect solution file to {}".format(opt['outfile']))
    v2("\nYou *must* review the following plots before using this correction:")
    for p in plot_list:
        v2("\t{}".format(p))

    v2("")

    # Actually write out the new aspect solution file
    asol.write(opt['outfile'], clobber=opt['clobber'])

    # Add history
    add_tool_history(opt['outfile'], TOOLNAME, params=opt,
                     toolversion=VERSION)


@handle_ciao_errors(TOOLNAME, VERSION)
def doit():
    "Catch all errors raised in this set of routines"

    opt = process_command_line(sys.argv)
    display_start_info(opt)
    main(opt)


if __name__ == "__main__":
    doit()
