OpenSTL Tutorial-1
Introduction:
OpenSTL是一个全面的时空预测学习基准,将常见的方法分为recurrent-based和recurrent-free的模型两大类。OpenSTL提供了一个模块化和可扩展的框架,实现了各种最先进的方法。主要的特点有:
- 灵活的代码设计。OpenSTL将STL算法分解为方法(训练和预测)、模型(网络架构)和模块,并提供统一的实验API。用户可以根据不同的STL任务使用灵活的训练策略和网络开发自己的STL算法。
- 标准化基准。OpenSTL将支持STL算法的标准化基准,包括训练和评估,类似于许多开源项目(例如MMDetection和USB等)。
- 支持多种模型和任务。OpenSTL包含了十四种有代表性的时空预测学习算法和二十四种模型,涵盖了从合成移动物体轨迹到真实世界的人体动作、驾驶场景、交通流量和天气预测等六类任务和十余个数据集。
介绍来自:https://zhuanlan.zhihu.com/p/640271275
Pipeline
tutorial.ipynb
Openstl库在example文件夹下实现了一个基本的训练、测试以及可视化的pipeline。总体上分四个部分:1. 数据预处理;2. Dataset定义;3.训练+测试;4.可视化
数据预处理
tutorial给出的预处理部分实际上是对视频进行采样,之后将采样得到的视频帧转成pkl格式,通过pickle.load()直接读取输入帧序列(train_x) 和 输出帧序列(train_y)。
通过tutorial的预处理模块,train_x和train_y的shape被定义为(B,T,C,H,W),B代表样本数,T代表帧数,C被定义为图像的通道,H、W为图像的长宽。
实际上该功能也可以从dataset直接进行采样。将在Dataset定义一节中详细说明。
Dataset定义
tutorial中通过预处理,将输入输出变成了(B,T,C,H,W),所以在采样(__getitem__
)的时候对第一个维度进行索引即可。同时在重建任务中,一般会对数据采用归一化/标准化。给出的代码使用的是标准化。
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, X, Y, normalize=False, data_name='custom'):
super(CustomDataset, self).__init__()
# 读取输入输出
self.X = X #(B,T,C,H,W)
self.Y = Y #(B,T,C,H,W)
self.mean = None
self.std = None
self.data_name = data_name
if normalize:
# get the mean/std values along the channel dimension
mean = data.mean(axis=(0, 1, 2, 3)).reshape(1, 1, -1, 1, 1)
std