Source code for pytorch_wavelets.dtcwt.transform2d

import torch
import torch.nn as nn
from numpy import ndarray
from math import sqrt
import numpy as np

from pytorch_wavelets.dtcwt.coeffs import qshift as _qshift, biort as _biort
from pytorch_wavelets.dtcwt.lowlevel import prep_filt
import pytorch_wavelets.dwt.lowlevel as lowlevel
from pytorch_wavelets.dtcwt import transform_funcs as tf


class Scatter(nn.Module):
    """ Performs a single scale scattering with DTCWTs """
    def __init__(self, biort='near_sym_a', separable=True):
        super().__init__()
        self.biort = biort
        self.separable = separable
        if isinstance(biort, str):
            h0, _, h1, _ = _biort(biort)
        else:
            h0, h1 = biort
        if h0.shape[0] < h1.shape[0]:
            h0 = np.pad(h0, ((1,1), (0,0)), 'constant')
        elif h1.shape[0] < h0.shape[0]:
            h1 = np.pad(h1, ((1,1), (0,0)), 'constant')
        # Prepare the filters
        if separable:
            filts = lowlevel.prep_filt_afb2d(h0, h1, h0, h1)
            self.h0_col = nn.Parameter(filts[0], requires_grad=False)
            self.h1_col = nn.Parameter(filts[1], requires_grad=False)
            self.h0_row = nn.Parameter(filts[2], requires_grad=False)
            self.h1_row = nn.Parameter(filts[3], requires_grad=False)
            self.h = (self.h0_col, self.h1_col, self.h0_row, self.h1_row)
        else:
            filts = lowlevel.prep_filt_afb2d_nonsep(h0, h1, h0, h1)
            self.h = nn.Parameter(filts, requires_grad=False)
        self.b = 1e-4

    def forward(self, x):
        # If the row/col count of X is not divisible by 2 then we need to
        # extend X
        batch, ch, r, c = x.shape
        if r % 2 != 0:
            x = torch.cat((x, x[:,:,-1:]), dim=2)
        if c % 2 != 0:
            x = torch.cat((x, x[:,:,:,-1:]), dim=3)
        batch, ch, r, c = x.shape

        if self.separable:
            y = lowlevel.afb2d(x, self.h, mode='symmetric',
                                      decimate=False)
        else:
            y = lowlevel.afb2d_nonsep(x, self.h, mode='symmetric', decimate=False)
        y = y.view(batch, ch, 4, r, c)
        lo = y[:,:,0]
        hi = y[:,:,1:]/sqrt(2)
        r1 = hi[..., ::2, ::2] - hi[..., 1::2,1::2]
        r2 = hi[..., ::2, ::2] + hi[..., 1::2,1::2]
        i1 = hi[..., ::2, 1::2] + hi[..., 1::2, ::2]
        i2 = hi[..., ::2, 1::2] - hi[..., 1::2, ::2]
        m1 = torch.sqrt(r1**2 + i1**2 + self.b) - sqrt(self.b)
        m2 = torch.sqrt(r2**2 + i2**2 + self.b) - sqrt(self.b)
        return lo, torch.cat((m1, m2), dim=2).view(batch, ch*6, r//2, c//2)


[docs]class DTCWTForward(nn.Module): """ Performs a 2d DTCWT Forward decomposition of an image Args: biort (str): One of 'antonini', 'legall', 'near_sym_a', 'near_sym_b'. Specifies the first level biorthogonal wavelet filters. Can also give a two tuple for the low and highpass filters directly. qshift (str): One of 'qshift_06', 'qshift_a', 'qshift_b', 'qshift_c', 'qshift_d'. Specifies the second level quarter shift filters. Can also give a 4-tuple for the low tree a, low tree b, high tree a and high tree b filters directly. J (int): Number of levels of decomposition skip_hps (bools): List of bools of length J which specify whether or not to calculate the bandpass outputs at the given scale. skip_hps[0] is for the first scale. Can be a single bool in which case that is applied to all scales. include_scale (bool): If true, return the bandpass outputs. Can also be a list of length J specifying which lowpasses to return. I.e. if [False, True, True], the forward call will return the second and third lowpass outputs, but discard the lowpass from the first level transform. downsample (bool): If true, subsample the output lowpass (to match the bandpass output sizes) o_dim (int): Which dimension to put the orientations in ri_dim (int): which dimension to put the real and imaginary parts """ def __init__(self, biort='near_sym_a', qshift='qshift_a', J=3, skip_hps=False, include_scale=False, downsample=False, o_dim=2, ri_dim=-1): super().__init__() if o_dim == ri_dim: raise ValueError("Orientations and real/imaginary parts must be " "in different dimensions.") self.biort = biort self.qshift = qshift self.J = J self.downsample = downsample self.o_dim = o_dim self.ri_dim = ri_dim if isinstance(biort, str): h0o, _, h1o, _ = _biort(biort) self.h0o = torch.nn.Parameter(prep_filt(h0o, 1), False) self.h1o = torch.nn.Parameter(prep_filt(h1o, 1), False) else: self.h0o = torch.nn.Parameter(prep_filt(biort[0], 1), False) self.h1o = torch.nn.Parameter(prep_filt(biort[1], 1), False) if isinstance(qshift, str): h0a, h0b, _, _, h1a, h1b, _, _ = _qshift(qshift) self.h0a = torch.nn.Parameter(prep_filt(h0a, 1), False) self.h0b = torch.nn.Parameter(prep_filt(h0b, 1), False) self.h1a = torch.nn.Parameter(prep_filt(h1a, 1), False) self.h1b = torch.nn.Parameter(prep_filt(h1b, 1), False) else: self.h0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False) self.h0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False) self.h1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False) self.h1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False) # Get the function to do the DTCWT if isinstance(skip_hps, (list, tuple, ndarray)): self.skip_hps = skip_hps else: self.skip_hps = [skip_hps,] * self.J if isinstance(include_scale, (list, tuple, ndarray)): self.include_scale = include_scale else: self.include_scale = [include_scale,] * self.J if True in self.include_scale: self.dtcwt_func = getattr(tf, 'xfm{J}scale'.format(J=J)) else: self.dtcwt_func = getattr(tf, 'xfm{J}'.format(J=J))
[docs] def forward(self, x): """ Forward Dual Tree Complex Wavelet Transform Args: x (tensor): Input to transform. Should be of shape :math:`(N, C_{in}, H_{in}, W_{in})`. Returns: (yl, yh) tuple of lowpass (yl) and bandpass (yh) coefficients. If include_scale was true, yl will be a list of lowpass coefficients, otherwise will be just the final lowpass coefficient of shape :math:`(N, C_{in}, H_{in}', W_{in}')`. Yh will be a list of the complex bandpass coefficients of shape :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar shape depending on o_dim and ri_dim Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a DTCWT pyramid. """ coeffs = self.dtcwt_func.apply( x, self.h0o, self.h1o, self.h0a, self.h0b, self.h1a, self.h1b, self.skip_hps, self.include_scale, self.o_dim, self.ri_dim, self.downsample) if True in self.include_scale: return coeffs[:self.J], coeffs[self.J:] else: return coeffs[0], coeffs[1:]
[docs]class DTCWTInverse(nn.Module): """ 2d DTCWT Inverse Args: biort (str): One of 'antonini', 'legall', 'near_sym_a', 'near_sym_b'. Specifies the first level biorthogonal wavelet filters. Can also give a two tuple for the low and highpass filters directly. qshift (str): One of 'qshift_06', 'qshift_a', 'qshift_b', 'qshift_c', 'qshift_d'. Specifies the second level quarter shift filters. Can also give a 4-tuple for the low tree a, low tree b, high tree a and high tree b filters directly. J (int): Number of levels of decomposition. o_dim (int):which dimension the orientations are in ri_dim (int): which dimension to put th real and imaginary parts in """ def __init__(self, biort='near_sym_a', qshift='qshift_a', J=3, o_dim=2, ri_dim=-1): super().__init__() self.biort = biort self.qshift = qshift self.o_dim = o_dim self.ri_dim = ri_dim self.J = J if isinstance(biort, str): _, g0o, _, g1o = _biort(biort) self.g0o = torch.nn.Parameter(prep_filt(g0o, 1), False) self.g1o = torch.nn.Parameter(prep_filt(g1o, 1), False) else: self.g0o = torch.nn.Parameter(prep_filt(biort[0], 1), False) self.g1o = torch.nn.Parameter(prep_filt(biort[1], 1), False) if isinstance(qshift, str): _, _, g0a, g0b, _, _, g1a, g1b = _qshift(qshift) self.g0a = torch.nn.Parameter(prep_filt(g0a, 1), False) self.g0b = torch.nn.Parameter(prep_filt(g0b, 1), False) self.g1a = torch.nn.Parameter(prep_filt(g1a, 1), False) self.g1b = torch.nn.Parameter(prep_filt(g1b, 1), False) else: self.g0a = torch.nn.Parameter(prep_filt(qshift[0], 1), False) self.g0b = torch.nn.Parameter(prep_filt(qshift[1], 1), False) self.g1a = torch.nn.Parameter(prep_filt(qshift[2], 1), False) self.g1b = torch.nn.Parameter(prep_filt(qshift[3], 1), False) # Create the function to do the DTCWT self.dtcwt_func = getattr(tf, 'ifm{J}'.format(J=J))
[docs] def forward(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: yl is a tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of the complex bandpass coefficients of shape :math:`list(N, C_{in}, 6, H_{in}'', W_{in}'', 2)`, or similar depending on o_dim and ri_dim Returns: Reconstructed output Note: Can accept Nones or an empty tensor (torch.tensor([])) for the lowpass or bandpass inputs. In this cases, an array of zeros replaces that input. Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` are the shapes of a DTCWT pyramid. Note: If include_scale was true for the forward pass, you should provide only the final lowpass output here, as normal for an inverse wavelet transform. Note: Won't work if the forward transform lowpass was downsampled. """ yl, yh = coeffs for s in yh: if s is not None and s.shape != torch.Size([0]): assert s.shape[self.o_dim] == 6, "Inverse transform must " \ "have input with 6 orientations" assert len(s.shape) == 6, "Bandpass inputs must have " \ "6 dimensions" assert s.shape[self.ri_dim] == 2, "Inputs must be complex " \ "with real and imaginary parts in the ri dimension" assert len(yh) == self.J, "The input provided has more scales than J" return self.dtcwt_func.apply( yl, *yh, self.g0o, self.g1o, self.g0a, self.g0b, self.g1a, self.g1b, self.o_dim, self.ri_dim)