Source code for pyproximal.proximal.GenericIntersection

from typing import Any, Callable, List

import numpy as np
from pylops.utils.typing import NDArray

from pyproximal.projection import GenericIntersectionProj
from pyproximal.ProxOperator import ProxOperator, _check_tau


[docs]class GenericIntersectionProx(ProxOperator): r"""The proximal operator corresponding to the convex projection to the intersection of convex sets using Dykstra's algorithm. Parameters ---------- projections : :obj:`list` A list of projection functions :math:`P_1, \ldots, P_m`. niter : :obj:`int`, optional, default=1000 The maximum number of iterations. tol : :obj:`float`, optional, default=1e-6 Tolerance on change of the solution (used as stopping criterion). If ``tol=0``, run until ``niter`` is reached. use_parallel : :obj:`bool`, optional, default=False If True, use the parallel version when $m=2$. Notes ----- As the intersection of convex sets is an indicator function, the proximal operator corresponds to its convex projection (see :class:`pyproximal.projection.GenericIntersectionProj` for details). See also -------- pyproximal.projection.GenericIntersectionProj : The corresponding convex projection. """ def __init__( self, projections: List[Callable[[NDArray], NDArray]], niter: int = 1000, tol: float = 1e-6, use_parallel: bool = False, ) -> None: super().__init__(None, False) self.projections = projections # The tolerance for the indicator function is set to 10 times larger # than the tolerance used in Dykstra's projection. This is because # using the same tolerance does not guarantee that the condition # will hold even after the convergence of Dykstra's algorithm. self.tol = tol * 10 self.generic_intersection = GenericIntersectionProj( projections=self.projections, niter=niter, tol=tol, use_parallel=use_parallel, ) def __call__(self, x: NDArray) -> bool: return all(np.abs(x - proj(x)).max() < self.tol for proj in self.projections) @_check_tau def prox(self, x: NDArray, tau: float, **kwargs: Any) -> NDArray: return self.generic_intersection(x)