目录结构:
配置文件,里面包含了一些超参数和模型设置的选项。
- "batch_size": 32,代表训练时每个batch使用的样本数。
- "patch_size": 256,代表将原始图像切成多少大小为256x256的小块进行训练。
- "valid_mode": "test",代表验证集使用的数据集,这里是test。
- "edge_decay": 0,代表去除图像边缘的程度,这里为0。
- "only_h_flip": false,代表是否只使用水平翻转数据扩增,这里为false。
- "optimizer": "adamw",代表使用的优化器,这里是AdamW。
- "lr": 4e-4/,代表学习率,这里是4e-4或2e-4。
- "epochs":300,代表训练的总epoch数。
- "eval_freq": 1,代表每多少个epoch输出一次验证集的结果,输出psnr和ssim的值。
自己训练的话,创建一个配置文件来训练。
loader.py实现了一个数据加载器。其中包括两个类:PairLoader
和 SingleLoader。
PairLoader
类用于对图像进行配对加载,即同时加载原始图像和目标图像,原始图像存放在 hazy
目录下,目标图像存放在 GT
目录下。其中 data_dir
是数据集的根目录,sub_dir
是数据集子目录(比如训练集、验证集等),mode
是数据加载模式,支持三种模式:训练模式、验证模式和测试模式。size
是图像的尺寸,edge_decay
是边缘衰减系数,only_h_flip
表示是否只进行水平翻转。
SingleLoader
类用于单独加载图像,其中 root_dir
是图像所在目录。
在 PairLoader
类中,__len__
方法返回数据集中样本数目,__getitem__
方法根据索引 idx
返回对应的数据样本。在此方法中,首先使用 cv2
库读取原始图像和目标图像,并将像素值的范围从 [0, 1]
转换为 [-1, 1]
。然后,如果是训练模式,调用 augment
函数进行数据增强;如果是验证模式,调用 align
函数对图像进行裁剪居中。最后,将处理后的图像和文件名返回。
augment
函数用于进行数据增强,包括对图像进行随机裁剪、水平翻转和旋转等操作。其中 size
是目标图像的尺寸,edge_decay
是边缘衰减系数,only_h_flip
表示是否只进行水平翻转。
align
函数用于将图像进行居中裁剪,以适应模型输入要求。
hwc_to_chw
函数用于将图像从 HWC
(高度、宽度、通道)格式转换为 CHW
(通道、高度、宽度)格式。
read_img
函数用于读取图像文件,并将像素值的范围从 [0, 255]
转换为 [0, 1]
。
arch.png网络结构图
dehazefoemer网络的定义:
-
PatchEmbed 和 PatchUnEmbed 类:用于将输入图像分割成非重叠的图像块(patch),并将这些图像块重新组合成输出图像。
-
SKFusion 类:实现了一种特征融合机制,通过计算特征的均值和通过MLP学习到的权重来融合输入的多个特征图。
-
RLN 类:Revised LayerNorm,是一个自定义的归一化层,用于对输入进行归一化处理。
-
Mlp 类:多层感知机(MLP)模块,用于在Transformer中对输入特征进行非线性变换和映射。
-
WindowAttention 类:基于窗口的注意力机制,用于在Transformer中计算注意力权重。
-
Attention 和 TransformerBlock 类:实现了Transformer中的自注意力机制和前馈网络部分。
-
BasicLayer 类:包含多个 TransformerBlock,用于构建整个Transformer的层。
-
DehazeFormer 类:整个图像去雾模型的定义,包括了上述的各种自定义层和模块,以及多个 BasicLayer 组成的主干部分。
-
最后是一些工厂函数 dehazeformer_t、dehazeformer_s 等,用于创建不同配置的 DehazeFormer 模型实例。
common.py:定义了一些在图像处理中常用的辅助函数和类。
-
AverageMeter
类是一个用于计算平均值的计量器。它具有以下方法:reset()
:重置计量器的值update(val, n=1)
:更新计量器的值,val
表示要更新的数值,n
表示该数值的权重,默认为1
-
ListAverageMeter
类是一个用于计算列表中元素的平均值的计量器。它具有以下方法:reset()
:重置计量器的值set_len(n)
:设置列表的长度为n
,并重置计量器的值update(vals, n=1)
:更新计量器的值,vals
是一个长度为n
的列表,表示要更新的数值,n
表示每个数值的权重,默认为1
-
read_img(filename)
函数用于读取图像文件,并返回一个经过归一化处理的图像数组。它使用OpenCV库的imread()
函数读取图像,并将图像通道顺序从BGR转换为RGB,并进行归一化处理。 -
write_img(filename, img)
函数用于将图像数组保存为图像文件。它将图像数组重新转换为BGR通道顺序,并使用OpenCV库的imwrite()
函数将图像保存为指定文件名的图像文件。 -
hwc_to_chw(img)
函数用于将图像数组的通道顺序从HWC(高度、宽度、通道)转换为CHW(通道、高度、宽度)。 -
chw_to_hwc(img)
函数用于将图像数组的通道顺序从CHW转换为HWC。
这些函数和类可以在图像处理任务中使用,例如计算图像的平均值、读取和保存图像文件,以及在深度学习模型中转换图像的通道顺序。
data_parallel.py:这段代码是一个自定义的BalancedDataParallel类,继承自torch.nn.DataParallel类。它主要用于在多个GPU上平衡地并行处理模型的输入数据。
代码中定义了scatter函数和scatter_kwargs函数,这两个函数用于将输入数据切分成多个块,并将这些块分布到指定的GPU上。scatter函数用于普通的输入数据,而scatter_kwargs函数用于带有关键字参数的输入数据。
BalancedDataParallel类重写了DataParallel类的forward()方法,实现了在多个GPU上平衡地并行执行模型的前向传播操作。具体来说,代码中的forward()方法首先根据GPU数量和指定的gpu0_bsz参数,决定使用哪些GPU进行计算。然后,根据选择的GPU数量,将输入数据切分成多个块,并复制模型的副本到相应的GPU上。接下来,使用parallel_apply函数在各个GPU上并行地应用模型的副本,得到每个GPU上的输出结果。最后,使用gather函数将各个GPU上的输出结果收集起来,并返回最终的输出结果。
总的来说,这段代码实现了一个平衡地在多个GPU上进行并行计算的功能,可以提高模型训练和推理的效率。
run.sh 测试脚本
test.py用于测试图像去雾模型的 Python 程序,其主要功能如下:
-
定义了命令行参数,包括模型名称、数据集路径、结果保存路径等。用户可以通过命令行修改这些参数。
-
定义了一个函数
single()
,用于处理从保存的模型中加载的状态字典。在 PyTorch 中使用多 GPU 训练模型时,保存的状态字典中包含了每个 GPU 的参数以及一些额外信息,这可能会导致在单 GPU 上加载模型时出现问题。该函数的作用就是将保存的状态字典中前缀为module.
的字符串删除。 -
定义了一个函数
test()
,用于测试模型并保存结果。该函数的输入参数包括数据加载器、模型、结果保存路径等。该函数首先将模型设置为评估模式,并使用测试数据集逐个样本地进行测试。对于每个样本,该函数首先将输入图像输入模型并得到输出图像,然后计算输出图像与目标图像的 PSNR 和 SSIM,最后将结果保存到文件中。 -
在
__main__
函数中,创建模型实例并加载已训练的模型,然后创建数据集并调用test()
函数进行测试。在测试过程中,程序会打印出当前测试进度和每个测试样本的 PSNR 和 SSIM。 -
其他辅助函数包括
write_img()
和chw_to_hwc()
,用于保存图像和从 CHW 格式转换为 HWC 格式。
train.py 是一个训练脚本。主要包括以下几个部分:
- 导入所需的库和模块:包括os、argparse、json、torch等。
- 解析命令行参数:使用argparse模块解析用户输入的命令行参数,包括模型名称、数据集路径、日志保存路径等。
- 设置GPU可见性:通过设置
os.environ['CUDA_VISIBLE_DEVICES']
来指定使用的GPU设备。 - 定义训练函数:
train()
函数用于执行一个训练步骤,包括前向传播、计算损失、反向传播更新参数等操作。 - 定义验证函数:
valid()
函数用于在验证集上评估模型的性能,计算平均峰值信噪比(PSNR)。 - 主函数:加载配置文件,创建模型、优化器和学习率调度器,加载数据集,进行训练和验证,并保存最佳模型。
总体来说,这段代码实现了一个基于PyTorch的模型训练框架,用于训练图像去雾模型。
运行指令python test.py --model dehazeformer-b --dataset RESIDE-IN --exp indoor
输出格式为
不懂的点
test[ ]和epoch的不同在哪里,为什么输出test有500次,prne和ssim括号外的数值和括号里的数值分别表示的是:当次测试的值和当前测试的平均值吗?