#!/usr/bin/env python
#
# Copyright (C) 2013,2016, 2018 Smithsonian Astrophysical Observatory
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

toolname = "reg2tgmask"
__revision__ = "28 November 2018"


# This is only needed for development.
import os
import sys

try:
    if not __file__.startswith(os.environ['ASCDS_INSTALL']):
        _thisdir = os.path.dirname(__file__)
        _libname = "python{}.{}".format(sys.version_info.major,
                                        sys.version_info.minor)
        _pathdir = os.path.normpath(os.path.join(_thisdir, '../lib', _libname, 'site-packages'))
        if os.path.isdir(_pathdir):
            os.sys.path.insert(1, _pathdir)
        else:
            print("*** WARNING: no {}".format(_pathdir))

        del _libname
        del _pathdir
        del _thisdir
except KeyError:
    raise IOError('Unable to find ASCDS_INSTALL environment variable.\nHas CIAO been started?')



import sys as sys


import ciao_contrib.logger_wrapper as lw
lw.initialize_logger(toolname)
lgr = lw.get_logger(toolname)
verb0 = lgr.verbose0
verb1 = lgr.verbose1
verb2 = lgr.verbose2
verb3 = lgr.verbose3
verb5 = lgr.verbose5


from cxcdm import *

#The required columns and a function to convert the string to correct datatype
required_tags = {
    "rowid" : str,
    "tg_part" : int ,
    "tg_srcid" : int ,
    "tg_m": int
    }

def get_list_of_tags( mm ):
    """
    Parse the ds9 region syntax.

    shape(values) # stuff stuff tag={stuff} stuff tag={stuff}

    We want to pull out the tag={stuff} values

    """

    # check that the required tag's are in the row, if not throw exception
    for tt in required_tags:
        if tt not in mm:
            raise IOError("Missing tagging information in row {}".format(mm))

    # Split region line at the '#', 2nd element has the info we want
    shape_tag_split = mm.split("#")
    tagstr = shape_tag_split[1]

    # Tags are space separated.  Technically tag value can
    # contain spaces, but none that we are interested in so
    # this should be ok.
    taglist = tagstr.split(" ")

    # There may be other things after the "#" so we just pull out
    # the tag values
    taglist = [ x for x in taglist if x.startswith("tag={")] #filter( lambda x: x[0:5] == "tag={", taglist )

    return taglist


def add_tag_values_to_keep( mytags, keep_tag_values ):
    """
    parse the tags.  The ones we want will be

      tag={name=value}

    """
    # separate name and value (-1 to strip of trailing '}'
    tagval = mytags[5:-1].split("=")
    name = tagval[0]
    val  = tagval[1]

    # if tag is one we want, then keep it.
    if name in required_tags:

        # The value for each element in the
        # required_tags dictionary is a function that will
        # convert the string into the required datatype.
        val = required_tags[name](val)
        keep_tag_values[ name ].append( val )


def validate_ds9_region_format(myreg):
    if "# Region file format: DS9 version 4" not in myreg[0]:
        raise IOError("Needs ds9 region format")

    # find the global properties line, it starts with 'global'
    global_line = next( ii for ii in range( len(myreg) ) if 'global' == myreg[ii][0:6] )

    # the coordinate system is next line
    coord_line = global_line + 1

    if myreg[coord_line] != "physical":
        raise IOError("DS9 region must be saved in physical coords")

    # next line after coords is start of shapes; all must be polygons
    shapes_ln = coord_line + 1
    if not all([xx[0:7] == "polygon" for xx in myreg[shapes_ln:]] ):
        raise IOError("Only polygon regions are allowed")

    return shapes_ln


def validate_regions( infile ):
    """
    Make sure returned file has everything that is needed

    """
    with open( infile, "r") as fp:
        myreg = [ x.strip("\n") for x in fp.readlines()]

    shapes_ln = validate_ds9_region_format( myreg )

    #
    # We now want to split the lines
    #   polygon( stuff ) # tag={} tag={} tag={} other stuff
    # What we want is basically the values for each tag in a column/list
    # so that it can be added to output file

    # setup output column lists
    keep_tag_values = {}
    for r in required_tags:
        keep_tag_values[r] = []

    # Loop over remaing rows
    for mm in myreg[shapes_ln:]:
        taglist = get_list_of_tags(mm)

        # Loop over tag values
        for mytags in taglist:
            add_tag_values_to_keep( mytags, keep_tag_values )

    return keep_tag_values


def setup_fits_output( infile, origfile, outfile ):
    """
    Create the FITS region part of the file using dm module.  All
    the stuff about array sizes/etc are taken care  of
    automatically.
    """
    from region import regParse
    myreg = regParse( "region({})".format( infile ))

    myds = dmDatasetCreate( outfile )

    # HACK: dmTableWriteRegion should allow a None/NULL for descriptor,but doesn't
    # so we need to supply one.
    mydr = dmTableOpenColumn( dmTableOpen( origfile ), "tg_r")
    my_block = dmTableWriteRegion( myds, "REGION", mydr, myreg )

    # HACK:  No way to rename columns with cxcdm nor with pycrates
    dmKeyWrite( my_block, "TTYPE1", "tg_r")
    dmKeyWrite( my_block, "TTYPE2", "tg_d")

    return myds, my_block


def write_required_cols( my_block, keep_tag_values):
    """
    Write the values to the required columns from the tags returned
    from ds9
    """
    for rr in required_tags:
        # If string col, determine max string length
        ll = 0
        if required_tags[rr] == str:
            ll=max([len(x) for x in keep_tag_values[rr]])

        rc = dmColumnCreate( my_block, rr, required_tags[rr], itemsize=ll)
        dmSetData( rc[0], keep_tag_values[rr],1 )

def workaround_tgextract2_nan_padding_bug( outfile):
    """
    tgextract2 wants the values in the arrays to be padded with
    0's instead of NaNs which is now the norm/standard.

    We use a little pycrates magic to do that.
    """
    import numpy as np
    import pycrates as pyc
    mf = pyc.read_file(outfile, "rw")
    rr = mf.get_column("tg_r").values
    rr[ np.isnan(rr)]=0
    dd = mf.get_column("tg_d").values
    dd[ np.isnan(dd)]=0
    mf.write()


def locate_caldb_file():
    """
    Locate the TGMASK2 file in the CALDB
    """
    import caldb4 as cal

    mycal = cal.Caldb("chandra", "hrc", product="TGMASK2")
    myfiles = mycal.search
    if len(myfiles) != 1:
        raise IOError("Unexpected number of CALDB files found")
    myfile = myfiles[0].split("[")

    return myfile[0]



@lw.handle_ciao_errors(toolname, __revision__)
def main():
    """
    Main routine to parse ds9 region file w/ tgmask tags and
    save to FITS mask (region) file.
    """

    import ciao_contrib.param_soaker as ps
    from ciao_contrib._tools.fileio import outfile_clobber_checks
    import os as os

    pars = ps.get_params( toolname, "rw", sys.argv,
        verbose={"set":lw.set_verbosity, "cmd":verb1} )

    infile = pars["infile"]
    origfile = pars["srcfile"]
    outfile = pars["outfile"]
    clobber = (pars["clobber"] == "yes")

    outfile_clobber_checks( clobber, outfile )
    if os.path.exists( outfile ):
        os.remove( outfile )

    if 'CALDB' == origfile:
        origfile = locate_caldb_file()

    tagvalues = validate_regions( infile )
    myds, myblock = setup_fits_output( infile, origfile, outfile )
    write_required_cols( myblock, tagvalues)
    dmBlockClose( myblock)
    dmDatasetClose( myds)

    workaround_tgextract2_nan_padding_bug( outfile )



if __name__ == "__main__":
    try:
        main()
    except Exception as E:
        print("\n# "+toolname+" ("+__revision__+"): ERROR "+str(E)+"\n", file=sys.stderr)
        sys.exit(1)
    sys.exit(0)
