DBPN的PyTorch实现

DBPN的PyTorch实现

源码:https://github.com/alterzero/DBPN-Pytorch

依赖项:

  • Python 3.5
  • PyTorch >= 1.0.0

 

1、测试

1)下载解压 DBPN-Pytorch,在 models 文件夹中包含已经训练好的 9 个模型文件。

2)打开 eval.py 修改 upscale 因子,相应的参数随之修改。

parser.add_argument('--upscale_factor', type=int, default=2, help="super resolution upscale factor")
# Input文件夹中的Set5_LR_x2文件夹为输入的测试图像
parser.add_argument('--test_dataset', type=str, default='Set5_LR_x2')
parser.add_argument('--model_type', type=str, default='DBPN')
parser.add_argument('--model', default='models/DBPN_x2.pth', help='sr pretrained base model')

# parser.add_argument('--upscale_factor', type=int, default=4, help="super resolution upscale factor")
# parser.add_argument('--test_dataset', type=str, default='Set5_LR_x4')
# parser.add_argument('--model_type', type=str, default='DBPN')
# parser.add_argument('--model', default='models/DBPN_x4.pth', help='sr pretrained base model')

# parser.add_argument('--upscale_factor', type=int, default=8, help="super resolution upscale factor")
# parser.add_argument('--test_dataset', type=str, default='Set5_LR_x8')
# parser.add_argument('--model_type', type=str, default='DBPN')
# parser.add_argument('--model', default='models/DBPN_x8.pth', help='sr pretrained base model')

# parser.add_argument('--upscale_factor', type=int, default=8, help="super resolution upscale factor")
# parser.add_argument('--test_dataset', type=str, default='Set5_LR_x8')
# parser.add_argument('--model_type', type=str, default='DBPN-RES-MR64-3')
# parser.add_argument('--model', default='models/DBPN-RES-MR64-3_8x.pth', help='sr pretrained base model')

3)出现如下错误时,打开 eval.py 修改最后一行 “ eval() ”。

RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

# eval()
def main():
    with torch.no_grad():
        eval()


# ##Eval Start!!!!
if __name__ == '__main__':
    main()

4)测试结果。

Namespace(chop_forward=False, gpu_mode=True, gpus=1, input_dir='Input', model='models/DBPN_x2.pth', model_type='DBPN', output='Results/', residual=False, seed=123, self_ensemble=False, testBatchSize=1, test_dataset='Set5_LR_x2', threads=1, upscale_factor=2)
===> Loading datasets
===> Building model
Pre-trained SR model is loaded.
===> Processing: baby_x2.png || Timer: 0.7450 sec.
===> Processing: bird_x2.png || Timer: 0.0180 sec.
===> Processing: butterfly_x2.png || Timer: 0.0170 sec.
===> Processing: head_x2.png || Timer: 0.0170 sec.
===> Processing: woman_x2.png || Timer: 0.0180 sec.

input:          result:

5)问题:使用 DBPN-RES-MR64-3 时,测试结果是黑色图像。

 

2、训练

1)在 DBPN-Pytorch 文件夹中新建 Dataset 文件夹和 weights 文件夹,在 Dataset 文件夹中新建 DIV2K_train_HR 文件夹存放训练图像。

2)打开 main.py 修改 upscale 因子。

parser.add_argument('--upscale_factor', type=int, default=8, help="super resolution upscale factor")
parser.add_argument('--model_type', type=str, default='DBPN')

3)如遇到 1、测试中 “ 3)”这个问题,打开 main.py 进行如下修改。

# 原程序
for epoch in range(opt.start_iter, opt.nEpochs + 1):
    train(epoch)

    # learning rate is decayed by a factor of 10 every half of total epochs
    if (epoch+1) % (opt.nEpochs/2) == 0:
        for param_group in optimizer.param_groups:
            param_group['lr'] /= 10.0
        print('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))

    if (epoch+1) % (opt.snapshots) == 0:
        checkpoint(epoch)



# 修改如下,放到代码最后
def main():
    for epoch in range(opt.start_iter, opt.nEpochs + 1):
        train(epoch)

        # learning rate is decayed by a factor of 10 every half of total epochs
        if epoch % (opt.nEpochs/2) == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] /= 10.0
            print('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))

        if epoch % opt.snapshots == 0:
            checkpoint(epoch)

    # test()


if __name__ == '__main__':
    main()

4)如遇到以下问题,修改 main.py 中的 patch_size。

raise ValueError("empty range for randrange() (%d,%d, %d)" % (istart, istop, width))
ValueError: empty range for randrange() (0,-7, -7)

parser.add_argument('--patch_size', type=int, default=10, help='Size of cropped HR image')  # 原始默认值为40

5)训练结果,训练 epoch 设为 2000,每隔500次保存一个模型文件。

----------------------------------------------
===> Epoch[6](1/18): Loss: 0.0060 || Timer: 0.0170 sec.
===> Epoch[6](2/18): Loss: 0.0292 || Timer: 0.0160 sec.
===> Epoch[6](3/18): Loss: 0.0614 || Timer: 0.0159 sec.
===> Epoch[6](4/18): Loss: 0.0283 || Timer: 0.0150 sec.
===> Epoch[6](5/18): Loss: 0.0444 || Timer: 0.0160 sec.
===> Epoch[6](6/18): Loss: 0.0182 || Timer: 0.0170 sec.
===> Epoch[6](7/18): Loss: 0.1083 || Timer: 0.0159 sec.
===> Epoch[6](8/18): Loss: 0.1186 || Timer: 0.0160 sec.
===> Epoch[6](9/18): Loss: 0.0719 || Timer: 0.0160 sec.
===> Epoch[6](10/18): Loss: 0.0572 || Timer: 0.0160 sec.
===> Epoch[6](11/18): Loss: 0.0184 || Timer: 0.0160 sec.
===> Epoch[6](12/18): Loss: 0.0236 || Timer: 0.0170 sec.
===> Epoch[6](13/18): Loss: 0.0804 || Timer: 0.0159 sec.
===> Epoch[6](14/18): Loss: 0.0251 || Timer: 0.0159 sec.
===> Epoch[6](15/18): Loss: 0.0130 || Timer: 0.0150 sec.
===> Epoch[6](16/18): Loss: 0.1449 || Timer: 0.0159 sec.
===> Epoch[6](17/18): Loss: 0.0234 || Timer: 0.0159 sec.
===> Epoch[6](18/18): Loss: 0.0827 || Timer: 0.0150 sec.
===> Epoch 6 Complete: Avg. Loss: 0.0531

parser.add_argument('--nEpochs', type=int, default=2000, help='number of epochs to train for')  # 2000
parser.add_argument('--snapshots', type=int, default=500, help='Snapshots')  # 50

  • 5
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 18
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值