Source code for pyproximal.proximal.RED

from collections.abc import Callable
from typing import Any

from pylops.utils.backend import get_array_module
from pylops.utils.typing import NDArray, ShapeLike
from typing_extensions import Self

from pyproximal.proximal.L1 import _current_sigma
from pyproximal.ProxOperator import ProxOperator, _check_tau
from pyproximal.utils.typing import FloatCallableLike


class _Denoise:
    r"""Denoiser of choice

    Parameters
    ----------
    denoiser : :obj:`func`
        Denoiser (must be a function with two inputs, the first is the signal
        to be denoised, the second is the strenght of the denoiser `sigma`)
    dims : :obj:`tuple`
        Dimensions used to reshape the vector ``x`` in the ``prox`` method
        prior to calling the ``denoiser``

    """

    def __init__(
        self,
        denoiser: Callable[[NDArray, float], NDArray],
        dims: ShapeLike,
    ) -> None:
        self.denoiser = denoiser
        self.dims = dims

    def __call__(self, x: NDArray, tau: float) -> NDArray:
        x = x.reshape(self.dims)
        xden = self.denoiser(x, tau)
        return xden.ravel()


[docs] class RED(ProxOperator): r"""Regularization by Denoising (RED) Regularization by Denoising loss [1]_: :math:`RED(\mathbf{x}) = \sigma\mathbf{x}^T (\mathbf{x} - f_{\sigma_d}(\mathbf{x}))` Parameters ---------- denoiser : :obj:`func` Denoiser (must be a function with one input corresponding to the signal to be denoised) dims : :obj:`tuple` Dimensions used to reshape the vector ``x`` in the ``denoiser`` method prior to applying the denoiser sigma : :obj:`float`, optional Multiplicative coefficient of RED term sigmad : :obj:`float` or :obj:`func`, optional Strenght of the denoiser. This can be a constant number or a function that is called passing a counter which keeps track of how many times the ``grad`` or ``prox`` methods has been invoked before and returns a scalar (or a list of) ``sigma`` to be used x0 : :obj:`numpy.ndarray`, optional Initial vector of iterative scheme used to compute the proximal niter : :obj:`int`, optional Number of iterations of iterative scheme used to compute the proximal warm : :obj:`bool`, optional Warm start (``True``) or not (``False``). Uses estimate from previous call of ``prox`` method. call : :obj:`bool`, optional Evalutate call method (``True``) or not (``False``) Notes ----- The gradient of the RED loss is defined as: .. math:: \nabla_\mathbf{x} RED(\mathbf{x}) = \sigma (\mathbf{x} - f_{\sigma_d}(\mathbf{x})) whilst the proximal operator is obtained by solving the minimization problem .. math:: prox_{\tau RED} (\mathbf{x}) = \argmin_{\mathbf{y}} RED(\mathbf{y}) + \frac{1}{2 \tau}||\mathbf{y} - \mathbf{x}||^2_2 via the following fixed-point iteration: .. math:: \mathbf{y}^k = \frac{1}{\beta + \sigma} (\sigma f_{\sigma_d}(\mathbf{y}^{k-1}) + \beta \mathbf{x}) where :math:`\beta=1/\tau`. References ---------- .. [1] Romano, Y., Elad, M., and Milanfar, P. "The Little Engine that Could Regularization by Denoising (RED)", SIAM Journal on Imaging Science. 2017. """ def __init__( self, denoiser: Callable[[NDArray, float], NDArray], dims: ShapeLike, sigma: float = 1.0, sigmad: FloatCallableLike = 1.0, x0: NDArray | None = None, niter: int = 10, warm: bool = True, call: bool = True, ) -> None: super().__init__(None, False) self.denoiser = _Denoise(denoiser, dims=dims) self.sigma = sigma self.sigmad = sigmad self.x0 = x0 self.niter = niter self.warm = warm self.call = call self.count = 0 def __call__(self, x: NDArray) -> bool | float: """Evaluate RED loss Parameters ---------- x : :obj:`numpy.ndarray` Vector Returns ------- :obj:`float` - return ``0.0`` immediately if ``call=False`` - return dot-product of the input and residual if ``call=True`` """ if not self.call: return 0.0 else: ncp = get_array_module(x) sigmad = _current_sigma(self.sigmad, self.count) res = self.sigma * (x - self.denoiser(x, sigmad)) return float(ncp.dot(x, res)) def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]: """Increment counter""" def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any: self.count += 1 return func(self, *args, **kwargs) return wrapped @_increment_count @_check_tau def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: ncp = get_array_module(x) beta = 1.0 / tau sigmad = _current_sigma(self.sigmad, self.count) # Define starting guess if self.x0 is None: sol = ncp.zeros_like(x) else: sol = self.x0 # Fixed point iterations for _ in range(self.niter): den = self.denoiser(sol, sigmad) sol = (self.sigma * den + beta * x) / (self.sigma + beta) if self.warm: self.x0 = sol return sol @_increment_count def grad(self, x: NDArray) -> NDArray: sigmad = _current_sigma(self.sigmad, self.count) res = x - self.denoiser(x, sigmad) return self.sigma * res