DVS-SR---test.py

1、导入包

import os
import argparse

import models
import datasets

import torch.nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from tqdm import tqdm

cudnn.benchmark = True

2、设置参数:设置检查点路径、设置测试数据路径、设置输出图像保存路径、设置线程数

parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path', type=str, help='Location to save checkpoint models')
parser.add_argument('--data_dir', type=str, help='Testing data directory')
parser.add_argument('--save_dir', type=str, help='Location to save output images')
parser.add_argument('--num_workers', type=int, default=0, help='Number of workers for pytorch DataLoader')

args = parser.parse_args()

3、加载checkpoint

checkpoint = torch.load(args.checkpoint_path)

4、加载配置文件

cfg = checkpoint['config']

5、加载数据集 

test_set = datasets.SingleFolderDataset(root=args.data_dir, cfg=cfg, is_train=False)
test_loader = DataLoader(dataset=test_set, num_workers=args.num_workers,
                         batch_size=cfg.TEST.BATCH_SIZE, shuffle=False)

6、构建模型

model = models.SRNet(cfg)
model = torch.nn.DataParallel(model)
model.load_state_dict(checkpoint['model'])
model = model.cuda()

 7、测试

with torch.no_grad():
    model.eval()
    for batch in tqdm(test_loader):
        central_stack = batch['central_stack'].cuda()
        side_stack = [data.cuda() for data in batch['side_stack']]
        flow = [data.cuda() for data in batch['flow']]
        video_name = batch['video_name']
        image_name = batch['image_name']

        prediction = model(central_stack, side_stack, flow)

        for idx in range(prediction.size(0)):
            cur_pred = prediction[idx].clamp(0.0, 1.0)
            cur_video_name = video_name[idx]
            cur_image_name = image_name[idx]

            generated_image = transforms.ToPILImage()(cur_pred.cpu()).convert('L')
            os.makedirs(os.path.join(args.save_dir, cur_video_name), exist_ok=True)
            generated_image.save(os.path.join(args.save_dir, cur_video_name, '%s.png' % cur_image_name))

为何train.py与test.py的最后部分代码相同?

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值