Source code for pyproximal.proximal.Intersection
import numpy as np
from pyproximal.ProxOperator import _check_tau
from pyproximal import ProxOperator
from pyproximal.projection import IntersectionProj
[docs]class Intersection(ProxOperator):
r"""Intersection of multiple convex sets operator.
Parameters
----------
k : :obj:`int`
Size of vector to be projected
n : :obj:`int`
Number of vectors to be projected simultaneously
sigma : :obj:`np.ndarray` or :obj:`int`
Matrix of distances of size :math:`k \times k` (or single value in the
case of constant matrix)
k : :obj:`int`, optional
Number of iterations
tol : :obj:`float`, optional
Toleance of update
call : :obj:`bool`, optional
Evalutate call method (``True``) or not (``False``)
Notes
-----
As the Intersection is an indicator function, the proximal operator
corresponds to its orthogonal projection (see
:class:`pyproximal.projection.IntersectionProj` for details.
"""
def __init__(self, k, n, sigma, niter=100, tol=1e-5, call=True):
super().__init__(None, False)
self.k, self.n = k, n
self.sigma = sigma if isinstance(sigma, np.ndarray) \
else sigma * np.ones((k, k))
self.call = call
self.ic = IntersectionProj(k, n, sigma, niter=niter, tol=tol)
def __call__(self, x, tol=1e-8):
if not self.call:
return False
x = x.reshape(self.k, self.n)
for i in range(self.n):
for i1 in range(self.k - 1):
for i2 in range(i1 + 1, self.k):
if np.abs(x[i1, i] - x[i2, i]) > self.sigma[i1, i2] + tol:
return False
return True
@_check_tau
def prox(self, x, tau):
return self.ic(x)