📕作者简介:热编程的贝贝,致力于C/C++、Java、Python等多编程语言,热爱跑步健身,喜爱音乐的一位博主。
📗本文收录于贝贝的日常汇报系列,大家有兴趣的可以看一看
📘相关专栏深度学习、目标检测系列等,大家有兴趣的可以看一看
📙C++零基础入门系列,Web入门篇系列正在发展中,喜欢Python、C++的朋友们可以关注一下哦!
📗如有需要此项目工程,请评论区留言哦.有问题欢迎联系作者WX Qwe1398276934
目录
前言
基于深度学习的图像增强是利用深度学习技术,如卷积神经网络(CNN)和生成对抗网络(GAN),来改进图像质量、增强细节、修复损坏图像等。这些技术可以自动学习从原始图像中提取特征并生成改进后的图像,使图像在视觉上更加吸引人。
以下是一些基于深度学习的图像增强方法:
-
超分辨率重建(Super-Resolution):通过训练深度学习模型,可以将低分辨率图像升级到高分辨率图像,从而增强图像的细节和清晰度。这在图像放大、视频压缩和医学图像分析等领域有广泛应用。
-
去噪(Denoising):深度学习模型可以学习从含有噪声的图像中恢复出更干净的版本。这对于低光照条件下的图像和传感器噪声较大的情况非常有用。
-
图像修复(Image Restoration):对于受损的图像,深度学习模型可以学习如何修复缺失的部分或损坏的区域,使图像恢复到更接近原始状态。
-
颜色恢复与增强(Color Restoration and Enhancement):深度学习可以帮助修复因色彩退化或损失而受影响的图像,或者增强图像的颜色饱和度和对比度。
-
风格迁移(Style Transfer):这种方法使用神经网络将一个图像的风格应用到另一个图像上,从而创造出具有不同风格的图像。
-
图像增强滤镜:一些深度学习模型被训练用于实现各种图像滤镜,如油画效果、水彩效果等,从而给图像增添艺术性。
-
生成对抗网络(GAN)在图像增强中的应用:GAN 可以用于生成与输入图像在质量和视觉上相似但更为逼真的图像。这在图像修复、超分辨率重建和风格迁移等任务中得到了广泛应用。
-
数据增强:在训练深度学习模型时,可以采用各种数据增强技术,如旋转、缩放、平移、剪切等,以生成更多的训练样本,提高模型的鲁棒性和泛化能力。
这些方法的应用领域广泛,从医学影像到艺术创作,都能受益于基于深度学习的图像增强技术的发展。然而,选择适当的方法取决于特定的问题和数据集。
网络构架:PReNet
神经网络框架:torch
编程语言:python
结果图片
一、导入相关包
此处较难安装的库是 torch、其余均可用以下命令:
pip install XXX(包的名称) -i https://pypi.tuna.tsinghua.edu.cn/simple/
来解决,安装torch可以参考博客https://mp.csdn.net/mp_blog/creation/editor/129744112
# 导入相关库
# PyTorch 库
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torch.nn.functional as F
import torchvision
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
# 工具库
import numpy as np
import cv2
import random
import time
import os
import re
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torch.autograd import Variable
import numpy as np
from math import exp
from PIL import Image
二、准备数据集
1.加载数据集的dataset
此处为了方便大家训练,设置图像较小,重写数据集类
代码如下(示例):
class MyTrainDataset(Dataset):
def __init__(self, input_path, label_path):
self.input_path = input_path
self.input_files = os.listdir(input_path)
self.label_path = label_path
self.label_files = os.listdir(label_path)
self.transforms = transforms.Compose([
transforms.CenterCrop([64, 64]),
transforms.ToTensor(),
])
def __len__(self):
return len(self.input_files)
def __getitem__(self, index):
label_image_path = os.path.join(self.label_path, self.label_files[index])
label_image = Image.open(label_image_path).convert('RGB')
'''
Ensure input and label are in couple.
'''
#temp = self.label_files[index][:-4]
#self.input_files[index] = temp + 'x2.png'
input_image_path = os.path.join(self.input_path, self.input_files[index])
input_image = Image.open(input_image_path).convert('RGB')
input = self.transforms(input_image)
label = self.transforms(label_image)
return input, label
'''
Dataset for testing.
'''
class MyValidDataset(Dataset):
def __init__(self, input_path, label_path):
self.input_path = input_path
self.input_files = os.listdir(input_path)
self.label_path = label_path
self.label_files = os.listdir(label_path)
self.transforms = transforms.Compose([
transforms.Resize([512, 512]),
transforms.ToTensor(),
])
def __len__(self):
return len(self.input_files)
def __getitem__(self, index):
label_image_path = os.path.join(self.label_path, self.label_files[index])
label_image = Image.open(label_image_path).convert('RGB')
#temp = self.label_files[index][:-4]
#self.input_files[index] = temp + 'x2.png'
input_image_path = os.path.join(self.input_path, self.input_files[index])
input_image = Image.open(input_image_path).convert('RGB')
input = self.transforms(input_image)
label = self.transforms(label_image)
return input, label
2.读入数据
将相应得input_path,label_path,valid_input_path,valid_label_path修改为自己的图片路径即可.
代码如下(示例):
input_path = "./low_light_images"
label_path = "./reference_images"
valid_input_path = './test/test_low'
valid_label_path = './test/test_high'
dataset_train = MyTrainDataset(input_path, label_path)
dataset_valid = MyValidDataset(valid_input_path, valid_label_path)
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, pin_memory=True)
valid_loader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=True, pin_memory=True)
3、数据集格式
四个文件夹下面的图片故事如下:
——low_light_images
——train1.jpg
——train2.jpg
...
——rederence_images
——train1.jpg
——train2.jpg
...
——test_low
——test1.jpg
——test2.jpg
...
——test_high
——test1.jpg
——test2.jpg
...
三、构建模型
此处使用图像去雨的prNet模型
如果需要有别的任务,可以换别得模型,用prNet也可以完成建模,效果可能不佳。
# 网络架构
class PReNet_r(nn.Module):
def __init__(self, recurrent_iter=6, use_GPU=True):
super(PReNet_r, self).__init__()
self.iteration = recurrent_iter
self.use_GPU = use_GPU
self.conv0 = nn.Sequential(
nn.Conv2d(6, 32, 3, 1, 1),
nn.ReLU()
)
self.res_conv1 = nn.Sequential(
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU(),
nn.Conv2d(32, 32, 3, 1, 1),
nn.ReLU()
)
self.conv_i = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Sigmoid()
)
self.conv_f = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Sigmoid()
)
self.conv_g = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Tanh()
)
self.conv_o = nn.Sequential(
nn.Conv2d(32 + 32, 32, 3, 1, 1),
nn.Sigmoid()
)
self.conv = nn.Sequential(
nn.Conv2d(32, 3, 3, 1, 1),
)
def forward(self, input):
batch_size, row, col = input.size(0), input.size(2), input.size(3)
#mask = Variable(torch.ones(batch_size, 3, row, col)).cuda()
x = input
h = Variable(torch.zeros(batch_size, 32, row, col))
c = Variable(torch.zeros(batch_size, 32, row, col))
if self.use_GPU:
h = h.cuda()
c = c.cuda()
x_list = []
for i in range(self.iteration):
x = torch.cat((input, x), 1)
x = self.conv0(x)
x = torch.cat((x, h), 1)
i = self.conv_i(x)
f = self.conv_f(x)
g = self.conv_g(x)
o = self.conv_o(x)
c = f * c + i * g
h = o * torch.tanh(c)
x = h
for j in range(5):
resx = x
x = F.relu(self.res_conv1(x) + resx)
x = self.conv(x)
x = input + x
x_list.append(x)
return x, x_list
四、损失函数实现
图像重建一般均得评价一般是用ssim来评价,因此损失函数大多也是用ssim,这里一般情况下不用做改动
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def _ssim(img1, img2, window, window_size, channel, size_average = True):
mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size = 11, size_average = True):
super(SSIM, self).__init__()
self.window_size = window_size
self.size_average = size_average
self.channel = 1
self.window = create_window(window_size, self.channel)
def forward(self, img1, img2):
(_, channel, _, _) = img1.size()
if channel == self.channel and self.window.data.type() == img1.data.type():
window = self.window
else:
window = create_window(self.window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
def ssim(img1, img2, window_size = 11, size_average = True):
(_, channel, _, _) = img1.size()
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
五、优化器、超参数等设置
设置学习率,批次大小和迭代次数
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
learning_rate = 1e-3
batch_size = 2
epoch = 60
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=epoch)
六、训练和验证
for i in range(epoch):
# ---------------Train----------------
net.train()
train_losses = []
'''
tqdm is a toolkit for progress bar.
'''
for batch in tqdm(train_loader):
inputs, labels = batch
outputs, _ = net(inputs.to(device))
loss = loss_f(labels.to(device), outputs)
loss = -loss
optimizer.zero_grad()
loss.backward()
'''
Avoid grad to be too BIG.
'''
grad_norm = nn.utils.clip_grad_norm_(net.parameters(), max_norm=10)
optimizer.step()
'''
Attension:
We need set 'loss.item()' to turn Tensor into Numpy, or plt will not work.
'''
train_losses.append(loss.item())
train_loss = sum(train_losses) / len(train_losses)
Loss_list.append(train_loss)
print(f"[ Train | {i + 1:03d}/{epoch:03d} ] SSIM_loss = {train_loss:.5f}")
scheduler.step()
for param_group in optimizer.param_groups:
learning_rate_list.append(param_group["lr"])
print('learning rate %f' % param_group["lr"])
# -------------Validation-------------
'''
Validation is a step to ensure training process is working.
You can also exploit Validation to see if your net work is overfitting.
Firstly, you should set model.eval(), to ensure parameters not training.
'''
net.eval()
valid_losses = []
for batch in tqdm(valid_loader):
inputs, labels = batch
'''
Cancel gradient decent.
'''
with torch.no_grad():
outputs, _ = net(inputs.to(device))
loss = loss_f(labels.to(device), outputs)
loss = -loss
valid_losses.append(loss.item())
valid_loss = sum(valid_losses) / len(valid_losses)
Valid_Loss_list.append(valid_loss)
print(f"[ Valid | {i + 1:03d}/{epoch:03d} ] SSIM_loss = {valid_loss:.5f}")
break_point = i + 1
'''
Update Logs and save the best model.
Patience is also checked.
'''
if valid_loss < best_valid_loss:
print(
f"[ Valid | {i + 1:03d}/{epoch:03d} ] SSIM_loss = {valid_loss:.5f} -> best")
else:
print(
f"[ Valid | {i + 1:03d}/{epoch:03d} ] SSIM_loss = {valid_loss:.5f}")
if valid_loss < best_valid_loss:
print(f'Best model found at epoch {i+1}, saving model')
torch.save(net.state_dict(), f'model_best.ckpt')
best_valid_loss = valid_loss
stale = 0
else:
stale += 1
if stale > patience:
print(f'No improvement {patience} consecutive epochs, early stopping.')
break
七、绘制结果图片
'''
Use plt to draw Loss curves.
'''
plt.figure(dpi=500)
plt.subplot(211)
x = range(break_point)
y = Loss_list
plt.plot(x, y, 'ro-', label='Train Loss')
plt.plot(range(break_point), Valid_Loss_list, 'bs-', label='Valid Loss')
plt.ylabel('Loss')
plt.xlabel('epochs')
plt.subplot(212)
plt.plot(x, learning_rate_list, 'ro-', label='Learning rate')
plt.ylabel('Learning rate')
plt.xlabel('epochs')
plt.legend()
plt.show()
八、预测图片
修改img_path为自己的图片路径,即可完成自己得图像预测
transforms = transforms.Compose([
transforms.Resize([512, 512]),
transforms.ToTensor(),
])
img_path="test/test_low/5.png"
net = PReNet_r(use_GPU=False).to('cpu')#cuda()
net.load_state_dict(torch.load('./model_best.ckpt')) # 加载训练好的模型参数
net.eval()
input_image = Image.open(img_path).convert('RGB')
input = transforms(input_image)
input = input.to('cpu')#cuda()
input=input.unsqueeze(0)
print(input.size())
output_image = net(input)
img=output_image[0]
save_image(img, './'+str(1).zfill(4)+'.jpg') # 直接保存张量图片,自动转换
总结
本文适用于图像重建任务的小白入门,其中包括,图像去噪、图像去雨、图像对比度调整、图像压缩等等等,均可以通过改变模型来实现任务,因为图像输入和输出均为(3,512,512)相对来讲,修改比较容易。
如果这份博客对大家有帮助,希望各位给恒川一个免费的点赞👍作为鼓励,并评论收藏一下⭐,谢谢大家!!!
制作不易,如果大家有什么疑问或给恒川的意见,欢迎评论区留言。