Source code for pyproximal.proximal.Geman

from typing import Tuple

import numpy as np
from pylops.utils.typing import NDArray

from pyproximal.ProxOperator import ProxOperator, _check_tau


[docs]class Geman(ProxOperator): r"""Geman penalty. The Geman penalty (named after its inventor) is a non-convex penalty [1]_. The pyproximal implementation considers a generalized model where .. math:: \mathrm{Geman}_{\sigma,\gamma}(\mathbf{x}) = \sum_i \frac{\sigma |x_i|}{|x_i| + \gamma} where :math:`{\sigma\geq 0}`, :math:`{\gamma>0}`. Parameters ---------- sigma : :obj:`float` Regularization parameter. gamma : :obj:`float`, optional Regularization parameter. Default is 1.3. Notes ----- In order to compute the proximal operator of the Geman penalty one must find the roots of a cubic polynomial. Consider the one-dimensional problem .. math:: \prox_{\tau \mathrm{Geman}(\cdot)}(x) = \argmin_{z} \mathrm{Geman}(z) + \frac{1}{2\tau}(x - z)^2 and assume :math:`{x\geq 0}`. Either the minimum is obtained when :math:`z=0` or when .. math:: \tau\sigma\gamma + (z-x)(z+\gamma)^2 = 0 . The pyproximal implementation uses the closed-form solution for a cubic equation, and discards infeasible roots, to find the minimum. .. [1] Geman and Yang "Nonlinear image recovery with half-quadratic regularization", IEEE Transactions on Image Processing, 4(7):932 – 946, 1995. """ def __init__(self, sigma: float, gamma: float = 1.3) -> None: super().__init__(None, False) if sigma < 0: raise ValueError('Variable "sigma" must be positive.') if gamma <= 0: raise ValueError('Variable "gamma" must be strictly positive.') self.sigma = sigma self.gamma = gamma def __call__(self, x: NDArray) -> float: return float(np.sum(self.elementwise(x))) def elementwise(self, x: NDArray) -> NDArray: return self.sigma * np.abs(x) / (np.abs(x) + self.gamma) @_check_tau def prox(self, x: NDArray, tau: float) -> NDArray: out = np.zeros_like(x) b = 2 * self.gamma - np.abs(x) c = self.gamma**2 - 2 * self.gamma * np.abs(x) d = self.gamma * self.sigma * tau - self.gamma**2 * np.abs(x) idx, loc_mins = self._find_local_minima(b, c, d) global_min_idx = ( tau * self.elementwise(loc_mins) + (loc_mins - np.abs(x[idx])) ** 2 / 2 < np.abs(x[idx]) ** 2 / 2 ) idx[idx] = global_min_idx out[idx] = np.sign(x[idx]) * loc_mins[global_min_idx] return out @staticmethod def _find_local_minima( b: NDArray, c: NDArray, d: NDArray ) -> Tuple[NDArray, NDArray]: f = -((c - b**2.0 / 3.0) ** 3.0) / 27.0 g = (2.0 * b**3.0 - 9.0 * b * c + 27.0 * d) / 27.0 idx = g**2.0 / 4.0 - f <= 0 sqrtf = np.sqrt(f[idx]) k = np.arccos(-(g[idx] / (2 * sqrtf))) loc_mins = 2 * sqrtf ** (1 / 3.0) * np.cos(k / 3.0) - b[idx] / 3.0 return idx, loc_mins