本文致力于将文中的一些细节给大家解释清楚,如果有照顾不到的细节,还请见谅,欢迎留言讨论
1.参数部分:
parser = argparse.ArgumentParser(description="TSM testing on the full validation set")
parser.add_argument('dataset', type=str)
# may contain splits
parser.add_argument('--weights', type=str, default=None)
parser.add_argument('--test_segments', type=str, default=25)
parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D')
parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble')
parser.add_argument('--full_res', default=False, action="store_true",
help='use full resolution 256x256 for test as in Non-local I3D')
parser.add_argument('--test_crops', type=int, default=1)
parser.add_argument('--coeff', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
help='number of data loading workers (default: 8)')
# for true test
parser.add_argument('--test_list', type=str, default=None)
parser.add_argument('--csv_file', type=str, default=None)
parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')
parser.add_argument('--max_num', type=int, default=-1)
parser.add_argument('--input_size', type=int, default=224)
parser.add_argument('--crop_fusion_type', type=str, default='avg')
parser.add_argument('--gpus', nargs='+', type=int, default=None)
parser.add_argument('--img_feature_dim',type=int, default=256)
parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video')
parser.add_argument('--pretrain', type=str, default='imagenet')
args = parser.parse_args()
这里我们主要关注的参数应该是 test_segments,由于文章才用的是跨步采样的方式。因此这里的test_segments表示将视频等分成的份儿数,从每份中随机抽取一帧。注意dense-sample,twice-sample以及test_crops参数,接下来我们还会介绍,其他的不影响阅读代码的参数我们不再介绍。
2.数据处理
读者部分之前,你必须知道的两个函数是zip() 和 enmunate(). 不然循环会让你晕头转向;
1.for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
#这表示从zip中的三个list中同时返回三个元素
2.enumerate(data_loader),返回一个list,元素为索引和值本身组成的tuple
weights_list = args.weights.split(',')
test_segments_list = [int(s) for s in args.test_segments.split(',')]
assert len(weights_list) == len(test_segments_list) #均为1
if args.coeff is None:
coeff_list = [1] * len(weights_list)
else:
coeff_list = [float(c) for c in args.coeff.split(',')]
if args.test_list is not None:
test_file_list = args.test_list.split(',')
else:
test_file_list = [None] * len(weights_list)
data_iter_list = []
net_list = []
modality_list = []
total_num = None
for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights) #设置模型参数
if 'RGB' in this_weights:
modality = 'RGB'
else:
modality = 'Flow'
this_arch = this_weights.split('TSM_')[1].split('_')[2] #resnet50
modality_list.append(modality)
num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,
modality) #获得类别总数,训练集,验证集等
print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
'''
Created on 2020年5月10日
定义net,并加载参数
@author: DELL
'''
net = TSN(num_class, this_test_segments if is_shift else 1, modality,
base_model=this_arch,
consensus_type=args.crop_fusion_type,
img_feature_dim=args.img_feature_dim,
pretrain=args.pretrain,
is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
non_local='_nl' in this_weights,
)
print(net)
if 'tpool' in this_weights:
from first.ops.temporal_shift import make_temporal_pool
make_temporal_pool(net.base_model, this_test_segments) # since DataParallel
checkpoint = torch.load(this_weights)
checkpoint = checkpoint['state_dict']
# base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
'base_model.classifier.bias': 'new_fc.bias',
}
for k, v in replace_dict.items():
if k in base_dict:
base_dict[v] = base_dict.pop(k)
net.load_state_dict(base_dict)
'''
模型加载结束,这部分有不懂得可以参见我的另一篇博客
'''
input_size = net.scale_size if args.full_res else net.input_size
'''
Created on 2020年5月10日
选择数据的随机处理模式
这里给大家简单介绍,有个初步的概念
@author: DELL
'''
if args.test_crops == 1: #为1,则先放缩,再裁剪只留中间的符合尺寸的部分
cropping = torchvision.transforms.Compose([
GroupScale(net.scale_size),
GroupCenterCrop(input_size),
])
elif args.test_crops == 3: # do not flip, so only 3 crops #为3,先放缩,后裁剪,然后留下左边,右边,中间三个裁剪数据,不翻转,故一一张图片扩充为了3张
cropping = torchvision.transforms.Compose([
GroupFullResSample(input_size, net.scale_size, flip=False)
])
elif args.test_crops == 5: # do not flip, so only 5 crops #为5,先放缩,后裁剪,然后留下左上,左下,右上,右下,中间5个裁剪数据,不翻转,故一一张图片扩充为了5张
cropping = torchvision.transforms.Compose([
GroupOverSample(input_size, net.scale_size, flip=False)
])
elif args.test_crops == 10:#为10,先放缩,后裁剪,然后留下左上,左下,右上,右下,中间5个裁剪数据,翻转翻倍,故一一张图片扩充为了10张
cropping = torchvision.transforms.Compose([
GroupOverSample(input_size, net.scale_size)
])
else:
raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))
'''
定义data_loader,data:[K,Batch_size,test_crops,test_segments,224,224,3],可以暂时这样理解,主要能理解K控制底下的循环即可。这里不再详细解释,不影响阅读代码
dense_sample,表示密集采样,在一个视频的每一段中随机取10帧,然后对每一帧进行上述的crop等处理,则10帧扩展成了10*test_crops帧
twice_sample,表示采样两次 ,然后对每一帧进行上述的crop等处理,则2帧扩展成了2*test_crops帧
'''
data_loader = torch.utils.data.DataLoader(
TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,
new_length=1 if modality == "RGB" else 5,
modality=modality,
image_tmpl=prefix,
test_mode=True,
remove_missing=len(weights_list) == 1,
transform=torchvision.transforms.Compose([
cropping,
Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),
GroupNormalize(net.input_mean, net.input_std),
]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True,
)
'''
设置gpu
'''
if args.gpus is not None:
devices = [args.gpus[i] for i in range(args.workers)]
else:
devices = list(range(args.workers))
'''
设置数据并行
'''
net = torch.nn.DataParallel(net.cuda())
net.eval()
'''
返回一个list,元素形式为tuple:(index,data,label)
'''
data_gen = enumerate(data_loader)
if total_num is None:
total_num = len(data_loader.dataset)
else:
assert total_num == len(data_loader.dataset)
data_iter_list.append(data_gen) #data部分[1,K,batch_size,test_crops,test_segments,224,224,3]
net_list.append(net)
3.测试:
def eval_video(video_data, net, this_test_segments, modality):
net.eval()
with torch.no_grad():
i, data, label = video_data
batch_size = label.numel() #返回数组中的元素个数
num_crop = args.test_crops #这里我们用到了test_crops的值
if args.dense_sample:
num_crop *= 10 # 10 clips for testing when using dense sample #这里为什么这样操作大家也应该明白了
if args.twice_sample:
num_crop *= 2
if modality == 'RGB':
length = 3
elif modality == 'Flow':
length = 10
elif modality == 'RGBDiff':
length = 18
else:
raise ValueError("Unknown modality "+ modality)
data_in = data.view(-1, length, data.size(2), data.size(3))
if is_shift: #如果有时间位移模块,则调整输入为下列格式
data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
rst = net(data_in) #[batch_size * num_crop,174]
rst = rst.reshape(batch_size, num_crop, -1).mean(1) #[batch_size, 174]
if args.softmax:
# take the softmax to normalize the output to probability
rst = F.softmax(rst, dim=1) #按行进行softmax
rst = rst.data.cpu().numpy().copy()
if net.module.is_shift:
rst = rst.reshape(batch_size, num_class)
else:
rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))
return i, rst, label
proc_start_time = time.time()
max_num = args.max_num if args.max_num > 0 else total_num
top1 = AverageMeter()
top5 = AverageMeter()
for i, data_label_pairs in enumerate(zip(*data_iter_list)):#*表示降维,K控制大循环次数!!K = total / batch_size
with torch.no_grad():
if i >= max_num:
break
this_rst_list = []
this_label = None
for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):
rst = eval_video((i, data, label), net, n_seg, modality)#turple返回一个turple
this_rst_list.append(rst[1])# rst[1]表示预测的类别部分,这里我当时看了很久,这里的rst与eval_video函数中返回的rst名字一致,但其实不是一个东西
this_label = label
assert len(this_rst_list) == len(coeff_list) #1 = 1
for i_coeff in range(len(this_rst_list)):
this_rst_list[i_coeff] *= coeff_list[i_coeff]
ensembled_predict = sum(this_rst_list) / len(this_rst_list) #sum(表示,沿最高维相加) #[batch_size,174]
for p, g in zip(ensembled_predict, this_label.cpu().numpy()):
output.append([p[None, ...], g]) #[[data[0],label],[[data[1]],label],...] 共total_num个元素,每个元素的尺寸[[1,174],1],注意这里每次循环添加batch_size个元素
cnt_time = time.time() - proc_start_time
prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5)) #详见accuracy函数
top1.update(prec1.item(), this_label.numel())#详见AverageMeter
top5.update(prec5.item(), this_label.numel())
if i % 20 == 0:
print('video {} done, total {}/{}, average {:.3f} sec/video, '
'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,
float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))
video_pred = [np.argmax(x[0]) for x in output] #详见上面output的出处!
video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]
video_labels = [x[1] for x in output]
4.相关函数和类:
这里你必须知道的函数为pytorch中的topk,非常好用的函数
a,b = data.topk(maxk,dims,True,True),这里返回的a是data中前maxk大的元素,b是其索引,dims = 1,按列,= 0,按行!
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1): #计算平均值
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0) #72
_, pred = output.topk(maxk, 1, True, True)#取指定维度上的几个最大值 第一个返回值为值,第二个为值得位置
pred = pred.t()#转置[5,72]
correct = pred.eq(target.view(1, -1).expand_as(pred)) #[5,72] 由1,0构成
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0) #.view(-1)转换为行向量
res.append(correct_k.mul_(100.0 / batch_size))
return res
def parse_shift_option_from_log_name(log_name):
if 'shift' in log_name:
strings = log_name.split('_')
for i, s in enumerate(strings):
if 'shift' in s:
break
return True, int(strings[i].replace('shift', '')), strings[i + 1]
else:
return False, None, None
#仅供大家参考,有笔误或什么欢迎指正,大家交流提高