Source code for pyproximal.proximal.RelaxedMS
from typing import Any, Callable, Union
import numpy as np
from pylops.utils.typing import NDArray
from typing_extensions import Self
from pyproximal.proximal.L1 import _current_sigma
from pyproximal.ProxOperator import ProxOperator, _check_tau
from pyproximal.utils.typing import FloatCallableLike
def _l2(x: NDArray, alpha: float) -> NDArray:
r"""Scaling operation.
Applies the proximal of ``alpha||y - x||_2^2`` which is essentially a scaling operation.
Parameters
----------
x : :obj:`numpy.ndarray`
Vector
alpha : :obj:`float`
Scaling parameter
Returns
-------
y : :obj:`numpy.ndarray`
Scaled vector
"""
y = 1 / (1 + 2 * alpha) * x
return y
def _current_kappa(
kappa: FloatCallableLike,
count: int,
) -> Union[float, NDArray]:
if not callable(kappa):
return kappa
else:
return kappa(count)
[docs]class RelaxedMumfordShah(ProxOperator):
r"""Relaxed Mumford-Shah norm proximal operator.
Proximal operator of the relaxed Mumford-Shah norm:
:math:`\text{rMS}(x) = \min (\alpha\Vert x\Vert_2^2, \kappa)`.
Parameters
----------
sigma : :obj:`float` or :obj:`np.ndarray` or :obj:`func`, optional
Multiplicative coefficient of L2 norm that controls the smoothness of the solutuon.
This can be a constant number, a list of values (for multidimensional inputs, acting
on the second dimension) or a function that is called passing a counter which keeps
track of how many times the ``prox`` method has been invoked before and returns a
scalar (or a list of) ``sigma`` to be used.
kappa : :obj:`float` or :obj:`np.ndarray` or :obj:`func`, optional
Constant value in the rMS norm which essentially controls when the norm allows a jump. This can be a
constant number, a list of values (for multidimensional inputs, acting on the second dimension) or
a function that is called passing a counter which keeps track of how many
times the ``prox`` method has been invoked before and returns a scalar (or a list of)
``kappa`` to be used.
Notes
-----
The :math:`rMS` proximal operator is defined as [1]_:
.. math::
\text{prox}_{\tau \text{rMS}}(x) =
\begin{cases}
\frac{1}{1+2\tau\alpha}x & \text{ if } & \vert x\vert \leq \sqrt{\frac{\kappa}{\alpha}(1 + 2\tau\alpha)} \\
\kappa & \text{ else }
\end{cases}.
.. [1] Strekalovskiy, E., and D. Cremers, 2014, Real-time minimization of the piecewise smooth
Mumford-Shah functional: European Conference on Computer Vision, 127–141.
"""
def __init__(
self,
sigma: FloatCallableLike = 1.0,
kappa: FloatCallableLike = 1.0,
) -> None:
super().__init__(None, False)
self.sigma = sigma
self.kappa = kappa
self.count = 0
def __call__(self, x: NDArray) -> float:
sigma = _current_sigma(self.sigma, self.count)
kappa = _current_sigma(self.kappa, self.count)
return float(np.minimum(sigma * np.linalg.norm(x) ** 2, kappa))
def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]:
"""Increment counter"""
def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any:
self.count += 1
return func(self, *args, **kwargs)
return wrapped
@_increment_count
@_check_tau
def prox(self, x: NDArray, tau: float) -> NDArray:
sigma = _current_sigma(self.sigma, self.count)
kappa = _current_sigma(self.kappa, self.count)
x = np.where(
np.abs(x) <= np.sqrt(kappa / sigma * (1 + 2 * tau * sigma)),
_l2(x, tau * sigma),
x,
)
return x