Source code for agasc_gaia.utils

import tables
import functools
import pickle
import numpy as np

from pathlib import Path

from astropy.table import Table, MaskedColumn


class TableCache:
    @staticmethod
    def save(table, filename, force=False):
        if filename.exists() and not force:
            return
        if not filename.parent.exists():
            filename.parent.mkdir(parents=True, exist_ok=True)
        if not isinstance(table, Table):
            table = Table(table)
        if "description" in table.meta:
            table.meta["descript"] = table.meta["description"]
            del table.meta["description"]
        table.write(filename, overwrite=force)

    @staticmethod
    def load(filename):
        if filename.exists():
            table = Table.read(filename)
            table.convert_bytestring_to_unicode()
            return table


class HDFCache:
    @staticmethod
    def save(table, filename, force=False):
        if filename.exists() and not force:
            raise Exception(f"File {filename} already exists.")
        if not filename.parent.exists():
            filename.parent.mkdir(parents=True, exist_ok=True)
        if isinstance(table, Table):
            table = table.as_array()

        mask = Table()
        for col in table.dtype.names:
            if hasattr(table[col], "mask"):
                mask[col] = table[col].mask
        with tables.open_file(filename, "w") as h5:
            h5.create_table("/", "data", table)
            h5.create_table("/", "mask", mask.as_array())

    @staticmethod
    def load(filename):
        if filename.exists():
            with tables.open_file(filename) as h5:
                table = Table(h5.root.data[:])
                mask = h5.root.mask[:]
            for col in mask.dtype.names:
                table[col] = MaskedColumn(table[col], mask=mask[col])
            return table


class PickleCache:
    @staticmethod
    def save(table, filename, force=False):
        if filename.exists() and not force:
            raise Exception(f"File {filename} already exists.")
        if not filename.parent.exists():
            filename.parent.mkdir(parents=True, exist_ok=True)
        if isinstance(table, Table):
            table = table.as_array()
        with open(filename, "wb") as fh:
            pickle.dump(table, fh)

    @staticmethod
    def load(filename):
        if filename.exists():
            with open(filename, "rb") as fh:
                return pickle.load(fh)


CACHES = {
    ".fits.gz": TableCache,
    ".fits": TableCache,
    ".csv": TableCache,
    ".h5": HDFCache,
    ".pkl": PickleCache,
}


class Cache:
    @staticmethod
    def save(table, filenames, force=False):
        if isinstance(filenames, list):
            filenames = tuple(filenames)
        elif not isinstance(filenames, tuple):
            filenames = (filenames,)
        filenames = [Path(filename) for filename in filenames]
        cache = [CACHES["".join(filename.suffixes)] for filename in filenames]
        if not hasattr(table, "__len__") or isinstance(table, Table):
            table = [table]
        assert len(table) == len(
            filenames
        ), f"Expected {len(filenames)} tables, got {len(table)}."
        for res, cch, filename in zip(table, cache, filenames):
            cch.save(res, filename, force)

    @staticmethod
    def load(filenames):
        if isinstance(filenames, list):
            filenames = tuple(filenames)
        elif not isinstance(filenames, tuple):
            filenames = (filenames,)
        filenames = [Path(filename) for filename in filenames]
        cache = [CACHES["".join(filename.suffixes)] for filename in filenames]
        result = [cch.load(filename) for cch, filename in zip(cache, filenames)]
        if np.any([res is None for res in result]):
            # one invalid entry invalidates all the others, since they come from the same function
            return None
        if len(result) == 1:
            result = result[0]
        return result


[docs]def cached(name=None, force=False): """Decorator to cache the result of a function in a [list of] file[s].""" def decorator(func): @functools.wraps(func) def wrapper(*args, filenames=None, force=force, **kwargs): from agasc_gaia.config import FILES if filenames is None: filenames = FILES[name] result = None if not force: result = Cache.load(filenames) if result is None: result = func(*args, **kwargs) Cache.save(result, filenames, force) return result return wrapper return decorator
class Files: def __init__(self, data_dir, files): self.data_dir = data_dir self.files = files def _prepend_path_(self, value): if isinstance(value, list) or isinstance(value, tuple): value = [self.data_dir / file for file in value] else: value = self.data_dir / value return value def keys(self): return self.files.keys() def values(self): for value in self.files.values(): yield self._prepend_path_(value) def items(self): for key, value in self.files.items(): yield key, self._prepend_path_(value) def __len__(self): return len(self.files) def __iter__(self): return self.files.__iter__() def __repr__(self): return repr(dict(self)) def __str__(self): return str(dict(self)) def __getitem__(self, key): return self._prepend_path_(self.files[key])