Source code for pyproximal.optimization.red

import time
from collections.abc import Callable
from typing import Any, Literal, cast

import numpy as np
from pylops import LinearOperator
from pylops.basicoperators import Identity
from pylops.optimization.basic import lsqr
from pylops.utils.backend import to_numpy
from pylops.utils.typing import NDArray

from pyproximal.optimization.primal import ADMM
from pyproximal.proximal import L2
from pyproximal.proximal import RED as pRED
from pyproximal.ProxOperator import ProxOperator


def _GradientDescent(
    f: ProxOperator,
    g: ProxOperator,
    x0: NDArray,
    alpha: float = 1.0,
    niter: int = 100,
    callback: Callable[[NDArray], None] | None = None,
    show: bool = False,
) -> NDArray:
    """Gradient descent solver for composite objective.

    Parameters
    ----------
    f: :obj:`pyproximal.ProxOperator`
        First proximal operator (must implement ``grad`` and ``call``)
    g: :obj:`pyproximal.ProxOperator`
        Second proximal operator (must implement ``grad`` and ``call``)
    x0 : :obj:`numpy.ndarray`
        Initial vector
    alpha : :obj:`float`, optional
        Step size
    niter : :obj:`int`, optional
        Number of iterations
    callback : :obj:`callable`, optional
        Function with signature (``callback(x)``) to call after each iteration
        where ``x`` is the current model vector
    show: :obj:`bool`, optional
        Display iterations log

    Returns
    -------
    x : :obj:`numpy.ndarray`
        Inverted model

    """
    if show:
        tstart = time.time()
        print(
            "Gradient descent algorithm\n"
            "---------------------------------------------------------\n"
            "Proximal operator (f): %s\n"
            "Proximal operator (g): %s\n"
            "alpha = %10e\tniter = %d\n" % (type(f), type(g), alpha, niter)
        )
        head = "   Itn       x[0]          f"
        print(head)

    x = x0.copy()
    for iiter in range(niter):
        grad = f.grad(x).real + g.grad(x)
        x -= alpha * grad
        if callback is not None:
            callback(x)
        if show:
            if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
                msg = "%6g  %12.5e  %10.3e" % (
                    iiter + 1,
                    np.real(to_numpy(x[0])),
                    f(x),
                )
                print(msg)
    if show:
        print("\nTotal time (s) = %.2f" % (time.time() - tstart))
        print("---------------------------------------------------------\n")
    return x


def _FixedPoint(
    Op: LinearOperator,
    y: NDArray,
    denoiser: Callable[[NDArray, float], NDArray],
    x0: NDArray,
    sigmad: float | Callable[[int], Any],
    sigmaOp: float,
    sigma: float,
    niter: int = 100,
    niter_inner: int = 10,
    callback: Callable[[NDArray], None] | None = None,
    show: bool = False,
) -> NDArray:
    """Fixed-point solver for L2 data-misfit term and RED term.

    Parameters
    ----------
    Op: :obj:`pylops.LinearOperator`
        Linear operator of L2 data-misfit term
    y: :obj:`numpy.ndarray`
        Data vector
    denoiser: :obj:`callable`
        Denoising function
    x0 : :obj:`numpy.ndarray`
        Initial vector
    sigmad : :obj:`float`, optional
        Strenght of the denoiser
    sigmaOp : :obj:`float`, optional
        Multiplicative coefficient of L2 data-misfit term
    sigma : :obj:`float` or :obj:`func`, optional
        Multiplicative coefficient of RED term
    niter : :obj:`int`, optional
        Number of iterations
    niter_inner : :obj:`int`, optional
        Number of iterations for the inner solver
    callback : :obj:`callable`, optional
        Function with signature (``callback(x)``) to call after each iteration
        where ``x`` is the current model vector
    show: :obj:`bool`, optional
        Display iterations log

    Returns
    -------
    x : :obj:`numpy.ndarray`
        Inverted model

    """
    if show:
        tstart = time.time()
        print(
            "Fixed point algorithm\n"
            "---------------------------------------------------------\n"
            "Linear Operator: %s\n"
            "Denoiser: %s\n"
            "sigmad = %s\tsigmaOp = %10e\tsigma = %10e\n"
            "niter = %10e\tniter_inner = %d\n"
            % (
                type(Op),
                type(denoiser),
                "multi" if callable(sigmad) else str(sigmad),
                sigmaOp,
                sigma,
                niter,
                niter_inner,
            )
        )
        head = "   Itn       x[0]          f"
        print(head)

    x = x0.copy()
    yy = sigmaOp * Op.H @ y
    Iop = Identity(Op.shape[1], dtype=Op.dtype)
    Op1 = sigma * Iop + sigmaOp * Op.H * Op
    for iiter in range(niter):
        xden = denoiser(x, sigmad(iiter) if callable(sigmad) else sigmad)
        y1 = yy + sigma * xden
        x = x = lsqr(Op1, y1, niter=niter_inner, x0=x)[0]
        if callback is not None:
            callback(x)
            if show:
                if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
                    msg = "%6g  %12.5e  %10.3e" % (
                        iiter + 1,
                        np.real(to_numpy(x[0])),
                        sigmaOp / 2.0 * np.linalg.norm(Op * x - y) ** 2,
                    )
                    print(msg)
    if show:
        print("\nTotal time (s) = %.2f" % (time.time() - tstart))
        print("---------------------------------------------------------\n")

    return x


[docs] def RED( proxf: ProxOperator, red: pRED, x0: NDArray, solver: Callable[..., NDArray] | Literal["gradientdescent", "fixedpoint"] = ADMM, **kwargs_solver: dict[str, Any], ) -> NDArray | tuple[NDArray, ...]: r"""Regularization by Denoising (RED) solver Solves the following minimization problem: .. math:: \mathbf{x} = \argmin_{\mathbf{x}} f(\mathbf{x}) + \sigma \mathbf{x}^T (\mathbf{x} - g_{\sigma_d}(\mathbf{x})) using either any of the proximal solvers in :mod:`pyproximal.optimization.primal` or :mod:`pyproximal.optimization.primaldual`, or the gradient descent of fixed point solvers. Here :math:`f(\mathbf{x})` is a function that has a known gradient or proximal operator, whilst :math:`g_{\sigma_d}(\mathbf{x})` is a denoiser acting on a noisy signal with noise strenght equal to `\sigma_d`. Parameters ---------- proxf : :obj:`pyproximal.ProxOperator` Proximal operator of f function. In the case of `solver="gradientdescent", it must implement the `grad` method red : :obj:`pyproximal.proximal.RED` RED operator x0 : :obj:`numpy.ndarray` Initial vector solver : :func:`pyproximal.optimization.primal` or :func:`pyproximal.optimization.primaldual` or :obj:`str` Solver of choice or ``"gradientdescent"`` for gradient-descent solver or ``"fixedpoint"`` for fixed-point solver kwargs_solver : :obj:`dict` Additional parameters required by the selected solver. For ``solver="gradientdescent"``, the parameter ``alpha`` corresponds to the step size; for ``solver="fixedpoint"``, the parameter ``niter_inner`` corresponds to the number of iterations of the inner loop Returns ------- out : :obj:`numpy.ndarray` or :obj:`tuple` Output of the solver of choice. For Raises ------ ValueError If ``solver="gradientdescent"`` and ``proxf`` does not implement the ``grad`` method or ``solver="fixedpoint"`` and ``proxf`` is not of :class:`pyproximal.proximal.L2` type. Notes ----- The Regularization by Denoising (RED) term can be used with any proximal solvers, since its proximal operator can be easily evaluated via fixed-point iterations (see Notes of :class:`pyproximal.proximal.RED` for more details). However, [1]_ presented two additional solvers, namely: - gradient descent (for differentiable :math:`f`): .. math:: \mathbf{x}^{k+1} = \mathbf{x}^k - \alpha (\nabla_f(\mathbf{x}^k) + \mathbf{x}^k - g_{\sigma_d}(\mathbf{x}^k)) - fixed point (for :math:`f= \frac{\sigma_{\mathbf{Op}}}{2} ||\mathbf{Op}\mathbf{x} - \mathbf{b}||_2^2`): .. math:: \mathbf{y}^k = g_{\sigma_d}(\mathbf{x}^k)\\ \mathbf{x}^{k+1} = (\sigma_{\mathbf{Op}} \mathbf{Op}^H\mathbf{Op} + \sigma \mathbf{I})^{-1} (\sigma_{\mathbf{Op}} \mathbf{Op}^H \mathbf{b} + \sigma \mathbf{y}^k) References ---------- .. [1] Romano, Y., Elad, M., and Milanfar, P. "The Little Engine that Could Regularization by Denoising (RED)", SIAM Journal on Imaging Science. 2017. """ if isinstance(solver, str): if solver == "gradientdescent": if hasattr(proxf, "grad") and callable(proxf.grad): return _GradientDescent( proxf, red, x0=x0, **cast(dict[str, Any], kwargs_solver) ) else: msg = f"The provided proximal operator ({proxf}) does not implement the grad method..." raise ValueError(msg) else: if isinstance(proxf, L2): return _FixedPoint( proxf.Op, proxf.b, red.denoiser, x0, red.sigmad, proxf.sigma, red.sigma, **cast(dict[str, Any], kwargs_solver), ) else: msg = f"The proximal operator must be of type L2, ({type(proxf).__name__}) provided instead..." raise ValueError(msg) else: return solver(proxf, red, x0=x0, **kwargs_solver)