最近看完“GRL for Image Restoration”论文后想要尝试用自己的数据训练一下模型,但是实际操作下来发现整个过程还是踩了非常多的坑的,所以记录一下便于需要的朋友使用,项目链接:GRL_github
1.参考README配置环境
参考作者给出的README(个人觉得还是有点简洁),配置环境,准备好数据集,下载模型并成功执行测试代码。在整个过程中我觉得有几个需要注意的点:
(1) 代码里所有需要设置的参数全部在config文件夹里,但是针对不同的实验需要修改的配置不同,以dn实验base模型为例,你可能需要用到的yaml大致为:
(2) 在Ubuntu配置环境并运行,Windwos配置完环境执行代码后会报一个什么os缺少模块的错误,找了半天都说Windows目前不适配。
(3) 测试数据集我选用的是Mcmaster(下载自取),同时还需要下载对应的image_info(下载自取),此时你还要修改grl_p256.yaml中的数据集名称。
(4) 测试过程如果出现与数据集有关的错误,这个时候需要检查restoration_dn.py代码里返回的img_info路径以及对应的图像路径是否正确。
2.准备自己的数据集
还是以dn实验为例,原代码里输入只有一张高清图(target),对高清图随机加噪作为(source),我想用真实的数据对进行训练,所以我的数据集包括source跟target两个部分,source为真实的带噪声图,target为高清图,具体数据集结构如下:
3.修改数据加载函数
原代码里是对高清图自动加噪,所以加载部分只加载一张图作为label,然后自动加噪后作为网络的input,并且需要从json里读取图像位置信息。为了方便训练自己本地的数据对,首先在训练数据加载处(restoration_dn.py)新增一个custom数据的加载模式get_custom_train_file(),同理测试加载函数也要同样修改。
继续修改图像读取方式,保证source跟target图像经过相同的预处理(原始代码里有随机翻转、切割等,这里需要保证一致)
4.开始训练
修改grl_p256.yaml中的数据集名称为custom,调整图像尺寸(我是用128的patch进行训练的,调整patch的同时需要同时调整stripe的尺寸),下图为我训练时的配置
执行以下代码进行训练,j是cpu线程,gpus是显卡数目。
torchx run -- -j 1x2 -- \
-m training=True gpus=2 experiment=dn/grl model=grl/grl_base\
训练注意事项:由于模型用了多层的transformer叠加,训练过程中对于显存的要求比较高,而且训练速度比较慢,在grl_small模型,patch_size=256,stripe_size=128×64,window_size=16时,占用了34GB左右的显存。大家训练的时候可根据自己的GPU显存情况选择对应的模型,或者将上述几个size的尺寸为减半。推荐的是使用grl_base模型,patch_size=128,stripe_size=64×32,window_size=16,大约占用16GB左右的显存
5.总结
写的可能有点乱,有些细节由于时间的原因没写详细,有问题欢迎大家提出。