Source code for pyproximal.proximal.Simplex

import logging
from typing import Optional, Union

import numpy as np
from pylops.utils.backend import get_array_module, to_cupy_conditional
from pylops.utils.typing import NDArray, ShapeLike

from pyproximal.projection.Simplex import SimplexProj
from pyproximal.ProxOperator import ProxOperator, _check_tau

try:
    from numba import jit

    from ._Simplex_cuda import simplex_jit_cuda
    from ._Simplex_numba import bisect_jit, fun_jit, simplex_jit
except ModuleNotFoundError:
    jit = None
    jit_message = "Numba not available, reverting to numpy."
except Exception as e:
    jit = None
    jit_message = "Failed to import numba (error:%s), use numpy." % e

logger = logging.getLogger(__name__)


class _Simplex(ProxOperator):
    """Simplex operator (numpy version)"""

    def __init__(
        self,
        n: int,
        radius: float,
        dims: Optional[ShapeLike] = None,
        axis: int = -1,
        maxiter: int = 100,
        xtol: float = 1e-8,
        call: bool = True,
    ) -> None:
        super().__init__(None, False)
        if dims is not None and len(dims) != 2:
            raise ValueError("provide only 2 dimensions, or None")
        self.n = n
        self.radius = radius
        self.dims = dims
        self.axis = axis
        self.otheraxis = 1 if axis == 0 else 0
        self.maxiter = maxiter
        self.xtol = xtol
        self.call = call

        self.simplex = SimplexProj(
            self.n if dims is None else dims[axis],
            self.radius,
            maxiter=self.maxiter,
            xtol=self.xtol,
        )

    def __call__(self, x: NDArray, tol: float = 1e-8) -> bool:
        if not self.call:
            return False
        if self.dims is None:
            radcheck = np.sum(x) - self.radius > tol or np.sum(x) - self.radius < -tol
            c = not (radcheck or (np.any(x < 0)))
        else:
            x = x.reshape(self.dims)
            if self.axis == 0:
                x = x.T
            carr = np.empty(self.dims[self.otheraxis], dtype=np.bool)
            for i in range(self.dims[self.otheraxis]):
                carr[i] = not (
                    np.abs(np.sum(x)) - self.radius < tol or np.any(x[i] < 0)
                )
            c = bool(np.all(carr))
        return c

    @_check_tau
    def prox(self, x: NDArray, tau: float) -> NDArray:
        if self.dims is None:
            y = self.simplex(x)
        else:
            x = x.reshape(self.dims)
            if self.axis == 0:
                x = x.T
            y = np.zeros_like(x)
            for i in range(self.dims[self.otheraxis]):
                y[i] = self.simplex(x[i])
            if self.axis == 0:
                y = y.T
        return y.ravel()


class _Simplex_numba(_Simplex):
    """Simplex operator (numba version)"""

    def __init__(
        self,
        n: int,
        radius: float,
        dims: Optional[ShapeLike] = None,
        axis: int = -1,
        maxiter: int = 100,
        ftol: float = 1e-8,
        xtol: float = 1e-8,
        call: bool = False,
    ) -> None:
        super().__init__(n, radius, dims, axis, maxiter, xtol, call)
        self.ftol = ftol
        self.coeffs = np.ones(self.n if dims is None else dims[axis])

    @_check_tau
    def prox(self, x: NDArray, tau: float) -> NDArray:
        if self.dims is None:
            bisect_lower = -1
            while (
                fun_jit(bisect_lower, x, self.coeffs, self.radius, 0, 10000000000) < 0
            ):
                bisect_lower *= 2
            bisect_upper = 1
            while (
                fun_jit(bisect_upper, x, self.coeffs, self.radius, 0, 10000000000) > 0
            ):
                bisect_upper *= 2
            c = bisect_jit(
                x,
                self.coeffs,
                self.radius,
                0,
                10000000000,
                bisect_lower,
                bisect_upper,
                self.maxiter,
                self.ftol,
                self.xtol,
            )
            y = np.minimum(np.maximum(x - c * self.coeffs, 0), 10000000000)
        else:
            x = x.reshape(self.dims)
            if self.axis == 0:
                x = x.T
            y = simplex_jit(
                x,
                self.coeffs,
                self.radius,
                0,
                10000000000,
                self.maxiter,
                self.ftol,
                self.xtol,
            )
            if self.axis == 0:
                y = y.T
        return y.ravel()


class _Simplex_cuda(_Simplex):
    """Simplex operator (cuda version)

    This implementation is adapted from https://github.com/DIG-Kaust/HPC_Hackathon_DIG.

    """

    def __init__(
        self,
        n: int,
        radius: float,
        dims: Optional[ShapeLike] = None,
        axis: int = -1,
        maxiter: int = 100,
        ftol: float = 1e-8,
        xtol: float = 1e-8,
        call: bool = False,
        num_threads_per_blocks: int = 32,
    ) -> None:
        super().__init__(n, radius, dims, axis, maxiter, xtol, call)
        self.ftol = ftol
        self.coeffs = np.ones(self.n if dims is None else dims[axis])
        self.num_threads_per_blocks = num_threads_per_blocks

    @_check_tau
    def prox(self, x: NDArray, tau: float) -> NDArray:
        ncp = get_array_module(x)
        x = x.reshape(self.dims)
        if self.axis == 0:
            x = x.T
        if type(self.coeffs) is not type(x):
            self.coeffs = to_cupy_conditional(x, self.coeffs)

        y = ncp.empty_like(x)
        num_blocks = (
            x.shape[0] + self.num_threads_per_blocks - 1
        ) // self.num_threads_per_blocks
        simplex_jit_cuda[num_blocks, self.num_threads_per_blocks](
            x,
            self.coeffs,
            self.radius,
            0,
            10000000000,
            self.maxiter,
            self.ftol,
            self.xtol,
            y,
        )
        if self.axis == 0:
            y = y.T
        return y.ravel()


[docs]def Simplex( n: int, radius: float, dims: Optional[ShapeLike] = None, axis: int = -1, maxiter: int = 100, ftol: float = 1e-8, xtol: float = 1e-8, call: bool = True, engine: str = "numpy", ) -> ProxOperator: r"""Simplex proximal operator. Proximal operator of a Simplex: :math:`\Delta_n(r) = \{ \mathbf{x}: \sum_i x_i = r,\; x_i \geq 0 \}`. This operator can be applied to a single vector as well as repeatedly to a set of vectors which are defined as the rows (or columns) of a matrix obtained by reshaping the input vector as defined by the ``dims`` and ``axis`` parameters. Parameters ---------- n : :obj:`int` Number of elements of input vector radius : :obj:`float` Radius dims : :obj:`tuple`, optional Dimensions of the matrix onto which the input vector is reshaped axis : :obj:`int`, optional Axis along which simplex is repeatedly applied when ``dims`` is provided maxiter : :obj:`int`, optional Maximum number of iterations used by bisection ftol : :obj:`float`, optional Function tolerance in bisection (only with ``engine='numba'`` or ``engine='cuda'``) xtol : :obj:`float`, optional Solution absolute tolerance in bisection call : :obj:`bool`, optional Evalutate call method (``True``) or not (``False``) engine : :obj:`str`, optional Engine used for simplex computation (``numpy``, ``numba``or ``cuda``). Raises ------ KeyError If ``engine`` is neither ``numpy`` nor ``numba`` nor ``cuda`` ValueError If ``dims`` is provided as a list (or tuple) with more or less than 2 elements Notes ----- As the Simplex is an indicator function, the proximal operator corresponds to its orthogonal projection (see :class:`pyproximal.projection.SimplexProj` for details. Note that ``tau`` does not have effect for this proximal operator, any positive number can be provided. """ if engine not in ["numpy", "numba", "cuda"]: raise KeyError("engine must be numpy or numba or cuda") s: Union[_Simplex, _Simplex_numba, _Simplex_cuda] if engine == "numba" and jit is not None: s = _Simplex_numba( n, radius, dims=dims, axis=axis, maxiter=maxiter, ftol=ftol, xtol=xtol, call=call, ) elif engine == "cuda" and jit is not None: s = _Simplex_cuda( n, radius, dims=dims, axis=axis, maxiter=maxiter, ftol=ftol, xtol=xtol, call=call, ) else: if engine == "numba" and jit is None: logger.warning(jit_message) s = _Simplex( n, radius, dims=dims, axis=axis, maxiter=maxiter, xtol=xtol, call=call ) return s