# Source code for pyproximal.optimization.palm

import time

[docs]def PALM(H, proxf, proxg, x0, y0, gammaf=1., gammag=1.,
niter=10, callback=None, show=False):
r"""Proximal Alternating Linearized Minimization

Solves the following minimization problem using the Proximal Alternating
Linearized Minimization (PALM) algorithm:

.. math::

\mathbf{x}\mathbf{,y} = \argmin_{\mathbf{x}, \mathbf{y}}
f(\mathbf{x}) + g(\mathbf{y}) + H(\mathbf{x}, \mathbf{y})

where :math:f(\mathbf{x}) and :math:g(\mathbf{y}) are any pair of
convex functions that have known proximal operators, and
:math:H(\mathbf{x}, \mathbf{y}) is a smooth function.

Parameters
----------
H : :obj:pyproximal.utils.bilinear.Bilinear
Bilinear function
proxf : :obj:pyproximal.ProxOperator
Proximal operator of f function
proxg : :obj:pyproximal.ProxOperator
Proximal operator of g function
x0 : :obj:numpy.ndarray
Initial x vector
y0 : :obj:numpy.ndarray
Initial y vector
gammaf : :obj:float, optional
Positive scalar weight for f function update
gammag : :obj:float, optional
Positive scalar weight for g function update
niter : :obj:int, optional
Number of iterations of iterative scheme
callback : :obj:callable, optional
Function with signature (callback(x)) to call after each iteration
where x and y are the current model vectors
show : :obj:bool, optional
Display iterations log

Returns
-------
x : :obj:numpy.ndarray
Inverted x vector
y : :obj:numpy.ndarray
Inverted y vector

Notes
-----
PALM [1]_ can be expressed by the following recursion:

.. math::

\mathbf{x}^{k+1} = \prox_{c_k f}(\mathbf{x}^{k} -
\frac{1}{c_k}\nabla_x H(\mathbf{x}^{k}, \mathbf{y}^{k}))\\
\mathbf{y}^{k+1} = \prox_{d_k g}(\mathbf{y}^{k} -
\frac{1}{d_k}\nabla_y H(\mathbf{x}^{k+1}, \mathbf{y}^{k}))\\

Here :math:c_k=\gamma_f L_x and :math:d_k=\gamma_g L_y, where
:math:L_x and :math:L_y are the Lipschitz constant of :math:\nabla_x H
and :math:\nabla_y H, respectively.

.. [1] Bolte, J., Sabach, S., and Teboulle, M. "Proximal alternating
linearized minimization for nonconvex and nonsmooth problems",
Mathematical Programming, vol. 146, pp. 459–494. 2014.

"""
if show:
tstart = time.time()
print('PALM algorithm\n'
'---------------------------------------------------------\n'
'Bilinear operator: %s\n'
'Proximal operator (f): %s\n'
'Proximal operator (g): %s\n'
'gammaf = %10e\tgammaf = %10e\tniter = %d\n' %
(type(H), type(proxf), type(proxg), gammaf, gammag, niter))
head = '   Itn      x[0]       y[0]        f         g         H         ck         dk'

x, y = x0.copy(), y0.copy()
for iiter in range(niter):
ck = gammaf * H.ly(y)
x = x - (1 / ck) * H.gradx(x.ravel())
if proxf is not None:
x = proxf.prox(x, ck)
H.updatex(x.copy())
dk = gammag * H.lx(x)
y = y - (1 / dk) * H.grady(y.ravel())
if proxg is not None:
y = proxg.prox(y, dk)
H.updatey(y.copy())

# run callback
if callback is not None:
callback(x, y)

if show:
pf = proxf(x) if proxf is not None else 0.
pg = proxg(y) if proxg is not None else 0.
if iiter < 10 or niter - iiter < 10 or iiter % (niter // 10) == 0:
msg = '%6g  %5.5e  %5.2e  %5.2e  %5.2e  %5.2e  %5.2e  %5.2e' % \
(iiter + 1, x[0], y[0], pf if pf is not None else 0.,
pg if pg is not None else 0., H(x, y), ck, dk)
print(msg)
if show:
print('\nTotal time (s) = %.2f' % (time.time() - tstart))
print('---------------------------------------------------------\n')
return x, y