最近在弄视频分类的项目
其中包括了暴恐识别,我用了https://github.com/swathikirans/violence-recognition-pytorch这个项目代码,将其中的代码修改为符合Pytorch1.4版本的风格,就愉快的训练了,但是发现训练完成后并没有测试代码,于是自己就写了一个,当然借鉴了作者的训练代码,很简单,在此分享出来,如果能帮助大家就更happy了。
import torch
import glob
import os
from createModel import *
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from spatial_transforms import (Compose, ToTensor, FiveCrops, Scale, Normalize, MultiScaleCornerCrop,
RandomHorizontalFlip, TenCrops, FlippedImagesTest, CenterCrop)
from makeDataset import *
import torch.nn.functional as F
def make_split(fights_dir):
imagesF = []
for target in sorted(os.listdir(fights_dir)):
d = os.path.join(fights_dir, target)
if not os.path.isdir(d):
continue
imagesF.append(d)
imagesNoF = []
# for target in sorted(os.listdir(noFights_dir)):
# d = os.path.join(noFights_dir, target)
# if not os.path.isdir(d):
# continue
# imagesNoF.append(d)
Dataset = imagesF + imagesNoF
Labels = list([1] * len(imagesF)) + list([0] * len(imagesNoF))
NumFrames = [len(glob.glob1(Dataset[i], "*.jpg")) for i in range(len(Dataset))]
return Dataset, Labels, NumFrames
img_dir="../test/test2/" #3个文件夹 单个文件夹 分别是0 0 1 中间视频分类模糊.
testDataset, testLabels, testNumFrames = make_split(img_dir)
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
test_spatial_transform = Compose([Scale(256), CenterCrop(224), FlippedImagesTest(mean=mean, std=std)])
testBatchSize = 1
vidSeqTest = makeDataset(testDataset, testLabels, testNumFrames, seqLen=20,spatial_transform=test_spatial_transform)
testLoader = torch.utils.data.DataLoader(vidSeqTest, batch_size=testBatchSize,shuffle=False, num_workers=int(4/2), pin_memory=True)
outputs=[]
predicts=[]
model=torch.load("../experiments_violence/bestModel.pth")
for j, (inputs, targets) in enumerate(testLoader):
model.eval()
inputVariable1 = Variable(inputs[0].cuda())
outputLabel = model(inputVariable1)
outputLabel_mean = torch.mean(outputLabel, 0, True)
_, predicted = torch.max(outputLabel_mean.data, 1) #每行的最大值 predicted包含最大值以及索引. _是最大值,predicted是索引,分类问题中,打印视频所属类别就可以了。
predicts.append(predicted.cpu().data.numpy())
print(predicts) #[0,0,1] 分别是3个视频序列帧的所属分类.
我测试了3个视频,分别是非暴力,非暴力(本来是暴力的后面打斗的帧没抽),暴力。这个项目来自论文https://arxiv.org/abs/1709.06531 Learning to Detect Violent Videos using Convolutional Long Short-Term Memory。我训练的识别率达到了%94,效果还行。。数据集来自https://www.pianshen.com/article/91361368135/ 中的
训练集:测试集=4:1
有什么不懂的评论区吧。