OpenSTL(未来帧预测)从零到掌握教程-1

OpenSTL Tutorial-1

Introduction:

OpenSTL是一个全面的时空预测学习基准,将常见的方法分为recurrent-based和recurrent-free的模型两大类。OpenSTL提供了一个模块化和可扩展的框架,实现了各种最先进的方法。主要的特点有:

  • 灵活的代码设计。OpenSTL将STL算法分解为方法(训练和预测)、模型(网络架构)和模块,并提供统一的实验API。用户可以根据不同的STL任务使用灵活的训练策略和网络开发自己的STL算法。
  • 标准化基准。OpenSTL将支持STL算法的标准化基准,包括训练和评估,类似于许多开源项目(例如MMDetection和USB等)。
  • 支持多种模型和任务。OpenSTL包含了十四种有代表性的时空预测学习算法和二十四种模型,涵盖了从合成移动物体轨迹到真实世界的人体动作、驾驶场景、交通流量和天气预测等六类任务和十余个数据集。

介绍来自:https://zhuanlan.zhihu.com/p/640271275

Pipeline

tutorial.ipynb

​ Openstl库在example文件夹下实现了一个基本的训练、测试以及可视化的pipeline。总体上分四个部分:1. 数据预处理;2. Dataset定义;3.训练+测试;4.可视化

数据预处理

​ tutorial给出的预处理部分实际上是对视频进行采样,之后将采样得到的视频帧转成pkl格式,通过pickle.load()直接读取输入帧序列(train_x) 和 输出帧序列(train_y)。

​ 通过tutorial的预处理模块,train_x和train_y的shape被定义为(B,T,C,H,W),B代表样本数,T代表帧数,C被定义为图像的通道,H、W为图像的长宽。

​ 实际上该功能也可以从dataset直接进行采样。将在Dataset定义一节中详细说明。

Dataset定义

​ tutorial中通过预处理,将输入输出变成了(B,T,C,H,W),所以在采样(__getitem__)的时候对第一个维度进行索引即可。同时在重建任务中,一般会对数据采用归一化/标准化。给出的代码使用的是标准化。

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, X, Y, normalize=False, data_name='custom'):
        super(CustomDataset, self).__init__()
        # 读取输入输出
        self.X = X #(B,T,C,H,W)
        self.Y = Y #(B,T,C,H,W)
        self.mean = None
        self.std = None
        self.data_name = data_name

        if normalize:
            # get the mean/std values along the channel dimension
            mean = data.mean(axis=(0, 1, 2, 3)).reshape(1, 1, -1, 1, 1)
            std 
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Eason_12138

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值