Source code for pyproximal.proximal.TV

from copy import deepcopy
from typing import Any, Callable, Union

import numpy as np
from pylops import FirstDerivative, Gradient
from pylops.utils.typing import NDArray, ShapeLike
from typing_extensions import Self

from pyproximal.ProxOperator import ProxOperator, _check_tau


[docs]class TV(ProxOperator): r"""TV Norm proximal operator. Proximal operator for the TV norm defined as: :math:`f(\mathbf{x}) = \sigma ||\mathbf{x}||_{\text{TV}}`. Parameters ---------- dims : :obj:`tuple` Number of samples for each dimension (``None`` if only one dimension is available) sigma : :obj:`float`, optional Multiplicative coefficient of TV norm niter : :obj:`int` or :obj:`func`, optional Number of iterations of iterative scheme used to compute the proximal. This can be a constant number or a function that is called passing a counter which keeps track of how many times the ``prox`` method has been invoked before and returns the ``niter`` to be used. rtol : :obj:`float`, optional Relative tolerance for stopping criterion. Notes ----- The proximal algorithm is implemented following [1]_. .. [1] Beck, A. and Teboulle, M., "Fast gradient-based algorithms for constrained total variation image denoising and deblurring problems", 2009. """ def __init__( self, dims: ShapeLike, sigma: float = 1.0, niter: Union[int, Callable[[int], int]] = 10, rtol: float = 1e-4, **kwargs: Any, ) -> None: super().__init__(None, True) self.dims = dims self.ndim = len(dims) self.sigma = sigma self.niter = niter self.count = 0 self.rtol = rtol self.kwargs = kwargs def __call__(self, x: NDArray) -> float: x = x.reshape(self.dims) if self.ndim == 1: derivOp = FirstDerivative( dims=self.dims[0], axis=0, edge=False, dtype=x.dtype, kind="forward" ) dx = derivOp @ x y = np.sum(np.abs(dx), axis=0) elif self.ndim >= 2: y = 0 gradOp = Gradient(self.dims, edge=False, dtype=x.dtype, kind="forward") grads = gradOp.matvec(x.ravel()) grads = grads.reshape((self.ndim,) + self.dims) for g in grads: y += np.power(abs(g), 2) y = np.sqrt(y) return float(self.sigma * np.sum(y)) def _increment_count(func: Callable[..., Any]) -> Callable[..., Any]: """Increment counter""" def wrapped(self: Self, *args: Any, **kwargs: Any) -> Any: self.count += 1 return func(self, *args, **kwargs) return wrapped @_increment_count @_check_tau def prox(self, x: NDArray, tau: float) -> NDArray: # define current number of iterations if isinstance(self.niter, int): niter = self.niter else: niter = self.niter(self.count) gamma = self.sigma * tau rtol = self.rtol # TODO implement test_gamma # Initialization x = x.reshape(self.dims) sol = x if self.ndim == 1: derivOp = FirstDerivative( dims=self.dims[0], axis=0, edge=False, dtype=x.dtype, kind="forward" ) else: gradOp = Gradient(x.shape, edge=False, dtype=x.dtype, kind="forward") if self.ndim == 1: r = derivOp @ (x * 0) rr = deepcopy(r) elif self.ndim == 2: r, s = gradOp.matvec((x * 0).ravel()).reshape((self.ndim,) + x.shape) rr, ss = deepcopy(r), deepcopy(s) elif self.ndim == 3: r, s, k = gradOp.matvec((x * 0).ravel()).reshape((self.ndim,) + x.shape) rr, ss, kk = deepcopy(r), deepcopy(s), deepcopy(k) elif self.ndim == 4: r, s, k, u = gradOp.matvec((x * 0).ravel()).reshape((self.ndim,) + x.shape) rr, ss, kk, uu = deepcopy(r), deepcopy(s), deepcopy(k), deepcopy(u) if self.ndim >= 1: pold = r if self.ndim >= 2: qold = s if self.ndim >= 3: oold = k if self.ndim >= 4: mold = u told, prev_obj = 1.0, 0.0 # Initialization for weights if self.ndim >= 1: try: wx = self.kwargs["wx"] except (KeyError, TypeError): wx = 1.0 if self.ndim >= 2: try: wy = self.kwargs["wy"] except (KeyError, TypeError): wy = 1.0 if self.ndim >= 3: try: wz = self.kwargs["wz"] except (KeyError, TypeError): wz = 1.0 if self.ndim >= 4: try: wt = self.kwargs["wt"] except (KeyError, TypeError): wt = 1.0 if self.ndim == 1: mt = wx elif self.ndim == 2: mt = np.maximum(wx, wy) elif self.ndim == 3: mt = np.maximum(wx, np.maximum(wy, wz)) elif self.ndim == 4: mt = np.maximum(np.maximum(wx, wy), np.maximum(wz, wt)) if self.ndim >= 1: try: rr *= np.conjugate(wx) except KeyError: pass if self.ndim >= 2: try: ss *= np.conjugate(wy) except KeyError: pass if self.ndim >= 3: try: kk *= np.conjugate(wz) except KeyError: pass if self.ndim >= 4: try: uu *= np.conjugate(wt) except KeyError: pass iter = 0 while iter <= niter: # Current Solution if self.ndim == 0: raise ValueError("Need to input at least one value") if self.ndim >= 1: div = np.concatenate( ( np.expand_dims( rr[0,], axis=0, ), rr[1:-1,] - rr[:-2,], -np.expand_dims( rr[-2,], axis=0, ), ), axis=0, ) if self.ndim >= 2: div += np.concatenate( ( np.expand_dims( ss[ :, 0, ], axis=1, ), ss[ :, 1:-1, ] - ss[ :, :-2, ], -np.expand_dims( ss[ :, -2, ], axis=1, ), ), axis=1, ) if self.ndim >= 3: div += np.concatenate( ( np.expand_dims( kk[ :, :, 0, ], axis=2, ), kk[ :, :, 1:-1, ] - kk[ :, :, :-2, ], -np.expand_dims( kk[ :, :, -2, ], axis=2, ), ), axis=2, ) if self.ndim >= 4: div += np.concatenate( ( np.expand_dims( uu[ :, :, :, 0, ], axis=3, ), uu[ :, :, :, 1:-1, ] - uu[ :, :, :, :-2, ], -np.expand_dims( uu[ :, :, :, -2, ], axis=3, ), ), axis=3, ) sol = x - gamma * div # Objective function value obj = 0.5 * np.power(np.linalg.norm(x[:] - sol[:]), 2) + gamma * np.sum( self.__call__(sol), axis=0 ) if obj > 1e-10: rel_obj = np.abs(obj - prev_obj) / obj else: rel_obj = 2 * rtol prev_obj = obj # Stopping criterion if rel_obj < rtol: break # Update divergence vectors and project if self.ndim == 1: dx = derivOp @ sol r -= 1.0 / (4 * gamma * mt**2) * dx weights = np.maximum(1, np.abs(r)) elif self.ndim == 2: dx, dy = gradOp.matvec(sol.ravel()).reshape((self.ndim,) + x.shape) r -= (1.0 / (8.0 * gamma * mt**2.0)) * dx s -= (1.0 / (8.0 * gamma * mt**2.0)) * dy weights = np.maximum( 1, np.sqrt(np.power(np.abs(r), 2) + np.power(np.abs(s), 2)) ) elif self.ndim == 3: dx, dy, dz = gradOp.matvec(sol.ravel()).reshape((self.ndim,) + x.shape) r -= 1.0 / (12.0 * gamma * mt**2) * dx s -= 1.0 / (12.0 * gamma * mt**2) * dy k -= 1.0 / (12.0 * gamma * mt**2) * dz weights = np.maximum( 1, np.sqrt( np.power(np.abs(r), 2) + np.power(np.abs(s), 2) + np.power(np.abs(k), 2) ), ) elif self.ndim == 4: dx, dy, dz, dt = gradOp.matvec(sol.ravel()).reshape( (self.ndim,) + x.shape ) r -= 1.0 / (16 * gamma * mt**2) * dx s -= 1.0 / (16 * gamma * mt**2) * dy k -= 1.0 / (16 * gamma * mt**2) * dz u -= 1.0 / (16 * gamma * mt**2) * dt weights = np.maximum( 1, np.sqrt( np.power(np.abs(r), 2) + np.power(np.abs(s), 2) + np.power(np.abs(k), 2) + np.power(np.abs(u), 2) ), ) # FISTA update t = (1 + np.sqrt(4 * told**2)) / 2.0 if self.ndim >= 1: p = r / weights r = p + (told - 1) / t * (p - pold) pold = p rr = deepcopy(r) if self.ndim >= 2: q = s / weights s = q + (told - 1) / t * (q - qold) qold = q ss = deepcopy(s) if self.ndim >= 3: o = k / weights k = o + (told - 1) / t * (o - oold) oold = o kk = deepcopy(k) if self.ndim >= 4: m = u / weights u = m + (told - 1) / t * (m - mold) mold = m uu = deepcopy(u) told = t iter += 1 return sol.ravel()