GRL(CVPR2023图像修复)训练自己的数据集

文章讲述了作者在尝试GRL论文中ImageRestoration模型训练时遇到的问题及解决方案。包括环境配置、数据集准备、代码修改,特别是Windows环境下不兼容问题和GPU内存管理。作者提供了训练自己的数据集的方法,并分享了训练过程中的显存优化策略。
摘要由CSDN通过智能技术生成

最近看完“GRL for Image Restoration”论文后想要尝试用自己的数据训练一下模型,但是实际操作下来发现整个过程还是踩了非常多的坑的,所以记录一下便于需要的朋友使用,项目链接:GRL_github

1.参考README配置环境

在这里插入图片描述
参考作者给出的README(个人觉得还是有点简洁),配置环境,准备好数据集,下载模型并成功执行测试代码。在整个过程中我觉得有几个需要注意的点:
(1) 代码里所有需要设置的参数全部在config文件夹里,但是针对不同的实验需要修改的配置不同,以dn实验base模型为例,你可能需要用到的yaml大致为:
在这里插入图片描述
(2) 在Ubuntu配置环境并运行,Windwos配置完环境执行代码后会报一个什么os缺少模块的错误,找了半天都说Windows目前不适配
(3) 测试数据集我选用的是Mcmaster(下载自取),同时还需要下载对应的image_info(下载自取),此时你还要修改grl_p256.yaml中的数据集名称。
在这里插入图片描述
(4) 测试过程如果出现与数据集有关的错误,这个时候需要检查restoration_dn.py代码里返回的img_info路径以及对应的图像路径是否正确。
在这里插入图片描述

2.准备自己的数据集

还是以dn实验为例,原代码里输入只有一张高清图(target),对高清图随机加噪作为(source),我想用真实的数据对进行训练,所以我的数据集包括source跟target两个部分,source为真实的带噪声图,target为高清图,具体数据集结构如下:
在这里插入图片描述

3.修改数据加载函数

原代码里是对高清图自动加噪,所以加载部分只加载一张图作为label,然后自动加噪后作为网络的input,并且需要从json里读取图像位置信息。为了方便训练自己本地的数据对,首先在训练数据加载处(restoration_dn.py)新增一个custom数据的加载模式get_custom_train_file(),同理测试加载函数也要同样修改。
在这里插入图片描述
在这里插入图片描述
继续修改图像读取方式,保证source跟target图像经过相同的预处理(原始代码里有随机翻转、切割等,这里需要保证一致)
在这里插入图片描述

4.开始训练

修改grl_p256.yaml中的数据集名称为custom,调整图像尺寸(我是用128的patch进行训练的,调整patch的同时需要同时调整stripe的尺寸),下图为我训练时的配置
在这里插入图片描述
执行以下代码进行训练,j是cpu线程,gpus是显卡数目。

 torchx run -- -j 1x2 -- \
     -m training=True gpus=2 experiment=dn/grl model=grl/grl_base\

训练注意事项:由于模型用了多层的transformer叠加,训练过程中对于显存的要求比较高,而且训练速度比较慢,在grl_small模型,patch_size=256,stripe_size=128×64,window_size=16时,占用了34GB左右的显存。大家训练的时候可根据自己的GPU显存情况选择对应的模型,或者将上述几个size的尺寸为减半。推荐的是使用grl_base模型,patch_size=128,stripe_size=64×32,window_size=16,大约占用16GB左右的显存

5.总结

写的可能有点乱,有些细节由于时间的原因没写详细,有问题欢迎大家提出。

参考资源链接:[DANN迁移训练实战:MNIST与MNIST-M数据集应用](https://wenku.csdn.net/doc/1jca3czt2g?utm_source=wenku_answer2doc_content) 深度域适配(DANN)中的梯度反转层(GRL)是实现无监督领域适应的关键技术之一。它通过在反向传播过程中对源域数据施加一个负权重,强制模型在学习源域特征的同时,增加对目标域的适应性。这种机制让模型在保留源域有用知识的同时,也能够学习到目标域的特定特征,从而减少两个域之间的分布差异。 具体来说,在DANN架构中,GRL通常位于源域特征提取器和领域分类器之间。在前向传播过程中,数据正常传递,而到了反向传播阶段,GRL会对来自源域的梯度施加负号,反转梯度的方向。这样,当梯度更新参数时,会朝着减少源域分类器判别能力的方向进行,这有助于模型提取更加领域无关的特征。 在MNIST与MNIST-M数据集迁移学习的上下文中,MNIST-M数据集是通过将MNIST数据集的手写数字图像与BSDS500数据集中的真实背景图像融合得到的。这样的数据集转换增加了域间迁移的难度。DANN通过GRL可以有效地利用源域MNIST的数据来训练模型,同时确保模型不会过度拟合源域的特性,而是学习到更加泛化的特征表示,最终提高在MNIST-M上的分类准确率。 为了深入理解这一机制,并掌握如何在MNIST和MNIST-M数据集上实现DANN,可以参考《DANN迁移训练实战:MNIST与MNIST-M数据集应用》。这份资料通过具体的实战案例和代码示例,详细讲解了如何构建DANN模型,包括GRL的集成和超参数调整,帮助你更好地将这些理论应用到实际的迁移学习任务中。 参考资源链接:[DANN迁移训练实战:MNIST与MNIST-M数据集应用](https://wenku.csdn.net/doc/1jca3czt2g?utm_source=wenku_answer2doc_content)
评论 40
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

ffffffdl

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值