Vision Transformers for Single Image Dehazing代码解读

目录结构:

配置文件,里面包含了一些超参数和模型设置的选项。

  • "batch_size": 32,代表训练时每个batch使用的样本数。
  • "patch_size": 256,代表将原始图像切成多少大小为256x256的小块进行训练。
  • "valid_mode": "test",代表验证集使用的数据集,这里是test。
  • "edge_decay": 0,代表去除图像边缘的程度,这里为0。
  • "only_h_flip": false,代表是否只使用水平翻转数据扩增,这里为false。
  • "optimizer": "adamw",代表使用的优化器,这里是AdamW。
  • "lr": 4e-4/,代表学习率,这里是4e-4或2e-4。
  • "epochs":300,代表训练的总epoch数。
  • "eval_freq": 1,代表每多少个epoch输出一次验证集的结果,输出psnr和ssim的值。

自己训练的话,创建一个配置文件来训练。

loader.py实现了一个数据加载器。其中包括两个类:PairLoaderSingleLoader。

PairLoader用于对图像进行配对加载,即同时加载原始图像和目标图像,原始图像存放在 hazy 目录下,目标图像存放在 GT 目录下。其中 data_dir 是数据集的根目录,sub_dir 是数据集子目录(比如训练集、验证集等),mode 是数据加载模式,支持三种模式:训练模式、验证模式和测试模式。size 是图像的尺寸,edge_decay 是边缘衰减系数,only_h_flip 表示是否只进行水平翻转。

SingleLoader用于单独加载图像,其中 root_dir 是图像所在目录。

PairLoader 类中,__len__ 方法返回数据集中样本数目,__getitem__ 方法根据索引 idx 返回对应的数据样本。在此方法中,首先使用 cv2 库读取原始图像和目标图像,并将像素值的范围从 [0, 1] 转换为 [-1, 1]。然后,如果是训练模式,调用 augment 函数进行数据增强;如果是验证模式,调用 align 函数对图像进行裁剪居中。最后,将处理后的图像和文件名返回。

augment 函数用于进行数据增强,包括对图像进行随机裁剪、水平翻转和旋转等操作。其中 size 是目标图像的尺寸,edge_decay 是边缘衰减系数,only_h_flip 表示是否只进行水平翻转。

align 函数用于将图像进行居中裁剪,以适应模型输入要求。

hwc_to_chw 函数用于将图像从 HWC(高度、宽度、通道)格式转换为 CHW(通道、高度、宽度)格式。

read_img 函数用于读取图像文件,并将像素值的范围从 [0, 255] 转换为 [0, 1]

arch.png网络结构图

dehazefoemer网络的定义:

  1. PatchEmbed 和 PatchUnEmbed 类:用于将输入图像分割成非重叠的图像块(patch),并将这些图像块重新组合成输出图像。

  2. SKFusion 类:实现了一种特征融合机制,通过计算特征的均值和通过MLP学习到的权重来融合输入的多个特征图。

  3. RLN 类:Revised LayerNorm,是一个自定义的归一化层,用于对输入进行归一化处理。

  4. Mlp 类:多层感知机(MLP)模块,用于在Transformer中对输入特征进行非线性变换和映射。

  5. WindowAttention 类:基于窗口的注意力机制,用于在Transformer中计算注意力权重。

  6. Attention 和 TransformerBlock 类:实现了Transformer中的自注意力机制和前馈网络部分。

  7. BasicLayer 类:包含多个 TransformerBlock,用于构建整个Transformer的层。

  8. DehazeFormer 类:整个图像去雾模型的定义,包括了上述的各种自定义层和模块,以及多个 BasicLayer 组成的主干部分。

  9. 最后是一些工厂函数 dehazeformer_t、dehazeformer_s 等,用于创建不同配置的 DehazeFormer 模型实例。

common.py:定义了一些在图像处理中常用的辅助函数和类。

  1. AverageMeter类是一个用于计算平均值的计量器。它具有以下方法:

    • reset():重置计量器的值
    • update(val, n=1):更新计量器的值,val表示要更新的数值,n表示该数值的权重,默认为1
  2. ListAverageMeter类是一个用于计算列表中元素的平均值的计量器。它具有以下方法:

    • reset():重置计量器的值
    • set_len(n):设置列表的长度为n,并重置计量器的值
    • update(vals, n=1):更新计量器的值,vals是一个长度为n的列表,表示要更新的数值,n表示每个数值的权重,默认为1
  3. read_img(filename)函数用于读取图像文件,并返回一个经过归一化处理的图像数组。它使用OpenCV库的imread()函数读取图像,并将图像通道顺序从BGR转换为RGB,并进行归一化处理。

  4. write_img(filename, img)函数用于将图像数组保存为图像文件。它将图像数组重新转换为BGR通道顺序,并使用OpenCV库的imwrite()函数将图像保存为指定文件名的图像文件。

  5. hwc_to_chw(img)函数用于将图像数组的通道顺序从HWC(高度、宽度、通道)转换为CHW(通道、高度、宽度)。

  6. chw_to_hwc(img)函数用于将图像数组的通道顺序从CHW转换为HWC。

这些函数和类可以在图像处理任务中使用,例如计算图像的平均值、读取和保存图像文件,以及在深度学习模型中转换图像的通道顺序。

data_parallel.py:这段代码是一个自定义的BalancedDataParallel类,继承自torch.nn.DataParallel类。它主要用于在多个GPU上平衡地并行处理模型的输入数据。

代码中定义了scatter函数和scatter_kwargs函数,这两个函数用于将输入数据切分成多个块,并将这些块分布到指定的GPU上。scatter函数用于普通的输入数据,而scatter_kwargs函数用于带有关键字参数的输入数据。

BalancedDataParallel类重写了DataParallel类的forward()方法,实现了在多个GPU上平衡地并行执行模型的前向传播操作。具体来说,代码中的forward()方法首先根据GPU数量和指定的gpu0_bsz参数,决定使用哪些GPU进行计算。然后,根据选择的GPU数量,将输入数据切分成多个块,并复制模型的副本到相应的GPU上。接下来,使用parallel_apply函数在各个GPU上并行地应用模型的副本,得到每个GPU上的输出结果。最后,使用gather函数将各个GPU上的输出结果收集起来,并返回最终的输出结果。

总的来说,这段代码实现了一个平衡地在多个GPU上进行并行计算的功能,可以提高模型训练和推理的效率。

run.sh 测试脚本

test.py用于测试图像去雾模型的 Python 程序,其主要功能如下:

  1. 定义了命令行参数,包括模型名称、数据集路径、结果保存路径等。用户可以通过命令行修改这些参数。

  2. 定义了一个函数 single(),用于处理从保存的模型中加载的状态字典。在 PyTorch 中使用多 GPU 训练模型时,保存的状态字典中包含了每个 GPU 的参数以及一些额外信息,这可能会导致在单 GPU 上加载模型时出现问题。该函数的作用就是将保存的状态字典中前缀为 module. 的字符串删除。

  3. 定义了一个函数 test(),用于测试模型并保存结果。该函数的输入参数包括数据加载器、模型、结果保存路径等。该函数首先将模型设置为评估模式,并使用测试数据集逐个样本地进行测试。对于每个样本,该函数首先将输入图像输入模型并得到输出图像,然后计算输出图像与目标图像的 PSNR 和 SSIM,最后将结果保存到文件中。

  4. __main__ 函数中,创建模型实例并加载已训练的模型,然后创建数据集并调用 test() 函数进行测试。在测试过程中,程序会打印出当前测试进度和每个测试样本的 PSNR 和 SSIM。

  5. 其他辅助函数包括 write_img()chw_to_hwc(),用于保存图像和从 CHW 格式转换为 HWC 格式。

train.py 是一个训练脚本。主要包括以下几个部分:

  1. 导入所需的库和模块:包括os、argparse、json、torch等。
  2. 解析命令行参数:使用argparse模块解析用户输入的命令行参数,包括模型名称、数据集路径、日志保存路径等。
  3. 设置GPU可见性:通过设置os.environ['CUDA_VISIBLE_DEVICES']来指定使用的GPU设备。
  4. 定义训练函数:train()函数用于执行一个训练步骤,包括前向传播、计算损失、反向传播更新参数等操作。
  5. 定义验证函数:valid()函数用于在验证集上评估模型的性能,计算平均峰值信噪比(PSNR)。
  6. 主函数:加载配置文件,创建模型、优化器和学习率调度器,加载数据集,进行训练和验证,并保存最佳模型。

总体来说,这段代码实现了一个基于PyTorch的模型训练框架,用于训练图像去雾模型。

运行指令python test.py --model dehazeformer-b --dataset RESIDE-IN --exp indoor

输出格式为

不懂的点

test[ ]和epoch的不同在哪里,为什么输出test有500次,prne和ssim括号外的数值和括号里的数值分别表示的是:当次测试的值和当前测试的平均值吗?

  • 21
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
回答: 本文提出了一种名为EfficientFormerV2的高效网络,旨在重新思考Vision Transformers以实现与MobileNet相当的模型大小和速度。作者结合了细粒度联合搜索策略,通过一系列的设计和优化,使EfficientFormerV2在相同参数量和延迟下比MobileNetV2在ImageNet验证集上的性能高出4个百分点。\[1\]该网络的设计考虑了资源受限型硬件的需求,特别关注模型的参数量和延迟,以适应端侧部署的场景。\[2\]如果您对EfficientFormerV2感兴趣,可以通过扫描二维码或添加微信号CVer222来获取论文和代码,并申请加入CVer-Transformer微信交流群。此外,CVer学术交流群也提供了其他垂直方向的讨论,包括目标检测、图像分割、目标跟踪、人脸检测和识别等多个领域。\[3\] #### 引用[.reference_title] - *1* *3* [更快更强!EfficientFormerV2来了!一种新的轻量级视觉Transformer](https://blog.csdn.net/amusi1994/article/details/128379490)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* [EfficientFormerV2: Transformer家族中的MobileNet](https://blog.csdn.net/CVHub/article/details/129739986)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值