Denoising#

This tutorial considers the classical problem of denoising of images affected by either random noise or salt-and-pepper noise using proximal algorithms.

The overall cost function to minimize is written in the following form:

\[\argmin_\mathbf{u} \frac{1}{2}\|\mathbf{u}-\mathbf{f}\|_2^2 + \sigma J(\mathbf{u})\]

where the L2 norm in the data term can be replaced by a L1 norm for salt-and-pepper (outlier like noise).

For both examples we investigate with different choices of regularization:

  • L2 on Gradient \(J(\mathbf{u}) = \|\nabla \mathbf{u}\|_2^2\)

  • Anisotropic TV \(J(\mathbf{u}) = \|\nabla \mathbf{u}\|_1\)

  • Isotropic TV \(J(\mathbf{u}) = \|\nabla \mathbf{u}\|_{2,1}\)

import numpy as np
import matplotlib.pyplot as plt
import pylops
from scipy import misc

import pyproximal

plt.close('all')

Let’s start by loading a sample image and adding some noise

# Load image
img = misc.ascent()
img = img / np.max(img)
ny, nx = img.shape

# Add noise
sigman = .2
n = sigman * np.max(abs(img.ravel())) * np.random.uniform(-1, 1, img.shape)
noise_img = img + n
/home/docs/checkouts/readthedocs.org/user_builds/pyproximal/checkouts/latest/tutorials/denoising.py:36: DeprecationWarning: scipy.misc.ascent has been deprecated in SciPy v1.10.0; and will be completely removed in SciPy v1.12.0. Dataset methods have moved into the scipy.datasets module. Use scipy.datasets.ascent instead.
  img = misc.ascent()

We can now define a pylops.Gradient operator that we are going to use for all regularizers

# Gradient operator
sampling = 1.
Gop = pylops.Gradient(dims=(ny, nx), sampling=sampling, edge=False,
                      kind='forward', dtype='float64')
L = 8. / sampling ** 2 # maxeig(Gop^H Gop)

We then consider the first regularization (L2 norm on Gradient). We expect to get a smooth image where noise is suppressed by sharp edges in the original image are however lost.

# L2 data term
l2 = pyproximal.L2(b=noise_img.ravel())

# L2 regularization
sigma = 2.
thik = pyproximal.L2(sigma=sigma)

# Solve
tau = 1.
mu = 1. / (tau*L)

iml2 = pyproximal.optimization.primal.LinearizedADMM(l2, thik,
                                                     Gop, tau=tau,
                                                     mu=mu,
                                                     x0=np.zeros_like(img.ravel()),
                                                     niter=100)[0]
iml2 = iml2.reshape(img.shape)

Let’s try now to use TV regularization, both anisotropic and isotropic

# L2 data term
l2 = pyproximal.L2(b=noise_img.ravel())

# Anisotropic TV
sigma = .1
l1 = pyproximal.L1(sigma=sigma)

# Solve
tau = 1.
mu = tau / L

iml1 = pyproximal.optimization.primal.LinearizedADMM(l2, l1, Gop, tau=tau,
                                                     mu=mu, x0=np.zeros_like(img.ravel()),
                                                     niter=100)[0]
iml1 = iml1.reshape(img.shape)


# Isotropic TV with Proximal Gradient
sigma = .1
tv = pyproximal.TV(dims=img.shape, sigma=sigma)

# Solve
tau = 1 / L

imtv = pyproximal.optimization.primal.ProximalGradient(l2, tv, tau=tau, x0=np.zeros_like(img.ravel()),
                                                       niter=100)
imtv = imtv.reshape(img.shape)

# Isotropic TV with Primal Dual
sigma = .1
l1iso = pyproximal.L21(ndim=2, sigma=sigma)

# Solve
tau = 1 / np.sqrt(L)
mu = 1. / (tau*L)

iml12 = pyproximal.optimization.primaldual.PrimalDual(l2, l1iso, Gop,
                                                      tau=tau, mu=mu, theta=1.,
                                                      x0=np.zeros_like(img.ravel()),
                                                      niter=100)
iml12 = iml12.reshape(img.shape)

fig, axs = plt.subplots(1, 5, figsize=(14, 4))
axs[0].imshow(img, cmap='gray', vmin=0, vmax=1)
axs[0].set_title('Original')
axs[0].axis('off')
axs[0].axis('tight')
axs[1].imshow(noise_img, cmap='gray', vmin=0, vmax=1)
axs[1].set_title('Noisy')
axs[1].axis('off')
axs[1].axis('tight')
axs[2].imshow(iml1, cmap='gray', vmin=0, vmax=1)
axs[2].set_title('TVaniso')
axs[2].axis('off')
axs[2].axis('tight')
axs[3].imshow(imtv, cmap='gray', vmin=0, vmax=1)
axs[3].set_title('TViso (with ProxGrad)')
axs[3].axis('off')
axs[3].axis('tight')
axs[4].imshow(iml12, cmap='gray', vmin=0, vmax=1)
axs[4].set_title('TViso (with PD)')
axs[4].axis('off')
axs[4].axis('tight')
plt.tight_layout()
Original, Noisy, TVaniso, TViso (with ProxGrad), TViso (with PD)

Finally we consider an example where the original image is corrupted by salt-and-pepper noise.

# Add salt and pepper noise
noiseperc = .1

isalt = np.random.permutation(np.arange(ny*nx))[:int(noiseperc*ny*nx)]
ipepper = np.random.permutation(np.arange(ny*nx))[:int(noiseperc*ny*nx)]
noise_img = img.copy().ravel()
noise_img[isalt] = img.max()
noise_img[ipepper] = img.min()
noise_img = noise_img.reshape(ny, nx)

Here we compare L2 and L1 norms for the data term L2 data term

l2 = pyproximal.L2(b=noise_img.ravel())

# L1 regularization (isotropic TV)
sigma = .2
l1iso = pyproximal.L21(ndim=2, sigma=sigma)

# Solve
tau = .1
mu = 1. / (tau*L)

iml12_l2 = pyproximal.optimization.primaldual.PrimalDual(l2, l1iso, Gop,
                                                         tau=tau, mu=mu, theta=1.,
                                                         x0=np.zeros_like(noise_img).ravel(),
                                                         niter=100, show=True)
iml12_l2 = iml12_l2.reshape(img.shape)


# L1 data term
l1 = pyproximal.L1(g=noise_img.ravel())

# L1 regularization (isotropic TV)
sigma = .7
l1iso = pyproximal.L21(ndim=2, sigma=sigma)

# Solve
tau = 1.
mu = 1. / (tau*L)

iml12_l1 = pyproximal.optimization.primaldual.PrimalDual(l1, l1iso, Gop,
                                                         tau=tau, mu=mu, theta=1.,
                                                         x0=np.zeros_like(noise_img).ravel(),
                                                         niter=100, show=True)
iml12_l1 = iml12_l1.reshape(img.shape)

fig, axs = plt.subplots(2, 2, figsize=(14, 14))
axs[0][0].imshow(img, cmap='gray', vmin=0, vmax=1)
axs[0][0].set_title('Original')
axs[0][0].axis('off')
axs[0][0].axis('tight')
axs[0][1].imshow(noise_img, cmap='gray', vmin=0, vmax=1)
axs[0][1].set_title('Noisy')
axs[0][1].axis('off')
axs[0][1].axis('tight')
axs[1][0].imshow(iml12_l2, cmap='gray', vmin=0, vmax=1)
axs[1][0].set_title('L2data + TViso')
axs[1][0].axis('off')
axs[1][0].axis('tight')
axs[1][1].imshow(iml12_l1, cmap='gray', vmin=0, vmax=1)
axs[1][1].set_title('L1data + TViso')
axs[1][1].axis('off')
axs[1][1].axis('tight')
plt.tight_layout()
Original, Noisy, L2data + TViso, L1data + TViso
Primal-dual: min_x f(Ax) + x^T z + g(x)
---------------------------------------------------------
Proximal operator (f): <class 'pyproximal.proximal.L2.L2'>
Proximal operator (g): <class 'pyproximal.proximal.L21.L21'>
Linear operator (A): <class 'pylops.basicoperators.gradient.Gradient'>
Additional vector (z): None
tau = 0.1               mu = 1.25
theta = 1.00            niter = 100

   Itn       x[0]          f           g          z^x       J = f + g + z^x
     1   2.95900e-02   2.328e+04   1.474e+03   0.000e+00       2.476e+04
     2   4.97650e-02   2.021e+04   1.654e+03   0.000e+00       2.186e+04
     3   6.61798e-02   1.779e+04   1.606e+03   0.000e+00       1.939e+04
     4   8.11699e-02   1.579e+04   1.556e+03   0.000e+00       1.735e+04
     5   9.56866e-02   1.414e+04   1.539e+03   0.000e+00       1.568e+04
     6   1.10187e-01   1.276e+04   1.542e+03   0.000e+00       1.430e+04
     7   1.24702e-01   1.161e+04   1.556e+03   0.000e+00       1.316e+04
     8   1.38884e-01   1.065e+04   1.578e+03   0.000e+00       1.222e+04
     9   1.52242e-01   9.842e+03   1.603e+03   0.000e+00       1.145e+04
    10   1.64381e-01   9.172e+03   1.630e+03   0.000e+00       1.080e+04
    11   1.75147e-01   8.613e+03   1.656e+03   0.000e+00       1.027e+04
    21   2.50045e-01   6.262e+03   1.826e+03   0.000e+00       8.088e+03
    31   2.95632e-01   5.856e+03   1.892e+03   0.000e+00       7.748e+03
    41   3.05298e-01   5.774e+03   1.916e+03   0.000e+00       7.690e+03
    51   3.10402e-01   5.754e+03   1.925e+03   0.000e+00       7.678e+03
    61   3.12550e-01   5.748e+03   1.927e+03   0.000e+00       7.675e+03
    71   3.12918e-01   5.745e+03   1.928e+03   0.000e+00       7.673e+03
    81   3.13311e-01   5.744e+03   1.928e+03   0.000e+00       7.672e+03
    91   3.13517e-01   5.744e+03   1.928e+03   0.000e+00       7.672e+03
    92   3.13527e-01   5.744e+03   1.928e+03   0.000e+00       7.672e+03
    93   3.13536e-01   5.744e+03   1.928e+03   0.000e+00       7.672e+03
    94   3.13544e-01   5.744e+03   1.928e+03   0.000e+00       7.672e+03
    95   3.13551e-01   5.744e+03   1.928e+03   0.000e+00       7.672e+03
    96   3.13557e-01   5.744e+03   1.928e+03   0.000e+00       7.672e+03
    97   3.13562e-01   5.744e+03   1.927e+03   0.000e+00       7.672e+03
    98   3.13566e-01   5.744e+03   1.927e+03   0.000e+00       7.671e+03
    99   3.13570e-01   5.744e+03   1.927e+03   0.000e+00       7.671e+03
   100   3.13573e-01   5.744e+03   1.927e+03   0.000e+00       7.671e+03

Total time (s) = 2.73
---------------------------------------------------------

Primal-dual: min_x f(Ax) + x^T z + g(x)
---------------------------------------------------------
Proximal operator (f): <class 'pyproximal.proximal.L1.L1'>
Proximal operator (g): <class 'pyproximal.proximal.L21.L21'>
Linear operator (A): <class 'pylops.basicoperators.gradient.Gradient'>
Additional vector (z): None
tau = 1.0               mu = 0.125
theta = 1.00            niter = 100

   Itn       x[0]          f           g          z^x       J = f + g + z^x
     1   3.25490e-01   9.646e+04   5.675e+04   0.000e+00       1.532e+05
     2   3.25490e-01   9.646e+04   5.675e+04   0.000e+00       1.532e+05
     3   3.25490e-01   9.461e+04   5.122e+04   0.000e+00       1.458e+05
     4   3.25490e-01   9.028e+04   3.647e+04   0.000e+00       1.267e+05
     5   3.25490e-01   8.691e+04   2.683e+04   0.000e+00       1.137e+05
     6   3.25490e-01   8.564e+04   2.375e+04   0.000e+00       1.094e+05
     7   3.25490e-01   8.555e+04   2.110e+04   0.000e+00       1.067e+05
     8   3.25490e-01   8.663e+04   1.727e+04   0.000e+00       1.039e+05
     9   3.25490e-01   8.783e+04   1.470e+04   0.000e+00       1.025e+05
    10   3.25490e-01   8.878e+04   1.324e+04   0.000e+00       1.020e+05
    11   3.25490e-01   8.933e+04   1.210e+04   0.000e+00       1.014e+05
    21   3.25490e-01   8.939e+04   7.485e+03   0.000e+00       9.687e+04
    31   3.25490e-01   8.950e+04   6.782e+03   0.000e+00       9.628e+04
    41   3.25490e-01   8.956e+04   6.542e+03   0.000e+00       9.610e+04
    51   3.25490e-01   8.959e+04   6.410e+03   0.000e+00       9.600e+04
    61   3.25490e-01   8.961e+04   6.319e+03   0.000e+00       9.593e+04
    71   3.25490e-01   8.962e+04   6.244e+03   0.000e+00       9.587e+04
    81   3.25490e-01   8.963e+04   6.190e+03   0.000e+00       9.582e+04
    91   3.25490e-01   8.963e+04   6.147e+03   0.000e+00       9.578e+04
    92   3.25490e-01   8.963e+04   6.143e+03   0.000e+00       9.577e+04
    93   3.25490e-01   8.963e+04   6.139e+03   0.000e+00       9.577e+04
    94   3.25490e-01   8.963e+04   6.134e+03   0.000e+00       9.576e+04
    95   3.25490e-01   8.963e+04   6.130e+03   0.000e+00       9.576e+04
    96   3.25490e-01   8.963e+04   6.126e+03   0.000e+00       9.576e+04
    97   3.25490e-01   8.963e+04   6.122e+03   0.000e+00       9.575e+04
    98   3.25490e-01   8.963e+04   6.117e+03   0.000e+00       9.575e+04
    99   3.25490e-01   8.963e+04   6.113e+03   0.000e+00       9.574e+04
   100   3.25490e-01   8.963e+04   6.110e+03   0.000e+00       9.574e+04

Total time (s) = 2.16
---------------------------------------------------------

Total running time of the script: (0 minutes 55.593 seconds)

Gallery generated by Sphinx-Gallery