# 
#  Copyright (C) 2010-2019,2021,2023  Smithsonian Astrophysical Observatory
#
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

from pycrates.crateutils import convert_2_str
from pytransform import *
from collections import OrderedDict
from numpy import array
from numpy import array_str
from numpy import ascontiguousarray
from numpy import unpackbits
from numpy import packbits
from numpy import zeros
from numpy import uint8
from numpy import concatenate
from numpy import empty
import sys
import hashlib
import warnings

# Define 'constants' for REGULAR and VIRTUAL
REGULAR = 0
VIRTUAL = 1

class CrateData(object):

    def _get_values(self):
        vals = self._values

        # get VIRTUAL data
        if self.is_virtual() and self.parent is None:
            if self.source is not None:
                vals = self.__transform.apply( self.source.values )
                return vals
            else:
                return None

        # get VECTOR COMPONENT values
        if self.parent is not None:
            ii = 0
            cpts = self.parent.get_cptslist()

            while self.name != cpts[ii] and ii < self.vdim:
                ii += 1

            if self.parent._varlen_itemsize:
                nsets = self.parent._values.shape[0]
                vals = self.parent._values
                tmparr = empty(shape=nsets, dtype=object)

                for jj in range (0, nsets):
                    tmparr[jj] = vals[jj][ii]

                vals = tmparr
            else:

                vals = self.parent.values[::, ii::self.parent.vdim]
                vals_shape = list(vals.shape)

                if len(vals_shape) > 2:
                    if vals_shape[1] == 1:
                        vals_shape.pop(1)
                        vals = vals.reshape( vals_shape )
                else:
                    vals = vals.reshape( vals_shape[0] )

        return vals


    def _set_values(self, invals):
        if self.is_virtual():
            return
        else: 
            # convert image values to array
            if self._image:
                if "array" not in str(type(invals)):
                    invals = array(invals)

            # convert table values to an array
            elif "array" not in str(type(invals)):
                invals = array(invals).flatten()

            # table values already in an array
            elif "array" in str(type(invals)) and (invals is not None and len(invals) > 0):
                # for single row data for all datatypes except bit/byte
                if invals.size == 1 and "uint8" not in str(invals.dtype.type):
                    invals = invals.flatten()

                # for bit/byte values
                if "uint8" in str(invals.dtype.type) and invals.ndim < 2:
                    invals = invals.reshape( [invals.shape[0], 1] )

                # for vector columns without data
                if self.is_vector() and invals.ndim < 2 and len(invals) == 0:
                    invals = invals.reshape( [invals.shape[0], self.vdim] )

                # for variable-length arrays, get longest row array
                if "object" in str(invals.dtype.type):
                    nrows = invals.shape[0]
                    for ii in range(0, nrows):
                        arrlen = invals[ii].size

                        if self.vdim >= 2:
                            arrlen = arrlen/self.vdim

                        if arrlen > self._varlen_itemsize:
                            self._varlen_itemsize = arrlen

            if invals is not None:
                # for python3 array of bytes, convert to array of str/unicode
                if sys.version_info[0] >= 3 and "bytes" in invals.dtype.name:
                    invals = invals.astype('U')

                # for python 2, convert unicode array to str array
                if sys.version_info[0] < 3 and "unicode" in invals.dtype.name:
                    invals = invals.astype('S')

            # If component of parent, set value for just that component
            if self.parent:
                cpt_no = self.parent.get_cptslist().index(self.name)
                self.parent._values[: , cpt_no] = ascontiguousarray(invals)
            else:
                 self._values = ascontiguousarray(invals)

    values = property(_get_values, _set_values)


    def get_name(self):
        return self._name

    def _set_name(self, in_name):
        """
        Sets the name of the CrateData object.
        """
        in_name = convert_2_str(in_name)
        colname = in_name

        if '(' in in_name: 
            colname = str.split(in_name, '(') 
            if len(colname) > 1:
                colname = colname[0]
            warnings.warn("Invalid column name.  '"+ in_name + "' has been changed to '" + colname + "'.") 

        self._name = colname
 
    name = property(get_name, _set_name)

    def get_unit(self):
        return self._unit

    def _set_unit(self, in_unit):
        """
        Sets the unit of the CrateData object.
        """
        if in_unit is None:
            in_unit = ""
        in_unit = convert_2_str(in_unit)
        self._unit = in_unit
 
    unit = property(get_unit, _set_unit)


    def get_desc(self):
        return self._desc

    def _set_desc(self, in_desc):
        """
        Sets the description of the CrateData object.
        """
        if in_desc is None:
            in_desc = "";
        in_desc = convert_2_str(in_desc)
        self._desc = in_desc

    desc = property(get_desc, _set_desc)


    def __init__(self):
        """
        Initializes the CrateData object.
        """
        self.__clear__()
        self.__cptslist = OrderedDict()


    def __clear__(self):
        """
        Clears the CrateData object.
        """
        if hasattr(self, "_CrateData__cptslist") and len(self.__cptslist) != 0:
            self.__cptslist.clear()

        self._name = ""
        self._desc = ""
        self._unit = ""
        self._values = None
        self.vdim = 0
        self.source = None       # for virtual columns; points to its source CrateData object
        self.parent = None       # for vector columns; points to its parent CrateData object
        self.__eltype = REGULAR
        self.__tlmin = None
        self.__tlmax = None
        self.__nullval = None
        self.__transform = None
        self.__crateref = None   # refers to the Crate this object belongs to
        self.__signature = ""
        self._image = False
        self.__bit_array_flag = False
        self._bit_array_itemsize = 0
        self._varlen_itemsize = 0


    def __del__(self):
        """
        Clears and closes the CrateData object.
        """
        self.__clear__()


    def __str__(self):
        """
        Returns a formatted string representation of the CrateData object.
        """
        return self.__repr__()


    def __repr__(self):
        """
        Returns a formatted string representation of the CrateData object.
        """
        dtype = ""

        retstr =  "  Name:     " + self.name  + "\n"

        if not self._image:
            retstr += "  Shape:    " + str(self.get_shape()) + "\n"

        if not self.is_virtual():
            if self._values is not None:
                dtype = str(self._values.dtype.name)
                if "bytes" in dtype:
                    dtype = dtype.replace("bytes", "string")

                if "uint8" in dtype:
                    if self.is_bit_array():
                        retstr += "  Datatype: " + dtype + " | Bit[" + str(self._bit_array_itemsize) + "]\n"
                    else:
                        retstr += "  Datatype: " + dtype + " | Byte\n"
                elif "object" in dtype:
                    retstr += "  Datatype: " + self._values[0].dtype.name + " | Object\n"
                else: 
                    retstr += "  Datatype: " + dtype + "\n"
                if not self._image:
                    retstr += "  Nsets:    " + str(self.get_nsets()) + "\n"
          

        retstr += "  Unit:     " + self.unit + "\n"
        retstr += "  Desc:     " + self.desc + "\n"

        ndim = self.get_ndim()        
        retstr += "  Eltype:   " + self._get_eltype_str() + "\n"

        if self.is_vector():
            retstr += "     NumCpts:   " + str(self.vdim) + "\n"
            retstr += "     Cpts:      " + str(list(self.__cptslist.keys())) + "\n"

        if ndim > 0 and (self.is_virtual() == False or (self.is_virtual() and self.source is not None)):
            retstr += "     Ndim:     " + str(ndim) + "\n"
            retstr += "     Dimarr:   " + str(self.get_dimarr())  + "\n"

        if not self.is_virtual() and "str" not in dtype:
            retstr += "  Range:    "  + "\n"
            retstr += "     Min:   " + str(self.__tlmin) + "\n"
            retstr += "     Max:   " + str(self.__tlmax) + "\n"

        return retstr


    def _get_eltype_str(self):
        """
        Returns a string definition of the element type, which can be 
        one of the following:
           Scalar
           Array
           Vector
           Virtual
           Vector Array
           Virtual Array
           Virtual Vector Array
        """
        retstr = ""
        
        ndim = self.get_ndim()

        if self.is_vector():
            if self.is_virtual():
                retstr += "Virtual Vector " 
            else: 
                retstr += "Vector " 
                      
            if ndim > 0:
                if self.is_varlen():
                    retstr += "Variable-Length "
                retstr += "Array"

        elif self.is_virtual():
            retstr += "Virtual " 

            if ndim > 0:
                if self.is_varlen():
                    retstr += "Variable-Length "
                retstr += "Array"

        else:
            if ndim > 0:
                if self.is_varlen():
                    retstr += "Variable-Length "
                retstr += "Array"
            else:
                retstr += "Scalar"

        return retstr


    def get_eltype(self):
        """
        Returns the element type; either REGULAR or VIRTUAL.
        """
        return self.__eltype


    def _set_eltype(self, intype):
        """
        Assigns value (REGULAR or VIRTUAL) to eltype field. 
        """
        if (intype == REGULAR) or (intype ==VIRTUAL):
            self.__eltype = intype
        else:
            self.__eltype = REGULAR


    def is_varlen(self):
        if self._values is not None and "object" in self._values.dtype.name:
            return True
        if self.parent and self.parent.is_varlen():
            return True
        return False


    def get_fixed_length_array(self):
        if self.is_varlen():
            # make new array
            nrows = self._values.shape[0]

            if self.vdim == 2:
                newarr = zeros(shape=[nrows, self.vdim, self._varlen_itemsize], dtype=self.values[0].dtype)

                for jj in range(0, self.vdim):
                    for ii in range(0, nrows):
                        arrlen = len(self.values[ii][jj])
                        for kk in range(0, arrlen):
                            newarr[ii][jj][kk] = self.values[ii][jj][kk]

            else:
                newarr = zeros(shape=[nrows, self._varlen_itemsize], dtype=self.values[0].dtype)

                for ii in range(0, nrows):
                    newarr[ii][0:len(self.values[ii]):] = self.values[ii][::]
 
        return newarr


    def convert_to_fixed_length_array(self):
        newarr = self.get_fixed_length_array()
        self.values = newarr
        self._varlen_itemsize = 0


    def get_shape(self):
        """
        Retrieves the shape of the data values in the format: (rows, [vdim], [dimarr]).
        """
        vals = None

        if self.is_virtual() and (self.source is not None) and (self.parent is None):
            vals = self.source.values
        elif (self._values is not None) or ((self.parent is not None) and (self.parent._values is not None)):
            vals = self.values
            
        if vals is not None:
            return vals.shape

        return None


    def get_size(self):
        """
        Returns the number of elements per row.
        """
        shape = self.get_shape()
        if shape is None:
            return 0

        vals = None
        if self.is_virtual() and (self.source is not None) and (self.parent is None):
            vals = self.source.values
        elif (self._values is not None) or ((self.parent is not None) and (self.parent._values is not None)):
            vals = self.values
            
        if vals is not None:
            valsize = vals.size
            return valsize/shape[0]

        return 0


    def get_nsets(self):
        """
        Returns the number of rows of data.
        """
        shape = self.get_shape()

        if shape is not None:
            return shape[0]

        return 0


    def get_ndim(self):
        """
        Returns the number of dimensions.
        """
        if self._image == True and (self.values is not None) or self.is_varlen():
            return self.values.ndim
        
        if self.is_virtual() and (self.source is not None) and (self.parent is None):
            if self.source.values is not None:
                if self.is_vector():
                    return self.source.values.ndim - 2
                else:
                    return self.source.values.ndim - 1
 
        if (self._values is not None) or ((self.parent is not None) and (self.parent._values is not None)):
            # is a vector
            if self.is_vector():
                if (self.parent is not None):
                    return self.parent.values.ndim - 2
                else:
                    return self.values.ndim - 2
            else:
                return self.values.ndim - 1
            
        return 0


    def get_dimarr(self):
        """
        Returns an array containing size of each dimension.
        """
        
        shape = self.get_shape()

        if shape is None:
            return None

        if self._image == True:
            return shape

        if self.is_varlen():
            return array(self._varlen_itemsize)

        dimarr = None
        length = len(shape)
 
        if length > 1:
            # is a vector
            if self.is_vector():
                dimarr = shape[2: length]
            else:
                dimarr = shape[1: length]

        return dimarr


    def is_virtual(self):
        """
        Returns whether the CrateData object is Virtual.
        """
        if self.__eltype == VIRTUAL:
            return True
        return False


    def is_vector(self):
        """
        Returns whether the CrateData object contains vector data or not.
        """
        if self.vdim > 1:
            return True

        return False


    def _get_full_name(self):
        """
        Returns the 'full' name of a column.

        For non-vector columns, this is just the name.

        For vectors, combines the vector name with the name of its 
        components in the following format:  vector(cpt1, cpt2, etc.)
        """
        retstr = self.name 

        if self.is_vector():
            retstr += "("
            
            cptslist = list(self.__cptslist.keys())
            for ii in range(self.vdim):
                if ii > 0 and ii < self.vdim:
                    retstr += ", "
                retstr +=  cptslist[ii] 

            retstr += ")"

        return retstr


    def _set_crate_ref(self, crate):
        """
        Assigns the reference to the parent Crate.
        """
        self.__crate_ref = crate

    def _get_crate_ref(self):
        """
        Retrieves the reference to the parent Crate.
        """
        return self.__crate_ref


    def get_vector_component(self, cpt):
        """
        Retrieves vector component specified by name or number (zero-based indexing).
        """
        cpt = convert_2_str(cpt)

        if self.is_vector() == False:
            raise TypeError("This is not a vector.")

        cptname = None
        cptslist = list(self.__cptslist.keys())

        if isinstance(cpt, str):
            for name in (cptslist):
                if name.upper() == cpt.upper():
                    cptname = name

            if cptname is None:
                raise Exception("Requested component not found.")
            
        if isinstance(cpt, int):            
            if cpt < 0 or cpt >= self.vdim:
                raise IndexError("Index out of range.")

            cptname = cptslist[cpt]

        if not cptname:
            raise KeyError("Input must be a string or a number.")

        return self.__cptslist[ cptname ]


    def get_cptslist(self):
        """
        Returns an array containing the component names if the
        CrateData object is a vector.
        """
        if self.is_vector():
            return list(self.__cptslist.keys())

        return None

    def _set_cptslist(self, in_dict):
        """
        Sets the vector components list.
        """
        if isinstance(in_dict, OrderedDict):
            self.__cptslist = in_dict
            self.vdim = len(self.__cptslist)
        else:
            raise TypeError("Input must be a OrderedDict.")


    def set_transform(self, in_trans):
        """
        Sets the transform.
        """
        if isinstance(in_trans, Transform):
            self.__transform = in_trans

            if self.get_eltype() == REGULAR:
                self.__eltype = VIRTUAL
        else:
            raise TypeError("Input must be a Transform.")


    def get_transform(self):
        """
        Returns the transform.
        """
        return self.__transform


    def get_transform_matrix(self):
        """
        Returns matrix derived from transform.
        """
        if self.__transform is not None:
            return self.__transform.get_transform_matrix()
        else:
            raise Exception("No transform found.  Unable to retrieve transform matrix.")


    def _set_transform_matrix(self, in_matrix):
        """
        Sets the transform matrix.
        """
        # check that in_matrix in 3X3
        self.__transform.set_transform_matrix( in_matrix )


    def is_modified(self):
        current_sig = self.__calculate_signature()

        if current_sig != self.get_signature():
            return True
        return False


    def __calculate_signature(self):
        sigstr = self.__str__()

        if self.values is not None:
            sigstr += str(self.values)
        
        return hashlib.sha256( sigstr.encode('utf-8') ).hexdigest()


    def get_signature(self):
        """
        Retrieves the stored checksum of the CrateData.
        """
        
        if len(self.__signature) == 0:
            self.update_signature()
        return self.__signature


    def update_signature(self):
        """
        Recalculates and stores the new checksum.
        """
        self.__signature = self.__calculate_signature()
        

    def get_tlmin(self):
        return self.__tlmin

    def _set_tlmin(self, in_min):
        self.__tlmin = in_min


    def get_tlmax(self):
        return self.__tlmax

    def _set_tlmax(self, in_max):
        self.__tlmax = in_max


    def get_nullval(self):
        return self.__nullval

    def _set_nullval(self, in_null):
        if in_null is None:
            self.__nullval = None
            return

        # if input is not an integer, try to convert
        if self._image and not isinstance(in_null, int): 
            try:
                in_null = int(in_null)
            except:
                raise TypeError("Image null value must be an integer.  Unable to convert input value to integer.")

        self.__nullval = in_null


    def is_bit_array(self):
        if self.__bit_array_flag == True and self._bit_array_itemsize > 0:
            return True

        return False


    def _set_bit_array_flag(self, inflag):
        self.__bit_array_flag = inflag


    def get_bit_array_flag(self):
        return self.__bit_array_flag


    def convert_bits_to_bytes(self):
        """
        Converts binary-valued bit array values to bytes.
        """
        if self.__bit_array_flag == False:
            warnings.warn("Values are already in byte representation.")
            return

        self.__bit_array_flag = False

        if (self._values is not None) and (len(self._values) > 0):
            self.values = packbits(self._values, axis=-1)

        self._bit_array_itemsize = 0


    def convert_bytes_to_bits(self, itemsize=8):
        """
        Converts byte values to binary-valued bit arrays.
        
        itemsize  -  indicates the length of the bit array
                     default=8
        """

        if itemsize < 1 :
            raise ValueError("Bit array itemsize must be greater than 0.")

        if self.__bit_array_flag == True:
            warnings.warn("Values are already in binary format.")
            return

        self.__bit_array_flag = True

        if (self._values is not None) and (len(self._values) > 0):
            self.values = unpackbits(self._values, axis=1)
            
        self.resize_bit_array(itemsize)


            
    def resize_bit_array(self, itemsize=8):
        """
        Increases or decreases the size of the bit array.  Data values must 
        already be in a binary-valued bit array.  When expanding the array, 
        values will be appended to the right end of the array. 

        itemsize  -  indicates the length of the bit array
                     default=8

        """

        if itemsize < 1 :
            raise ValueError("Bit array itemsize must be greater than 0.")

        if self.__bit_array_flag == False:
            raise TypeError("Values are not in binary format.")

        if itemsize != self._bit_array_itemsize:
            if (len(self._values) > 0):
 
                self.values = self.values[:,:itemsize]

                if itemsize > self._bit_array_itemsize:
                    valshape = self.values.shape

                    if len(valshape) == 2 and itemsize != valshape[1]:
                        pos1 = valshape[0]
                        pos2 = itemsize - valshape[1]

                        newvals= zeros(shape=(pos1, pos2), dtype=uint8)
                        self.values = concatenate( (self.values, newvals), axis=1)

            self._bit_array_itemsize = itemsize
