Source code for pyproximal.proximal.Geman
import numpy as np
from pyproximal.ProxOperator import _check_tau
from pyproximal import ProxOperator
[docs]class Geman(ProxOperator):
r"""Geman penalty.
The Geman penalty (named after its inventor) is a non-convex penalty [1]_.
The pyproximal implementation considers a generalized model where
.. math::
\mathrm{Geman}_{\sigma,\gamma}(\mathbf{x}) = \sum_i \frac{\sigma |x_i|}{|x_i| + \gamma}
where :math:`{\sigma\geq 0}`, :math:`{\gamma>0}`.
Parameters
----------
sigma : :obj:`float`
Regularization parameter.
gamma : :obj:`float`, optional
Regularization parameter. Default is 1.3.
Notes
-----
In order to compute the proximal operator of the Geman penalty one must find the
roots of a cubic polynomial. Consider the one-dimensional problem
.. math::
\prox_{\tau \mathrm{Geman}(\cdot)}(x) = \argmin_{z} \mathrm{Geman}(z) + \frac{1}{2\tau}(x - z)^2
and assume :math:`{x\geq 0}`. Either the minimum is obtained when :math:`z=0` or
when
.. math::
\tau\sigma\gamma + (z-x)(z+\gamma)^2 = 0 .
The pyproximal implementation uses the closed-form solution for a cubic equation,
and discards infeasible roots, to find the minimum.
.. [1] Geman and Yang "Nonlinear image recovery with half-quadratic regularization",
IEEE Transactions on Image Processing, 4(7):932 – 946, 1995.
"""
def __init__(self, sigma, gamma=1.3):
super().__init__(None, False)
if sigma < 0:
raise ValueError('Variable "sigma" must be positive.')
if gamma <= 0:
raise ValueError('Variable "gamma" must be strictly positive.')
self.sigma = sigma
self.gamma = gamma
def __call__(self, x):
return np.sum(self.elementwise(x))
def elementwise(self, x):
return self.sigma * np.abs(x) / (np.abs(x) + self.gamma)
@_check_tau
def prox(self, x, tau):
out = np.zeros_like(x)
b = 2 * self.gamma - np.abs(x)
c = self.gamma ** 2 - 2 * self.gamma * np.abs(x)
d = self.gamma * self.sigma * tau - self.gamma ** 2 * np.abs(x)
idx, loc_mins = self._find_local_minima(b, c, d)
global_min_idx = tau * self.elementwise(loc_mins) + \
(loc_mins - np.abs(x[idx])) ** 2 / 2 < np.abs(x[idx]) ** 2 / 2
idx[idx] = global_min_idx
out[idx] = np.sign(x[idx]) * loc_mins[global_min_idx]
return out
@staticmethod
def _find_local_minima(b, c, d):
f = -(c - b ** 2.0 / 3.0) ** 3.0 / 27.0
g = (2.0 * b ** 3.0 - 9.0 * b * c + 27.0 * d) / 27.0
idx = g ** 2.0 / 4.0 - f <= 0
sqrtf = np.sqrt(f[idx])
k = np.arccos(-(g[idx] / (2 * sqrtf)))
loc_mins = 2 * sqrtf ** (1 / 3.0) * np.cos(k / 3.0) - b[idx] / 3.0
return idx, loc_mins