Asif Rahman

Separable temporal convolutions

Posted on 2022-01-22

Given a multivariate time series \(x \in \mathbb{R}^{B \times D \times T}\) with \(D=3\) channels, \(T=4\) timesteps and batch size \(B=1\).

x = [
    [1,5,10,20],
    [100,150,200,250],
    [1000,1500,2000,2500],
]

# [batch_size=1, in_channels=3, timesteps=4]
xt = torch.FloatTensor(x)
xt = xt.unsqueeze(0)

The separable convolution learns a group of filters for each channel independently, without interactions across channels. Below, I define a 1D convolutional layer with kernel size 2 and learn 1 filter per channel (num_channels=1).

separable = True
in_channels = xt.shape[1]
num_channels = 1
kernel_size = 2
stride = 1
layer_i = 0
dilation_size = 2 ** layer_i
padding = (kernel_size - 1) * dilation_size
groups = in_channels if separable else 1
out_channels = in_channels * num_channels

For illustrative purposes, the weights are initialized to 1 and bias to 0.

conv1 = nn.Conv1d(
    in_channels,
    out_channels,
    kernel_size,
    stride=stride,
    padding=padding,
    dilation=dilation_size,
    groups=groups,
)
torch.nn.init.constant_(conv1.weight, 1)
torch.nn.init.constant_(conv1.bias, 0)

A separable convolution with kernel_size=2 and num_channels=1 is simply a weighted sum along each channel.

tensor([[[   1,    6,   15,   30],
         [ 100,  250,  350,  450],
         [1000, 2500, 3500, 4500]]])

Increasing num_channels will learn a set of independent filters for each channel. For example, with num_channels=3 gives a total number of output channels of num_channels*in_channels=9.

tensor([[[   1,    6,   15,   30],
         [   1,    6,   15,   30],
         [   1,    6,   15,   30],
         [ 100,  250,  350,  450],
         [ 100,  250,  350,  450],
         [ 100,  250,  350,  450],
         [1000, 2500, 3500, 4500],
         [1000, 2500, 3500, 4500],
         [1000, 2500, 3500, 4500]]])

When separable=False and num_channels=1, you get mixing between the channels:

tensor([[[1101, 2756, 3865, 4980],
         [1101, 2756, 3865, 4980],
         [1101, 2756, 3865, 4980]]])