import logging
from typing import Optional, Union
import numpy as np
from pylops.utils.backend import get_array_module, to_cupy_conditional
from pylops.utils.typing import NDArray, ShapeLike
from pyproximal.projection.Simplex import SimplexProj
from pyproximal.ProxOperator import ProxOperator, _check_tau
try:
from numba import jit
from ._Simplex_cuda import simplex_jit_cuda
from ._Simplex_numba import bisect_jit, fun_jit, simplex_jit
except ModuleNotFoundError:
jit = None
jit_message = "Numba not available, reverting to numpy."
except Exception as e:
jit = None
jit_message = "Failed to import numba (error:%s), use numpy." % e
logger = logging.getLogger(__name__)
class _Simplex(ProxOperator):
"""Simplex operator (numpy version)"""
def __init__(
self,
n: int,
radius: float,
dims: Optional[ShapeLike] = None,
axis: int = -1,
maxiter: int = 100,
xtol: float = 1e-8,
call: bool = True,
) -> None:
super().__init__(None, False)
if dims is not None and len(dims) != 2:
raise ValueError("provide only 2 dimensions, or None")
self.n = n
self.radius = radius
self.dims = dims
self.axis = axis
self.otheraxis = 1 if axis == 0 else 0
self.maxiter = maxiter
self.xtol = xtol
self.call = call
self.simplex = SimplexProj(
self.n if dims is None else dims[axis],
self.radius,
maxiter=self.maxiter,
xtol=self.xtol,
)
def __call__(self, x: NDArray, tol: float = 1e-8) -> bool:
if not self.call:
return False
if self.dims is None:
radcheck = np.sum(x) - self.radius > tol or np.sum(x) - self.radius < -tol
c = not (radcheck or (np.any(x < 0)))
else:
x = x.reshape(self.dims)
if self.axis == 0:
x = x.T
carr = np.empty(self.dims[self.otheraxis], dtype=np.bool)
for i in range(self.dims[self.otheraxis]):
carr[i] = not (
np.abs(np.sum(x)) - self.radius < tol or np.any(x[i] < 0)
)
c = bool(np.all(carr))
return c
@_check_tau
def prox(self, x: NDArray, tau: float) -> NDArray:
if self.dims is None:
y = self.simplex(x)
else:
x = x.reshape(self.dims)
if self.axis == 0:
x = x.T
y = np.zeros_like(x)
for i in range(self.dims[self.otheraxis]):
y[i] = self.simplex(x[i])
if self.axis == 0:
y = y.T
return y.ravel()
class _Simplex_numba(_Simplex):
"""Simplex operator (numba version)"""
def __init__(
self,
n: int,
radius: float,
dims: Optional[ShapeLike] = None,
axis: int = -1,
maxiter: int = 100,
ftol: float = 1e-8,
xtol: float = 1e-8,
call: bool = False,
) -> None:
super().__init__(n, radius, dims, axis, maxiter, xtol, call)
self.ftol = ftol
self.coeffs = np.ones(self.n if dims is None else dims[axis])
@_check_tau
def prox(self, x: NDArray, tau: float) -> NDArray:
if self.dims is None:
bisect_lower = -1
while (
fun_jit(bisect_lower, x, self.coeffs, self.radius, 0, 10000000000) < 0
):
bisect_lower *= 2
bisect_upper = 1
while (
fun_jit(bisect_upper, x, self.coeffs, self.radius, 0, 10000000000) > 0
):
bisect_upper *= 2
c = bisect_jit(
x,
self.coeffs,
self.radius,
0,
10000000000,
bisect_lower,
bisect_upper,
self.maxiter,
self.ftol,
self.xtol,
)
y = np.minimum(np.maximum(x - c * self.coeffs, 0), 10000000000)
else:
x = x.reshape(self.dims)
if self.axis == 0:
x = x.T
y = simplex_jit(
x,
self.coeffs,
self.radius,
0,
10000000000,
self.maxiter,
self.ftol,
self.xtol,
)
if self.axis == 0:
y = y.T
return y.ravel()
class _Simplex_cuda(_Simplex):
"""Simplex operator (cuda version)
This implementation is adapted from https://github.com/DIG-Kaust/HPC_Hackathon_DIG.
"""
def __init__(
self,
n: int,
radius: float,
dims: Optional[ShapeLike] = None,
axis: int = -1,
maxiter: int = 100,
ftol: float = 1e-8,
xtol: float = 1e-8,
call: bool = False,
num_threads_per_blocks: int = 32,
) -> None:
super().__init__(n, radius, dims, axis, maxiter, xtol, call)
self.ftol = ftol
self.coeffs = np.ones(self.n if dims is None else dims[axis])
self.num_threads_per_blocks = num_threads_per_blocks
@_check_tau
def prox(self, x: NDArray, tau: float) -> NDArray:
ncp = get_array_module(x)
x = x.reshape(self.dims)
if self.axis == 0:
x = x.T
if type(self.coeffs) is not type(x):
self.coeffs = to_cupy_conditional(x, self.coeffs)
y = ncp.empty_like(x)
num_blocks = (
x.shape[0] + self.num_threads_per_blocks - 1
) // self.num_threads_per_blocks
simplex_jit_cuda[num_blocks, self.num_threads_per_blocks](
x,
self.coeffs,
self.radius,
0,
10000000000,
self.maxiter,
self.ftol,
self.xtol,
y,
)
if self.axis == 0:
y = y.T
return y.ravel()
[docs]def Simplex(
n: int,
radius: float,
dims: Optional[ShapeLike] = None,
axis: int = -1,
maxiter: int = 100,
ftol: float = 1e-8,
xtol: float = 1e-8,
call: bool = True,
engine: str = "numpy",
) -> ProxOperator:
r"""Simplex proximal operator.
Proximal operator of a Simplex: :math:`\Delta_n(r) = \{ \mathbf{x}:
\sum_i x_i = r,\; x_i \geq 0 \}`. This operator can be applied to a
single vector as well as repeatedly to a set of vectors which are
defined as the rows (or columns) of a matrix obtained by reshaping the
input vector as defined by the ``dims`` and ``axis`` parameters.
Parameters
----------
n : :obj:`int`
Number of elements of input vector
radius : :obj:`float`
Radius
dims : :obj:`tuple`, optional
Dimensions of the matrix onto which the input vector is reshaped
axis : :obj:`int`, optional
Axis along which simplex is repeatedly applied when ``dims`` is
provided
maxiter : :obj:`int`, optional
Maximum number of iterations used by bisection
ftol : :obj:`float`, optional
Function tolerance in bisection (only with ``engine='numba'`` or ``engine='cuda'``)
xtol : :obj:`float`, optional
Solution absolute tolerance in bisection
call : :obj:`bool`, optional
Evalutate call method (``True``) or not (``False``)
engine : :obj:`str`, optional
Engine used for simplex computation (``numpy``, ``numba``or ``cuda``).
Raises
------
KeyError
If ``engine`` is neither ``numpy`` nor ``numba`` nor ``cuda``
ValueError
If ``dims`` is provided as a list (or tuple) with more or less than
2 elements
Notes
-----
As the Simplex is an indicator function, the proximal operator corresponds
to its orthogonal projection (see :class:`pyproximal.projection.SimplexProj`
for details.
Note that ``tau`` does not have effect for this proximal operator, any
positive number can be provided.
"""
if engine not in ["numpy", "numba", "cuda"]:
raise KeyError("engine must be numpy or numba or cuda")
s: Union[_Simplex, _Simplex_numba, _Simplex_cuda]
if engine == "numba" and jit is not None:
s = _Simplex_numba(
n,
radius,
dims=dims,
axis=axis,
maxiter=maxiter,
ftol=ftol,
xtol=xtol,
call=call,
)
elif engine == "cuda" and jit is not None:
s = _Simplex_cuda(
n,
radius,
dims=dims,
axis=axis,
maxiter=maxiter,
ftol=ftol,
xtol=xtol,
call=call,
)
else:
if engine == "numba" and jit is None:
logger.warning(jit_message)
s = _Simplex(
n, radius, dims=dims, axis=axis, maxiter=maxiter, xtol=xtol, call=call
)
return s