-- --8<--8<--8<--8<--
--
-- Copyright (C) 2009 Smithsonian Astrophysical Observatory
--
-- This file is part of rdb-lua
--
-- rdb_lua is free software; you can redistribute it and/or
-- modify it under the terms of the GNU General Public License
-- as published by the Free Software Foundation; either version 2
-- of the License, or (at your option) any later version.
--
-- rdb_lua is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program; if not, write to the
--       Free Software Foundation, Inc.
--       51 Franklin Street, Fifth Floor
--       Boston, MA  02110-1301, USA
--
-- -->8-->8-->8-->8--

local oo = require "loop.base"
local table = table

local io = require "io"
local string = require "string"

local vobj = require( "validate.args" ):new()
vobj:setopts{ error_on_invalid = true }

local RDB = oo.class()

------------------------------------------------
-- private functions

-- split a line into tokens
-- see http://lua-users.org/lists/lua-l/2006-12/msg00414.html

local function _string_split(self, pat)
  local st, g = 1, self:gmatch("()("..pat..")")
  local function getter(self, segs, seps, sep, cap1, ...)
    st = sep and seps + #sep
    return self:sub(segs, (seps or 0) - 1), cap1 or sep, ...
  end
  local function splitter(self)
    if st then return getter(self, st, g()) end
  end
  return splitter, self
end

local function _split( line )

  local t = {}

  for word in _string_split( line, '\t' ) do
     table.insert( t, word )
  end
  return t

end

------------------------------------------------
-- private methods

-- update column & definition dependent object attributes
local function _update( self  )

  -- create table of positions and definitions keyed off of column names
  self.pos = {}
  self.keydefs = {}
  self.numeric = {}
  for i, v in ipairs( self.cols ) do

     self.pos[v] = i
     self.keydefs[v] = self.defs[i]

     self.numeric[i] = self.defs[i]:find( '^%d*[Nn]' )
  end

end


local function _read_hdr( self  )

   local next_line = self.fh:lines()

   local buf
   local vars = {}
   local comments = {}

   for line in next_line do

      self.lines = self.lines + 1

      local comment = line:match( '^%s*#(.*)' )
      if not comment then
	 buf = line
	 break
      end

      comments[#comments+1] = comment

      local var, value = line:match( ':%s*(%w*)%s*=%s*(.*)' )

      if var then
	 vars[var] = value
      end

   end

   self.comments = comments
   self.vars = vars

  -- check for EOF
  if not buf then
     error( 'unexpected EOF in ' .. self.filename, self.level )
  end

  -- read in header
  self.cols = _split( buf )
  self.ncols = #self.cols

  -- read in column definitions
  buf = next_line()
  if nil == buf then
     error( 'unexpected EOF in ' .. self.filename, self.level )
  end

  self.lines = self.lines + 1

  self.defs = _split( buf )

  if #self.defs ~= self.ncols then
     error(
	string.format( '%s:%d -- # of columns (%d) != # of definitions (%d)\n',
	   self.filename, self.lines, self.ncols, #self.defs ),
	self.level
     )
  end

   _update( self )

end

local function _write_hdr( self )

   local fh = self.fh

   local vwritten = {}

   -- write comments and header variables
   for i, comment in ipairs( self.comments ) do

      local var, value = comment:match( ':%s*(%w*)%s*=%s*(.*)' )

      if var then

	 vwritten[var] = true

	 -- only output the comment line if the variable is still there
	 if vars[var] ~= nil then

	    fh:write( '#', comment:gsub( '(:%s*%w*%s*=%s*)(.*)', "%1 " .. vars[var] ), "\n" )

	 end

      else

	 fh:write( '#', comment, "\n" )

      end

   end

   for var, value in pairs( self.vars ) do

      if not vwritten[var] then

	 fh:write( '#: ', var, ' = ', value, "\n" )

      end

   end

   if #self.cols ~= #self.defs then

      error( string.format( "number of colum names (%d) not equal to number of column definitions (%d)",
			   #self.cols, #self.defs ), self.level+1 )

   end

   -- just in case...
   self.ncols = self.ncols or #self.cols

   _write_tsv( self, self.cols )
   _write_tsv( self, self.defs )

   self.hdr_written = true

end

function _write_tsv( self, values )

   local fh = self.fh

   if values ~= nil and #values > 0 then

      fh:write( string.join("\t", values), "\n" )
   end

end



------------------------------------------------


-- base object

function RDB:__new( ... )

   local obj = oo.rawnew( self )

   obj.vars = {}
   obj.comments = {}
   obj.cols = {}
   obj.defs = {}

   obj.level = 3

   obj:open( ... )

   obj.level = 2

   return obj

end

-- backwards compatibility with ancient version which doesn't use __new

RDB.__init =  RDB.__new

function RDB.slurp ( ... )

   local ok, filename = vobj:validate( { { type = 'string' } }, ... )

   local fh = RDB( filename )
   local vars = fh.vars

   local rows = {}

   local rrow = function () return fh:read() end

   for row in rrow do

      table.insert( rows, row )

   end

   fh:close()

   return rows, vars

end

function RDB.select ( ... )

   local coroutine = require('coroutine')
   local yield = coroutine.yield

   local ok,  args =
      vobj:validate( { file  = { type = 'string' },
		       match = { type = { 'function', 'table' } },
		       col   = { type = 'table', optional = true },
		       mode  = { enum = { 'function', 'iterator', 'coroutine' },
				 default = 'function' },
		       idx_by_row = { type = 'boolean', default = false },
		    }, ... )

   assert( ok, args )

   if ( args.mode == 'iterator' ) then

      local unpack = unpack

      args.mode = 'coroutine'
      local co = coroutine.create( function() RDB.select( args ) end)
      return function()
		local res = { coroutine.resume( co ) }
		if not res[1] then
		   error( unpack( res, 2  ) )
		end

		return unpack( res, 2 )
	     end
   end


   local rdb = assert( RDB( args.file ) )

   local mfunc

   if ( type( args.match ) == 'table' ) then
      mfunc = function( rec, idx )
		 for k,v in pairs( args.match ) do

		    if type(v) == 'table' then

		       local ok = false

		       for _,vv in pairs( v ) do
			  if rec[k] == vv then 
			     ok = true
			  end
		       end

		       if not ok then
			  return false
		       end

		    elseif rec[k] ~= v then
		       return false
		    end
		 end

		 return true
	      end

   else

      mfunc = args.match

   end

   local row = 0

   local matches = {}
   local rrec = function() return rdb:read() end
   for rec in rrec do
      row = row + 1
      local match = mfunc( rec, row )
      if match then
	 if args.mode == 'coroutine' then
	    yield( rec, row )
	 else
	    if args.row_by_row then
	       matches[row] = rec
	    else
	       table.insert( matches, rec )
	    end
	 end

      end
  end

  rdb:close()

  if args.mode == 'function' then
     return matches
  end

end



function RDB:open( ... )

   local ok, filename, mode = vobj:validate(
      {
	 { type = 'string', optional = true },
	 { type = 'string', default = 'r',
	   enum = { 'r', 'w' }
	},
      },
      ... )

   if not filename then
      return
   end

   self.filename = filename
   self.lines = 0

   if self.filename == '-' then

      if      mode == 'r' then self.fh = io.input()
      elseif  mode == 'w' then self.fh = io.output()
      end

   else

      local err

      self.fh, err = io.open( self.filename, mode )
      if not self.fh then error( err, self.level ) end

   end


   if mode == 'r' then
      _read_hdr( self )
   end

end

function RDB:close( )

   if self.fh then
      if self.filename ~= '-' then self.fh:close() end
      self.fh = nil
   end

end

function RDB:init( ... )

   local ok, args = va:validate( { { type = 'table' } },
				 { ... } )

   if #args%2 == 1 then
      error( "mismatched number of columns and definitions", 2 )
   end

   self.cols = {}
   self.defs = {}

   for i = 1, #args, 2 do

      self.cols[#self.cols+1] = args[i]
      self.defs[#self.defs+1] = args[i+1]

   end

   _update( self )

end

function RDB:read( )


   -- only read if there's been a successful open
   if not self.fh then
      return nil
   end

   local buf = self.fh:read()

   if not buf then
      return nil
   end

   local data = {}


   self.lines = self.lines + 1

   local rec = _split(buf)

   if #rec ~= self.ncols then
      error( self.filename
	    ..': '
	    .. self.lines
	    .. ': error in RDB file: expected '
	    .. self.ncols
	    .. ' columns, got '
	    .. #rec,
	 self.level
      );
   end


   for i,v in pairs(rec) do

      if v ~= '' then

	 if self.numeric[i] then

	    data[self.cols[i]] = tonumber(v)

	 else

	    data[self.cols[i]] = v

	 end

      end
   end

   return data
end

function RDB:iwrite( data )

   if not self.hdr_written then
      _write_hdr( self )
   end

   local row = data

   if #data ~= self.ncols then

      row = {}

      for i = 1,self.ncols,1 do
	 row[#row+1] = data[i] or ''
      end

   end

   _write_tsv( self, data )

end

function RDB:kwrite( data )

   if not self.hdr_written then
      _write_hdr( self )
   end

   local row = {}

   for _, v in ipairs( self.cols ) do

      row[#row+1] = data[v] or ''
   end

   _write_tsv( self, row )

end


function RDB:colpos( col )
  return self.pos[col]
end

function RDB:coldef( col )
  return self.keydefs[col]
end

function RDB:var( varname )

  return self.vars[varname]

end

function RDB:chk_cols( cols )

   for _,v in pairs( cols ) do

      if not self.keydefs[v] then return v end

   end

   return nil

end

return RDB
