Source code for pyproximal.proximal.RelaxedMS

from collections.abc import Callable
from typing import Any

import numpy as np
from pylops.utils.typing import NDArray
from typing_extensions import Self

from pyproximal.proximal.L1 import _current_sigma
from pyproximal.ProxOperator import ProxOperator, _check_tau
from pyproximal.utils.typing import FloatCallableLike


def _l2(x: NDArray, alpha: float) -> NDArray:
    r"""Scaling operation.

    Applies the proximal of ``alpha||y - x||_2^2`` which is essentially a scaling operation.

    Parameters
    ----------
    x : :obj:`numpy.ndarray`
        Vector
    alpha : :obj:`float`
        Scaling parameter

    Returns
    -------
    y : :obj:`numpy.ndarray`
        Scaled vector

    """
    y = 1 / (1 + 2 * alpha) * x
    return y


[docs] class RelaxedMumfordShah(ProxOperator): r"""Relaxed Mumford-Shah norm proximal operator. Proximal operator of the relaxed Mumford-Shah norm: :math:`\text{rMS}(x) = \min (\alpha\Vert x\Vert_2^2, \kappa)`. Parameters ---------- sigma : :obj:`float` or :obj:`numpy.ndarray` or :obj:`func`, optional Multiplicative coefficient of L2 norm that controls the smoothness of the solutuon. This can be a constant number, a list of values (for multidimensional inputs, acting on the second dimension) or a function that is called passing a counter which keeps track of how many times the ``prox`` method has been invoked before and returns a scalar (or a list of) ``sigma`` to be used. kappa : :obj:`float` or :obj:`numpy.ndarray` or :obj:`func`, optional Constant value in the rMS norm which essentially controls when the norm allows a jump. This can be a constant number, a list of values (for multidimensional inputs, acting on the second dimension) or a function that is called passing a counter which keeps track of how many times the ``prox`` method has been invoked before and returns a scalar (or a list of) ``kappa`` to be used. Notes ----- The :math:`rMS` proximal operator is defined as [1]_: .. math:: \text{prox}_{\tau \text{rMS}}(\mathbf{x}) = \begin{cases} \frac{1}{1+2\tau\alpha}x_i & \text{ if } & \vert x_i\vert \leq \sqrt{\frac{\kappa}{\alpha}(1 + 2\tau\alpha)} \\ x_i & \text{ else } \end{cases}. .. [1] Strekalovskiy, E., and D. Cremers, 2014, Real-time minimization of the piecewise smooth Mumford-Shah functional: European Conference on Computer Vision, 127–141. """ def __init__( self, sigma: FloatCallableLike = 1.0, kappa: FloatCallableLike = 1.0, ) -> None: super().__init__(None, False) self.sigma = sigma self.kappa = kappa self.count = 0 def __call__(self, x: NDArray) -> float: sigma = _current_sigma(self.sigma, self.count) kappa = _current_sigma(self.kappa, self.count) return float(np.minimum(sigma * np.linalg.norm(x) ** 2, kappa)) def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]: """Increment counter""" def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any: self.count += 1 return func(self, *args, **kwargs) return wrapped @_increment_count @_check_tau def prox(self, x: NDArray, tau: float) -> NDArray: sigma = _current_sigma(self.sigma, self.count) kappa = _current_sigma(self.kappa, self.count) x = np.where( np.abs(x) <= np.sqrt(kappa / sigma * (1 + 2 * tau * sigma)), _l2(x, tau * sigma), x, ) return x