Pytorch RNN API
Events describe a difference of luminance in the scene, so running neural networks for pattern recognition is not as performant if you do not use some form of memory. This is why we tend to use Recurrent Neural Network to deal with video sequences.
In this tutorial, we will explain one of the fundamental model architectures that we have used in our Metavision API: ConvRNN.
import os
import glob
import numpy as np
import torch
from metavision_ml.core import temporal_modules as tm, modules as m
A typical Sequence
In our API, an EB video sequence is represented by 5-d tensors of shape T,B,C,H,W:
T: number of time bins
B: batch size
C: number of channels
H: height of the input
W: width of the input
In Pytorch, 2D operations usually only take the last 4 dimensions. To
ease the transformation between the 5d and 4d tensors, we use function
time_to_batch
to “flatten” the 5d tensor to 4d, and function
batch_to_time
to unflatten the 4d tensor to 5d. These functions can
be imported from <metavision_ml.core.temporal_modules>.
t,b,c,h,w = 3,4,5,64,64 #random 5-d input
x = torch.randn(t,b,c,h,w)
print("Input shape: ", x.shape)
flatten_x, _ = tm.time_to_batch(x)
print("Flattened shape: ", flatten_x.shape)
unflatten_back_x = tm.batch_to_time(flatten_x, b)
print("Back to original shape: ", unflatten_back_x.shape)
Input shape: torch.Size([3, 4, 5, 64, 64])
Flatten shape: torch.Size([12, 5, 64, 64])
Back to original shape: torch.Size([3, 4, 5, 64, 64])
We also created a wrapper around torch nn module, so that all the
before-mentioned transformations can be performed within the wrapper
called seq_wise
:
cout = 16
layer = m.ConvLayer(c, cout)
y = tm.seq_wise(layer)(x)
print("Output shape: ", y.shape)
Output shape: torch.Size([3, 4, 16, 64, 64])
In practice, what this wrapper does is just to flatten the 5d input tensor to 4d before passing to the 2d operator, then convert it back to 5d as result.
ConvRNN
All this is great if we just care about processing time steps in parallel in isolation, but what if we want to propagate information sequentially? This is where the ConvRNN comes to the rescue.
You can call the ConvRNN class from the <metavision_ml.core.temporal_modules> as well.
rnn_layer = tm.ConvRNN(c, cout)
# let's test it on the input x
y = rnn_layer(x)
print('output: ', y.shape)
output: torch.Size([3, 4, 16, 64, 64])
ConvRNN layers
ConvRNN class is composed of 2 parts:
a 2d input-to-hidden layer performed in parallel (using the “flattening” operations described above)
a for-loop of 2d operations for hidden-to-hidden transformations.
Here the input-to-hidden layer outputs 4 times the number of output channels, because for the RNN part we use LSTM unit which contains 4 components:
input
cell
forget
output
See https://en.wikipedia.org/wiki/Long_short-term_memory for complete information.
Let’s take a look in the ConvRNN layers:
with torch.no_grad():
x1 = rnn_layer.conv_x2h(x)
print('first part: ', x1.shape)
x2 = rnn_layer.timepool(x1)
print('second part: ', x2.shape)
first part: torch.Size([3, 4, 64, 64, 64])
second part: torch.Size([3, 4, 16, 64, 64])
The RNN Cell
The for-loop is a bit hidden in the previous example, let’s take a closer look at the RNN component
import torch.nn as nn
x2h = nn.Conv2d(c, cout, kernel_size=3, stride=1, padding=1)
h2h = nn.Conv2d(cout, cout, kernel_size=3, stride=1, padding=1)
x1 = tm.seq_wise(x2h)(x)
ht = torch.zeros((b,cout,h,w), dtype=torch.float32) # we initialize a hidden state ourselves
all_timesteps = []
for t, xt in enumerate(x1.unbind(0)):
ht = torch.sigmoid(xt + h2h(ht)) # basic rnn cell
all_timesteps.append(ht[None])
print('final output: ', ht.shape)
# Our dummy loss
final_output = torch.cat(all_timesteps)
loss = final_output.sum()
# We can backward everything using Pytorch's AD
loss.backward()
print('Beginning of gradient of hidden-to-hidden connexion: ', h2h.weight.grad.view(-1)[:10], ' ...')
final output: torch.Size([4, 16, 64, 64])
Beginning of gradient of hidden-to-hidden connexion: tensor([3066.7192, 3116.1343, 3067.7227, 3124.3948, 3176.4548, 3119.4619,
3070.6348, 3118.1299, 3064.5928, 3604.1628]) ...