Source code for pyproximal.projection.L0
import numpy as np
from pyproximal.projection import SimplexProj
[docs]class L0BallProj():
r"""L0 ball projection.
Parameters
----------
radius : :obj:`int`
Radius
Notes
-----
Given an L0 ball defined as:
.. math::
L0_{r} = \{ \mathbf{x}: ||\mathbf{x}||_0 \leq r \}
its orthogonal projection is computed by finding the :math:`r` highest
largest entries of :math:`\mathbf{x}` (in absolute value), keeping those
and zero-ing all the other entries.
Note that this is the proximal operator of the corresponding
indicator function :math:`\mathcal{I}_{L0_{r}}`.
"""
def __init__(self, radius):
self.radius = int(radius)
def __call__(self, x):
xshape = x.shape
xf = x.copy().flatten()
xf[np.argsort(np.abs(xf))[:-self.radius]] = 0
return xf.reshape(xshape)