Pytorch RNN API

Events describe a difference of luminance in the scene, so running neural networks for pattern recognition is not as performant if you do not use some form of memory. This is why we tend to use Recurrent Neural Network to deal with video sequences.

In this tutorial, we will explain one of the fundamental model architectures that we have used in our Metavision API: ConvRNN.

import os
import glob
import numpy as np
import torch

from metavision_ml.core import temporal_modules as tm, modules as m

A typical Sequence

In our API, an EB video sequence is represented by 5-d tensors of shape T,B,C,H,W:

  • T: number of time bins

  • B: batch size

  • C: number of channels

  • H: height of the input

  • W: width of the input

In Pytorch, 2D operations usually only take the last 4 dimensions. To ease the transformation between the 5d and 4d tensors, we use function time_to_batch to “flatten” the 5d tensor to 4d, and function batch_to_time to unflatten the 4d tensor to 5d. These functions can be imported from <metavision_ml.core.temporal_modules>.

t,b,c,h,w = 3,4,5,64,64 #random 5-d input
x = torch.randn(t,b,c,h,w)

print("Input shape: ", x.shape)

flatten_x, _ = tm.time_to_batch(x)

print("Flattened shape: ", flatten_x.shape)

unflatten_back_x = tm.batch_to_time(flatten_x, b)

print("Back to original shape: ", unflatten_back_x.shape)
Input shape:  torch.Size([3, 4, 5, 64, 64])
Flatten shape:  torch.Size([12, 5, 64, 64])
Back to original shape:  torch.Size([3, 4, 5, 64, 64])

We also created a wrapper around torch nn module, so that all the before-mentioned transformations can be performed within the wrapper called seq_wise:

cout = 16
layer = m.ConvLayer(c, cout)

y = tm.seq_wise(layer)(x)

print("Output shape: ", y.shape)
Output shape:  torch.Size([3, 4, 16, 64, 64])

In practice, what this wrapper does is just to flatten the 5d input tensor to 4d before passing to the 2d operator, then convert it back to 5d as result.

ConvRNN

All this is great if we just care about processing time steps in parallel in isolation, but what if we want to propagate information sequentially? This is where the ConvRNN comes to the rescue.

You can call the ConvRNN class from the <metavision_ml.core.temporal_modules> as well.

rnn_layer = tm.ConvRNN(c, cout)

# let's test it on the input x
y = rnn_layer(x)

print('output: ', y.shape)
output:  torch.Size([3, 4, 16, 64, 64])

ConvRNN layers

ConvRNN class is composed of 2 parts:

  • a 2d input-to-hidden layer performed in parallel (using the “flattening” operations described above)

  • a for-loop of 2d operations for hidden-to-hidden transformations.

Here the input-to-hidden layer outputs 4 times the number of output channels, because for the RNN part we use LSTM unit which contains 4 components:

  • input

  • cell

  • forget

  • output

See https://en.wikipedia.org/wiki/Long_short-term_memory for complete information.

Let’s take a look in the ConvRNN layers:

with torch.no_grad():
    x1 = rnn_layer.conv_x2h(x)
    print('first part: ', x1.shape)
    x2 = rnn_layer.timepool(x1)
    print('second part: ', x2.shape)
first part:  torch.Size([3, 4, 64, 64, 64])
second part:  torch.Size([3, 4, 16, 64, 64])

The RNN Cell

The for-loop is a bit hidden in the previous example, let’s take a closer look at the RNN component

import torch.nn as nn

x2h = nn.Conv2d(c, cout, kernel_size=3, stride=1, padding=1)
h2h = nn.Conv2d(cout, cout, kernel_size=3, stride=1, padding=1)

x1 = tm.seq_wise(x2h)(x)

ht = torch.zeros((b,cout,h,w), dtype=torch.float32) # we initialize a hidden state ourselves

all_timesteps = []
for t, xt in enumerate(x1.unbind(0)):
    ht = torch.sigmoid(xt + h2h(ht)) # basic rnn cell
    all_timesteps.append(ht[None])

print('final output: ', ht.shape)

# Our dummy loss
final_output = torch.cat(all_timesteps)
loss = final_output.sum()

# We can backward everything using Pytorch's AD
loss.backward()

print('Beginning of gradient of hidden-to-hidden connexion: ', h2h.weight.grad.view(-1)[:10], ' ...')
final output:  torch.Size([4, 16, 64, 64])
Beginning of gradient of hidden-to-hidden connexion:  tensor([3066.7192, 3116.1343, 3067.7227, 3124.3948, 3176.4548, 3119.4619,
        3070.6348, 3118.1299, 3064.5928, 3604.1628])  ...

Masking RNN’s hidden state

Here you notice that we had to initialize the hidden state ourselves. However with temporal_modules ConvRNN’s implementation, the hidden state is handled internally, and is automatically reset when new recordings arrive.

You can find a small illustration below:

with torch.no_grad():
    y = rnn_layer(x)

mask = torch.rand(b) > 0.7 # imagine that 70% of the batch have been replaced with new videos

print('mask: ', mask)
print('hidden state first pixel before masking: ', rnn_layer.timepool.prev_h[:,0,0,0])
rnn_layer.reset(mask[:,None,None,None])
print('hidden state first pixel after masking: ', rnn_layer.timepool.prev_h[:,0,0,0])
mask:  tensor([ True, False, False, False])
hidden state first pixel before masking:  tensor([0.1131, 0.1069, 0.0398, 0.1078])
hidden state first pixel after masking:  tensor([0.1131, 0.0000, 0.0000, 0.0000])

That’s it, you can now train a ConvRNN!

A few closing remarks:

  • ConvRNN is very memory intensive, so use float16 tensors

  • You can circumvent some of the memory cost by using Pytorch’s gradient checkpointing (not yet implemented)