Reusing a Torchjit Model in Python

In the detection and tracking tutorial, we have seen how we can use the Metavision ML module for detection and tracking.

The model we provide is a torchjit, which means that you can also directly access the model using the Torch library.

In this tutorial, we will learn how to use our pre-trained model with Torch and how to extract the features generated from the model.

First, we need a pre-trained TorchScript model with a JSON file of hyperparameters. Check our pre-trained models page to find out how to download the object detection TorchScript models. Move the folder red_event_cube_05_2020 to your local directory or update the path in the code that follows.

Now let’s load the required libraries and the model:

%matplotlib inline
import os
import numpy as np
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = [12, 8]
import torch

from metavision_ml.data import CDProcessorIterator
jit_path = os.path.join(os.getcwd(), "model.ptjit")

model = torch.jit.load(jit_path)

The loaded model is a torchjit model:

print(model)
RecursiveScriptModule(
  original_name=SSDPipeline
  (feature_extractor): RecursiveScriptModule(
    original_name=Vanilla
    (conv1): RecursiveScriptModule(
      original_name=SequenceWise
      (module): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=ConvLayer
          (0): RecursiveScriptModule(original_name=BatchNorm2d)
          (1): RecursiveScriptModule(original_name=Conv2d)
          (2): RecursiveScriptModule(original_name=ReLU)
        )
        (1): RecursiveScriptModule(
          original_name=PreActBlock
          (conv1): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=ReLU)
          )
          (conv2): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=ReLU)
          )
          (downsample): RecursiveScriptModule(
            original_name=Sequential
            (0): RecursiveScriptModule(original_name=Conv2d)
          )
          (fc1): RecursiveScriptModule(original_name=Conv2d)
          (fc2): RecursiveScriptModule(original_name=Conv2d)
        )
        (2): RecursiveScriptModule(
          original_name=PreActBlock
          (conv1): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=ReLU)
          )
          (conv2): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=ReLU)
          )
          (downsample): RecursiveScriptModule(original_name=Sequential)
          (fc1): RecursiveScriptModule(original_name=Conv2d)
          (fc2): RecursiveScriptModule(original_name=Conv2d)
        )
        (3): RecursiveScriptModule(
          original_name=PreActBlock
          (conv1): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=ReLU)
          )
          (conv2): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=ReLU)
          )
          (downsample): RecursiveScriptModule(
            original_name=Sequential
            (0): RecursiveScriptModule(original_name=Conv2d)
          )
          (fc1): RecursiveScriptModule(original_name=Conv2d)
          (fc2): RecursiveScriptModule(original_name=Conv2d)
        )
      )
    )
    (conv2): RecursiveScriptModule(
      original_name=ModuleList
      (0): RecursiveScriptModule(
        original_name=ConvRNN
        (timepool): RecursiveScriptModule(
          original_name=ConvLSTMCell
          (conv_h2h): RecursiveScriptModule(original_name=Conv2d)
        )
        (conv_x2h): RecursiveScriptModule(
          original_name=SequenceWise
          (module): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=Identity)
          )
        )
      )
      (1): RecursiveScriptModule(
        original_name=ConvRNN
        (timepool): RecursiveScriptModule(
          original_name=ConvLSTMCell
          (conv_h2h): RecursiveScriptModule(original_name=Conv2d)
        )
        (conv_x2h): RecursiveScriptModule(
          original_name=SequenceWise
          (module): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=Identity)
          )
        )
      )
      (2): RecursiveScriptModule(
        original_name=ConvRNN
        (timepool): RecursiveScriptModule(
          original_name=ConvLSTMCell
          (conv_h2h): RecursiveScriptModule(original_name=Conv2d)
        )
        (conv_x2h): RecursiveScriptModule(
          original_name=SequenceWise
          (module): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=Identity)
          )
        )
      )
      (3): RecursiveScriptModule(
        original_name=ConvRNN
        (timepool): RecursiveScriptModule(
          original_name=ConvLSTMCell
          (conv_h2h): RecursiveScriptModule(original_name=Conv2d)
        )
        (conv_x2h): RecursiveScriptModule(
          original_name=SequenceWise
          (module): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=Identity)
          )
        )
      )
      (4): RecursiveScriptModule(
        original_name=ConvRNN
        (timepool): RecursiveScriptModule(
          original_name=ConvLSTMCell
          (conv_h2h): RecursiveScriptModule(original_name=Conv2d)
        )
        (conv_x2h): RecursiveScriptModule(
          original_name=SequenceWise
          (module): RecursiveScriptModule(
            original_name=ConvLayer
            (0): RecursiveScriptModule(original_name=BatchNorm2d)
            (1): RecursiveScriptModule(original_name=Conv2d)
            (2): RecursiveScriptModule(original_name=Identity)
          )
        )
      )
    )
  )
  (rpn): RecursiveScriptModule(
    original_name=BoxHead
    (box_head): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(
        original_name=ConvLayer
        (0): RecursiveScriptModule(original_name=Identity)
        (1): RecursiveScriptModule(original_name=Conv2d)
        (2): RecursiveScriptModule(original_name=ReLU)
      )
      (1): RecursiveScriptModule(original_name=Conv2d)
    )
    (cls_head): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(
        original_name=ConvLayer
        (0): RecursiveScriptModule(original_name=Identity)
        (1): RecursiveScriptModule(original_name=Conv2d)
        (2): RecursiveScriptModule(original_name=ReLU)
      )
      (1): RecursiveScriptModule(original_name=Conv2d)
    )
  )
  (anchor_generator): RecursiveScriptModule(
    original_name=Anchors
    (anchor_generators): RecursiveScriptModule(
      original_name=ModuleList
      (0): RecursiveScriptModule(original_name=AnchorLayer)
      (1): RecursiveScriptModule(original_name=AnchorLayer)
      (2): RecursiveScriptModule(original_name=AnchorLayer)
      (3): RecursiveScriptModule(original_name=AnchorLayer)
      (4): RecursiveScriptModule(original_name=AnchorLayer)
    )
  )
  (box_decoder): RecursiveScriptModule(original_name=BoxDecoder)
)

Now that we loaded the model, we can reuse its features. They can be used for different purposes, such as training a linear model based on them. In this tutorial, we will use them for visualization.

Our detection model expects as input a histogram of events, which we can generate with our preprocessing tools.

Here is the link to download the RAW file used in this sample: driving_sample.raw

input_path = "driving_sample.raw"

Let’s load the data using an iterator and create a histogram.

delta_t = 50000
# The processor iterator combines the events iterator with the preprocessing functions
proc_iterator = CDProcessorIterator(input_path, "histo", delta_t=delta_t, num_tbins=1, preprocess_kwargs={"max_incr_per_pixel": 5},
                                    device=torch.device('cpu'), height=None, width=None)

input_tensor =  next(iter(proc_iterator))

We can now extract the feature map:

feature_maps = model.feature_extractor(input_tensor[None, ...])
for feature_map in feature_maps:
    print(feature_map.shape)
feature_maps = [feature_map.detach().numpy() for feature_map in feature_maps]
torch.Size([1, 256, 90, 160])
torch.Size([1, 256, 45, 80])
torch.Size([1, 256, 23, 40])
torch.Size([1, 256, 12, 20])
torch.Size([1, 256, 6, 10])

Our detection network produces features maps at different resolution. Each different feature map corresponds to one channel in the final convolutional layer of the feature extractor network. These feature maps are the features that our network “considers” the best for the detection task. To make a comparison, a human looking for cars might search for headlights or wheels, our network uses these feature maps.

We can now visualize some of these feature maps: negative values (features that suggest that the object is not a car) are in blue, positives values (features that suggest that the object is a car) in red.

def remove_outliers(array):
    """remove outlier values for better visualization"""
    filtered_array = array.copy()
    absolute_value = np.abs(filtered_array)
    mean = absolute_value.mean()
    std = absolute_value.std()
    filtered_array[absolute_value > mean + 3 * std] = 0
    return filtered_array

plt.rcParams['figure.figsize'] = [6, 4]
# as a reminder, we first visualize the input of the neural network
plt.imshow(proc_iterator.show(time_bin=0))
plt.title("Neural network input histogram")
plt.show()

for index, feature_map in enumerate(feature_maps[0][0, :14]):
    feature_map = remove_outliers(feature_map)
    plt.imshow(feature_map, cmap="coolwarm")
    plt.title("feature map number {}".format(index))
    plt.show()
../../../_images/using_torchjit_model_16_0.png ../../../_images/using_torchjit_model_16_1.png ../../../_images/using_torchjit_model_16_2.png ../../../_images/using_torchjit_model_16_3.png ../../../_images/using_torchjit_model_16_4.png ../../../_images/using_torchjit_model_16_5.png ../../../_images/using_torchjit_model_16_6.png ../../../_images/using_torchjit_model_16_7.png ../../../_images/using_torchjit_model_16_8.png ../../../_images/using_torchjit_model_16_9.png ../../../_images/using_torchjit_model_16_10.png ../../../_images/using_torchjit_model_16_11.png ../../../_images/using_torchjit_model_16_12.png ../../../_images/using_torchjit_model_16_13.png ../../../_images/using_torchjit_model_16_14.png

From the first 14 features, feature map number 0 looks interesting, as it seems to have a positive correlation with the car in this example. Let’s take a closer look:

from itertools import islice
plt.rcParams['figure.figsize'] = [12, 8]

for input_tensor in islice(proc_iterator,4):

    # we extract this particular feature map
    single_feature_map = model.feature_extractor(input_tensor[None, ...])[0][0, 0].detach()

    # Let's display the features alongside the input events.
    _, (ax1, ax2) = plt.subplots(1, 2)
    feature_map = remove_outliers(single_feature_map.numpy())
    ax1.imshow(feature_map, cmap='coolwarm')
    ax2.imshow(proc_iterator.show(time_bin=0))
    plt.show()
../../../_images/using_torchjit_model_18_0.png ../../../_images/using_torchjit_model_18_1.png ../../../_images/using_torchjit_model_18_2.png ../../../_images/using_torchjit_model_18_3.png