TimeSformer代码复现


一、TimeSformer

论文:Is Space-Time Attention All You Need for Video Understanding?
代码:TimeSformer

1.创建环境

# 创建环境
conda create -n TimeSformer python=3.7 -y
# 激活环境
conda activate TimeSformer
# 按照官方步骤安装包
pip install torchvisionconda install torchvision -c pytorch
pip install 'git+https://github.com/facebookresearch/fvcore'
pip install simplejson
pip install einops
pip install timm
conda install av -c conda-forge
pip install psutil
pip install scikit-learn
pip install opencv-python
pip install tensorboard

最后,通过运行以下命令来构建 TimeSformer 代码库:

git clone https://github.com/facebookresearch/TimeSformer
cd TimeSformer
python setup.py build develop

2.数据集

(1)Kinetics-400是视频领域benchmark常用数据集,详细介绍可以参考其官方网站Kinetics。整个数据集包含400个类别,全部文件大概需要135G左右的存储空间,下载起来比较困难。
Tiny-Kinetics-400同样包含400个类别,每个类别下仅有两条视频数据,分为train与val,可用于调试一些视频理解模型。
目前提供了百度网盘的下载方式:
百度云:Baidu
Google:google

(2)UCF101是一个现实动作视频的动作识别数据集,收集自YouTube,提供了来自101个动作类别的13320个视频。
本人采用的是UCF50数据集:UCF50
处理数据集,生成.csv文件的代码如下:

import os
import csv
import shutil
from tqdm import tqdm
from sklearn.model_selection import train_test_split

out_dir = "D:/ALLUsers/hyy/TimeSformer-main/dataUCF"  # 输出路径
video_path = "D:/ALLUsers/hyy/TimeSformer-main/UCF50" # 数据集路径
file_name = ".csv"
video_name = ".avi"
name_list = ["train","test","val"]
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    os.mkdir(os.path.join(out_dir, 'train'))
    os.mkdir(os.path.join(out_dir, 'val'))
    os.mkdir(os.path.join(out_dir, 'test'))
for file in os.listdir(video_path):
        file_path = os.path.join(video_path, file)
        video_files = [name for name in os.listdir(file_path)]
        train_and_valid, test = train_test_split(video_files, test_size=0.2, random_state=42)
        train, val = train_test_split(train_and_valid, test_size=0.2, random_state=42)
        train_dir = os.path.join(out_dir, 'train', file)
        val_dir = os.path.join(out_dir, 'val', file)
        test_dir = os.path.join(out_dir, 'test', file)
        if not os.path.exists(train_dir):
            os.mkdir(train_dir)
        if not os.path.exists(val_dir):
            os.mkdir(val_dir)
        if not os.path.exists(test_dir):
            os.mkdir(test_dir)
        for video in tqdm(train):
           shutil.copy(os.path.join(video_path,file,video),os.path.join(train_dir,video))
        for video in tqdm(test):
            shutil.copy(os.path.join(video_path,file,video),os.path.join(test_dir,video))
        for video in tqdm(val):
            shutil.copy(os.path.join(video_path,file,video),os.path.join(val_dir,video))
if not os.path.exists(os.path.join(out_dir,"csv")):
    os.mkdir(os.path.join(out_dir,"csv"))
    for name in name_list:
        with open(os.path.join(out_dir,"csv",name+file_name),'wb') as f:
            print("创建"+os.path.join(out_dir,"csv",name+file_name))
csv_path = os.path.join(out_dir,"csv")
print(csv_path)
for ii in os.listdir(path=csv_path):
    if ii.split(".")[0] in name_list:
        path1 = os.path.join(csv_path,ii)
        with open(path1, 'w', newline='') as f:
            for dd in os.listdir(out_dir):
                if dd==ii.split(".")[0]:
                    for zz in os.listdir(os.path.join(out_dir,dd)):
                        for mm in os.listdir(os.path.join(out_dir,dd,zz)):
                            writer = csv.writer(f)
                            writer.writerow([os.path.join(out_dir,dd,zz,mm),zz])

## 创建类别label标号文件
labels= []
for label in sorted(os.listdir(video_path)):
    labels.append(label)
label2index = {label: index for index, label in enumerate(sorted(set(labels)))}
label_file = os.path.join(out_dir, str(len(os.listdir(video_path))) + 'class_labels.txt')
with open(label_file, 'w') as f:
    for id, label in enumerate(sorted(label2index)):
        f.writelines(str(id) + ' ' + label +'\n')
#替换csv文件中类别名为数字
csv_file = os.path.join(out_dir,"csv")
def txt_read(files):
    txt_dict = {}
    fopen = open(files)
    for line in fopen.readlines():
        line = str(line).replace('\n','')
        txt_dict[line.split(' ',1)[1]] = line.split(' ',1)[0]      
    fopen.close()
    return txt_dict
txt_dict = txt_read(label_file)
print(txt_dict)

for ii in os.listdir(csv_file):
    path1 = os.path.join(csv_file,ii)
    r = csv.reader(open(path1))
    lines = [l for l in r]
    for i in range(len(lines)):
        cs = lines[i][1]
        value = txt_dict[cs]
        lines[i][1] = value
    writer = csv.writer(open(path1, 'w'))
    writer.writerows(lines)
    
# 由于生成的.csv文件中有多余空行,导致读入数据失败,这里我添加了删除空行的代码
def remove_blank_lines_from_csv(folder_path):
    # 遍历文件夹中的文件
    for file_name in os.listdir(folder_path):
        file_path = os.path.join(folder_path, file_name)
        # 检查文件是否为CSV文件
        if file_name.endswith('.csv'):
            # 读取CSV文件并删除空白行
            with open(file_path, 'r', newline='') as csv_file:
                csv_reader = csv.reader(csv_file)
                lines = [line for line in csv_reader if line]
            # 写入CSV文件(覆盖原文件)
            with open(file_path, 'w', newline='') as csv_file:
                csv_writer = csv.writer(csv_file)
                csv_writer.writerows(lines)
# 指定要遍历的文件夹路径
folder_path = csv_file
# 调用函数删除每个CSV文件中的空白行
remove_blank_lines_from_csv(folder_path)

3.修改参数

TimeSformermain\configs\Kinetics\TimeSformer_divST_8x32_224.yaml

TRAIN:
  ENABLE: True
  DATASET: kinetics
  BATCH_SIZE: 4  # 修改batch size
  EVAL_PERIOD: 5
  CHECKPOINT_PERIOD: 5
  AUTO_RESUME: True
DATA:
  PATH_TO_DATA_DIR: D:/ALLUsers/hyy/TimeSformer-main/dataUCF/csv/  # 修改数据集csv文件的路径
  NUM_FRAMES: 8
  SAMPLING_RATE: 32
  TRAIN_JITTER_SCALES: [256, 320]
  TRAIN_CROP_SIZE: 224
  TEST_CROP_SIZE: 224
  INPUT_CHANNEL_NUM: [3]
TIMESFORMER:
  ATTENTION_TYPE: 'divided_space_time'
SOLVER:
  BASE_LR: 0.005
  LR_POLICY: steps_with_relative_lrs
  STEPS: [0, 11, 14]
  LRS: [1, 0.1, 0.01]
  MAX_EPOCH: 15
  MOMENTUM: 0.9
  WEIGHT_DECAY: 1e-4
  OPTIMIZING_METHOD: sgd
MODEL:
  MODEL_NAME: vit_base_patch16_224
  NUM_CLASSES: 50  #根据使用的数据集修改类别数 NUM_CLASSES
  ARCH: vit
  LOSS_FUNC: cross_entropy
  DROPOUT_RATE: 0.5
TEST:
  ENABLE: False
  DATASET: kinetics
  BATCH_SIZE: 8
  NUM_ENSEMBLE_VIEWS: 1
  NUM_SPATIAL_CROPS: 3
DATA_LOADER:
  NUM_WORKERS: 8
  PIN_MEMORY: True
NUM_GPUS: 1  # 根据自己GPU选择
NUM_SHARDS: 1
RNG_SEED: 0
OUTPUT_DIR: ./output/ # 输出路径

4.运行代码

TimeSformer-main\timesformer\utils\parser.py line 48 修改路径

 parser.add_argument(
        "--cfg",
        dest="cfg_file",
        help="Path to the config file",
        default="D:\ALLUsers\hyy\TimeSformer-main\configs\Kinetics\TimeSformer_divST_8x32_224.yaml",
        type=str,
    )

TimeSformer-main\timesformer\config\defaults.py line63 修改路径
.pyth文件下载自TimeSformer

_C.TRAIN.CHECKPOINT_FILE_PATH = "TimeSformer-main/tools/checkpoints/TimeSformer_divST_8x32_224_K400.pyth"

运行代码

python tools/run_net.py 

问题

1.TimeSformer-main/timesformer/models/resnet_helper.py", line 15,
cannot import name ‘_LinearWithBias’ from 'torch.nn.modules.linear
是pytorch版本问题,在torch>1.10之后就没有_LinearWithBias了

将import的代码替换为
# from torch.nn.modules.linear import _LinearWithBias
if float(torch.__version__.split('.')[0]) == 0 or (float(torch.__version__.split('.')[0]) == 1 and float(torch.__version__.split('.')[1])) < 9:
    from torch.nn.modules.linear import _LinearWithBias
else:
    from torch.nn.modules.linear import NonDynamicallyQuantizableLinear as _LinearWithBias

2.TimeSformer-main/timesformer/models/vit_utils.py", line 14
from torch._six import container_abcs
ImportError: cannot import name ‘container_abcs’ from ‘torch._six’

将import的代码替换为
# from torch._six import container_abcs
    import collections.abc as container_abcs
    int_classes = int
    string_classes = str

3.TimeSformer-main/timesformer/datasets/multigrid_helper.py", line 6
ImportError: cannot import name ‘int_classes’ from ‘torch._six’

修改代码
# from torch._six import int_classes as _int_classes
int_classes = int

not isinstance(batch_size, _int_classes)修改为:not isinstance(batch_size, int_classes)
  1. 根据生成的.csv文件中,video_path 与 label之间的分隔符决定,这里是 , 参数默认为空格,需要修改
    在这里插入图片描述
_C.DATA.PATH_LABEL_SEPARATOR = ","

参考:TimesFormer Ubuntun环境搭建及训练自己数据集

  • 23
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
GitHub代码复现是指在GitHub上找到一个感兴趣或有用的开源代码项目,并通过阅读代码、运行代码并进行修改来重新实现或再次创建整个项目。 首先,需要在GitHub上搜索并找到目标项目。可以通过GitHub的搜索功能,输入关键词、项目名称、描述等来筛选出符合条件的项目。选择一个代码质量好、维护活跃的项目会更有保障。 一旦找到了目标项目,就可以clone(克隆)该项目到本地。可以使用git命令行或者通过GitHub Desktop等工具进行操作。克隆项目后,就可以在本地对代码进行修改、调试、定制等。 接下来,对项目进行配置和安装依赖。一般来说,项目中会有一个readme文件或者其他文档来指导配置环境和安装所需的依赖包。根据项目要求进行配置和安装。 然后,就可以运行项目了。根据项目的要求,可能需要提供一些参数或者数据集。根据项目的文档,在终端或者IDE中运行相应的命令或者程序。 当项目运行成功后,就可以根据自己的需求对代码进行修改和优化。可以根据项目的架构和实现逻辑进行更改,添加新的功能,或者提升代码的性能等。 最后,如果对项目的改进比较显著,可以考虑提交自己的贡献给项目的维护者。可以通过Fork项目、修改代码、提交Pull Request等方式向项目提交自己的改动。项目维护者会进行代码审查,并决定是否接受你的改动。 总之,GitHub代码复现是一个学习和交流的过程。通过复现别人的代码,可以提升自己的编程能力,了解项目的实现细节,还可以与其他开发者交流、合作,共同提高。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值