import warnings
from typing import Any, Callable, Union
import numpy as np
from pylops.utils.typing import NDArray
from typing_extensions import Self
from pyproximal.projection import L0BallProj, L10BallProj
from pyproximal.proximal.L1 import _current_sigma
from pyproximal.ProxOperator import ProxOperator, _check_tau
from pyproximal.utils.typing import FloatCallableLike, IntCallableLike
def _hardthreshold(x: NDArray, thresh: float) -> NDArray:
r"""Hard thresholding.
Applies hard thresholding to vector ``x`` (equal to the proximity
operator for :math:`\|\mathbf{x}\|_0`) as shown in [1]_.
.. [1] Chen, F., Shen, L., Suter, B.W., "Computing the proximity
operator of the Lp norm with 0 < p < 1",
IET Signal Processing, 10, 2016.
Parameters
----------
x : :obj:`numpy.ndarray`
Vector
thresh : :obj:`float`
Threshold
Returns
-------
x1 : :obj:`numpy.ndarray`
Tresholded vector
"""
x1 = x.copy()
x1[np.abs(x) <= thresh] = 0
return x1
def _current_radius(
radius: IntCallableLike,
count: int,
) -> Union[int, NDArray]:
if not callable(radius):
return radius
else:
return radius(count)
[docs]class L0(ProxOperator):
r""":math:`L_0` norm proximal operator.
Proximal operator of the :math:`\ell_0` norm:
:math:`\sigma\|\mathbf{x}\|_0 = \text{count}(x_i \ne 0)`.
Parameters
----------
sigma : :obj:`float` or :obj:`np.ndarray` or :obj:`func`, optional
Multiplicative coefficient of L0 norm. This can be a constant number, a list
of values (for multidimensional inputs, acting on the second dimension) or
a function that is called passing a counter which keeps track of how many
times the ``prox`` method has been invoked before and returns a scalar (or a list of)
``sigma`` to be used.
Notes
-----
The :math:`\ell_0` proximal operator is defined as:
.. math::
\prox_{\tau \sigma \|\cdot\|_0}(\mathbf{x}) =
\operatorname{hard}(\mathbf{x}, \tau \sigma) =
\begin{cases}
x_i, & x_i < -\tau \sigma \\
0, & -\tau\sigma \leq x_i \leq \tau\sigma \\
x_i, & x_i > \tau\sigma\\
\end{cases}
where :math:`\operatorname{hard}` is the so-called called *hard thresholding*.
"""
def __init__(
self,
sigma: FloatCallableLike = 1.0,
) -> None:
super().__init__(None, False)
self.sigma = sigma
self.count = 0
def __call__(self, x: NDArray) -> int:
return int(np.sum(np.abs(x) > 0.0))
def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]:
"""Increment counter"""
def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any:
self.count += 1
return func(self, *args, **kwargs)
return wrapped
@_increment_count
@_check_tau
def prox(self, x: NDArray, tau: float) -> NDArray:
sigma = _current_sigma(self.sigma, self.count)
x = _hardthreshold(x, tau * sigma)
return x
[docs]class L0Ball(ProxOperator):
r""":math:`L_0` ball proximal operator.
Proximal operator of the L0 ball: :math:`L0_{r} =
\{ \mathbf{x}: ||\mathbf{x}||_0 \leq r \}`.
Parameters
----------
radius : :obj:`int` or :obj:`func`, optional
Radius. This can be a constant number or a function that is called passing a
counter which keeps track of how many times the ``prox`` method has been
invoked before and returns a scalar ``radius`` to be used.
Notes
-----
As the L0 ball is an indicator function, the proximal operator
corresponds to its orthogonal projection
(see :class:`pyproximal.projection.L0BallProj` for details.
"""
def __init__(self, radius: IntCallableLike) -> None:
super().__init__(None, False)
self.radius = radius
if callable(radius):
radius_resolved = radius(0)
else:
radius_resolved = radius
self.ball = L0BallProj(radius_resolved)
self.count = 0
def __call__(self, x: NDArray) -> bool:
radius = _current_radius(self.radius, self.count)
return bool(np.linalg.norm(np.abs(x), ord=0) <= radius)
def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]:
"""Increment counter"""
def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any:
self.count += 1
return func(self, *args, **kwargs)
return wrapped
@_increment_count
@_check_tau
def prox(self, x: NDArray, tau: float) -> NDArray:
radius = _current_radius(self.radius, self.count)
self.ball.radius = radius
y = self.ball(x)
return y
[docs]class L10Ball(ProxOperator):
r""":math:`L_{1,0}` ball proximal operator.
Proximal operator of the :math:`L_{1,0}` ball: :math:`L_{1,0}^{r} =
\{ \mathbf{x}: \text{count}([||\mathbf{x}_1||_1, ||\mathbf{x}_2||_1, ...,
||\mathbf{x}_1||_1] \ne 0) \leq r \}`
Parameters
----------
ndim : :obj:`int`
Number of dimensions :math:`N_{dim}`. Used to reshape the input array
in a matrix of size :math:`N_{dim} \times N'_{x}` where
:math:`N'_x = \frac{N_x}{N_{dim}}`. Note that the input
vector ``x`` should be created by stacking vectors from different
dimensions.
radius : :obj:`int` or :obj:`func`, optional
Radius. This can be a constant number or a function that is called passing a
counter which keeps track of how many times the ``prox`` method has been
invoked before and returns a scalar ``radius`` to be used.
Notes
-----
As the :math:`L_{1,0}` ball is an indicator function, the proximal operator
corresponds to its orthogonal projection
(see :class:`pyproximal.projection.L10BallProj` for details.
"""
def __init__(self, ndim: int, radius: IntCallableLike) -> None:
super().__init__(None, False)
self.ndim = ndim
self.radius = radius
if callable(radius):
radius_resolved = radius(0)
else:
radius_resolved = radius
self.ball = L10BallProj(radius_resolved)
self.count = 0
def __call__(self, x: NDArray, tol: float = 1e-4) -> bool:
x = x.reshape(self.ndim, len(x) // self.ndim)
radius = _current_radius(self.radius, self.count)
return bool(np.linalg.norm(np.linalg.norm(x, ord=1, axis=0), ord=0) <= radius)
def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]:
"""Increment counter"""
def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any:
self.count += 1
return func(self, *args, **kwargs)
return wrapped
@_increment_count
@_check_tau
def prox(self, x: NDArray, tau: float) -> NDArray:
x = x.reshape(self.ndim, len(x) // self.ndim)
radius = _current_radius(self.radius, self.count)
self.ball.radius = radius
y = self.ball(x)
return y.ravel()
class L01Ball(L10Ball):
def __init__(self, ndim: int, radius: IntCallableLike) -> None:
warnings.warn(
"The L01Ball class has been renamed L10Ball due "
"to a mistake in the original choice of the name. As such "
"L01Ball will be deprecated in v1.0.0.",
FutureWarning,
)
super().__init__(ndim, radius)