from collections.abc import Callable
from typing import Any
from pylops.utils.backend import get_array_module
from pylops.utils.typing import NDArray
from pyproximal.proximal._Dykstra import (
_select_impl_by_arity,
dykstra_two,
parallel_dykstra_prox,
)
from pyproximal.ProxOperator import ProxOperator, _check_tau
[docs]
class Sum(ProxOperator):
r"""Proximal operator of the sum of proximable functions
using Dykstra-like algorithm.
Parameters
----------
ops : :obj:`list`
A list of proximable functions :math:`f_1, \ldots, f_m`.
weights : :obj:`numpy.ndarray` or :obj:`list` or :obj:`None`, optional, default=None
Weights :math:`\sum_{i=1}^m w_i = 1, \ 0 < w_i < 1`,
used when :math:`m > 2`, or :math:`m = 2` and ``use_parallel=True``.
Defaults to None, which means :math:`w_1 = \cdots = w_m = \frac{1}{m}.`
niter : :obj:`int`, optional, default=1000
The maximum number of iterations.
tol : :obj:`float`, optional, default=1e-7
Tolerance on change of the solution (used as stopping criterion).
If ``tol=0``, run until ``niter`` is reached.
use_parallel : :obj:`bool`, optional, default=False
The parallel version is used when :math:`m > 2`,
or :math:`m = 2` and `use_parallel=True`.
use_original_tau : :obj:`bool`, optional, default=False
Use the original value of :math:`\tau` (``True``)
or the scaled version :math:`\tau_i = \tau / w_i` (``False``).
Notes
-----
Given two functions :math:`f` and :math:`g`, or a set of proximable functions
:math:`f_i` and corresponding weights :math:`w_i` for :math:`i=1, \ldots, m`,
this class computes the proximal operator of the sum of two functions
.. math:: \prox_{\tau (f + g)}
using the Dykstra-like algorithm, or of the weighted sum of functions
.. math:: \prox_{\tau \ \sum_{i=1}^m w_i f_i}
using the parallel Dykstra-like algorithm.
For :math:`m=2`, the proximal mapping :math:`\prox_{\tau (f + g)}(\mathbf{x})` of
:math:`\mathbf{x}` is computed by the Dykstra-like algorithm [1]_, [2]_:
* :math:`\mathbf{x}^0 = \mathbf{x}, \mathbf{p}^0 = \mathbf{q}^0 = \mathbf{0}`
* for :math:`k = 1, \ldots`
* :math:`\mathbf{y}^k = \prox_{\tau g}(\mathbf{x}^k + \mathbf{p}^k)`
* :math:`\mathbf{p}^{k+1} = \mathbf{p}^k + \mathbf{x}^k - \mathbf{y}^k`
* :math:`\mathbf{x}^{k+1} = \prox_{\tau f}(\mathbf{y}^k + \mathbf{q}^k)`
* :math:`\mathbf{q}^{k+1} = \mathbf{q}^k + \mathbf{y}^k - \mathbf{x}^{k+1}`
For :math:`m \ge 2`, the proximal mapping :math:`\prox_{\tau \sum_{i=1}^m w_i f_i}(\mathbf{x})`
of :math:`\mathbf{x}` is computed by
the parallel Dykstra-like algorithm [3]_, [4]_, [5]_,
where :math:`\sum_{i=1}^m w_i = 1, \ 0 < w_i < 1`:
* :math:`\mathbf{x}^0 = \mathbf{z}_1^0 = \cdots = \mathbf{z}_m^0 = \mathbf{x}`
* for :math:`k = 1, \ldots`
* :math:`\mathbf{x}^{k+1} = \sum_{i=1}^{m} w_i \prox_{\tau_i f_i} (\mathbf{z}_{i}^k)`
* for :math:`i = 1, \ldots, m`
* :math:`\mathbf{z}_{i}^{k+1} = \mathbf{z}_{i}^k + \mathbf{x}^{k+1} - \prox_{\tau_i f_i} (\mathbf{z}_{i}^k)`
Note that :math:`\tau_i = \tau / w_i` if ``use_original_tau==False`` (default),
otherwise :math:`\tau_i = \tau`.
References
----------
.. [1] Combettes, P.L., Pesquet, J.-C., 2011. Proximal Splitting Methods in
Signal Processing, in Fixed-Point Algorithms for Inverse Problems in
Science and Engineering, Springer, pp. 185-212. Algorithm 10.18.
https://doi.org/10.1007/978-1-4419-9569-8_10
.. [2] Bauschke, H.H., Combettes, P.L., 2008. A Dykstra-like algorithm for two
monotone operators. Pacific Journal of Pitimization 4, 383-391.
Theorem 3.3.
http://www.ybook.co.jp/online-p/PJO/vol4/pjov4n3p383.pdf
.. [3] Combettes, P.L., Pesquet, J.-C., 2011. Proximal Splitting Methods in
Signal Processing, in Fixed-Point Algorithms for Inverse Problems in
Science and Engineering, Springer, pp. 185-212. Algorithm 10.31.
https://doi.org/10.1007/978-1-4419-9569-8_10
.. [4] Combettes, P.L., Dũng, Đ., Vũ, B.C., 2011. Proximity for sums of composite
functions. Journal of Mathematical Analysis and Applications 380, 680-688.
Eq. (2.26)
https://doi.org/10.1016/j.jmaa.2011.02.079
.. [5] Combettes, P.L., 2009. Iterative Construction of the Resolvent of a Sum of
Maximal Monotone Operators. Journal of Convex Analysis 16, 727-748.
Theorem 4.2.
https://www.heldermann.de/JCA/JCA16/JCA163/jca16044.htm
See also
--------
projection.GenericIntersectionProj :
The convex projection to the intersection of convex sets
using Dykstra's algorithm.
"""
def __init__(
self,
ops: list[ProxOperator],
weights: NDArray | list[float] | None = None,
niter: int = 1000,
tol: float = 1e-7,
use_parallel: bool = False,
use_original_tau: bool = False,
) -> None:
super().__init__(None, False)
self.ops = ops
self.niter = niter
self.tol = tol
self.use_original_tau = use_original_tau
if weights is None:
self.w = [1.0 / len(self.ops)] * len(self.ops)
else:
self.w = weights
self._prox = _select_impl_by_arity(
ops,
use_parallel=use_parallel,
single=self._single_prox,
two=self._two_prox,
more=self._more_prox,
)
def __call__(self, x: NDArray) -> bool | float:
"""Evaluate proximable functions
Parameters
----------
x : :obj:`numpy.ndarray`
Vector
Returns
-------
:obj:`bool` or :obj:`float`
- return ``False`` immediately if any boolean-type ops is ``False``
- return the sum of numeric-type ops values if all boolean-type ops are ``True``
- return ``True`` if all ops are boolean-type (no numeric-type ops) and ``True``
"""
# logic inspired by https://github.com/PyLops/pyproximal/issues/116
ncp = get_array_module(x)
def is_bool(v: bool | float) -> bool:
return isinstance(v, (bool, ncp.bool_))
prox_vals = [op(x) for op in self.ops]
bools, vals = [], []
for v in prox_vals:
if is_bool(v):
bools.append(v)
else:
vals.append(float(v))
if bools and not all(bools):
return False
if vals:
return sum(vals)
return True
@_check_tau
def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray:
r"""compute :math:`\prox_{\tau \ f}(\mathbf{x})`` of :math:`\mathbf{x}`."""
return self._prox(x, tau)
def _single_prox(self, x0: NDArray, tau: float) -> NDArray:
r"""Compute :math:`\prox_{\tau \ f}(\mathbf{x})` for :math:`m = 1`."""
if len(self.ops) != 1:
msg = "len(ops) should be 1"
raise ValueError(msg)
return self.ops[0].prox(x0, tau)
def _two_prox(self, x0: NDArray, tau: float) -> NDArray:
r"""Compute :math:`\prox_{\tau \ f + g}(\mathbf{x})` for :math:`m = 2`."""
if len(self.ops) != 2:
msg = "len(ops) should be 2"
raise ValueError(msg)
def bind_tau(
prox: Callable[[NDArray, float], NDArray],
tau: float,
) -> Callable[[NDArray], NDArray]:
return lambda x: prox(x, tau)
step1, step2 = [bind_tau(op.prox, tau) for op in self.ops]
return dykstra_two(
x0,
step1,
step2,
niter=self.niter,
tol=self.tol,
)
def _more_prox(self, x0: NDArray, tau: float) -> NDArray:
r"""Compute :math:`\prox_{\tau \ \sum_{i=1}^m w_i f_i}(\mathbf{x})`
for :math:`m \ge 2`.
"""
def tau_policy(tau: float, w: NDArray | list[float]) -> list[float]:
if self.use_original_tau:
# legacy: all prox_i use the same tau
return [tau] * len(w)
# PPXA-like scaling: tau_i = T / w_i
return [tau / wi for wi in w]
if len(self.ops) < 2:
msg = "len(ops) should be 2 or larger"
raise ValueError(msg)
return parallel_dykstra_prox(
x0,
prox_ops=[op.prox for op in self.ops],
weights=self.w,
taus=tau_policy(tau, self.w),
niter=self.niter,
tol=self.tol,
)