参看原代码GitHub地址
srgan.py
"""
Super-resolution of CelebA using Generative Adversarial Networks.
The dataset can be downloaded from: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=0
Instrustion on running the script:
1. Download the dataset from the provided link
2. Save the folder 'img_align_celeba' to 'datasets/'
4. Run the sript using command 'python srgan.py'
"""
# Successfully uninstalled tensorflow-2.2.0
#参考
#https://blog.csdn.net/weixin_44791964/article/details/103825427
#https://blog.csdn.net/weixin_41485242/article/details/105946150
from __future__ import print_function, division
from glob import glob
import imageio
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import datetime
import matplotlib
import matplotlib.pyplot as plt
from data_loader import DataLoader
from load import Loader
import numpy as np
import os
import cv2
import scipy
from PIL import Image
class SRGAN():
def __init__(self):
# Input shape
# 关于channel:https://www.cnblogs.com/Terrypython/p/10310531.html
self.channels = 3
# 低分辨率图的shape
self.lr_height = 56 # Low resolution height
self.lr_width = 56 # Low resolution width
self.lr_shape = (self.lr_height, self.lr_width, self.channels)
# 高分辨率图的shape
self.hr_height = self.lr_height * 4 # High resolution height
self.hr_width = self.lr_width * 4 # High resolution width
self.hr_shape = (self.hr_height, self.hr_width, self.channels)
# Number of residual blocks in the generator
# 生成网络中16个残差卷积块
self.n_residual_blocks = 16
# 优化器
optimizer = Adam(0.0002, 0.5)
# We use a pre-trained VGG19 model to extract image features from the high resolution
# and the generated high resolution images and minimize the mse between them
# 创建VGG模型,该模型用于提取特征
self.vgg = self.build_vgg()
self.vgg.trainable = False
self.vgg.summary()
self.vgg.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
# Configure data loader
# 数据集
self.dataset_name = 'D:/Easiest-SRGAN-demo-master/datasets/img_align_celeba'
self.data_loader = DataLoader(dataset_name=self.dataset_name,
img_res=(self.hr_height, self.hr_width))
# Calculate output shape of D (PatchGAN)
# patch?超参数之间的关系吗?
patch = int(self.hr_height / 2 ** 4)
self.disc_patch = (patch, patch, 1)
# Number of filters in the first layer of G and D
# 滤波器-啥?
self.gf = 64
self.df = 64
# Build and compile the discriminator
# 建立判别模型
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='mse',
optimizer=optimizer,
metrics=['accuracy'])
self.discriminator.summary()
# Build the generator
# 建立生成模型
self.generator = self.build_generator()
# High res. and low res. images
img_hr = Input(shape=self.hr_shape)
img_lr = Input(shape=self.lr_shape)
# Generate high res. version from low res.
# 输入低分辨率图像, 生成假的高分辨率图像
fake_hr = self.generator(img_lr)
# Extract image features of the generated img
#提取生成的假的高分辨率的特征
fake_features = self.vgg(fake_hr)
# 将生成模型和判别模型结合。训练生成模型的时候不训练判别模型。
# For the combined model we will only train the generator
self.discriminator.trainable = False
# Discriminator determines validity of generated high res. images
validity = self.discriminator(fake_hr)
#输入低分辨率和高分辨率,输出真实度和特征
self.combined = Model([img_lr, img_hr], [validity, fake_features])
self.combined.summary()
self.combined.compile(loss=['binary_crossentropy', 'mse'],
loss_weights=[1e-3, 1],
optimizer=optimizer)
# 建立VGG模型,只使用第9层的特征
def build_vgg(self):
"""
Builds a pre-trained VGG19 model that outputs image features extracted at the
third block of the model
"""
print('start loading trianed weights of vgg...')
vgg = VGG19(weights="imagenet")
# Set outputs to outputs of last conv. layer in block 3
# 将输出设置为块 3 中最后一个 conv. 层的输出
# See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
print('loading completes')
vgg.outputs = [vgg.layers[9].output]
img = Input(shape=self.hr_shape)
# 检查X_test和X_test[num]的形状
print('X_test.shape:' + str(img.shape))
print('X_test[num].shape:' + str(img[1].shape))
# Extract image features
# 提取图像特征
img_features = vgg(img)
return Model(img, img_features)
#生成网络:1个卷积 + 16个残差块 +2个反卷积 + 1个卷积
def build_generator(self):
#residual_block残差块
def residual_block(layer_input, filters):
"""Residual block described in paper"""
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
d = Activation('relu')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
d = BatchNormalization(momentum=0.8)(d)
d = Add()([d, layer_input])
return d
#deconv2d反卷积?
def deconv2d(layer_input):
"""Layers used during upsampling"""
u = UpSampling2D(size=2)(layer_input)
u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
u = Activation('relu')(u)
return u
# Low resolution image input
img_lr = Input(shape=self.lr_shape)
# Pre-residual block
# 第一部分,低分辨率图像进入后会经过一个卷积+RELU函数
c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
c1 = Activation('relu')(c1)
# Propogate through residual blocks
# 第二部分,经过16个残差网络结构,每个残差网络内部包含两个卷积+标准化+RELU,还有一个残差边
r = residual_block(c1, self.gf)
for _ in range(self.n_residual_blocks - 1):
r = residual_block(r, self.gf)
# Post-residual block
# 第三部分,上采样部分,将长宽进行放大,两次上采样后,变为原来的4倍,实现提高分辨率
c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
c2 = BatchNormalization(momentum=0.8)(c2)
c2 = Add()([c2, c1])
# Upsampling
u1 = deconv2d(c2)
u2 = deconv2d(u1)
# Generate high resolution output
gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)
return Model(img_lr, gen_hr)
# 辨别网络:8个卷积
def build_discriminator(self):
def d_block(layer_input, filters, strides=1, bn=True):
"""Discriminator layer"""
d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
d = LeakyReLU(alpha=0.2)(d)
if bn:
d = BatchNormalization(momentum=0.8)(d)
return d
# Input img
# 由一堆的卷积+LeakyReLU+BatchNor构成
d0 = Input(shape=self.hr_shape)
d1 = d_block(d0, self.df, bn=False)
d2 = d_block(d1, self.df, strides=2)
d3 = d_block(d2, self.df * 2)
d4 = d_block(d3, self.df * 2, strides=2)
d5 = d_block(d4, self.df * 4)
d6 = d_block(d5, self.df * 4, strides=2)
d7 = d_block(d6, self.df * 8)
d8 = d_block(d7, self.df * 8, strides=2)
d9 = Dense(self.df * 16)(d8)
d10 = LeakyReLU(alpha=0.2)(d9)
validity = Dense(1, activation='sigmoid')(d10)
return Model(d0, validity)
#https://www.sciencedirect.com/topics/computer-science/sample-interval#:~:text=Consider%20the%20purpose%20of%20the%20performance%20counter%20data,or%2010%20min%20sample%20interval%20is%20more%20appropriate.
#sample_interval? 采样间隔,取样间隔,取样时间间隔;
#https://www.cnblogs.com/XDU-Lakers/p/10607358.html
#超参数?-epochs batch_size sample_interval
def train(self, epochs, batch_size=1, sample_interval=5):
start_time = datetime.datetime.now()
# ---------------------------------------------oh
# 30-10
# 100-50
for epoch in range(epochs):
if epoch > 3:
sample_interval = 1
if epoch > 1:
sample_interval = 5
# ----------------------
# Train Discriminator
# 训练判别网络
# ----------------------
# Sample images and their conditioning counterparts
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
# From low res. image generate high res. version
fake_hr = self.generator.predict(imgs_lr)
#标注真的HR图像为真;对于数据图像打label
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
#真1;假0;在图像输入判别器之前还有打label的过程
#https://blog.csdn.net/weixin_43624538/article/details/99675296
# Train the discriminators (original images = real / generated = Fake)
d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid) #real-valid,更真实
d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake) #fake对fake,明辨
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) #两者结合,构成loss
#train_on_batch
#https://blog.csdn.net/weixin_42886817/article/details/99855287
# ------------------
# Train Generator
# 训练生成网络
# ------------------
# Sample images and their conditioning counterparts
# 样本图片和它们调整后的图片
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)
# The generators want the discriminators to label the generated images as real
#标注真的HR图像为真
valid = np.ones((batch_size,) + self.disc_patch)
# Extract ground truth image features using pre-trained VGG19 model
#使用预先训练的 VGG19 模型提取基准真实图像的特征
image_features = self.vgg.predict(imgs_hr)
# Train the generators
g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
#loss-[低分辨率,高分辨率],[真实度,高分辨率图片特征]
elapsed_time = datetime.datetime.now() - start_time
# Plot the progress
print("%d time: %s" % (epoch, elapsed_time))
#-------------------------------------------------------------------------oh
#
# If at save interval => save generated image samples
# 如果到保存的区间,保存训练得到的图片样本
#
if epoch % sample_interval == 0:
#--------------------------------------------------------for test
#self.sample_images_new(epoch)
print(fake_hr)
self.p_fake(epoch, fake_hr, imgs_hr)
#self.generator.save_weights('./saved_model/' + str(epoch) + '.h5')
if epoch % 500 == 0 and epoch > 1:
self.generator.save_weights('./saved_model/' + str(epoch) + '.h5')
def test_images(self, batch_size=1):
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size, is_pred=False)
os.makedirs('saved_model/', exist_ok=True)
self.generator.load_weights('./saved_model/' + str(2000) + '.h5')
#self.generator.load_weights('./saved_model/' + str(800_1) + '.h5')
#self.generator.load_weights('./saved_model/' + str(800_2) + '.h5')
fake_hr = self.generator.predict(imgs_lr)
print(fake_hr)
r = imgs_hr.shape[0]
print(r)
imgs_lr = 0.5 * imgs_lr + 0.5
fake_hr = 0.5 * fake_hr + 0.5
imgs_hr = 0.5 * imgs_hr + 0.5
# Save generated images and the high resolution originals
#输出------------------------------------------------------------------------
fig, axs = plt.subplots(1)
# subplot(总行数,总列数,按顺序第几个)
for row in range(r):
plt.imshow(fake_hr[row])
# plt.subplots_adjust(left=0, bottom=0, right=1, top=1, hspace = 0, wspace = 0)
plt.axis('off')
#https://www.zhangshengrong.com/p/9Oab7V3GNd/
#猜想:是否有可能是image内shape未传递;plt.imshow(fake_hr[row])内参数相关
#为什么去白边的操作去除了上下却左右没能干预
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
plt.margins(0, 0)
''' '''
plt.savefig("./test1.png",bbox_inches='tight', pad_inches=0.0)
plt.close()
#im = cv2.imread('test1.png')
#cv2.imshow('image', im)
#cv2.imwrite('./aa.jpg', im)
'''
for row in range(r):
for col, image in enumerate([imgs_lr, fake_hr]):
axs[row, col].imshow(image[row])
#坐标轴它可以放在图像的任意位置,在一幅图内绘制小图
#https://www.runoob.com/w3cnote/matplotlib-tutorial.html
axs[row, col].set_title(titles[col])
axs[row, col].axis('off')
fig.savefig("./result.png")
plt.close()
'''
#for image in enumerate([fake_hr]):
#------------------------------
def sample_images_new(self, epoch):
os.makedirs('images/', exist_ok=True)
imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=1, is_testing=True, is_pred=True)
fake_hr = self.generator.predict(imgs_lr)
imgs_lr = 0.5 * imgs_lr + 0.5
fake_hr = 0.5 * fake_hr + 0.5
imgs_hr = 0.5 * imgs_hr + 0.5
r, c = imgs_hr.shape[0], 3
titles = ['Generated epoch: ' + str(epoch), 'Original', 'Low']
fig, axs = plt.subplots(r, c)
for row in range(r):
for col, image in enumerate([fake_hr, imgs_hr, imgs_lr]):
axs[row, col].imshow(image[row])
axs[row, col].set_title(titles[col])
axs[row, col].axis('off')
'''
r, c = imgs_hr.shape[0], 3
titles = ['Generated epoch: ' + str(epoch), 'Original', 'Low']
fig, axs = plt.subplots(r, c)
for row in range(r):
plt.imshow(fake_hr[row])
# plt.subplots_adjust(left=0, bottom=0, right=1, top=1, hspace = 0, wspace = 0)
plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
plt.margins(0, 0)
fig.savefig("images/%d.png" % (epoch))
plt.close()
# https://www.zhangshengrong.com/p/9Oab7V3GNd/
# 猜想:是否有可能是image内shape未传递;plt.imshow(fake_hr[row])内参数相关
# 为什么去白边的操作去除了上下却左右没能干预
'''
#---------------------------------------------------------------怎么输出的
fig.savefig("images/%d.png" % (epoch))
plt.close()
def read_all(self,size):
os.makedirs('compare/', exist_ok=True)
# self.generator.load_weights('./saved_model/ohG.h5')
#self.generator.load_weights('./compare/' + str(2000) + '.h5')
#self.generator.load_weights('./compare/' + str(2900) + 'lr'+str(5)+'.h5')
self.generator.load_weights('./compare/' + str(22000) + '.h5')
i=0
dataset_name='./datasets/img_align_celeba2'
#for i in range(15):
#imgs_hr, imgs_lr = self.data_loader.load_data(is_pred=False)
#fake_hr = self.generator.predict(imgs_lr)
#如果我自己取样呢
path = glob('%s/*' % (dataset_name))
print("hello start")
print(path[0])
# breakpoint()
batch_images = np.random.choice(path, size=size, replace=False)
# breakpoint()
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = imageio.imread(img_path, pilmode='RGB').astype(np.float)
h, w = (224, 224)
low_h, low_w = int(h / 4), int(w / 4)
img_hr = np.array(Image.fromarray(np.uint8(img)).resize((224, 224)))
img_lr = np.array(Image.fromarray(np.uint8(img)).resize((low_h, low_w)))
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
fake_hr = self.generator.predict(imgs_lr)
imgs_lr = 0.5 * imgs_lr + 0.5
fake_hr = 0.5 * fake_hr + 0.5
imgs_hr = 0.5 * imgs_hr + 0.5
print(fake_hr)
#imgs_hr fake_hr
plt.imshow(imgs_hr[0])
plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
plt.margins(0, 0)
''' '''
#plt.savefig("./tryoutputdata/%d.png" % i, bbox_inches='tight', pad_inches=0.0)
plt.savefig("./tryimghr/%d.png" % i, bbox_inches='tight', pad_inches=0.0)
plt.close()
i=i+1
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = imageio.imread(img_path, pilmode='RGB').astype(np.float)
h, w = (224, 224)
low_h, low_w = int(h / 4), int(w / 4)
img_hr = np.array(Image.fromarray(np.uint8(img)).resize((224, 224)))
img_lr = np.array(Image.fromarray(np.uint8(img)).resize((low_h, low_w)))
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
fake_hr = self.generator.predict(imgs_lr)
imgs_lr = 0.5 * imgs_lr + 0.5
fake_hr = 0.5 * fake_hr + 0.5
imgs_hr = 0.5 * imgs_hr + 0.5
print(fake_hr)
# imgs_hr fake_hr
plt.imshow(fake_hr[0])
plt.axis('off')
plt.gca().xaxis.set_major_locator(plt.NullLocator())
plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
plt.margins(0, 0)
''' '''
plt.savefig("./tryoutputdata/%d.png" % i, bbox_inches='tight', pad_inches=0.0)
#plt.savefig("./tryimghr/%d.png" % i, bbox_inches='tight', pad_inches=0.0)
plt.close()
i = i + 1
imgs_hr = []
imgs_lr = []
if __name__ == '__main__':
gan = SRGAN()
#gan.train(epochs=1, batch_size=1, sample_interval=1)
#3000 10 2
#gan.test_images()
gan.read_all(50)
data_loader.py
import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import os
import imageio
from PIL import Image
import PIL
class DataLoader():
def __init__(self, dataset_name, img_res=(128, 128)):
self.dataset_name = dataset_name
self.img_res = img_res
def load_data(self, batch_size=1, is_testing=False, is_pred=False):
data_type = "train" if not is_testing else "test"
if is_pred:
batch_images = ['test_images/' + x for x in os.listdir('test_images/')]
else:
path = glob('%s/*' % (self.dataset_name))
print("hello start")
#print(path[0])
#breakpoint()
batch_images = np.random.choice(path, size=1, replace=False)
#breakpoint()
imgs_hr = []
imgs_lr = []
for img_path in batch_images:
img = self.imread(img_path)
#print("here")
#print(img)
#breakpoint()
h, w = self.img_res
low_h, low_w = int(h / 4), int(w / 4)
#报错 module 'scipy.misc' has no attribute 'imresize'
#img_hr = scipy.misc.imresize(img, self.img_res)
img_lr = scipy.misc.imresize(img, (low_h, low_w))
img_hr= np.array(Image.fromarray(np.uint8(img)).resize(self.img_res))
img_lr = np.array(Image.fromarray(np.uint8(img)).resize((low_h, low_w)))
# If training => do random flip
if not is_testing and np.random.random() < 0.5:
img_hr = np.fliplr(img_hr)
img_lr = np.fliplr(img_lr)
imgs_hr.append(img_hr)
imgs_lr.append(img_lr)
imgs_hr = np.array(imgs_hr) / 127.5 - 1.
imgs_lr = np.array(imgs_lr) / 127.5 - 1.
print("here")
print(imgs_lr[0].shape)
#breakpoint()
return imgs_hr, imgs_lr
def imread(self, path):
#input_image = imageio.imread(path, pilmode='RGB').astype(np.float)
return imageio.imread(path, pilmode='RGB').astype(np.float)