DBPN的PyTorch实现
源码:https://github.com/alterzero/DBPN-Pytorch
依赖项:
- Python 3.5
- PyTorch >= 1.0.0
1、测试
1)下载解压 DBPN-Pytorch,在 models 文件夹中包含已经训练好的 9 个模型文件。
2)打开 eval.py 修改 upscale 因子,相应的参数随之修改。
parser.add_argument('--upscale_factor', type=int, default=2, help="super resolution upscale factor")
# Input文件夹中的Set5_LR_x2文件夹为输入的测试图像
parser.add_argument('--test_dataset', type=str, default='Set5_LR_x2')
parser.add_argument('--model_type', type=str, default='DBPN')
parser.add_argument('--model', default='models/DBPN_x2.pth', help='sr pretrained base model')
# parser.add_argument('--upscale_factor', type=int, default=4, help="super resolution upscale factor")
# parser.add_argument('--test_dataset', type=str, default='Set5_LR_x4')
# parser.add_argument('--model_type', type=str, default='DBPN')
# parser.add_argument('--model', default='models/DBPN_x4.pth', help='sr pretrained base model')
# parser.add_argument('--upscale_factor', type=int, default=8, help="super resolution upscale factor")
# parser.add_argument('--test_dataset', type=str, default='Set5_LR_x8')
# parser.add_argument('--model_type', type=str, default='DBPN')
# parser.add_argument('--model', default='models/DBPN_x8.pth', help='sr pretrained base model')
# parser.add_argument('--upscale_factor', type=int, default=8, help="super resolution upscale factor")
# parser.add_argument('--test_dataset', type=str, default='Set5_LR_x8')
# parser.add_argument('--model_type', type=str, default='DBPN-RES-MR64-3')
# parser.add_argument('--model', default='models/DBPN-RES-MR64-3_8x.pth', help='sr pretrained base model')
3)出现如下错误时,打开 eval.py 修改最后一行 “ eval() ”。
RuntimeError:
An attempt has been made to start a new process before the
current process has finished its bootstrapping phase.
This probably means that you are not using fork to start your
child processes and you have forgotten to use the proper idiom
in the main module:
if __name__ == '__main__':
freeze_support()
...
The "freeze_support()" line can be omitted if the program
is not going to be frozen to produce an executable.
# eval()
def main():
with torch.no_grad():
eval()
# ##Eval Start!!!!
if __name__ == '__main__':
main()
4)测试结果。
Namespace(chop_forward=False, gpu_mode=True, gpus=1, input_dir='Input', model='models/DBPN_x2.pth', model_type='DBPN', output='Results/', residual=False, seed=123, self_ensemble=False, testBatchSize=1, test_dataset='Set5_LR_x2', threads=1, upscale_factor=2)
===> Loading datasets
===> Building model
Pre-trained SR model is loaded.
===> Processing: baby_x2.png || Timer: 0.7450 sec.
===> Processing: bird_x2.png || Timer: 0.0180 sec.
===> Processing: butterfly_x2.png || Timer: 0.0170 sec.
===> Processing: head_x2.png || Timer: 0.0170 sec.
===> Processing: woman_x2.png || Timer: 0.0180 sec.
input: result:
5)问题:使用 DBPN-RES-MR64-3 时,测试结果是黑色图像。
2、训练
1)在 DBPN-Pytorch 文件夹中新建 Dataset 文件夹和 weights 文件夹,在 Dataset 文件夹中新建 DIV2K_train_HR 文件夹存放训练图像。
2)打开 main.py 修改 upscale 因子。
parser.add_argument('--upscale_factor', type=int, default=8, help="super resolution upscale factor")
parser.add_argument('--model_type', type=str, default='DBPN')
3)如遇到 1、测试中 “ 3)”这个问题,打开 main.py 进行如下修改。
# 原程序
for epoch in range(opt.start_iter, opt.nEpochs + 1):
train(epoch)
# learning rate is decayed by a factor of 10 every half of total epochs
if (epoch+1) % (opt.nEpochs/2) == 0:
for param_group in optimizer.param_groups:
param_group['lr'] /= 10.0
print('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))
if (epoch+1) % (opt.snapshots) == 0:
checkpoint(epoch)
# 修改如下,放到代码最后
def main():
for epoch in range(opt.start_iter, opt.nEpochs + 1):
train(epoch)
# learning rate is decayed by a factor of 10 every half of total epochs
if epoch % (opt.nEpochs/2) == 0:
for param_group in optimizer.param_groups:
param_group['lr'] /= 10.0
print('Learning rate decay: lr={}'.format(optimizer.param_groups[0]['lr']))
if epoch % opt.snapshots == 0:
checkpoint(epoch)
# test()
if __name__ == '__main__':
main()
4)如遇到以下问题,修改 main.py 中的 patch_size。
raise ValueError("empty range for randrange() (%d,%d, %d)" % (istart, istop, width))
ValueError: empty range for randrange() (0,-7, -7)
parser.add_argument('--patch_size', type=int, default=10, help='Size of cropped HR image') # 原始默认值为40
5)训练结果,训练 epoch 设为 2000,每隔500次保存一个模型文件。
----------------------------------------------
===> Epoch[6](1/18): Loss: 0.0060 || Timer: 0.0170 sec.
===> Epoch[6](2/18): Loss: 0.0292 || Timer: 0.0160 sec.
===> Epoch[6](3/18): Loss: 0.0614 || Timer: 0.0159 sec.
===> Epoch[6](4/18): Loss: 0.0283 || Timer: 0.0150 sec.
===> Epoch[6](5/18): Loss: 0.0444 || Timer: 0.0160 sec.
===> Epoch[6](6/18): Loss: 0.0182 || Timer: 0.0170 sec.
===> Epoch[6](7/18): Loss: 0.1083 || Timer: 0.0159 sec.
===> Epoch[6](8/18): Loss: 0.1186 || Timer: 0.0160 sec.
===> Epoch[6](9/18): Loss: 0.0719 || Timer: 0.0160 sec.
===> Epoch[6](10/18): Loss: 0.0572 || Timer: 0.0160 sec.
===> Epoch[6](11/18): Loss: 0.0184 || Timer: 0.0160 sec.
===> Epoch[6](12/18): Loss: 0.0236 || Timer: 0.0170 sec.
===> Epoch[6](13/18): Loss: 0.0804 || Timer: 0.0159 sec.
===> Epoch[6](14/18): Loss: 0.0251 || Timer: 0.0159 sec.
===> Epoch[6](15/18): Loss: 0.0130 || Timer: 0.0150 sec.
===> Epoch[6](16/18): Loss: 0.1449 || Timer: 0.0159 sec.
===> Epoch[6](17/18): Loss: 0.0234 || Timer: 0.0159 sec.
===> Epoch[6](18/18): Loss: 0.0827 || Timer: 0.0150 sec.
===> Epoch 6 Complete: Avg. Loss: 0.0531
parser.add_argument('--nEpochs', type=int, default=2000, help='number of epochs to train for') # 2000
parser.add_argument('--snapshots', type=int, default=500, help='Snapshots') # 50