如何用Pytorch提取视频单帧网络特征+SVM分类 - 非End-to-End

31 篇文章 1 订阅
20 篇文章 6 订阅

1、提取网络特征

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
os.system('echo $CUDA_VISIBLE_DEVICES')

import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.autograd import Variable

import numpy as np
from PIL import Image

def pre_image(image_path):
    trans = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor()])

    img = Image.open(image_path)
    img = trans(img)
    x = Variable(torch.unsqueeze(img,dim=0).float(),requires_grad=False).cuda()
    return x

model = models.resnet152(pretrained=True).cuda()
extractor = nn.Sequential(*list(model.children())[:-1])

feature_path = '/data/FrameFeature/Penn/'
video_path = '/home/UPenn_RGB/frames/'
for video in os.listdir(video_path):
    for frame in os.listdir(os.path.join(video_path,video)):
        image_path = video_path+video+'/'+frame
        x = pre_image(image_path)
        y = extractor(x)
        y = y.data.cpu().numpy().reshape(1,2048)
        if not os.path.exists(feature_path+video):
            os.mkdir(feature_path+video)
        np.savetxt(feature_path+video+'/'+frame.split('.')[0]+'.txt',y,delimiter=',')
    print video

2、时域pooling以及SVM分类器训练

import os
import numpy as np
import h5py
from sklearn import svm

def load_feature(video_name):
    frames = os.listdir(video_name)
    feature = []
    for frame in frames:
        frame_path = os.path.join(video_name,frame)
        feature.append(np.loadtxt(frame_path,delimiter=','))
    feature = np.asarray(feature)
    return feature

def mean_pool(feature):
    return np.mean(feature,axis=0)

def max_pool(feature):
    return np.max(feature,axis=0)

def min_pool(feature):
    return np.min(feature,axis=0)

if __name__ == '__main__':
    ''' Save Data '''
    # with open('/data/FrameFeature/Penn_train.txt','r') as fp:
    #     mean_feat,max_feat,min_feat,diff_feat,dyna_feat,label=[],[],[],[],[],[]
    #     for line in fp.readlines():
    #         video_name = line.strip().split(' ')[0]
    #         video_label= int(line.strip().split(' ')[1])
    #         video_name = '/data/FrameFeature/Penn/'+video_name
    #         print video_name+'\ttrain'
    #         feature = load_feature(video_name)
    #         mean_feat.append(mean_pool(feature))
    #         max_feat.append(max_pool(feature))
    #         min_feat.append(min_pool(feature))
    #         diff_feat.append(sum_diff_pool(feature))
    #         dyna_feat.append(dynamic_pool(feature))
    #         label.append(video_label)
    # train_mean = np.asarray(mean_feat); del mean_feat
    # train_max  = np.asarray(max_feat);  del max_feat
    # train_min  = np.asarray(min_feat);  del min_feat
    # train_diff = np.asarray(diff_feat); del diff_feat
    # train_dyna = np.asarray(dyna_feat); del dyna_feat
    # train_label= np.asarray(label);     del label
    # h5file = h5py.File('/data/FrameFeature/Penn_train.h5','w')
    # h5file.create_dataset('train_mean',data=train_mean)
    # h5file.create_dataset('train_max',data=train_max)
    # h5file.create_dataset('train_min',data=train_min)
    # h5file.create_dataset('train_diff',data=train_diff)
    # h5file.create_dataset('train_dyna',data=train_dyna)
    # h5file.create_dataset('train_label',data=train_label)
    # h5file.close()
    #
    #
    # with open('/data/FrameFeature/Penn_test.txt','r') as fp:
    #     mean_feat,max_feat,min_feat,diff_feat,dyna_feat,label=[],[],[],[],[],[]
    #     for line in fp.readlines():
    #         video_name = line.strip().split(' ')[0]
    #         video_label= int(line.strip().split(' ')[1])
    #         video_name = '/data/FrameFeature/Penn/'+video_name
    #         print video_name+'\ttest'
    #         feature = load_feature(video_name)
    #         mean_feat.append(mean_pool(feature))
    #         max_feat.append(max_pool(feature))
    #         min_feat.append(min_pool(feature))
    #         diff_feat.append(sum_diff_pool(feature))
    #         dyna_feat.append(dynamic_pool(feature))
    #         label.append(video_label)
    # test_mean = np.asarray(mean_feat); del mean_feat
    # test_max  = np.asarray(max_feat);  del max_feat
    # test_min  = np.asarray(min_feat);  del min_feat
    # test_diff = np.asarray(diff_feat); del diff_feat
    # test_dyna = np.asarray(dyna_feat); del dyna_feat
    # test_label= np.asarray(label);     del label
    # h5file = h5py.File('/data/FrameFeature/Penn_test.h5','w')
    # h5file.create_dataset('test_mean',data=test_mean)
    # h5file.create_dataset('test_max',data=test_max)
    # h5file.create_dataset('test_min',data=test_min)
    # h5file.create_dataset('test_diff',data=test_diff)
    # h5file.create_dataset('test_dyna',data=test_dyna)
    # h5file.create_dataset('test_label',data=test_label)
    # h5file.close()
    ''' Read Data '''
    h5file = h5py.File('/data/FrameFeature/Penn_train.h5','r')
    train_mean = h5file['train_mean'][:]
    train_max  = h5file['train_max'][:]
    train_min  = h5file['train_min'][:]
    train_diff = h5file['train_diff'][:]
    train_dyna = h5file['train_dyna'][:]
    train_label= h5file['train_label'][:]
    h5file.close()

    h5file = h5py.File('/data/FrameFeature/Penn_test.h5','r')
    test_mean = h5file['test_mean'][:]
    test_max  = h5file['test_max'][:]
    test_min  = h5file['test_min'][:]
    test_diff = h5file['test_diff'][:]
    test_dyna = h5file['test_dyna'][:]
    test_label= h5file['test_label'][:]
    h5file.close()

    ''' Train SVM '''
    SVM = svm.SVC(kernel='linear')
    # Mean
    SVM.fit(train_mean, train_label)
    print 'Mean: ' + str(SVM.score(test_mean, test_label))
    # Max
    SVM.fit(train_max, train_label)
    print 'Max: ' + str(SVM.score(test_max, test_label))
    # Min
    SVM.fit(train_min, train_label)
    print 'Min: ' + str(SVM.score(test_min, test_label))

  • 1
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值