Source code for pyproximal.ProxOperator

from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import numpy as np
from pylops.utils.typing import NDArray

from pyproximal.utils.backend import cp_dtype

if TYPE_CHECKING:
    from pylops.linearoperator import LinearOperator


def _check_tau(func: Callable[..., NDArray]) -> Callable[..., NDArray]:
    """Check that tau>0

    This utility function is used to decorate every prox and dualprox method
    to check that tau is positive before performing any computation

    """

    def wrapper(*args: Any, **kwargs: Any) -> Any:
        if np.any(args[2] <= 0):
            raise ValueError("tau must be positive")
        return func(*args, **kwargs)

    return wrapper


[docs]class ProxOperator(object): r"""Common interface for proximal operators of a function. This class defines the overarching structure of any proximal operator. It contains two main methods, ``prox`` and ``dualprox`` which are both implemented by means of the Moreau decomposition assuming explicit knowledge of the other method. For this reason any proximal operators that subclasses the ``ProxOperator`` class needs at least one of these two methods to be implemented directly. Moreover, the method ``grad`` is also defined to compute the gradient of the Moreau envelope of the function. This function is only called if the user does not provide a gradient function when creating the proximal operator. The variable ``hasgrad`` is used to indicate if the function has a gradient or not (and thus if the ``grad`` method computes the gradient of the actual function or of its Moreau envelope). .. note:: End users of PyProximal should not use this class directly but simply use operators that are already implemented. This class is meant for developers and it has to be used as the parent class of any new operator developed within PyProximal. Find more details regarding implementation of new operators at :ref:`addingoperator`. Parameters ---------- Op : :obj:`pylops.LinearOperator`, optional Linear operator used by the Proximal operator hasgrad : :obj:`bool`, optional Flag to indicate if the function is differentiable, i.e., has a uniquely defined gradient (``True``) or not (``False``). sigmame : :obj:`float`, optional Relaxation parameter of the Moreau envelope (when ``sigmame`` tends to infinity the gradient of the Moreau envelope tends to the gradient of the function itself). Refer to the docstring of the ``grad`` method for more details. Notes ----- The proximal operator of a function ``f`` is defined as: .. math:: prox_{\tau f} (\mathbf{x}) = \argmin_{\mathbf{y}} f(\mathbf{y}) + \frac{1}{2 \tau}||\mathbf{y} - \mathbf{x}||^2_2 """ def __init__( self, Op: Optional["LinearOperator"] = None, hasgrad: bool = False, sigmame: float = 1.0, ) -> None: self.Op = Op self.hasgrad = hasgrad self.sigmame = sigmame def __call__(self, x: NDArray) -> bool | float | int: """Functional evaluation of the operator. Subclasses should implement this. Returns the value of the function. """ raise NotImplementedError( "This ProxOperator's __call__ method " "must be implemented by subclasses to return a float." ) @_check_tau def _prox_moreau(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: """Proximal operator applied to a vector via Moreau decomposition""" p = x - tau * self.proxdual(x / tau, 1.0 / tau, **kwargs) return p @_check_tau def _proxdual_moreau(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: """Dual proximal operator applied to a vector via Moreau decomposition""" pdual = x - tau * self.prox(x / tau, 1.0 / tau, **kwargs) return pdual @_check_tau def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: """Proximal operator applied to a vector The proximal operator can always be computed given its dual proximal operator using the Moreau decomposition as defined in :func:`pyproximal.moreau`. For this reason we can easily create a common method for all proximal operators that can be evaluated provided the dual proximal is implemented. However, direct implementations are generally available. This can be done by simply implementing ``prox`` for a specific proximal operator, which will overwrite the general method. Parameters ---------- x : :obj:`numpy.ndarray` Vector tau : :obj:`float` Positive scalar weight """ return self._prox_moreau(x, tau, **kwargs) @_check_tau def proxdual(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: """Dual proximal operator applied to a vector The dual of a proximal operator can always be computed given its proximal operator using the Moreau decomposition as defined in :func:`pyproximal.moreau`. For this reason we can easily create a common method for all dual proximal operators that can be evaluated provided the proximal is implemented. However, since the dual of a proximal operator of a function is equivalent to the proximal operator of the conjugate function, smarter and faster implementation may be available in special cases. This can be done by simply implementing ``proxdual`` for a specific proximal operator, which will overwrite the general method. Parameters ---------- x : :obj:`numpy.ndarray` Vector tau : :obj:`float` Positive scalar weight """ return self._proxdual_moreau(x, tau, **kwargs) def grad(self, x: NDArray) -> NDArray: r"""Gradient of the Moreau envelope of the function. This method is only called if the user does not provide a gradient because the function is not differentiable. In this case, the gradient of the Moreau envelope of the function is computed instead: .. math:: \nabla_\mathbf{x} M_{\sigma f) = \frac{1}{sigma} (\mathbf{x} - \prox_{\sigma f}(\mathbf{x})) Parameters ---------- x : :obj:`numpy.ndarray` Vector Returns ------- g : :obj:`numpy.ndarray` Gradient vector """ g = (x - self.prox(x, self.sigmame)) / self.sigmame return g def affine_addition(self, v: NDArray) -> "ProxOperator": r"""Affine addition Adds the dot-product of vector ``v`` and vector ``x`` (which is passed to ``dual`` or ``proxdual``) to the current function. This method can also be accessed via the ``+`` operator. Parameters ---------- v : :obj:`numpy.ndarray` Vector Notes ----- The proximal operator of a function :math:`g=f(\mathbf{x}) + \mathbf{v}^T \mathbf{x}` is defined as: .. math:: prox_{\tau g} (\mathbf{x}) = prox_{\tau f} (\mathbf{x} - \tau \mathbf{v}) """ if isinstance(v, (np.ndarray, cp_dtype)): return _SumOperator(self, v) else: raise NotImplementedError("v must be a numpy.ndarray or cupy.ndarray") def postcomposition(self, sigma: float) -> "ProxOperator": r"""Postcomposition Multiplies a scalar ``sigma`` to the current function. This method can also be accessed via the ``*`` operator. Parameters ---------- sigma : :obj:`float` Scalar Notes ----- The proximal operator of a function :math:`g= \sigma f(\mathbf{x})` is defined as: .. math:: prox_{\tau g} (\mathbf{x}) = prox_{\sigma \tau f} (\mathbf{x}) """ if isinstance(sigma, float): return _PostcompositionOperator(self, sigma) else: raise NotImplementedError("sigma must be of type float") def precomposition(self, a: float, b: float | NDArray) -> "ProxOperator": r"""Precomposition Multiplies and add scalars ``a`` and ``b`` to ``x`` when evaluating the proximal function Parameters ---------- a : :obj:`float` Multiplicative scalar b : :obj:`float` or obj:`numpy.ndarray` or obj:`cupy.ndarray` Additive scalar (or vector) Notes ----- The proximal operator of a function :math:`g= f(a \mathbf{x} + b)` is defined as: .. math:: prox_{\tau g} (\mathbf{x}) = \frac{1}{a} ( prox_{a^2 \tau f} (a \mathbf{x} + b) - b) """ if isinstance(a, float) and isinstance(b, (float, np.ndarray, cp_dtype)): # type: ignore[redundant-expr] return _PrecompositionOperator(self, a, b) else: raise NotImplementedError( "a must be of type float and b " "must be of type float or " "numpy.ndarray" ) def chain(self, g: "ProxOperator") -> "ProxOperator": r"""Chain Chains two proximal operators. This must be used with care only when aware that the combination of two proximal operators can be simply obtained by chaining them Parameters ---------- g : :obj:`pyproximal.proximal.ProxOperator` Rigth operator Notes ----- The proximal operator of the chain of two operators is defined as: .. math:: prox_{\tau f g} (\mathbf{x}) = prox_{\tau g}(prox_{\tau f g}(x)) """ return _ChainOperator(self, g) def __add__(self, v: NDArray) -> "ProxOperator": return self.affine_addition(v) def __sub__(self, v: NDArray) -> "ProxOperator": return self.__add__(-v) def __rmul__(self, sigma: Union[float, int, "ProxOperator"]) -> "ProxOperator": if isinstance(sigma, (int, float)): return self.postcomposition(sigma) else: return self.chain(sigma) # __rmul__ = __mul__ def _adjoint(self) -> "_AdjointOperator": """Adjoint operator - swaps prox and proxdual""" return _AdjointOperator(self) H = property(_adjoint)
class _AdjointOperator(ProxOperator): def __init__(self, f: "ProxOperator") -> None: self.f = f super().__init__(None, f.hasgrad) def __call__(self, x: NDArray) -> bool | float | int: return self.f(x) @_check_tau def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: return self.f.proxdual(x, tau, **kwargs) @_check_tau def proxdual(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: return self.f.prox(x, tau, **kwargs) class _SumOperator(ProxOperator): def __init__(self, f: ProxOperator, v: NDArray) -> None: # if not isinstance(f, ProxOperator): # raise ValueError('First input must be a ProxOperator') if not isinstance(v, (np.ndarray, cp_dtype)): raise ValueError("Second input must be a numpy.ndarray or cupy.ndarray") self.f, self.v = f, v super().__init__(None, f.hasgrad) def __call__(self, x: NDArray) -> float: f: float = self.f(x) + np.dot(self.v, x) return f @_check_tau def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: return self.f.prox(x - tau * self.v, tau) def grad(self, x: NDArray) -> NDArray: return self.f.grad(x) + self.v class _ChainOperator(ProxOperator): def __init__(self, f: ProxOperator, g: ProxOperator) -> None: # if not isinstance(f, ProxOperator) or not isinstance(g, ProxOperator): # raise ValueError('Inputs must be a ProxOperator') self.f, self.g = f, g super().__init__(None, f.hasgrad and g.hasgrad) def __call__(self, x: NDArray) -> bool | float | int: return self.g(self.f(x)) @_check_tau def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: return self.g.prox(self.f.prox(x, tau), tau) def grad(self, x: NDArray) -> NDArray: pass class _PostcompositionOperator(ProxOperator): def __init__(self, f: ProxOperator, sigma: float) -> None: # if not isinstance(f, ProxOperator): # raise ValueError('First input must be a ProxOperator') if not isinstance(sigma, float): raise ValueError("Second input must be a float") self.f, self.sigma = f, sigma super().__init__(None, f.hasgrad) def __call__(self, x: NDArray) -> bool | float | int: return self.sigma * self.f(x) @_check_tau def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: return self.f.prox(x, self.sigma * tau) def grad(self, x: NDArray) -> NDArray: return self.sigma * self.f.grad(x) class _PrecompositionOperator(ProxOperator): def __init__(self, f: ProxOperator, a: float, b: float | NDArray) -> None: # if not isinstance(f, ProxOperator): # raise ValueError('First input must be a ProxOperator') if not isinstance(a, float): raise ValueError("Second input must be a float") if not isinstance(b, (float, np.ndarray, cp_dtype)): raise ValueError( "Third input must be a float, numpy.ndarray, or cupy.ndarray" ) self.f, self.a, self.b = f, a, b super().__init__(None, f.hasgrad) def __call__(self, x: NDArray) -> NDArray: return self.f(self.a * x + self.b) @_check_tau def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: return (self.f.prox(self.a * x + self.b, (self.a**2) * tau) - self.b) / self.a def grad(self, x: NDArray) -> NDArray: return self.a * self.f.grad(self.a * x + self.b)