Source code for pyproximal.proximal.L2

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional

import numpy as np
from pylops import Identity, MatrixMult
from pylops.optimization.basic import cg, lsqr
from pylops.optimization.leastsquares import regularized_inversion
from pylops.utils.backend import (
    get_array_module,
    get_module_name,
    get_normalize_axis_index,
)
from pylops.utils.typing import NDArray, ShapeLike
from scipy.linalg import cho_factor, cho_solve
from scipy.sparse.linalg import lsqr as sp_lsqr
from typing_extensions import Self

from pyproximal.ProxOperator import ProxOperator, _check_tau

if TYPE_CHECKING:
    from pylops.linearoperator import LinearOperator


[docs] class L2(ProxOperator): r"""L2 Norm proximal operator. The Proximal operator of the :math:`\ell_2` norm is defined as: :math:`f(\mathbf{x}) = \frac{\sigma}{2} ||\mathbf{Op}\mathbf{x} - \mathbf{b}||_2^2` and :math:`f_\alpha(\mathbf{x}) = f(\mathbf{x}) + \alpha \mathbf{q}^T\mathbf{x}`. Parameters ---------- Op : :obj:`pylops.LinearOperator`, optional Linear operator b : :obj:`numpy.ndarray`, optional Data vector q : :obj:`numpy.ndarray`, optional Dot vector sigma : :obj:`int`, optional Multiplicative coefficient of L2 norm alpha : :obj:`float`, optional Multiplicative coefficient of dot product qgrad : :obj:`bool`, optional Add q term to gradient (``True``) or not (``False``) niter : :obj:`int` or :obj:`func`, optional Number of iterations of iterative scheme used to compute the proximal. This can be a constant number or a function that is called passing a counter which keeps track of how many times the ``prox`` method has been invoked before and returns the ``niter`` to be used. x0 : :obj:`numpy.ndarray`, optional Initial vector warm : :obj:`bool`, optional Warm start (``True``) or not (``False``). Uses estimate from previous call of ``prox`` method. solver : :obj:`str`, optional .. versionadded:: 0.11.0 Name of solver to use with non-explicit operators: - ``legacy``: enforces the legacy behaviour where :py:func:`scipy.sparse.linalg.lsqr` is used with numpy data and :py:func:`pylops.optimization.basic.lsqr` is used with cupy data (both are applied to the normal equations); - ``cg`` to use :py:func:`pylops.optimization.basic.cg` on the normal equations; - ``cgls`` to use :py:func:`pylops.optimization.basic.cgls` on the regularized system of equations; densesolver : :obj:`str`, optional Use ``numpy``, ``scipy``, or ``factorize`` when dealing with explicit operators. The former two rely on dense solvers from either library, whilst the last computes a factorization of the matrix to invert and avoids to do so unless the :math:`\tau` or :math:`\sigma` paramets have changed. Choose ``densesolver=None`` when using PyLops versions earlier than v1.18.1 or v2.0.0 **kwargs_solver : :obj:`dict`, optional Dictionary containing extra arguments for the solver selected via the ``solver`` parameter. Notes ----- The L2 proximal operator is defined as: .. math:: prox_{\tau f_\alpha}(\mathbf{x}) = \left(\mathbf{I} + \tau \sigma \mathbf{Op}^T \mathbf{Op} \right)^{-1} \left( \mathbf{x} + \tau \sigma \mathbf{Op}^T \mathbf{b} - \tau \alpha \mathbf{q}\right) when both ``Op`` and ``b`` are provided. This formula shows that the proximal operator requires the solution of an inverse problem. If the operator ``Op`` is of kind ``explicit=True``, we can solve this problem directly. On the other hand if ``Op`` is of kind ``explicit=False``, an iterative solver is employed. In this case it is possible to provide a warm start via the ``x0`` input parameter. Note that alternatively the proximal operator can be computed solving the following augumented system of equations (instead of its normal equations as shown above): .. math:: \begin{bmatrix} \sqrt{\tau \sigma} \mathbf{Op} \\ \mathbf{I} \end{bmatrix} prox_{\tau f_\alpha}(\mathbf{x}) = \begin{bmatrix} \sqrt{\tau \sigma} \mathbf{b} \\ \mathbf{x} - \tau \alpha \mathbf{q} \end{bmatrix} The choice of the parameter ``solver`` determines which of the two approaches is used. Alternatively, when only ``b`` is provided, ``Op`` is assumed to be an Identity operator and the proximal operator reduces to: .. math:: \prox_{\tau f_\alpha}(\mathbf{x}) = \frac{\mathbf{x} + \tau \sigma \mathbf{b} - \tau \alpha \mathbf{q}} {1 + \tau \sigma} If ``b`` is not provided, the proximal operator reduces to: .. math:: \prox_{\tau f_\alpha}(\mathbf{x}) = \frac{\mathbf{x} - \tau \alpha \mathbf{q}}{1 + \tau \sigma} Finally, note that the second term in :math:`f_\alpha(\mathbf{x})` is added because this combined expression appears in several problems where Bregman iterations are used alongside a proximal solver. """ def __init__( self, Op: Optional["LinearOperator"] = None, b: NDArray | None = None, q: NDArray | None = None, sigma: float = 1.0, alpha: float = 1.0, qgrad: bool = True, niter: int | Callable[[int], int] = 10, x0: NDArray | None = None, warm: bool = True, solver: str | None = "legacy", densesolver: str | None = None, kwargs_solver: dict[str, Any] | None = None, ) -> None: super().__init__(Op, True) self.b = b self.q = q self.sigma = sigma self.alpha = alpha self.qgrad = qgrad self.niter = niter self.x0 = x0 self.warm = warm self.solver = solver self.densesolver = densesolver self.count = 0 self.kwargs_solver = {} if kwargs_solver is None else kwargs_solver # define whether the normal equations or the regularized system # of equations are solved if self.solver in ("legacy", "cg"): self.normaleqs = True elif self.solver == "cgls": self.normaleqs = False else: msg = ( f"Provided solver={self.solver}. " "Available options are: 'legacy', 'cg', 'cgls'." ) raise ValueError(msg) # when using factorize, store the first tau*sigma=0 so that the # first time it will be recomputed (as tau cannot be 0) if self.densesolver == "factorize": self.tausigma = 0.0 # create data term if ( self.Op is not None and self.b is not None and (self.Op.explicit or self.normaleqs) ): self.OpTb = self.sigma * self.Op.H @ self.b # create A.T A upfront for explicit operators if self.Op.explicit: self.ATA = np.conj(self.Op.A.T) @ self.Op.A def __call__(self, x: NDArray) -> float: if self.Op is not None and self.b is not None: f = (self.sigma / 2.0) * (np.linalg.norm(self.Op * x - self.b) ** 2) elif self.b is not None: f = (self.sigma / 2.0) * (np.linalg.norm(x - self.b) ** 2) else: f = (self.sigma / 2.0) * (np.linalg.norm(x) ** 2) if self.q is not None: f += self.alpha * np.dot(self.q, x) return float(f) def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]: """Increment counter""" def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any: self.count += 1 return func(self, *args, **kwargs) return wrapped @_increment_count @_check_tau def prox(self, x: NDArray, tau: float) -> NDArray: # define current number of iterations if isinstance(self.niter, int): niter = self.niter else: niter = self.niter(self.count) # solve proximal optimization if self.Op is not None and self.b is not None: if self.normaleqs or self.Op.explicit: y = x + tau * self.OpTb if self.q is not None: y -= tau * self.alpha * self.q if self.Op.explicit: if self.densesolver != "factorize": Op1 = MatrixMult( np.eye(self.Op.shape[1]) + tau * self.sigma * self.ATA ) if self.densesolver is None: # to allow backward compatibility with # PyLops versions earlier than v1.18.1 # and v2.0.0 x = Op1.div(y) else: x = Op1.div(y, densesolver=self.densesolver) else: if self.tausigma != tau * self.sigma: # recompute factorization self.tausigma = tau * self.sigma ATA = np.eye(self.Op.shape[1]) + self.tausigma * self.ATA self.cl = cho_factor(ATA) x = cho_solve(self.cl, y) elif self.normaleqs: Op1 = Identity(self.Op.shape[1], dtype=self.Op.dtype) + float( tau * self.sigma ) * (self.Op.H * self.Op) if self.solver == "legacy": if get_module_name(get_array_module(x)) == "numpy": x = sp_lsqr( Op1, y, iter_lim=niter, x0=self.x0, **self.kwargs_solver )[0] else: x = lsqr(Op1, y, niter=niter, x0=self.x0, **self.kwargs_solver)[ 0 ].ravel() elif self.solver == "cg": x = cg(Op1, y, niter=niter, x0=self.x0, **self.kwargs_solver)[ 0 ].ravel() else: y = x if self.q is not None: y -= tau * self.alpha * self.q x = regularized_inversion( np.sqrt(tau * self.sigma) * self.Op, np.sqrt(tau * self.sigma) * self.b, [ Identity(self.Op.shape[1], dtype=self.Op.dtype), ], x0=self.x0, dataregs=[ y, ], niter=niter, engine="pylops", **self.kwargs_solver, )[0].ravel() if self.warm: self.x0 = x elif self.b is not None: num = x + tau * self.sigma * self.b if self.q is not None: num -= tau * self.alpha * self.q x = num / (1.0 + tau * self.sigma) else: num = x if self.q is not None: num -= tau * self.alpha * self.q x = num / (1.0 + tau * self.sigma) return x def grad(self, x: NDArray) -> NDArray: if self.Op is not None and self.b is not None: g = self.sigma * self.Op.H @ (self.Op @ x - self.b) elif self.b is not None: g = self.sigma * (x - self.b) else: g = self.sigma * x if self.q is not None and self.qgrad: g += self.alpha * self.q return g
[docs] class L2Convolve(ProxOperator): r"""L2 Norm proximal operator with convolution operator Proximal operator for the L2 norm defined as: :math:`f(\mathbf{x}) = \frac{\sigma}{2} ||\mathbf{h} * \mathbf{x} - \mathbf{b}||_2^2` where :math:`\mathbf{h}` is the kernel of a convolution operator and :math:`*` represents convolution Parameters ---------- h : :obj:`numpy.ndarray` Kernel of convolution operator b : :obj:`numpy.ndarray` Data vector nfft : :obj:`int`, optional Fourier transform number of samples sigma : :obj:`int`, optional Multiplicative coefficient of L2 norm dims : :obj:`tuple`, optional Number of samples for each dimension (``None`` if only one dimension is available) dir : :obj:`int`, optional Direction along which smoothing is applied. Notes ----- The L2Convolve proximal operator is defined as: .. math:: prox_{\tau f}(\mathbf{x}) = F^{-1}\left(\frac{\tau\sigma F(\mathbf{h})^* F(\mathbf{b}) + F(\mathbf{x})} {1 + \tau\sigma F(\mathbf{h})^* F(\mathbf{h})} \right) """ def __init__( self, h: NDArray, b: NDArray, nfft: int = 2**10, sigma: float = 1.0, dims: ShapeLike | None = None, dir: int | None = None, ) -> None: super().__init__(None, True) self.nfft = nfft self.sigma = sigma self.dims = dims if dims is None: self.dir = -1 else: self.dir = ( len(dims) - 1 if dir is None else get_normalize_axis_index()(dir, len(dims)) ) # convert data and filter to Fourier domain self.bf = np.fft.fft(b, self.nfft, axis=self.dir) self.hf = np.fft.fft(h, self.nfft, axis=0 if h.ndim == 1 else self.dir) # expand dimensions of filters if self.dims is not None: self.dimsf = list(self.dims).copy() self.dimsf[self.dir] = nfft self.bf = self.bf.reshape(self.dimsf) ndims = len(self.dims) if self.hf.ndim == 1: for _ in range(self.dir): self.hf = np.expand_dims(self.hf, axis=0) for _ in range(ndims - self.dir - 1): self.hf = np.expand_dims(self.hf, axis=-1) # precompute terms for prox self.hbf = np.conj(self.hf) * self.bf self.h2f = np.abs(self.hf) ** 2 def __call__(self, x: NDArray) -> float: if self.dims is not None: x = x.reshape(self.dims) xf = np.fft.fft(x, self.nfft, axis=self.dir) f = (self.sigma / 2.0) * np.linalg.norm( np.fft.ifft(self.bf - self.hf * xf, axis=self.dir) ) ** 2 return float(f) @_check_tau def prox(self, x: NDArray, tau: float) -> NDArray: if self.dims is not None: x = x.reshape(self.dims) xf = np.fft.fft(x, self.nfft, axis=self.dir) yf = (xf + self.sigma * tau * self.hbf) / (1.0 + self.sigma * tau * self.h2f) y = np.fft.ifft(yf, axis=self.dir) if self.dims is None: y = y[: len(x)] else: y = np.take(y, range(self.dims[self.dir]), axis=self.dir).ravel() return y def grad(self, x: NDArray) -> NDArray: if self.dims is not None: x = x.reshape(self.dims) xf = np.fft.fft(x, self.nfft, axis=self.dir) g = self.sigma * np.fft.ifft( np.conj(self.hf) * (self.hf * xf - self.bf), axis=self.dir ) g = np.take(g, range(x.shape[self.dir]), axis=self.dir).ravel() return g