使用pytorch提供的预训练模型做视频分类

pytorch提供的预训练模型

  pytorch提供图片分类、模型量化、语义分割、目标检测、视频分类等多种视觉任务的预训练模型,官网地址。本博客使用pytorch提供的视频分类预训练模型做预测。在进行预测之前我们需要获得Kinetics-400 的验证数据集。

下载Kinetics-400 的验证数据集

  Kinetics-400的总体数据集比较大(大约144.17G),我们主要下载验证集(13.14G)。数据集磁力链
在这里插入图片描述
  下载好的压缩包解压如下:
在这里插入图片描述

调用pytorch预训练模型做分类

  • 调用r3d_18模型做分类
from torchvision.io.video import read_video
from torchvision.models.video import r3d_18, R3D_18_Weights
# 加载视频数据
vid, _, _ = read_video("E:/Kinetics400/val_256/abseiling/0wR5jVB-WPk.mp4", output_format="TCHW")
# vid, _, _ = read_video("E:/Kinetics400/val_256/doing_nails/1DitEQ5S190.mp4", output_format="TCHW")
# 选择帧数据
vid = vid[:32] 

# 选择预训练模型
weights = R3D_18_Weights.DEFAULT
model = r3d_18(weights=weights)
model.eval()

# 初始化数据预处理模块
preprocess = weights.transforms()
# 数据预处理
batch = preprocess(vid).unsqueeze(0)

# 将数据放入模型做预测
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
# 打印类别和置信度
print(f"{category_name}: {100 * score}%")
  • 调用swin3d_b模型做分类
from torchvision.io.video import read_video
from torchvision.models.video import swin3d_b, Swin3D_B_Weights

# vid, _, _ = read_video("E:/Kinetics400/val_256/abseiling/0wR5jVB-WPk.mp4", output_format="TCHW")
vid, _, _ = read_video("E:/Kinetics400/val_256/doing_nails/1DitEQ5S190.mp4", output_format="TCHW")
vid = vid[:32]  # optionally shorten duration

# Step 1: Initialize model with the best available weights
weights = Swin3D_B_Weights.DEFAULT
model = swin3d_b(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

# Step 3: Apply inference preprocessing transforms
batch = preprocess(vid).unsqueeze(0)

# Step 4: Use the model and print the predicted category
prediction = model(batch).squeeze(0).softmax(0)
label = prediction.argmax().item()
score = prediction[label].item()
category_name = weights.meta["categories"][label]
print(f"{category_name}: {100 * score}%")

预训练模型结果评估

  根据部分可视化的结果得出结论:swin3d_b的预测准确度高于r3d_18。官网关于各类预训练模型的准确度数据如下:
在这里插入图片描述

结尾

欢迎加入群聊一起学习、讨论技术!
B站账号:Silver__Wolf_
Q:130856474

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值