#!/usr/bin/env python
"""
Apply a correction to fid light centroid values to compensate for periscope
offsets determined via periscope gradient telemetry.

Inputs
---------
::

 --acenfile=ACENFILE   Aspect L1 ACEN file
 --fidpropsfile=FIDPROPSFILE
                       Aspect L1 FIDPROPS file
 --nsmooth=NSMOOTH     Gradient telem window size (n telem samples)
                       default is (10000(s)/32.8(s)).

Outputs
---------
::

 --acenfile=ACENFILE   Aspect L1 ACEN file (ang_y_sm and ang_z_sm columns are updated
                       in place)
 --plotfile

Description
-----------
::

Prototype python to perform the periscope fid correction.  See the ECR for details:
ECR/ECR_pipe_perifidcorr.rst

"""
__docformat__ = "restructuredtext"

import pyfits
import numpy
import numpy as np
import sys
from glob import glob
from Ska.engarchive import fetch


GRADIENTS = dict(OOBAGRD3=dict(yag=6.98145650e-04,
                               zag=9.51578351e-05,
                               ),
                 OOBAGRD6=dict(yag=-1.67009240e-03,
                               zag=-2.79084775e-03,
                               ))

def main():
    opt, args = get_options()
    # Read the input data files. 
    fidprops = pyfits.open(glob(opt.fidpropsfile)[0])[1].data 
    acen_hdus = pyfits.open(glob(opt.acenfile)[0], mode='update')
    acen = acen_hdus[1].data
    # Make sure there is at least one good fid to correct
    if (fidprops.field('id_status') == 'GOOD').shape[0] == 0:
        print 'No good fids - skipping corr_fid_cent'
        sys.exit(0)
    ## Calculate detector housing temperature using ACIS eng. telemetry and fid counts
    #times, dh_temp = calc_dh_temp(acis0eng, fidprops, acen, opt)
    gradients = fetch_tilt_telem(acen, GRADIENTS.keys())
    # Apply centroid correction to acen values in-place
    apply_cent_corr(gradients, acen, fidprops, opt)
    # For development, optionally make a plot showing results of correction
    if opt.plotfile:
        make_plot(opt.plotfile, gradients, acen, fidprops, opt)
    # Write updated values back to original file, OR do whatever
    # makes sense within pipeline processing.
    acen_hdus.flush()
    

def get_options():
    """Get program options."""
    from optparse import OptionParser
    global opt, args
    parser = OptionParser()
    parser.set_defaults()
    parser.add_option("--acenfile",
                      default='pcadf*_acen1.fits*',
                      help="Aspect L1 ACEN file",
                      )
    parser.add_option("--fidpropsfile",
                      default='pcadf*_fidpr1.fits*',
                      help="Aspect L1 FIDPROPS file",
                      )
    parser.add_option("--nsmooth",
                      default=304,
                      help="n samples for telem smoothing.  Default=10000/32.8s",
                      )
    parser.add_option("--plotfile",
                      default='corr_plot.png',
                      help="Output file for diagnostic plot of fid centroid adjustments",
                      )
    # args = file names
    (opt, args) = parser.parse_args()
    return (opt, args)


def fetch_tilt_telem(acen, msids):
    # fetch MSID during aspect interval from engineering archive
    # padding fetch by an arbitrary amount (100s) 
    gradients = fetch.MSIDset( msids,
                               start=min(acen.field('time'))-100,
                               stop=max(acen.field('time'))+100)
    return gradients
    

def apply_cent_corr(gradients, acen, fidprops, opt=None):
    """
    The 'ang_y_sm' and 'ang_z_sm' columns of the 'acen' bintable HDU are modified in place.
    """

    # In this routine we do NOT select on status or alg -- all fid centroids get corrected.
    for fid in fidprops:

        ok = np.where(acen.field('slot') == fid['slot'])
        ang_y_sm = acen.field('ang_y_sm')
        ang_z_sm = acen.field('ang_z_sm')
        acen_times = acen.field('time')

        # If making a plot (for TEST ONLY) then copy the smoothed values into the unsmoothed cols
        if opt.plotfile:
            acen.field('ang_y')[ok] =  ang_y_sm[ok]
            acen.field('ang_z')[ok] =  ang_z_sm[ok]

        for msid in gradients.keys():
            # Find indexes into 'times' of the acen time stamps.  Temps vary slowly so don't worry
            # about interpolating values, just use picks from searchsorted.
            times = gradients[msid].times
            acen_idx = np.searchsorted(times, acen_times[ok])
            # find a mean gradient, because this calibration is relative to mean
            mean_gradient = np.mean(gradients[msid].vals[acen_idx])
            # and smooth the telemetry to deal with slow changes and large step sizes...
            smooth_gradient = smooth(gradients[msid].vals, window_len=opt.nsmooth)
            # Make the actual centroid corrections for the y and z axes 
            ang_y_sm[ok] -= (smooth_gradient[acen_idx] - mean_gradient) * GRADIENTS[msid]['yag'] 
            ang_z_sm[ok] -= (smooth_gradient[acen_idx] - mean_gradient) * GRADIENTS[msid]['zag'] 


    return 
        
def make_plot(plotfile, gradients, acen, fidprops, opt):
    """ Make a plot of the before (using the unsmoothed centroids) and after centroids
    for each fid light."""
    from matplotlib import pylab as pl
            
    t0 = acen[0].field('time')
    n_fid = fidprops.shape[0]
    pl.figure(1, figsize=(7,8))
    pl.clf()
    iplot = 1

    ref_ax1 = pl.subplot(n_fid+1, 2, iplot)
    pl.setp( ref_ax1.get_xticklabels(), visible=False)

    for fid in fidprops:
        ok = np.logical_and(acen.field('alg')==8, acen.field('slot')==fid.field('slot'))
        cen = acen[ok]
        t = (cen.field('time') - t0)/1000.
        if iplot != 1:
            ax = pl.subplot(n_fid+1, 2, iplot, sharex=ref_ax1)
            pl.setp( ax.get_xticklabels(), visible=False)
        pl.plot(t, cen.field('ang_y')*3600, 'b')
        pl.plot(t, cen.field('ang_y_sm')*3600, 'r', alpha=.8)
        if iplot==1:
            pl.title('Y angle (arcsec)')
        pl.grid()
        iplot += 2

    iplot = 1
    ref_ax2 = pl.subplot(n_fid+1, 2, iplot+1)
    pl.setp( ref_ax2.get_xticklabels(), visible=False)
    for fid in fidprops:
        pl.ylabel(fid.field('id_string'))
        if iplot != 1:
            ax = pl.subplot(n_fid+1, 2, iplot+1, sharex=ref_ax2)
            pl.setp( ax.get_xticklabels(), visible=False)
        pl.plot(t, cen.field('ang_z')*3600, 'b')
        pl.plot(t, cen.field('ang_z_sm')*3600, 'r', alpha=.8)
        if iplot==1:
            pl.title('Z angle (arcsec)')
        pl.grid()
        iplot += 2


    msid = 'OOBAGRD3'
    pl.subplot(n_fid+1, 2, iplot, sharex=ref_ax1)
    pl.plot((gradients[msid].times - t0)/1000,
            smooth(gradients[msid].vals, window_len=opt.nsmooth)
            - np.mean(gradients[msid].vals))
    pl.xlabel('Time (ksec)')
    pl.title(msid)
    pl.grid()

    msid = 'OOBAGRD6'
    pl.subplot(n_fid+1, 2, iplot+1, sharex=ref_ax2)
    pl.plot((gradients[msid].times - t0)/1000,
            smooth(gradients[msid].vals, window_len=opt.nsmooth)
            - np.mean(gradients[msid].vals))
    pl.xlabel('Time (ksec)')
    pl.title(msid)
    pl.grid()


    pl.savefig(plotfile)
    pl.show()


def smooth(x,window_len=10,window='hanning'):
    """
    Smooth the data using a window with requested size.
    
    This method is based on the convolution of a scaled window with the signal.
    The signal is prepared by introducing reflected copies of the signal 
    (with the window size) in both ends so that transient parts are minimized
    in the begining and end part of the output signal.
    
    Example::

      t = linspace(-2, 2, 50)
      y = sin(t) + randn(len(t)) * 0.1
      ys = Ska.Numpy.smooth(y)
      plot(t, y, t, ys)
    
    See also::

      numpy.hanning, numpy.hamming, numpy.bartlett, numpy.blackman, numpy.convolve
      scipy.signal.lfilter

    :param x: input signal 
    :param window_len: dimension of the smoothing window
    :param window: type of window ('flat', 'hanning', 'hamming', 'bartlett', 'blackman')

    :rtype: smoothed signal
    """

    if x.ndim != 1:
        raise ValueError, "smooth only accepts 1 dimension arrays."

    if x.size < window_len:
        raise ValueError, "Input vector needs to be bigger than window size."


    if window_len<3:
        return x

    if not window in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']:
        raise ValueError, "Window is on of 'flat', 'hanning', 'hamming', 'bartlett', 'blackman'"

    s=numpy.r_[2*x[0]-x[window_len:1:-1],x,2*x[-1]-x[-1:-window_len:-1]]

    if window == 'flat': #moving average
        w=numpy.ones(window_len,'d')
    else:
        w=eval('numpy.'+window+'(window_len)')

    y=numpy.convolve(w/w.sum(),s,mode='same')
    return y[window_len-1:-window_len+1]


if __name__ == '__main__':
     main()