SRGANdemo代码注释学习

本文介绍了一种使用生成对抗网络(GAN)实现超分辨率图像增强的方法。作者详细展示了如何构建SRGAN(超分辨率生成对抗网络),包括生成器和判别器的结构。SRGAN通过VGG19模型提取图像特征,以提高生成的高分辨率图像的真实感。代码中包含了训练和测试过程,以及训练样本的加载和预处理。
摘要由CSDN通过智能技术生成

文章目录

参看原代码GitHub地址

srgan.py

"""
Super-resolution of CelebA using Generative Adversarial Networks.

The dataset can be downloaded from: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=0

Instrustion on running the script:
1. Download the dataset from the provided link
2. Save the folder 'img_align_celeba' to 'datasets/'
4. Run the sript using command 'python srgan.py'
"""
#  Successfully uninstalled tensorflow-2.2.0
#参考
#https://blog.csdn.net/weixin_44791964/article/details/103825427
#https://blog.csdn.net/weixin_41485242/article/details/105946150


from __future__ import print_function, division

from glob import glob

import imageio
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import datetime

import matplotlib
import matplotlib.pyplot as plt

from data_loader import DataLoader
from load import Loader
import numpy as np
import os

import cv2
import scipy
from PIL import Image


class SRGAN():
    def __init__(self):
        # Input shape
		# 关于channel:https://www.cnblogs.com/Terrypython/p/10310531.html
        self.channels = 3
		# 低分辨率图的shape
        self.lr_height = 56  # Low resolution height
        self.lr_width = 56  # Low resolution width
        self.lr_shape = (self.lr_height, self.lr_width, self.channels)
		# 高分辨率图的shape
        self.hr_height = self.lr_height * 4  # High resolution height
        self.hr_width = self.lr_width * 4  # High resolution width
        self.hr_shape = (self.hr_height, self.hr_width, self.channels)

        # Number of residual blocks in the generator
		# 生成网络中16个残差卷积块
        self.n_residual_blocks = 16

		# 优化器
        optimizer = Adam(0.0002, 0.5)

        # We use a pre-trained VGG19 model to extract image features from the high resolution
        # and the generated high resolution images and minimize the mse between them
		# 创建VGG模型,该模型用于提取特征
        self.vgg = self.build_vgg()
        self.vgg.trainable = False
        self.vgg.summary()
        self.vgg.compile(loss='mse',
                         optimizer=optimizer,
                         metrics=['accuracy'])

        # Configure data loader
		# 数据集
        self.dataset_name = 'D:/Easiest-SRGAN-demo-master/datasets/img_align_celeba'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.hr_height, self.hr_width))

        # Calculate output shape of D (PatchGAN)
		# patch?超参数之间的关系吗?
        patch = int(self.hr_height / 2 ** 4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
		# 滤波器-啥?
        self.gf = 64
        self.df = 64

        # Build and compile the discriminator
		# 建立判别模型
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
                                   optimizer=optimizer,
                                   metrics=['accuracy'])
        self.discriminator.summary()

        # Build the generator
		# 建立生成模型
        self.generator = self.build_generator()

        # High res. and low res. images
        img_hr = Input(shape=self.hr_shape)
        img_lr = Input(shape=self.lr_shape)

        # Generate high res. version from low res.
		# 输入低分辨率图像, 生成假的高分辨率图像
        fake_hr = self.generator(img_lr)

        # Extract image features of the generated img
		#提取生成的假的高分辨率的特征
        fake_features = self.vgg(fake_hr)

		# 将生成模型和判别模型结合。训练生成模型的时候不训练判别模型。
        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # Discriminator determines validity of generated high res. images
        validity = self.discriminator(fake_hr)

		#输入低分辨率和高分辨率,输出真实度和特征
        self.combined = Model([img_lr, img_hr], [validity, fake_features])
        self.combined.summary()
        self.combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)

	# 建立VGG模型,只使用第9层的特征
    def build_vgg(self):
        """
        Builds a pre-trained VGG19 model that outputs image features extracted at the
        third block of the model
        """
        print('start loading trianed weights of vgg...')
        vgg = VGG19(weights="imagenet")
        # Set outputs to outputs of last conv. layer in block 3
		# 将输出设置为块 3 中最后一个 conv. 层的输出
        # See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py
        print('loading completes')
        vgg.outputs = [vgg.layers[9].output]

        img = Input(shape=self.hr_shape)

        # 检查X_test和X_test[num]的形状
        print('X_test.shape:' + str(img.shape))
        print('X_test[num].shape:' + str(img[1].shape))

        # Extract image features
		# 提取图像特征
        img_features = vgg(img)

        return Model(img, img_features)

		
#生成网络:1个卷积 + 16个残差块 +2个反卷积 + 1个卷积
    def build_generator(self):

		#residual_block残差块
        def residual_block(layer_input, filters):
            """Residual block described in paper"""
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
            d = Activation('relu')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
            d = BatchNormalization(momentum=0.8)(d)
            d = Add()([d, layer_input])
            return d

		#deconv2d反卷积?
        def deconv2d(layer_input):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
            u = Activation('relu')(u)
            return u

        # Low resolution image input
        img_lr = Input(shape=self.lr_shape)

        # Pre-residual block
		# 第一部分,低分辨率图像进入后会经过一个卷积+RELU函数
        c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
        c1 = Activation('relu')(c1)

        # Propogate through residual blocks
		# 第二部分,经过16个残差网络结构,每个残差网络内部包含两个卷积+标准化+RELU,还有一个残差边
        r = residual_block(c1, self.gf)
        for _ in range(self.n_residual_blocks - 1):
            r = residual_block(r, self.gf)

        # Post-residual block
		# 第三部分,上采样部分,将长宽进行放大,两次上采样后,变为原来的4倍,实现提高分辨率
        c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
        c2 = BatchNormalization(momentum=0.8)(c2)
        c2 = Add()([c2, c1])

        # Upsampling
        u1 = deconv2d(c2)
        u2 = deconv2d(u1)

        # Generate high resolution output
        gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)

        return Model(img_lr, gen_hr)

	# 辨别网络:8个卷积
    def build_discriminator(self):

        def d_block(layer_input, filters, strides=1, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        # Input img
		# 由一堆的卷积+LeakyReLU+BatchNor构成
        d0 = Input(shape=self.hr_shape)

        d1 = d_block(d0, self.df, bn=False)
        d2 = d_block(d1, self.df, strides=2)
        d3 = d_block(d2, self.df * 2)
        d4 = d_block(d3, self.df * 2, strides=2)
        d5 = d_block(d4, self.df * 4)
        d6 = d_block(d5, self.df * 4, strides=2)
        d7 = d_block(d6, self.df * 8)
        d8 = d_block(d7, self.df * 8, strides=2)

        d9 = Dense(self.df * 16)(d8)
        d10 = LeakyReLU(alpha=0.2)(d9)
        validity = Dense(1, activation='sigmoid')(d10)

        return Model(d0, validity)
		
	#https://www.sciencedirect.com/topics/computer-science/sample-interval#:~:text=Consider%20the%20purpose%20of%20the%20performance%20counter%20data,or%2010%20min%20sample%20interval%20is%20more%20appropriate.
	#sample_interval? 采样间隔,取样间隔,取样时间间隔;
	#https://www.cnblogs.com/XDU-Lakers/p/10607358.html
	#超参数?-epochs batch_size sample_interval
    def train(self, epochs, batch_size=1, sample_interval=5):
        start_time = datetime.datetime.now()

        # ---------------------------------------------oh
        # 30-10
        # 100-50
        for epoch in range(epochs):
            if epoch > 3:
                sample_interval = 1
            if epoch > 1:
                sample_interval = 5

            # ----------------------
            #  Train Discriminator
			#	训练判别网络
            # ----------------------

            # Sample images and their conditioning counterparts
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            # From low res. image generate high res. version
            fake_hr = self.generator.predict(imgs_lr)



			
			#标注真的HR图像为真;对于数据图像打label
            valid = np.ones((batch_size,) + self.disc_patch)
            fake = np.zeros((batch_size,) + self.disc_patch)
			#真1;假0;在图像输入判别器之前还有打label的过程
			#https://blog.csdn.net/weixin_43624538/article/details/99675296

            # Train the discriminators (original images = real / generated = Fake)
            d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid) #real-valid,更真实
            d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)	#fake对fake,明辨
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) #两者结合,构成loss

			#train_on_batch
			#https://blog.csdn.net/weixin_42886817/article/details/99855287
			
            # ------------------
            #  Train Generator
			#	训练生成网络
            # ------------------

            # Sample images and their conditioning counterparts
			# 样本图片和它们调整后的图片
            imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)

            # The generators want the discriminators to label the generated images as real
			#标注真的HR图像为真
            valid = np.ones((batch_size,) + self.disc_patch)
			
			
            # Extract ground truth image features using pre-trained VGG19 model
			#使用预先训练的 VGG19 模型提取基准真实图像的特征
            image_features = self.vgg.predict(imgs_hr)

            # Train the generators
            g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
			#loss-[低分辨率,高分辨率],[真实度,高分辨率图片特征]

            elapsed_time = datetime.datetime.now() - start_time
            # Plot the progress
            print("%d time: %s" % (epoch, elapsed_time))
#-------------------------------------------------------------------------oh
			# 
            # If at save interval => save generated image samples
            # 如果到保存的区间,保存训练得到的图片样本
			#
            if epoch % sample_interval == 0:
#--------------------------------------------------------for test
                #self.sample_images_new(epoch)
                print(fake_hr)
                self.p_fake(epoch, fake_hr, imgs_hr)
				#self.generator.save_weights('./saved_model/' + str(epoch) + '.h5')
            if epoch % 500 == 0 and epoch > 1:
                self.generator.save_weights('./saved_model/' + str(epoch) + '.h5')

    def test_images(self, batch_size=1):
        imgs_hr, imgs_lr = self.data_loader.load_data(batch_size, is_pred=False)
        os.makedirs('saved_model/', exist_ok=True)
        self.generator.load_weights('./saved_model/' + str(2000) + '.h5')
        #self.generator.load_weights('./saved_model/' + str(800_1) + '.h5')
        #self.generator.load_weights('./saved_model/' + str(800_2) + '.h5')		

        fake_hr = self.generator.predict(imgs_lr)
        print(fake_hr)
        r = imgs_hr.shape[0]
        print(r)


        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5


        # Save generated images and the high resolution originals
#输出------------------------------------------------------------------------
        fig, axs = plt.subplots(1)
        # subplot(总行数,总列数,按顺序第几个)
        for row in range(r):
            plt.imshow(fake_hr[row])
           # plt.subplots_adjust(left=0, bottom=0, right=1, top=1, hspace = 0, wspace = 0)
            plt.axis('off')

            #https://www.zhangshengrong.com/p/9Oab7V3GNd/
            #猜想:是否有可能是image内shape未传递;plt.imshow(fake_hr[row])内参数相关
            #为什么去白边的操作去除了上下却左右没能干预
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
            plt.margins(0, 0)
        ''' '''
        plt.savefig("./test1.png",bbox_inches='tight',  pad_inches=0.0)
        plt.close()
        #im = cv2.imread('test1.png')
        #cv2.imshow('image', im)
        #cv2.imwrite('./aa.jpg', im)

    '''
        for row in range(r):
            for col, image in enumerate([imgs_lr, fake_hr]):
                axs[row, col].imshow(image[row])
				#坐标轴它可以放在图像的任意位置,在一幅图内绘制小图
				#https://www.runoob.com/w3cnote/matplotlib-tutorial.html
                axs[row, col].set_title(titles[col])
                axs[row, col].axis('off')
        fig.savefig("./result.png")
        plt.close()
    '''
            #for image in enumerate([fake_hr]):


#------------------------------		

    def sample_images_new(self, epoch):
        os.makedirs('images/', exist_ok=True)
        imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=1, is_testing=True, is_pred=True)
        fake_hr = self.generator.predict(imgs_lr)

        imgs_lr = 0.5 * imgs_lr + 0.5
        fake_hr = 0.5 * fake_hr + 0.5
        imgs_hr = 0.5 * imgs_hr + 0.5
        r, c = imgs_hr.shape[0], 3
        titles = ['Generated  epoch: ' + str(epoch), 'Original', 'Low']
        fig, axs = plt.subplots(r, c)
      
        for row in range(r):
            for col, image in enumerate([fake_hr, imgs_hr, imgs_lr]):
                axs[row, col].imshow(image[row])
                axs[row, col].set_title(titles[col])
                axs[row, col].axis('off')
        '''  
        r, c = imgs_hr.shape[0], 3
        titles = ['Generated  epoch: ' + str(epoch), 'Original', 'Low']
        fig, axs = plt.subplots(r, c)
        for row in range(r):
            plt.imshow(fake_hr[row])
            # plt.subplots_adjust(left=0, bottom=0, right=1, top=1, hspace = 0, wspace = 0)
            plt.axis('off')
        plt.gca().xaxis.set_major_locator(plt.NullLocator())
        plt.gca().yaxis.set_major_locator(plt.NullLocator())
        plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
        plt.margins(0, 0)
        fig.savefig("images/%d.png" % (epoch))
        plt.close()
        
        # https://www.zhangshengrong.com/p/9Oab7V3GNd/
        # 猜想:是否有可能是image内shape未传递;plt.imshow(fake_hr[row])内参数相关
        # 为什么去白边的操作去除了上下却左右没能干预
 '''
#---------------------------------------------------------------怎么输出的
        fig.savefig("images/%d.png" % (epoch))
        plt.close()


    def read_all(self,size):

        os.makedirs('compare/', exist_ok=True)
#        self.generator.load_weights('./saved_model/ohG.h5')
        #self.generator.load_weights('./compare/' + str(2000) + '.h5')
        #self.generator.load_weights('./compare/' + str(2900) + 'lr'+str(5)+'.h5')
        self.generator.load_weights('./compare/' + str(22000) + '.h5')
        i=0
        dataset_name='./datasets/img_align_celeba2'

        #for i in range(15):
            #imgs_hr, imgs_lr = self.data_loader.load_data(is_pred=False)
            #fake_hr = self.generator.predict(imgs_lr)
            #如果我自己取样呢

        path = glob('%s/*' % (dataset_name))
        print("hello start")
        print(path[0])
        # breakpoint()

        batch_images = np.random.choice(path, size=size, replace=False)
        # breakpoint()

        imgs_hr = []
        imgs_lr = []
        for img_path in batch_images:
            img = imageio.imread(img_path, pilmode='RGB').astype(np.float)

            h, w = (224, 224)
            low_h, low_w = int(h / 4), int(w / 4)

            img_hr = np.array(Image.fromarray(np.uint8(img)).resize((224, 224)))
            img_lr = np.array(Image.fromarray(np.uint8(img)).resize((low_h, low_w)))

            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)

            imgs_hr = np.array(imgs_hr) / 127.5 - 1.
            imgs_lr = np.array(imgs_lr) / 127.5 - 1.

            fake_hr = self.generator.predict(imgs_lr)

            imgs_lr = 0.5 * imgs_lr + 0.5
            fake_hr = 0.5 * fake_hr + 0.5
            imgs_hr = 0.5 * imgs_hr + 0.5
            print(fake_hr)
            #imgs_hr	fake_hr
            plt.imshow(imgs_hr[0])
            plt.axis('off')
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
            plt.margins(0, 0)
            ''' '''
            #plt.savefig("./tryoutputdata/%d.png" % i, bbox_inches='tight', pad_inches=0.0)
            plt.savefig("./tryimghr/%d.png" % i, bbox_inches='tight', pad_inches=0.0)			
			
            plt.close()
            i=i+1
            imgs_hr = []
            imgs_lr = []

        for img_path in batch_images:
            img = imageio.imread(img_path, pilmode='RGB').astype(np.float)

            h, w = (224, 224)
            low_h, low_w = int(h / 4), int(w / 4)

            img_hr = np.array(Image.fromarray(np.uint8(img)).resize((224, 224)))
            img_lr = np.array(Image.fromarray(np.uint8(img)).resize((low_h, low_w)))

            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)

            imgs_hr = np.array(imgs_hr) / 127.5 - 1.
            imgs_lr = np.array(imgs_lr) / 127.5 - 1.

            fake_hr = self.generator.predict(imgs_lr)

            imgs_lr = 0.5 * imgs_lr + 0.5
            fake_hr = 0.5 * fake_hr + 0.5
            imgs_hr = 0.5 * imgs_hr + 0.5
            print(fake_hr)
            # imgs_hr	fake_hr
            plt.imshow(fake_hr[0])
            plt.axis('off')
            plt.gca().xaxis.set_major_locator(plt.NullLocator())
            plt.gca().yaxis.set_major_locator(plt.NullLocator())
            plt.subplots_adjust(top=1, bottom=0, left=0, right=1, hspace=0, wspace=0)
            plt.margins(0, 0)
            ''' '''
            plt.savefig("./tryoutputdata/%d.png" % i, bbox_inches='tight', pad_inches=0.0)
            #plt.savefig("./tryimghr/%d.png" % i, bbox_inches='tight', pad_inches=0.0)

            plt.close()
            i = i + 1
            imgs_hr = []
            imgs_lr = []





if __name__ == '__main__':
    gan = SRGAN()

    #gan.train(epochs=1, batch_size=1, sample_interval=1)
	#3000 10 2
    #gan.test_images()
    gan.read_all(50)


data_loader.py

import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import os
import imageio
from PIL import Image
import PIL

class DataLoader():
    def __init__(self, dataset_name, img_res=(128, 128)):
        self.dataset_name = dataset_name
        self.img_res = img_res

    def load_data(self, batch_size=1, is_testing=False, is_pred=False):
        data_type = "train" if not is_testing else "test"
        if is_pred:
            batch_images = ['test_images/' + x for x in os.listdir('test_images/')]
        else:
            path = glob('%s/*' % (self.dataset_name))
            print("hello start")
            #print(path[0])
            #breakpoint()

            batch_images = np.random.choice(path, size=1, replace=False)
            #breakpoint()

        imgs_hr = []
        imgs_lr = []
        for img_path in batch_images:
            img = self.imread(img_path)
            #print("here")
            #print(img)
            #breakpoint()

            h, w = self.img_res
            low_h, low_w = int(h / 4), int(w / 4)

            #报错 module 'scipy.misc' has no attribute 'imresize'
            #img_hr = scipy.misc.imresize(img, self.img_res)
            img_lr = scipy.misc.imresize(img, (low_h, low_w))

            img_hr= np.array(Image.fromarray(np.uint8(img)).resize(self.img_res))
            img_lr = np.array(Image.fromarray(np.uint8(img)).resize((low_h, low_w)))

            # If training => do random flip
            if not is_testing and np.random.random() < 0.5:
                img_hr = np.fliplr(img_hr)
                img_lr = np.fliplr(img_lr)

            imgs_hr.append(img_hr)
            imgs_lr.append(img_lr)

        imgs_hr = np.array(imgs_hr) / 127.5 - 1.
        imgs_lr = np.array(imgs_lr) / 127.5 - 1.
        print("here")
        print(imgs_lr[0].shape)
        #breakpoint()

        return imgs_hr, imgs_lr


    def imread(self, path):
        #input_image = imageio.imread(path, pilmode='RGB').astype(np.float)
        return imageio.imread(path, pilmode='RGB').astype(np.float)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值