Source code for pyproximal.proximal.Nonlinear
import numpy as np
from pyproximal.ProxOperator import _check_tau
from pyproximal import ProxOperator
[docs]class Nonlinear(ProxOperator):
r"""Nonlinear function proximal operator.
Proximal operator for a generic nonlinear function :math:`f`. This is a
template class which a user must subclass and implement the following
methods:
- ``fun``: a method evaluating the generic function :math:`f`
- ``grad``: a method evaluating the gradient of the generic function
:math:`f`
- ``fungrad``: a method evaluating both the generic function :math:`f`
and its gradient
- ``optimize``: a method that solves the optimization problem associated
with the proximal operator of :math:`f`. Note that the
``gradprox`` method must be used (instead of ``grad``) as this will
automatically add the regularization term involved in the evaluation
of the proximal operator
Parameters
----------
x0 : :obj:`np.ndarray`
Initial vector
niter : :obj:`int`, optional
Number of iterations of iterative scheme used to compute the proximal
warm : :obj:`bool`, optional
Warm start (``True``) or not (``False``). Uses estimate from previous
call of ``prox`` method.
Notes
-----
The proximal operator of a generic function requires solving the following
optimization problem numerically
.. math::
prox_{\tau f} (\mathbf{x}) = arg \; min_{\mathbf{y}} f(\mathbf{y}) +
\frac{1}{2 \tau}||\mathbf{y} - \mathbf{x}||^2_2
which is done via the provided ``optimize`` method.
"""
def __init__(self, x0, niter=10, warm=True):
super().__init__(None, True)
self.niter = niter
self.x0 = x0
self.warm = warm
def __call__(self, x):
return self.fun(x)
def _funprox(self, x, tau):
return self.fun(x) + 1. / (2 * tau) * ((x - self.y) ** 2).sum()
def _gradprox(self, x, tau):
return self.grad(x) + 1. / tau * (x - self.y)
def _fungradprox(self, x, tau):
f, g = self.fungrad(x)
f = f + 1. / (2 * tau) * ((x - self.y) ** 2).sum()
g = g + 1. / tau * (x - self.y)
return f, g
def fun(self, x):
raise NotImplementedError('The method fun has not been implemented.'
'Refer to the documentation for details on '
'how to subclass this operator.')
def grad(self, x):
raise NotImplementedError('The method grad has not been implemented.'
'Refer to the documentation for details on '
'how to subclass this operator.')
def fungrad(self, x):
raise NotImplementedError('The method grad has not been implemented.'
'Refer to the documentation for details on '
'how to subclass this operator.')
def optimize(self):
raise NotImplementedError('The method optimize has not been implemented.'
'Refer to the documentation for details on '
'how to subclass this operator.')
@_check_tau
def prox(self, x, tau):
self.y = x
self.tau = tau
x = self.optimize()
if self.warm:
self.x0 = x
return x