# Copyright (C) 2025 the astropix team.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Histogram facilities.
"""
from abc import ABC, abstractmethod
import math
from numbers import Number
import numpy as np
from astropix_analysis.plt_ import plt, setup_axes, matplotlib
[docs]
class RunningStats:
"""Online mean and standard deviation using Welford's algorithm.
"""
def __init__(self, n: int = 0, mean: float = 0., M2: float = 0.) -> None:
"""Constructor.
"""
# pylint: disable=invalid-name
self._n = int(n)
self._mean = float(mean)
self._M2 = float(M2)
[docs]
@classmethod
def from_sample(cls, x: np.ndarray) -> 'RunningStats':
"""Create a RunningStats object from a data sample.
This is much more efficient than processing the numbers in the sample
one by one because it is tranforming the input sample into a numpy array
and doing the calculation of the sample average and variance in C.
Arguments
---------
x : array_like
The input sample. This is generally a 1-dimensional numpy array.
Lists, tuples and iterables in general will be converted to a numpy
array if possible. Multidimensional arrays will be flattened.
"""
x = np.asarray(x).flatten()
return cls(x.size, x.mean(), x.var() * x.size)
[docs]
def merge(self, other: 'RunningStats') -> 'RunningStats':
"""Merge another RunningStat object into the current one.
Note this happens in place, and no new object is created. A reference
to the original object is returned, so that multiple calls can be
conveniently chained.
"""
# pylint: disable=protected-access
# This will only work if other is another RunningStats object,
if not isinstance(other, RunningStats):
raise TypeError(f'{other} is not a RunningStats object')
# If the other running stats is empty, there is nothing to do.
if len(other) == 0:
return self
# If the original object is empty, then we just copy the other over.
if len(self) == 0:
self._n, self._mean, self._M2 = other.n, other.mean, other._M2
return self
# And if we made it to this point, we need the fully-fledged combination.
self._n += other.n
delta = other.mean - self._mean
self._mean += delta * other.n / self._n
self._M2 += other._M2 + delta * delta * (self._n - other.n) * other.n / self._n
return self
[docs]
def update(self, x: np.ndarray) -> None:
"""Update the running statistics with more data.
"""
if isinstance(x, Number):
# Good old plain Welford formula. Note this is exactly the same as
# the merge formula, with other.n = 1 and other.mean = x.
self._n += 1
delta = x - self._mean
self._mean += delta / self._n
self._M2 += delta * delta * (self._n - 1) / self._n
else:
self.merge(RunningStats.from_sample(x))
def __len__(self) -> int:
"""Return the sample size.
"""
return self._n
@property
def n(self) -> int:
"""Return the sample size.
"""
return self._n
@property
def mean(self) -> float:
"""Return the current mean of the sample.
"""
return self._mean if self._n > 0 else float('nan')
@property
def variance(self) -> float:
"""Return the (unbiased) sample variance.
"""
return self._M2 / (self._n - 1) if self._n > 1 else float('nan')
@property
def stdev(self) -> float:
"""Return the standard deviation of the sample.
"""
return math.sqrt(self.variance) if self._n > 1 else float('nan')
def __repr__(self):
"""String formatting.
"""
return (f'RunningStats(n={self._n}, mean={self.mean}, stdev={self.stdev})')
[docs]
class InvalidShapeError(RuntimeError):
"""RuntimeError subclass to signal an invalid shape while operating with arrays.
"""
def __init__(self, expected, actual):
"""Constructor.
"""
super().__init__(f'Invalid array shape: {expected} expected, got {actual}')
[docs]
class AbstractHistogram(ABC):
"""Base class for an n-dimensional weighted histogram.
This interface to histograms is profoundly different for the minimal
numpy/matplotlib approach, where histogramming methods return bare
vectors of bin edges and counts.
Parameters
----------
bin_edges : n-dimensional tuple of arrays
the bin edges on the different axes.
axis_labels : n-dimensional tuple of strings
the text labels for the different axes.
"""
PLOT_OPTIONS = {}
def __init__(self, bin_edges: tuple, axis_labels: list) -> None:
"""Constructor.
"""
# Quick check on the bin_edges and label tuples---we need N + 1 axis labels
# for an N-dimensional histogram.
if not len(axis_labels) == len(bin_edges) + 1:
msg = f'Length mismatch between bin edges ({len(bin_edges)}) and '\
f'axis_labels ({len(axis_labels)})'
raise RuntimeError(msg)
# The bin_edges is not supposed to change ever, so we make sure it is a tuple...
self._bin_edges = tuple(bin_edges)
# ...while the labels might conceivably be changed after the fact, hence a list.
self._axis_labels = list(axis_labels)
# Initialize all the relevant arrays. Note we cache the shape of all the
# underlying arrays for future use; keep in mind there are N + 1 bin edges
# for N bins.
self._shape = tuple(len(edges) - 1 for edges in self._bin_edges)
self._content = self._zeros()
self._sumw2 = self._zeros()
[docs]
def _zeros(self, dtype: type = float) -> np.ndarray:
"""Return an array of zeros of the proper shape for the underlying
histograms quantities.
"""
return np.zeros(shape=self._shape, dtype=dtype)
[docs]
def _check_array_shape(self, data: np.array) -> None:
"""Check the shape of a given array used to update the histogram.
"""
if data.shape == self._shape:
raise InvalidShapeError(self._shape, data.shape)
[docs]
def reset(self) -> None:
"""Reset the histogram.
"""
self._content = self._zeros()
self._sumw2 = self._zeros()
[docs]
def bin_centers(self, axis: int = 0) -> np.array:
"""Return the bin centers for a specific axis.
"""
return 0.5 * (self._bin_edges[axis][1:] + self._bin_edges[axis][:-1])
[docs]
def bin_widths(self, axis: int = 0) -> np.array:
"""Return the bin widths for a specific axis.
"""
return np.diff(self._bin_edges[axis])
[docs]
def errors(self) -> np.array:
"""Return the errors on the bin content.
"""
return np.sqrt(self._sumw2)
[docs]
def fill(self, *values, weights=None) -> 'AbstractHistogram':
"""Fill the histogram from unbinned data.
Note this method is returning the histogram instance, so that the function
call can be chained.
"""
values = np.vstack(values).T
if weights is None:
content, _ = np.histogramdd(values, bins=self._bin_edges)
sumw2 = content
else:
content, _ = np.histogramdd(values, bins=self._bin_edges, weights=weights)
sumw2, _ = np.histogramdd(values, bins=self._bin_edges, weights=weights**2.)
self._content += content
self._sumw2 += sumw2
return self
[docs]
def set_content(self, content: np.array, errors: np.array = None):
"""Set the bin contents programmatically from binned data.
Note this method is returning the histogram instance, so that the function
call can be chained.
"""
self._check_array_shape(content)
self._content = content
if errors is not None:
self.set_errors(errors)
return self
[docs]
def set_errors(self, errors: np.array) -> None:
"""Set the proper value for the _sumw2 underlying array, given the
errors on the bin content.
"""
self._check_array_shape(errors)
self._sumw2 = errors**2.
[docs]
@staticmethod
def bisect(bin_edges: np.array, values: np.array, side: str = 'left') -> np.array:
"""Return the indices corresponding to a given array of values for a
given bin_edges.
"""
return np.searchsorted(bin_edges, values, side) - 1
[docs]
def find_bin(self, *coords) -> tuple:
"""Find the bin corresponding to a given set of "physical" coordinates
on the histogram axes.
This returns a tuple of integer indices that can be used to address
the histogram content.
"""
return tuple(self.bisect(bin_edges, value) for bin_edges, value in
zip(self._bin_edges, coords))
[docs]
def find_bin_value(self, *coords) -> float:
"""Find the histogram content corresponding to a given set of "physical"
coordinates on the histogram axes.
"""
return self._content[self.find_bin(*coords)]
[docs]
def normalization(self, axis: int = None):
"""return the sum of weights in the histogram.
"""
return self._content.sum(axis)
[docs]
def empty_copy(self):
"""Create an empty copy of a histogram.
"""
return self.__class__(*self._bin_edges, *self._axis_labels)
[docs]
def copy(self):
"""Create a full copy of a histogram.
"""
hist = self.empty_copy()
hist.set_content(self._content.copy(), self.errors())
return hist
def __add__(self, other):
"""Histogram addition.
"""
hist = self.empty_copy()
hist.set_content(self._content + other._content, np.sqrt(self._sumw2 + other._sumw2))
return hist
def __sub__(self, other):
"""Histogram subtraction.
"""
hist = self.empty_copy()
hist.set_content(self._content - other._content, np.sqrt(self._sumw2 + other._sumw2))
return hist
def __mul__(self, value):
"""Histogram multiplication by a scalar.
"""
hist = self.empty_copy()
hist.set_content(self._content * value, self.errors() * value)
return hist
def __rmul__(self, value):
"""Histogram multiplication by a scalar.
"""
return self.__mul__(value)
[docs]
@abstractmethod
def _draw(self, axes, **kwargs) -> None:
"""No-op method, to be overloaded by derived classes.
"""
[docs]
def draw(self, axes=None, **kwargs) -> None:
"""Plot the histogram.
"""
if axes is None:
axes = plt.gca()
for key, value in self.PLOT_OPTIONS.items():
kwargs.setdefault(key, value)
self._draw(axes, **kwargs)
setup_axes(axes, xlabel=self._axis_labels[0], ylabel=self._axis_labels[1])
[docs]
class Histogram1d(AbstractHistogram):
"""A one-dimensional histogram.
"""
PLOT_OPTIONS = dict(lw=1.25, alpha=0.4, histtype='stepfilled')
def __init__(self, xbinning: np.array, xlabel: str = '', ylabel: str = 'Entries/bin') -> None:
"""Constructor.
"""
super().__init__((xbinning, ), [xlabel, ylabel])
[docs]
def _draw(self, axes, **kwargs) -> None:
"""Overloaded method.
"""
axes.hist(self.bin_centers(0), self._bin_edges[0], weights=self._content, **kwargs)
[docs]
class Histogram2d(AbstractHistogram):
"""A two-dimensional histogram.
"""
PLOT_OPTIONS = dict(cmap=plt.get_cmap('hot'))
# pylint: disable=invalid-name
def __init__(self, xbinning, ybinning, xlabel='', ylabel='', zlabel='Entries/bin'):
"""Constructor.
"""
# pylint: disable=too-many-arguments
super().__init__((xbinning, ybinning), [xlabel, ylabel, zlabel])
self.color_bar = None
[docs]
def _update_color_bar(self, axes, image) -> None:
"""Update the color bar after a histogram re-draw.
This is a little bit tricky, as by default the colorbar gets her own
axes, and a call to plt.gca().cla() will not delete the color bar.
This is a small utility function to draw the color bar the first time
around, and then re-bind to the latest version of the data each time
the histogram is re-drawn.
"""
if self.color_bar is None:
self.color_bar = plt.colorbar(image, ax=axes)
if self._axis_labels[2] is not None:
self.color_bar.set_label(self._axis_labels[2])
else:
self.color_bar.update_normal(image)
[docs]
def _draw(self, axes, logz=False, **kwargs):
"""Overloaded method.
"""
# pylint: disable=arguments-differ
x, y = (v.flatten() for v in np.meshgrid(self.bin_centers(0), self.bin_centers(1)))
bins = self._bin_edges
w = self._content.T.flatten()
if logz:
# Hack for a deprecated functionality in matplotlib 3.3.0
# Parameters norm and vmin/vmax should not be used simultaneously
# If logz is requested, we intercent the bounds when created the norm
# and refrain from passing vmin/vmax downstream.
vmin = kwargs.pop('vmin', None)
vmax = kwargs.pop('vmax', None)
kwargs.setdefault('norm', matplotlib.colors.LogNorm(vmin, vmax))
_, _, _, image = axes.hist2d(x, y, bins, weights=w, **kwargs)
self._update_color_bar(axes, image)
[docs]
def slice(self, bin_index: int, axis: int = 0):
"""Return a slice of the two-dimensional histogram along the given axis.
"""
hist = Histogram1d(self._bin_edges[axis], self._axis_labels[axis])
hist.set_content(self._content[:, bin_index])
return hist
[docs]
def slices(self, axis: int = 0):
"""Return all the slices along a given axis.
"""
return tuple(self.slice(bin_index, axis) for bin_index in range(self._shape[axis]))
[docs]
def hslice(self, bin_index: int):
"""Return the horizontal slice for a given bin.
"""
return self.slice(bin_index, 0)
[docs]
def hslices(self):
"""Return a list of all the horizontal slices.
"""
return self.slices(0)
[docs]
def hbisect(self, y: float):
"""Return the horizontal slice corresponding to a given y value.
"""
return self.hslice(self.bisect(self._bin_edges[1], y))
[docs]
def vslice(self, bin_index):
"""Return the vertical slice for a given bin.
"""
return self.slice(bin_index, 1)
[docs]
def vslices(self):
"""Return a list of all the vertical slices.
"""
return self.slices(1)
[docs]
def vbisect(self, x):
"""Return the vertical slice corresponding to a given y value.
"""
return self.vslice(self.bisect(self._bin_edges[0], x))
[docs]
class Matrix2d(Histogram2d):
"""Specialized 2-dimensional histogram to display matrix-like data
(e.g., hitmap in logical space).
"""
def __init__(self, num_cols: int, num_rows: int, xlabel='Column', ylabel='Row',
zlabel='Entries/bin') -> None:
"""Constructor.
"""
xedges = np.arange(-0.5, num_cols)
yedges = np.arange(-0.5, num_rows)
super().__init__(xedges, yedges, xlabel, ylabel, zlabel)
[docs]
def _draw(self, axes, logz=False, **kwargs):
"""Overloaded method.
Note we have to transpose the underlying content due to the very
nature of item addressing in numpy arrays.
.. warning::
This points to the fact that some of the Histogram2d interfaces might
be broken, and we might better off with a content() method that
one can overload.
"""
image = axes.matshow(self._content.T, **kwargs)
axes.set_xticks(self._bin_edges[0], minor=True)
axes.set_yticks(self._bin_edges[1], minor=True)
axes.grid(which='minor', linewidth=1)
self._update_color_bar(axes, image)