基于PyTorch搭建CNN实现视频动作分类任务

基于PyTorch搭建CNN实现视频动作分类任务

2022年10月6日

项目介绍

简单实训项目:datafountain.cn

里面有具体的项目说明,并且可以下载数据集。

项目路径

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3aPi7hsD-1665063823285)(https://cdn.jsdelivr.net/gh/cxy-sky/jat-blog-img/image-20221006200919784.png)]

  • datasets是在官网下载的数据包,
    - [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0UTjIQ42-1665063823286)(https://cdn.jsdelivr.net/gh/cxy-sky/jat-blog-img/image-20221006201008001.png)]

代码

BasketballDataset.py

继承dataset,适配pytorch的dataloader

import torch.nn as nn
from PIL import Image
import os
from torch.utils.data import DataLoader, sampler, Dataset


# 继承Module类,使用pytorch中设计好的Dataloader作为我们的数据的加载器。
#
# 该加载器能够根据设定在每一次请求时自动加载一批训练数据,
#
# 能够自主实现多线程加载,能够在快速加载的同时尽可能的节省内存开销。

# 而Dataloader类所加载的数据必须是pytorch中定义好的Dataset类,
#
# 所以我们的第一步,就是将我们的数据封装成一个Dataset类。


class BasketballDataset(Dataset):
    def __init__(self, root_dir, labels=[], transform=None):
        super(BasketballDataset, self).__init__()
        self.root_dir = root_dir
        self.labels = labels
        self.transform = transform
        self.length = len(os.listdir(self.root_dir))

    def __len__(self):
        return self.length * 3  # 因为此数据集一个视频有三张图片

    def __getitem__(self, idx):
        folder = idx // 3 + 1  # 得到文件夹名非0部分
        imidex = idx % 3 + 1  # 得到每个文件夹下图片的编号1,2,3
        # 文件夹,命名的文件夹名为5为数字,如00001,所以要进行格式转变
        folder = format(folder, '05d')
        imgname = str(imidex) + '.jpg'
        img_path = os.path.join(self.root_dir, folder, imgname)
        image = Image.open(img_path)

        Label = 0
        # test集没有标签
        if len(self.labels) != 0:
            Label = self.labels[idx // 3][0] - 1
        if self.transform:
            image = self.transform(image)
        if len(self.labels) != 0:
            sample = {'image': image, 'img_path': img_path, "Label": Label}
        else:
            sample = {'image': image, 'img_path': img_path}
        return sample

pre.py

​ 数据读取

from torch.utils.data import DataLoader
import torchvision.transforms as T
import scipy.io

from basketball.basketballDataset import BasketballDataset


def get_dataloader(name):

    # 保存的label的标签
    label_mat = scipy.io.loadmat('../datasets/q3_2_data.mat')
    label_train = label_mat['trLb']
    print('train_len:', len(label_train))
    label_val = label_mat['valLb']
    print('val_len:', len(label_val))

    if name == 'train':
        dataset_train = BasketballDataset(root_dir='../datasets/trainClips',
                                          labels=label_train,
                                          transform=T.ToTensor())

        dataloader_train = DataLoader(dataset_train,
                                      batch_size=32,
                                      shuffle=True,
                                      num_workers=4)
        return dataloader_train

    if name == 'val':
        dataset_val = BasketballDataset(root_dir='../datasets/valClips',
                                        labels=label_val,
                                        transform=T.ToTensor())

        dataloader_val = DataLoader(dataset_val,
                                    batch_size=32,
                                    shuffle=True,
                                    num_workers=4)
        return dataloader_val
    if name == 'test':
        dataset_test = BasketballDataset(root_dir='../datasets/testClips',
                                         transform=T.ToTensor())

        dataloader_test = DataLoader(dataset_test,
                                     batch_size=32,
                                     shuffle=True,
                                     num_workers=4)
        return dataloader_test


if __name__ == '__main__':

    # 保存的label的标签
    label_mat = scipy.io.loadmat('../datasets/q3_2_data.mat')
    label_train = label_mat['trLb']
    print('train_len:', len(label_train))
    label_val = label_mat['valLb']
    print('val_len:', len(label_val))

    # Dataloader类所加载的数据必须是pytorch中定义好的Dataset类,
    # 所以我们的第一步,就是将我们的数据封装成一个Dataset类。
    dataset = BasketballDataset(root_dir='../datasets/trainClips',
                                labels=label_train,
                                transform=T.ToTensor())

    # Dataloader
    dataloader = DataLoader(dataset,
                            batch_size=4,
                            shuffle=True,
                            num_workers=4)

    for i in range(3):
        sample = dataset[i]
        print(sample['image'].shape)
        print(sample['Label'])
        print(sample['img_path'])

    for i, sample in enumerate(dataloader):
        print(i, sample['image'].shape, sample['img_path'], sample['Label'])
        if i > 5:
            break

model.py

​ 继承nn.Module,编写网络

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, sampler, Dataset
import torchvision.datasets as dset
import torchvision.transforms as T
import timeit
from PIL import Image
import os
import numpy as np
import scipy.io
import torchvision.models.inception as inception


class Cnn(nn.Module):
    def __init__(self, channel=3):
        super(Cnn, self).__init__()
        self.sequential = nn.Sequential(
            nn.Conv2d(channel, 8, kernel_size=7, stride=1),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),
            nn.Conv2d(8, 16, 7, 1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),
            nn.Flatten(),
            nn.ReLU(inplace=True),
            nn.Linear(16 * 11 * 11, 10)
        )

    def forward(self, x):
        return self.sequential(x)


if __name__ == '__main__':
    model = Cnn()
    model = model
    x = torch.randn(32, 3, 64, 64)
    x_var = Variable(x)  # 需要将其封装为Variable类型。
    outputs = model(x_var)
    print(np.array(outputs.size()))  # 检查模型输出。
    np.array_equal(np.array(outputs.size()), np.array([32, 10]))

train.py

​ 模型训练

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

from basketball.pre import get_dataloader
from basketball.model import Cnn

dtype = torch.FloatTensor  # 这是pytorch所支持的cpu数据类型中的浮点数类型。

print_every = 100  # 这个参数用于控制loss的打印频率,因为我们需要在训练过程中不断的对loss进行检测。

model = Cnn()
loss_fun = nn.CrossEntropyLoss()
opt_fun = optim.RMSprop(model.parameters(), lr=0.0001)


def train(epoch):
    dataloader_train = get_dataloader('train')
    dataloader_val = get_dataloader('val')
    model.train()
    for i in range(epoch):
        for idx, sample in enumerate(dataloader_train):
            img = Variable(sample['image'])
            labels = Variable(sample['Label'].long())

            output = model(img)
            loss = loss_fun(output, labels)

            opt_fun.zero_grad()
            loss.backward()
            opt_fun.step()

            if idx % 100 == 0:
                model.eval()
                acc_count = 0.0
                all_count = 0
                with torch.no_grad():
                    for idx, sample in enumerate(dataloader_val):
                        img = Variable(sample['image'])
                        labels = Variable(sample['Label'].long())
                        output = model(img)

                        _, pre = torch.max(output, dim=1)
                        acc_count += (pre == labels).sum().item()
                        all_count += labels.size()[0]

                print('第{}轮,第{}次,loss为: {},验证集准确率为:{}'.format(i + 1, idx + 1, loss.item(), acc_count / all_count))


if __name__ == '__main__':
    train(1)
    torch.save(model.state_dict(), '../model/basketball.pth')


predict.py

​ 模型预测test数据集,并保存到result.csv文件中

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader, sampler, Dataset
import torchvision.datasets as dset
import torchvision.transforms as T
import timeit
from PIL import Image
import os
import numpy as np
import scipy.io
import torchvision.models.inception as inception

from basketball.model import Cnn
from basketball.pre import get_dataloader


def predict():
    dataloader_test = get_dataloader('test')
    model = Cnn()
    model.load_state_dict(torch.load('../model/basketball.pth'))

    write_csv = open('../result.csv', 'w')
    count = 0
    write_csv.write('Id' + ',' + 'Class' + '\n')

    model.eval()
    for t, sample, in enumerate(dataloader_test):
        img = Variable(sample['image'])
        output = model(img)
        _, pre = torch.max(output, dim=1)
        for i in range(len(pre)):
            write_csv.write(str(count) + ',' + str(pre[i].item()) + '\n')
            count += 1

    write_csv.close()
    return count


if __name__ == '__main__':
    count = predict()
    print(count)

预测结果

在这里插入图片描述

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值