Pytorch 之 TSM(Time Shift Module)测试部分源码详解

本文致力于将文中的一些细节给大家解释清楚,如果有照顾不到的细节,还请见谅,欢迎留言讨论

1.参数部分:

parser = argparse.ArgumentParser(description="TSM testing on the full validation set")
parser.add_argument('dataset', type=str)

# may contain splits
parser.add_argument('--weights', type=str, default=None)
parser.add_argument('--test_segments', type=str, default=25)
parser.add_argument('--dense_sample', default=False, action="store_true", help='use dense sample as I3D')
parser.add_argument('--twice_sample', default=False, action="store_true", help='use twice sample for ensemble')
parser.add_argument('--full_res', default=False, action="store_true",
                    help='use full resolution 256x256 for test as in Non-local I3D')

parser.add_argument('--test_crops', type=int, default=1)
parser.add_argument('--coeff', type=str, default=None)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
                    help='number of data loading workers (default: 8)')

# for true test
parser.add_argument('--test_list', type=str, default=None)
parser.add_argument('--csv_file', type=str, default=None)

parser.add_argument('--softmax', default=False, action="store_true", help='use softmax')

parser.add_argument('--max_num', type=int, default=-1)
parser.add_argument('--input_size', type=int, default=224)
parser.add_argument('--crop_fusion_type', type=str, default='avg')
parser.add_argument('--gpus', nargs='+', type=int, default=None)
parser.add_argument('--img_feature_dim',type=int, default=256)
parser.add_argument('--num_set_segments',type=int, default=1,help='TODO: select multiply set of n-frames from a video')
parser.add_argument('--pretrain', type=str, default='imagenet')

args = parser.parse_args()

这里我们主要关注的参数应该是 test_segments,由于文章才用的是跨步采样的方式。因此这里的test_segments表示将视频等分成的份儿数,从每份中随机抽取一帧。注意dense-sample,twice-sample以及test_crops参数,接下来我们还会介绍,其他的不影响阅读代码的参数我们不再介绍。

2.数据处理

读者部分之前,你必须知道的两个函数是zip() 和 enmunate(). 不然循环会让你晕头转向;

1.for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):

#这表示从zip中的三个list中同时返回三个元素

2.enumerate(data_loader),返回一个list,元素为索引和值本身组成的tuple

weights_list = args.weights.split(',')
test_segments_list = [int(s) for s in args.test_segments.split(',')]
assert len(weights_list) == len(test_segments_list) #均为1
if args.coeff is None:
    coeff_list = [1] * len(weights_list)
else:
    coeff_list = [float(c) for c in args.coeff.split(',')]

if args.test_list is not None:
    test_file_list = args.test_list.split(',')
else:
    test_file_list = [None] * len(weights_list)


data_iter_list = []
net_list = []
modality_list = []

total_num = None
for this_weights, this_test_segments, test_file in zip(weights_list, test_segments_list, test_file_list):
    is_shift, shift_div, shift_place = parse_shift_option_from_log_name(this_weights) #设置模型参数
    if 'RGB' in this_weights: 
        modality = 'RGB'
    else:
        modality = 'Flow'
    this_arch = this_weights.split('TSM_')[1].split('_')[2] #resnet50
    modality_list.append(modality)
    num_class, args.train_list, val_list, root_path, prefix = dataset_config.return_dataset(args.dataset,
                                                                                            modality) #获得类别总数,训练集,验证集等
    print('=> shift: {}, shift_div: {}, shift_place: {}'.format(is_shift, shift_div, shift_place))
    '''
    Created on 2020年5月10日
          定义net,并加载参数
    @author: DELL
    '''
    net = TSN(num_class, this_test_segments if is_shift else 1, modality,
              base_model=this_arch,
              consensus_type=args.crop_fusion_type,
              img_feature_dim=args.img_feature_dim,
              pretrain=args.pretrain,
              is_shift=is_shift, shift_div=shift_div, shift_place=shift_place,
              non_local='_nl' in this_weights,
              )
    print(net)
    if 'tpool' in this_weights:
        from first.ops.temporal_shift import make_temporal_pool
        make_temporal_pool(net.base_model, this_test_segments)  # since DataParallel

    checkpoint = torch.load(this_weights)
    checkpoint = checkpoint['state_dict']

    # base_dict = {('base_model.' + k).replace('base_model.fc', 'new_fc'): v for k, v in list(checkpoint.items())}
    base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())}
    replace_dict = {'base_model.classifier.weight': 'new_fc.weight',
                    'base_model.classifier.bias': 'new_fc.bias',
                    }
    for k, v in replace_dict.items():
        if k in base_dict:
            base_dict[v] = base_dict.pop(k)

    net.load_state_dict(base_dict)
    '''
           模型加载结束,这部分有不懂得可以参见我的另一篇博客
    '''
    input_size = net.scale_size if args.full_res else net.input_size
    '''
    Created on 2020年5月10日
          选择数据的随机处理模式
          这里给大家简单介绍,有个初步的概念     
    @author: DELL
    '''
    if args.test_crops == 1: #为1,则先放缩,再裁剪只留中间的符合尺寸的部分
        cropping = torchvision.transforms.Compose([
            GroupScale(net.scale_size),
            GroupCenterCrop(input_size),
        ])
    elif args.test_crops == 3:  # do not flip, so only 3 crops #为3,先放缩,后裁剪,然后留下左边,右边,中间三个裁剪数据,不翻转,故一一张图片扩充为了3张
        cropping = torchvision.transforms.Compose([
            GroupFullResSample(input_size, net.scale_size, flip=False)
        ])
    elif args.test_crops == 5:  # do not flip, so only 5 crops #为5,先放缩,后裁剪,然后留下左上,左下,右上,右下,中间5个裁剪数据,不翻转,故一一张图片扩充为了5张
        cropping = torchvision.transforms.Compose([
            GroupOverSample(input_size, net.scale_size, flip=False)
        ])
    elif args.test_crops == 10:#为10,先放缩,后裁剪,然后留下左上,左下,右上,右下,中间5个裁剪数据,翻转翻倍,故一一张图片扩充为了10张
        cropping = torchvision.transforms.Compose([
            GroupOverSample(input_size, net.scale_size)
        ])
    else:
        raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(args.test_crops))
    '''
          定义data_loader,data:[K,Batch_size,test_crops,test_segments,224,224,3],可以暂时这样理解,主要能理解K控制底下的循环即可。这里不再详细解释,不影响阅读代码
     dense_sample,表示密集采样,在一个视频的每一段中随机取10帧,然后对每一帧进行上述的crop等处理,则10帧扩展成了10*test_crops帧
     twice_sample,表示采样两次  ,然后对每一帧进行上述的crop等处理,则2帧扩展成了2*test_crops帧   
    '''
    data_loader = torch.utils.data.DataLoader(
            TSNDataSet(root_path, test_file if test_file is not None else val_list, num_segments=this_test_segments,
                       new_length=1 if modality == "RGB" else 5,
                       modality=modality,
                       image_tmpl=prefix,
                       test_mode=True,
                       remove_missing=len(weights_list) == 1,
                       transform=torchvision.transforms.Compose([
                           cropping,
                           Stack(roll=(this_arch in ['BNInception', 'InceptionV3'])),
                           ToTorchFormatTensor(div=(this_arch not in ['BNInception', 'InceptionV3'])),
                           GroupNormalize(net.input_mean, net.input_std),
                       ]), dense_sample=args.dense_sample, twice_sample=args.twice_sample),
            batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True,
    )
    '''
          设置gpu
    '''
    if args.gpus is not None:
        devices = [args.gpus[i] for i in range(args.workers)]
    else:
        devices = list(range(args.workers))
    '''
          设置数据并行
    '''
    net = torch.nn.DataParallel(net.cuda())
    net.eval()
    '''
         返回一个list,元素形式为tuple:(index,data,label)
    '''
    data_gen = enumerate(data_loader)

    if total_num is None:
        total_num = len(data_loader.dataset)
    else:
        assert total_num == len(data_loader.dataset)

    data_iter_list.append(data_gen) #data部分[1,K,batch_size,test_crops,test_segments,224,224,3]
    net_list.append(net)

3.测试:

def eval_video(video_data, net, this_test_segments, modality):
    net.eval()
    with torch.no_grad():
        i, data, label = video_data
        batch_size = label.numel() #返回数组中的元素个数
        num_crop = args.test_crops #这里我们用到了test_crops的值
        if args.dense_sample:
            num_crop *= 10  # 10 clips for testing when using dense sample #这里为什么这样操作大家也应该明白了

        if args.twice_sample:
            num_crop *= 2

        if modality == 'RGB':
            length = 3
        elif modality == 'Flow':
            length = 10
        elif modality == 'RGBDiff':
            length = 18
        else:
            raise ValueError("Unknown modality "+ modality)

        data_in = data.view(-1, length, data.size(2), data.size(3))
        if is_shift: #如果有时间位移模块,则调整输入为下列格式
            data_in = data_in.view(batch_size * num_crop, this_test_segments, length, data_in.size(2), data_in.size(3))
        rst = net(data_in) #[batch_size * num_crop,174]
        rst = rst.reshape(batch_size, num_crop, -1).mean(1) #[batch_size, 174]

        if args.softmax:
            # take the softmax to normalize the output to probability
            rst = F.softmax(rst, dim=1) #按行进行softmax

        rst = rst.data.cpu().numpy().copy()

        if net.module.is_shift:
            rst = rst.reshape(batch_size, num_class) 
        else:
            rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class))

        return i, rst, label


proc_start_time = time.time()
max_num = args.max_num if args.max_num > 0 else total_num 

top1 = AverageMeter()
top5 = AverageMeter()

for i, data_label_pairs in enumerate(zip(*data_iter_list)):#*表示降维,K控制大循环次数!!K = total / batch_size
    with torch.no_grad():
        if i >= max_num:
            break
        this_rst_list = []
        this_label = None
        for n_seg, (_, (data, label)), net, modality in zip(test_segments_list, data_label_pairs, net_list, modality_list):
            rst = eval_video((i, data, label), net, n_seg, modality)#turple返回一个turple
            this_rst_list.append(rst[1])# rst[1]表示预测的类别部分,这里我当时看了很久,这里的rst与eval_video函数中返回的rst名字一致,但其实不是一个东西
            this_label = label
        assert len(this_rst_list) == len(coeff_list) #1 = 1
        for i_coeff in range(len(this_rst_list)):
            this_rst_list[i_coeff] *= coeff_list[i_coeff]
        ensembled_predict = sum(this_rst_list) / len(this_rst_list) #sum(表示,沿最高维相加) #[batch_size,174]

        for p, g in zip(ensembled_predict, this_label.cpu().numpy()):
            output.append([p[None, ...], g]) #[[data[0],label],[[data[1]],label],...] 共total_num个元素,每个元素的尺寸[[1,174],1],注意这里每次循环添加batch_size个元素
        cnt_time = time.time() - proc_start_time
        prec1, prec5 = accuracy(torch.from_numpy(ensembled_predict), this_label, topk=(1, 5)) #详见accuracy函数
        top1.update(prec1.item(), this_label.numel())#详见AverageMeter
        top5.update(prec5.item(), this_label.numel())
        if i % 20 == 0:
            print('video {} done, total {}/{}, average {:.3f} sec/video, '
                  'moving Prec@1 {:.3f} Prec@5 {:.3f}'.format(i * args.batch_size, i * args.batch_size, total_num,
                                                              float(cnt_time) / (i+1) / args.batch_size, top1.avg, top5.avg))

video_pred = [np.argmax(x[0]) for x in output] #详见上面output的出处! 
video_pred_top5 = [np.argsort(np.mean(x[0], axis=0).reshape(-1))[::-1][:5] for x in output]

video_labels = [x[1] for x in output]

 

4.相关函数和类:

这里你必须知道的函数为pytorch中的topk,非常好用的函数

a,b = data.topk(maxk,dims,True,True),这里返回的a是data中前maxk大的元素,b是其索引,dims = 1,按列,= 0,按行!

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1): #计算平均值
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0) #72
    _, pred = output.topk(maxk, 1, True, True)#取指定维度上的几个最大值 第一个返回值为值,第二个为值得位置 
    pred = pred.t()#转置[5,72]
    correct = pred.eq(target.view(1, -1).expand_as(pred)) #[5,72] 由1,0构成
    res = []
    for k in topk:
         correct_k = correct[:k].view(-1).float().sum(0) #.view(-1)转换为行向量
         res.append(correct_k.mul_(100.0 / batch_size))
    return res


def parse_shift_option_from_log_name(log_name):
    if 'shift' in log_name:
        strings = log_name.split('_')
        for i, s in enumerate(strings):
            if 'shift' in s:
                break
        return True, int(strings[i].replace('shift', '')), strings[i + 1]
    else:
        return False, None, None

#仅供大家参考,有笔误或什么欢迎指正,大家交流提高 

 

  • 0
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值