Source code for pyproximal.proximal.L0

import warnings
from collections.abc import Callable
from typing import Any

import numpy as np
from pylops.utils.typing import NDArray
from typing_extensions import Self

from pyproximal.projection import L0BallProj, L10BallProj
from pyproximal.proximal.L1 import _current_sigma
from pyproximal.ProxOperator import ProxOperator, _check_tau
from pyproximal.utils.typing import FloatCallableLike, IntCallableLike


def _hardthreshold(x: NDArray, thresh: float) -> NDArray:
    r"""Hard thresholding.

    Applies hard thresholding to vector ``x`` (equal to the proximity
    operator for :math:`\|\mathbf{x}\|_0`) as shown in [1]_.

    .. [1] Chen, F., Shen, L., Suter, B.W., "Computing the proximity
       operator of the Lp norm with 0 < p < 1",
       IET Signal Processing, 10, 2016.

    Parameters
    ----------
    x : :obj:`numpy.ndarray`
        Vector
    thresh : :obj:`float`
        Threshold

    Returns
    -------
    x1 : :obj:`numpy.ndarray`
        Tresholded vector

    """
    x1 = x.copy()
    x1[np.abs(x) <= thresh] = 0
    return x1


def _current_radius(
    radius: IntCallableLike,
    count: int,
) -> int | NDArray:
    if not callable(radius):
        return radius
    else:
        return radius(count)


[docs] class L0(ProxOperator): r""":math:`L_0` norm proximal operator. Proximal operator of the :math:`\ell_0` norm: :math:`\sigma\|\mathbf{x}\|_0 = \text{count}(x_i \ne 0)`. Parameters ---------- sigma : :obj:`float` or :obj:`numpy.ndarray` or :obj:`func`, optional Multiplicative coefficient of L0 norm. This can be a constant number, a list of values (for multidimensional inputs, acting on the second dimension) or a function that is called passing a counter which keeps track of how many times the ``prox`` method has been invoked before and returns a scalar (or a list of) ``sigma`` to be used. Notes ----- The :math:`\ell_0` proximal operator is defined as: .. math:: \prox_{\tau \sigma \|\cdot\|_0}(\mathbf{x}) = \operatorname{hard}(\mathbf{x}, \tau \sigma) = \begin{cases} x_i, & x_i < -\tau \sigma \\ 0, & -\tau\sigma \leq x_i \leq \tau\sigma \\ x_i, & x_i > \tau\sigma\\ \end{cases} where :math:`\operatorname{hard}` is the so-called called *hard thresholding*. """ def __init__( self, sigma: FloatCallableLike = 1.0, ) -> None: super().__init__(None, False) self.sigma = sigma self.count = 0 def __call__(self, x: NDArray) -> int: return int(np.sum(np.abs(x) > 0.0)) def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]: """Increment counter""" def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any: self.count += 1 return func(self, *args, **kwargs) return wrapped @_increment_count @_check_tau def prox(self, x: NDArray, tau: float) -> NDArray: sigma = _current_sigma(self.sigma, self.count) x = _hardthreshold(x, tau * sigma) return x
[docs] class L0Ball(ProxOperator): r""":math:`L_0` ball proximal operator. Proximal operator of the L0 ball: :math:`L0_{r} = \{ \mathbf{x}: ||\mathbf{x}||_0 \leq r \}`. Parameters ---------- radius : :obj:`int` or :obj:`func`, optional Radius. This can be a constant number or a function that is called passing a counter which keeps track of how many times the ``prox`` method has been invoked before and returns a scalar ``radius`` to be used. Notes ----- As the L0 ball is an indicator function, the proximal operator corresponds to its orthogonal projection (see :class:`pyproximal.projection.L0BallProj` for details. """ def __init__(self, radius: IntCallableLike) -> None: super().__init__(None, False) self.radius = radius if callable(radius): radius_resolved = radius(0) else: radius_resolved = radius self.ball = L0BallProj(radius_resolved) self.count = 0 def __call__(self, x: NDArray) -> bool: radius = _current_radius(self.radius, self.count) return bool(np.linalg.norm(np.abs(x), ord=0) <= radius) def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]: """Increment counter""" def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any: self.count += 1 return func(self, *args, **kwargs) return wrapped @_increment_count @_check_tau def prox(self, x: NDArray, tau: float) -> NDArray: radius = _current_radius(self.radius, self.count) self.ball.radius = radius y = self.ball(x) return y
[docs] class L10Ball(ProxOperator): r""":math:`L_{1,0}` ball proximal operator. Proximal operator of the :math:`L_{1,0}` ball: :math:`L_{1,0}^{r} = \{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1, ||\mathbf{x}_2||_1, ..., ||\mathbf{x}_1||_1] \ne 0) \leq r \}` Parameters ---------- ndim : :obj:`int` Number of dimensions :math:`N_{dim}`. Used to reshape the input array in a matrix of size :math:`N_{dim} \times N'_{x}` where :math:`N'_x = \frac{N_x}{N_{dim}}`. Note that the input vector ``x`` should be created by stacking vectors from different dimensions. radius : :obj:`int` or :obj:`func`, optional Radius. This can be a constant number or a function that is called passing a counter which keeps track of how many times the ``prox`` method has been invoked before and returns a scalar ``radius`` to be used. Notes ----- As the :math:`L_{1,0}` ball is an indicator function, the proximal operator corresponds to its orthogonal projection (see :class:`pyproximal.projection.L10BallProj` for details. """ def __init__(self, ndim: int, radius: IntCallableLike) -> None: super().__init__(None, False) self.ndim = ndim self.radius = radius if callable(radius): radius_resolved = radius(0) else: radius_resolved = radius self.ball = L10BallProj(radius_resolved) self.count = 0 def __call__(self, x: NDArray, tol: float = 1e-4) -> bool: x = x.reshape(self.ndim, len(x) // self.ndim) radius = _current_radius(self.radius, self.count) return bool(np.linalg.norm(np.linalg.norm(x, ord=1, axis=0), ord=0) <= radius) def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]: """Increment counter""" def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any: self.count += 1 return func(self, *args, **kwargs) return wrapped @_increment_count @_check_tau def prox(self, x: NDArray, tau: float) -> NDArray: x = x.reshape(self.ndim, len(x) // self.ndim) radius = _current_radius(self.radius, self.count) self.ball.radius = radius y = self.ball(x) return y.ravel()
class L01Ball(L10Ball): def __init__(self, ndim: int, radius: IntCallableLike) -> None: warnings.warn( "The L01Ball class has been renamed L10Ball due " "to a mistake in the original choice of the name. As such " "L01Ball will be deprecated in v1.0.0.", FutureWarning, stacklevel=2, ) super().__init__(ndim, radius)