DWT in Pytorch Wavelets¶
While pytorch_wavelets was initially built as a repo to do the dual tree wavelet transform efficiently in pytorch, I have also built a thin wrapper over PyWavelets, allowing the calculation of the 2D-DWT in pytorch on a GPU on a batch of images.
Older versions did the DWT non separably. As of v1.0.0 we now have code to do it separably. The old non-separable code is still there and is surprisingly sometimes faster. You can test the two out to see which is better for you by changing the separable flag in the DWT/IDWT constructor.
The DWT/IDWT now supports most of the padding schemes that PyWavelets uses. In particular:
- symmetric padding
- reflection padding
- zero padding
- periodization
You can see the source here. It is pretty minimal and should be clear what is going on.
In particular, the DWT and IWT classes initialize the filter banks as pytorch tensors (taking care to flip them as pytorch uses cross-correlation not convolution). It then performs non-separable 2D convolution on the input, using strided convolution to calculate the LL, LH, HL, and HH subbands. It also takes care of padding to match the PyWavelets implementation.
Differences to PyWavelets¶
Inputs¶
The pytorch_wavelets DWT expects the standard pytorch image format of NCHW
- i.e., a batch of N images, with C channels, height H and width W. For a single
RGB image, you would need to make it a torch tensor of size (1, 3, H,
W)
, or for a batch of 100 grayscale images, you would need to make it a tensor
of size (100, 1, H, W)
.
Returned Coefficients¶
We deviate slightly from PyWavelets with the format of the returned
coefficients. In particular, we return a tuple of (yl, yh)
where yl is
the LL band, and yh is a list. The first list entry yh[0] are the scale
1 bandpass coefficients (finest resolution), and the last list entry yh[-1]
are the coarsest bandpass coefficients. Note that this is the reverse of the
PyWavelets format (but fits with the dtcwt standard output). Each of the bands
is a single stacked tensor of the LH (horiz), HL (vertic), and HH (diag)
coefficients for each scale (as opposed to PyWavelets style of returning as
a tuple) with the stack along the third dimension. As the input had
4 dimensions, this output has 5 dimensions, with shape (N, C, 3, H, W)
.
This is easily transformed into the PyWavelets style by unstacking the
list elements in yh.
Example¶
import torch
from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT)
xfm = DWTForward(J=3, mode='zero', wave='db3') # Accepts all wave types available to PyWavelets
ifm = DWTInverse(mode='zero', wave='db3')
X = torch.randn(10,5,64,64)
Yl, Yh = xfm(X)
print(Yl.shape)
>>> torch.Size([10, 5, 12, 12])
print(Yh[0].shape)
>>> torch.Size([10, 5, 3, 34, 34])
print(Yh[1].shape)
>>> torch.Size([10, 5, 3, 19, 19])
print(Yh[2].shape)
>>> torch.Size([10, 5, 3, 12, 12])
Y = ifm((Yl, Yh))
import numpy as np
np.testing.assert_array_almost_equal(Y.cpu().numpy(), X.cpu().numpy())
Other Notes¶
GPU Calculations¶
As you would expect, you can move the transforms to the GPU by calling
xfm.cuda()
or ifm.cuda()
, where xfm, ifm are instances of
pytorch_wavelets.DWTForward
and pytorch_wavelets.DWTInverse
.