"""
Class that can be use to evaluate and interpolate functions.
"""
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Callable, Tuple
import logging
import numpy as np
from scipy.interpolate import make_interp_spline, BSpline
from . import helpers
inputType = list[float] | np.ndarray
outputType = list[float | np.ndarray] | np.ndarray
[docs]
class InterpolatableFunction(ABC):
r"""
This is a totally-not-overengineered base class for defining optimized functions
:math:`f(x)` that, in addition to normal evaluation, support the following:
- Producing and using interpolation tables in favor of direct evaluation, where
applicable.
- Automatic adaptive updating of the interpolation table.
- Reading interpolation tables from a file.
- Producing said file for some range of inputs.
- Validating that what was read from a file makes sense, ie. matches the result
given by :py:meth:`_evaluate()`.
WallGo uses this class for evaluating the free energy as function of the
temperature. It can also be used for the thermal :math:`J_b, J_f` integrals.
This also works for functions returning many numbers, ie. vector functions
:math:`V(x) = [V1, V2, ...]`. In this case each component gets its own interpolation table.
Works with numpy array input and applying the function element-wise, but it is the
user's responsibility to ensure that the implementation of :py:meth:`_functionImplementation`
is compatible with this behavior. The logic is such that if x is an array and idx is
a index-tuple for an element in `x`, then `fx[idx]` is the value of `f(x)` at `x[idx]`. Note
that the shapes of `fx` and `x` will NOT match IF `f(x)` is vector valued.
Special care is needed if the function evaluation fails for some input `x`, eg. if the
function is evaluated only on some interval. In this case it is the user's
responsibility to return np.nan from _functionImplementation() for these input
values; this will mark these points as invalid and they will not be included in
interpolations. Failure to return np.nan for bad input will likely break the
interpolation.
Limitations.
- If the initial interpolation is bad, then it will remain bad: no functionality to
improve existing interpolations, only increase of the range is possible.
- Currently makes sense only for functions of one variable. However, you CAN call
this with numpy arrays of any shape (see above).
- Does NOT support piecewise functions as interpolations would break for those.
"""
[docs]
def __init__(
self,
bUseAdaptiveInterpolation: bool = True,
initialInterpolationPointCount: int = 1000,
returnValueCount: int = 1,
) -> None:
"""
Optional argument returnValueCount should be set by the user if using
list-valued functions.
Parameters
----------
bUseAdaptiveInterpolation : bool, optional
Whether or not to use adaptive interpolation. The default is True.
initialInterpolationPointCount : int, optional
Initial number of points for the interpolation. The default is 1000.
returnValueCount : int, optional
Number of outputs returned by the function. The default is 1.
"""
## Vector-like functions can return many values from one input, user needs to
## specify this when constructing the object
assert returnValueCount >= 1
self._RETURN_VALUE_COUNT = returnValueCount # pylint: disable=invalid-name
self._interpolatedFunction: BSpline
## Will hold list of interpolated derivatives, 1st and 2nd derivatives only
self._interpolatedDerivatives: list[Callable]
## These control out-of-bounds extrapolations.
## See toggleExtrapolation() function below.
self.extrapolationTypeLower = EExtrapolationType.NONE
self.extrapolationTypeUpper = EExtrapolationType.NONE
if bUseAdaptiveInterpolation:
self.enableAdaptiveInterpolation()
else:
self.disableAdaptiveInterpolation()
### Variables for adaptive interpolation
# This can safely be changed at runtime and adjusted for different functions
self._evaluationsUntilAdaptiveUpdate = 500
## keep list of values where the function had to be evaluated without
## interpolation, allows smart updating of ranges
self._directEvaluateCount = 0
self._directlyEvaluatedAt: inputType = []
## Range for which we have precalculated data ("x")
self._interpolationPoints: inputType = []
## f(x) for x in self._interpolationPoints
self._interpolationValues: outputType = []
"""This specifies how many points are calculated the first time an interpolation
table is constructed. If the interpolation range is changed later
(adaptive interpolation), more points will be added outside the initial table.
Point spacing is NOT guaranteed to be uniform in adaptive updating."""
self._initialInterpolationPointCount = initialInterpolationPointCount
self._rangeMin: float
self._rangeMax: float
@abstractmethod
def _functionImplementation(self, x: inputType | float) -> outputType:
"""
Override this with the function return value.
Do not call this directly, use the __call__ functionality instead.
If the function value is invalid for whatever reason, you should return np.nan.
This will guarantee that the invalid values are not included in interpolations
The return value can be a scalar, or a list if the function is vector valued.
Can also be a numpy array, in which case the function should be applied
element-wise. The number of elements returned needs to match
self._RETURN_VALUE_COUNT; for numpy array input, list length
self._RETURN_VALUE_COUNT for each x value. A list containing np.nan anywhere in
the list is interpreted as a failed evaluation, and this input x is not included
in interpolation
"""
""" Non abstracts """
[docs]
def interpolationRangeMin(self) -> float:
"""Get lower limit of our current interpolation table."""
return self._rangeMin
[docs]
def interpolationRangeMax(self) -> float:
"""Get upper limit of our current interpolation table."""
return self._rangeMax
[docs]
def numPoints(self) -> int:
"""How many input points in our interpolation table."""
return len(self._interpolationPoints)
[docs]
def hasInterpolation(self) -> bool:
"""Returns true if we have an interpolation table."""
return hasattr(self, "_interpolatedFunction")
# return self._interpolatedFunction is not None
[docs]
def enableAdaptiveInterpolation(self) -> None:
"""
Enables adaptive interpolation functionality.
Will clear internal work arrays.
"""
self._bUseAdaptiveInterpolation = True
self._directEvaluateCount = 0
self._directlyEvaluatedAt = []
[docs]
def disableAdaptiveInterpolation(self) -> None:
"""Disables adaptive interpolation functionality."""
self._bUseAdaptiveInterpolation = False
[docs]
def newInterpolationTable(
self,
xMin: float,
xMax: float,
numberOfPoints: int,
) -> None:
"""
Creates a new interpolation table over given range.
This will purge any existing interpolation information.
Parameters
----------
xMin : float
Minimal interpolation point.
xMax : float
Maximal interpolation point.
numberOfPoints : int
Number of points to use in the interpolation.
"""
xValues = np.linspace(xMin, xMax, numberOfPoints)
fx = self._functionImplementation(xValues)
self._interpolate(xValues, fx)
[docs]
def newInterpolationTableFromValues(
self,
x: inputType,
fx: outputType,
derivatives: list[outputType] | None = None,
splineDegree: int = 3
) -> None:
"""
Like initializeInterpolationTable but takes in precomputed function values 'fx'
Parameters
----------
x : list[float] or np.ndarray
Points where the function was evaluated.
fx : list[float | np.ndarray] or np.ndarray
Value of the function at x.
derivatives : list[outputType] | None
List containing the values of each derivative of the function at x. If None,
computes the derivatives from the interpolated spline.
"""
self._interpolate(x, fx, derivatives, splineDegree)
[docs]
def scheduleForInterpolation(self, x: inputType, fx: outputType) -> None:
"""
Add x, f(x) pairs to our pending interpolation table update
Parameters
----------
x : list[float] or np.ndarray
Points where the function was evaluated.
fx : list[float | np.ndarray] or np.ndarray
Value of the function at x.
"""
x = np.asanyarray(x)
fx = np.asanyarray(fx)
if np.ndim(x) == 0:
# Just got 1 input x
bValidResult = np.all(np.isfinite(fx))
# put x in array format for consistency with array input
xValid = np.array([x]) if bValidResult else np.array([])
else:
## Got many input x, keep only x values where f(x) is finite.
## For vector-valued f(x), keep x where ALL return values are finite
if self._RETURN_VALUE_COUNT > 1:
assert fx.shape == x.shape + (self._RETURN_VALUE_COUNT,), (
""
"Incompatable array shapes in scheduleForInterpolation(), "
"should not happen!"
)
validIndices = np.all(np.isfinite(fx), axis=-1)
else:
assert fx.shape == x.shape, (
""
"Incompatable array shapes in scheduleForInterpolation(), "
"should not happen!"
)
validIndices = np.all(np.isfinite(fx))
xValid = x[validIndices]
# Avoid unnecessary nested lists. This flattens to a 1D array,
## which is fine here since we're just storing x values for later
xValid = np.ravel(xValid)
# add x to our internal work list
if np.size(xValid) > 0:
xValid = np.unique(xValid)
self._directEvaluateCount += len(xValid)
self._directlyEvaluatedAt = np.concatenate(
(self._directlyEvaluatedAt, xValid)
)
if self._directEvaluateCount >= self._evaluationsUntilAdaptiveUpdate:
self._adaptiveInterpolationUpdate()
[docs]
def evaluateInterpolation(self, x: inputType | float) -> np.ndarray:
"""Evaluates our interpolated function at input x"""
return np.asarray(self._interpolatedFunction(x))
def _evaluateOutOfBounds(self, x: inputType) -> outputType:
"""
This gets called when the function is called outside the range of its
interpolation table. We either extrapolate (different extrapolations are
possible) or evaluate the function directly based on _functionImplementation().
"""
x = np.asanyarray(x)
bErrorExtrapolation = (
self.extrapolationTypeLower == EExtrapolationType.ERROR
and self.extrapolationTypeUpper == EExtrapolationType.ERROR
)
bNoExtrapolation = (
self.extrapolationTypeLower == EExtrapolationType.NONE
and self.extrapolationTypeUpper == EExtrapolationType.NONE
)
if bErrorExtrapolation:
raise ValueError(
f"Out of bounds: {x} outside [{self._rangeMin}, {self._rangeMax}]"
)
if not self.hasInterpolation() or bNoExtrapolation:
res = self._evaluateDirectly(x)
else:
## Now we have something to extrapolate
xLower = x <= self._rangeMin
xUpper = x >= self._rangeMax
# Figure out shape of the result. If we are vector valued, need an extra axis
if self._RETURN_VALUE_COUNT > 1:
resShape = x.shape + (self._RETURN_VALUE_COUNT,)
else:
resShape = x.shape
res = np.empty(resShape)
## Lower range
if np.any(xLower):
match self.extrapolationTypeLower:
case EExtrapolationType.ERROR:
# TODO better error message, this is nonsensible if x is array or list
raise ValueError(f"Out of bounds: {x} < {self._rangeMin}")
case EExtrapolationType.NONE:
res[xLower, :] = self._evaluateDirectly(x[xLower])
case EExtrapolationType.CONSTANT:
res[xLower, :] = self.evaluateInterpolation(self._rangeMin)
case EExtrapolationType.FUNCTION:
res[xLower, :] = self.evaluateInterpolation(x[xLower])
## Upper range
if np.any(xUpper):
match self.extrapolationTypeUpper:
case EExtrapolationType.ERROR:
# TODO better error message, this is nonsensible if x is array or list
raise ValueError(f"Out of bounds: {x} > {self._rangeMax}")
case EExtrapolationType.NONE:
res[xUpper, :] = self._evaluateDirectly(x[xUpper])
case EExtrapolationType.CONSTANT:
res[xUpper, :] = self.evaluateInterpolation(self._rangeMax)
case EExtrapolationType.FUNCTION:
res[xUpper, :] = self.evaluateInterpolation(x[xUpper])
return res
def __call__(self, x: inputType, bUseInterpolatedValues: bool = True) -> outputType:
"""
Just calls evaluate()
Parameters
----------
x : list[float] or np.ndarray
Points where the function will be evaluated.
bUseInterpolatedValues : bool, optional
Whether or not to use interpolation to evaluate the function.
The default is True.
Returns
-------
list[float | np.ndarray] or np.ndarray
Value of the function at x.
"""
return self.evaluate(x, bUseInterpolatedValues)
[docs]
def evaluate(self, x: inputType, bUseInterpolatedValues: bool = True) -> outputType:
"""
Evaluate the function.
Parameters
----------
x : list[float] or np.ndarray
Points where the function will be evaluated.
bUseInterpolatedValues : bool, optional
Whether or not to use interpolation to evaluate the function.
The default is True.
Returns
-------
list[float | np.ndarray] or np.ndarray
Value of the function at x.
"""
x = np.asanyarray(x)
if not bUseInterpolatedValues or not self.hasInterpolation():
return self._evaluateDirectly(x)
# Use interpolated values whenever possible
canInterpolateCondition, fxShape = self._findInterpolatablePoints(x)
needsEvaluationCondition = ~canInterpolateCondition
xInterpolateRegion = x[canInterpolateCondition]
xEvaluateRegion = x[needsEvaluationCondition]
results = np.empty(fxShape)
results[canInterpolateCondition] = self.evaluateInterpolation(
xInterpolateRegion
)
if xEvaluateRegion.size > 0:
results[needsEvaluationCondition] = self._evaluateOutOfBounds(
xEvaluateRegion
)
return results
def _evaluateDirectly(
self,
x: inputType,
bScheduleForInterpolation: bool = True,
) -> outputType:
"""
Evaluate the function directly based on _functionImplementation, instead of
using interpolations. This also accumulates data for the adaptive interpolation
functionality which is best kept separate from the abstract
_functionImplementation method.
"""
fx = self._functionImplementation(x)
if self._bUseAdaptiveInterpolation and bScheduleForInterpolation:
self.scheduleForInterpolation(x, fx)
return fx
[docs]
def derivative(
self,
x: inputType,
order: int = 1,
bUseInterpolation: bool = True,
epsilon: float = 1e-16,
scale: float = 1.0,
) -> outputType:
"""
Takes derivative of the function at points x. If bUseInterpolation=True, will
compute derivatives from the interpolated function (if it exists). nth order
derivative can be taken with order=n, however we only support interpolated
derivative of order=1,2 for now. epsilon and scale are parameters for the
helpers.derivative() routine.
Parameters
----------
x : list[float] or np.ndarray
Points where the derivative will be evaluated.
order : int, optional
Order of the derivative to take. The default is 1.
bUseInterpolation : bool, optional
Whether or not to use interpolation to evaluate the function.
The default is True.
epsilon : float, optional
Relative accuracy at which the function is evaluated. The default is 1e-16.
scale : float, optional
Scale at which the function changes by O(1). The default is 1.0.
Returns
-------
list[float | np.ndarray] or np.ndarray
Value of the derivative at x.
"""
x = np.asanyarray(x)
if (not bUseInterpolation or
not self.hasInterpolation() or
order > len(self._interpolatedDerivatives)):
return helpers.derivative(self._evaluateDirectly, x, n=order)
# Use interpolated values whenever possible
canInterpolateCondition, fxShape = self._findInterpolatablePoints(x)
needsEvaluationCondition = ~canInterpolateCondition
xEvaluateRegion = x[needsEvaluationCondition]
results = np.empty(fxShape)
results[canInterpolateCondition] = self._interpolatedDerivatives[order - 1](
x[canInterpolateCondition]
)
## Outside the interpolation region use whatever extrapolation
## type the function uses
if xEvaluateRegion.size > 0:
results[needsEvaluationCondition] = helpers.derivative(
self._evaluateOutOfBounds, x, n=order, epsilon=epsilon, scale=scale
)
return results
def _findInterpolatablePoints(
self,
x: np.ndarray,
) -> Tuple[np.ndarray, Tuple]:
"""
Finds x values where interpolation can be used. Return tuple is:
canInterpolateCondition, fxShape where the condition is a numpy bool array and
fxShape is the resulting shape of f(x).
"""
canInterpolateCondition = (x <= self._rangeMax) & (x >= self._rangeMin)
"""If x is N-dimensional array and idx is a tuple index for this array,
we want to return fx so that fx[idx] is the result of function evaluation at
x[idx]. But if f(x) is vector-valued then necessarily fx shape will not match x
shape. So figure out the shape here.
"""
if self._RETURN_VALUE_COUNT > 1:
fxShape = x.shape + (self._RETURN_VALUE_COUNT,)
else:
fxShape = x.shape
return canInterpolateCondition, fxShape
def _interpolate(
self,
x: inputType,
fx: outputType,
derivatives: list[outputType] | None = None,
splineDegree: int = 3,
) -> None:
"""Does the actual interpolation and sets some internal values.
Input x needs to be 1D, and input fx needs to be at most 2D.
"""
x = np.asanyarray(x)
fx = np.asanyarray(fx)
assert x.ndim == 1 and fx.ndim <= 2, (
"Shape error in _interpolate(), " "this should not happen!"
)
## Can't specify different extrapolation methods for x > xmax, x < xmin in
## Spline! This logic is handled manually in __call__()
bShouldExtrapolate = EExtrapolationType.FUNCTION in (
self.extrapolationTypeLower,
self.extrapolationTypeUpper,
)
## Explicitly drop non-numerics
xFiltered, fxFiltered, derivativesFiltered = self._dropBadPoints(x, fx,
derivatives)
## This works even if f(x) is vector valued
self._interpolatedFunction = make_interp_spline(
xFiltered, fxFiltered, k=splineDegree, axis=0
)
self._interpolatedFunction.extrapolate = bShouldExtrapolate
self._rangeMin = np.min(xFiltered)
self._rangeMax = np.max(xFiltered)
self._interpolationPoints = xFiltered
self._interpolationValues = fxFiltered
"""Store a cubic spline for the 1st and 2nd derivatives into a list.
We do not attempt to spline the higher derivatives as they are not
guaranteed to be continuous."""
if derivatives is None or len(derivatives) == 0:
self._interpolatedDerivatives = [
self._interpolatedFunction.derivative(1),
self._interpolatedFunction.derivative(2),
]
else:
self._interpolatedDerivatives = []
for d in derivativesFiltered:
self._interpolatedDerivatives.append(make_interp_spline(
xFiltered, d, k=splineDegree, axis=0
))
self._interpolatedDerivatives[-1].extrapolate = bShouldExtrapolate
if len(self._interpolatedDerivatives) == 1:
self._interpolatedDerivatives.append(
self._interpolatedDerivatives[0].derivative(1))
@staticmethod
def _dropBadPoints(
x: np.ndarray,
fx: np.ndarray,
derivatives: list[outputType] | None = None,
) -> tuple[np.ndarray, np.ndarray, list[outputType] | None]:
"""
Removes non-numerical (x, fx) pairs. For 2D fx the check is applied row-wise.
Input x needs to be 1D, and input fx needs to be at most 2D.
Output is same shape as input.
"""
if derivatives is None:
derivativesValid = None
else:
derivativesValid = []
if fx.ndim > 1:
validIndices = np.all(np.isfinite(fx), axis=1)
fxValid = fx[validIndices]
if derivatives is not None:
for d in derivatives:
derivativesValid.append(d[validIndices])
else:
## fx is 1D array
validIndices = np.all(np.isfinite(fx))
fxValid = np.ravel(fx[validIndices])
if derivatives is not None:
for d in derivatives:
derivativesValid.append(np.ravel(d[validIndices]))
xValid = np.ravel(x[validIndices])
return xValid, fxValid, derivativesValid
def _adaptiveInterpolationUpdate(self) -> None:
"""
Handles interpolation table updates for adaptive interpolation.
"""
## Where did the new evaluations happen
evaluatedPointMin = np.min(self._directlyEvaluatedAt)
evaluatedPointMax = np.max(self._directlyEvaluatedAt)
# Reset work variables (doing this here already to avoid spaghetti nesting)
self._directEvaluateCount = 0
self._directlyEvaluatedAt = []
if self.hasInterpolation():
appendPointCount = int(0.2 * self._initialInterpolationPointCount)
else:
appendPointCount = int(self._initialInterpolationPointCount / 2)
self.extendInterpolationTable(
evaluatedPointMin, evaluatedPointMax, appendPointCount, appendPointCount
)
[docs]
def extendInterpolationTable(
self,
newMin: float,
newMax: float,
pointsMin: int,
pointsMax: int,
) -> None:
"""
Extend our interpolation table.
NB: This will reset internally accumulated data of adaptive interpolation.
Parameters
----------
newMin : float
New minimal value at which the interpolation starts.
newMax : float
New maximal value at which the interpolation starts.
pointsMin : int
Minimal number of points to use.
pointsMax : int
Maximal number of points to use.
"""
if not self.hasInterpolation():
newPoints = int(pointsMin + pointsMax)
logging.warning(
f"Warning: {self.__class__.__name__}.extendInterpolationRange() "
"called without existing interpolation. "
f"Creating new table in range [{newMin}, {newMax}] with {newPoints} "
"points"
)
self.newInterpolationTable(newMin, newMax, newPoints)
return
# what to append to lower end
if newMin < self._rangeMin and pointsMin > 0:
## Point spacing to use at new lower end
spacing = np.abs(self._rangeMin - newMin) / pointsMin
# arange stops one spacing before the max value, which is what we want
appendPointsMin = np.arange(newMin, self._rangeMin, spacing)
else:
appendPointsMin = np.array([])
# what to append to upper end
if newMax > self._rangeMax and pointsMax > 0:
## Point spacing to use at new upper end
spacing = np.abs(newMax - self._rangeMax) / pointsMax
appendPointsMax = np.arange(
self._rangeMax + spacing, newMax + spacing, spacing
)
else:
appendPointsMax = np.array([])
appendValuesMin = np.asarray(self._functionImplementation(appendPointsMin))
appendValuesMax = np.asarray(self._functionImplementation(appendPointsMax))
# Ordering is important since interpolation needs the x values to be ordered.
# This works, but could be made safer by rearranging the resulting arrays:
xRange = np.concatenate(
(appendPointsMin, self._interpolationPoints, appendPointsMax)
)
fxRange: np.ndarray = np.concatenate(
(appendValuesMin, np.asarray(self._interpolationValues), appendValuesMax)
)
self.newInterpolationTableFromValues(xRange, fxRange)
## Hacky reset of adaptive routines
if self._bUseAdaptiveInterpolation:
self.disableAdaptiveInterpolation()
self.enableAdaptiveInterpolation()
[docs]
def readInterpolationTable(self, fileToRead: str) -> None:
"""
Reads precalculated values from a file and does cubic interpolation.
Each line in the file must be of form x f(x).
For vector valued functions: x f1(x) f2(x)
Parameters
----------
fileToRead : str
Path of the file where the interpolation table is stored.
"""
# for logging
selfName = self.__class__.__name__
try:
## Each line should be of form x f(x).
## For vector valued functions, x f1(x) f2(x) ...
data = np.genfromtxt(
fileToRead, delimiter=" ", dtype=float, encoding=None
)
columns = data.shape[1]
# now slice this column-wise. First column is x:
x = data[:, 0]
# and for fx we remove the first column,
# using magic syntax 1: to leave all others
fx = data[:, 1:]
## If f(x) is 1D, this actually gives it in messy format
## [ [fx1] [fx2] ...]. So let's fix that
if columns == 2:
fx = np.ravel(fx)
self._interpolate(x, fx)
## check that what we read matches our function definition (just evaluate
## and compare at a few values)
self._validateInterpolationTable(self._rangeMin)
self._validateInterpolationTable(self._rangeMax)
self._validateInterpolationTable((self._rangeMax - self._rangeMin) / 2.55)
logging.debug(
"%s: Succesfully read interpolation table from file. "
"Range [%g, %g]",
selfName,
self._rangeMin,
self._rangeMax,
)
except IOError as ioError:
logging.warning(
f"IOError! {selfName} attempted to read interpolation table from "
"file, but got error:"
)
logging.warning(ioError)
logging.warning("This is non-fatal. Interpolation table will not be updated.\n")
[docs]
def writeInterpolationTable(self, outputFileName: str) -> None:
"""
Write our interpolation table to file.
Parameters
----------
outputFileName : str
Name of the file where the interpolation table will be written.
"""
try:
## Write to file, line i is of form: x[i] fx[i]. If our function is vector
## valued then x[i] fx1[i] fx2[i] ...
stackedArray = np.column_stack(
(
np.asarray(self._interpolationPoints),
np.asarray(self._interpolationValues),
)
)
np.savetxt(outputFileName, stackedArray, fmt="%.15g", delimiter=" ")
logging.debug(
"Stored interpolation table for function "
f"{self.__class__.__name__}, output file {outputFileName}."
)
except Exception as e:
logging.warning(
f"Error from {self.__class__.__name__}, function "
f"writeInterpolationTable(): {e}"
)
def _validateInterpolationTable(
self,
x: float,
absoluteTolerance: float = 1e-6,
) -> bool:
"""
Test the interpolation table with some input.
Result should agree with self._evaluateDirectly(x).
"""
if (
self._interpolatedFunction is None
or not self._rangeMin <= x <= self._rangeMax
):
logging.warning(
f"{self.__class__.__name__}: _validateInterpolationTable called, "
"but no valid interpolation table was found."
)
return False
diff = self.evaluateInterpolation(x) - self._functionImplementation(x)
if np.any(np.abs(diff) > absoluteTolerance):
logging.warning(
f"{self.__class__.__name__}: Could not validate interpolation table!"
f" Value discrepancy was {diff}"
)
return False
return True