PyTorch中通过训练图像去雾数据集 建立基于SFNet图像去雾算法的完整系统
文章目录
以下文字及代码仅供参考。

SFNet图像去雾算法 PyTorch 附图像去雾数据集

基于SFNet图像去雾算法的完整系统,包括环境配置、数据集准备、模型训练、优化以及界面代码

深度学习模型用于图像恢复(如去雾、超分辨率等)的详细设计。让我们深入解析这个架构的各个部分。
(a) 整体架构
整体架构展示了模型如何处理输入的降质图像并输出恢复后的图像。流程如下:
- 输入层:接收降质图像。
- 浅层特征提取:通过一个
Conv 3x3
卷积层提取浅层特征。 - ResBlock堆叠:多个残差块(ResBlocks)被串联起来,每个ResBlock内部包含复杂的特征学习机制(见© ResBlock)。这些ResBlocks负责学习更深层次的特征表示。
- 上采样与下采样:在某些ResBlocks之间,使用
Conv 1x1
进行通道调整,并通过箭头指示的上采样或下采样操作来改变特征图的空间尺寸。 - 最终恢复:经过一系列特征学习后,通过
Conv 3x3
层生成最终的恢复图像。
(b) 浅层特征提取
浅层特征提取模块主要由几个基础的卷积操作组成:
Conv 3x3
:标准的3x3卷积核,用于提取局部特征。Conv 1x1
:1x1卷积用于调整通道数,不改变空间维度。MCBF
和MDSF
:可能是特定的多尺度融合模块,用于结合不同尺度的信息。
© ResBlock
ResBlock是整个网络的核心组件,它包括:
- 多个
Conv 3x3
层,用于逐层提取特征。 Decoupler
和Modulator
模块(见(d)和(e)),用于解耦和调制特征,增强模型的表达能力。- 残差连接(用⊕符号表示),将输入直接加到输出上,有助于缓解梯度消失问题。
(d) Decoupler
Decoupler模块的作用是将输入特征分解为两部分:
GAP
(全局平均池化):获取全局信息。Split
:将特征分为两部分,分别进行不同的处理。Invert
:可能是一个逆变换操作,用于恢复或转换特征。Concat
:将处理后的特征重新拼接在一起。
(e) Modulator
Modulator模块对特征进行调制:
Sum
、GAP
、FC
(全连接层)、Concat
、Softmax
、Split
等操作共同作用,实现对特征的非线性变换和选择性增强。- 这些操作有助于模型关注更重要的特征,抑制不重要的信息。
总结
该模型通过多层次的特征提取和复杂的特征调制机制,能够有效地从降质图像中恢复出高质量的图像。其设计考虑了特征的多尺度融合、深度残差学习以及特征的动态调制,体现了现代深度学习模型在图像恢复任务中的先进性和复杂性。
1. 环境配置
首先确保你的环境中安装了必要的库:
pip install torch torchvision opencv-python pillow PyQt5
2. 数据集准备
假设你已经有了RSHAZE或其他图像去雾数据集,并且已经按照以下结构组织好:
data/
train/
hazy/
gt/
test/
hazy/
gt/
3. SFNet模型定义
这里我们简化地展示一个基础的SFNet模型定义(实际应用中请参考官方或相关论文中的具体实现):
import torch
import torch.nn as nn
class SFNet(nn.Module):
def __init__(self):
super(SFNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 3, kernel_size=3, padding=1)
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
4. 数据加载与预处理
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision import transforms
class DehazeDataset(Dataset):
def __init__(self, hazy_dir, gt_dir, transform=None):
self.hazy_images = sorted([os.path.join(hazy_dir, img) for img in os.listdir(hazy_dir)])
self.gt_images = sorted([os.path.join(gt_dir, img) for img in os.listdir(gt_dir)])
self.transform = transform
def __len__(self):
return len(self.hazy_images)
def __getitem__(self, idx):
hazy_image = Image.open(self.hazy_images[idx]).convert('RGB')
gt_image = Image.open(self.gt_images[idx]).convert('RGB')
if self.transform:
hazy_image = self.transform(hazy_image)
gt_image = self.transform(gt_image)
return hazy_image, gt_image
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
train_dataset = DehazeDataset('data/train/hazy', 'data/train/gt', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
5. 模型训练
model = SFNet()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10
for epoch in range(num_epochs):
for i, (hazy, gt) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(hazy)
loss = criterion(outputs, gt)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')
6. 界面代码
SFNet图像去雾系统,包括训练、测试和推理(GUI界面),我们需要编写多个Python脚本文件。以下是详细的代码示例:
1. main.py
- 训练和测试脚本
import argparse
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
from SFNet_model import SFNet # 假设SFNet模型定义在SFNet_model.py中
class DehazeDataset(Dataset):
def __init__(self, hazy_dir, gt_dir, transform=None):
self.hazy_images = sorted([os.path.join(hazy_dir, img) for img in os.listdir(hazy_dir)])
self.gt_images = sorted([os.path.join(gt_dir, img) for img in os.listdir(gt_dir)])
self.transform = transform
def __len__(self):
return len(self.hazy_images)
def __getitem__(self, idx):
hazy_image = Image.open(self.hazy_images[idx]).convert('RGB')
gt_image = Image.open(self.gt_images[idx]).convert('RGB')
if self.transform:
hazy_image = self.transform(hazy_image)
gt_image = self.transform(gt_image)
return hazy_image, gt_image
def train(args):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
dataset = DehazeDataset(os.path.join(args.data_dir, 'train', 'hazy'),
os.path.join(args.data_dir, 'train', 'gt'),
transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
model = SFNet().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
for epoch in range(args.num_epoch):
for i, (hazy, gt) in enumerate(dataloader):
hazy, gt = hazy.cuda(), gt.cuda()
optimizer.zero_grad()
outputs = model(hazy)
loss = criterion(outputs, gt)
loss.backward()
optimizer.step()
print(f'Epoch [{epoch+1}/{args.num_epoch}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item()}')
torch.save(model.state_dict(), f'results/SFNet/{args.data}/Training-Results/Epoch_{epoch+1}.pkl')
def test(args):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
dataset = DehazeDataset(os.path.join(args.data_dir, 'test', 'hazy'),
os.path.join(args.data_dir, 'test', 'gt'),
transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
model = SFNet().cuda()
model.load_state_dict(torch.load(args.test_model))
model.eval()
with torch.no_grad():
for i, (hazy, gt) in enumerate(dataloader):
hazy, gt = hazy.cuda(), gt.cuda()
outputs = model(hazy)
if args.save_image:
for j in range(outputs.size(0)):
output_img = transforms.ToPILImage()(outputs[j].cpu())
output_img.save(f'results/SFNet/{args.data}/Test-Results/image_{i*args.batch_size+j}.png')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SFNet Image Dehazing')
parser.add_argument('--data_dir', type=str, required=True, help='directory of the dataset')
parser.add_argument('--data', type=str, required=True, help='dataset name')
parser.add_argument('--mode', type=str, required=True, choices=['train', 'test'], help='train or test mode')
parser.add_argument('--batch_size', type=int, default=4, help='batch size')
parser.add_argument('--learning_rate', type=float, default=2e-5, help='learning rate')
parser.add_argument('--num_epoch', type=int, default=300, help='number of epochs')
parser.add_argument('--test_model', type=str, default='', help='path to the trained model for testing')
parser.add_argument('--save_image', type=bool, default=False, help='whether to save dehazed images')
args = parser.parse_args()
if args.mode == 'train':
train(args)
elif args.mode == 'test':
test(args)
2. SFNet_model.py
- SFNet模型定义
import torch
import torch.nn as nn
class SFNet(nn.Module):
def __init__(self):
super(SFNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 3, kernel_size=3, padding=1)
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
3. GUI.py
- GUI界面代码
import sys
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QVBoxLayout, QLabel, QFileDialog
from PyQt5.QtGui import QPixmap
import cv2
import numpy as np
import torch
from torchvision import transforms
from SFNet_model import SFNet
class DehazeApp(QWidget):
def __init__(self):
super().__init__()
self.initUI()
def initUI(self):
self.setWindowTitle('图像去雾')
self.setGeometry(100, 100, 800, 400)
layout = QVBoxLayout()
self.btn_select = QPushButton('选择图像', self)
self.btn_select.clicked.connect(self.select_image)
layout.addWidget(self.btn_select)
self.btn_dehaze = QPushButton('SFNet去雾', self)
self.btn_dehaze.clicked.connect(self.dehaze_image)
layout.addWidget(self.btn_dehaze)
self.image_label = QLabel(self)
layout.addWidget(self.image_label)
self.setLayout(layout)
def select_image(self):
options = QFileDialog.Options()
fileName, _ = QFileDialog.getOpenFileName(self, "选择图像", "", "Images (*.png *.xpm *.jpg *.bmp);;All Files (*)", options=options)
if fileName:
self.image_path = fileName
pixmap = QPixmap(fileName)
self.image_label.setPixmap(pixmap.scaled(400, 400))
def dehaze_image(self):
if hasattr(self, 'image_path'):
# Load and preprocess image
image = cv2.imread(self.image_path)
image = cv2.resize(image, (256, 256))
image = image / 255.0
image = np.transpose(image, (2, 0, 1))
image = torch.tensor(image, dtype=torch.float32).unsqueeze(0).cuda()
# Load pre-trained model
model = SFNet().cuda()
model.load_state_dict(torch.load('results/SFNet/Outdoor/Training-Results/Best.pkl'))
model.eval()
# Perform dehazing
with torch.no_grad():
output = model(image).squeeze().cpu().numpy()
output = np.transpose(output, (1, 2, 0))
output = (output * 255).astype(np.uint8)
# Display result
cv2.imwrite('dehazed.jpg', output)
pixmap = QPixmap('dehazed.jpg')
self.image_label.setPixmap(pixmap.scaled(400, 400))
if __name__ == '__main__':
app = QApplication(sys.argv)
ex = DehazeApp()
ex.show()
sys.exit(app.exec_())
运行步骤
-
训练模型:
python main.py --data_dir dehaze --data Outdoor --mode train --batch_size 4 --learning_rate 2e-5 --num_epoch 300
-
测试模型:
python main.py --data_dir dehaze --data Outdoor --mode test --batch_size 4 --test_model results/SFNet/Outdoor/Training-Results/Best.pkl --save_image True
-
运行GUI界面:
python GUI.py
确保所有路径正确,并根据实际情况调整参数和文件路径。