通过mmaction2实现slowfast的训练

参考借鉴:slowfast训练自定义数据集,识别动物行为_slowfast复现训练小狗_盛世芳华的博客-CSDN博客

前言

提示:这里可以添加本文要记录的大概内容:

通过实验了各种slowfast行为检测模型的训练和测试过程,从中总结了两种好用的方法。这边我做的是教室学生检测系统。


提示:以下是本篇文章正文内容,下面案例可供参考

一、数据集的准备

1.视频数据集准备

数据结构:

     mmaction2
        ├── data
        │   ├── ava
        │   │   ├── annotations
        │   │   |   ├── ava_dense_proposals_train.FAIR.recall_93.9.pkl
        │   │   |   ├── ava_dense_proposals_val.FAIR.recall_93.9.pkl
        │   │   |   ├── ava_dense_proposals_test.FAIR.recall_93.9.pkl
        │   │   |   ├── ava_train_v2.1.csv
        │   │   |   ├── ava_val_v2.1.csv
        │   │   |   ├── ava_train_excluded_timestamps_v2.1.csv
        │   │   |   ├── ava_val_excluded_timestamps_v2.1.csv
        │   │   |   ├── ava_action_list_v2.1.pbtxt

准备一段一分钟以上的视频用于剪切视频,需要将这段视频分割成两段30秒的视频,用于抽帧。

slowfast数据集抽帧分为两部分,一部分是1秒抽一帧图片用于标注,另一种1秒抽30帧图片,目的是为了训练,因为slowfast在slow流里1秒会采集到15帧,在fast流里1秒会采集到2帧。

本文使用ffmpeg进行视频裁剪与抽帧,所以先安装ffmpeg

conda install x264 ffmpeg -c conda-forge -y

创建脚本 cut_video.sh

注意:

sh脚本如果在pycharm运行需要自行安装git,这里参考【精选】Git 详细安装教程(详解 Git 安装过程的每一个步骤)_git安装_mukes的博客-CSDN博客

 实现pycharm运行.sh文件——本地运行和打开服务器终端_sh文件打开方式执行脚本弹出终端运行_Mecv_清痕的博客-CSDN博客

IN_DATA_DIR="./ava/videos"
OUT_DATA_DIR="./ava/video_cut"
ffmpeg -ss 0 -t 30 -y -i "${IN_DATA_DIR}/1.mp4" "${OUT_DATA_DIR}/1.mp4"
ffmpeg -ss 31 -t 30 -y -i "${IN_DATA_DIR}/1.mp4" "${OUT_DATA_DIR}/2.mp4"

 其中IN_DATA_DIR=放置你准备的视频所在文件夹位置

OUT_DATA_DIR=是剪切完视频后保存的位置

-ss后面的是切割视频的起始时间,后面的30是需要切割的视频长度,如果你准备的视频时长过段,就选择从0s开始切割。

2.视频抽帧

创建脚本

video2img.py

import os
import shutil
from tqdm import tqdm
start = 0  #############
seconds = 30  ##############
 
video_path = './ava/videos'
labelframes_path = './ava/labelframes'
rawframes_path = './ava/rawframes'
cut_videos_sh_path = './cut_videos.sh'
 
if os.path.exists(labelframes_path):
    shutil.rmtree(labelframes_path)
if os.path.exists(rawframes_path):
    shutil.rmtree(rawframes_path)
 
fps = 30
raw_frames = seconds*fps
 
with open(cut_videos_sh_path, 'r') as f:
    sh = f.read()
sh = sh.replace(sh[sh.find('    ffmpeg'):], f'    ffmpeg -ss {start} -t {seconds} -i "${{video}}" -r 30 -strict experimental "${{out_name}}"\n  fi\ndone\n')
with open(cut_videos_sh_path, 'w') as f:
    f.write(sh)
# 902打到1798
os.system('bash cut_videos.sh')
os.system('bash extract_rgb_frames_ffmpeg.sh')
os.makedirs(labelframes_path, exist_ok=True)
video_ids = [video_id[:-4] for video_id in os.listdir(video_path)]
for video_id in tqdm(video_ids):
    for img_id in range(2*fps+1, (seconds-2)*30, fps):
        shutil.copyfile(os.path.join(rawframes_path, video_id, 'img_'+format(img_id, '05d')+'.jpg'),
                        os.path.join(labelframes_path, video_id+'_'+format(start+img_id//30, '05d')+'.jpg'))

  extract_rgb_frames_ffmpeg.sh

IN_DATA_DIR="./ava/videos_cut"
OUT_DATA_DIR="./ava/rawframes"
 
if [[ ! -d "${OUT_DATA_DIR}" ]]; then
  echo "${OUT_DATA_DIR} doesn't exist. Creating it.";
  mkdir -p ${OUT_DATA_DIR}
fi
 
for video in $(ls -A1 -U ${IN_DATA_DIR}/*)
do
  video_name=${video##*/}
 
  if [[ $video_name = *".webm" ]]; then
    video_name=${video_name::-5}
  else
    video_name=${video_name::-4}
  fi
 
  out_video_dir=${OUT_DATA_DIR}/${video_name}
  mkdir -p "${out_video_dir}"
 
  out_name="${out_video_dir}/img_%05d.jpg"
 
  ffmpeg -i "${video}" -r 30 -q:v 1 "${out_name}"
done

 cut_videos.sh

IN_DATA_DIR="./ava/videos"
OUT_DATA_DIR="./ava/videos_cut"
 
if [[ ! -d "${OUT_DATA_DIR}" ]]; then
  echo "${OUT_DATA_DIR} doesn't exist. Creating it.";
  mkdir -p ${OUT_DATA_DIR}
fi
 
for video in $(ls -A1 -U ${IN_DATA_DIR}/*)
do
  out_name="${OUT_DATA_DIR}/${video##*/}"
  if [ ! -f "${out_name}" ]; then
    ffmpeg -ss 0 -t 3 -i "${video}" -r 30 -strict experimental "${out_name}"
  fi
done

将这三个代码放在剪切的视频数据同目录下写入自己的文件目录

video_path = './ava/videos'           原始视频目录
labelframes_path = './ava/labelframes'        抽帧后生成的标注图片目录
rawframes_path = './ava/rawframes'        抽帧后生成的slowfast训练图片目录
cut_videos_sh_path = './cut_videos.sh'        保存剪切完成的视频目录

 执行完会在目录下面生成

其中labelframes是需要标注的数据集,rawframes是需要训练的slowfast数据集。

3.via标注slowfast标签

这一部分可以用训练好的yolov5或者yolov7进行预标注标完框后生成json文件导入到via标注网页里面进行标注,这样可以节约些时间。我测试标注的数据集数目较少,就采用了部分yolov5标框,剩余手动标框。

这里可以借鉴:【精选】自定义ava数据集及训练与测试 完整版 时空动作/行为 视频数据集制作 yolov5, deep sort, VIA MMAction, SlowFast-CSDN博客

via获取

链接:https://pan.baidu.com/s/1oXjcUnbVKC8fjoTiqRLtmw 
提取码:7ctu 

加号是添加需要标注的图片,白色的文件夹可以导入之前标注完成的json文件。

点击export导出需要的标注的csv文件。

 

via2ava.py

"""
Theme:ava format data transformer
author:Hongbo Jiang
time:2022/3/14/1:51:51
description:
    
    这是一个数据格式转换器,根据mmaction2的ava数据格式转换规则将来自网站:
    https://www.robots.ox.ac.uk/~vgg/software/via/app/via_video_annotator.html
    的、标注好的、视频理解类型的csv文件转换为mmaction2指定的数据格式。
    转换规则:
        # AVA Annotation Explained
        In this section, we explain the annotation format of AVA in details:
        ```
        mmaction2
        ├── data
        │   ├── ava
        │   │   ├── annotations
        │   │   |   ├── ava_dense_proposals_train.FAIR.recall_93.9.pkl
        │   │   |   ├── ava_dense_proposals_val.FAIR.recall_93.9.pkl
        │   │   |   ├── ava_dense_proposals_test.FAIR.recall_93.9.pkl
        │   │   |   ├── ava_train_v2.1.csv
        │   │   |   ├── ava_val_v2.1.csv
        │   │   |   ├── ava_train_excluded_timestamps_v2.1.csv
        │   │   |   ├── ava_val_excluded_timestamps_v2.1.csv
        │   │   |   ├── ava_action_list_v2.1.pbtxt
        ```
        ## The proposals generated by human detectors
        In the annotation folder, `ava_dense_proposals_[train/val/test].FAIR.recall_93.9.pkl` are human proposals generated by a human detector. They are used in training, validation and testing respectively. Take `ava_dense_proposals_train.FAIR.recall_93.9.pkl` as an example. It is a dictionary of size 203626. The key consists of the `videoID` and the `timestamp`. For example, the key `-5KQ66BBWC4,0902` means the values are the detection results for the frame at the $$902_{nd}$$ second in the video `-5KQ66BBWC4`. The values in the dictionary are numpy arrays with shape $$N \times 5$$ , $$N$$ is the number of detected human bounding boxes in the corresponding frame. The format of bounding box is $$[x_1, y_1, x_2, y_2, score], 0 \le x_1, y_1, x_2, w_2, score \le 1$$. $$(x_1, y_1)$$ indicates the top-left corner of the bounding box, $$(x_2, y_2)$$ indicates the bottom-right corner of the bounding box; $$(0, 0)$$ indicates the top-left corner of the image, while $$(1, 1)$$ indicates the bottom-right corner of the image.
        ## The ground-truth labels for spatio-temporal action detection
        In the annotation folder, `ava_[train/val]_v[2.1/2.2].csv` are ground-truth labels for spatio-temporal action detection, which are used during training & validation. Take `ava_train_v2.1.csv` as an example, it is a csv file with 837318 lines, each line is the annotation for a human instance in one frame. For example, the first line in `ava_train_v2.1.csv` is `'-5KQ66BBWC4,0902,0.077,0.151,0.283,0.811,80,1'`: the first two items `-5KQ66BBWC4` and `0902` indicate that it corresponds to the $$902_{nd}$$ second in the video `-5KQ66BBWC4`. The next four items ($$[0.077(x_1), 0.151(y_1), 0.283(x_2), 0.811(y_2)]$$) indicates the location of the bounding box, the bbox format is the same as human proposals. The next item `80` is the action label. The last item `1` is the ID of this bounding box.
        ## Excluded timestamps
        `ava_[train/val]_excludes_timestamps_v[2.1/2.2].csv` contains excluded timestamps which are not used during training or validation. The format is `video_id, second_idx` .
        ## Label map
        `ava_action_list_v[2.1/2.2]_for_activitynet_[2018/2019].pbtxt` contains the label map of the AVA dataset, which maps the action name to the label index.
"""
 
import csv
import os
from distutils.log import info
import pickle
from matplotlib.pyplot import contour, show
import numpy as np
import cv2
from sklearn.utils import shuffle
 
 
def transformer(origin_csv_path, frame_image_dir,
                train_output_pkl_path, train_output_csv_path,
                valid_output_pkl_path, valid_output_csv_path,
                exclude_train_output_csv_path, exclude_valid_output_csv_path,
                out_action_list, out_labelmap_path, dataset_percent=0.9):
    """
    输入:
    origin_csv_path:从网站导出的csv文件路径。
    frame_image_dir:以"视频名_第n秒.jpg"格式命名的图片,这些图片是通过逐秒读取的。
    output_pkl_path:输出pkl文件路径
    output_csv_path:输出csv文件路径
    out_labelmap_path:输出labelmap.txt文件路径
    dataset_percent:训练集和测试集分割
    
    输出:无
    
    """
 
    # -----------------------------------------------------------------------------------------------
    get_label_map(origin_csv_path, out_action_list, out_labelmap_path)
    # -----------------------------------------------------------------------------------------------
    information_array = [[], [], []]
    # 读取输入csv文件的位置信息段落
    with open(origin_csv_path, 'r') as csvfile:
        count = 0
        content = csv.reader(csvfile)
        for line in content:
            # print(line)
            if count >= 10:
                frame_image_name = eval(line[1])[0]  # str
                # print(line[-2])
                location_info = eval(line[4])[1:]  # list
                action_list = list(eval(line[5]).values())[0].split(',')
                print(action_list)
                action_list = [int(x) for x in action_list if x!='']  # list
                information_array[0].append(frame_image_name)
                information_array[1].append(location_info)
                information_array[2].append(action_list)
            count += 1
    # 将:对应帧图片名字、物体位置信息、动作种类信息汇总为一个信息数组
    information_array = np.array(information_array, dtype=object).transpose()
    # information_array = np.array(information_array)
    # -----------------------------------------------------------------------------------------------
    num_train = int(dataset_percent * len(information_array))
    train_info_array = information_array[:num_train]
    valid_info_array = information_array[num_train:]
    get_pkl_csv(train_info_array, train_output_pkl_path, train_output_csv_path, exclude_train_output_csv_path, frame_image_dir)
    get_pkl_csv(valid_info_array, valid_output_pkl_path, valid_output_csv_path, exclude_valid_output_csv_path, frame_image_dir)
 
 
def get_label_map(origin_csv_path, out_action_list, out_labelmap_path):
    classes_list = 0
    classes_content = ""
    labelmap_strings = ""
    # 提取出csv中的第9行的行为下标
    with open(origin_csv_path, 'r') as csvfile:
        count = 0
        content = csv.reader(csvfile)
        for line in content:
            if count == 8:
                classes_list = line
                break
            count += 1
    # 截取种类字典段落
    st = 0
    ed = 0
    for i in range(len(classes_list)):
        if classes_list[i].startswith('options'):
            st = i
        if classes_list[i].startswith('default_option_id'):
            ed = i
    for i in range(st, ed):
        if i == st:
            classes_content = classes_content + classes_list[i][len('options:'):] + ','
        else:
            classes_content = classes_content + classes_list[i] + ','
    classes_dict = eval(classes_content)[0]
    # 写入labelmap.txt文件
    with open(out_action_list, 'w') as f:  # 写入action_list文件
        for v, k in classes_dict.items():
            labelmap_strings = labelmap_strings + "label {{\n  name: \"{}\"\n  label_id: {}\n  label_type: PERSON_MOVEMENT\n}}\n".format(k, int(v)+1)
        f.write(labelmap_strings)
    labelmap_strings = ""
    with open(out_labelmap_path, 'w') as f:  # 写入label_map文件
        for v, k in classes_dict.items():
            labelmap_strings = labelmap_strings + "{}: {}\n".format(int(v)+1, k)
        f.write(labelmap_strings)
 
 
def get_pkl_csv(information_array, output_pkl_path, output_csv_path, exclude_output_csv_path, frame_image_dir):
    # 在遍历之前需要对我们的字典进行初始化
    pkl_data = dict()  # 存储pkl键值对信的字典(其值为普通list)
    csv_data = []  # 存储导出csv文件的2d数组
    read_data = {}  # 存储pkl键值对的字典(方便字典的值化为numpy数组)
 
    for i in range(len(information_array)):
        img_name = information_array[i][0]
        # -------------------------------------------------------------------------------------------
        video_name, frame_name = '_'.join(img_name.split('_')[:-1]), format(int(img_name.split('_')[-1][:-4]), '04d')  # 我的格式是"视频名称_帧名称",格式不同可自行更改
        # -------------------------------------------------------------------------------------------
        pkl_key = video_name + ',' + frame_name
        pkl_data[pkl_key] = []
    # 遍历所有的图片进行信息读取并写入pkl数据
    for i in range(len(information_array)):
        img_name = information_array[i][0]
        # -------------------------------------------------------------------------------------------
        video_name, frame_name = '_'.join(img_name.split('_')[:-1]), str(int(img_name.split('_')[-1][:-4]))  # 我的格式是"视频名称_帧名称",格式不同可自行更改
        # -------------------------------------------------------------------------------------------
        imgpath = frame_image_dir + '/' + img_name
        location_list = information_array[i][1]
        action_info = information_array[i][2]
        image_array = cv2.imread(imgpath)
        h, w = image_array.shape[:2]
        # 进行归一化
        location_list[0] /= w
        location_list[1] /= h
        location_list[2] /= w
        location_list[3] /= h
        location_list[2] = location_list[2]+location_list[0]
        location_list[3] = location_list[3]+location_list[1]
        # 置信度置为1
        # 组装pkl数据
 
        for kind_idx in action_info:
            csv_info = [video_name, frame_name, *location_list, kind_idx+1, 1]
            csv_data.append(csv_info)
 
        location_list = location_list + [1]
        pkl_key = video_name + ',' + format(int(frame_name), '04d')
        pkl_value = location_list
        pkl_data[pkl_key].append(pkl_value)
 
    for k, v in pkl_data.items():
        read_data[k] = np.array(v)
 
    with open(output_pkl_path, 'wb') as f:  # 写入pkl文件
        pickle.dump(read_data, f)
 
    with open(output_csv_path, 'w', newline='') as f:  # 写入csv文件, 设定参数newline=''可以不换行。
        f_csv = csv.writer(f)
        f_csv.writerows(csv_data)
 
    with open(exclude_output_csv_path, 'w', newline='') as f:  # 写入csv文件, 设定参数newline=''可以不换行。
        f_csv = csv.writer(f)
        f_csv.writerows([])
 
def showpkl(pkl_path):
    with open(pkl_path, 'rb') as f:
        content = pickle.load(f)
    return content
 
 
def showcsv(csv_path):
    output = []
    with open(csv_path, 'r') as f:
        content = csv.reader(f)
        for line in content:
            output.append(line)
    return output
 
 
def showlabelmap(labelmap_path):
    classes_dict = dict()
    with open(labelmap_path, 'r') as f:
        content = (f.read().split('\n'))[:-1]
        for item in content:
            mid_idx = -1
            for i in range(len(item)):
                if item[i] == ":":
                    mid_idx = i
            classes_dict[item[:mid_idx]] = item[mid_idx + 1:]
    return classes_dict
 
 
os.makedirs('./ava/annotations', exist_ok=True)
transformer("./Unnamed-VIA Project15Nov2023_15h21m11s_export.csv", './ava/labelframes',
            './ava/annotations/ava_dense_proposals_train.FAIR.recall_93.9.pkl', './ava/annotations/ava_train_v2.1.csv',
            './ava/annotations/ava_dense_proposals_val.FAIR.recall_93.9.pkl', './ava/annotations/ava_val_v2.1.csv',
            './ava/annotations/ava_train_excluded_timestamps_v2.1.csv', './ava/annotations/ava_val_excluded_timestamps_v2.1.csv',
            './ava/annotations/ava_action_list_v2.1.pbtxt', './ava/annotations/labelmap.txt', 0.9)
print(showpkl('./ava/annotations/ava_dense_proposals_train.FAIR.recall_93.9.pkl'))
print(showcsv('././ava/annotations/ava_train_v2.1.csv'))
print(showlabelmap('././ava/annotations/labelmap.txt'))

 执行完上面的代码在目录下面会生成annotations目录和相关配置文件。

报错:由于上述标注文件中存在“”空值所以在转换为int()类型时会出现类型转换错误

解决:

在第81行修改代码为: action_list = [int(x) for x in action_list if x!='']  # list 

二、mmaction2的配置和使用

1.mmaction2所需环境

GitHub - open-mmlab/mmaction2: OpenMMLab's Next Generation Video Understanding Toolbox and Benchmark

conda create -n open-mmlab python=3.8 pytorch=1.10 cudatoolkit=11.3 torchvision -c pytorch -y
conda activate open-mmlab
pip3 install openmim
mim install mmcv-full
mim install mmdet  # optional
mim install mmpose  # optional
git clone https://github.com/open-mmlab/mmaction2.git
cd mmaction2
pip3 install -e .

这边建议使用 AutoDL算力云 | 弹性、好用、省钱。租GPU就上AutoDL

直接在算力市场中选择一个服务器,在镜像中选择mmaction2的镜像就会自动搭配好所需环境。

2.配置文件设置

进入mmaction2/configs/detection/ava目录slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py文件配置文件内容如下:

# model setting
model = dict(
    type='FastRCNN',
    backbone=dict(
        type='ResNet3dSlowFast',
        pretrained=None,
        resample_rate=8,
        speed_ratio=8,
        channel_ratio=8,
        slow_pathway=dict(
            type='resnet3d',
            depth=50,
            pretrained=None,
            lateral=True,
            conv1_kernel=(1, 7, 7),
            dilations=(1, 1, 1, 1),
            conv1_stride_t=1,
            pool1_stride_t=1,
            inflate=(0, 0, 1, 1),
            spatial_strides=(1, 2, 2, 1)),
        fast_pathway=dict(
            type='resnet3d',
            depth=50,
            pretrained=None,
            lateral=False,
            base_channels=8,
            conv1_kernel=(5, 7, 7),
            conv1_stride_t=1,
            pool1_stride_t=1,
            spatial_strides=(1, 2, 2, 1))),
    roi_head=dict(
        type='AVARoIHead',
        bbox_roi_extractor=dict(
            type='SingleRoIExtractor3D',
            roi_layer_type='RoIAlign',
            output_size=8,
            with_temporal_pool=True),
        bbox_head=dict(
            type='BBoxHeadAVA',
            in_channels=2304,
            num_classes=7,
			topk=(1,6),
            multilabel=True,
            dropout_ratio=0.5)),
    train_cfg=dict(
        rcnn=dict(
            assigner=dict(
                type='MaxIoUAssignerAVA',
                pos_iou_thr=0.9,
                neg_iou_thr=0.9,
                min_pos_iou=0.9),
            sampler=dict(
                type='RandomSampler',
                num=32,
                pos_fraction=1,
                neg_pos_ub=-1,
                add_gt_as_proposals=True),
            pos_weight=1.0,
            debug=False)),
    test_cfg=dict(rcnn=dict(action_thr=0.002)))
 
dataset_type = 'AVADataset'
data_root = '../data/ava/rawframes'
anno_root = '../data/ava/annotations'
 
ann_file_train = f'{anno_root}/ava_train_v2.1.csv'
ann_file_val = f'{anno_root}/ava_val_v2.1.csv'
 
exclude_file_train = f'{anno_root}/ava_train_excluded_timestamps_v2.1.csv'
exclude_file_val = f'{anno_root}/ava_val_excluded_timestamps_v2.1.csv'
 
label_file = f'{anno_root}/ava_action_list_v2.1.pbtxt'
 
proposal_file_train = (f'{anno_root}/ava_dense_proposals_train.FAIR.'
                       'recall_93.9.pkl')
proposal_file_val = f'{anno_root}/ava_dense_proposals_val.FAIR.recall_93.9.pkl'
 
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_bgr=False)
 
train_pipeline = [
    dict(type='SampleAVAFrames', clip_len=32, frame_interval=2),
    dict(type='RawFrameDecode'),
    dict(type='RandomRescale', scale_range=(256, 320)),
    dict(type='RandomCrop', size=256),
    dict(type='Flip', flip_ratio=0.5),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='FormatShape', input_format='NCTHW', collapse=True),
    # Rename is needed to use mmdet detectors
    dict(type='Rename', mapping=dict(imgs='img')),
    dict(type='ToTensor', keys=['img', 'proposals', 'gt_bboxes', 'gt_labels']),
    dict(
        type='ToDataContainer',
        fields=[
            dict(key=['proposals', 'gt_bboxes', 'gt_labels'], stack=False)
        ]),
    dict(
        type='Collect',
        keys=['img', 'proposals', 'gt_bboxes', 'gt_labels'],
        meta_keys=['scores', 'entity_ids'])
]
# The testing is w/o. any cropping / flipping
val_pipeline = [
    dict(
        type='SampleAVAFrames', clip_len=32, frame_interval=2, test_mode=True),
    dict(type='RawFrameDecode'),
    dict(type='Resize', scale=(-1, 256)),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='FormatShape', input_format='NCTHW', collapse=True),
    # Rename is needed to use mmdet detectors
    dict(type='Rename', mapping=dict(imgs='img')),
    dict(type='ToTensor', keys=['img', 'proposals']),
    dict(type='ToDataContainer', fields=[dict(key='proposals', stack=False)]),
    dict(
        type='Collect',
        keys=['img', 'proposals'],
        meta_keys=['scores', 'img_shape'],
        nested=True)
]
 
data = dict(
    videos_per_gpu=5,
    workers_per_gpu=2,
    val_dataloader=dict(videos_per_gpu=1),
    test_dataloader=dict(videos_per_gpu=1),
    train=dict(
        type=dataset_type,
        ann_file=ann_file_train,
        exclude_file=exclude_file_train,
        pipeline=train_pipeline,
        label_file=label_file,
        proposal_file=proposal_file_train,
        person_det_score_thr=0.9,
		num_classes=7,
		start_index=1,
        data_prefix=data_root),
    val=dict(
        type=dataset_type,
        ann_file=ann_file_val,
        exclude_file=exclude_file_val,
        pipeline=val_pipeline,
        label_file=label_file,
        proposal_file=proposal_file_val,
        person_det_score_thr=0.9,
		num_classes=7,
		start_index=1,
        data_prefix=data_root))
data['test'] = data['val']
 
optimizer = dict(type='SGD', lr=0.1125, momentum=0.9, weight_decay=0.00001)
# this lr is used for 8 gpus
 
optimizer_config = dict(grad_clip=dict(max_norm=40, norm_type=2))
# learning policy
 
lr_config = dict(
    policy='step',
    step=[10, 15],
    warmup='linear',
    warmup_by_epoch=True,
    warmup_iters=5,
    warmup_ratio=0.1)
total_epochs = 200
checkpoint_config = dict(interval=1)
workflow = [('train', 1)]
evaluation = dict(interval=1, save_best='mAP@0.5IOU')
log_config = dict(
    interval=20, hooks=[
        dict(type='TextLoggerHook'),
    ])
dist_params = dict(backend='nccl')
log_level = 'INFO'
work_dir = ('./work_dirs/ava/'
            'slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb')
load_from = ('https://download.openmmlab.com/mmaction/recognition/slowfast/'
             'slowfast_r50_4x16x1_256e_kinetics400_rgb/'
             'slowfast_r50_4x16x1_256e_kinetics400_rgb_20200704-bcde7ed7.pth')
resume_from = None
find_unused_parameters = False

注意:

1、替换全部num_classes,我定义了6种行为,所以num_classes=7,要考虑__background__;

2、第42行topk=(1,6),1保持默认,6为行为的数量;

3、62-64行注意训练数据集的路径;

4、若训练过程中显存不够,修改第122行videos_per_gpu的数量;

5、第135、146行要加上start_index=1;

6、163行修改训练次数;

7、第175行load_from可使用预训练模型。
 

然后输入命令:

bash tools/dist_train.sh configs/detection/ava/slowfast_kinetics_pretrained_r50_4x16x1_20e_ava_rgb.py 1

 这个1可以根据自己的GPU数量自己选择数量

运行结果:


总结

以上就是通过mmaction2训练slowfast的过程。后面讲用AVA数据集流程及在SlowFast中训练的过程也就是第二种方法。

  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
要在PyCharm中安装MMAction2,可以按照以下步骤进行操作: 1. 首先,确保你已经安装了PyCharm。如果没有安装,可以从官方网站下载并按照指示进行安装。 2. 接下来,需要创建一个新的虚拟环境。你可以按照引用中的方法在终端中创建和激活虚拟环境,或者你也可以在PyCharm中使用内置的虚拟环境管理工具来创建虚拟环境。 3. 在PyCharm中打开一个新的项目。你可以点击"File"菜单,然后选择"New Project"来创建一个新项目。 4. 在新项目中,你需要安装所需的依赖包。可以通过以下两种方法之一来安装依赖包: - 方法1: 打开终端,进入项目的根目录,然后使用命令`pip install -r requirements/build.txt`来安装依赖包。 - 方法2: 在PyCharm的项目窗口中,展开"External Libraries",然后右击项目依赖,选择"Install requirements",并选择`build.txt`文件来安装依赖包。 5. 接下来,你需要将MMAction2源代码添加到你的项目中。你可以通过以下两种方法之一来完成: - 方法1: 打开终端,进入项目的根目录,然后使用命令`git clone https://github.com/open-mmlab/mmaction2.git`来克隆MMAction2的代码库。 - 方法2: 在PyCharm的项目窗口中,点击"VCS"菜单,然后选择"Checkout from Version Control",并选择Git,填入MMAction2的代码库URL,点击"Clone"来克隆代码库。 6. 确保你已经切换到了MMAction2的源代码目录。你可以通过命令`cd mmaction2`来进入该目录,或者在PyCharm的项目窗口中导航到该目录。 7. 最后,你可以在PyCharm中使用MMAction2来开发和运行你的项目了。你可以根据MMAction2的文档和示例代码来编写和测试你的代码。 请注意,以上步骤是基于你已经安装了正确版本的MMCV和PyTorch。如果你尚未安装这些库,可以参考引用中的指南来安装它们。 希望以上信息对你有帮助!如果还有其他问题,请随时提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值