violence-recognition-pytorch的测试代码

最近在弄视频分类的项目
其中包括了暴恐识别,我用了https://github.com/swathikirans/violence-recognition-pytorch这个项目代码,将其中的代码修改为符合Pytorch1.4版本的风格,就愉快的训练了,但是发现训练完成后并没有测试代码,于是自己就写了一个,当然借鉴了作者的训练代码,很简单,在此分享出来,如果能帮助大家就更happy了。

import torch
import glob
import os
from createModel import *

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from spatial_transforms import (Compose, ToTensor, FiveCrops, Scale, Normalize, MultiScaleCornerCrop,
                                RandomHorizontalFlip, TenCrops, FlippedImagesTest, CenterCrop)
from makeDataset import *
import torch.nn.functional as F

def make_split(fights_dir):
    imagesF = []
    for target in sorted(os.listdir(fights_dir)):
        d = os.path.join(fights_dir, target)
        if not os.path.isdir(d):
            continue
        imagesF.append(d)
    imagesNoF = []
    # for target in sorted(os.listdir(noFights_dir)):
    #     d = os.path.join(noFights_dir, target)
    #     if not os.path.isdir(d):
    #         continue
    #     imagesNoF.append(d)
    Dataset = imagesF + imagesNoF
    Labels = list([1] * len(imagesF)) + list([0] * len(imagesNoF))
    NumFrames = [len(glob.glob1(Dataset[i], "*.jpg")) for i in range(len(Dataset))]
    return Dataset, Labels, NumFrames

img_dir="../test/test2/" #3个文件夹 单个文件夹 分别是0 0 1 中间视频分类模糊.
testDataset, testLabels, testNumFrames = make_split(img_dir)

mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
test_spatial_transform = Compose([Scale(256), CenterCrop(224), FlippedImagesTest(mean=mean, std=std)])
testBatchSize = 1
vidSeqTest = makeDataset(testDataset, testLabels, testNumFrames, seqLen=20,spatial_transform=test_spatial_transform)
testLoader = torch.utils.data.DataLoader(vidSeqTest, batch_size=testBatchSize,shuffle=False, num_workers=int(4/2), pin_memory=True)

outputs=[]
predicts=[]
model=torch.load("../experiments_violence/bestModel.pth")
for j, (inputs, targets) in enumerate(testLoader):
    model.eval()
    inputVariable1 = Variable(inputs[0].cuda())
    outputLabel = model(inputVariable1)

   
    outputLabel_mean = torch.mean(outputLabel, 0, True) 

    _, predicted = torch.max(outputLabel_mean.data, 1) #每行的最大值 predicted包含最大值以及索引. _是最大值,predicted是索引,分类问题中,打印视频所属类别就可以了。

    predicts.append(predicted.cpu().data.numpy())

print(predicts)  #[0,0,1] 分别是3个视频序列帧的所属分类. 

我测试了3个视频,分别是非暴力,非暴力(本来是暴力的后面打斗的帧没抽),暴力。这个项目来自论文https://arxiv.org/abs/1709.06531 Learning to Detect Violent Videos using Convolutional Long Short-Term Memory。我训练的识别率达到了%94,效果还行。。数据集来自https://www.pianshen.com/article/91361368135/ 中的在这里插入图片描述
训练集:测试集=4:1
有什么不懂的评论区吧。

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值