# Plug and Play Priors#

In this tutorial we will consider a rather atypical proximal algorithm. In their seminal work, Venkatakrishnan et al. [2021], Plug-and-Play Priors for Model Based Reconstruction showed that the y-update in the ADMM algorithm can be interpreted as a denoising problem. The authors therefore suggested to replace the regularizer of the original problem with any denoising algorithm of choice (even if it does not have a known proximal). The proposed algorithm has shown great performance in a variety of inverse problems.

As an example, we will consider a simplified MRI experiment, where the data is created by appling a 2D Fourier Transform to the input model and by randomly sampling 60% of its values. We will use the famous BM3D as the denoiser, but any other denoiser of choice can be used instead!

Finally, whilst in the original paper, PnP is associated to the ADMM solver, subsequent research showed that the same principle can be applied to pretty much any proximal solver. We will show how to pass a solver of choice to our `pyproximal.optimization.pnp.PlugAndPlay` solver.

```import numpy as np
import matplotlib.pyplot as plt
import pylops

import pyproximal
import bm3d

from pylops.config import set_ndarray_multiplication

plt.close('all')
np.random.seed(0)
set_ndarray_multiplication(False)
```

Let’s start by loading the famous Shepp logan phantom and creating the modelling operator

```x = np.load("../testdata/shepp_logan_phantom.npy")
x = x / x.max()
ny, nx = x.shape

perc_subsampling = 0.6
nxsub = int(np.round(ny * nx * perc_subsampling))
iava = np.sort(np.random.permutation(np.arange(ny * nx))[:nxsub])
Rop = pylops.Restriction(ny * nx, iava, dtype=np.complex128)
Fop = pylops.signalprocessing.FFT2D(dims=(ny, nx))
```

We now create and display the data alongside the model

```y = Rop * Fop * x.ravel()
yfft = Fop * x.ravel()
yfft = np.fft.fftshift(yfft.reshape(ny, nx))

fig, axs = plt.subplots(1, 3, figsize=(14, 5))
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
axs[0].set_title("Model")
axs[0].axis("tight")
axs[1].imshow(np.abs(yfft), vmin=0, vmax=1, cmap="rainbow")
axs[1].set_title("Full data")
axs[1].axis("tight")
axs[2].imshow(np.abs(ymask), vmin=0, vmax=1, cmap="rainbow")
axs[2].set_title("Sampled data")
axs[2].axis("tight")
plt.tight_layout()
```

At this point we create a denoiser instance using the BM3D algorithm and use as Plug-and-Play Prior to the PG and ADMM algorithms

```def callback(x, xtrue, errhist):
errhist.append(np.linalg.norm(x - xtrue))

Op = Rop * Fop
L = np.real((Op.H*Op).eigs(neigs=1, which='LM')[0])
tau = 1./L
sigma = 0.05

l2 = pyproximal.proximal.L2(Op=Op, b=y.ravel(), niter=50, warm=True)

# BM3D denoiser
denoiser = lambda x, tau: bm3d.bm3d(np.real(x), sigma_psd=sigma * tau,
stage_arg=bm3d.BM3DStages.HARD_THRESHOLDING)

# PG-Pnp
errhistpg = []
xpnppg = pyproximal.optimization.pnp.PlugAndPlay(l2, denoiser, x.shape,
tau=tau, x0=np.zeros(x.size),
niter=40,
acceleration='fista',
show=True,
callback=lambda xx: callback(xx, x.ravel(),
errhistpg))
xpnppg = np.real(xpnppg.reshape(x.shape))

xpnpadmm = pyproximal.optimization.pnp.PlugAndPlay(l2, denoiser, x.shape,
tau=tau, x0=np.zeros(x.size),
niter=40, show=True,
callback=lambda xx: callback(xx, x.ravel(),

fig, axs = plt.subplots(1, 3, figsize=(14, 5))
axs[0].imshow(x, vmin=0, vmax=1, cmap="gray")
axs[0].set_title("Model")
axs[0].axis("tight")
axs[1].imshow(xpnppg, vmin=0, vmax=1, cmap="gray")
axs[1].set_title("PG-PnP Inversion")
axs[1].axis("tight")
axs[2].imshow(xpnpadmm, vmin=0, vmax=1, cmap="gray")
axs[2].axis("tight")
plt.tight_layout()
```
```Accelerated Proximal Gradient
---------------------------------------------------------
Proximal operator (f): <class 'pyproximal.proximal.L2.L2'>
Proximal operator (g): <class 'pyproximal.optimization.pnp._Denoise'>
tau = 0.9999999999999989        beta=5.000000e-01
epsg = 1.0      niter = 40
niterback = 100 acceleration = fista

Itn       x[0]          f           g       J=f+eps*g       tau
1   5.79539e-02   6.093e+01   0.000e+00   6.093e+01   1.000e+00
2   5.11582e-03   1.826e+01   0.000e+00   1.826e+01   1.000e+00
3  -3.01890e-02   5.364e+00   0.000e+00   5.364e+00   1.000e+00
4  -2.53920e-02   2.701e+00   0.000e+00   2.701e+00   1.000e+00
5  -2.29824e-02   2.213e+00   0.000e+00   2.213e+00   1.000e+00
6  -2.71790e-02   1.999e+00   0.000e+00   1.999e+00   1.000e+00
7  -2.83003e-02   1.852e+00   0.000e+00   1.852e+00   1.000e+00
8  -2.48831e-02   1.751e+00   0.000e+00   1.751e+00   1.000e+00
9  -2.05380e-02   1.657e+00   0.000e+00   1.657e+00   1.000e+00
10  -1.46019e-02   1.569e+00   0.000e+00   1.569e+00   1.000e+00
13  -1.43390e-03   1.434e+00   0.000e+00   1.434e+00   1.000e+00
17   2.77639e-04   1.404e+00   0.000e+00   1.404e+00   1.000e+00
21  -2.18475e-03   1.411e+00   0.000e+00   1.411e+00   1.000e+00
25  -4.27971e-03   1.421e+00   0.000e+00   1.421e+00   1.000e+00
29  -3.73545e-03   1.423e+00   0.000e+00   1.423e+00   1.000e+00
32  -2.78804e-03   1.421e+00   0.000e+00   1.421e+00   1.000e+00
33  -2.74846e-03   1.416e+00   0.000e+00   1.416e+00   1.000e+00
34  -2.69139e-03   1.415e+00   0.000e+00   1.415e+00   1.000e+00
35  -2.65875e-03   1.412e+00   0.000e+00   1.412e+00   1.000e+00
36  -2.75697e-03   1.412e+00   0.000e+00   1.412e+00   1.000e+00
37  -2.92092e-03   1.416e+00   0.000e+00   1.416e+00   1.000e+00
38  -2.98991e-03   1.420e+00   0.000e+00   1.420e+00   1.000e+00
39  -3.07033e-03   1.421e+00   0.000e+00   1.421e+00   1.000e+00
40  -3.24239e-03   1.420e+00   0.000e+00   1.420e+00   1.000e+00

Total time (s) = 52.72
---------------------------------------------------------

---------------------------------------------------------
Proximal operator (f): <class 'pyproximal.proximal.L2.L2'>
Proximal operator (g): <class 'pyproximal.optimization.pnp._Denoise'>
tau = 1.000000e+00      niter = 40

Itn       x[0]          f           g       J = f + g
1   1.15604e-02   2.216e+02   0.000e+00   2.216e+02
2  -8.56307e-03   1.145e+02   0.000e+00   1.145e+02
3  -1.17574e-02   5.280e+01   0.000e+00   5.280e+01
4  -2.08330e-02   2.375e+01   0.000e+00   2.375e+01
5  -1.84653e-02   1.123e+01   0.000e+00   1.123e+01
6  -1.71549e-02   6.032e+00   0.000e+00   6.032e+00
7  -2.03437e-02   3.873e+00   0.000e+00   3.873e+00
8  -2.14074e-02   2.922e+00   0.000e+00   2.922e+00
9  -2.13936e-02   2.446e+00   0.000e+00   2.446e+00
10  -2.20952e-02   2.195e+00   0.000e+00   2.195e+00
13  -2.01355e-02   1.846e+00   0.000e+00   1.846e+00
17  -1.56626e-02   1.628e+00   0.000e+00   1.628e+00
21  -1.09523e-02   1.528e+00   0.000e+00   1.528e+00
25  -8.01692e-03   1.469e+00   0.000e+00   1.469e+00
29  -6.14493e-03   1.454e+00   0.000e+00   1.454e+00
32  -5.21199e-03   1.440e+00   0.000e+00   1.440e+00
33  -5.15385e-03   1.436e+00   0.000e+00   1.436e+00
34  -4.66368e-03   1.438e+00   0.000e+00   1.438e+00
35  -4.60660e-03   1.437e+00   0.000e+00   1.437e+00
36  -4.48794e-03   1.437e+00   0.000e+00   1.437e+00
37  -4.28470e-03   1.437e+00   0.000e+00   1.437e+00
38  -4.17887e-03   1.428e+00   0.000e+00   1.428e+00
39  -3.96370e-03   1.426e+00   0.000e+00   1.426e+00
40  -4.15882e-03   1.424e+00   0.000e+00   1.424e+00

Total time (s) = 53.33
---------------------------------------------------------
```

Finally, let’s compare the error convergence of the two variations of PnP

```plt.figure(figsize=(12, 3))
plt.plot(errhistpg, 'k', lw=2, label='PG')