使用pytorch提供的预训练模型做视频分类
pytorch提供的预训练模型
pytorch提供图片分类、模型量化、语义分割、目标检测、视频分类等多种视觉任务的预训练模型,官网地址。本博客使用pytorch提供的视频分类预训练模型做预测。在进行预测之前我们需要获得Kinetics-400 的验证数据集。
下载Kinetics-400 的验证数据集
Kinetics-400的总体数据集比较大(大约144.17G),我们主要下载验证集(13.14G)。数据集磁力链
下载好的压缩包解压如下:
调用pytorch预训练模型做分类
- 调用r3d_18模型做分类
from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights
# 加载视频数据
vid, _, _ = read_video("E:/Kinetics400/val_256/abseiling/0wR5jVB-WPk.mp4", output_format="TCHW")
# vid, _, _ = read_video("E:/Kinetics400/val_256/doing_nails/1DitEQ5S190.mp4", output_format="TCHW")
# 选择帧数据
vid = vid[:32]
# 选择预训练模型
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()
# 初始化数据预处理模块
preprocess = weights.transforms()
# 数据预处理
batch = preprocess(vid).unsqueeze(0)
# 将数据放入模型做预测
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
# 打印类别和置信度
print(f"{category_name}: {100 * score}%")
- 调用swin3d_b模型做分类
from torchvision.io.video import read_video
from torchvision.models.video import swin3d_b, Swin3D_B_Weights
# vid, _, _ = read_video("E:/Kinetics400/val_256/abseiling/0wR5jVB-WPk.mp4", output_format="TCHW")
vid, _, _ = read_video("E:/Kinetics400/val_256/doing_nails/1DitEQ5S190.mp4", output_format="TCHW")
vid = vid[:32] # optionally shorten duration
# Step 1: Initialize model with the best available weights
weights = Swin3D_B_Weights.DEFAULT
model = swin3d_b(weights=weights)
model.eval()
# Step 2: Initialize the inference transforms
preprocess = weights.transforms()
# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)
# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")
预训练模型结果评估
根据部分可视化的结果得出结论:swin3d_b的预测准确度高于r3d_18。官网关于各类预训练模型的准确度数据如下:
结尾
欢迎加入群聊一起学习、讨论技术!
B站账号:Silver__Wolf_
Q:130856474