# Licensed under a 3-clause BSD style license - see LICENSE.rst
from functools import cached_property
import numpy as np
import six
try:
from Ska.Matplotlib import cxctime2plotdate, plot_cxctime
except ImportError:
pass
from xija import tmal
[docs]
class Param(dict):
"""Model component parameter. Inherits from dict but adds attribute access
for convenience.
Parameters
----------
Returns
-------
"""
def __init__(
self, comp_name, name, val, min=-1e38, max=1e38, fmt="{:.4g}", frozen=False
):
dict.__init__(self)
self.comp_name = comp_name
self.name = name
self.val = val
self.min = min
self.max = max
self.fmt = fmt
self.frozen = frozen
self.full_name = comp_name + "__" + name
def __setattr__(self, attr, val):
dict.__setitem__(self, attr, val)
def __getattr__(self, attr):
return dict.__getitem__(self, attr)
[docs]
class ModelComponent:
"""Model component base class"""
def __init__(self, model):
self.model = model
self.n_mvals = 0
self.predict = False # Predict values for this model component
self.data = None
self.data_times = None
@cached_property
def pars(self):
return []
@cached_property
def pars_dict(self):
return {}
@property
def n_parvals(self):
return len(self.parvals)
@property
def times(self):
return self.model.times
@property
def model_plotdate(self):
if not hasattr(self, "_model_plotdate"):
self._model_plotdate = cxctime2plotdate(self.model.times)
return self._model_plotdate
def add_par(self, name, val=None, min=-1e38, max=1e38, fmt="{:.4g}", frozen=False):
param = Param(self.name, name, val, min=min, max=max, fmt=fmt, frozen=frozen)
self.pars_dict[name] = param
self.pars.append(param)
def _getAttributeNames(self):
"""Add dynamic attribute names for IPython completer."""
return [par.name for par in self.pars]
def __getattr__(self, attr):
# The following is needed for the IPython completer
if attr == "trait_names":
return []
if attr in self.pars_dict:
return self.pars_dict[attr].val
else:
# This will raise the expected AttributeError exception
return super(ModelComponent, self).__getattribute__(attr)
def __setattr__(self, attr, val):
if attr in self.pars_dict:
self.pars_dict[attr].val = val
else:
super(ModelComponent, self).__setattr__(attr, val)
def _set_mvals(self, vals):
self.model.mvals[self.mvals_i, :] = vals
def _get_mvals(self):
return self.model.mvals[self.mvals_i, :]
mvals = property(_get_mvals, _set_mvals)
def get_par(self, name):
for par in self.pars:
if par.name == name:
return par
raise ValueError('No par named "{}" in {}', self.__class__.__name__)
@property
def name(self):
return self.__str__()
@property
def parvals(self):
return np.array([par.val for par in self.pars])
@property
def parnames(self):
return [par.name for par in self.pars]
def update(self):
pass
def set_data(self, data, times=None):
self.data = data
if times is not None:
self.data_times = times
def get_dvals_tlm(self):
return np.zeros_like(self.model.times)
@property
def dvals(self):
if not hasattr(self, "_dvals"):
if self.data is None:
dvals = self.get_dvals_tlm()
elif isinstance(self.data, np.ndarray):
dvals = self.model.interpolate_data(
self.data, self.data_times, str(self)
)
elif isinstance(
self.data,
(six.integer_types, float, np.integer, np.floating, bool, str),
):
if isinstance(self.data, six.string_types):
dtype = "S{}".format(len(self.data))
else:
dtype = type(self.data)
dvals = np.empty(self.model.n_times, dtype=dtype)
dvals[:] = self.data
else:
raise ValueError(
"Data value '{}' and type '{}' for '{}' component "
"not allowed ".format(self.data, type(self.data).__name__, self)
)
self._dvals = dvals
return self._dvals
[docs]
class TelemData(ModelComponent):
times = property(lambda self: self.model.times)
def __init__(
self, model, msid, mval=True, data=None, fetch_attr="vals", units=None
):
super(TelemData, self).__init__(model)
self.msid = msid
self.n_mvals = 1 if mval else 0
self.predict = False
self.data = data
self.data_times = None
self.fetch_attr = fetch_attr
self.units = units
def get_dvals_tlm(self):
return self.model.fetch(self.msid, attr=self.fetch_attr)
def plot_data__time(self, fig, ax):
lines = ax.get_lines()
if not lines:
plot_cxctime(
self.model.times, self.dvals, ls="-", color="#386cb0", fig=fig, ax=ax
)
ax.grid()
ax.set_title("{}: data".format(self.name))
ylabel = "%s" % self.name
if self.units is not None:
ylabel += " (%s)" % self.units
ax.set_ylabel(ylabel)
ax.margins(0.05)
else:
lines[0].set_data(self.model_plotdate, self.dvals)
def __str__(self):
return self.msid
[docs]
class CmdStatesData(TelemData):
def get_dvals_tlm(self):
return self.model.cmd_states[self.msid]
[docs]
class Node(TelemData):
"""Time-series dataset for prediction.
If the ``sigma`` value is negative then sigma is computed from the
node data values as the specified percent of the data standard
deviation. The default ``sigma`` value is -10, so this implies
using a sigma of 10% of the data standard deviation. If ``sigma``
is set to 0 then the fit statistic is set to 0.0 for this node.
Parameters
----------
model :
parent model
msid :
MSID for telemetry data
name :
component name (default=``msid``)
sigma :
sigma value used in chi^2 fit statistic
quant :
use quantized stats (not currently implemented)
predict :
compute prediction for this node (default=True)
mask :
Mask component for masking values from fit statistic
data :
Node data (None or a single value)
Returns
-------
"""
def __init__(
self,
model,
msid,
sigma=-10,
quant=None,
predict=True,
mask=None,
name=None,
data=None,
fetch_attr="vals",
units="degC",
):
TelemData.__init__(
self, model, msid, data=data, fetch_attr=fetch_attr, units=units
)
self._sigma = sigma
self.quant = quant
self.predict = predict
self.mask = model.get_comp(mask)
self._name = name or msid
def __str__(self):
return self._name
@property
def randx(self):
"""Random X-offset for plotting which is a uniform distribution
with width = self.quant or 1.0
Parameters
----------
Returns
-------
"""
if not hasattr(self, "_randx"):
dx = self.quant or 1.0
self._randx = np.random.uniform(
low=-dx / 2.0, high=dx / 2.0, size=self.model.n_times
)
return self._randx
@property
def sigma(self):
if self._sigma < 0:
self._sigma = self.dvals.std() * (-self._sigma / 100.0)
return self._sigma
@property
def resids(self):
resid = self.dvals - self.mvals
# Zero out residuals for any masked times
for i0, i1 in self.model.mask_times_indices:
resid[i0:i1] = 0.0
return resid
def calc_stat(self):
if self.sigma == 0:
return 0.0
resids = self.resids
if self.mask is not None:
resids = resids[self.mask.mask]
return np.sum(resids**2 / self.sigma**2)
def plot_data__time(self, fig, ax):
lines = ax.get_lines()
if not lines:
plot_cxctime(
self.model.times, self.mvals, ls="-", color="#d92121", fig=fig, ax=ax
)
plot_cxctime(
self.model.times, self.dvals, ls="-", color="#386cb0", fig=fig, ax=ax
)
# Overplot bad time regions in cyan
for i0, i1 in self.model.bad_times_indices:
plot_cxctime(
self.model.times[i0:i1],
self.dvals[i0:i1],
"-c",
fig=fig,
ax=ax,
linewidth=5,
alpha=0.5,
)
ax.grid()
ax.set_title("{}: model (red) and data (blue)".format(self.name))
ax.set_ylabel("Temperature (%s)" % self.units)
else:
lines[0].set_ydata(self.mvals)
def plot_resid__time(self, fig, ax):
lines = ax.get_lines()
resids = self.resids
if self.mask:
resids[~self.mask.mask] = np.nan
for i0, i1 in self.model.mask_times_indices:
resids[i0:i1] = np.nan
if not lines:
plot_cxctime(
self.model.times, resids, ls="-", color="#386cb0", fig=fig, ax=ax
)
# Overplot bad time regions in cyan
for i0, i1 in self.model.bad_times_indices:
plot_cxctime(
self.model.times[i0:i1],
resids[i0:i1],
"-c",
fig=fig,
ax=ax,
linewidth=5,
alpha=0.5,
)
ax.grid()
ax.set_title("{}: residuals (data - model)".format(self.name))
ax.set_ylabel("Temperature (%s)" % self.units)
else:
lines[0].set_ydata(resids)
ax.relim()
ax.autoscale()
def plot_resid__data(self, fig, ax):
lines = ax.get_lines()
resids = self.resids
if self.mask:
resids[~self.mask.mask] = np.nan
for i0, i1 in self.model.mask_times_indices:
resids[i0:i1] = np.nan
if not lines:
ax.plot(
self.dvals + self.randx,
resids,
"o",
markersize=0.25,
color="#386cb0",
markeredgecolor="#386cb0",
)
ax.grid()
ax.set_title("{}: residuals (data - model) vs data".format(self.name))
ax.set_ylabel("Residuals (%s)" % self.units)
ax.set_ylabel("Temperature (%s)" % self.units)
else:
lines[0].set_ydata(resids)
ax.relim()
ax.autoscale()
[docs]
class Coupling(ModelComponent):
"""\
First-order coupling between Nodes `node1` and `node2`
::
dy1/dt = -(y1 - y2) / tau
Parameters
----------
Returns
-------
"""
def __init__(self, model, node1, node2, tau):
ModelComponent.__init__(self, model)
self.node1 = self.model.get_comp(node1)
self.node2 = self.model.get_comp(node2)
self.add_par("tau", tau, min=2.0, max=200.0)
def update(self):
self.tmal_ints = (
tmal.OPCODES["coupling"],
self.node1.mvals_i, # y1 index
self.node2.mvals_i, # y2 index
)
self.tmal_floats = (self.tau,)
def __str__(self):
return "coupling__{0}__{1}".format(self.node1, self.node2)
[docs]
class Delay(ModelComponent):
"""Delay mval from ``node`` by ``delay`` ksec
See the example in examples/delay/. For a positive delay, the computed model
value (``node.mval``) will be constant at the initial value for the first
``delay`` ksec. Conversely for a negative delay the values at the end will
be constant for ``delay`` ksec.
"""
def __init__(self, model, node, delay=0):
super().__init__(model)
self.node = self.model.get_comp(node)
self.add_par("delay", delay, min=-40, max=40)
def __str__(self):
return f"delay__{self.node}"
[docs]
class HeatSink(ModelComponent):
"""Fixed temperature external heat bath"""
def __init__(self, model, node, T=0.0, tau=20.0):
ModelComponent.__init__(self, model)
self.node = self.model.get_comp(node)
self.add_par("T", T, min=-100.0, max=100.0)
self.add_par("tau", tau, min=2.0, max=200.0)
def update(self):
self.tmal_ints = (tmal.OPCODES["heatsink"], self.node.mvals_i) # dy1/dt index
self.tmal_floats = (self.T, self.tau)
def __str__(self):
return "heatsink__{0}".format(self.node)
[docs]
class HeatSinkRef(ModelComponent):
"""Fixed temperature external heat bath, reparameterized so that varying
tau does not affect the mean model temperature. This requires an extra
non-fitted parameter T_ref which corresponds to a reference temperature for
the node.::
dT/dt = U * (Te - T)
= P + U* (T_ref - T) # reparameterization
P = U * (Te - T_ref)
Te = P / U + T_ref
In code below, "T" corresponds to "Te" above. The "T" above is node.dvals.
Parameters
----------
Returns
-------
"""
def __init__(self, model, node, T=0.0, tau=20.0, T_ref=20.0):
ModelComponent.__init__(self, model)
self.node = self.model.get_comp(node)
self.add_par("P", (T - T_ref) / tau, min=-10.0, max=10.0)
self.add_par("tau", tau, min=2.0, max=200.0)
self.add_par("T_ref", T_ref, min=-100, max=100)
def update(self):
self.tmal_ints = (tmal.OPCODES["heatsink"], self.node.mvals_i) # dy1/dt index
self.tmal_floats = (self.P * self.tau + self.T_ref, self.tau)
def __str__(self):
return "heatsink__{0}".format(self.node)
[docs]
class Pitch(TelemData):
def __init__(self, model):
TelemData.__init__(self, model, "pitch", units="deg")
def get_dvals_tlm(self):
vals = self.model.fetch(self.msid, attr=self.fetch_attr)
# Pitch values outside of 45 to 180 are not possible. Normally
# this is geniune bad data that gets sent down in safe mode when
# the spacecraft is at normal sun. So set these values to 90.
bad = (vals >= 180.0) | (vals <= 45.0)
vals[bad] = 90.0
# Spacecraft must operate between 45 and 180 degrees pitch, so clip
# values to that range.
vals.clip(45.001, 179.999, out=vals)
return vals
def __str__(self):
return "pitch"
[docs]
class AcisFPtemp(Node):
"""Make a wrapper around MSID FPTEMP_11 because that currently comes from
the eng_archive in K instead of C.
Parameters
----------
Returns
-------
"""
def __init__(self, model, mask=None):
Node.__init__(self, model, "fptemp_11", mask=mask)
def get_dvals_tlm(self):
fptemp = self.model.fetch(self.msid, "vals", "nearest")
return fptemp - 273.15
def __str__(self):
return "fptemp"
[docs]
class Eclipse(TelemData):
def __init__(self, model):
TelemData.__init__(self, model, "aoeclips")
self.n_mvals = 1
self.fetch_attr = "midvals"
self.fetch_method = "nearest"
def get_dvals_tlm(self):
aoeclips = self.model.fetch(self.msid, "vals", "nearest")
return aoeclips == "ECL "
def update(self):
self.mvals = np.where(self.dvals, 1, 0)
def __str__(self):
return "eclipse"
[docs]
class SimZ(TelemData):
def __init__(self, model):
TelemData.__init__(self, model, "sim_z")
def get_dvals_tlm(self):
sim_z_mm = self.model.fetch(self.msid)
return np.rint(sim_z_mm * -397.7225924607)
[docs]
class Roll(TelemData):
def __init__(self, model):
TelemData.__init__(self, model, "roll", units="deg")