# Source code for pyproximal.proximal.Geman

import numpy as np

from pyproximal.ProxOperator import _check_tau
from pyproximal import ProxOperator

[docs]class Geman(ProxOperator):
r"""Geman penalty.

The Geman penalty (named after its inventor) is a non-convex penalty [1]_.
The pyproximal implementation considers a generalized model where

.. math::

\mathrm{Geman}_{\sigma,\gamma}(\mathbf{x}) = \sum_i \frac{\sigma |x_i|}{|x_i| + \gamma}

where :math:{\sigma\geq 0}, :math:{\gamma>0}.

Parameters
----------
sigma : :obj:float
Regularization parameter.
gamma : :obj:float, optional
Regularization parameter. Default is 1.3.

Notes
-----
In order to compute the proximal operator of the Geman penalty one must find the
roots of a cubic polynomial. Consider the one-dimensional problem

.. math::
\prox_{\tau \mathrm{Geman}(\cdot)}(x) = \argmin_{z} \mathrm{Geman}(z) + \frac{1}{2\tau}(x - z)^2

and assume :math:{x\geq 0}. Either the minimum is obtained when :math:z=0 or
when

.. math::
\tau\sigma\gamma + (z-x)(z+\gamma)^2 = 0 .

The pyproximal implementation uses the closed-form solution for a cubic equation,
and discards infeasible roots, to find the minimum.

.. [1] Geman and Yang "Nonlinear image recovery with half-quadratic regularization",
IEEE Transactions on Image Processing, 4(7):932 – 946, 1995.

"""

def __init__(self, sigma, gamma=1.3):
super().__init__(None, False)
if sigma < 0:
raise ValueError('Variable "sigma" must be positive.')
if gamma <= 0:
raise ValueError('Variable "gamma" must be strictly positive.')
self.sigma = sigma
self.gamma = gamma

def __call__(self, x):
return np.sum(self.elementwise(x))

def elementwise(self, x):
return self.sigma * np.abs(x) / (np.abs(x) + self.gamma)

@_check_tau
def prox(self, x, tau):
out = np.zeros_like(x)
b = 2 * self.gamma - np.abs(x)
c = self.gamma ** 2 - 2 * self.gamma * np.abs(x)
d = self.gamma * self.sigma * tau - self.gamma ** 2 * np.abs(x)
idx, loc_mins = self._find_local_minima(b, c, d)
global_min_idx = tau * self.elementwise(loc_mins) + \
(loc_mins - np.abs(x[idx])) ** 2 / 2 < np.abs(x[idx]) ** 2 / 2
idx[idx] = global_min_idx
out[idx] = np.sign(x[idx]) * loc_mins[global_min_idx]
return out

@staticmethod
def _find_local_minima(b, c, d):
f = -(c - b ** 2.0 / 3.0) ** 3.0 / 27.0
g = (2.0 * b ** 3.0 - 9.0 * b * c + 27.0 * d) / 27.0
idx = g ** 2.0 / 4.0 - f <= 0
sqrtf = np.sqrt(f[idx])
k = np.arccos(-(g[idx] / (2 * sqrtf)))
loc_mins = 2 * sqrtf ** (1 / 3.0) * np.cos(k / 3.0) - b[idx] / 3.0
return idx, loc_mins