#!/usr/bin/env python

# Copyright (C) 2022,2025 MIT
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

'''
Implementation notes
--------------------

The implementation is done in two classes instead of functions to make it
easier to derive new variants. In principle, the machinery in the
Kashyap et al. (2010) paper is even more general then the implementation
here, but it is still possible that we want to derive more specialized
implementations than what is currently in this code at a later time, thus
it does not hurt to implement as classes.

This code uses a number of numeric functions and processes, all of which are
available in scipy. Since the ciao_contrib cannot depend on scipy at this time
they are taken from different sources:

- Gamma function and incomplete gamma function as available from Sherpa as a
  thin wrapper around an underlying C implementation.
- bisection is available from Sherpa.
- Numerical integration is implemented in the most naive form here, but while
  not ideal we have enough knowledge on the shape of the integrated function
  to simply evaluate and sum up in Python. An alternative would be to use
  ctypes to connect to the numerical integration in the GSL.
- The PDF of the gamma distribution
  https://en.wikipedia.org/wiki/Gamma_distribution
  (not to be confused with the Gamma function) is implemented here in Python
  following the ideas in scipy's C code.
'''

import os
import sys

import ciao_contrib.logger_wrapper as lw

import numpy as np
from sherpa.utils import igam, lgam, bisection

TOOLNAME = "aplimits"
__REVISION__ = "21 August 2025"

lw.initialize_logger(TOOLNAME)
v1 = lw.make_verbose_level(TOOLNAME, 1)
v2 = lw.make_verbose_level(TOOLNAME, 2)
v3 = lw.make_verbose_level(TOOLNAME, 3)

#
# Numerical routines
#
def xlogy(x, y):
    '''Return x * log(y) in a numerically robust way

    The simple `x * np.log(y)` may fail if x and/or y are
    0 due to the way that NaNs are propagated. This little
    helper function treats those special cases.

    Following the idea of the Cython implementation in scipy.
    '''
    if x == 0 and not np.isnan(y):
        return 0
    elif np.isfinite(x) and y == 0:
        # Numpy correctly returns -np.inf, but raises
        # RuntimeWarning: divide by zero encountered in log
        # To avoid this warning, hardcode the answer here.
        return -np.inf
    else:
        return x * np.log(y)

def gamma_pdf(x, alpha, beta):
    '''Probabillity-density function for the Gamma distribution

    (Not to be confused with the related Gamma function.)
    '''
    # Rescaling x -> x/b (see scipy docs)
    # for easier numeric evaluation
    x = x * beta

    # Calculate log of pdf to avoid overflow - following ideas of scipy implementation
    logpdf = xlogy(alpha - 1.0, x) - x - lgam(alpha)

    return np.exp(logpdf) * beta

class APLimits():
    '''Calculate upper limit for a detection following Kashyap et al. (2010)

    Parameters
    ----------
    lamb : float
        background intensity in the source aperture in units of counts / time
    taus : float
        Exposure time in source region
    taub : float
        Exposure time in background region. Since lamb is already normalized to the exposure time,
        this setting has no effect in this class.
    r : float
        Ratio of background to source area. Since lamb is already normalized to the area,
        this setting has no effect in this class.
    maxfev : int
        maximum number of function evaluations in bisection
    '''

    def __init__(self, lamb, taus=1, taub=1, r=1, maxfev=500):
        self.lamb = lamb
        self.taus = taus
        self.taub = taub
        self.r = r
        self.maxfev = maxfev

    def find_sstar(self, alpha):
        '''This is essentially evaluating eqn 4 to the value for Sstar'''
        i = 0
        while self.beta(lams=0, sstar=i) > alpha:
            i +=1
        return i

    def beta(self, lams, sstar, lamb=None):
        '''Probability of a detection, i.e. one minus the probability of a Type II error

        In the derived class, lamb is not a class attribute, but is a
        variable that is integrated over. In that case, we need to pass
        in lamb.

        In this class, lamb is an attribute of the class and not
        passed in.

        In Kashyap et al. (2010) this is eqn 5 implemented for a Poisson
        case. The equation in footnote 13 in that paper (about 2/3 down -
        equations within a footnote are not numbered)
        can be obtained as `beta(0, sstar, lamb)` and the general case is
        equivalent to the formula in footnote 14 (bottom of footnote).

        Parameters
        ----------
        lams : float
            Source intensity
        sstar : int
            Detection threshold value
        lamb: float or None
            Background intensity (if None, `self.lamb` is used)
        '''
        if lamb is None:
            lamb = self.lamb
        return igam(sstar + 1, (lamb + lams) * self.taus)

    def rootfunc(self, lams, sstar, betalim=.5):
        return self.beta(lams, sstar) - betalim

    def __call__(self, alpha, beta):
        '''
        Parameters
        ----------

        Returns
        -------
        upper_limit : float
            Upper limit
        sstar : int
            Minimum number of counts needed to claim a detection
        '''
        if alpha <=0 or alpha >= 1:
            raise ValueError('alpha is a probability, which must be between 0 and 1.')
        if beta <=0 or beta >= 1:
            raise ValueError('beta is a probability, which must be between 0 and 1.')
        sstar = self.find_sstar(alpha)

        # Note on using sstar for upper guess
        # for extreme values (e.g. alpha=.9, beta=.1, lamb=3) it can happen that
        # even a source count rate of 0 is already too large.
        # In practice, that won't happen (alpha is typically 0.1 or 0.01)
        # but we want the algorithm to work for every case.
        if self.rootfunc(0, sstar, beta) > 0:
            return 0, 1

        # We use a bisect method for root finding,
        # so we first need to establish a bracketing interval.
        upper_bracket = sstar if sstar > 0 else 1
        while self.rootfunc(upper_bracket, sstar, beta) < 0:
            upper_bracket *= 10
        out = bisection(self.rootfunc, 0, upper_bracket, args=(sstar, beta),
                        maxfev=self.maxfev)

        if np.abs(out[0][1]) > 1e-6:
            v1('Numerical root finding did not converge.')
        return out[0][0], sstar

class APLimitsMarginalize(APLimits):
    '''Calculate upper limit for a detection following Kashyap et al. (2010)

    Parameters
    ----------
    nb : int
        number of counts in background region
    taus : float
        Exposure time in source region
    taub : float
        Exposure time in background region.
    r : float
        Ratio of background to source area.
    maxfev : int
        maximum number of function evaluations in bisection
    '''
    # super().__init__ has lamb instead of "nb"
    # but the rest of the signature is the same.
    def __init__(self, nb, taus=1, taub=1, r=1, maxfev=500):
        if nb < 0:
            raise ValueError('Number of counts cannot be negative.')
        if nb == 0:
            raise ValueError('nb=0 does not constrain the background count rate. ' +
                             'Thus, the posterior of the count rate cannot be normalized. ' +
                             'At least one background count is required for this algorithm.')
        if not float(nb).is_integer():
            raise ValueError('Number of counts in the background must be an integer number.')
        if (taus <= 0) or (taub <= 0):
            raise ValueError('Exposure time must to be >0.')
        if r <= 0:
            raise ValueError('The ratio of background to source area must be positive.')
        self.nb = nb
        self.taus = taus
        self.taub = taub
        self.r = r
        self.maxfev = maxfev


    def lamb_posterior(self, lamb):
        '''Estimate posterior distribution for lamb_b

        Need to take care to use right scaling for input output for r and
        taub/taus
        Gamma is the conjugate prior for Poisson, so we can get the posterior
        analytically which is nice because it is fast and there is no worry
        about numerical stability.

        We are using an uninformative prior with gamma(1, 0). This represents
        the information from an observation with 0 counts in 0 area
        (following what is done for aprates:
        https://cxc.harvard.edu/csc/memos/files/Kashyap_xraysrc.pdf)
        This prior is "improper" - it cannot be normalized. However, the
        gamma function is the conjugate prior for our problem and we can then
        use the information passed in about the background to analytically
        calculate the posterior - and that is normalizable and useful.
        '''
        return gamma_pdf(lamb, 1 + self.nb, self.taub * self.r)

    def integrant(self, lamb, lams, sstar_trial):
        p_of_lamb = self.lamb_posterior(lamb)
        beta = super().beta(lams=lams, sstar=sstar_trial, lamb=lamb)
        return  beta * p_of_lamb

    def integrate(self, lams, sstar):
        '''Integrate to marginalize over lamb

        We need to integrate over ``lamb`` from 0 to inf.
        The easiest way to implement that is to pass this to a
        numerical library with adaptive bin size.
        However, we don't have such a library available at the Python level.
        So, choice is to use ctypes to get numerical integration functions from
        the GSL, or to implement a simple integration here by hand.
        Fortunately, we know a lot about te shape of the function to be integrated.
        In particular, we know the location of the peak and know that
        scale (the width) of the Gamma distribution in `lamb_posterior`.

        We also know that no great precision is needed, because this function
        will only be invoked for low-count statistics. Last, we know that
        executing the ``self.integrant`` function is relatively cheap, so we can
        afford to evaluate it at more points instead of programming
        more efficient schemes like Gauss-Kronrod or similar.
        '''
        # See wikipedia for mean and variance of distribution
        loc_peak = self.nb / (self.r * self.taub)
        width = np.sqrt(self.nb / (self.r * self.taub))
        # Numerically integrate by using a regular grid from
        # width = sqrt(variance)
        # -10 * std to + 20 * std
        # It's asymmetric because the function is for small numbers.
        # choice of 1000 points is made by looking at functions
        # for typical input values and checking how many point we need to get
        # reasonable precision.
        x = np.linspace(max(0, loc_peak - 10 * width), loc_peak + 20 * width, 1000)
        # For very small backgrounds, loc_peak << width
        # In that case, we have to make sure that the region around the peak is well
        # sampled. Without the following added points, the peak can be missed and thus
        # the integrated value is close to 0.
        # Note: Even in an alternative implementation that uses quad from scipy.optimize
        #       it is still necessary to ensure that that region is sampled well.
        if loc_peak < 3 * width:
            # Region around peak not well sampled by np.linspace. Add points.
            x_around_peak = np.linspace(0, 20 * loc_peak, 1000, endpoint=False)
            x = np.concatenate([x, x_around_peak])
            x.sort()
        # igam does not accept numpy arrays, so need to call one by one and put in list
        y = [self.integrant(xp, lams, sstar) for xp in x]

        if hasattr(np, 'trapezoid'):
            # Numpy 2.x
            return np.trapezoid(y, x)
        elif hasattr(np, 'trapz'):
            # Numpy 1.x
            return np.trapz(y, x)
        else:
            raise RuntimeError("Unsupported version of Numpy")

    def find_sstar(self, alpha):
        i = 0
        while self.integrate(0, i) > alpha:
            i += 1
        return i

    def beta(self, lams, sstar):
        return self.integrate(lams, sstar)


def aplimits(prob_false_detection=.1, prob_missed_detection=.5, T_s=1, A_s=1,
             bkg_rate=None, m=None, A_b=1, T_b=1, max_counts=50,
             maxfev=500,
            ):
    '''Calculate an upper limit on a count rate

    This function provides an interface between the two classes above and the
    functional form that is more suitable for ciao_contrib scripts. It also
    translates from an interface that is good for CIAO and is modeled after
    the similar `aprates` to the definitions used in the Kashap et al. (2010)
    paper. For example, in this ciao_contrib script, we call the main
    parameters prob_false_detection and prob_missed_detection, while Kashyap et al. (2010) call
    them alpha and beta, which might be confusing because aplimits uses
    alpha and beta for different concepts. Similarly, Kashyap et al. (2010)
    only have an area ratio "r", while in CIAO we measure source and
    background areas separately.

    Parameters
    ----------
    prob_false_detection : float
        Probability that a background fluctuation
        causes a count rate above the threshold used for the upper limit.
    prob_missed_detection : float
        Upper limit on the probability that a true
        source of a given intensity is missed because it gives a count rate
        below the detection threshold.
    T_s : float
        Exposure time in source aperture. May be set to 1 if computation of
        source rate is not desired.
        Units: sec
    A_s : float
        Geometric area of source aperture. Either square arcsec or pixels are
        allowed, as long as the units agree with those of A_b.
        Units: pixels or arcsec^2
    bkg_rate : float
        Known background rate. The flux is given per area and time.
        The unit of the area must match the units used for A_s and A_b.
        Units: photons / (pixel sec) or photons / (arcsec^2 sec)
    m : integer
        Number of counts in background aperture
    A_b : float
        Geometric area of background aperture. Either square arcsec or pixels
        are allowed, as long as the units agree with those of A_s.
        Units: pixels or arcsec^2
    T_b : float
        Exposure time in background aperture. May be set to 1 if computation
        of background rate is not desired.
        Units: sec
    max_counts : integer
        Max total counts before switching from Bayesian approach to assuming
        that the background rate is known exactly, even if the input is given
        as background counts and aperture area.

        If the number of background counts is large, the Bayesian posterior
        becomes sharply peaked. In his case the algorithms for the numerical
        integration may run very long or even fail while result is only
        negligibly different from those obtained assuming that the background
        is known exactly.

        Comparison of both methods indicates that a value of max_counts=50
        (the default) is a good choice for a wide range of input parameters.
    maxfev : int
        Maximal number of function evaluations in bisection algorithm to find
        numerical solution

    Returns
    -------
    lambs : float
        Upper limit for the source intensity (in cts/time), integrated over the
        source aperture.
    sstar : int
        Minimum number of counts needed to claim a detection

    Note
    ----
    The probabilities for Type I and Type II errors are called alpha and beta
    in the Kashyap et al. (2010) paper but the same symbols are used with a
    different meaning in aprates. To avoid confusion, the parameters of this
    function are named prob_false_detection and prob_missed_detection.
    '''
    if prob_false_detection <=0 or prob_false_detection >= 1:
        raise ValueError('prob_false_detection is a probability, which must be between 0 and 1.')
    if prob_missed_detection <=0 or prob_missed_detection >= 1:
        raise ValueError('prob_missed_detection is a probability, which must be between 0 and 1.')
    if bkg_rate is None and m is None:
        raise ValueError('Either bkg_rate or m have to be given.')
    if bkg_rate is not None and m is not None:
        v1('bkg_rate and m are given. Using bkg_rate for the computation and ignoring m.')
    if bkg_rate is None and m >= max_counts:
        bkg_rate = m / A_b / T_b
    if bkg_rate is not None:
        v2(f'background rate known: {bkg_rate}')
        limitfinder= APLimits(lamb=bkg_rate * A_s, taus=T_s, taub=T_b, r=A_b / A_s,
                              maxfev=maxfev)
    else:
        v2('background unknown. Marginalizing over background rate.')
        v2(f'bkg counts {m} in {T_b} exposure time and {A_b} area.')
        limitfinder = APLimitsMarginalize(nb=m, taus=T_s, taub=T_b, r=A_b / A_s,
                                          maxfev=maxfev)
    ul, s_star = limitfinder(prob_false_detection, prob_missed_detection)
    return ul, s_star


#
# Main Routine
#
@lw.handle_ciao_errors(TOOLNAME, __REVISION__)
def main():
    'Main routine'
    from ciao_contrib.param_soaker import get_params

    # get parameters
    pars = get_params(TOOLNAME, "rw", sys.argv,
                      verbose={"set": lw.set_verbosity, "cmd": v1})
    ul, mincounts = aplimits(prob_false_detection=float(pars["prob_false_detection"]),
                         prob_missed_detection=float(pars["prob_missed_detection"]),
                         bkg_rate=None if pars["bkg_rate"] == 'INDEF' else float(pars["bkg_rate"]),
                         T_s=float(pars["T_s"]),
                         A_s=float(pars["A_s"]),
                         m=None if pars["m"] == 'INDEF' else int(pars["m"]),
                         T_b=float(pars["T_b"]),
                         A_b=float(pars["A_b"]),
                         max_counts=int(pars["max_counts"]),
                         maxfev=int(pars["maxfev"]))
    v1(f"Minimum number of counts for detection: {mincounts} counts")
    v1(f"Upper limit: {ul}")
    v3("Clobber check")
    if pars["clobber"] == "no" and os.path.exists(pars["outfile"]):
        raise IOError(f"outfile={pars['outfile']} exists and clobber=no.")

    with open(pars["outfile"], "w") as fp:
        fp.write(f'min_counts_detect,i,h,{mincounts},,,"Detection threshold"\n')
        fp.write(f'upper_limit,r,h,{ul},,,"Upper limit"\n')


if __name__ == "__main__":
    main()
