Source code for pyproximal.proximal.TV

from collections.abc import Callable
from copy import deepcopy
from typing import Any

import numpy as np
from pylops import FirstDerivative, Gradient
from pylops.utils.typing import NDArray, ShapeLike
from typing_extensions import Self

from pyproximal.ProxOperator import ProxOperator, _check_tau


[docs] class TV(ProxOperator): r"""TV Norm proximal operator. Proximal operator for the TV norm defined as: :math:`f(\mathbf{x}) = \sigma ||\mathbf{x}||_{\text{TV}}`. Parameters ---------- dims : :obj:`tuple`, optional Number of samples for each dimension (``None`` if only one dimension is available) sigma : :obj:`float`, optional Multiplicative coefficient of TV norm niter : :obj:`int` or :obj:`func`, optional Number of iterations of iterative scheme used to compute the proximal. This can be a constant number or a function that is called passing a counter which keeps track of how many times the ``prox`` method has been invoked before and returns the ``niter`` to be used. rtol : :obj:`float`, optional Relative tolerance for stopping criterion. Notes ----- The proximal algorithm is implemented following [1]_. .. [1] Beck, A. and Teboulle, M., "Fast gradient-based algorithms for constrained total variation image denoising and deblurring problems", 2009. """ def __init__( self, dims: ShapeLike | None, sigma: float = 1.0, niter: int | Callable[[int], int] = 10, rtol: float = 1e-4, **kwargs: Any, ) -> None: super().__init__(None, True) self.dims = (0,) if dims is None else dims self.ndim = 1 if dims is None else len(dims) self.sigma = sigma self.niter = niter self.count = 0 self.rtol = rtol self.kwargs = kwargs def __call__(self, x: NDArray) -> float: x = x.reshape(self.dims) if self.ndim == 1: derivOp = FirstDerivative( dims=self.dims[0], axis=0, edge=False, dtype=x.dtype, kind="forward" ) dx = derivOp @ x y = np.sum(np.abs(dx), axis=0) elif self.ndim >= 2: y = 0 gradOp = Gradient(self.dims, edge=False, dtype=x.dtype, kind="forward") grads = gradOp.matvec(x.ravel()) grads = grads.reshape((self.ndim,) + self.dims) for g in grads: y += np.power(abs(g), 2) y = np.sqrt(y) return float(self.sigma * np.sum(y)) def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]: """Increment counter""" def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any: self.count += 1 return func(self, *args, **kwargs) return wrapped @_increment_count @_check_tau def prox(self, x: NDArray, tau: float) -> NDArray: # define current number of iterations if isinstance(self.niter, int): niter = self.niter else: niter = self.niter(self.count) gamma = self.sigma * tau rtol = self.rtol # TODO implement test_gamma # Initialization x = x.reshape(self.dims) sol = x if self.ndim == 1: derivOp = FirstDerivative( dims=self.dims[0], axis=0, edge=False, dtype=x.dtype, kind="forward" ) else: gradOp = Gradient(x.shape, edge=False, dtype=x.dtype, kind="forward") if self.ndim == 1: r = derivOp @ (x * 0) rr = deepcopy(r) elif self.ndim == 2: r, s = gradOp.matvec((x * 0).ravel()).reshape((self.ndim,) + x.shape) rr, ss = deepcopy(r), deepcopy(s) elif self.ndim == 3: r, s, k = gradOp.matvec((x * 0).ravel()).reshape((self.ndim,) + x.shape) rr, ss, kk = deepcopy(r), deepcopy(s), deepcopy(k) elif self.ndim == 4: r, s, k, u = gradOp.matvec((x * 0).ravel()).reshape((self.ndim,) + x.shape) rr, ss, kk, uu = deepcopy(r), deepcopy(s), deepcopy(k), deepcopy(u) if self.ndim >= 1: pold = r if self.ndim >= 2: qold = s if self.ndim >= 3: oold = k if self.ndim >= 4: mold = u told, prev_obj = 1.0, 0.0 # Initialization for weights if self.ndim >= 1: try: wx = self.kwargs["wx"] except (KeyError, TypeError): wx = 1.0 if self.ndim >= 2: try: wy = self.kwargs["wy"] except (KeyError, TypeError): wy = 1.0 if self.ndim >= 3: try: wz = self.kwargs["wz"] except (KeyError, TypeError): wz = 1.0 if self.ndim >= 4: try: wt = self.kwargs["wt"] except (KeyError, TypeError): wt = 1.0 if self.ndim == 1: mt = wx elif self.ndim == 2: mt = np.maximum(wx, wy) elif self.ndim == 3: mt = np.maximum(wx, np.maximum(wy, wz)) elif self.ndim == 4: mt = np.maximum(np.maximum(wx, wy), np.maximum(wz, wt)) if self.ndim >= 1: try: rr *= np.conjugate(wx) except KeyError: pass if self.ndim >= 2: try: ss *= np.conjugate(wy) except KeyError: pass if self.ndim >= 3: try: kk *= np.conjugate(wz) except KeyError: pass if self.ndim >= 4: try: uu *= np.conjugate(wt) except KeyError: pass iter = 0 while iter <= niter: # Current Solution if self.ndim >= 1: div = np.concatenate( ( np.expand_dims( rr[0,], axis=0, ), rr[1:-1,] - rr[:-2,], -np.expand_dims( rr[-2,], axis=0, ), ), axis=0, ) if self.ndim >= 2: div += np.concatenate( ( np.expand_dims( ss[ :, 0, ], axis=1, ), ss[ :, 1:-1, ] - ss[ :, :-2, ], -np.expand_dims( ss[ :, -2, ], axis=1, ), ), axis=1, ) if self.ndim >= 3: div += np.concatenate( ( np.expand_dims( kk[ :, :, 0, ], axis=2, ), kk[ :, :, 1:-1, ] - kk[ :, :, :-2, ], -np.expand_dims( kk[ :, :, -2, ], axis=2, ), ), axis=2, ) if self.ndim >= 4: div += np.concatenate( ( np.expand_dims( uu[ :, :, :, 0, ], axis=3, ), uu[ :, :, :, 1:-1, ] - uu[ :, :, :, :-2, ], -np.expand_dims( uu[ :, :, :, -2, ], axis=3, ), ), axis=3, ) sol = x - gamma * div # Objective function value obj = 0.5 * np.power(np.linalg.norm(x[:] - sol[:]), 2) + gamma * np.sum( self.__call__(sol), axis=0 ) if obj > 1e-10: rel_obj = np.abs(obj - prev_obj) / obj else: rel_obj = 2 * rtol prev_obj = obj # Stopping criterion if rel_obj < rtol: break # Update divergence vectors and project if self.ndim == 1: dx = derivOp @ sol r -= 1.0 / (4 * gamma * mt**2) * dx weights = np.maximum(1, np.abs(r)) elif self.ndim == 2: dx, dy = gradOp.matvec(sol.ravel()).reshape((self.ndim,) + x.shape) r -= (1.0 / (8.0 * gamma * mt**2.0)) * dx s -= (1.0 / (8.0 * gamma * mt**2.0)) * dy weights = np.maximum( 1, np.sqrt(np.power(np.abs(r), 2) + np.power(np.abs(s), 2)) ) elif self.ndim == 3: dx, dy, dz = gradOp.matvec(sol.ravel()).reshape((self.ndim,) + x.shape) r -= 1.0 / (12.0 * gamma * mt**2) * dx s -= 1.0 / (12.0 * gamma * mt**2) * dy k -= 1.0 / (12.0 * gamma * mt**2) * dz weights = np.maximum( 1, np.sqrt( np.power(np.abs(r), 2) + np.power(np.abs(s), 2) + np.power(np.abs(k), 2) ), ) elif self.ndim == 4: dx, dy, dz, dt = gradOp.matvec(sol.ravel()).reshape( (self.ndim,) + x.shape ) r -= 1.0 / (16 * gamma * mt**2) * dx s -= 1.0 / (16 * gamma * mt**2) * dy k -= 1.0 / (16 * gamma * mt**2) * dz u -= 1.0 / (16 * gamma * mt**2) * dt weights = np.maximum( 1, np.sqrt( np.power(np.abs(r), 2) + np.power(np.abs(s), 2) + np.power(np.abs(k), 2) + np.power(np.abs(u), 2) ), ) # FISTA update t = (1 + np.sqrt(4 * told**2)) / 2.0 if self.ndim >= 1: p = r / weights r = p + (told - 1) / t * (p - pold) pold = p rr = deepcopy(r) if self.ndim >= 2: q = s / weights s = q + (told - 1) / t * (q - qold) qold = q ss = deepcopy(s) if self.ndim >= 3: o = k / weights k = o + (told - 1) / t * (o - oold) oold = o kk = deepcopy(k) if self.ndim >= 4: m = u / weights u = m + (told - 1) / t * (m - mold) mold = m uu = deepcopy(u) told = t iter += 1 return sol.ravel()