基于pytorch的SRGAN实现
前言
SRGAN是发表在顶会CVPR2017的文章, 利用GAN进行超分实现了不错的效果!
注意到的是,SRGAN是目前SR领域中引用量最高的论文.
链接地址: https://arxiv.org/abs/1609.04802v5
作者代码链接: https://github.com/leftthomas/SRGAN
SRGAN论文概要(贡献)
- 深度RESNet(SRRESNet)针对MSE进行了优化,通过PSNR和结构相似度(SSIM)来测量图像SR的高放大因子
- SRGAN,是一种基于GAN的网络,针对一种新的感知损失进行了优化。用在VGG网络的特征映射上计算的损失来代替基于MSE的内容损失,该特征映射对像素空间的变化更加不变,这样相较于原来像素损失超分的图像更具有纹理等高频细节.
- 对来自三个公共基准数据集的图像进行广泛的平均意见得分(MOS)测试,证实SRGAN在很大程度上是高放大因子(4×)的照片真实感SR图像估计的最新技术, 即超分后的图像更加接近自然图像.
网络结构和损失函数
对应的详解在相应的代码实现处
网络模型:
Perceptual loss function(感知损失函数或总损失)
Content loss(内容损失)
Adversarial loss(对抗损失)
pytorch代码实现
1. 准备工作
1.1 数据下载并放到合适位置
train 和 val 数据集是从VOC2012中采样得到的
VOC2012:链接地址 提取码: 5tzp
测试图像数据集来自Set5 Set14 BSD100 Urban100 SunHays80 链接地址
下载图像数据集,然后将其解压到data目录中
如图所示:
注意: 如需训练自己的数据集,请准备好原图和对应插值缩放后的图片
2. 开始训练和测试
训练: (1) 打开终端,进入当前文件目录
(2) 选择指定的参数, 未指定的情况下按代码中的默认值处理
也可以直接运行README文件中相应代码:
训练完成后,训练结果会保存到benchmark_results 文件夹中
测试过程同训练过程一样, 对应实现即可!其它参数细节论文里面均有说明
源码详解
1. 数据集的加载: data_utils.py
from os import listdir
from os.path import join
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize
def is_image_file(filename):
# 判断文件名是否是图像文件
return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
def calculate_valid_crop_size(crop_size, upscale_factor):
# 计算可用的裁剪尺寸
return crop_size - (crop_size % upscale_factor)
def train_hr_transform(crop_size):
# 训练集的高分辨率图像转换
return Compose([
RandomCrop(crop_size), # 随机裁剪图像到指定尺寸
ToTensor(), # 将图像转换为张量
])
def train_lr_transform(crop_size, upscale_factor):
# 训练集的低分辨率图像转换
return Compose([
ToPILImage(), # 将张量转换为PIL图像
Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC), # 将图像缩放到指定尺寸,使用双三次插值方法
ToTensor() # 将图像转换为张量
])
def display_transform():
# 显示图像的转换
return Compose([
ToPILImage(), # 将张量转换为PIL图像
Resize(400), # 将图像调整大小为400x400
CenterCrop(400), # 对图像进行中心裁剪为400x400
ToTensor() # 将图像转换为张量
])
# 加载训练集中的图像数据
class TrainDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, crop_size, upscale_factor):
super(TrainDatasetFromFolder, self).__init__()
# 获取目录中的所有图像文件名,并使用is_image_file函数来筛选出图像文件
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
# 计算可用的裁剪尺寸
crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
# 分别创建高分辨率和低分辨率图像的转换操作
self.hr_transform = train_hr_transform(crop_size)
self.lr_transform = train_lr_transform(crop_size, upscale_factor)
def __getitem__(self, index):
# 获取给定索引的图像数据
hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
lr_image = self.lr_transform(hr_image)
return lr_image, hr_image
def __len__(self):
# 返回数据集的大小(图像数量)
return len(self.image_filenames)
# 加载验证集中的图像数据 同训练集
class ValDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(ValDatasetFromFolder, self).__init__()
self.upscale_factor = upscale_factor
self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
def __getitem__(self, index):
# 打开高分辨率图像文件
hr_image = Image.open(self.image_filenames[index])
w, h = hr_image.size # 获取图像的宽度和高度
crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor) # 计算可用的裁剪尺寸
lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC) # 缩放图像为低分辨率图像
hr_scale = Resize(crop_size, interpolation=Image.BICUBIC) # 缩放图像为高分辨率图像
hr_image = CenterCrop(crop_size)(hr_image) # 对高分辨率图像进行中心裁剪
lr_image = lr_scale(hr_image) # 缩放得到低分辨率图像
hr_restore_img = hr_scale(lr_image) # 缩放得到还原后的高分辨率图像
return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
def __len__(self):
return len(self.image_filenames)
# 加载测试集中的图像数据 同训练集
class TestDatasetFromFolder(Dataset):
def __init__(self, dataset_dir, upscale_factor):
super(TestDatasetFromFolder, self).__init__()
self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/' # 构建低分辨率图像文件路径
self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/' # 构建高分辨率图像文件路径
self.upscale_factor = upscale_factor
# 获取两个路径下的图像文件名,并保存在lr_filenames和hr_filenames列表中。
self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]
def __getitem__(self, index):
# 获取给定索引的图像数据
image_name = self.lr_filenames[index].split('/')[-1]
lr_image = Image.open(self.lr_filenames[index]) # 打开低分辨率图像文件
w, h = lr_image.size # 获取低分辨率图像的宽度和高度
hr_image = Image.open(self.hr_filenames[index]) # 打开高分辨率图像文件
hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC) # 缩放高分辨率图像
hr_restore_img = hr_scale(lr_image) # 缩放得到还原后的高分辨率图像
# 将图像文件名、低分辨率图像、还原后的高分辨率图像和原始高分辨率图像转换为张量并返回
return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)
def __len__(self):
return len(self.lr_filenames)
2. 网络模型: model.py
2.1 生成器: Generator
import math
import torch
from torch import nn
# 生成器模型
class Generator(nn.Module):
def __init__(self, scale_factor):
# 计算需要进行上采样的块的数量
upsample_block_num = int(math.log(scale_factor, 2))
super(Generator, self).__init__()
# 二维卷积层,输入通道数为3,输出通道数为64,卷积核大小为9,填充为4
self.block1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=9, padding=4),
nn.PReLU() # Parametric ReLU激活函数
)
self.block2 = ResidualBlock(64) # 定义(残差)ResidualBlock模块
self.block3 = ResidualBlock(64)
self.block4 = ResidualBlock(64)
self.block5 = ResidualBlock(64)
self.block6 = ResidualBlock(64)
self.block7 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64)
)
# 由多个UpsampleBlock模块组成的列表
block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
self.block8 = nn.Sequential(*block8) # 由block8列表中的模块组成的序列模块
def forward(self, x):
block1 = self.block1(x)
block2 = self.block2(block1)
block3 = self.block3(block2)
block4 = self.block4(block3)
block5 = self.block5(block4)
block6 = self.block6(block5)
block7 = self.block7(block6)
block8 = self.block8(block1 + block7)
# 将输出限制在0到1之间,通过tanh激活函数和缩放操作得到最终生成的图像
return (torch.tanh(block8) + 1) / 2
2.2 判别器:Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
# 二维卷积层,输入通道数为3,输出通道数为64,卷积核大小为3,填充为1
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2), # LeakyReLU激活函数
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
# 自适应平均池化层,将输入特征图转换为大小为1x1的特征图
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(512, 1024, kernel_size=1),
nn.LeakyReLU(0.2),
nn.Conv2d(1024, 1, kernel_size=1)
)
def forward(self, x):
# 输入批次的大小
batch_size = x.size(0)
# 使用torch.sigmoid函数将特征图映射到0到1之间,表示输入图像为真实图像的概率。
return torch.sigmoid(self.net(x).view(batch_size))
2.3 残差块: ResidualBlock
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
# 二维卷积层,输入通道数为channels,输出通道数为channels,卷积核大小为3,填充为1
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels) # 二维批归一化层
self.prelu = nn.PReLU() # Parametric ReLU激活函数
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels) # 二维批归一化层
def forward(self, x):
# 应用对应的layer得到前向传播的输出(残差项)
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
return x + residual # 将输入x与残差项相加,得到最终输出
2.4 上采样块: UpsampleBLock
# 上采样块
class UpsampleBLock(nn.Module):
def __init__(self, in_channels, up_scale):
super(UpsampleBLock, self).__init__()
# 卷积层,输入通道数为in_channels,输出通道数为in_channels * 2 ** 2,卷积核大小为3,填充为1
self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
# 像素重排操作,上采样因子为up_scale
self.pixel_shuffle = nn.PixelShuffle(up_scale)
self.prelu = nn.PReLU()
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.prelu(x)
return x
3. 损失函数: loss.py
import torch
from torch import nn
from torchvision.models.vgg import vgg16
class GeneratorLoss(nn.Module):
def __init__(self):
super(GeneratorLoss, self).__init__()
# 使用预训练的 VGG16 模型来构建特征提取网络
vgg = vgg16(pretrained=True)
# 选择 VGG16 模型的前 31 层作为损失网络,并将其设置为评估模式(不进行梯度更新)
loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
# 冻结其参数,不进行梯度更新
for param in loss_network.parameters():
param.requires_grad = False
self.loss_network = loss_network
# 定义均方误差损失函数: 计算生成器生成图像与目标图像之间的均方误差损失
self.mse_loss = nn.MSELoss()
# 定义总变差损失函数: 计算生成器生成图像的总变差损失,用于平滑生成的图像
self.tv_loss = TVLoss()
def forward(self, out_labels, out_images, target_images):
# Adversarial Loss(对抗损失):使生成的图像更接近真实图像,目标是最小化生成器对图像的判别结果的平均值与 1 的差距
adversarial_loss = torch.mean(1 - out_labels)
# Perception Loss(感知损失):计算生成图像和目标图像在特征提取网络中提取的特征之间的均方误差损失
perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
# Image Loss(图像损失):计算生成图像和目标图像之间的均方误差损失
image_loss = self.mse_loss(out_images, target_images)
# TV Loss(总变差损失):计算生成图像的总变差损失,用于平滑生成的图像
tv_loss = self.tv_loss(out_images)
# 返回生成器的总损失,四个损失项加权求和
return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss
class TVLoss(nn.Module):
def __init__(self, tv_loss_weight=1):
super(TVLoss, self).__init__()
self.tv_loss_weight = tv_loss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self.tensor_size(x[:, :, 1:, :])
count_w = self.tensor_size(x[:, :, :, 1:])
# 计算水平方向上的总变差损失
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
# 计算垂直方向上的总变差损失
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
# 返回总变差损失
return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
@staticmethod
def tensor_size(t):
# 返回张量的尺寸大小,即通道数乘以高度乘以宽度
return t.size()[1] * t.size()[2] * t.size()[3]
if __name__ == "__main__":
g_loss = GeneratorLoss()
print(g_loss)
4.训练:train.py
import argparse
import os
from math import log10
import pandas as pd
import torch.optim as optim
import torch.utils.data
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import pytorch_ssim
from data_utils import TrainDatasetFromFolder, ValDatasetFromFolder, display_transform
from loss import GeneratorLoss
from model import Generator, Discriminator
# 创建一个命令行参数解析器对象
parser = argparse.ArgumentParser(description='Train Super Resolution Models')
# 用于指定训练图像的裁剪尺寸,默认为88
parser.add_argument('--crop_size', default=88, type=int, help='training images crop size')
# 用于指定超分辨率的放大因子,默认为4
parser.add_argument('--upscale_factor', default=4, type=int, choices=[2, 4, 8],
help='super resolution upscale factor')
# 用于指定训练的轮数,默认为100
parser.add_argument('--num_epochs', default=100, type=int, help='train epoch number')
if __name__ == '__main__':
# 解析命令行参数并将结果存储在变量opt中
opt = parser.parse_args()
# 从opt中获取crop_size、upscale_factor和num_epochs的值,并分别赋给对应的变量
CROP_SIZE = opt.crop_size
UPSCALE_FACTOR = opt.upscale_factor
NUM_EPOCHS = opt.num_epochs
# 创建训练数据集对象TrainDatasetFromFolder,指定数据集路径、裁剪尺寸和放大因子
train_set = TrainDatasetFromFolder('data/VOC2012/train', crop_size=CROP_SIZE, upscale_factor=UPSCALE_FACTOR)
# 创建验证数据集对象ValDatasetFromFolder,指定数据集路径和放大因子
val_set = ValDatasetFromFolder('data/VOC2012/val', upscale_factor=UPSCALE_FACTOR)
# 创建训练数据加载器,指定数据集对象、工作线程数、批量大小和是否打乱数据顺序
train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=64, shuffle=True)
# 创建验证数据加载器,指定数据集对象、工作线程数、批量大小和是否打乱数据顺序
val_loader = DataLoader(dataset=val_set, num_workers=4, batch_size=1, shuffle=False)
# 创建生成器模型对象Generator,指定放大因子
netG = Generator(UPSCALE_FACTOR)
# 输出生成器模型参数的数量
print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
# 创建生成器损失函数对象GeneratorLoss
netD = Discriminator()
# 输出判别器模型参数的数量
print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))
# 创建生成器损失函数对象GeneratorLoss
generator_criterion = GeneratorLoss()
# GPU如果可用的话,将生成器模型、判别器模型和生成器损失函数移动到GPU上进行计算
if torch.cuda.is_available():
netG.cuda()
netD.cuda()
generator_criterion.cuda()
# 创建生成器和判别器的优化器对象,用于更新模型参数
optimizerG = optim.Adam(netG.parameters())
optimizerD = optim.Adam(netD.parameters())
# 创建一个字典用于存储训练过程中的判别器和生成器的损失、分数和评估指标结果(信噪比和相似性)
results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}
for epoch in range(1, NUM_EPOCHS + 1):
# 创建训练数据的进度条
train_bar = tqdm(train_loader)
running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}
netG.train() # 将生成器设置为训练模式
netD.train() # 将判别器设置为训练模式
for data, target in train_bar:
g_update_first = True
batch_size = data.size(0)
running_results['batch_sizes'] += batch_size
# (1) Update D network: maximize D(x)-1-D(G(z))
real_img = Variable(target)
if torch.cuda.is_available():
real_img = real_img.cuda()
z = Variable(data)
if torch.cuda.is_available():
z = z.cuda()
fake_img = netG(z) # 通过生成器生成伪图像
# 清除判别器的梯度
netD.zero_grad()
# 通过判别器对真实图像进行前向传播,并计算其输出的平均值
real_out = netD(real_img).mean()
# 通过判别器对伪图像进行前向传播,并计算其输出的平均值
fake_out = netD(fake_img).mean()
# 计算判别器的损失
d_loss = 1 - real_out + fake_out
# 在判别器网络中进行反向传播,并保留计算图以进行后续优化步骤
d_loss.backward(retain_graph=True)
# 利用优化器对判别器网络的参数进行更新
optimizerD.step()
# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
netG.zero_grad()
# The two lines below are added to prevent runtime error in Google Colab
# 通过生成器对输入图像(z)进行生成,生成伪图像(fake_img)
fake_img = netG(z)
# 通过判别器对伪图像进行前向传播,并计算其输出的平均值
fake_out = netD(fake_img).mean()
##
# 计算生成器的损失,包括对抗损失、感知损失、图像损失和TV损失
g_loss = generator_criterion(fake_out, fake_img, real_img)
# 在生成器网络中进行反向传播,计算生成器的梯度
g_loss.backward()
# 再次通过生成器对输入图像(z)进行生成,得到新的伪图像(fake_img)
fake_img = netG(z)
# 通过判别器对新的伪图像进行前向传播,并计算其输出的平均值
fake_out = netD(fake_img).mean()
# 利用优化器对生成器网络的参数进行更新
optimizerG.step()
# loss for current batch before optimization
# 累加当前批次生成器的损失值乘以批次大小,用于计算平均损失
running_results['g_loss'] += g_loss.item() * batch_size
# 累加当前批次判别器的损失值乘以批次大小,用于计算平均损失
running_results['d_loss'] += d_loss.item() * batch_size
# 累加当前批次真实图像在判别器的输出得分乘以批次大小,用于计算平均得分
running_results['d_score'] += real_out.item() * batch_size
# 累加当前批次伪图像在判别器的输出得分乘以批次大小,用于计算平均得分
running_results['g_score'] += fake_out.item() * batch_size
# 更新训练进度条的描述信息
train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
running_results['g_loss'] / running_results['batch_sizes'],
running_results['d_score'] / running_results['batch_sizes'],
running_results['g_score'] / running_results['batch_sizes']))
netG.eval()
# 创建用于保存训练结果的目录
out_path = 'training_results/SRF_' + str(UPSCALE_FACTOR) + '/'
if not os.path.exists(out_path):
os.makedirs(out_path)
with torch.no_grad():
val_bar = tqdm(val_loader)
valing_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
val_images = []
# 遍历验证数据集(低分辨率图 恢复的高分辨率图 高分辨率图)
for val_lr, val_hr_restore, val_hr in val_bar:
batch_size = val_lr.size(0)
valing_results['batch_sizes'] += batch_size
lr = val_lr
hr = val_hr
if torch.cuda.is_available():
lr = lr.cuda()
hr = hr.cuda()
# 生成超分辨率图像
sr = netG(lr)
# 计算批量图像的均方误差
batch_mse = ((sr - hr) ** 2).data.mean()
# 累加均方误差
valing_results['mse'] += batch_mse * batch_size
# 计算批量图像的结构相似度指数
batch_ssim = pytorch_ssim.ssim(sr, hr).item()
# 累加结构相似度指数
valing_results['ssims'] += batch_ssim * batch_size
# 计算平均峰值信噪比
valing_results['psnr'] = 10 * log10((hr.max()**2) / (valing_results['mse'] / valing_results['batch_sizes']))
# 计算平均结构相似度指数
valing_results['ssim'] = valing_results['ssims'] / valing_results['batch_sizes']
# 更新训练进度条的描述信息
val_bar.set_description(
desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
valing_results['psnr'], valing_results['ssim']))
val_images.extend(
# 将图像应用转换函数,并添加到验证图像列表
[display_transform()(val_hr_restore.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
display_transform()(sr.data.cpu().squeeze(0))])
# 将验证图像列表堆叠为张量
val_images = torch.stack(val_images)
# 将堆叠后的张量分割为多个小块,每个小块包含15张图像
val_images = torch.chunk(val_images, val_images.size(0) // 15)
# 创建进度条,并设置描述为“[saving training results]”
val_save_bar = tqdm(val_images, desc='[saving training results]')
index = 1
for image in val_save_bar:
# 将小块中的图像创建为一个网格,每行显示3张图像,图像之间有5个像素的间隔
image = utils.make_grid(image, nrow=3, padding=5)
# 将网格图像保存为文件,文件名包含epoch和index信息
utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
index += 1
# save model parameters
# 将判别器和生成器的参数保存到指定文件
torch.save(netG.state_dict(), 'epochs/netG_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
torch.save(netD.state_dict(), 'epochs/netD_epoch_%d_%d.pth' % (UPSCALE_FACTOR, epoch))
# save loss\scores\psnr\ssim
results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
results['psnr'].append(valing_results['psnr'])
results['ssim'].append(valing_results['ssim'])
if epoch % 10 == 0 and epoch != 0:
out_path = 'statistics/'
# 创建一个DataFrame对象,用于存储训练结果数据
data_frame = pd.DataFrame(
data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': results['ssim']},
index=range(1, epoch + 1))
# 将DataFrame对象保存为CSV文件
data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_train_results.csv', index_label='Epoch')
5. 测试基准数据集: test_benchmark.py
import argparse
import os
from math import log10
import numpy as np
import pandas as pd
import torch
import torchvision.utils as utils
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tqdm import tqdm
import pytorch_ssim
from data_utils import TestDatasetFromFolder, display_transform
from model import Generator
parser = argparse.ArgumentParser(description='Test Benchmark Datasets')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--model_name', default='netG_epoch_4_150.pth', type=str, help='generator model epoch name')
opt = parser.parse_args()
UPSCALE_FACTOR = opt.upscale_factor
MODEL_NAME = opt.model_name
# 保存每个测试数据集的结果
results = {'Set5': {'psnr': [], 'ssim': []}, 'Set14': {'psnr': [], 'ssim': []}, 'BSD100': {'psnr': [], 'ssim': []},
'Urban100': {'psnr': [], 'ssim': []}, 'SunHays80': {'psnr': [], 'ssim': []}}
# 创建一个 Generator 对象
model = Generator(UPSCALE_FACTOR).eval()
if torch.cuda.is_available():
model = model.cuda()
# 加载训练好的模型参数
model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
# 加载测试数据集
test_set = TestDatasetFromFolder('data/test', upscale_factor=UPSCALE_FACTOR)
test_loader = DataLoader(dataset=test_set, num_workers=4, batch_size=1, shuffle=False)
# 创建一个用于 test_loader 的 tqdm 进度条
test_bar = tqdm(test_loader, desc='[testing benchmark datasets]')
# 测试结果输出路径
out_path = 'benchmark_results/SRF_' + str(UPSCALE_FACTOR) + '/'
if not os.path.exists(out_path):
os.makedirs(out_path)
for image_name, lr_image, hr_restore_img, hr_image in test_bar:
# 由于 image_name 是一个包含单个元素的列表,所以将其取出
image_name = image_name[0]
# 将 lr_image 转换为 Variable 对象,并设置 volatile=True
# volatile=True 表示不会计算梯度,这在推理阶段通常是需要的
lr_image = Variable(lr_image, volatile=True)
hr_image = Variable(hr_image, volatile=True)
if torch.cuda.is_available():
lr_image = lr_image.cuda()
hr_image = hr_image.cuda()
# 生成超分变率图像
sr_image = model(lr_image)
mse = ((hr_image - sr_image) ** 2).data.mean()
# 计算峰值信噪比(Peak Signal-to-Noise Ratio)
psnr = 10 * log10(255 ** 2 / mse)
# 计算结构相似性指数(Structural Similarity Index)
# 使用 pytorch_ssim 库中的 ssim 函数计算 SSIM
ssim = pytorch_ssim.ssim(sr_image, hr_image).data[0]
# 创建一个包含三张图像的张量,分别是原始恢复的高分辨率图像、原始高分辨率图像和生成的超分辨率图像
# 将每张图像应用 display_transform() 转换,并通过 squeeze(0) 去除批次维度
test_images = torch.stack(
[display_transform()(hr_restore_img.squeeze(0)), display_transform()(hr_image.data.cpu().squeeze(0)),
display_transform()(sr_image.data.cpu().squeeze(0))])
# 使用 make_grid 函数将三张图像拼接成一张大图像
# nrow=3 表示每行显示 3 张图像,padding=5 表示图像之间的间距为 5
image = utils.make_grid(test_images, nrow=3, padding=5)
# 使用 save_image 函数将合成的图像保存到指定路径
utils.save_image(image, out_path + image_name.split('.')[0] + '_psnr_%.4f_ssim_%.4f.' % (psnr, ssim) +
image_name.split('.')[-1], padding=5)
# 将对应数据集的PSNR和SSIM保存到对应的字典当中
results[image_name.split('_')[0]]['psnr'].append(psnr)
results[image_name.split('_')[0]]['ssim'].append(ssim)
# 最终结果保存路径
out_path = 'statistics/'
saved_results = {'psnr': [], 'ssim': []}
# 遍历 results 字典中的每个值
for item in results.values():
# 获取 PSNR 和 SSIM 的列表
psnr = np.array(item['psnr'])
ssim = np.array(item['ssim'])
# 如果列表为空,将 PSNR 和 SSIM 设置为 'No data'
if (len(psnr) == 0) or (len(ssim) == 0):
psnr = 'No data'
ssim = 'No data'
else:
# 如果列表不为空,计算 PSNR 和 SSIM 的均值
psnr = psnr.mean()
ssim = ssim.mean()
# 将计算得到的 PSNR 和 SSIM 添加到 saved_results 字典的相应列表中
saved_results['psnr'].append(psnr)
saved_results['ssim'].append(ssim)
# 创建一个 DataFrame 对象,使用 saved_results 字典作为数据,以 results.keys() 作为列标签
data_frame = pd.DataFrame(saved_results, results.keys())
# 将 DataFrame 对象保存为 CSV 文件
# 文件路径由 out_path、'srf_'、UPSCALE_FACTOR 值和 '_test_results.csv' 组成
# index_label='DataSet' 表示使用 'DataSet' 作为索引标签
data_frame.to_csv(out_path + 'srf_' + str(UPSCALE_FACTOR) + '_test_results.csv', index_label='DataSet')
6. 测试单张图片:test_image.py(原理和基准测试集相同,只需只能需要测试的图片名字即可)
import argparse
import time
import torch
from PIL import Image
from torch.autograd import Variable
from torchvision.transforms import ToTensor, ToPILImage
from model import Generator
parser = argparse.ArgumentParser(description='Test Single Image')
parser.add_argument('--upscale_factor', default=4, type=int, help='super resolution upscale factor')
parser.add_argument('--test_mode', default='CPU', type=str, choices=['GPU', 'CPU'], help='using GPU or CPU')
parser.add_argument('--image_name', default='data/test/SRF_4/data/Set5_003.png',type=str, help='test low resolution image name')
parser.add_argument('--model_name', default='netG_epoch_4_150.pth', type=str, help='generator model epoch name')
opt = parser.parse_args()
UPSCALE_FACTOR = opt.upscale_factor
TEST_MODE = True if opt.test_mode == 'GPU' else False
IMAGE_NAME = opt.image_name
MODEL_NAME = opt.model_name
model = Generator(UPSCALE_FACTOR).eval()
if TEST_MODE:
model.cuda()
model.load_state_dict(torch.load('epochs/' + MODEL_NAME))
else:
model.load_state_dict(torch.load('epochs/' + MODEL_NAME, map_location=lambda storage, loc: storage))
image = Image.open(IMAGE_NAME)
image = Variable(ToTensor()(image), volatile=True).unsqueeze(0)
if TEST_MODE:
image = image.cuda()
start = time.time()
out = model(image)
elapsed = (time.time() - start)
print('cost: ' + str(elapsed) + 's')
out_img = ToPILImage()(out[0].data.cpu())
# out_img.save('out_srf_' + str(UPSCALE_FACTOR) + '_' + IMAGE_NAME)
out_img.show()
save_path = 'single_test/image_name.jpg'
out_img.save(save_path)
print("图像已保存到文件夹中。")
本人水平有限,文中发现错误敬请指正,如有相同研究方向的同学可以互相学习,共同进步。