Train and Test Event-based Yolo Detection Model
Introduction
The goal of this tutorial is to demonstrate how to quickly leverage the existing popular frame-based neural networks for event-based vision with minimum modifications. Typically, these frame-based networks input RGB images and make predictions in different formats depending on the task. Event frames, such as event histograms, encode similar visual information as RGB images and can be used as their substitutions for the input of the networks.
In this tutorial, we use yolov8 as an example to show how to train and evaluate object detection models with event histograms instead of RGB images. The main differences between the properties of event histograms and RGB images are:
An event histogram has 2 channels (for ON and OFF polarities) while a RGB image has 3 channels.
Event histograms are often continuous and stored together in a
.npy
file when they are converted from an events stream. RGB images, by contrast, are separately stored and often irrelevant to each other.
The modification of the source code mainly targets these two issues.
Prepare the Source Code of Yolov8
Download the source code of yolov8 by:
git clone https://github.com/ultralytics/ultralytics.git
Checkout this commit:
git checkout eb976f5ad20d7779e82f733af4ebe592beaa89b5
We will start modifying the code from this commit.
Prepare the Events Dataset for Object Detection
Make sure Metavision SDK is properly installed.
Make some recordings with an event camera (for example with Metavision Studio). The events are stored in RAW files.
Convert the RAW event files into HDF5 tensor files made of event histograms using Generate HDF5 sample. Those files have the extension
.h5
.Generate labels for each event histogram. Each label should contain at least the information of timestamp, position of the bounding box of the object and the object class ID. Timestamp is used to associate the label with the event histogram. The labels should be saved in a
numpy
array with a customized structured data type:dtype=[('ts', '<u8'), ('x', '<f4'), ('y', '<f4'), ('w', '<f4'), ('h', '<f4'), ('class_id', 'u1')]
. The unit of the timestamp(ts
) is us. The location of the bounding box(x, y, w, h
) is in pixel unit, without normalization. For example, a label[500000, 127., 165., 117., 156., 0]
means it corresponds to the50th
event histogram if the event histogram is generated with a time interval of10ms
. There is an object, which corresponds to class0
, in the scene. The top-left anchor of the bounding box to frame the object is[165., 127.]
and the height and width of the box is156.
and117.
, respectively. The labels corresponding to the event histograms in axxx.h5
file should be saved in axxx_bbox.npy
file, namely changing the suffix from.h5
to_bbox.npy
. And thexxx.h5
file andxxx_bbox.npy
file should be placed in the same folder. We provide a labelling tool to facilitate the process. You can get access to it if you are a Prophesee customer by creating your Knowledge Center account.Group your event histogram and label files into
train
,val
,test
folders.
Apply the Modifications to the Source Code
First, download the patch file here:
Then navigate to the root directory of the source code of yolov8 in a terminal, execute the following command to apply the changes:
git apply PATH_TO_PATCH_FILE/yolov8_updates_for_eb_data.patch
The explanations of the modifications:
In ultralytics/data/base.py
:
Add
import h5py
at the beginning of the file.Modify the
get_img_files
function of theBaseDataset
:First, change
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() in IMG_FORMATS)
to
im_files = sorted(x.replace('/', os.sep) for x in f if x.split('.')[-1].lower() == "h5")
because our event histograms are saved in
.h5
files and do not have the endings like.jpg
,.png
, etc.Then, comment these two lines as their usage is not longer valid:
# if self.fraction < 1: # im_files = im_files[:round(len(im_files) * self.fraction)]
Modify the
_init_
function of theBaseDataset
:First, change
self.labels = self.get_labels()
to
self.labels = self.get_events_labels()
as later we will add this
get_events_labels
function in theYOLODataset
class to load the labels.Then, comment this line as we do not store individual images in
.npy
files:# self.npy_files = [Path(f).with_suffix('.npy') for f in self.im_files]
Modify the
load_image
function of theBaseDataset
:First, comment this piece of code as we will not load any RGB images:
#"""Loads 1 image from dataset index 'i', returns (im, resized hw).""" #im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i] #if im is None: # not cached in RAM # if fn.exists(): # load npy # im = np.load(fn) # else: # read image # im = cv2.imread(f) # BGR # if im is None: # raise FileNotFoundError(f'Image Not Found {f}')
Then, add this piece of code below to read event histograms:
file_idx, frame_idx = self.labels[i]["im_file"] ev_frame_identifier = self.im_files[file_idx] + "_frame_" + str(frame_idx) with h5py.File(self.im_files[file_idx], "r") as h5_file: im = h5_file["data"][frame_idx].transpose(1, 2, 0)
We need to put the channel dimension of the event histogram to the last dimension so that it can be processed by
OpenCV
functions such ascv2.resize
.ev_frame_identifier
is a string recording the location of the event histogram.
Next, change the line
return im, (h0, w0), im.shape[:2]
to
return im, (h0, w0), im.shape[:2], ev_frame_identifier
and comment the line
# return self.ims[i], self.im_hw0[i], self.im_hw[i]
Modify the
set_rectangle
function of theBaseDataset
:Comment this line as it is no longer valid:
# self.im_files = [self.im_files[i] for i in irect]
Modify the
get_image_and_label
function of theBaseDataset
:Change this line
label['img'], label['ori_shape'], label['resized_shape'] = self.load_image(index)
to
label['img'], label['ori_shape'], label['resized_shape'], label['im_file'] = self.load_image(index)
because we also return
ev_frame_identifier
in theload_image
function.
In ultralytics/data/dataset.py
:
Add
import h5py
at the beginning of the file.Add a function called
get_events_labels
to the classYOLODataset
:def get_events_labels(self): self.label_files = [x.rsplit('.', 1)[0] + '_bbox.npy' for x in self.im_files] all_labels = [] for file_idx in range(len(self.im_files)): with h5py.File(self.im_files[file_idx], "r") as h5_file: num_frames, _, height, width = h5_file["data"].shape labels = np.load(self.label_files[file_idx]) x_normalized, y_normalized, w_normalized, h_normalized = labels['x'] / width, labels['y'] / height, labels['w'] / width, labels['h'] / height x_centered_normalized = x_normalized + w_normalized / 2 y_centered_normalized = y_normalized + h_normalized / 2 bboxes_normalized = np.stack((x_centered_normalized, y_centered_normalized, w_normalized, h_normalized), axis=1) classes = labels["class_id"] timestamps = labels['ts'] for frame_idx in range(num_frames): box_idxs = np.nonzero(timestamps == (frame_idx+1) * 1e5)[0] if len(box_idxs) > 0: all_labels.append( dict( im_file=[file_idx, frame_idx], shape=(height, width), cls=classes[box_idxs, np.newaxis], # n, 1 bboxes=bboxes_normalized[box_idxs, :], # n, 4 segments=[], keypoints=None, normalized=True, bbox_format='xywh')) return all_labels
The previous section mentioned that the labels corresponding to the event histograms in a
xxx.h5
file should be saved in axxx_bbox.npy
file in the same folder. So theself.label_files
here stored the paths to all the label files.It also mentioned that the
x, y
fields of the bounding boxes refer to the top-left corner. But in yolov8, they refer to the center of the bounding box. So we need to setx=x+w/2, y=y+h/2
. Also the fieldsx,y,w,h
need to be normalized with respect to theheight
andwidth
of the event histograms in align with the convention of yolov8.A label is matched with an event histogram by its timestamp, i.e. its
ts
field. For example, if the timestamp of a label is500000
, it corresponds to the50th
event histogram, assuming the event histograms are generated with a time interval of10ms
. If you use a different time interval to generate events, please change the1e5
in the code to your time interval.Each label is constructed as a dictionary and it contains the information of the location, shape of the event histogram, the class ID and the location of the bounding boxes in the event histogram.
In ultralytics/data/augment.py
:
Comment these augmentations in the v8_transforms
function:
# Mosaic(dataset, imgsz=imgsz, p=hyp.mosaic), # CopyPaste(p=hyp.copy_paste), # MixUp(dataset, pre_transform=pre_transform, p=hyp.mixup), # Albumentations(p=1.0), # RandomHSV(hgain=hyp.hsv_h, sgain=hyp.hsv_s, vgain=hyp.hsv_v),since they are no longer valid with respect to event histograms.
In ultralytics/utils/plotting.py
:
Add these lines right after the imports of libraries:
BG_COLOR = np.array([30, 37, 52], dtype=np.uint8) POS_COLOR = np.array([216, 223, 236], dtype=np.uint8) NEG_COLOR = np.array([64, 126, 201], dtype=np.uint8)
and add a function called
viz_histo_binarized
somewhere in the file:def viz_histo_binarized(im): """ Visualize binarized histogram of events Args: im (np.ndarray): Array of shape (2,H,W) Returns: output_array (np.ndarray): Array of shape (H,W,3) """ img = np.full(im.shape[-2:] + (3,), BG_COLOR, dtype=np.uint8) y, x = np.where(im[0] > 0) img[y, x, :] = POS_COLOR y, x = np.where(im[1] > 0) img[y, x, :] = NEG_COLOR return img
The function can be utilized to visualize the event histograms as RGB images.
Modify the
plot_images
function by changing this line:im = im.transpose(1, 2, 0)
to
im = viz_histo_binarized(im.copy())
to convert an event histogram to a 3-channel image which can be properly visualized.
In ultralytics/engine/results.py
:
Add import of
viz_histo_binarized
fromultralytics.utils.plotting
:from ultralytics.utils.plotting import Annotator, colors, save_one_box, viz_histo_binarized
Change the following line in the function
plot
:deepcopy(self.orig_img if img is None else img), to
deepcopy(viz_histo_binarized(self.orig_img.transpose(2, 0, 1)) if img is None else viz_histo_binarized(img.transpose(2, 0, 1))), The reason is the same as above: converting an event histogram to a 3-channel image which can be properly visualized.
In ultralytics/cfg/models/v8/yolov8.yaml
:
Add the line:
ch: 2
to specify the input is expected to have two channels because the event histogram has two channels.
In ultralytics/engine/predictor.py
:
Change the line:
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
to
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 2, *self.imgsz))
because the network now expects the input has 2 channels not 3.
In ultralytics/engine/validator.py
:
Change the line:
model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
to
model.warmup(imgsz=(1 if pt else self.args.batch, 2, imgsz, imgsz)) # warmup
Prepare Your Python Environment
We need a python environment that fulfills the requirements
listed by yolov8. Once you download the source code, the requirements.txt
can be found in the root
directory. To install (it is recommended to do it in a virtual environment):
pip install -r requirements.txt
Besides, h5py
package needs to be installed to read the events histograms saved in the .h5
files:
pip install h5py
Train a Detection Model
Create a
.yaml
file and specify your dataset path and class names:path: YOUR_DATASET_PATH # dataset root dir train: train # train images (relative to 'path') 128 images val: val # val images (relative to 'path') 128 images test: # test images (optional) # Classes names: 0: xxx 1: yyy 2: zzz
Create a
.py
file and you can train your network with just three lines of code:from ultralytics import YOLO model = YOLO('yolov8n.yaml') results = model.train(data='YOUR_YAML_FILE.yaml', amp=False, epochs=10)
The network will be trained for only 10 epochs. Increase the number if you want to train more.
Make Predictions with the Trained Detection Model
The following example shows how to run the detection model over a sequence of event histograms to predict the bounding boxes and the classes. The bounding boxes and classes are shown on the images, which are saved as a video.
from ultralytics import YOLO
import cv2
import h5py
import numpy as np
model = YOLO('YOUR_TRAINED_MODEL.pt') # load a trained model
events_file = h5py.File("YOUR_TEST_EVENTS_FILE.h5", "r+")
num_ev_histograms, _, height, width = events_file['data'].shape
out = cv2.VideoWriter('OUTPUT_FOLDER/output.mp4', cv2.VideoWriter_fourcc(*'MP4V'), 20.0, (width, height))
for idx in range(num_ev_histograms):
ev_histo = np.transpose(events_file['data'][idx], (1, 2, 0))
results = model(ev_histo) # return a generator of Results objects
annotated_frame = results[0].plot()
out.write(cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR))
out.release()