Train and Test Event-based Yolo Detection Model

Introduction

The goal of this tutorial is to demonstrate how to quickly leverage the existing popular frame-based neural networks for event-based vision with minimum modifications. Typically, these frame-based networks input RGB images and make predictions in different formats depending on the task. Event frames, such as event histograms, encode similar visual information as RGB images and can be used as their substitutions for the input of the networks.

In this tutorial, we use yolov8 as an example to show how to train and evaluate object detection models with event histograms instead of RGB images. The main differences between the properties of event histograms and RGB images are:

  • An event histogram has 2 channels (for ON and OFF polarities) while a RGB image has 3 channels.

  • Event histograms are often continuous and stored together in a .npy file when they are converted from an events stream. RGB images, by contrast, are separately stored and often irrelevant to each other.

The modification of the source code mainly targets these two issues.

Prepare the Source Code of Yolov8

Download the source code of yolov8 by:

git clone https://github.com/ultralytics/ultralytics.git

Checkout this commit:

git checkout eb976f5ad20d7779e82f733af4ebe592beaa89b5

We will start modifying the code from this commit.

Prepare the Events Dataset for Object Detection

  1. Make sure Metavision SDK is properly installed.

  2. Make some recordings with an event camera (for example with Metavision Studio). The events are stored in RAW files.

  3. Convert the RAW event files into HDF5 tensor files made of event histograms using Generate HDF5 sample. Those files have the extension .h5.

  4. Generate labels for each event histogram. Each label should contain at least the information of timestamp, position of the bounding box of the object and the object class ID. Timestamp is used to associate the label with the event histogram. The labels should be saved in a numpy array with a customized structured data type: dtype=[('ts', '<u8'), ('x', '<f4'), ('y', '<f4'), ('w', '<f4'), ('h', '<f4'), ('class_id', 'u1')]. The unit of the timestamp(ts) is us. The location of the bounding box(x, y, w, h) is in pixel unit, without normalization. For example, a label [500000, 127., 165., 117., 156., 0] means it corresponds to the 50th event histogram if the event histogram is generated with a time interval of 10ms. There is an object, which corresponds to class 0, in the scene. The top-left anchor of the bounding box to frame the object is [165., 127.] and the height and width of the box is 156. and 117., respectively. The labels corresponding to the event histograms in a xxx.h5 file should be saved in a xxx_bbox.npy file, namely changing the suffix from .h5 to _bbox.npy. And the xxx.h5 file and xxx_bbox.npy file should be placed in the same folder. We provide a labelling tool to facilitate the process. You can get access to it if you are a Prophesee customer by creating your Knowledge Center account.

  5. Group your event histogram and label files into train, val, test folders.

Apply the Modifications to the Source Code

First, download the patch file here:

Download the patch file.

Then navigate to the root directory of the source code of yolov8 in a terminal, execute the following command to apply the changes:

git apply PATH_TO_PATCH_FILE/yolov8_updates_for_eb_data.patch

The explanations of the modifications:

In ultralytics/data/base.py:

  1. Add import h5py at the beginning of the file.

  2. Modify the get_img_files function of the BaseDataset:

    First, change

    im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
    

    to

    im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() == "h5")
    

    because our event histograms are saved in .h5 files and do not have the endings like .jpg, .png, etc.

    Then, comment these two lines as their usage is not longer valid:

    # if self.fraction < 1:
    #     im_files = im_files[:round(len(im_files) * self.fraction)]
    
  3. Modify the _init_ function of the BaseDataset:

    First, change

    self.labels = self.get_labels()
    

    to

    self.labels = self.get_events_labels()
    

    as later we will add this get_events_labels function in the YOLODataset class to load the labels.

    Then, comment this line as we do not store individual images in .npy files:

    # self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
    
  4. Modify the load_image function of the BaseDataset:

    First, comment this piece of code as we will not load any RGB images:

    #"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
    #im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
    #if im is None:  # not cached in RAM
    #    if fn.exists():  # load npy
    #        im = np.load(fn)
    #    else:  # read image
    #        im = cv2.imread(f)  # BGR
    #        if im is None:
    #            raise FileNotFoundError(f'Image Not Found {f}')
    

    Then, add this piece of code below to read event histograms:

    file_idx, frame_idx = self.labels[i]["im_file"]
    ev_frame_identifier = self.im_files[file_idx] + "_frame_" + str(frame_idx)
    with h5py.File(self.im_files[file_idx], "r") as h5_file:
        im = h5_file["data"][frame_idx].transpose(1, 2, 0)
    
    • We need to put the channel dimension of the event histogram to the last dimension so that it can be processed by OpenCV functions such as cv2.resize.

    • ev_frame_identifier is a string recording the location of the event histogram.

    Next, change the line

    return im, (h0, w0), im.shape[:2]
    

    to

    return im, (h0, w0), im.shape[:2], ev_frame_identifier
    

    and comment the line

    # return self.ims[i], self.im_hw0[i], self.im_hw[i]
    
  5. Modify the set_rectangle function of the BaseDataset:

    Comment this line as it is no longer valid:

    # self.im_files = [self.im_files[i] for i in irect]
    
  6. Modify the get_image_and_label function of the BaseDataset:

    Change this line

    label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
    

    to

    label['img'], label['ori_shape'], label['resized_shape'], label['im_file'] = self.load_image(index)
    

    because we also return ev_frame_identifier in the load_image function.

In ultralytics/data/dataset.py:

  1. Add import h5py at the beginning of the file.

  2. Add a function called get_events_labels to the class YOLODataset:

    def get_events_labels(self):
        self.label_files = [x.rsplit('.', 1)[0] + '_bbox.npy' for x in self.im_files]
        all_labels = []
        for file_idx in range(len(self.im_files)):
            with h5py.File(self.im_files[file_idx], "r") as h5_file:
                num_frames, _, height, width = h5_file["data"].shape
            labels = np.load(self.label_files[file_idx])
            x_normalized, y_normalized, w_normalized, h_normalized = labels['x'] / width, labels['y'] / height, labels['w'] / width, labels['h'] / height
            x_centered_normalized = x_normalized + w_normalized / 2
            y_centered_normalized = y_normalized + h_normalized / 2
            bboxes_normalized = np.stack((x_centered_normalized, y_centered_normalized, w_normalized, h_normalized), axis=1)
            classes = labels["class_id"]
            timestamps = labels['ts']
            for frame_idx in range(num_frames):
                box_idxs = np.nonzero(timestamps == (frame_idx+1) * 1e5)[0]
                if len(box_idxs) > 0:
                    all_labels.append(
                        dict(
                            im_file=[file_idx, frame_idx],
                            shape=(height, width),
                            cls=classes[box_idxs, np.newaxis],  # n, 1
                            bboxes=bboxes_normalized[box_idxs, :],  # n, 4
                            segments=[],
                            keypoints=None,
                            normalized=True,
                            bbox_format='xywh'))
        return all_labels
    
    • The previous section mentioned that the labels corresponding to the event histograms in a xxx.h5 file should be saved in a xxx_bbox.npy file in the same folder. So the self.label_files here stored the paths to all the label files.

    • It also mentioned that the x, y fields of the bounding boxes refer to the top-left corner. But in yolov8, they refer to the center of the bounding box. So we need to set x=x+w/2, y=y+h/2. Also the fields x,y,w,h need to be normalized with respect to the height and width of the event histograms in align with the convention of yolov8.

    • A label is matched with an event histogram by its timestamp, i.e. its ts field. For example, if the timestamp of a label is 500000, it corresponds to the 50th event histogram, assuming the event histograms are generated with a time interval of 10ms. If you use a different time interval to generate events, please change the 1e5 in the code to your time interval.

    • Each label is constructed as a dictionary and it contains the information of the location, shape of the event histogram, the class ID and the location of the bounding boxes in the event histogram.

In ultralytics/data/augment.py:

Comment these augmentations in the v8_transforms function:

# Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic),
# CopyPaste(p=hyp.copy_paste),

# MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup),
# Albumentations(p=1.0),
# RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),

since they are no longer valid with respect to event histograms.

In ultralytics/utils/plotting.py:

  1. Add these lines right after the imports of libraries:

    BG_COLOR = np.array([30, 37, 52], dtype=np.uint8)
    POS_COLOR = np.array([216, 223, 236], dtype=np.uint8)
    NEG_COLOR = np.array([64, 126, 201], dtype=np.uint8)
    

    and add a function called viz_histo_binarized somewhere in the file:

    def viz_histo_binarized(im):
        """
        Visualize binarized histogram of events
    
        Args:
            im (np.ndarray): Array of shape (2,H,W)
    
        Returns:
            output_array (np.ndarray): Array of shape (H,W,3)
        """
        img = np.full(im.shape[-2:] + (3,), BG_COLOR, dtype=np.uint8)
        y, x = np.where(im[0] > 0)
        img[y, x, :] = POS_COLOR
        y, x = np.where(im[1] > 0)
        img[y, x, :] = NEG_COLOR
        return img
    

    The function can be utilized to visualize the event histograms as RGB images.

  2. Modify the plot_images function by changing this line:

    im = im.transpose(1, 2, 0)
    

    to

    im = viz_histo_binarized(im.copy())
    

    to convert an event histogram to a 3-channel image which can be properly visualized.

In ultralytics/engine/results.py:

  1. Add import of viz_histo_binarized from ultralytics.utils.plotting:

    from ultralytics.utils.plotting import Annotator, colors, save_one_box, viz_histo_binarized
    
  2. Change the following line in the function plot:

       deepcopy(self.orig_img if img is None else img),
    
    to
    
       deepcopy(viz_histo_binarized(self.orig_img.transpose(2, 0, 1)) if img is None else viz_histo_binarized(img.transpose(2, 0, 1))),
    
    The reason is the same as above: converting an event histogram to a 3-channel image which can be properly visualized.
    

In ultralytics/cfg/models/v8/yolov8.yaml:

Add the line:

ch: 2

to specify the input is expected to have two channels because the event histogram has two channels.

In ultralytics/engine/predictor.py:

Change the line:

self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))

to

self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 2, *self.imgsz))

because the network now expects the input has 2 channels not 3.

In ultralytics/engine/validator.py:

Change the line:

model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz))  # warmup

to

model.warmup(imgsz=(1 if pt else self.args.batch, 2, imgsz, imgsz))  # warmup

Prepare Your Python Environment

We need a python environment that fulfills the requirements listed by yolov8. Once you download the source code, the requirements.txt can be found in the root directory. To install (it is recommended to do it in a virtual environment):

pip install -r requirements.txt

Besides, h5py package needs to be installed to read the events histograms saved in the .h5 files:

pip install h5py

Train a Detection Model

  1. Create a .yaml file and specify your dataset path and class names:

    path: YOUR_DATASET_PATH # dataset root dir
    train: train  # train images (relative to 'path') 128 images
    val: val  # val images (relative to 'path') 128 images
    test:  # test images (optional)
    
    # Classes
    names:
    0: xxx
    1: yyy
    2: zzz
    
  2. Create a .py file and you can train your network with just three lines of code:

    from ultralytics import YOLO
    
    model = YOLO('yolov8n.yaml')
    results = model.train(data='YOUR_YAML_FILE.yaml', amp=False, epochs=10)
    

    The network will be trained for only 10 epochs. Increase the number if you want to train more.

Make Predictions with the Trained Detection Model

The following example shows how to run the detection model over a sequence of event histograms to predict the bounding boxes and the classes. The bounding boxes and classes are shown on the images, which are saved as a video.

from ultralytics import YOLO
import cv2
import h5py
import numpy as np

model = YOLO('YOUR_TRAINED_MODEL.pt')  # load a trained model

events_file = h5py.File("YOUR_TEST_EVENTS_FILE.h5", "r+")
num_ev_histograms, _, height, width = events_file['data'].shape

out = cv2.VideoWriter('OUTPUT_FOLDER/output.mp4', cv2.VideoWriter_fourcc(*'MP4V'), 20.0, (width, height))
for idx in range(num_ev_histograms):
    ev_histo = np.transpose(events_file['data'][idx], (1, 2, 0))
    results = model(ev_histo)  # return a generator of Results objects
    annotated_frame = results[0].plot()
    out.write(cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR))
out.release()