1.论文介绍
1.1简介
论文名称《Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial》
Ledig C , Theis L , Huszar F , et al. Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network[J]. 2016.
该文针对传统超分辨方法中存在的结果过于平滑的问题,提出了结合最新的对抗网络的方法,得到了不错的效果。并且针对此网络结构,构建了自己的感知损失函数。
传统的超分辨率模型采用均方差作为损失函数,虽然可以获得很高的峰值信噪比,但是恢复出来的图像会丢失高频细节。
因此,该文提出了一种自己设计的感知损失函数:
1.2 content loss
像素级 MSE Loss 的计算为:
这个是最经常使用的优化目标。但是,这种方式当取得较高的 PSNR的同时,MSE 优化问题导致缺乏 high-frequency content,这就会使得结果太过于平滑(overly smooth solutions)。因此在在 pre-trained 19-layer VGG network 的 ReLU activation layers 的基础上,定义了 VGG loss
其中,Wi,jWi,j and Hi,jHi,j 表示了 VGG network 当中相应的 feature maps 的维度。
1.3 Adversarial Loss
在所有训练样本上,基于判别器的概率定义 generative loss :
此处,D 是重构图像是 natural HR image 的概率。
1.4 Regulatization Loss
我们进一步的采用 基于 total variation 的正则化项来鼓励 spatially coherent solutions。正则化损失的定义为:
效果展示
3. 模型复现
3.1 整体结构
分为两部分,一部分为产生网络,一部分为鉴别网络,接下来分别对其进行复现。
3.2 Generator Network
首先由结构图可以看出,整体大致分为三部分。
第一部分input输入一张低分辨率的图像,并进行一次卷积和ReLU操作
#第一部分,传入低分辨率图像
LR_input = Input(shape=self.LR_shape)
layer1 = Conv2D(64, kernel_size=3, strides=1, padding='same')(LR_input)
layer1 = Activation('relu')(layer1)
第二部分会经过B个残差块,每个残差块的组成为两次卷积,两次BN层,一次ReLU,最后有一个残差边。
先构建一个残差块函数,方便使用
def residual_block(input, filter):
layer = Conv2D(filter, kernel_size=3, strides=1, padding='same')(input)
# 衰减率暂时设0.8,不考虑性能,优先考虑回环检测需求
layer = BatchNormalization(momentum=0.8)(layer)
layer = Activation('relu')(layer)
layer = Conv2D(filter, kernel_size=3, strides=1, padding='same')(layer)
layer = BatchNormalization(momentum=0.8)(layer)
layer = Add()([layer,input])
return layer
经过b个残差结构块:
#第二部分,经过b个残差结构块
layer2 = residual_block(layer1, 64)
for _ in range(self.b_residual_blocks - 1):
layer2 = residual_block(layer2, 64)
第三部分经过两次Deconv上采样和ReLU,再经过一次卷积,将图像扩大为原来的四倍来提高分辨率。
先构建解码函数:
def deConv(input):
layer = UpSampling2D(size=2)(input)
layer = Conv2D(256, kernel_size=3, strides=1, padding='same')(layer)
layer = Activation('relu')(layer)
return layer
#第三部分,上采样图像放大为原来的4倍
layer3 = Conv2D(64, kernel_size=3, strides=1, padding='same')(layer2)
layer3 = BatchNormalization(momentum=0.8)(layer3)
layer3 = Add()([layer3,layer1])
res1 = deConv(layer3)
res2 = deConv(res1)
genernator_HR = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(res2)
# 返回原低分辨率图像以及生成的伪高清图像
return Model(LR_input, genernator_HR)
最终输出高分辨率的图像。
3.3 Discriminator Network
整体看还是分为三个大部分:
第一部分对input输入图像进行两次卷积和Leaky ReLU。
# 公用卷积块
def conv_block(input, filter, strides=1, BN=True):
block = Conv2D(filter, kernel_size=3, strides=strides, padding='same')(input)
#参数慢慢调吧
block = LeakyReLU(alpha=0.2)(block)
if BN:
block = BatchNormalization(momentum=0.8)(block)
return block
# 过滤器和步长参考的论文的结构图上的
# 第一部分
input = Input(shape=self.HR_shape)
layer1 = conv_block(input=input, filter=64, BN=False)
layer2 = conv_block(layer1,64,strides=2)
第二部分是经过6次卷积+Leaky ReLU+BN操作。
#第二部分
layer3 = conv_block(layer2, 128)
layer4 = conv_block(layer3, 128, strides=2)
layer5 = conv_block(layer4, 256)
layer6 = conv_block(layer5, 256, strides=2)
layer7 = conv_block(layer6, 512)
layer8 = conv_block(layer7, 512,strides=2)
第三部分是经过一个Dense(1024)+Leaky ReLU+Dense(1)+Sigmoid,然后对图像进行输出。
#第三部分
layer9 = Dense(1024)(layer8)
layer10 = LeakyReLU(alpha=0.2)(layer9)
res = Dense(1,activation='sigmoid')(layer10)
return Model(input, res)
3.4 VGGNet
利用VGG19的第9层特征层,虽然代码是from keras.applications import VGG19导入的VGG19,但是预训练权重最好还是先去官网下载好,否则第一次下载的太慢了。
贴一个百度云:https://pan.baidu.com/s/1cJabHCqKNdIqNOxHZnXvmg,提取码:vltq
#采用VGG19的第九层特征层
def VGG(self):
vgg = VGG19(weights = "imagenet")
vgg.outputs = [vgg.layers[9].output]
img = Input(shape = self.HR_shape)
img_features = vgg(img)
return Model(img, img_features)
至此,论文模型就算复现完成,后续将完成模型训练,预测。
4.训练过程
类初始化:
高清图像的大小为512*512*3,低分辨率图片为128*128*3,四倍关系。
def __init__(self):
#低分辨率初始化
self.LR_height = 128
self.LR_width = 128
self.channels = 3
self.LR_shape = (self.LR_height, self.LR_width, self.channels)
#高分辨率初始化
self.HR_height = self.LR_height * 4
self.HR_width = self.LR_width * 4
self.HR_shape = (self.HR_height, self.HR_width, self.channels)
#b个残差块
self.b_residual_blocks = 16
#优化器,初始学习率0.002
optimizer = Adam(0.0002, 0.9)
# 数据集
self.datasets_name = "DIV"
self.dataProcess = dataProcess(data_name=self.datasets_name, img=(self.HR_height, self.HR_width))
self.vgg = self.VGG()
self.vgg.trainable = False
#建立生成网络
self.generator = self.generatorNet()
self.generator.summary()
# 建立鉴别网络
patch = int(self.HR_height / 2**4)
self.discriminator_patch = (patch, patch, 1)
self.discriminator = self.discriminatorNet()
self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer,metrics=['accuracy'])
self.discriminator.summary()
LR_img = Input(shape=self.LR_shape)
gen_HR = self.generator(LR_img)
gen_HR_features = self.vgg(gen_HR)
self.discriminator.trainable = False
validity = self.discriminator(gen_HR)
self.combined = Model(LR_img, [validity, gen_HR_features])
self.combined.compile(loss=['binary_crossentropy', 'mse'], loss_weights=[5e-1,1],optimizer=optimizer)
先对数据集进行预处理:
class dataProcess():
def __init__(self, data_name, img=(128, 128)):
self.data_name = data_name
self.img = img
def process(self, batch_size=1, data_type=False):
dataPath = glob('./datasets/%s/train/*' % (self.data_name))
images = np.random.choice(dataPath, size=batch_size)
img_HR = []
img_LR = []
for i in images:
img = self.imread(i)
height, width = self.img
# 缩小4倍
L_height, L_width = int(height / 4), int(width / 4)
img_H = scipy.misc.imresize(img, self.img)
img_L = scipy.misc.imresize(img, (L_height, L_width))
if not data_type and np.random.random() < 0.5:
img_H = np.fliplr(img_H)
img_L = np.fliplr(img_L)
img_HR.append(img_H)
img_LR.append(img_L)
img_HR = np.array(img_HR) / 127.5 - 1
img_LR = np.array(img_LR) / 127.5 - 1
return img_HR, img_LR
def imread(self, path):
return scipy.misc.imread(path, mode='RGB').astype(np.float)
训练函数:
先将图片传入产生网络生成伪高清图片,然后分别计算伪高清图像和原始高清图像的 loss值,求他们的平均值。然后将两个图像传入VGG网络,利用VGG第9层网络提取图片的特征向量,并计算loss值。
def train(self, epochs, init_epoch=0, batch_size=1, example=50):
begin_time = datetime.datetime.now()
if init_epoch != 0:
self.generator.load_weights("weights/%s/gen_epoch%d.h5" % (self.datasets_name, init_epoch), skip_mismatch=True)
self.discriminator.load_weights("weights/%s/dis_epoch%d.h5" % (self.dataset_name, init_epoch), skip_mismatch=True)
for epoch in range(init_epoch, epochs):
self.learining([self.combined, self.discriminator], epoch)
HR_img, LR_img = self.dataProcess.process(batch_size)
gen_HR = self.generator.predict(LR_img)
valid = np.ones((batch_size,) + self.discriminator_patch)
gen = np.zeros((batch_size,) + self.discriminator_patch)
origin_loss = self.discriminator.train_on_batch(HR_img, valid)
gen_loss = self.discriminator.train_on_batch(gen_HR, gen)
d_loss = 0.5 * np.add(origin_loss, gen_loss)
HR_img, LR_img = self.dataProcess.process(batch_size)
valid = np.ones((batch_size,) + self.discriminator_patch)
img_features = self.vgg.predict(HR_img)
g_loss = self.combined.train_on_batch(LR_img, [valid, img_features])
print(d_loss, g_loss)
end_time = datetime.datetime.now()
time = end_time - begin_time
print("[Epoch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, feature loss: %05f] time: %s " \
% (epoch,
epochs,
d_loss[0],
100 * d_loss[1],
g_loss[1],
g_loss[2],
time))
if epoch % example == 0:
self.restore(epoch)
if epoch % 500 == 0 and epoch != init_epoch:
# 500代保存一次
os.makedirs('weights/%s' % self.datasets_name, exist_ok=True)
self.generator.save_weights("weights/%s/gen_epoch%d.h5" % (self.datasets_name, epoch))
self.dicriminator.save_weights("weights/%s/dis_epoch%d.h5" % (self.datasets_name, epoch))
5. 效果展示
低分辨率图为128*128,伪高清图为512*512
低分辨率图:
生成的伪高清图:
项目地址:vanbou/TensorFlow_SRGAN (github.com)