Source code for pytorch_wavelets.dwt.transform2d

import torch.nn as nn
import pywt
import pytorch_wavelets.dwt.lowlevel as lowlevel
import torch


[docs]class DWTForward(nn.Module): """ Performs a 2d DWT Forward decomposition of an image Args: J (int): Number of levels of decomposition wave (str or pywt.Wavelet): Which wavelet to use. Can be a string to pass to pywt.Wavelet constructor, can also be a pywt.Wavelet class, or can be a two tuple of array-like objects for the analysis low and high pass filters. mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The padding scheme separable (bool): whether to do the filtering separably or not (the naive implementation can be faster on a gpu). """ def __init__(self, J=1, wave='db1', mode='zero'): super().__init__() if isinstance(wave, str): wave = pywt.Wavelet(wave) if isinstance(wave, pywt.Wavelet): h0_col, h1_col = wave.dec_lo, wave.dec_hi h0_row, h1_row = h0_col, h1_col else: if len(wave) == 2: h0_col, h1_col = wave[0], wave[1] h0_row, h1_row = h0_col, h1_col elif len(wave) == 4: h0_col, h1_col = wave[0], wave[1] h0_row, h1_row = wave[2], wave[3] # Prepare the filters filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row) self.register_buffer('h0_col', filts[0]) self.register_buffer('h1_col', filts[1]) self.register_buffer('h0_row', filts[2]) self.register_buffer('h1_row', filts[3]) self.J = J self.mode = mode
[docs] def forward(self, x): """ Forward pass of the DWT. Args: x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Returns: (yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. yh is a list of length J with the first entry being the finest scale coefficients. yl has shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new dimension in yh iterates over the LH, HL and HH coefficients. Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. """ yh = [] ll = x mode = lowlevel.mode_to_int(self.mode) # Do a multilevel transform for j in range(self.J): # Do 1 level of the transform ll, high = lowlevel.AFB2D.apply( ll, self.h0_col, self.h1_col, self.h0_row, self.h1_row, mode) yh.append(high) return ll, yh
[docs]class DWTInverse(nn.Module): """ Performs a 2d DWT Inverse reconstruction of an image Args: wave (str or pywt.Wavelet): Which wavelet to use C: deprecated, will be removed in future """ def __init__(self, wave='db1', mode='zero'): super().__init__() if isinstance(wave, str): wave = pywt.Wavelet(wave) if isinstance(wave, pywt.Wavelet): g0_col, g1_col = wave.rec_lo, wave.rec_hi g0_row, g1_row = g0_col, g1_col else: if len(wave) == 2: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = g0_col, g1_col elif len(wave) == 4: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = wave[2], wave[3] # Prepare the filters filts = lowlevel.prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) self.register_buffer('g0_col', filts[0]) self.register_buffer('g1_col', filts[1]) self.register_buffer('g0_row', filts[2]) self.register_buffer('g1_row', filts[3]) self.mode = mode
[docs] def forward(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of bandpass tensors of shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match the format returned by DWTForward Returns: Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. Note: Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though). """ yl, yh = coeffs ll = yl mode = lowlevel.mode_to_int(self.mode) # Do a multilevel inverse transform for h in yh[::-1]: if h is None: h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], ll.shape[-1], device=ll.device) # 'Unpad' added dimensions if ll.shape[-2] > h.shape[-2]: ll = ll[...,:-1,:] if ll.shape[-1] > h.shape[-1]: ll = ll[...,:-1] ll = lowlevel.SFB2D.apply( ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) return ll
class SWTForward(nn.Module): """ Performs a 2d Stationary wavelet transform (or undecimated wavelet transform) of an image Args: J (int): Number of levels of decomposition wave (str or pywt.Wavelet): Which wavelet to use. Can be a string to pass to pywt.Wavelet constructor, can also be a pywt.Wavelet class, or can be a two tuple of array-like objects for the analysis low and high pass filters. mode (str): 'zero', 'symmetric', 'reflect' or 'periodization'. The padding scheme. PyWavelets uses only periodization so we use this as our default scheme. """ def __init__(self, J=1, wave='db1', mode='periodization'): super().__init__() if isinstance(wave, str): wave = pywt.Wavelet(wave) if isinstance(wave, pywt.Wavelet): h0_col, h1_col = wave.dec_lo, wave.dec_hi h0_row, h1_row = h0_col, h1_col else: if len(wave) == 2: h0_col, h1_col = wave[0], wave[1] h0_row, h1_row = h0_col, h1_col elif len(wave) == 4: h0_col, h1_col = wave[0], wave[1] h0_row, h1_row = wave[2], wave[3] # Prepare the filters filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row) self.register_buffer('h0_col', filts[0]) self.register_buffer('h1_col', filts[1]) self.register_buffer('h0_row', filts[2]) self.register_buffer('h1_row', filts[3]) self.J = J self.mode = mode def forward(self, x): """ Forward pass of the SWT. Args: x (tensor): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Returns: List of coefficients for each scale. Each coefficient has shape :math:`(N, C_{in}, 4, H_{in}, W_{in})` where the extra dimension stores the 4 subbands for each scale. The ordering in these 4 coefficients is: (A, H, V, D) or (ll, lh, hl, hh). """ ll = x coeffs = [] # Do a multilevel transform filts = (self.h0_col, self.h1_col, self.h0_row, self.h1_row) for j in range(self.J): # Do 1 level of the transform y = lowlevel.afb2d_atrous(ll, filts, self.mode, 2**j) coeffs.append(y) ll = y[:,:,0] return coeffs