Shrinkage Fields (image restoration)

From HandWiki

Shrinkage fields is a random field-based machine learning technique that aims to perform high quality image restoration (denoising and deblurring) using low computational overhead.

Method

The restored image [math]\displaystyle{ x }[/math] is predicted from a corrupted observation [math]\displaystyle{ y }[/math] after training on a set of sample images [math]\displaystyle{ S }[/math].

A shrinkage (mapping) function [math]\displaystyle{ {f}_{{\pi }_{i}}\left(v\right)={\sum }_{j=1}^{M}{\pi }_{i,j}\exp \left(-\frac{\gamma }{2}{\left(v-{\mu }_{j}\right)}^{2}\right) }[/math] is directly modeled as a linear combination of radial basis function kernels, where [math]\displaystyle{ \gamma }[/math] is the shared precision parameter, [math]\displaystyle{ \mu }[/math] denotes the (equidistant) kernel positions, and M is the number of Gaussian kernels.

Because the shrinkage function is directly modeled, the optimization procedure is reduced to a single quadratic minimization per iteration, denoted as the prediction of a shrinkage field [math]\displaystyle{ {g}_{\mathrm{\Theta }}\left(\text{x}\right)={\mathcal{F}}^{-1}\left\lbrack \frac{\mathcal{F}\left(\lambda {K}^{T}y+{\sum }_{i=1}^{N}{F}_{i}^{T}{f}_{{\pi }_{i}}\left({F}_{i}x\right)\right)}{\lambda {\check{K}}^{\text{*}}\circ \check{K}+{\sum }_{i=1}^{N}{\check{F}}_{i}^{\text{*}}\circ {\check{F}}_{i}}\right\rbrack ={\mathrm{\Omega }}^{-1}\eta }[/math] where [math]\displaystyle{ \mathcal{F} }[/math] denotes the discrete Fourier transform and [math]\displaystyle{ F_x }[/math] is the 2D convolution [math]\displaystyle{ \text{f}\otimes \text{x} }[/math] with point spread function filter, [math]\displaystyle{ \breve{F} }[/math] is an optical transfer function defined as the discrete Fourier transform of [math]\displaystyle{ \text{f} }[/math], and [math]\displaystyle{ {\breve{F}}^{\text{*}} }[/math] is the complex conjugate of [math]\displaystyle{ \breve{F} }[/math].

[math]\displaystyle{ {\hat{x}}_{t} }[/math] is learned as [math]\displaystyle{ {\hat{x}}_{t}={g}_{{\mathrm{\Theta }}_{t}}\left({\hat{x}}_{t-1}\right) }[/math] for each iteration [math]\displaystyle{ t }[/math] with the initial case [math]\displaystyle{ {\hat{x}}_{0}=y }[/math], this forms a cascade of Gaussian conditional random fields (or cascade of shrinkage fields (CSF)). Loss-minimization is used to learn the model parameters [math]\displaystyle{ {\mathrm{\Theta }}_{t}={\left\lbrace {\lambda }_{t},{\pi }_{\mathit{ti}},{f}_{\mathit{ti}}\right\rbrace }_{i=1}^{N} }[/math].

The learning objective function is defined as [math]\displaystyle{ J\left({\mathrm{\Theta }}_{t}\right)={\sum }_{s=1}^{S}l\left({\hat{x}}_{t}^{\left(s\right)};{x}_{gt}^{\left(s\right)}\right) }[/math], where [math]\displaystyle{ l }[/math] is a differentiable loss function which is greedily minimized using training data [math]\displaystyle{ {\left\lbrace {x}_{gt}^{\left(s\right)},{y}^{\left(s\right)},{k}^{\left(s\right)}\right\rbrace }_{s=1}^{S} }[/math] and [math]\displaystyle{ {\hat{x}}_{t}^{\left(s\right)} }[/math].

Performance

Preliminary tests by the author suggest that RTF5[1] obtains slightly better denoising performance than [math]\displaystyle{ {\text{CSF}}_{7\times 7}^{\left\lbrace \mathrm{3,4,5}\right\rbrace } }[/math], followed by [math]\displaystyle{ {\text{CSF}}_{5\times 5}^{5} }[/math], [math]\displaystyle{ {\text{CSF}}_{7\times 7}^{2} }[/math], [math]\displaystyle{ {\text{CSF}}_{5\times 5}^{\left\lbrace \mathrm{3,4}\right\rbrace } }[/math], and BM3D.

BM3D denoising speed falls between that of [math]\displaystyle{ {\text{CSF}}_{5\times 5}^{4} }[/math] and [math]\displaystyle{ {\text{CSF}}_{7\times 7}^{4} }[/math], RTF being an order of magnitude slower.

Advantages

  • Results are comparable to those obtained by BM3D (reference in state of the art denoising since its inception in 2007)
  • Minimal runtime compared to other high-performance methods (potentially applicable within embedded devices)
  • Parallelizable (e.g.: possible GPU implementation)
  • Predictability: [math]\displaystyle{ O(D \log D) }[/math] runtime where [math]\displaystyle{ D }[/math] is the number of pixels
  • Fast training even with CPU

Implementations

  • A reference implementation has been written in MATLAB and released under the BSD 2-Clause license: shrinkage-fields

See also

References

  1. Jancsary, Jeremy; Nowozin, Sebastian; Sharp, Toby; Rother, Carsten (10 April 2012). "Regression Tree Fields – An Efficient, Non-parametric Approach to Image Labeling Problems". IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR). Providence, RI, USA: IEEE Computer Society. doi:10.1109/CVPR.2012.6247950.