API Guide

Decimated WT

class pytorch_wavelets.DWTForward(J=1, wave='db1', mode='zero')[source]

Bases: torch.nn.modules.module.Module

Performs a 2d DWT Forward decomposition of an image

Parameters:
  • J (int) – Number of levels of decomposition
  • wave (str or pywt.Wavelet) – Which wavelet to use. Can be a string to pass to pywt.Wavelet constructor, can also be a pywt.Wavelet class, or can be a two tuple of array-like objects for the analysis low and high pass filters.
  • mode (str) – ‘zero’, ‘symmetric’, ‘reflect’ or ‘periodization’. The padding scheme
  • separable (bool) – whether to do the filtering separably or not (the naive implementation can be faster on a gpu).
forward(x)[source]

Forward pass of the DWT.

Parameters:x (tensor) – Input of shape \((N, C_{in}, H_{in}, W_{in})\)
Returns:
(yl, yh)
tuple of lowpass (yl) and bandpass (yh) coefficients. yh is a list of length J with the first entry being the finest scale coefficients. yl has shape \((N, C_{in}, H_{in}', W_{in}')\) and yh has shape \(list(N, C_{in}, 3, H_{in}'', W_{in}'')\). The new dimension in yh iterates over the LH, HL and HH coefficients.

Note

\(H_{in}', W_{in}', H_{in}'', W_{in}''\) denote the correctly downsampled shapes of the DWT pyramid.

class pytorch_wavelets.DWTInverse(wave='db1', mode='zero')[source]

Bases: torch.nn.modules.module.Module

Performs a 2d DWT Inverse reconstruction of an image

Parameters:
  • wave (str or pywt.Wavelet) – Which wavelet to use
  • C – deprecated, will be removed in future
forward(coeffs)[source]
Parameters:coeffs (yl, yh) – tuple of lowpass and bandpass coefficients, where: yl is a lowpass tensor of shape \((N, C_{in}, H_{in}', W_{in}')\) and yh is a list of bandpass tensors of shape \(list(N, C_{in}, 3, H_{in}'', W_{in}'')\). I.e. should match the format returned by DWTForward
Returns:Reconstructed input of shape \((N, C_{in}, H_{in}, W_{in})\)

Note

\(H_{in}', W_{in}', H_{in}'', W_{in}''\) denote the correctly downsampled shapes of the DWT pyramid.

Note

Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though).

Dual Tree Complex WT

class pytorch_wavelets.DTCWTForward(biort='near_sym_a', qshift='qshift_a', J=3, skip_hps=False, include_scale=False, o_dim=2, ri_dim=-1, mode='symmetric')[source]

Bases: torch.nn.modules.module.Module

Performs a 2d DTCWT Forward decomposition of an image

Parameters:
  • biort (str) – One of ‘antonini’, ‘legall’, ‘near_sym_a’, ‘near_sym_b’. Specifies the first level biorthogonal wavelet filters. Can also give a two tuple for the low and highpass filters directly.
  • qshift (str) – One of ‘qshift_06’, ‘qshift_a’, ‘qshift_b’, ‘qshift_c’, ‘qshift_d’. Specifies the second level quarter shift filters. Can also give a 4-tuple for the low tree a, low tree b, high tree a and high tree b filters directly.
  • J (int) – Number of levels of decomposition
  • skip_hps (bools) – List of bools of length J which specify whether or not to calculate the bandpass outputs at the given scale. skip_hps[0] is for the first scale. Can be a single bool in which case that is applied to all scales.
  • include_scale (bool) – If true, return the bandpass outputs. Can also be a list of length J specifying which lowpasses to return. I.e. if [False, True, True], the forward call will return the second and third lowpass outputs, but discard the lowpass from the first level transform.
  • o_dim (int) – Which dimension to put the orientations in
  • ri_dim (int) – which dimension to put the real and imaginary parts
forward(x)[source]

Forward Dual Tree Complex Wavelet Transform

Parameters:x (tensor) – Input to transform. Should be of shape \((N, C_{in}, H_{in}, W_{in})\).
Returns:
(yl, yh)
tuple of lowpass (yl) and bandpass (yh) coefficients. If include_scale was true, yl will be a list of lowpass coefficients, otherwise will be just the final lowpass coefficient of shape \((N, C_{in}, H_{in}', W_{in}')\). Yh will be a list of the complex bandpass coefficients of shape \(list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)\), or similar shape depending on o_dim and ri_dim

Note

\(H_{in}', W_{in}', H_{in}'', W_{in}''\) are the shapes of a DTCWT pyramid.

class pytorch_wavelets.DTCWTInverse(biort='near_sym_a', qshift='qshift_a', o_dim=2, ri_dim=-1, mode='symmetric')[source]

Bases: torch.nn.modules.module.Module

2d DTCWT Inverse

Parameters:
  • biort (str) – One of ‘antonini’, ‘legall’, ‘near_sym_a’, ‘near_sym_b’. Specifies the first level biorthogonal wavelet filters. Can also give a two tuple for the low and highpass filters directly.
  • qshift (str) – One of ‘qshift_06’, ‘qshift_a’, ‘qshift_b’, ‘qshift_c’, ‘qshift_d’. Specifies the second level quarter shift filters. Can also give a 4-tuple for the low tree a, low tree b, high tree a and high tree b filters directly.
  • J (int) – Number of levels of decomposition.
  • o_dim (int) – which dimension the orientations are in
  • ri_dim (int) – which dimension to put th real and imaginary parts in
forward(coeffs)[source]
Parameters:coeffs (yl, yh) – tuple of lowpass and bandpass coefficients, where: yl is a tensor of shape \((N, C_{in}, H_{in}', W_{in}')\) and yh is a list of the complex bandpass coefficients of shape \(list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)\), or similar depending on o_dim and ri_dim
Returns:Reconstructed output

Note

Can accept Nones or an empty tensor (torch.tensor([])) for the lowpass or bandpass inputs. In this cases, an array of zeros replaces that input.

Note

\(H_{in}', W_{in}', H_{in}'', W_{in}''\) are the shapes of a DTCWT pyramid.

Note

If include_scale was true for the forward pass, you should provide only the final lowpass output here, as normal for an inverse wavelet transform.