Source code for pyproximal.proximal.VStack

from typing import TYPE_CHECKING

import numpy as np
from pylops.utils.typing import NDArray, ShapeLike

from pyproximal.ProxOperator import ProxOperator, _check_tau

if TYPE_CHECKING:
    from pylops.linearoperator import LinearOperator


[docs] class VStack(ProxOperator): r"""Vertical stacking. Stack a set of N proximal operators vertically. This operator can be used for separable inputs, where the overall proximal operator can be computed as the stack of proximal operators on parts of the input vector. Parameters ---------- ops : :obj:`list` Proximal operators to be stacked nn : :obj:`list`, optional Size of each portion of the input vector (to be used when different portions are in consecutive order) restr : :obj:`list`, optional List of :class:`pylops.Restriction` operators extracting the subset of interest (to be used when different portions are not in consecutive order). It is user responsibility to ensure that all elements of the input vector are used exactly once) Notes ----- Given an input vector :math:`\mathbf{x}` to which a number of :math:`N` functions are applied to different portions of the vector as: .. math:: f(\mathbf{x}) = \sum_{i=1}^N f_i(\mathbf{x}_i) the related proximal operator becomes: .. math:: \prox_{\tau f}(\mathbf{x}) = \left( \prox_{\tau f_1}(\mathbf{x}_1), \ldots, \tau f_N(\mathbf{x}_N) \right) """ def __init__( self, ops: list["LinearOperator"], nn: list[ShapeLike] | None = None, restr: list["LinearOperator"] | None = None, ) -> None: super().__init__(None, any(op.hasgrad for op in ops)) self.ops = ops if nn is not None: self.nn = nn cum_nn = np.cumsum(nn) self.xin = cum_nn[:-1] self.xin = np.insert(self.xin, 0, 0) self.xend = cum_nn # store required size of input self.nx = cum_nn[-1] elif restr is not None: self.restr = restr # store required size of input self.nx = np.sum([restr.iava.size for restr in self.restr]) else: msg = "Provide either nn or restr" raise ValueError(msg) def __call__(self, x: NDArray) -> float: if x.size != self.nx: msg = ( f"x must have size {self.nx}, instead the provided x has size {x.size}" ) raise ValueError(msg) f = 0.0 if hasattr(self, "nn"): for iop, op in enumerate(self.ops): f += op(x[self.xin[iop] : self.xend[iop]]) else: for op, restr in zip(self.ops, self.restr, strict=True): f += op(restr.matvec(x)) return float(f) @_check_tau def prox(self, x: NDArray, tau: float) -> NDArray: if x.size != self.nx: msg = ( f"x must have size {self.nx}, instead the provided x has size {x.size}" ) raise ValueError(msg) if hasattr(self, "nn"): f = np.hstack( [ op.prox(x[self.xin[iop] : self.xend[iop]], tau) for iop, op in enumerate(self.ops) ] ) else: f = np.zeros_like(x) for op, restr in zip(self.ops, self.restr, strict=True): f[restr.iava] = op.prox(restr.matvec(x), tau) return f def grad(self, x: NDArray) -> NDArray: if x.size != self.nx: msg = ( f"x must have size {self.nx}, instead the provided x has size {x.size}" ) raise ValueError(msg) if hasattr(self, "nn"): f = np.hstack( [ op.grad(x[self.xin[iop] : self.xend[iop]]) for iop, op in enumerate(self.ops) ] ) else: f = np.zeros_like(x) for op, restr in zip(self.ops, self.restr, strict=True): f[restr.iava] = op.grad(restr.matvec(x)) return f