pytorch 图像恢复问题的代码实现详解(derain,dehaze,deblur,denoise等通用)_深度学习图像修复pytorch代码(1)

def \_\_getitem\_\_(self, index):

    ps = self.ps
    index = index % len(self.targetImages)

    inputImagePath = os.path.join(self.inputPath, self.inputImages[index])  # 图片完整路径
    inputImage = Image.open(inputImagePath).convert('RGB')  # 读取图片

    targetImagePath = os.path.join(self.targetPath, self.targetImages[index])
    targetImage = Image.open(targetImagePath).convert('RGB')

    inputImage = ttf.to_tensor(inputImage)  # 将图片转为张量
    targetImage = ttf.to_tensor(targetImage)

    hh, ww = targetImage.shape[1], targetImage.shape[2]  # 图片的高和宽

    rr = random.randint(0, hh-ps)  # 随机数: patch 左下角的坐标 (rr, cc)
    cc = random.randint(0, ww-ps)
    # aug = random.randint(0, 8) # 随机数,对应对图片进行的操作

    input_ = inputImage[:, rr:rr+ps, cc:cc+ps]  # 裁剪 patch ,输入和目标 patch 要对应相同
    target = targetImage[:, rr:rr+ps, cc:cc+ps]

    return input_, target

##### 评估数据集


在网络训练中,不一定最后一次训练的效果就是最好的。评估数据集是在每一个 epoch 训练结束后对网络训练的性能进行评估,目的在于将最好的一次训练结果保存。



class MyValueDataSet(Dataset): # 评估数据集
def __init__(self, inputPathTrain, targetPathTrain, patch_size=128):
super(MyValueDataSet, self).init()

    self.inputPath = inputPathTrain
    self.inputImages = os.listdir(inputPathTrain)  # 输入图片路径下的所有文件名列表

    self.targetPath = targetPathTrain
    self.targetImages = os.listdir(targetPathTrain)  # 目标图片路径下的所有文件名列表

    self.ps = patch_size

def \_\_len\_\_(self):
    return len(self.targetImages)

def \_\_getitem\_\_(self, index):

    ps = self.ps
    index = index % len(self.targetImages)

    inputImagePath = os.path.join(self.inputPath, self.inputImages[index])  # 图片完整路径
    inputImage = Image.open(inputImagePath).convert('RGB')  # 读取图片,灰度图

    targetImagePath = os.path.join(self.targetPath, self.targetImages[index])
    targetImage = Image.open(targetImagePath).convert('RGB')

    inputImage = ttf.center_crop(inputImage, (ps, ps))
    targetImage = ttf.center_crop(targetImage, (ps, ps))

    input_ = ttf.to_tensor(inputImage)  # 将图片转为张量
    target = ttf.to_tensor(targetImage)

    return input_, target

##### 测试数据集


测试数据集的目的是将输入有雨进行去雨得到去雨后的结果,注意输入一般是原图大小,不进行裁剪。



class MyTestDataSet(Dataset): # 测试数据集
def __init__(self, inputPathTest):
super(MyTestDataSet, self).init()

    self.inputPath = inputPathTest
    self.inputImages = os.listdir(inputPathTest)  # 输入图片路径下的所有文件名列表

def \_\_len\_\_(self):
    return len(self.inputImages)  # 路径里的图片数量

def \_\_getitem\_\_(self, index):
    index = index % len(self.inputImages)

    inputImagePath = os.path.join(self.inputPath, self.inputImages[index])  # 图片完整路径
    inputImage = Image.open(inputImagePath).convert('RGB')  # 读取图片

    input_ = ttf.to_tensor(inputImage)  # 将图片转为张量

    return input_

#### 网络模型


以一个 5 层简单卷积神经网络为例子,具体网络自己设定。  
 `NetModel.py`



class Net(nn.Module):
def __init__(self):
super(Net, self).init()
self.inconv = nn.Sequential( # 输入层网络
nn.Conv2d(3, 32, 3, 1, 1),
nn.ReLU(inplace=True)
)
self.midconv = nn.Sequential( # 中间层网络
nn.Conv2d(3, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(3, 32, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(3, 32, 3, 1, 1),
nn.ReLU(inplace=True),
)
self.outconv = nn.Sequential( # 输出层网络
nn.Conv2d(3, 32, 3, 1, 1),
)

def forward(self, x):

    x = self.inconv(x)
    x = self.midconv(x)
    x = self.outconv(x)
    
    return x

#### 自定义工具包


自定义工具包主要是一个计算峰值信噪比(PSNR)的方法用来对训练进行评估。


`utils.py`



import torch

def torchPSNR(tar_img, prd_img):
imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
rmse = (imdff**2).mean().sqrt()
ps = 20*torch.log10(1/rmse)
return ps


#### 网络训练和测试


`main.py`



import sys
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm, trange # 进度条
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR
import utils
from NetModel import Net
from MyDataset import *

if name == ‘__main__’: # 只有在 main 中才能开多线程
EPOCH = 100 # 训练次数
BATCH_SIZE = 18 # 每批的训练数量
LEARNING_RATE = 1e-3 # 学习率
loss_list = [] # 损失存储数组
best_psnr = 0 # 训练最好的峰值信噪比
best_epoch = 0 # 峰值信噪比最好时的 epoch

inputPathTrain = 'E://Rain100H/inputTrain/'  # 训练输入图片路径
targetPathTrain = 'E://Rain100H/targetTrain/'  # 训练目标图片路径
inputPathTest = 'E://Rain100H/inputTest/'  # 测试输入图片路径
resultPathTest = 'E://Rain100H/resultTest/'  # 测试结果图片路径
targetPathTest = 'E://Rain100H/targetTest/'  # 测试目标图片路径

myNet = Net()  # 实例化网络
myNet = myNet.cuda()  # 网络放入GPU中
criterion = nn.MSELoss().cuda()

optimizer = optim.Adam(myNet.parameters(), lr=LEARNING_RATE)  # 网络参数优化算法

# 训练数据
datasetTrain = MyTrainDataSet(inputPathTrain, targetPathTrain)  # 实例化训练数据集类
# 可迭代数据加载器加载训练数据
trainLoader = DataLoader(dataset=datasetTrain, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=6, pin_memory=True)

# 评估数据
datasetValue = MyValueDataSet(inputPathTest, targetPathTest)  # 实例化评估数据集类
valueLoader = DataLoader(dataset=datasetValue, batch_size=16, shuffle=True, drop_last=False, num_workers=6, pin_memory=True)

# 测试数据
datasetTest = MyTestDataSet(inputPathTest)  # 实例化测试数据集类
# 可迭代数据加载器加载测试数据
testLoader = DataLoader(dataset=datasetTest, batch_size=1, shuffle=False, drop_last=False, num_workers=6, pin_memory=True)

# 开始训练
print('-------------------------------------------------------------------------------------------------------')
if os.path.exists('./model\_best.pth'):  # 判断是否预训练
    myNet.load_state_dict(torch.load('./model\_best.pth'))

for epoch in range(EPOCH):
    myNet.train()  # 指定网络模型训练状态
    iters = tqdm(trainLoader, file=sys.stdout)  # 实例化 tqdm,自定义
    epochLoss = 0  # 每次训练的损失
    timeStart = time.time()  # 每次训练开始时间
    for index, (x, y) in enumerate(iters, 0):

        myNet.zero_grad()  # 模型参数梯度置0

惊喜

最后还准备了一套上面资料对应的面试题(有答案哦)和面试时的高频面试算法题(如果面试准备时间不够,那么集中把这些算法题做完即可,命中率高达85%+)

image.png

image.png

merate(iters, 0):

        myNet.zero_grad()  # 模型参数梯度置0

惊喜

最后还准备了一套上面资料对应的面试题(有答案哦)和面试时的高频面试算法题(如果面试准备时间不够,那么集中把这些算法题做完即可,命中率高达85%+)

[外链图片转存中…(img-Qv8TMgpb-1714438332641)]

[外链图片转存中…(img-EF7cnZ3E-1714438332642)]

本文已被CODING开源项目:【一线大厂Java面试题解析+核心总结学习笔记+最新讲解视频+实战项目源码】收录

  • 11
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值