Note
Go to the end to download the full example code.
Plug and Play Priorsยถ
In this tutorial we will consider a rather atypical proximal algorithm. In their seminal work, Venkatakrishnan et al. [2021], Plug-and-Play Priors for Model Based Reconstruction showed that the y-update in the ADMM algorithm can be interpreted as a denoising problem. The authors therefore suggested to replace the regularizer of the original problem with any denoising algorithm of choice (even if it does not have a known proximal). The proposed algorithm has shown great performance in a variety of inverse problems.
As an example, we will consider a simplified MRI experiment, where the data is created by appling a 2D Fourier Transform to the input model and by randomly sampling 60% of its values. We will use the famous BM3D as the denoiser, but any other denoiser of choice can be used instead!
Finally, whilst in the original paper, PnP is associated to the ADMM solver, subsequent
research showed that the same principle can be applied to pretty much any proximal
solver. We will show how to pass a solver of choice to our
pyproximal.optimization.pnp.PlugAndPlay solver.
import bm3d
import matplotlib.pyplot as plt
import numpy as np
import pylops
from pylops.config import set_ndarray_multiplication
from pylops.utils.metrics import snr
import pyproximal
plt.close("all")
np.random.seed(0)
set_ndarray_multiplication(False)
Letโs start by loading the famous Shepp logan phantom and creating the modelling operator
x = np.load("../testdata/shepp_logan_phantom.npy")
x = x / x.max()
ny, nx = x.shape
perc_subsampling = 0.6
nxsub = int(np.round(ny * nx * perc_subsampling))
iava = np.sort(np.random.permutation(np.arange(ny * nx))[:nxsub])
Rop = pylops.Restriction(ny * nx, iava, dtype=np.complex128)
Fop = pylops.signalprocessing.FFT2D(dims=(ny, nx))
We now create and display the data alongside the model
y = Rop * Fop * x.ravel()
yfft = Fop * x.ravel()
yfft = np.fft.fftshift(yfft.reshape(ny, nx))
ymask = Rop.mask(Fop * x.ravel())
ymask = ymask.reshape(ny, nx)
ymask.data[:] = np.fft.fftshift(ymask.data)
ymask.mask[:] = np.fft.fftshift(ymask.mask)
fig, axs = plt.subplots(1, 3, figsize=(14, 5))
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
axs[0].set_title("Model")
axs[0].axis("tight")
axs[1].imshow(np.abs(yfft), vmin=0, vmax=1, cmap="rainbow")
axs[1].set_title("Full data")
axs[1].axis("tight")
axs[2].imshow(np.abs(ymask), vmin=0, vmax=1, cmap="rainbow")
axs[2].set_title("Sampled data")
axs[2].axis("tight")
plt.tight_layout()

At this point we create a denoiser instance using the BM3D algorithm and use as Plug-and-Play Prior to the ADMM, PG and HQS algorithms
def callback(x, xtrue, errhist):
errhist.append(np.linalg.norm(x - xtrue))
Op = Rop * Fop
L = np.real((Op.H * Op).eigs(neigs=1, which="LM")[0])
tau = 1.0 / L
sigma = 0.05
# BM3D denoiser
denoiser = lambda x, tau: bm3d.bm3d(
np.real(x), sigma_psd=sigma * tau, stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING
)
# ADMM-PnP
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)
errhistadmm = []
xpnpadmm = pyproximal.optimization.pnp.PlugAndPlay(
l2,
denoiser,
x.shape,
solver=pyproximal.optimization.primal.ADMM,
tau=tau,
x0=np.zeros(x.size),
niter=40,
show=True,
callback=lambda xx: callback(xx, x.ravel(), errhistadmm),
)[0]
xpnpadmm = np.real(xpnpadmm.reshape(x.shape))
# PG-Pnp
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)
errhistpg = []
xpnppg = pyproximal.optimization.pnp.PlugAndPlay(
l2,
denoiser,
x.shape,
solver=pyproximal.optimization.primal.ProximalGradient,
tau=tau,
x0=np.zeros(x.size),
niter=40,
acceleration="fista",
show=True,
callback=lambda xx: callback(xx, x.ravel(), errhistpg),
)
xpnppg = np.real(xpnppg.reshape(x.shape))
# HQS-PnP
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)
tau_hqs = 1.0 / L * 0.99 ** (np.arange(40))
errhisthqs = []
xpnphqs = pyproximal.optimization.pnp.PlugAndPlay(
l2,
denoiser,
x.shape,
solver=pyproximal.optimization.primal.HQS,
tau=tau_hqs,
x0=np.zeros(x.size),
niter=40,
show=True,
callback=lambda xx: callback(xx, x.ravel(), errhisthqs),
)[0]
xpnphqs = np.real(xpnphqs.reshape(x.shape))
fig, axs = plt.subplots(1, 4, sharey=True, figsize=(15, 5))
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
axs[0].set_title("Model")
axs[0].axis("tight")
axs[1].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
axs[1].set_title(f"ADMM-PnP (SNR={snr(x, xpnpadmm):.2f} dB)")
axs[1].axis("tight")
axs[2].imshow(xpnppg, vmin=0, vmax=1, cmap="gray")
axs[2].set_title(f"PG-PnP (SNR={snr(x, xpnppg):.2f} dB)")
axs[2].axis("tight")
axs[3].imshow(xpnphqs, vmin=0, vmax=1, cmap="gray")
axs[3].set_title(f"HQS-PnP (SNR={snr(x, xpnphqs):.2f} dB)")
axs[3].axis("tight")
plt.tight_layout()

ADMM
---------------------------------------------------------
Proximal operator (f): <class 'pyproximal.proximal.L2.L2'>
Proximal operator (g): <class 'pyproximal.optimization.pnp._Denoise'>
tau = 1.000000e+00 niter = 40
Itn x[0] f g J = f + g
1 1.15604e-02 2.216e+02 0.000e+00 2.216e+02
2 -8.56307e-03 1.145e+02 0.000e+00 1.145e+02
3 -1.17574e-02 5.280e+01 0.000e+00 5.280e+01
4 -2.08330e-02 2.375e+01 0.000e+00 2.375e+01
5 -1.84653e-02 1.123e+01 0.000e+00 1.123e+01
6 -1.71549e-02 6.032e+00 0.000e+00 6.032e+00
7 -2.03437e-02 3.873e+00 0.000e+00 3.873e+00
8 -2.14074e-02 2.922e+00 0.000e+00 2.922e+00
9 -2.13936e-02 2.446e+00 0.000e+00 2.446e+00
10 -2.20952e-02 2.195e+00 0.000e+00 2.195e+00
13 -2.01355e-02 1.846e+00 0.000e+00 1.846e+00
17 -1.56626e-02 1.628e+00 0.000e+00 1.628e+00
21 -1.09523e-02 1.528e+00 0.000e+00 1.528e+00
25 -8.01692e-03 1.469e+00 0.000e+00 1.469e+00
29 -6.14493e-03 1.454e+00 0.000e+00 1.454e+00
32 -5.21199e-03 1.440e+00 0.000e+00 1.440e+00
33 -5.15385e-03 1.436e+00 0.000e+00 1.436e+00
34 -4.66368e-03 1.438e+00 0.000e+00 1.438e+00
35 -4.60660e-03 1.437e+00 0.000e+00 1.437e+00
36 -4.48794e-03 1.437e+00 0.000e+00 1.437e+00
37 -4.28470e-03 1.437e+00 0.000e+00 1.437e+00
38 -4.17887e-03 1.428e+00 0.000e+00 1.428e+00
39 -3.96370e-03 1.426e+00 0.000e+00 1.426e+00
40 -4.15882e-03 1.424e+00 0.000e+00 1.424e+00
Total time (s) = 32.90
---------------------------------------------------------
Accelerated Proximal Gradient
---------------------------------------------------------
Proximal operator (f): <class 'pyproximal.proximal.L2.L2'>
Proximal operator (g): <class 'pyproximal.optimization.pnp._Denoise'>
tau = 0.9999999999999989 backtrack = False beta = 5.000000e-01
epsg = 1.0 niter = 40 tol = None
niterback = 100 acceleration = fista
Itn x[0] f g J=f+eps*g tau
1 5.79539e-02 6.093e+01 0.000e+00 6.093e+01 1.000e+00
2 5.11582e-03 1.826e+01 0.000e+00 1.826e+01 1.000e+00
3 -3.01890e-02 5.364e+00 0.000e+00 5.364e+00 1.000e+00
4 -2.53920e-02 2.701e+00 0.000e+00 2.701e+00 1.000e+00
5 -2.29824e-02 2.213e+00 0.000e+00 2.213e+00 1.000e+00
6 -2.71790e-02 1.999e+00 0.000e+00 1.999e+00 1.000e+00
7 -2.83003e-02 1.852e+00 0.000e+00 1.852e+00 1.000e+00
8 -2.48831e-02 1.751e+00 0.000e+00 1.751e+00 1.000e+00
9 -2.05380e-02 1.657e+00 0.000e+00 1.657e+00 1.000e+00
10 -1.46019e-02 1.569e+00 0.000e+00 1.569e+00 1.000e+00
13 -1.43390e-03 1.434e+00 0.000e+00 1.434e+00 1.000e+00
17 2.77639e-04 1.404e+00 0.000e+00 1.404e+00 1.000e+00
21 -2.18475e-03 1.411e+00 0.000e+00 1.411e+00 1.000e+00
25 -4.27971e-03 1.421e+00 0.000e+00 1.421e+00 1.000e+00
29 -3.73545e-03 1.423e+00 0.000e+00 1.423e+00 1.000e+00
32 -2.78804e-03 1.421e+00 0.000e+00 1.421e+00 1.000e+00
33 -2.74846e-03 1.416e+00 0.000e+00 1.416e+00 1.000e+00
34 -2.69139e-03 1.415e+00 0.000e+00 1.415e+00 1.000e+00
35 -2.65875e-03 1.412e+00 0.000e+00 1.412e+00 1.000e+00
36 -2.75697e-03 1.412e+00 0.000e+00 1.412e+00 1.000e+00
37 -2.92092e-03 1.416e+00 0.000e+00 1.416e+00 1.000e+00
38 -2.98991e-03 1.420e+00 0.000e+00 1.420e+00 1.000e+00
39 -3.07033e-03 1.421e+00 0.000e+00 1.421e+00 1.000e+00
40 -3.24239e-03 1.420e+00 0.000e+00 1.420e+00 1.000e+00
Total time (s) = 32.69
---------------------------------------------------------
HQS
---------------------------------------------------------
Proximal operator (f): <class 'pyproximal.proximal.L2.L2'>
Proximal operator (g): <class 'pyproximal.optimization.pnp._Denoise'>
tau = Variable niter = 40
Itn x[0] f g J = f + g
1 1.15604e-02 2.216e+02 0.000e+00 2.216e+02
2 4.33966e-03 8.100e+01 0.000e+00 8.100e+01
3 -3.47471e-03 3.553e+01 0.000e+00 3.553e+01
4 -9.93984e-03 1.844e+01 0.000e+00 1.844e+01
5 -1.27333e-02 1.090e+01 0.000e+00 1.090e+01
6 -1.70283e-02 7.098e+00 0.000e+00 7.098e+00
7 -2.01874e-02 4.968e+00 0.000e+00 4.968e+00
8 -2.29236e-02 3.689e+00 0.000e+00 3.689e+00
9 -2.46765e-02 2.878e+00 0.000e+00 2.878e+00
10 -2.53835e-02 2.353e+00 0.000e+00 2.353e+00
13 -2.48623e-02 1.566e+00 0.000e+00 1.566e+00
17 -2.12349e-02 1.162e+00 0.000e+00 1.162e+00
21 -1.61944e-02 9.502e-01 0.000e+00 9.502e-01
25 -1.18030e-02 8.127e-01 0.000e+00 8.127e-01
29 -8.41310e-03 7.208e-01 0.000e+00 7.208e-01
32 -6.73516e-03 6.670e-01 0.000e+00 6.670e-01
33 -6.14250e-03 6.532e-01 0.000e+00 6.532e-01
34 -5.66455e-03 6.403e-01 0.000e+00 6.403e-01
35 -5.18308e-03 6.288e-01 0.000e+00 6.288e-01
36 -4.82678e-03 6.168e-01 0.000e+00 6.168e-01
37 -4.32120e-03 6.063e-01 0.000e+00 6.063e-01
38 -3.96257e-03 5.948e-01 0.000e+00 5.948e-01
39 -3.57748e-03 5.843e-01 0.000e+00 5.843e-01
40 -3.25554e-03 5.748e-01 0.000e+00 5.748e-01
Total time (s) = 32.76
---------------------------------------------------------
Finally, the attentive reader may have noticed that in the HQS server a continuation strategy was used for the tau parameter; whilst this is strictly needed for HQS to converge, there is a consensus in the literature that also other solvers should benefit from adopting the same strategy when used with a PnP prior. This can be in fact interpreted as reducing the strength of the denoiser as iterations progress and the estimate comes closer to the true solution.
While our pyproximal.optimization.primal.ADMM solver does currently
not offer relaxation out-of-the-box, this can be achieved pretty easily
by creating an auxiliary Denoiser class with a decay parameter as
shown below.
class Denoiser:
def __init__(self, sigma, decay):
self.sigma = sigma
self.decay = decay
self.iiter = 0
def denoise(self, x, tau):
xden = bm3d.bm3d(
np.real(x),
sigma_psd=self.decay[self.iiter] * self.sigma * tau,
stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING,
)
self.iiter += 1
return xden
# ADMM-PnP with relaxation
denoiser = Denoiser(sigma, decay=0.99 ** (np.arange(40)))
l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)
errhistadmm1 = []
xpnpadmm1 = pyproximal.optimization.pnp.PlugAndPlay(
l2,
denoiser.denoise,
x.shape,
solver=pyproximal.optimization.primal.ADMM,
tau=tau,
x0=np.zeros(x.size),
niter=40,
show=True,
callback=lambda xx: callback(xx, x.ravel(), errhistadmm1),
)[0]
xpnpadmm1 = np.real(xpnpadmm1.reshape(x.shape))
fig, axs = plt.subplots(1, 3, sharey=True, figsize=(15, 5))
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
axs[0].set_title("Model")
axs[0].axis("tight")
axs[1].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
axs[1].set_title(f"ADMM-PnP (SNR={snr(x, xpnpadmm):.2f} dB)")
axs[1].axis("tight")
axs[2].imshow(xpnpadmm1, vmin=0, vmax=1, cmap="gray")
axs[2].set_title(f"ADMM-PnP with rel. (SNR={snr(x, xpnpadmm1):.2f} dB)")
axs[2].axis("tight")
plt.tight_layout()

ADMM
---------------------------------------------------------
Proximal operator (f): <class 'pyproximal.proximal.L2.L2'>
Proximal operator (g): <class 'pyproximal.optimization.pnp._Denoise'>
tau = 1.000000e+00 niter = 40
Itn x[0] f g J = f + g
1 1.15604e-02 2.216e+02 0.000e+00 2.216e+02
2 -8.56307e-03 1.145e+02 0.000e+00 1.145e+02
3 -9.34141e-03 5.266e+01 0.000e+00 5.266e+01
4 -2.22792e-02 2.353e+01 0.000e+00 2.353e+01
5 -2.24509e-02 1.096e+01 0.000e+00 1.096e+01
6 -1.70483e-02 5.774e+00 0.000e+00 5.774e+00
7 -1.95942e-02 3.616e+00 0.000e+00 3.616e+00
8 -2.06861e-02 2.642e+00 0.000e+00 2.642e+00
9 -2.17341e-02 2.142e+00 0.000e+00 2.142e+00
10 -2.14387e-02 1.851e+00 0.000e+00 1.851e+00
13 -2.11900e-02 1.460e+00 0.000e+00 1.460e+00
17 -1.66471e-02 1.175e+00 0.000e+00 1.175e+00
21 -1.17995e-02 9.945e-01 0.000e+00 9.945e-01
25 -8.75420e-03 8.533e-01 0.000e+00 8.533e-01
29 -6.17793e-03 7.465e-01 0.000e+00 7.465e-01
32 -4.97724e-03 6.927e-01 0.000e+00 6.927e-01
33 -4.49753e-03 6.773e-01 0.000e+00 6.773e-01
34 -4.11369e-03 6.620e-01 0.000e+00 6.620e-01
35 -4.28981e-03 6.469e-01 0.000e+00 6.469e-01
36 -4.09005e-03 6.334e-01 0.000e+00 6.334e-01
37 -3.53140e-03 6.211e-01 0.000e+00 6.211e-01
38 -3.25751e-03 6.101e-01 0.000e+00 6.101e-01
39 -3.31128e-03 6.010e-01 0.000e+00 6.010e-01
40 -3.05932e-03 5.908e-01 0.000e+00 5.908e-01
Total time (s) = 32.77
---------------------------------------------------------
Letโs finally compare the error convergence of the four variations of PnP
plt.figure(figsize=(12, 3))
plt.semilogy(errhistadmm, "k", lw=2, label="ADMM")
plt.semilogy(errhistpg, "r", lw=2, label="PG")
plt.semilogy(errhisthqs, "b", lw=2, label="HQS")
plt.semilogy(errhistadmm1, "--b", lw=2, label="ADMM with rel.")
plt.title("Error norm")
plt.legend()
plt.tight_layout()

This final results clearly shows the importance of relaxation also for ADMM.
Total running time of the script: (2 minutes 12.534 seconds)