from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional
import numpy as np
from pylops import Identity, MatrixMult
from pylops.optimization.basic import cg, lsqr
from pylops.optimization.leastsquares import regularized_inversion
from pylops.utils.backend import (
get_array_module,
get_module_name,
get_normalize_axis_index,
)
from pylops.utils.typing import NDArray, ShapeLike
from scipy.linalg import cho_factor, cho_solve
from scipy.sparse.linalg import lsqr as sp_lsqr
from typing_extensions import Self
from pyproximal.ProxOperator import ProxOperator, _check_tau
if TYPE_CHECKING:
from pylops.linearoperator import LinearOperator
[docs]
class L2(ProxOperator):
r"""L2 Norm proximal operator.
The Proximal operator of the :math:`\ell_2` norm is defined as: :math:`f(\mathbf{x}) =
\frac{\sigma}{2} ||\mathbf{Op}\mathbf{x} - \mathbf{b}||_2^2`
and :math:`f_\alpha(\mathbf{x}) = f(\mathbf{x}) +
\alpha \mathbf{q}^T\mathbf{x}`.
Parameters
----------
Op : :obj:`pylops.LinearOperator`, optional
Linear operator
b : :obj:`numpy.ndarray`, optional
Data vector
q : :obj:`numpy.ndarray`, optional
Dot vector
sigma : :obj:`int`, optional
Multiplicative coefficient of L2 norm
alpha : :obj:`float`, optional
Multiplicative coefficient of dot product
qgrad : :obj:`bool`, optional
Add q term to gradient (``True``) or not (``False``)
niter : :obj:`int` or :obj:`func`, optional
Number of iterations of iterative scheme used to compute the proximal.
This can be a constant number or a function that is called passing a
counter which keeps track of how many times the ``prox`` method has
been invoked before and returns the ``niter`` to be used.
x0 : :obj:`numpy.ndarray`, optional
Initial vector
warm : :obj:`bool`, optional
Warm start (``True``) or not (``False``). Uses estimate from previous
call of ``prox`` method.
solver : :obj:`str`, optional
.. versionadded:: 0.11.0
Name of solver to use with non-explicit operators:
- ``legacy``: enforces the legacy behaviour where :py:func:`scipy.sparse.linalg.lsqr` is
used with numpy data and :py:func:`pylops.optimization.basic.lsqr` is used with cupy data
(both are applied to the normal equations);
- ``cg`` to use :py:func:`pylops.optimization.basic.cg` on the
normal equations;
- ``cgls`` to use :py:func:`pylops.optimization.basic.cgls` on the
regularized system of equations;
densesolver : :obj:`str`, optional
Use ``numpy``, ``scipy``, or ``factorize`` when dealing with explicit
operators. The former two rely on dense solvers from either library,
whilst the last computes a factorization of the matrix to invert and
avoids to do so unless the :math:`\tau` or :math:`\sigma` paramets
have changed. Choose ``densesolver=None`` when using PyLops versions
earlier than v1.18.1 or v2.0.0
**kwargs_solver : :obj:`dict`, optional
Dictionary containing extra arguments for the solver selected
via the ``solver`` parameter.
Notes
-----
The L2 proximal operator is defined as:
.. math::
prox_{\tau f_\alpha}(\mathbf{x}) =
\left(\mathbf{I} + \tau \sigma \mathbf{Op}^T \mathbf{Op} \right)^{-1}
\left( \mathbf{x} + \tau \sigma \mathbf{Op}^T \mathbf{b} -
\tau \alpha \mathbf{q}\right)
when both ``Op`` and ``b`` are provided. This formula shows that the
proximal operator requires the solution of an inverse problem. If the
operator ``Op`` is of kind ``explicit=True``, we can solve this problem
directly. On the other hand if ``Op`` is of kind ``explicit=False``, an
iterative solver is employed. In this case it is possible to provide a warm
start via the ``x0`` input parameter.
Note that alternatively the proximal operator can be computed solving the following
augumented system of equations (instead of its normal equations as shown above):
.. math::
\begin{bmatrix}
\sqrt{\tau \sigma} \mathbf{Op} \\
\mathbf{I}
\end{bmatrix}
prox_{\tau f_\alpha}(\mathbf{x}) =
\begin{bmatrix}
\sqrt{\tau \sigma} \mathbf{b} \\
\mathbf{x} - \tau \alpha \mathbf{q}
\end{bmatrix}
The choice of the parameter ``solver`` determines which of the two
approaches is used.
Alternatively, when only ``b`` is provided, ``Op`` is assumed to be an
Identity operator and the proximal operator reduces to:
.. math::
\prox_{\tau f_\alpha}(\mathbf{x}) =
\frac{\mathbf{x} + \tau \sigma \mathbf{b} - \tau \alpha \mathbf{q}}
{1 + \tau \sigma}
If ``b`` is not provided, the proximal operator reduces to:
.. math::
\prox_{\tau f_\alpha}(\mathbf{x}) =
\frac{\mathbf{x} - \tau \alpha \mathbf{q}}{1 + \tau \sigma}
Finally, note that the second term in :math:`f_\alpha(\mathbf{x})` is added
because this combined expression appears in several problems where Bregman
iterations are used alongside a proximal solver.
"""
def __init__(
self,
Op: Optional["LinearOperator"] = None,
b: NDArray | None = None,
q: NDArray | None = None,
sigma: float = 1.0,
alpha: float = 1.0,
qgrad: bool = True,
niter: int | Callable[[int], int] = 10,
x0: NDArray | None = None,
warm: bool = True,
solver: str | None = "legacy",
densesolver: str | None = None,
kwargs_solver: dict[str, Any] | None = None,
) -> None:
super().__init__(Op, True)
self.b = b
self.q = q
self.sigma = sigma
self.alpha = alpha
self.qgrad = qgrad
self.niter = niter
self.x0 = x0
self.warm = warm
self.solver = solver
self.densesolver = densesolver
self.count = 0
self.kwargs_solver = {} if kwargs_solver is None else kwargs_solver
# define whether the normal equations or the regularized system
# of equations are solved
if self.solver in ("legacy", "cg"):
self.normaleqs = True
elif self.solver == "cgls":
self.normaleqs = False
else:
msg = (
f"Provided solver={self.solver}. "
"Available options are: 'legacy', 'cg', 'cgls'."
)
raise ValueError(msg)
# when using factorize, store the first tau*sigma=0 so that the
# first time it will be recomputed (as tau cannot be 0)
if self.densesolver == "factorize":
self.tausigma = 0.0
# create data term
if (
self.Op is not None
and self.b is not None
and (self.Op.explicit or self.normaleqs)
):
self.OpTb = self.sigma * self.Op.H @ self.b
# create A.T A upfront for explicit operators
if self.Op.explicit:
self.ATA = np.conj(self.Op.A.T) @ self.Op.A
def __call__(self, x: NDArray) -> float:
if self.Op is not None and self.b is not None:
f = (self.sigma / 2.0) * (np.linalg.norm(self.Op * x - self.b) ** 2)
elif self.b is not None:
f = (self.sigma / 2.0) * (np.linalg.norm(x - self.b) ** 2)
else:
f = (self.sigma / 2.0) * (np.linalg.norm(x) ** 2)
if self.q is not None:
f += self.alpha * np.dot(self.q, x)
return float(f)
def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]:
"""Increment counter"""
def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any:
self.count += 1
return func(self, *args, **kwargs)
return wrapped
@_increment_count
@_check_tau
def prox(self, x: NDArray, tau: float) -> NDArray:
# define current number of iterations
if isinstance(self.niter, int):
niter = self.niter
else:
niter = self.niter(self.count)
# solve proximal optimization
if self.Op is not None and self.b is not None:
if self.normaleqs or self.Op.explicit:
y = x + tau * self.OpTb
if self.q is not None:
y -= tau * self.alpha * self.q
if self.Op.explicit:
if self.densesolver != "factorize":
Op1 = MatrixMult(
np.eye(self.Op.shape[1]) + tau * self.sigma * self.ATA
)
if self.densesolver is None:
# to allow backward compatibility with
# PyLops versions earlier than v1.18.1
# and v2.0.0
x = Op1.div(y)
else:
x = Op1.div(y, densesolver=self.densesolver)
else:
if self.tausigma != tau * self.sigma:
# recompute factorization
self.tausigma = tau * self.sigma
ATA = np.eye(self.Op.shape[1]) + self.tausigma * self.ATA
self.cl = cho_factor(ATA)
x = cho_solve(self.cl, y)
elif self.normaleqs:
Op1 = Identity(self.Op.shape[1], dtype=self.Op.dtype) + float(
tau * self.sigma
) * (self.Op.H * self.Op)
if self.solver == "legacy":
if get_module_name(get_array_module(x)) == "numpy":
x = sp_lsqr(
Op1, y, iter_lim=niter, x0=self.x0, **self.kwargs_solver
)[0]
else:
x = lsqr(Op1, y, niter=niter, x0=self.x0, **self.kwargs_solver)[
0
].ravel()
elif self.solver == "cg":
x = cg(Op1, y, niter=niter, x0=self.x0, **self.kwargs_solver)[
0
].ravel()
else:
y = x
if self.q is not None:
y -= tau * self.alpha * self.q
x = regularized_inversion(
np.sqrt(tau * self.sigma) * self.Op,
np.sqrt(tau * self.sigma) * self.b,
[
Identity(self.Op.shape[1], dtype=self.Op.dtype),
],
x0=self.x0,
dataregs=[
y,
],
niter=niter,
engine="pylops",
**self.kwargs_solver,
)[0].ravel()
if self.warm:
self.x0 = x
elif self.b is not None:
num = x + tau * self.sigma * self.b
if self.q is not None:
num -= tau * self.alpha * self.q
x = num / (1.0 + tau * self.sigma)
else:
num = x
if self.q is not None:
num -= tau * self.alpha * self.q
x = num / (1.0 + tau * self.sigma)
return x
def grad(self, x: NDArray) -> NDArray:
if self.Op is not None and self.b is not None:
g = self.sigma * self.Op.H @ (self.Op @ x - self.b)
elif self.b is not None:
g = self.sigma * (x - self.b)
else:
g = self.sigma * x
if self.q is not None and self.qgrad:
g += self.alpha * self.q
return g
[docs]
class L2Convolve(ProxOperator):
r"""L2 Norm proximal operator with convolution operator
Proximal operator for the L2 norm defined as: :math:`f(\mathbf{x}) =
\frac{\sigma}{2} ||\mathbf{h} * \mathbf{x} - \mathbf{b}||_2^2` where
:math:`\mathbf{h}` is the kernel of a convolution operator and
:math:`*` represents convolution
Parameters
----------
h : :obj:`numpy.ndarray`
Kernel of convolution operator
b : :obj:`numpy.ndarray`
Data vector
nfft : :obj:`int`, optional
Fourier transform number of samples
sigma : :obj:`int`, optional
Multiplicative coefficient of L2 norm
dims : :obj:`tuple`, optional
Number of samples for each dimension
(``None`` if only one dimension is available)
dir : :obj:`int`, optional
Direction along which smoothing is applied.
Notes
-----
The L2Convolve proximal operator is defined as:
.. math::
prox_{\tau f}(\mathbf{x}) =
F^{-1}\left(\frac{\tau\sigma F(\mathbf{h})^* F(\mathbf{b}) + F(\mathbf{x})}
{1 + \tau\sigma F(\mathbf{h})^* F(\mathbf{h})} \right)
"""
def __init__(
self,
h: NDArray,
b: NDArray,
nfft: int = 2**10,
sigma: float = 1.0,
dims: ShapeLike | None = None,
dir: int | None = None,
) -> None:
super().__init__(None, True)
self.nfft = nfft
self.sigma = sigma
self.dims = dims
if dims is None:
self.dir = -1
else:
self.dir = (
len(dims) - 1
if dir is None
else get_normalize_axis_index()(dir, len(dims))
)
# convert data and filter to Fourier domain
self.bf = np.fft.fft(b, self.nfft, axis=self.dir)
self.hf = np.fft.fft(h, self.nfft, axis=0 if h.ndim == 1 else self.dir)
# expand dimensions of filters
if self.dims is not None:
self.dimsf = list(self.dims).copy()
self.dimsf[self.dir] = nfft
self.bf = self.bf.reshape(self.dimsf)
ndims = len(self.dims)
if self.hf.ndim == 1:
for _ in range(self.dir):
self.hf = np.expand_dims(self.hf, axis=0)
for _ in range(ndims - self.dir - 1):
self.hf = np.expand_dims(self.hf, axis=-1)
# precompute terms for prox
self.hbf = np.conj(self.hf) * self.bf
self.h2f = np.abs(self.hf) ** 2
def __call__(self, x: NDArray) -> float:
if self.dims is not None:
x = x.reshape(self.dims)
xf = np.fft.fft(x, self.nfft, axis=self.dir)
f = (self.sigma / 2.0) * np.linalg.norm(
np.fft.ifft(self.bf - self.hf * xf, axis=self.dir)
) ** 2
return float(f)
@_check_tau
def prox(self, x: NDArray, tau: float) -> NDArray:
if self.dims is not None:
x = x.reshape(self.dims)
xf = np.fft.fft(x, self.nfft, axis=self.dir)
yf = (xf + self.sigma * tau * self.hbf) / (1.0 + self.sigma * tau * self.h2f)
y = np.fft.ifft(yf, axis=self.dir)
if self.dims is None:
y = y[: len(x)]
else:
y = np.take(y, range(self.dims[self.dir]), axis=self.dir).ravel()
return y
def grad(self, x: NDArray) -> NDArray:
if self.dims is not None:
x = x.reshape(self.dims)
xf = np.fft.fft(x, self.nfft, axis=self.dir)
g = self.sigma * np.fft.ifft(
np.conj(self.hf) * (self.hf * xf - self.bf), axis=self.dir
)
g = np.take(g, range(x.shape[self.dir]), axis=self.dir).ravel()
return g