1.准备数据
类似ucf101数据集格式
1.1生成class.txt train.txt test.txt
#makelabel.py
import os
#图片数据集路径
baseDir = "D:\Downloads\PaddleVideo-develop\posvideo0/videos"
#标注文件输出文件夹
targetDir = "D:\Downloads\PaddleVideo-develop\posvideo0/annotations"
if not os.path.exists(targetDir):
os.makedirs(targetDir)
#classind.txt
labels = os.listdir(baseDir)
with open(os.path.join(targetDir,"classInd.txt"),"w+") as f:
for i in range(len(labels)):
line = "{} {}\n".format(i+1,labels[i])
f.write(line)
#label
labels = dict([(labels[i],i+1) for i in range(len(labels))])
print(labels)
#trainlist.txt
with open(os.path.join(targetDir,"trainlist.txt"),"w+") as f:
for label in labels:
index = os.listdir(os.path.join(baseDir,label))
for i in index:
line = "{}/{} {}\n".format(label,i,labels[label])
f.write(line)
#testlist.txt
with open(os.path.join(targetDir,"testlist.txt"),"w+") as f:
for label in labels:
index = os.listdir(os.path.join(baseDir,label))
for i in index:
line = "{}/{}\n".format(label,i)
f.write(line)
下载模型代码: http:// https://github. com/mit-han-lab/temporal-shift-module.
1.2对视频进行抽帧:在vid2img_ucf101.py中
from __future__ import print_function, division
import os
import sys
import subprocess
def class_process(dir_path, dst_dir_path, class_name):
class_path = os.path.join(dir_path, class_name)
if not os.path.isdir(class_path):
return
dst_class_path = os.path.join(dst_dir_path, class_name)
if not os.path.exists(dst_class_path):
os.mkdir(dst_class_path)
for file_name in os.listdir(class_path):
if '.avi' not in file_name:
continue
name, ext = os.path.splitext(file_name)
dst_directory_path = os.path.join(dst_class_path, name)
video_file_path = os.path.join(class_path, file_name)
try:
if os.path.exists(dst_directory_path):
if not os.path.exists(os.path.join(dst_directory_path, 'image_00001.jpg')):
subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True)
print('remove {}'.format(dst_directory_path))
os.mkdir(dst_directory_path)
else:
continue
else:
os.mkdir(dst_directory_path)
except:
print(dst_directory_path)
continue
cmd = 'ffmpeg -i \"{}\" -vf scale=-1:480 \"{}/image_%05d.jpg\"'.format(video_file_path, dst_directory_path)
print(cmd)
subprocess.call(cmd, shell=True)
print('\n')
if __name__=="__main__":
dir_path = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data3/videos' # 视频文件总路径
dst_dir_path = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data3/rawframes' # 抽帧后图片存放路径
for class_name in os.listdir(dir_path):
class_process(dir_path, dst_dir_path, class_name)
如果是mp4文件将.avi改成.mp4
结果:
1.3生成frames列表,用于训练,据说用帧来训练比较快
打开gen_label_ucf101.py文件,修改路径
import os
import glob
import fnmatch
import random
import sys
root = r"data4/rawframes" # 抽帧后的图片存放目录文件夹,用于写到txt文件中在构建数据集的时候读取
def parse_ucf_splits():
class_ind = [x.strip().split() for x in open('D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/annotations\classInd.txt')] # 类别txt
class_mapping = {x[1]:int(x[0])-1 for x in class_ind}
def line2rec(line):
items = line.strip().split('/')
label = class_mapping[items[0]]
vid = items[1].split('.')[0]
return vid, label
splits = []
for i in range(1, 4):
train_list = [line2rec(x) for x in open(r'D:\learning\ActionSeqmentation\temporal-shift-module-master\data4\annotations\trainlist.txt'.format(i))] # 训练集txt
test_list = [line2rec(x) for x in open(r'D:\learning\ActionSeqmentation\temporal-shift-module-master\data4\annotations\testlist.txt'.format(i))] # 测试集txt
splits.append((train_list, test_list))
return splits
split_parsers = dict()
split_parsers['ucf101'] = parse_ucf_splits()
def parse_split_file(dataset):
sp = split_parsers[dataset]
return tuple(sp)
def parse_directory(path, rgb_prefix='image_', flow_x_prefix='flow_x_', flow_y_prefix='flow_y_'):
"""
Parse directories holding extracted frames from standard benchmarks
"""
print('parse frames under folder {}'.format(path))
frame_folders = []
frame = glob.glob(os.path.join(path, '*'))
for frame_name in frame:
frame_path = glob.glob(os.path.join(frame_name, '*'))
frame_folders.extend(frame_path)
def count_files(directory, prefix_list):
lst = os.listdir(directory)
cnt_list = [len(fnmatch.filter(lst, x+'*')) for x in prefix_list]
return cnt_list
# check RGB
rgb_counts = {}
flow_counts = {}
dir_dict = {}
for i,f in enumerate(frame_folders):
all_cnt = count_files(f, (rgb_prefix, flow_x_prefix, flow_y_prefix))
k = f.split('\\')[-1]
rgb_counts[k] = all_cnt[0]
dir_dict[k] = f
x_cnt = all_cnt[1]
y_cnt = all_cnt[2]
if x_cnt != y_cnt:
raise ValueError('x and y direction have different number of flow images. video: '+f)
flow_counts[k] = x_cnt
if i % 200 == 0:
print('{} videos parsed'.format(i))
print('frame folder analysis done')
return dir_dict, rgb_counts, flow_counts
def build_split_list(split_tuple, frame_info, split_idx, shuffle=False):
split = split_tuple[split_idx]
def build_set_list(set_list):
rgb_list, flow_list = list(), list()
for item in set_list:
frame_dir = frame_info[0][item[0]]
frame_dir = root +'/'+ frame_dir.split('\\')[-2] +'/'+ frame_dir.split('\\')[-1]
rgb_cnt = frame_info[1][item[0]]
flow_cnt = frame_info[2][item[0]]
rgb_list.append('{} {} {}\n'.format(frame_dir, rgb_cnt, item[1]))
flow_list.append('{} {} {}\n'.format(frame_dir, flow_cnt, item[1]))
if shuffle:
random.shuffle(rgb_list)
random.shuffle(flow_list)
return rgb_list, flow_list
train_rgb_list, train_flow_list = build_set_list(split[0])
test_rgb_list, test_flow_list = build_set_list(split[1])
return (train_rgb_list, test_rgb_list), (train_flow_list, test_flow_list)
spl = parse_split_file('ucf101')
f_info = parse_directory(r"D:\learning\ActionSeqmentation\temporal-shift-module-master\data4\rawframes") # 存放抽帧后的图片
out_path = r"D:\learning\ActionSeqmentation\temporal-shift-module-master\data4" # 标签路径
dataset = "ucf101"
for i in range(max(3,len(spl))):
lists = build_split_list(spl,f_info,i)
open(os.path.join(out_path, '{}_rgb_train_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[0][0])
open(os.path.join(out_path, '{}_rgb_val_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[0][1])
# open(os.path.join(out_path, '{}_flow_train_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[1][0])
# open(os.path.join(out_path, '{}_flow_val_split_{}.txt'.format(dataset, i + 1)), 'w').writelines(lists[1][1])
2.修改配置文件
下载预训练权重,我是用迅雷下载的https://hanlab.mit.edu/projects/tsm/models/TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e100_dense.pth
更改ops/dataset_config.py中文件
# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
import os
ROOT_DATASET = 'D:\learning\ActionSeqmentation/temporal-shift-module-master' # '/data/jilin/'
def return_ucf101(modality):
filename_categories = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/annotations\classInd.txt'
if modality == 'RGB':
root_data = ROOT_DATASET + '/'
filename_imglist_train = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/ucf101_rgb_train_split_1.txt'
filename_imglist_val = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data4/ucf101_rgb_val_split_1.txt'
prefix = 'image_{:05d}.jpg'
elif modality == 'Flow':
root_data = ROOT_DATASET + 'UCF101/jpg'
filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt'
filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt'
prefix = 'flow_{}_{:05d}.jpg'
else:
raise NotImplementedError('no such modality:' + modality)
return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
def return_hmdb51(modality):
filename_categories = 51
if modality == 'RGB':
root_data = ROOT_DATASET + 'HMDB51/images'
filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt'
filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt'
prefix = 'img_{:05d}.jpg'
elif modality == 'Flow':
root_data = ROOT_DATASET + 'HMDB51/images'
filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt'
filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt'
prefix = 'flow_{}_{:05d}.jpg'
else:
raise NotImplementedError('no such modality:' + modality)
return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
def return_something(modality):
filename_categories = 'something/v1/category.txt'
if modality == 'RGB':
root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1'
filename_imglist_train = 'something/v1/train_videofolder.txt'
filename_imglist_val = 'something/v1/val_videofolder.txt'
prefix = '{:05d}.jpg'
elif modality == 'Flow':
root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1-flow'
filename_imglist_train = 'something/v1/train_videofolder_flow.txt'
filename_imglist_val = 'something/v1/val_videofolder_flow.txt'
prefix = '{:06d}-{}_{:05d}.jpg'
else:
print('no such modality:'+modality)
raise NotImplementedError
return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
def return_somethingv2(modality):
filename_categories = 'something/v2/category.txt'
if modality == 'RGB':
root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-frames'
filename_imglist_train = 'something/v2/train_videofolder.txt'
filename_imglist_val = 'something/v2/val_videofolder.txt'
prefix = '{:06d}.jpg'
elif modality == 'Flow':
root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow'
filename_imglist_train = 'something/v2/train_videofolder_flow.txt'
filename_imglist_val = 'something/v2/val_videofolder_flow.txt'
prefix = '{:06d}.jpg'
else:
raise NotImplementedError('no such modality:'+modality)
return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
def return_jester(modality):
filename_categories = 'jester/category.txt'
if modality == 'RGB':
prefix = '{:05d}.jpg'
root_data = ROOT_DATASET + 'jester/20bn-jester-v1'
filename_imglist_train = 'jester/train_videofolder.txt'
filename_imglist_val = 'jester/val_videofolder.txt'
else:
raise NotImplementedError('no such modality:'+modality)
return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
def return_kinetics(modality):
filename_categories = 400
if modality == 'RGB':
root_data = ROOT_DATASET + 'kinetics/images'
filename_imglist_train = 'kinetics/labels/train_videofolder.txt'
filename_imglist_val = 'kinetics/labels/val_videofolder.txt'
prefix = 'img_{:05d}.jpg'
else:
raise NotImplementedError('no such modality:' + modality)
return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix
def return_dataset(dataset, modality):
dict_single = {'jester': return_jester, 'something': return_something, 'somethingv2': return_somethingv2,
'ucf101': return_ucf101, 'hmdb51': return_hmdb51,
'kinetics': return_kinetics }
if dataset in dict_single:
file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality)
else:
raise ValueError('Unknown dataset '+dataset)
file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train)
file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val)
if isinstance(file_categories, str):
file_categories = os.path.join(ROOT_DATASET, file_categories)
with open(file_categories) as f:
lines = f.readlines()
categories = [item.rstrip() for item in lines]
else: # number of categories
categories = [None] * file_categories
n_class = len(categories)
print('{}: {} classes'.format(dataset, n_class))
return n_class, file_imglist_train, file_imglist_val, root_data, prefix
更改opts.py
# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
import argparse
parser = argparse.ArgumentParser(description="PyTorch implementation of Temporal Segment Networks")
parser.add_argument('--dataset',
default='ucf101',
type=str)
parser.add_argument('--modality', type=str,
default='RGB',
choices=['RGB', 'Flow'])
parser.add_argument('--train_list', type=str, default="data4/ucf101_rgb_train_split_1.txt")
parser.add_argument('--val_list', type=str, default="data4/ucf101_rgb_val_split_1.txt")
parser.add_argument('--root_path', type=str, default="D:\learning\ActionSeqmentation/temporal-shift-module-master")
# ========================= Model Configs ==========================
parser.add_argument('--arch', type=str, default="resnet50")
parser.add_argument('--num_segments', type=int, default=8)
parser.add_argument('--consensus_type', type=str, default='avg')
parser.add_argument('--k', type=int, default=3)
parser.add_argument('--dropout', '--do', default=0.8, type=float,
metavar='DO', help='dropout ratio (default: 0.5)')
parser.add_argument('--loss_type', type=str, default="nll",
choices=['nll'])
parser.add_argument('--img_feature_dim', default=256, type=int, help="the feature dimension for each frame")
parser.add_argument('--suffix', type=str, default=None)
parser.add_argument('--pretrain', type=str, default='imagenet')
parser.add_argument('--tune_from', type=str, default='D:\learning\ActionSeqmentation/temporal-shift-module-master\pretrain\TSM_kinetics_RGB_resnet50_shift8_blockres_avg_segment8_e100_dense.pth', help='fine-tune from checkpoint')
# ========================= Learning Configs ==========================
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=4, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--lr_type', default='step', type=str,
metavar='LRtype', help='learning rate type')
parser.add_argument('--lr_steps', default=[10, 20], type=float, nargs="+",
metavar='LRSteps', help='epochs to decay learning rate by 10')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--clip-gradient', '--gd', default=None, type=float,
metavar='W', help='gradient norm clipping (default: disabled)')
parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true")
# ========================= Monitor Configs ==========================
parser.add_argument('--print-freq', '-p', default=20, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--eval-freq', '-ef', default=1, type=int,
metavar='N', help='evaluation frequency (default: 5)')
# ========================= Runtime Configs ==========================
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
help='number of data loading workers (default: 8)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--snapshot_pref', type=str, default="")
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--gpus', nargs='+', type=int, default=None)
parser.add_argument('--flow_prefix', default="", type=str)
parser.add_argument('--root_log',type=str, default='log')
parser.add_argument('--root_model', type=str, default='checkpoint')
parser.add_argument('--shift', default=True, action="store_true", help='use shift for models')
parser.add_argument('--shift_div', default=8, type=int, help='number of div for shift (default: 8)')
parser.add_argument('--shift_place', default='blockres', type=str, help='place for shift (default: stageres)')
parser.add_argument('--temporal_pool', default=False, action="store_true", help='add temporal pooling')
parser.add_argument('--non_local', default=False, action="store_true", help='add non local block')
parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample for video dataset')
3.之后开始训练:main.py修改如下:
# Code for "TSM: Temporal Shift Module for Efficient Video Understanding"
# arXiv:1811.08383
# Ji Lin*, Chuang Gan, Song Han
# {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu
import os
import time
import shutil
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
from torch.nn.utils import clip_grad_norm_
from ops.dataset import TSNDataSet
from ops.models import TSN
from ops.transforms import *
from opts import parser
from ops import dataset_config
from ops.utils import AverageMeter, accuracy
from ops.temporal_shift import make_temporal_pool
from tensorboardX import SummaryWriter
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
best_prec1 = 0
def main():
global args, best_prec1
args = parser.parse_args()
num_class, args.train_list, args.val_list, args.root_path, prefix = dataset_config.return_dataset(args.dataset,
args.modality)
full_arch_name = args.arch
if args.shift:
full_arch_name += '_shift{}_{}'.format(args.shift_div, args.shift_place)
if args.temporal_pool:
full_arch_name += '_tpool'
args.store_name = '_'.join(
['TSM', args.dataset, args.modality, full_arch_name, args.consensus_type, 'segment%d' % args.num_segments,
'e{}'.format(args.epochs)])
if args.pretrain != 'imagenet':
args.store_name += '_{}'.format(args.pretrain)
if args.lr_type != 'step':
args.store_name += '_{}'.format(args.lr_type)
if args.dense_sample:
args.store_name += '_dense'
if args.non_local > 0:
args.store_name += '_nl'
if args.suffix is not None:
args.store_name += '_{}'.format(args.suffix)
print('storing name: ' + args.store_name)
check_rootfolders()
model = TSN(num_class, args.num_segments, args.modality,
base_model=args.arch,
consensus_type=args.consensus_type,
dropout=args.dropout,
img_feature_dim=args.img_feature_dim,
partial_bn=not args.no_partialbn,
pretrain=args.pretrain,
is_shift=args.shift, shift_div=args.shift_div, shift_place=args.shift_place,
fc_lr5=not (args.tune_from and args.dataset in args.tune_from),
temporal_pool=args.temporal_pool,
non_local=args.non_local)
crop_size = model.crop_size
scale_size = model.scale_size
input_mean = model.input_mean
input_std = model.input_std
policies = model.get_optim_policies()
train_augmentation = model.get_augmentation(flip=False if 'something' in args.dataset or 'jester' in args.dataset else True)
model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()
optimizer = torch.optim.SGD(policies,
args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay)
if args.resume:
if args.temporal_pool: # early temporal pool so that we can load the state_dict
make_temporal_pool(model.module.base_model, args.num_segments)
if os.path.isfile(args.resume):
print(("=> loading checkpoint '{}'".format(args.resume)))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print(("=> loaded checkpoint '{}' (epoch {})"
.format(args.evaluate, checkpoint['epoch'])))
else:
print(("=> no checkpoint found at '{}'".format(args.resume)))
if args.tune_from:
print(("=> fine-tuning from '{}'".format(args.tune_from)))
sd = torch.load(args.tune_from)
sd = sd['state_dict']
model_dict = model.state_dict()
replace_dict = []
for k, v in sd.items():
if k not in model_dict and k.replace('.net', '') in model_dict:
print('=> Load after remove .net: ', k)
replace_dict.append((k, k.replace('.net', '')))
for k, v in model_dict.items():
if k not in sd and k.replace('.net', '') in sd:
print('=> Load after adding .net: ', k)
replace_dict.append((k.replace('.net', ''), k))
for k, k_new in replace_dict:
sd[k_new] = sd.pop(k)
keys1 = set(list(sd.keys()))
keys2 = set(list(model_dict.keys()))
set_diff = (keys1 - keys2) | (keys2 - keys1)
print('#### Notice: keys that failed to load: {}'.format(set_diff))
if args.dataset not in args.tune_from: # new dataset
print('=> New dataset, do not load fc weights')
sd = {k: v for k, v in sd.items() if 'fc' not in k}
if args.modality == 'Flow' and 'Flow' not in args.tune_from:
sd = {k: v for k, v in sd.items() if 'conv1.weight' not in k}
model_dict.update(sd)
model.load_state_dict(model_dict)
if args.temporal_pool and not args.resume:
make_temporal_pool(model.module.base_model, args.num_segments)
cudnn.benchmark = True
# Data loading code
if args.modality != 'RGBDiff':
normalize = GroupNormalize(input_mean, input_std)
else:
normalize = IdentityTransform()
if args.modality == 'RGB':
data_length = 1
elif args.modality in ['Flow', 'RGBDiff']:
data_length = 5
train_loader = torch.utils.data.DataLoader(
TSNDataSet(args.root_path, args.train_list, num_segments=args.num_segments,
new_length=data_length,
modality=args.modality,
image_tmpl=prefix,
transform=torchvision.transforms.Compose([
train_augmentation,
Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
normalize,
]), dense_sample=args.dense_sample),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True,
drop_last=True) # prevent something not % n_GPU
val_loader = torch.utils.data.DataLoader(
TSNDataSet(args.root_path, args.val_list, num_segments=args.num_segments,
new_length=data_length,
modality=args.modality,
image_tmpl=prefix,
random_shift=False,
transform=torchvision.transforms.Compose([
GroupScale(int(scale_size)),
GroupCenterCrop(crop_size),
Stack(roll=(args.arch in ['BNInception', 'InceptionV3'])),
ToTorchFormatTensor(div=(args.arch not in ['BNInception', 'InceptionV3'])),
normalize,
]), dense_sample=args.dense_sample),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
# define loss function (criterion) and optimizer
if args.loss_type == 'nll':
criterion = torch.nn.CrossEntropyLoss().cuda()
else:
raise ValueError("Unknown loss type")
for group in policies:
print(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format(
group['name'], len(group['params']), group['lr_mult'], group['decay_mult'])))
if args.evaluate:
validate(val_loader, model, criterion, 0)
return
log_training = open(os.path.join(args.root_log, args.store_name, 'log.csv'), 'w')
with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f:
f.write(str(args))
tf_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name))
for epoch in range(args.start_epoch, args.epochs):
adjust_learning_rate(optimizer, epoch, args.lr_type, args.lr_steps)
# train for one epoch
train(train_loader, model, criterion, optimizer, epoch, log_training, tf_writer)
# evaluate on validation set
if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1:
prec1 = validate(val_loader, model, criterion, epoch, log_training, tf_writer)
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
tf_writer.add_scalar('acc/test_top1_best', best_prec1, epoch)
output_best = 'Best Prec@1: %.3f\n' % (best_prec1)
print(output_best)
log_training.write(output_best + '\n')
log_training.flush()
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'best_prec1': best_prec1,
}, is_best)
def train(train_loader, model, criterion, optimizer, epoch, log, tf_writer):
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
if args.no_partialbn:
model.module.partialBN(False)
else:
model.module.partialBN(True)
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
# measure data loading time
data_time.update(time.time() - end)
target = target.cuda()
input_var = torch.autograd.Variable(input)
target_var = torch.autograd.Variable(target)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 2))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# compute gradient and do SGD step
loss.backward()
if args.clip_gradient is not None:
total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)
optimizer.step()
optimizer.zero_grad()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
output = ('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
epoch, i, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr'] * 0.1)) # TODO
print(output)
log.write(output + '\n')
log.flush()
tf_writer.add_scalar('loss/train', losses.avg, epoch)
tf_writer.add_scalar('acc/train_top1', top1.avg, epoch)
tf_writer.add_scalar('acc/train_top5', top5.avg, epoch)
tf_writer.add_scalar('lr', optimizer.param_groups[-1]['lr'], epoch)
def validate(val_loader, model, criterion, epoch, log=None, tf_writer=None):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
# switch to evaluate mode
model.eval()
end = time.time()
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
target = target.cuda()
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 2))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
output = ('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
i, len(val_loader), batch_time=batch_time, loss=losses,
top1=top1, top5=top5))
print(output)
if log is not None:
log.write(output + '\n')
log.flush()
output = ('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}'
.format(top1=top1, top5=top5, loss=losses))
print(output)
if log is not None:
log.write(output + '\n')
log.flush()
if tf_writer is not None:
tf_writer.add_scalar('loss/test', losses.avg, epoch)
tf_writer.add_scalar('acc/test_top1', top1.avg, epoch)
tf_writer.add_scalar('acc/test_top5', top5.avg, epoch)
return top1.avg
def save_checkpoint(state, is_best):
filename = '%s/%s/ckpt.pth.tar' % (args.root_model, args.store_name)
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, filename.replace('pth.tar', 'best.pth.tar'))
def adjust_learning_rate(optimizer, epoch, lr_type, lr_steps):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
if lr_type == 'step':
decay = 0.1 ** (sum(epoch >= np.array(lr_steps)))
lr = args.lr * decay
decay = args.weight_decay
elif lr_type == 'cos':
import math
lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.epochs))
decay = args.weight_decay
else:
raise NotImplementedError
for param_group in optimizer.param_groups:
param_group['lr'] = lr * param_group['lr_mult']
param_group['weight_decay'] = decay * param_group['decay_mult']
def check_rootfolders():
"""Create log and model folder"""
folders_util = [args.root_log, args.root_model,
os.path.join(args.root_log, args.store_name),
os.path.join(args.root_model, args.store_name)]
for folder in folders_util:
if not os.path.exists(folder):
print('creating folder ' + folder)
os.mkdir(folder)
if __name__ == '__main__':
main()
以上是训练结果
4.demo测试:
新建demo.py文件:
import os
import time
from ops.models import TSN
from ops.transforms import *
import cv2
from PIL import Image
arch = 'resnet50'
num_class = 2
num_segments = 8
modality = 'RGB'
base_model = 'resnet50'
consensus_type='avg'
dataset = 'ucf101'
dropout = 0.1
img_feature_dim = 256
no_partialbn = True
pretrain = 'imagenet'
shift = True
shift_div = 8
shift_place = 'blockres'
temporal_pool = False
non_local = False
tune_from = None
#load model
model = TSN(num_class, num_segments, modality,
base_model=arch,
consensus_type=consensus_type,
dropout=dropout,
img_feature_dim=img_feature_dim,
partial_bn=not no_partialbn,
pretrain=pretrain,
is_shift=shift, shift_div=shift_div, shift_place=shift_place,
fc_lr5=not (tune_from and dataset in tune_from),
temporal_pool=temporal_pool,
non_local=non_local)
model = torch.nn.DataParallel(model, device_ids=None).cuda()
resume = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\checkpoint\TSM_ucf101_RGB_resnet50_shift8_blockres_avg_segment8_e20\ckpt.best.pth.tar' # the last weights
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
#how to deal with the pictures
input_mean = [0.485, 0.456, 0.406]
input_std = [0.229, 0.224, 0.225]
normalize = GroupNormalize(input_mean, input_std)
transform_hyj = torchvision.transforms.Compose([
GroupScale_hyj(input_size=320),
Stack(roll=(arch in ['BNInception', 'InceptionV3'])),
ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])),
normalize,
])
video_path = 'D:\learning\ActionSeqmentation/temporal-shift-module-master\data\posvideo\sketch/videos\YoYo/v_YoYo_g08_c01.avi'
pil_img_list = list()
cls_text = ['Rowing','YoYo']
cls_color = [(0,255,0),(0,0,255)]
import time
cap = cv2.VideoCapture(video_path) #导入的视频所在路径
start_time = time.time()
counter = 0
frame_numbers = 0
training_fps = 30
training_time = 2.5
fps = cap.get(cv2.CAP_PROP_FPS) #视频平均帧率
if fps < 1:
fps = 30
duaring = int(fps * training_time / num_segments)
print(duaring)
# exit()
state = 0
while cap.isOpened():
ret, frame = cap.read()
if ret:
frame_numbers+=1
print(frame_numbers)
# print(len(pil_img_list))
if frame_numbers%duaring == 0 and len(pil_img_list)<8:
frame_pil = Image.fromarray(cv2.cvtColor(frame,cv2.COLOR_BGR2RGB))
pil_img_list.extend([frame_pil])
if frame_numbers%duaring == 0 and len(pil_img_list)==8:
frame_pil = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
pil_img_list.pop(0)
pil_img_list.extend([frame_pil])
input = transform_hyj(pil_img_list)
input = input.unsqueeze(0).cuda()
out = model(input)
print(out)
output_index = int(torch.argmax(out).cpu())
state = output_index
#键盘输入空格暂停,输入q退出
key = cv2.waitKey(1) & 0xff
if key == ord(" "):
cv2.waitKey(0)
if key == ord("q"):
break
counter += 1#计算帧数
if (time.time() - start_time) != 0:#实时显示帧数
cv2.putText(frame, "{0} {1}".format((cls_text[state]),float('%.1f' % (counter / (time.time() - start_time)))), (50, 50),cv2.FONT_HERSHEY_SIMPLEX, 2, cls_color[state],3)
cv2.imshow('frame', frame)
counter = 0
start_time = time.time()
time.sleep(1 / fps)#按原帧率播放
# time.sleep(2/fps)# observe the output
else:
break
cap.release()
cv2.destroyAllWindows()