PyTorch中通过训练图像去雾数据集 建立基于SFNet图像去雾算法的完整系统 深度学习模型用于图像恢复(去雾、超分辨率等)解析图像去雾架构及模型。

PyTorch中通过训练图像去雾数据集 建立基于SFNet图像去雾算法的完整系统


以下文字及代码仅供参考。
在这里插入图片描述
SFNet图像去雾算法 PyTorch 附图像去雾数据集
在这里插入图片描述
基于SFNet图像去雾算法的完整系统,包括环境配置、数据集准备、模型训练、优化以及界面代码
在这里插入图片描述
深度学习模型用于图像恢复(如去雾、超分辨率等)的详细设计。让我们深入解析这个架构的各个部分。

(a) 整体架构

整体架构展示了模型如何处理输入的降质图像并输出恢复后的图像。流程如下:

  1. 输入层:接收降质图像。
  2. 浅层特征提取:通过一个Conv 3x3卷积层提取浅层特征。
  3. ResBlock堆叠:多个残差块(ResBlocks)被串联起来,每个ResBlock内部包含复杂的特征学习机制(见© ResBlock)。这些ResBlocks负责学习更深层次的特征表示。
  4. 上采样与下采样:在某些ResBlocks之间,使用Conv 1x1进行通道调整,并通过箭头指示的上采样或下采样操作来改变特征图的空间尺寸。
  5. 最终恢复:经过一系列特征学习后,通过Conv 3x3层生成最终的恢复图像。

(b) 浅层特征提取

浅层特征提取模块主要由几个基础的卷积操作组成:

  • Conv 3x3:标准的3x3卷积核,用于提取局部特征。
  • Conv 1x1:1x1卷积用于调整通道数,不改变空间维度。
  • MCBFMDSF:可能是特定的多尺度融合模块,用于结合不同尺度的信息。

© ResBlock

ResBlock是整个网络的核心组件,它包括:

  • 多个Conv 3x3层,用于逐层提取特征。
  • DecouplerModulator模块(见(d)和(e)),用于解耦和调制特征,增强模型的表达能力。
  • 残差连接(用⊕符号表示),将输入直接加到输出上,有助于缓解梯度消失问题。

(d) Decoupler

Decoupler模块的作用是将输入特征分解为两部分:

  • GAP(全局平均池化):获取全局信息。
  • Split:将特征分为两部分,分别进行不同的处理。
  • Invert:可能是一个逆变换操作,用于恢复或转换特征。
  • Concat:将处理后的特征重新拼接在一起。

(e) Modulator

Modulator模块对特征进行调制:

  • SumGAPFC(全连接层)、ConcatSoftmaxSplit等操作共同作用,实现对特征的非线性变换和选择性增强。
  • 这些操作有助于模型关注更重要的特征,抑制不重要的信息。

总结

该模型通过多层次的特征提取和复杂的特征调制机制,能够有效地从降质图像中恢复出高质量的图像。其设计考虑了特征的多尺度融合、深度残差学习以及特征的动态调制,体现了现代深度学习模型在图像恢复任务中的先进性和复杂性。

1. 环境配置

首先确保你的环境中安装了必要的库:

pip install torch torchvision opencv-python pillow PyQt5

2. 数据集准备

假设你已经有了RSHAZE或其他图像去雾数据集,并且已经按照以下结构组织好:

data/
    train/
        hazy/
        gt/
    test/
        hazy/
        gt/

3. SFNet模型定义

这里我们简化地展示一个基础的SFNet模型定义(实际应用中请参考官方或相关论文中的具体实现):

import torch
import torch.nn as nn

class SFNet(nn.Module):
    def __init__(self):
        super(SFNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

4. 数据加载与预处理

from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
from torchvision import transforms

class DehazeDataset(Dataset):
    def __init__(self, hazy_dir, gt_dir, transform=None):
        self.hazy_images = sorted([os.path.join(hazy_dir, img) for img in os.listdir(hazy_dir)])
        self.gt_images = sorted([os.path.join(gt_dir, img) for img in os.listdir(gt_dir)])
        self.transform = transform

    def __len__(self):
        return len(self.hazy_images)

    def __getitem__(self, idx):
        hazy_image = Image.open(self.hazy_images[idx]).convert('RGB')
        gt_image = Image.open(self.gt_images[idx]).convert('RGB')

        if self.transform:
            hazy_image = self.transform(hazy_image)
            gt_image = self.transform(gt_image)

        return hazy_image, gt_image

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

train_dataset = DehazeDataset('data/train/hazy', 'data/train/gt', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)

5. 模型训练

model = SFNet()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    for i, (hazy, gt) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(hazy)
        loss = criterion(outputs, gt)
        loss.backward()
        optimizer.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')

6. 界面代码

在这里插入图片描述

SFNet图像去雾系统,包括训练、测试和推理(GUI界面),我们需要编写多个Python脚本文件。以下是详细的代码示例:

1. main.py - 训练和测试脚本

import argparse
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np
from SFNet_model import SFNet  # 假设SFNet模型定义在SFNet_model.py中

class DehazeDataset(Dataset):
    def __init__(self, hazy_dir, gt_dir, transform=None):
        self.hazy_images = sorted([os.path.join(hazy_dir, img) for img in os.listdir(hazy_dir)])
        self.gt_images = sorted([os.path.join(gt_dir, img) for img in os.listdir(gt_dir)])
        self.transform = transform

    def __len__(self):
        return len(self.hazy_images)

    def __getitem__(self, idx):
        hazy_image = Image.open(self.hazy_images[idx]).convert('RGB')
        gt_image = Image.open(self.gt_images[idx]).convert('RGB')

        if self.transform:
            hazy_image = self.transform(hazy_image)
            gt_image = self.transform(gt_image)

        return hazy_image, gt_image

def train(args):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    dataset = DehazeDataset(os.path.join(args.data_dir, 'train', 'hazy'), 
                            os.path.join(args.data_dir, 'train', 'gt'), 
                            transform=transform)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)

    model = SFNet().cuda()
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    for epoch in range(args.num_epoch):
        for i, (hazy, gt) in enumerate(dataloader):
            hazy, gt = hazy.cuda(), gt.cuda()

            optimizer.zero_grad()
            outputs = model(hazy)
            loss = criterion(outputs, gt)
            loss.backward()
            optimizer.step()

            print(f'Epoch [{epoch+1}/{args.num_epoch}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item()}')

        torch.save(model.state_dict(), f'results/SFNet/{args.data}/Training-Results/Epoch_{epoch+1}.pkl')

def test(args):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor()
    ])

    dataset = DehazeDataset(os.path.join(args.data_dir, 'test', 'hazy'), 
                            os.path.join(args.data_dir, 'test', 'gt'), 
                            transform=transform)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)

    model = SFNet().cuda()
    model.load_state_dict(torch.load(args.test_model))
    model.eval()

    with torch.no_grad():
        for i, (hazy, gt) in enumerate(dataloader):
            hazy, gt = hazy.cuda(), gt.cuda()
            outputs = model(hazy)

            if args.save_image:
                for j in range(outputs.size(0)):
                    output_img = transforms.ToPILImage()(outputs[j].cpu())
                    output_img.save(f'results/SFNet/{args.data}/Test-Results/image_{i*args.batch_size+j}.png')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='SFNet Image Dehazing')
    parser.add_argument('--data_dir', type=str, required=True, help='directory of the dataset')
    parser.add_argument('--data', type=str, required=True, help='dataset name')
    parser.add_argument('--mode', type=str, required=True, choices=['train', 'test'], help='train or test mode')
    parser.add_argument('--batch_size', type=int, default=4, help='batch size')
    parser.add_argument('--learning_rate', type=float, default=2e-5, help='learning rate')
    parser.add_argument('--num_epoch', type=int, default=300, help='number of epochs')
    parser.add_argument('--test_model', type=str, default='', help='path to the trained model for testing')
    parser.add_argument('--save_image', type=bool, default=False, help='whether to save dehazed images')

    args = parser.parse_args()

    if args.mode == 'train':
        train(args)
    elif args.mode == 'test':
        test(args)

2. SFNet_model.py - SFNet模型定义

import torch
import torch.nn as nn

class SFNet(nn.Module):
    def __init__(self):
        super(SFNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=3, padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

3. GUI.py - GUI界面代码

import sys
from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QVBoxLayout, QLabel, QFileDialog
from PyQt5.QtGui import QPixmap
import cv2
import numpy as np
import torch
from torchvision import transforms
from SFNet_model import SFNet

class DehazeApp(QWidget):
    def __init__(self):
        super().__init__()
        self.initUI()

    def initUI(self):
        self.setWindowTitle('图像去雾')
        self.setGeometry(100, 100, 800, 400)

        layout = QVBoxLayout()

        self.btn_select = QPushButton('选择图像', self)
        self.btn_select.clicked.connect(self.select_image)
        layout.addWidget(self.btn_select)

        self.btn_dehaze = QPushButton('SFNet去雾', self)
        self.btn_dehaze.clicked.connect(self.dehaze_image)
        layout.addWidget(self.btn_dehaze)

        self.image_label = QLabel(self)
        layout.addWidget(self.image_label)

        self.setLayout(layout)

    def select_image(self):
        options = QFileDialog.Options()
        fileName, _ = QFileDialog.getOpenFileName(self, "选择图像", "", "Images (*.png *.xpm *.jpg *.bmp);;All Files (*)", options=options)
        if fileName:
            self.image_path = fileName
            pixmap = QPixmap(fileName)
            self.image_label.setPixmap(pixmap.scaled(400, 400))

    def dehaze_image(self):
        if hasattr(self, 'image_path'):
            # Load and preprocess image
            image = cv2.imread(self.image_path)
            image = cv2.resize(image, (256, 256))
            image = image / 255.0
            image = np.transpose(image, (2, 0, 1))
            image = torch.tensor(image, dtype=torch.float32).unsqueeze(0).cuda()

            # Load pre-trained model
            model = SFNet().cuda()
            model.load_state_dict(torch.load('results/SFNet/Outdoor/Training-Results/Best.pkl'))
            model.eval()

            # Perform dehazing
            with torch.no_grad():
                output = model(image).squeeze().cpu().numpy()
                output = np.transpose(output, (1, 2, 0))
                output = (output * 255).astype(np.uint8)

            # Display result
            cv2.imwrite('dehazed.jpg', output)
            pixmap = QPixmap('dehazed.jpg')
            self.image_label.setPixmap(pixmap.scaled(400, 400))

if __name__ == '__main__':
    app = QApplication(sys.argv)
    ex = DehazeApp()
    ex.show()
    sys.exit(app.exec_())

运行步骤

  1. 训练模型

    python main.py --data_dir dehaze --data Outdoor --mode train --batch_size 4 --learning_rate 2e-5 --num_epoch 300
    
  2. 测试模型

    python main.py --data_dir dehaze --data Outdoor --mode test --batch_size 4 --test_model results/SFNet/Outdoor/Training-Results/Best.pkl --save_image True
    
  3. 运行GUI界面

    python GUI.py
    

确保所有路径正确,并根据实际情况调整参数和文件路径。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值