![fb0e0e201b474fead5d845a4d696c44d.png](https://img-blog.csdnimg.cn/img_convert/fb0e0e201b474fead5d845a4d696c44d.png)
原理分析:
ESRGAN是香港中文(深圳)本科生在eccv2018的文章,该方法在PIRM218-SR比赛取得冠军。论文分析SRGAN能够生成更多的纹理细节,但它纹理往往不够自然,也常伴随着一些噪声。然后深入研究并改进了SRGAN的三个关键部分——网络结构、对抗损失函数和感知损失函数。具体就是引入了一个新网络结构单元RRDB (Residual-in-Residual Dense Block);借鉴了相对生成对抗网络(relativistic GAN)让判别器预测相对的真实度而不是绝对的值;还使用了激活前的具有更强监督信息的特征表达来约束感知损失函数。
ESRGAN针对SRGAN进行了四项改进:
- 引入对生成器架构的更改(从Residual Blocks切换到RRDB,删除批量规范化)。
- adversarial loss方面的改进主要是运用了relativistic GAN【2】使relative realness instead of the absolute value.
- perceptual loss方面使用激活之前的特征进行计算(以前是用激活后的特征)。
- 预先训练网络以首先针对PSNR进行优化,然后使用GAN对其进行微调。
架构方面:
![c313fe4174242084f6f7c0bca438eb72.png](https://img-blog.csdnimg.cn/img_convert/c313fe4174242084f6f7c0bca438eb72.png)
ESRGAN还受益于密集连接(如DenseNet的作者所提出的)。这不仅可以增加网络的深度,还可以实现更复杂的结构。这样,网络可以学习更精细的细节。
ESRGAN不使用批量标准化。学习如何规范化层之间的数据分布是许多深度神经网络中的一般做法。BN层通过在训练中使用一批数据的均值和方差规范化特征并且在测试时通过使用在整个训练集上预估后的均值和方差规范化测试数据。当训练集和测试集的统计结果相差甚远时,BN层可能限制模型的泛化能力,删除批量标准化可提高稳定性并降低计算成本(减少学习参数)。
adversarial loss方面:
![60762197bd0cbd4d86d489d38fb9fe12.png](https://img-blog.csdnimg.cn/img_convert/60762197bd0cbd4d86d489d38fb9fe12.png)
作者提出一种新的思考模式,就是判别器是来估计真实图像相对来说比fake图像更逼真的概率。具体而言,作者把标准的判别器换成Relativistic average Discriminator(RaD),所以判别器的损失函数定义为:
![d5445d67ae4b1415b4006ba947da7a73.png](https://img-blog.csdnimg.cn/img_convert/d5445d67ae4b1415b4006ba947da7a73.png)
对应的生成器的对抗损失函数为:
![40ca2e145eb723daba38b9cb546a2cbd.png](https://img-blog.csdnimg.cn/img_convert/40ca2e145eb723daba38b9cb546a2cbd.png)
求MSE的操作是通过对mini-batch中的所有数据求平均得到的,x_f是原始低分辨图像经过生成器以后的图像,由于对抗的损失包含了x_r和x_f,所以生成器受益于对抗训练中的生成数据和实际数据的梯度,这种调整会使得网络学习到更尖锐的边缘和更细节的纹理。
perceptual loss方面:
它是用来一个训练好的VGG16来给出超分辨率复原所需要的特征,作者通过对损失域的研究发现,激活前的特征,这样会克服两个缺点。
- 激活后的特征是非常稀疏的,特别是在很深的网络中。这种稀疏的激活提供的监督效果是很弱的,会造成性能低下;
- 使用激活后的特征会导致重建图像与GT的亮度不一致。
作者还加入了content loss,最终损失函数由三部分组成:
![bb868bc5f3881ea803170c94970702cb.png](https://img-blog.csdnimg.cn/img_convert/bb868bc5f3881ea803170c94970702cb.png)
式中:
![f14b020844b646ad1b4dca54bdf2af70.png](https://img-blog.csdnimg.cn/img_convert/f14b020844b646ad1b4dca54bdf2af70.png)
网络插值(Network Interpolation)
为了去除基于GAN产生的噪声结果,本文提出了一种高效的策略——网络插值。具体操作如下:首先训练一个以PSNR导向的网络
![731aa36602f2616dc485a91beced5ba2.png](https://img-blog.csdnimg.cn/img_convert/731aa36602f2616dc485a91beced5ba2.png)
代码解读:
1.定义dataset:
import glob
import random
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
def denormalize(tensors):
""" Denormalizes image tensors using mean and std """
for c in range(3):
tensors[:, c].mul_(std[c]).add_(mean[c])
return torch.clamp(tensors, 0, 255)
class ImageDataset(Dataset):
def __init__(self, root, hr_shape):
hr_height, hr_width = hr_shape
# Transforms for low resolution images and high resolution images
self.lr_transform = transforms.Compose(
[
transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
self.hr_transform = transforms.Compose(
[
transforms.Resize((hr_height, hr_height), Image.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean, std),
]
)
self.files = sorted(glob.glob(root + "/*.*"))
def __getitem__(self, index):
img = Image.open(self.files[index % len(self.files)])
img_lr = self.lr_transform(img)
img_hr = self.hr_transform(img)
return {"lr": img_lr, "hr": img_hr}
def __len__(self):
return len(self.files)
2.模型定义:
特征提取:
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision.models import vgg19
import math
class FeatureExtractor(nn.Module):
def __init__(self):
super(FeatureExtractor, self).__init__()
vgg19_model = vgg19(pretrained=True)
self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])
def forward(self, img):
return self.vgg19_54(img)
DenseResidualBlock:
![46ba6e99df8596a276cc729d3417dd62.png](https://img-blog.csdnimg.cn/img_convert/46ba6e99df8596a276cc729d3417dd62.png)
![edce1f15ef2863297dc786bc8e1657b8.png](https://img-blog.csdnimg.cn/img_convert/edce1f15ef2863297dc786bc8e1657b8.png)
class DenseResidualBlock(nn.Module):
"""
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
"""
def __init__(self, filters, res_scale=0.2):
super(DenseResidualBlock, self).__init__()
self.res_scale = res_scale
def block(in_features, non_linearity=True):
layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
if non_linearity:
layers += [nn.LeakyReLU()]
return nn.Sequential(*layers)
self.b1 = block(in_features=1 * filters)
self.b2 = block(in_features=2 * filters)
self.b3 = block(in_features=3 * filters)
self.b4 = block(in_features=4 * filters)
self.b5 = block(in_features=5 * filters, non_linearity=False)
self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]
def forward(self, x):
inputs = x
for block in self.blocks:
out = block(inputs)
inputs = torch.cat([inputs, out], 1)
return out.mul(self.res_scale) + x
整体的ResidualInResidualDenseBlock:
![bfc610f8052abfe84efcfecca95f0417.png](https://img-blog.csdnimg.cn/img_convert/bfc610f8052abfe84efcfecca95f0417.png)
class ResidualInResidualDenseBlock(nn.Module):
def __init__(self, filters, res_scale=0.2):
super(ResidualInResidualDenseBlock, self).__init__()
self.res_scale = res_scale
self.dense_blocks = nn.Sequential(
DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
)
def forward(self, x):
return self.dense_blocks(x).mul(self.res_scale) + x
定义Generator:
class GeneratorRRDB(nn.Module):
def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
super(GeneratorRRDB, self).__init__()
# First layer
self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
# Residual blocks
self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
# Second conv layer post residual blocks
self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
# Upsampling layers
upsample_layers = []
for _ in range(num_upsample):
upsample_layers += [
nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(),
nn.PixelShuffle(upscale_factor=2),
]
self.upsampling = nn.Sequential(*upsample_layers)
# Final output block
self.conv3 = nn.Sequential(
nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(),
nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
)
def forward(self, x):
out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2)
out = self.upsampling(out)
out = self.conv3(out)
return out
nn.PixelShuffle(upscale_factor=2)完成图像的上采样。
定义Discriminator:
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
self.input_shape = input_shape
in_channels, in_height, in_width = self.input_shape
patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
self.output_shape = (1, patch_h, patch_w)
def discriminator_block(in_filters, out_filters, first_block=False):
layers = []
layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
if not first_block:
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
layers.append(nn.BatchNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
layers = []
in_filters = in_channels
for i, out_filters in enumerate([64, 128, 256, 512]):
layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
in_filters = out_filters
layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))
self.model = nn.Sequential(*layers)
def forward(self, img):
return self.model(img)
3.训练:
导入必要的库:
import argparse
import os
import numpy as np
import math
import itertools
import sys
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torch.autograd import Variable
from models import *
from datasets import *
import torch.nn as nn
import torch.nn.functional as F
import torch
os.makedirs("images/training", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="img_align_celeba", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=4, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--hr_height", type=int, default=256, help="high res. image height")
parser.add_argument("--hr_width", type=int, default=256, help="high res. image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=100, help="interval between saving image samples")
parser.add_argument("--checkpoint_interval", type=int, default=5000, help="batch interval between model checkpoints")
parser.add_argument("--residual_blocks", type=int, default=23, help="number of residual blocks in the generator")
parser.add_argument("--warmup_batches", type=int, default=500, help="number of batches with pixel-wise loss only")
parser.add_argument("--lambda_adv", type=float, default=5e-3, help="adversarial loss weight")
parser.add_argument("--lambda_pixel", type=float, default=1e-2, help="pixel-wise loss weight")
opt = parser.parse_args()
print(opt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
实例化:
hr_shape = (opt.hr_height, opt.hr_width)
# Initialize generator and discriminator
generator = GeneratorRRDB(opt.channels, filters=64, num_res_blocks=opt.residual_blocks).to(device)
discriminator = Discriminator(input_shape=(opt.channels, *hr_shape)).to(device)
feature_extractor = FeatureExtractor().to(device)
# Set feature extractor to inference mode
feature_extractor.eval()
# Losses
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
criterion_content = torch.nn.L1Loss().to(device)
criterion_pixel = torch.nn.L1Loss().to(device)
if opt.epoch != 0:
# Load pretrained models
generator.load_state_dict(torch.load("saved_models/generator_%d.pth" % opt.epoch))
discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth" % opt.epoch))
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
dataloader = DataLoader(
ImageDataset("../../data/%s" % opt.dataset_name, hr_shape=hr_shape),
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
)
训练并保存模型:
for epoch in range(opt.epoch, opt.n_epochs):
for i, imgs in enumerate(dataloader):
batches_done = epoch * len(dataloader) + i
# Configure model input
imgs_lr = Variable(imgs["lr"].type(Tensor))
imgs_hr = Variable(imgs["hr"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# Generate a high resolution image from low resolution input
gen_hr = generator(imgs_lr)
# Measure pixel-wise loss against ground truth
loss_pixel = criterion_pixel(gen_hr, imgs_hr)
if batches_done < opt.warmup_batches:
# Warm-up (pixel-wise loss only)
loss_pixel.backward()
optimizer_G.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), loss_pixel.item())
)
continue
# Extract validity predictions from discriminator
pred_real = discriminator(imgs_hr).detach()
pred_fake = discriminator(gen_hr)
# Adversarial loss (relativistic average GAN)
loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)
# Content loss
gen_features = feature_extractor(gen_hr)
real_features = feature_extractor(imgs_hr).detach()
loss_content = criterion_content(gen_features, real_features)
# Total generator loss
loss_G = loss_content + opt.lambda_adv * loss_GAN + opt.lambda_pixel * loss_pixel
loss_G.backward()
optimizer_G.step()
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
pred_real = discriminator(imgs_hr)
pred_fake = discriminator(gen_hr.detach())
# Adversarial loss for real and fake images (relativistic average GAN)
loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)
# Total loss
loss_D = (loss_real + loss_fake) / 2
loss_D.backward()
optimizer_D.step()
# --------------
# Log Progress
# --------------
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, content: %f, adv: %f, pixel: %f]"
% (
epoch,
opt.n_epochs,
i,
len(dataloader),
loss_D.item(),
loss_G.item(),
loss_content.item(),
loss_GAN.item(),
loss_pixel.item(),
)
)
if batches_done % opt.sample_interval == 0:
# Save image grid with upsampled inputs and ESRGAN outputs
imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1))
save_image(img_grid, "images/training/%d.png" % batches_done, nrow=1, normalize=False)
if batches_done % opt.checkpoint_interval == 0:
# Save model checkpoints
torch.save(generator.state_dict(), "saved_models/generator_%d.pth" % epoch)
torch.save(discriminator.state_dict(), "saved_models/discriminator_%d.pth" %epoch)