整体结构
整个项目可以分为三个部分
- 数据准备(datasets)
1.1 rgb、flow提取,实现于build_of.py
1.2 帧列表文件生成,将视频包含rgb数量和flow数量与视频所属类别一起分别写进文件,实现于build_file_lists。
1.3 构建网络的输入,实现于ucf.py。(图片预处理缩放、裁剪、翻转video_transforms) - 模型搭建(modles)
2.1 搭建时间网络,flow_resnet.py
2.2 搭建空间网络,rgb_resnet.py
2.3 搭建融合CBAM的空间网络,rgb_resnet_cbam.py
2.4 搭建融合CBAM的时间网络,flow_resnet_cbam.py - 训练、测试模型
main.py导入在ImageNet上预训练好的网络,定义损失函数和优化器,导入数据集,多GPU训练,每训练25次在验证集上测试准确率,记录准确高的模型并将参数保存在best_modle.tar.gz中。我的项目是4块GPU,在空间网络上训练250次,大概10个小时完成。在时间流网络上训练350次,大概需要22个小时。
文件目录结构
│ extract.py #提取一个视频中的rgb帧和flow光流,返回rgb帧数量(rgb和flow数量相同)
│ flask_action.ipynb #基于flask的web框架的wsgi服务器,和客户端进行信息传递
│ main.py #主函数,进行模型的训练和测试
│ README.md #说明文件
│ video_transforms.py #图像的各种预处理方法
│
├─checkpoint #存放训练好的双流网络模型参数
│ flow_resnet152_cbam_model_best.pth.tar(融合注意力机制)
│ flow_resnet152_model_best.pth.tar
│ rgb_resnet152_cbam_model_best.pth.tar(融合注意力机制)
│ rgb_resnet152_model_best.pth.tar
│
├─datasets #对UCF-101数据集的各种处理
│ │ build_file_list.py #根据原始的测试训练数据列表文件生成对应的rgb和flow测试训练数据列表文件,格式视频名称 rgb/flow数量 视频类别
│ │ build_of.py #利用编译好的dense flow光流提取工具对UCF-101中每个视频进行光流和rgb帧提取,存放在ucf101_frames
│ │ ucf101.py #定义ucf101数据集,根据输入的训练、测试数据列表文件构建训练集和测试集,训练集为从每个视频中随机挑选一个rgb帧和10个连续的光流,测试集为每次取中间的rgb帧和中间的10张连续的flow光流。
│ │
│ ├─settings #根据ucf101_splits build_file_list.py生成的rgb、flow测试和训练列表文件
│ │ └─ucf101
│ │ train_flow_split1.txt
│ │ train_flow_split2.txt
│ │ train_flow_split3.txt
│ │ train_rgb_split1.txt
│ │ train_rgb_split2.txt
│ │ train_rgb_split3.txt
│ │ val_flow_split1.txt
│ │ val_flow_split2.txt
│ │ val_flow_split3.txt
│ │ val_rgb_split1.txt
│ │ val_rgb_split2.txt
│ │ val_rgb_split3.txt
│ │
│ └─ucf101_splits #官方给的三种不同的测试训练集的划分方案
│ │ classInd.txt #将index和class类别一一对应,index从1开始
│ │ testlist01.txt
│ │ testlist02.txt
│ │ testlist03.txt
│ │ trainlist01.txt
│ │ trainlist02.txt
│ │ trainlist03.txt
│
├─ucf101_frames
│
├─eval_ucf101 #训练好two-stream后在测试集上测试
│ demo.py #融合两种网络,最后对softmax分数进行融合(平均法)
│ SpatialPrediction.py #空间网络预测
│ spatial_demo.py
│ TemporalPrediction.py #时间网络预测
│ temporal_demo.py
│ __init__.py
│
├─models #搭建two-stream网络结果
│ flow_resnet.py #时间流网络
│ flow_resnet_cbam.py #融合注意力机制的时间流网络
│ rgb_resnet.py #空间流网络
│ rgb_resnet_cbam.py #融合注意力记住的空间流网络
│ __init__.py
│
├─static
│ ├─generated #用来存放客户端上传的视频中的rgb帧和flow光流
│ └─js
│ jquery-3.3.1.min.js
│
├─templates #用来存放传给浏览器的内容
│ .DS_Store
│ index.html
│
└─video #用来存放客户端上传的视频