#!/usr/bin/env python
#
# Copyright (C) 2023 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#


'Estimate dither parameters from the aspect solution file'


import sys
import os
from tempfile import NamedTemporaryFile
import numpy as np
import ciao_contrib.logger_wrapper as lw
from pycrates import read_file

__toolname__ = "get_dither_parameters"
__revision__ = "24 February 2023"

lw.initialize_logger(__toolname__)
verb0 = lw.get_logger(__toolname__).verbose0
verb1 = lw.get_logger(__toolname__).verbose1
verb2 = lw.get_logger(__toolname__).verbose2


def get_amplitude(vals, lo_thresh=0.01, hi_thresh=0.99):
    """Compute amplitude of dither

    To avoid the dither "tails" we sort the data and
    omit the bottom and top 1%

    TODO: Maybe make lo/hi tool parameters?
    """
    sort_vals = np.sort(vals)
    nvals = len(sort_vals)
    q05 = int(lo_thresh*nvals)
    q95 = int(hi_thresh*nvals)
    peak_to_peak = sort_vals[q95] - sort_vals[q05]
    amplitude = peak_to_peak/2.0
    return amplitude


def fold_vals(tmpfile, colname):
    "Use period folding technique"

    _nphase = 10  # Number of phase bins (kinda arbitrary)

    tab = read_file(tmpfile)
    time_vals = tab.get_column("time").values
    detxy_vals = tab.get_column(colname).values

    delta_t = tab.get_key_value("TSTOP") - tab.get_key_value("TSTART")
    delta_t = int(delta_t)

    if delta_t < 200:
        verb0(f"Input file only covers {delta_t} seconds and is " +
              "too short for period folding")
        return np.nan, np.nan

    max_period = min(delta_t, 3500)
    _periods = np.arange(200, max_period)

    phase_lo = np.arange(_nphase)/float(_nphase)
    phase_hi = (1.0+np.arange(_nphase))/_nphase

    mean_vals = []
    sig_vals = []
    range_vals = []
    phase_array = np.zeros(_nphase)

    for period in _periods:
        phase_vals, _ = np.modf(time_vals/period)

        for phase_bin in range(_nphase):
            lo = phase_lo[phase_bin]
            hi = phase_hi[phase_bin]
            idx, = np.where(np.logical_and(phase_vals >= lo,
                                           phase_vals < hi))
            phase_array[phase_bin] = np.mean(detxy_vals[idx])

        mean_vals.append(np.mean(phase_array))
        sig_vals.append(np.std(phase_array))
        range_vals.append(np.max(phase_array)-np.min(phase_array))

    max_idx = np.argmax(sig_vals)

    return range_vals[max_idx]/2.0, _periods[max_idx]


def fit_vals(tmpfile, colname):
    """Fit the dety|detz values with a cosine+constant"""

    from sherpa.astro.ui import load_arrays, set_staterror, set_stat
    from sherpa.astro.ui import set_source, set_method, fit, freeze
    from sherpa.utils.logging import SherpaVerbosity
    from sherpa import models

    verb1(f"Fitting for {colname}")

    tab = read_file(tmpfile)
    time_vals = tab.get_column("time").values*1.0  # make a copy
    time_vals -= time_vals[0]
    time_vals /= 1000.0   # ksec, cosine hard max of '10'

    detxy = tab.get_column(colname).values

    load_arrays(1, time_vals, detxy)

    errs = np.zeros_like(detxy)+0.001
    set_staterror(1, errs)
    set_stat("leastsq")

    wave = models.Cos('wave')
    offset = models.Const1D('offset')

    # Technically constant term is not needed because we freeze the
    # value at 0, but left here incase we decide we want to
    # thaw it at any point.
    set_source(wave + offset)

    # Sensible defaults for many observations
    offset.c0 = 0
    wave.ampl = 0.002
    wave.period = 1.0

    wave.ampl.min = 0.00001
    wave.ampl.max = 0.01
    wave.period.min = 0.1
    wave.period.max = 3
    wave.offset.min = 0
    wave.offset.max = 3

    freeze(offset.c0)

    set_method("neldermead")
    with SherpaVerbosity("WARN"):
        fit()

        # In testing we found when fit goes "bad", the amplitude
        # becomes unrealistically small.  So check the amplitude
        # and if too small, then try fit again.

        if wave.ampl.val < 0.5*(max(detxy)-min(detxy))/2.0:
            # If we have a bad fit, try again w/ different default
            # period
            verb1("Trying with different default period")
            wave.period = 0.707
            wave.ampl = 0.002
            fit()

        if wave.ampl.val < 0.5*(max(detxy)-min(detxy))/2.0:
            # If we have a bad fit, try again w/ different
            # fit method.
            verb1("Trying with moncar")
            wave.period = 0.707
            wave.ampl = 0.002
            set_method("moncar")
            fit()

    return wave.ampl.val, wave.period.val*1000.0


def make_dy_dz(infile, tmpdir):
    """Convert ra/dec into a detector-ish coordinates. Subtract off
    the mean value and rotate by ROLL.
    """

    def _rotate(ra_vals, dec_vals, roll_vals):
        'Rotate ra/dec to dety, detz'

        # Subtract off the mean (center around 0)
        ra_off = ra_vals - np.mean(ra_vals)
        dec_off = dec_vals - np.mean(dec_vals)

        # Scale ra by cos(dec)
        cos_dec = np.cos(np.deg2rad(dec_vals))
        ra_off *= cos_dec

        # Rotate ra/dec into dety,detz coordinates
        cos_roll = np.cos(np.deg2rad(90+roll_vals))
        sin_roll = np.sin(np.deg2rad(90+roll_vals))
        det_z = ra_off * cos_roll + dec_off * sin_roll
        det_y = -ra_off * sin_roll + dec_off * cos_roll

        return det_y, det_z

    tab = read_file(infile+"[cols time,ra,dec,roll]")

    ra_vals = tab.get_column("ra").values
    dec_vals = tab.get_column("dec").values
    roll_vals = tab.get_column("roll").values

    ontime = tab.get_key_value("TSTOP") - tab.get_key_value("TSTART")
    if ontime < 2000.0:
        # If less than 2ksec, then FFTs become less useful
        verb0(f"WARNING: ONTIME={ontime/1000.0:.4g} ksec is less than" +
              " 2.0 ksec which may lead to inaccurate dither " +
              "parameter estimates.")

    det_y, det_z = _rotate(ra_vals, dec_vals, roll_vals)

    # We go ahead an get amplitude since we have the dety and detz values
    det_y_amp = get_amplitude(det_y)
    det_z_amp = get_amplitude(det_z)

    if det_y_amp > 1.0 or det_z_amp > 1.0:
        # If amplitude it too big, then we've wrapped in ra so fix it.
        ra_vals[ra_vals < 180.0] += 360.0
        det_y, det_z = _rotate(ra_vals, dec_vals, roll_vals)

        det_y_amp = get_amplitude(det_y)
        det_z_amp = get_amplitude(det_z)

    # Write values
    from crates_contrib.utils import add_colvals
    add_colvals(tab, "det_y", det_y, desc="pseudo dety coordinate")
    add_colvals(tab, "det_z", det_z, desc="pseudo detz coordinate")

    tmpfile = NamedTemporaryFile(dir=tmpdir, suffix="_dydz.fits",
                                 delete=False)
    tmpfile.close()
    tmpfile = tmpfile.name
    tab.write(tmpfile, clobber=True)

    return tmpfile, det_y_amp, det_z_amp


def fft_vals(infile, column, tmpdir):
    """Determine the dither frequency

    We find the frequency by taking the powerspectrum of the
    DET coordinates.  In this rotated reference frame, there should only
    be 1 strong frequency (as opposed to doing in RA|Dec where we would see
    both frequencies (unless ROLL is ~n*90 degrees).

    We also need to allow for case of finding no significant frequencies
    """

    from ciao_contrib.runtool import make_tool

    _tmpfile = NamedTemporaryFile(dir=tmpdir, suffix="_power.fits",
                                  delete=False)
    _tmpfile.close()
    _tmpfile = _tmpfile.name
    apowerspectrum = make_tool("apowerspectrum")
    apowerspectrum(f"{infile}[cols time,{column}]", "", _tmpfile,
                   clobber=True, crop=True)

    tab = read_file(_tmpfile)
    freq = tab.get_column("frequency").values
    power = tab.get_column("data").values
    del tab
    os.unlink(_tmpfile)

    period = np.zeros_like(freq)
    period[1:] = 1.0/freq[1:]

    max_idx = np.argmax(power)
    return period[max_idx]


@lw.handle_ciao_errors(__toolname__, __revision__)
def doit():
    'Main routine'

    from ciao_contrib.param_soaker import get_params
    pars = get_params(__toolname__, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": verb1})

    tmpfile, det_y_amp, det_z_amp = make_dy_dz(pars["infile"],
                                               pars["tmpdir"])

    if pars["method"] == "fft":
        dety_period = fft_vals(tmpfile, "det_y", pars["tmpdir"])
        detz_period = fft_vals(tmpfile, "det_z", pars["tmpdir"])
    elif pars["method"] == "fit":
        det_y_amp, dety_period = fit_vals(tmpfile, "det_y")
        det_z_amp, detz_period = fit_vals(tmpfile, "det_z")
    elif pars["method"] == "fold":
        det_y_amp, dety_period = fold_vals(tmpfile, "det_y")
        det_z_amp, detz_period = fold_vals(tmpfile, "det_z")
    else:
        raise ValueError(f"Unknown value for method={pars['method']}")

    det_y_amp *= 3600.0   # convert to arcsec
    det_z_amp *= 3600.0

    os.unlink(tmpfile)

    from ciao_contrib.param_wrapper import open_param_file as opf
    import paramio as pio

    pfile = opf(sys.argv, toolname=__toolname__)["fp"]
    pio.pputd(pfile, "dety_amplitude", det_y_amp)
    pio.pputd(pfile, "detz_amplitude", det_z_amp)
    pio.pputd(pfile, "dety_period", dety_period)
    pio.pputd(pfile, "detz_period", detz_period)
    pio.paramclose(pfile)

    verb1("Results:")
    verb1("\tAmplitude\tPeriod")
    verb1("\t[arcsec]\t[sec]")
    verb1(f"DETY\t{det_y_amp:.3f}\t\t{dety_period:.3f}")
    verb1(f"DETZ\t{det_z_amp:.3f}\t\t{detz_period:.3f}")


if __name__ == '__main__':
    doit()
