[深度学习-实践]CycleGAN的入门例子-Tensorflow2.1-keras

系列文章目录

深度学习GAN(一)之简单介绍
深度学习GAN(二)之DCGAN基于CIFAR10数据集的例子
深度学习GAN(三)之DCGAN基于手写体Mnist数据集的例子
深度学习GAN(四)之cGAN (Conditional GAN)的例子
深度学习GAN(五)之PIX2PIX GAN的例子
深度学习GAN(六)之CycleGAN的例子


1. 什么是CycleGAN

CycleGAN模型是在2017的这篇论文中提出的-Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.。

CycleGAN模型的好处是无需配对示例即可对其进行训练。 也就是说,为了训练模型,例如在转换之前和之后,不需要照片的示例。 白天和晚上都拍摄同一城市景观的照片。 取而代之的是,该模型能够使用来自每个域的照片集合,并提取和利用集合中图像的基础样式来执行翻译。

该模型体系结构由两个生成器模型组成:一个生成器(Generator-A)用于生成第一域(Domain-A)的图像,第二生成器(Generator-B)用于生成第二域(Domain-B)的图像 。

  • Generator-A -> Domain-A
  • Generator-B -> Domain-B

生成器模型执行图像转换,这意味着图像生成过程取决于输入图像,特别是来自其他domain的图像。 生成器A从Domain-B获取图像作为输入,生成器B从Domain-A获取图像作为输入。

  • Domain-B -> Generator-A -> Domain-A
  • Domain-A -> Generator-B -> Domain-B

每个生成器都有一个对应的判别器模型。 第一个判别器模型(Discriminator-A)从Domain-A获取真实图像,并从Generator-A生成图像,并预测它们是真实的还是假的。 第二个判别器模型(Discriminator-B)从Domain-B获取真实图像,并从Generator-B生成图像,并预测它们是真实的还是伪造的。

  • Domain-A -> Discriminator-A -> [Real/Fake]
  • Domain-B -> Generator-A -> Discriminator-A -> [Real/Fake]
  • Domain-B -> Discriminator-B -> [Real/Fake]
  • Domain-A -> Generator-B -> Discriminator-B -> [Real/Fake]

像常规GAN模型一样,判别器和生成器模型是在对抗性的零和过程中训练的。 生成器学会更好地欺骗判别器,判别器学会更好地检测伪造图像。 在一起,模型在训练过程中找到了平衡。

此外,生成器模型经过了规范化处理,不仅可以在目标域中创建新图像,还可以转换源域中输入图像的更多重构版本。 这是通过将生成的图像用作相应生成器模型的输入并将输出图像与原始图像进行比较来实现的。 通过两个生成器传递图像称为循环。 一起训练每一对生成器模型,以更好地重现原始源图像,称为循环一致性。

  • Domain-B -> Generator-A -> Domain-A -> Generator-B -> Domain-B
  • Domain-A -> Generator-B -> Domain-B -> Generator-A -> Domain-A

该体系结构还有另一个元素,称为身份映射。 在这里,为生成器提供了来自目标域的图像作为输入,并且可以生成相同的图像而无需更改。 尽管可以使输入图像的颜色配置文件更好地匹配,但是对体系结构的这种添加是可选的。

  • Domain-A -> Generator-A -> Domain-A
  • Domain-B -> Generator-B -> Domain-B

2. 数据集准备

我们用的数据集为“ horses2zebra”。 该数据集的zip文件约为111M,可以从CycleGAN网页下载:
Download Horses to Zebras Dataset (111 megabytes)

减压后你会看到这样的目录结构

horse2zebra
├── testA
├── testB
├── trainA
└── trainB

打开testA文件夹,里面都是马的图片
在这里插入图片描述
打开testB文件夹,里面都是斑马的图片
在这里插入图片描述
每张图片的大小都是256x256

下面的代码从train和test文件夹中加载所有照片,并为A类创建一个图像数组,为B类创建另一个图像。

然后将两个数组都以压缩的NumPy数组格式保存到新文件中 horse2zebra_256.npz。

from os import listdir
import numpy as np
from numpy import asarray
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img

# load all images in a directory into memory
def load_images(path, size=(256, 256)):
    data_list = list()
    # enumerate filenames in directory, assume all are images
    for filename in listdir(path):
        # load and resize the image
        pixels = load_img(path + filename, target_size=size)
        # convert to numpy array
        pixels = img_to_array(pixels)
        # store
        data_list.append(pixels)
    return asarray(data_list)


# dataset path
path = 'D:/ML/datasets/horse2zebra/'
# load dataset A
dataA1 = load_images(path + 'trainA/')
dataAB = load_images(path + 'testA/')
dataA = np.vstack((dataA1, dataAB))
print('Loaded dataA: ', dataA.shape)
# load dataset B
dataB1 = load_images(path + 'trainB/')
dataB2 = load_images(path + 'testB/')
dataB = np.vstack((dataB1, dataB2))
print('Loaded dataB: ', dataB.shape)
# save as compressed numpy array
filename = 'horse2zebra_256.npz'
np.savez_compressed(filename, dataA, dataB)
print('Saved dataset: ', filename)
Loaded dataA:  (1187, 256, 256, 3)
Loaded dataB:  (1474, 256, 256, 3)
Saved dataset:  horse2zebra_256.npz

应用下面代码加载horse2zebra_256.npz数据集,然后用matplotlib显示图片。

# load and plot the prepared dataset
from numpy import load
from matplotlib import pyplot
# load the dataset
data = load('horse2zebra_256.npz')
dataA, dataB = data['arr_0'], data['arr_1']
print('Loaded: ', dataA.shape, dataB.shape)
# plot source images
n_samples = 3
for i in range(n_samples):
	pyplot.subplot(2, n_samples, 1 + i)
	pyplot.axis('off')
	pyplot.imshow(dataA[i].astype('uint8'))
# plot target image
for i in range(n_samples):
	pyplot.subplot(2, n_samples, 1 + n_samples + i)
	pyplot.axis('off')
	pyplot.imshow(dataB[i].astype('uint8'))
pyplot.show()

在这里插入图片描述

3. 环境准备

我用的是tensorflow 2.1 与tensorflow_addons 0.9

pip install tensorflow-gpu==2.1.0
pip install tensorflow_addons==0.9.1

4. 怎么创建CycleGAN转换马到斑马

整个架构由四个模型组成,两个判别器模型和两个生成器模型

判别器是执行图像分类的深层卷积神经网络。它以源图像作为输入,并预测目标图像是真实图像还是伪图像的可能性。使用两种判别器模型,一种用于Domain A(马),一种用于Domain B(斑马)。

判别器设计基于模型的有效接收场,该有效接收场定义了模型的一个输出与输入图像中像素数之间的关系。这被称为PatchGAN模型,并经过精心设计,以使模型的每个输出预测都映射到输入图像的70×70正方形或小块。这种方法的好处是可以将相同的模型应用于不同大小的输入图像,例如大于或小于256×256像素。

模型的输出取决于输入图像的大小,但可以是一个值或值的平方激活图。每个值是输入图像中的色块是真实的可能性的概率。如果需要,可以将这些值取平均值以给出总体可能性或分类分数。

在模型中使用了卷积批处理范式LeakyReLU层的模式,这是深度卷积判别器模型所共有的。与其他模型不同,CycleGAN判别器使用InstanceNormalization而不是BatchNormalization。这是一种非常简单的归一化类型,涉及标准化(例如缩放到标准高斯)每个输出要素图上的值,而不是跨批处理中的要素。

4.1. 定义判别器

import tensorflow as tf

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose

from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout

from matplotlib import pyplot
from tensorflow.keras.layers import LeakyReLU
import tensorflow_addons as tfa
import numpy as np
from random import random

def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_image = Input(shape=image_shape)
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = tfa.layers.InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = tfa.layers.InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = tfa.layers.InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = tfa.layers.InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	patch_out = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	# define model
	model = Model(in_image, patch_out)
	# compile model
	model.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])
	return model

if __name__ == '__main__':
	module = define_discriminator((256,256,3))
	print(module.summary())
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 128, 128, 64)      3136      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 128, 128, 64)      0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 64, 64, 128)       131200    
_________________________________________________________________
instance_normalization (Inst (None, 64, 64, 128)       256       
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 64, 64, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 32, 256)       524544    
_________________________________________________________________
instance_normalization_1 (In (None, 32, 32, 256)       512       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 32, 32, 256)       0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 16, 16, 512)       2097664   
_________________________________________________________________
instance_normalization_2 (In (None, 16, 16, 512)       1024      
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 16, 16, 512)       0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 16, 16, 512)       4194816   
_________________________________________________________________
instance_normalization_3 (In (None, 16, 16, 512)       1024      
_________________________________________________________________
leaky_re_lu_4 (LeakyReLU)    (None, 16, 16, 512)       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 16, 16, 1)         8193      
=================================================================
Total params: 6,962,369
Trainable params: 6,962,369
Non-trainable params: 0

4.2. 定义生成器

生成器模型比判别器模型更复杂。

生成器是encoder-decoder模型架构。该模型获取源图像(例如,马的照片)并生成目标图像(例如,斑马的照片)。它首先通过对输入图像进行下采样或编码到bottleneck层,然后使用多个ResNet层来解释编码,然后通过一系列对输出图像进行上采样或解码以达到输出大小的层来完成此操作图片。

首先,我们需要一个函数来定义ResNet块。这些是由两个3×3 CNN层组成的块,其中,该块的输入在通道方向上串联到该块的输出。

这是在resnet_block()函数中实现的,该函数创建两个具有3×3过滤器和1×1跨度的Convolution-InstanceNorm块,并且在第二个块之后没有ReLU激活,与build_conv_block()函数中的正式Torch实现匹配。为简单起见,使用相同的填充代替了纸张中建议的反射填充。

# generator a resnet block
def resnet_block(n_filters, input_layer):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# first layer convolutional layer
	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# second convolutional layer
	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	# concatenate merge channel-wise with input layer
	g = Concatenate()([g, input_layer])
	return g

接下来,我们可以定义一个函数,该函数将为256×256输入图像创建9分辨率的块版本。 通过将image_shape设置为(128x128x3),将n_resnet函数参数设置为6,可以轻松将其更改为6分辨率块版本。

重要的是,该模型输出形状为输入的像素值,并且像素值在GAN生成器模型典型的[-1,1]范围内。

# define the standalone generator model
def define_generator(image_shape, n_resnet=9):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=image_shape)
	# c7s1-64
	g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# d128
	g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# d256
	g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# R256
	for _ in range(n_resnet):
		g = resnet_block(256, g)
	# u128
	g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# u64
	g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# c7s1-3
	g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	out_image = Activation('tanh')(g)
	# define model
	model = Model(in_image, out_image)
	return model
if __name__ == '__main__':
	module = define_generator((256,256,3))
	print(module.summary())
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 256, 256, 64) 9472        input_1[0][0]                    
__________________________________________________________________________________________________
instance_normalization (Instanc (None, 256, 256, 64) 128         conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 256, 256, 64) 0           instance_normalization[0][0]     
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 128, 128, 128 73856       activation[0][0]                 
__________________________________________________________________________________________________
instance_normalization_1 (Insta (None, 128, 128, 128 256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 128, 128, 128 0           instance_normalization_1[0][0]   
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 64, 64, 256)  295168      activation_1[0][0]               
__________________________________________________________________________________________________
instance_normalization_2 (Insta (None, 64, 64, 256)  512         conv2d_2[0][0]                   
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 64, 64, 256)  0           instance_normalization_2[0][0]   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 64, 64, 256)  590080      activation_2[0][0]               
__________________________________________________________________________________________________
instance_normalization_3 (Insta (None, 64, 64, 256)  512         conv2d_3[0][0]                   
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 64, 64, 256)  0           instance_normalization_3[0][0]   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 64, 64, 256)  590080      activation_3[0][0]               
__________________________________________________________________________________________________
instance_normalization_4 (Insta (None, 64, 64, 256)  512         conv2d_4[0][0]                   
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 64, 64, 512)  0           instance_normalization_4[0][0]   
                                                                 activation_2[0][0]               
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 64, 64, 256)  1179904     concatenate[0][0]                
__________________________________________________________________________________________________
instance_normalization_5 (Insta (None, 64, 64, 256)  512         conv2d_5[0][0]                   
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 64, 64, 256)  0           instance_normalization_5[0][0]   
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 64, 64, 256)  590080      activation_4[0][0]               
__________________________________________________________________________________________________
instance_normalization_6 (Insta (None, 64, 64, 256)  512         conv2d_6[0][0]                   
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 64, 64, 768)  0           instance_normalization_6[0][0]   
                                                                 concatenate[0][0]                
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 64, 64, 256)  1769728     concatenate_1[0][0]              
__________________________________________________________________________________________________
instance_normalization_7 (Insta (None, 64, 64, 256)  512         conv2d_7[0][0]                   
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 64, 64, 256)  0           instance_normalization_7[0][0]   
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 64, 64, 256)  590080      activation_5[0][0]               
__________________________________________________________________________________________________
instance_normalization_8 (Insta (None, 64, 64, 256)  512         conv2d_8[0][0]                   
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 64, 64, 1024) 0           instance_normalization_8[0][0]   
                                                                 concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 64, 64, 256)  2359552     concatenate_2[0][0]              
__________________________________________________________________________________________________
instance_normalization_9 (Insta (None, 64, 64, 256)  512         conv2d_9[0][0]                   
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 64, 64, 256)  0           instance_normalization_9[0][0]   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 64, 64, 256)  590080      activation_6[0][0]               
__________________________________________________________________________________________________
instance_normalization_10 (Inst (None, 64, 64, 256)  512         conv2d_10[0][0]                  
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 64, 64, 1280) 0           instance_normalization_10[0][0]  
                                                                 concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 64, 64, 256)  2949376     concatenate_3[0][0]              
__________________________________________________________________________________________________
instance_normalization_11 (Inst (None, 64, 64, 256)  512         conv2d_11[0][0]                  
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 64, 64, 256)  0           instance_normalization_11[0][0]  
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 64, 64, 256)  590080      activation_7[0][0]               
__________________________________________________________________________________________________
instance_normalization_12 (Inst (None, 64, 64, 256)  512         conv2d_12[0][0]                  
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 64, 64, 1536) 0           instance_normalization_12[0][0]  
                                                                 concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 64, 64, 256)  3539200     concatenate_4[0][0]              
__________________________________________________________________________________________________
instance_normalization_13 (Inst (None, 64, 64, 256)  512         conv2d_13[0][0]                  
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 64, 64, 256)  0           instance_normalization_13[0][0]  
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, 64, 64, 256)  590080      activation_8[0][0]               
__________________________________________________________________________________________________
instance_normalization_14 (Inst (None, 64, 64, 256)  512         conv2d_14[0][0]                  
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 64, 64, 1792) 0           instance_normalization_14[0][0]  
                                                                 concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, 64, 64, 256)  4129024     concatenate_5[0][0]              
__________________________________________________________________________________________________
instance_normalization_15 (Inst (None, 64, 64, 256)  512         conv2d_15[0][0]                  
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 64, 64, 256)  0           instance_normalization_15[0][0]  
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 64, 64, 256)  590080      activation_9[0][0]               
__________________________________________________________________________________________________
instance_normalization_16 (Inst (None, 64, 64, 256)  512         conv2d_16[0][0]                  
__________________________________________________________________________________________________
concatenate_6 (Concatenate)     (None, 64, 64, 2048) 0           instance_normalization_16[0][0]  
                                                                 concatenate_5[0][0]              
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, 64, 64, 256)  4718848     concatenate_6[0][0]              
__________________________________________________________________________________________________
instance_normalization_17 (Inst (None, 64, 64, 256)  512         conv2d_17[0][0]                  
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 64, 64, 256)  0           instance_normalization_17[0][0]  
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 64, 64, 256)  590080      activation_10[0][0]              
__________________________________________________________________________________________________
instance_normalization_18 (Inst (None, 64, 64, 256)  512         conv2d_18[0][0]                  
__________________________________________________________________________________________________
concatenate_7 (Concatenate)     (None, 64, 64, 2304) 0           instance_normalization_18[0][0]  
                                                                 concatenate_6[0][0]              
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 64, 64, 256)  5308672     concatenate_7[0][0]              
__________________________________________________________________________________________________
instance_normalization_19 (Inst (None, 64, 64, 256)  512         conv2d_19[0][0]                  
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 64, 64, 256)  0           instance_normalization_19[0][0]  
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 64, 64, 256)  590080      activation_11[0][0]              
__________________________________________________________________________________________________
instance_normalization_20 (Inst (None, 64, 64, 256)  512         conv2d_20[0][0]                  
__________________________________________________________________________________________________
concatenate_8 (Concatenate)     (None, 64, 64, 2560) 0           instance_normalization_20[0][0]  
                                                                 concatenate_7[0][0]              
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 128, 128, 128 2949248     concatenate_8[0][0]              
__________________________________________________________________________________________________
instance_normalization_21 (Inst (None, 128, 128, 128 256         conv2d_transpose[0][0]           
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 128, 128, 128 0           instance_normalization_21[0][0]  
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 256, 256, 64) 73792       activation_12[0][0]              
__________________________________________________________________________________________________
instance_normalization_22 (Inst (None, 256, 256, 64) 128         conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 256, 256, 64) 0           instance_normalization_22[0][0]  
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 256, 256, 3)  9411        activation_13[0][0]              
__________________________________________________________________________________________________
instance_normalization_23 (Inst (None, 256, 256, 3)  6           conv2d_21[0][0]                  
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 256, 256, 3)  0           instance_normalization_23[0][0]  
==================================================================================================
Total params: 35,276,553
Trainable params: 35,276,553
Non-trainable params: 0
__________________________________________________________________________________________________

判别器模型直接在真实和生成的图像上训练,而生成器模型则没有。

取而代之的是,生成器模型通过其相关的判别器模型进行训练。具体来说,对它们进行更新以最小化判别器预测的生成图像标记为“真实”的损失,称为对抗损失。因此,鼓励他们生成更适合目标域的图像。

生成器模型还基于与其他生成器模型(称为循环损失(cycle loss))一起使用时在源图像再生方面的有效性而更新。最终,当从目标域提供一个称为身份丢失的示例时,生成器模型有望不经翻译就输出图像。

总而言之,每种生成器模型都是通过下面四个损耗函数的一起进行优化的:

  • Adversarial loss (L2 or mean squared error).
  • Identity loss (L1 or mean absolute error).
  • Forward cycle loss (L1 or mean absolute error).
  • Backward cycle loss (L1 or mean absolute error).

均方误差(mean-square error, MSE)
平均绝对误差(Mean Absolute Error, MAE)。

这可以通过定义用于训练每个生成器模型的复合模型来实现,尽管该模型需要负责与相关的判别器模型和其他生成器模型共享权重,但是该复合模型仅负责更新该生成器模型的权重

这在下面的define_composite_model()函数中实现,该函数采用已定义的生成器模型(g_model_1)以及已生成的生成器模型输出(d_model)和其他生成器模型(g_model_2)的已定义判别器模型。其他模型的权重被标记为不可训练,因为我们只对更新第一个生成器模型感兴趣,即此复合模型的重点。

判别器连接到生成器的输出,以便将生成的图像分类为真实图像或伪图像。组合模型的第二个输入定义为来自目标域(而不是源域)的图像,生成器应在不进行身份映射转换的情况下输出该图像。接下来,正向循环损耗包括将发生器的输出连接到另一个发生器,这将重建源图像。最后,后向循环损耗涉及来自目标域的用于身份映射的图像,该图像也通过另一个生成器,该生成器的输出连接到我们的主生成器作为输入,并从目标域输出该图像的重建版本。

总而言之,一个复合模型具有两个输入,分别用于来自Domain-A和Domain-B的真实照片,以及四个输出,用于判别器输出,身份生成图像,正向循环生成图像和反向循环生成图像。

4.3. 定义复合模型

对于复合模型,仅更新第一模型或主模型的权重,这是通过所有损失函数的加权总和来完成的。如本文所述,对循环损失(cycle loss)的权重(10倍)比对抗损失(adversarial loss)更大,并且始终使用身份损失(identity loss),权重为循环丢失的一半(5倍),与官方实现源代码相匹配。

# define a composite model for updating generators by adversarial and cycle loss
def define_composite_model(g_model_1, d_model, g_model_2, image_shape):
	# ensure the model we're updating is trainable
	g_model_1.trainable = True
	# mark discriminator as not trainable
	d_model.trainable = False
	# mark other generator model as not trainable
	g_model_2.trainable = False
	# discriminator element
	input_gen = Input(shape=image_shape)
	gen1_out = g_model_1(input_gen)
	output_d = d_model(gen1_out)
	# identity element
	input_id = Input(shape=image_shape)
	output_id = g_model_1(input_id)
	# forward cycle
	output_f = g_model_2(gen1_out)
	# backward cycle
	gen2_out = g_model_2(input_id)
	output_b = g_model_1(gen2_out)
	# define model graph
	model = Model([input_gen, input_id], [output_d, output_id, output_f, output_b])
	# define optimization algorithm configuration
	opt = Adam(lr=0.0002, beta_1=0.5)
	# compile model with weighting of least squares loss and L1 loss
	model.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)
	return model

我们需要为每个生成器模型创建一个复合模型,例如 生成器A(B->A)用于斑马到马的转换,生成器B(A->B)用于马到斑马的转换。

损失函数的权重比例如下
Adversarial loss:Identity loss:Forward cycle loss Backward cycle loss = 1:5:10:10

跨两个域的所有这些前进和后退变得令人困惑。 以下是每个复合模型的所有输入和输出的完整列表。 同一性和循环损失计算为每个翻译序列在输入图像和输出图像之间的L1距离。 对抗损失计算为模型输出与目标值(真实值1.0和假值0.0)之间的L2距离。

1. Generator-A Composite Model (B ->A or Zebra to Horse)

下面是模型的输入 转换, 与输出:

  • Adversarial Loss: Domain-B -> Generator-A -> Domain-A -> Discriminator-A -> [real/fake]
  • Identity Loss: Domain-A -> Generator-A -> Domain-A
  • Forward Cycle Loss: Domain-B -> Generator-A -> Domain-A -> Generator-B -> Domain-B
  • Backward Cycle Loss: Domain-A -> Generator-B -> Domain-B ->Generator-A -> Domain-A

输入与输出如下:

Inputs: Domain-B, Domain-A
Outputs: Real, Domain-A, Domain-B, Domain-A

2. Generator-B Composite Model (A -> B or Horse -> Zebra)

下面是模型的输入 转换, 与输出:

  • Adversarial Loss: Domain-A -> Generator-B -> Domain-B -> Discriminator-B -> [real/fake]
  • Identity Loss: Domain-B -> Generator-B -> Domain-B
  • Forward Cycle Loss: Domain-A -> Generator-B -> Domain-B -> Generator-A -> Domain-A
  • Backward Cycle Loss: Domain-B -> Generator-A -> Domain-A -> Generator-B -> Domain-B

输入与输出如下:
Inputs: Domain-A, Domain-B
Outputs: Real, Domain-B, Domain-A, Domain-B

定义CycleGAN的模型是难的一部分。下面就是标准的GAN的训练。

接下来,我们可以以压缩的NumPy数组格式加载配对的图像数据集。 这将返回两个NumPy数组的列表:第一个用于源图像,第二个用于对应的目标图像。

4.4 加载真实图片以及生成假的图片

load_real_samples方法是加载真实图片。
generate_real_samples 方法是生成真实图片。每个数组标签都是1, shape是(16,16,1)
generate_fake_samples方法是利用生成器生成图片。每个数组标签都是0,shape是(16,16,1)
标签这里不一样,一般是数字,但是这里是shape为(16,16,1)三维数组。

# load and prepare training images
def load_real_samples(filename):
	# load the dataset
	data = np.load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5
	return [X1, X2]
# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
	# choose random instances
	ix = np.random.randint(0, dataset.shape[0], n_samples)
	# retrieve selected images
	X = dataset[ix]
	# generate 'real' class labels (1)
	y = np.ones((n_samples, patch_shape, patch_shape, 1))
	return X, y

# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, dataset, patch_shape):
	# generate fake instance
	X = g_model.predict(dataset)
	# create 'fake' class labels (0)
	y = np.zeros((len(X), patch_shape, patch_shape, 1))
	return X, y

4.4 保存模型

通常,GAN模型不会收敛。 相反,在生成器模型和判别器模型之间找到了平衡。 因此,我们不能轻易判断培训是否应该停止。 因此,我们可以保存模型并在训练期间(例如每一个或五个训练时期)使用它定期生成示例图像到图像的转换。

然后,我们可以在训练结束时查看生成的图像,并使用图像质量选择最终模型。

下面的save_models()函数会将每个生成器模型以H5格式保存到当前目录,包括文件名中的训练迭代编号。 这将需要安装h5py库。

# save the generator models to file
def save_models(step, g_model_AtoB, g_model_BtoA):
	# save the first generator model
	filename1 = 'g_model_AtoB_%06d.h5' % (step+1)
	g_model_AtoB.save(filename1)
	# save the second generator model
	filename2 = 'g_model_BtoA_%06d.h5' % (step+1)
	g_model_BtoA.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

4.5 利用生成器生成图片

下面的summary_performance()函数使用给定的生成器模型生成一些随机选择的源照片的翻译版本,并将图保存到文件中。

源图像绘制在第一行上,生成的图像绘制在第二行上。 同样,图文件名包含训练迭代编号。

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, trainX, name, n_samples=5):
	# select a sample of input images
	X_in, _ = generate_real_samples(trainX, n_samples, 0)
	# generate translated images
	X_out, _ = generate_fake_samples(g_model, X_in, 0)
	# scale all pixels from [-1,1] to [0,1]
	X_in = (X_in + 1) / 2.0
	X_out = (X_out + 1) / 2.0
	# plot real images
	for i in range(n_samples):
		pyplot.subplot(2, n_samples, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X_in[i])
	# plot translated image
	for i in range(n_samples):
		pyplot.subplot(2, n_samples, 1 + n_samples + i)
		pyplot.axis('off')
		pyplot.imshow(X_out[i])
	# save plot to file
	filename1 = '%s_generated_plot_%06d.png' % (name, (step+1))
	pyplot.savefig(filename1)
	pyplot.close()

4.6 更新图像池

判别器模型直接在真实和生成的图像上更新,尽管为了进一步管理判别器模型学习的速度,维护了一组虚假图像。

本文为每个判别器模型定义了一个由50个生成的图像组成的图像池,该模型首先被填充,并有可能通过替换现有图像将新图像添加到池中,或者直接使用生成的图像。 我们可以将其实现为每个判别符的Python图像列表,并使用下面的update_image_pool()函数维护每个池列表。

# update image pool for fake images
def update_image_pool(pool, images, max_size=50):
	selected = list()
	for image in images:
		if len(pool) < max_size:
			# stock the pool
			pool.append(image)
			selected.append(image)
		elif random() < 0.5:
			# use image, but don't add it to the pool
			selected.append(image)
		else:
			# replace an existing image and use replaced image
			ix = np.random.randint(0, len(pool))
			selected.append(pool[ix])
			pool[ix] = image
	return np.asarray(selected)

4.7 训练模型

下面的train()函数将所有六个模型(两个判别器,两个生成器和两个复合模型)与数据集一起作为参数进行训练。

批次大小固定为一幅图像,以匹配论文中的描述,并且模型适合100个纪元。假设马数据集具有1,187张图像,则一个时期定义为1,187批,并且训练迭代次数相同。使用两个生成器在每个时期生成图像,并且每五个时期或(1187 * 5)5,935个训练迭代会保存模型。

模型更新的顺序与官方的Torch实施相匹配。首先,从每个域中选择一批真实图像,然后为每个域生成一批伪图像。然后使用伪造的图像来更新每个鉴别者的伪造的图像池。

接下来,通过组合模型更新Generator-A模型(斑马到马),然后通过Discriminator-A模型(马)更新。然后更新Generator-B(马到斑马)复合模型和Discriminator-B(斑马)模型。

然后,在训练迭代结束时报告每个更新模型的损失。重要的是,仅报告用于更新每个发电机的加权平均损失。

# train cyclegan models
def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset):
	# define properties of the training run
	n_epochs, n_batch, = 100, 1
	# determine the output square shape of the discriminator
	n_patch = d_model_A.output_shape[1] # 16
	# unpack dataset
	trainA, trainB = dataset
	# prepare image pool for fakes
	poolA, poolB = list(), list()
	# calculate the number of batches per training epoch
	bat_per_epo = int(len(trainA) / n_batch) # 1187
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs # 1187 * 100
	# manually enumerate epochs
	for i in range(n_steps):
		# select a batch of real samples
		X_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch) # 1187, 1, 16
		X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)
		# generate a batch of fake samples
		X_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch) # B>A
		X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch) # A>B
		# update fakes from pool
		X_fakeA = update_image_pool(poolA, X_fakeA)
		X_fakeB = update_image_pool(poolB, X_fakeB)
		# update generator B->A via adversarial and cycle loss
		g_loss2, _, _, _, _ = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])
		# update discriminator for A -> [real/fake]
		dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)
		dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)
		# update generator A->B via adversarial and cycle loss
		g_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])
		# update discriminator for B -> [real/fake]
		dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)
		dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)
		# summarize performance
		print('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))
		# evaluate the model performance every so often
		if (i+1) % (bat_per_epo * 1) == 0:
			# plot A->B translation
			summarize_performance(i, g_model_AtoB, trainA, 'AtoB')
			# plot B->A translation
			summarize_performance(i, g_model_BtoA, trainB, 'BtoA')
		if (i+1) % (bat_per_epo * 5) == 0:
			# save the models
			save_models(i, g_model_AtoB, g_model_BtoA)

5. 训练结果

每次训练迭代都会打印该损失,包括真实示例和假示例的Discriminator-A损失(dA),真实示例和假示例的Discriminator-B损失(dB)以及Generator-AtoB和Generator-BtoA损失,每个都是 对抗性,同一性,前进和后退周期损失的加权平均值(g)。

如果判别器的损失变为零并停留了很长时间,请考虑重新开始训练,因为这是训练失败的一个例子。

>1, dA[2.284,0.678] dB[1.422,0.918] g[18.747,18.452]
>2, dA[2.129,1.226] dB[1.039,1.331] g[19.469,22.831]
>3, dA[1.644,3.909] dB[1.097,1.680] g[19.192,23.757]
>4, dA[1.427,1.757] dB[1.236,3.493] g[20.240,18.390]
>5, dA[1.737,0.808] dB[1.662,2.312] g[16.941,14.915]
...
>118696, dA[0.004,0.016] dB[0.001,0.001] g[2.623,2.359]
>118697, dA[0.001,0.028] dB[0.003,0.002] g[3.045,3.194]
>118698, dA[0.002,0.008] dB[0.001,0.002] g[2.685,2.071]
>118699, dA[0.010,0.010] dB[0.001,0.001] g[2.430,2.345]
>118700, dA[0.002,0.008] dB[0.000,0.004] g[2.487,2.169]
>Saved: g_model_AtoB_118700.h5 and g_model_BtoA_118700.h5

大概9个Epoch后的结果如下:
可以看出稍微有点变化,变化不是很明显。
在这里插入图片描述

大概50个Epoch后结果如下:
在这里插入图片描述
从斑马到马的翻译对于该模型的学习似乎更具挑战性,尽管在50至60个纪元后也开始产生一些可能的翻译。

如本文中所使用的,另外100个带有权重衰减的训练时期可以实现更好的质量结果,也许还有一个数据生成器可以系统地处理每个数据集而不是随机采样。
在这里插入图片描述

6. 完整代码

import tensorflow as tf

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose

from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout

from matplotlib import pyplot
from tensorflow.keras.layers import LeakyReLU
import tensorflow_addons as tfa
import numpy as np
from random import random

def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_image = Input(shape=image_shape)
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = tfa.layers.InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = tfa.layers.InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = tfa.layers.InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = tfa.layers.InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	patch_out = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	# define model
	model = Model(in_image, patch_out)
	# compile model
	model.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])
	return model

# generator a resnet block
def resnet_block(n_filters, input_layer):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# first layer convolutional layer
	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# second convolutional layer
	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	# concatenate merge channel-wise with input layer
	g = Concatenate()([g, input_layer])
	return g

# define the standalone generator model
def define_generator(image_shape, n_resnet=9):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=image_shape)
	# c7s1-64
	g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# d128
	g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# d256
	g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# R256
	for _ in range(n_resnet):
		g = resnet_block(256, g)
	# u128
	g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# u64
	g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# c7s1-3
	g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)
	g = tfa.layers.InstanceNormalization(axis=-1)(g)
	out_image = Activation('tanh')(g)
	# define model
	model = Model(in_image, out_image)
	return model

# define a composite model for updating generators by adversarial and cycle loss
def define_composite_model(g_model_1, d_model, g_model_2, image_shape):
	# ensure the model we're updating is trainable
	g_model_1.trainable = True
	# mark discriminator as not trainable
	d_model.trainable = False
	# mark other generator model as not trainable
	g_model_2.trainable = False
	# discriminator element
	input_gen = Input(shape=image_shape)
	gen1_out = g_model_1(input_gen) # A >B
	output_d = d_model(gen1_out)  # 识别B
	# identity element
	input_id = Input(shape=image_shape)
	output_id = g_model_1(input_id) # B >A
	# forward cycle
	output_f = g_model_2(gen1_out) # A>B >A
	# backward cycle
	gen2_out = g_model_2(input_id) # B >A
	output_b = g_model_1(gen2_out) # B> A >B
	# define model graph
	model = Model([input_gen, input_id], [output_d, output_id, output_f, output_b])
	# define optimization algorithm configuration
	opt = Adam(lr=0.0002, beta_1=0.5)
	# compile model with weighting of least squares loss and L1 loss
	model.compile(loss=['mse', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=opt)
	return model

# load and prepare training images
def load_real_samples(filename):
	# load the dataset
	data = np.load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5
	return [X1, X2]

# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
	# choose random instances
	ix = np.random.randint(0, dataset.shape[0], n_samples)
	# retrieve selected images
	X = dataset[ix]
	# generate 'real' class labels (1)
	y = np.ones((n_samples, patch_shape, patch_shape, 1)) #(1,16,16,1)
	return X, y

# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, dataset, patch_shape):
	# generate fake instance
	X = g_model.predict(dataset)
	# create 'fake' class labels (0)
	y = np.zeros((len(X), patch_shape, patch_shape, 1)) # (1, 16,16,1)
	return X, y

# save the generator models to file
def save_models(step, g_model_AtoB, g_model_BtoA):
	# save the first generator model
	filename1 = 'g_model_AtoB_%06d.h5' % (step+1)
	g_model_AtoB.save(filename1)
	# save the second generator model
	filename2 = 'g_model_BtoA_%06d.h5' % (step+1)
	g_model_BtoA.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, trainX, name, n_samples=5):
	# select a sample of input images
	X_in, _ = generate_real_samples(trainX, n_samples, 0)
	# generate translated images
	X_out, _ = generate_fake_samples(g_model, X_in, 0)
	# scale all pixels from [-1,1] to [0,1]
	X_in = (X_in + 1) / 2.0
	X_out = (X_out + 1) / 2.0
	# plot real images
	for i in range(n_samples):
		pyplot.subplot(2, n_samples, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X_in[i])
	# plot translated image
	for i in range(n_samples):
		pyplot.subplot(2, n_samples, 1 + n_samples + i)
		pyplot.axis('off')
		pyplot.imshow(X_out[i])
	# save plot to file
	filename1 = '%s_generated_plot_%06d.png' % (name, (step+1))
	pyplot.savefig(filename1)
	pyplot.close()

# update image pool for fake images
def update_image_pool(pool, images, max_size=50):
	selected = list()
	for image in images:
		if len(pool) < max_size:
			# stock the pool
			pool.append(image)
			selected.append(image)
		elif random() < 0.5:
			# use image, but don't add it to the pool
			selected.append(image)
		else:
			# replace an existing image and use replaced image
			ix = np.random.randint(0, len(pool))
			selected.append(pool[ix])
			pool[ix] = image
	return np.asarray(selected)

# train cyclegan models
def train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset):
	# define properties of the training run
	n_epochs, n_batch, = 100, 1
	# determine the output square shape of the discriminator
	n_patch = d_model_A.output_shape[1] # 16
	# unpack dataset
	trainA, trainB = dataset
	# prepare image pool for fakes
	poolA, poolB = list(), list()
	# calculate the number of batches per training epoch
	bat_per_epo = int(len(trainA) / n_batch) # 1187
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs # 1187 * 100
	# manually enumerate epochs
	for i in range(n_steps):
		# select a batch of real samples
		X_realA, y_realA = generate_real_samples(trainA, n_batch, n_patch) # 1187, 1, 16
		X_realB, y_realB = generate_real_samples(trainB, n_batch, n_patch)
		# generate a batch of fake samples
		X_fakeA, y_fakeA = generate_fake_samples(g_model_BtoA, X_realB, n_patch) # B>A
		X_fakeB, y_fakeB = generate_fake_samples(g_model_AtoB, X_realA, n_patch) # A>B
		# update fakes from pool
		X_fakeA = update_image_pool(poolA, X_fakeA)
		X_fakeB = update_image_pool(poolB, X_fakeB)
		# update generator B->A via adversarial and cycle loss
		g_loss2, _, _, _, _ = c_model_BtoA.train_on_batch([X_realB, X_realA], [y_realA, X_realA, X_realB, X_realA])
		# update discriminator for A -> [real/fake]
		dA_loss1 = d_model_A.train_on_batch(X_realA, y_realA)
		dA_loss2 = d_model_A.train_on_batch(X_fakeA, y_fakeA)
		# update generator A->B via adversarial and cycle loss
		g_loss1, _, _, _, _ = c_model_AtoB.train_on_batch([X_realA, X_realB], [y_realB, X_realB, X_realA, X_realB])
		# update discriminator for B -> [real/fake]
		dB_loss1 = d_model_B.train_on_batch(X_realB, y_realB)
		dB_loss2 = d_model_B.train_on_batch(X_fakeB, y_fakeB)
		# summarize performance
		print('>%d, dA[%.3f,%.3f] dB[%.3f,%.3f] g[%.3f,%.3f]' % (i+1, dA_loss1,dA_loss2, dB_loss1,dB_loss2, g_loss1,g_loss2))
		# evaluate the model performance every so often
		if (i+1) % (bat_per_epo * 1) == 0:
			# plot A->B translation
			summarize_performance(i, g_model_AtoB, trainA, 'AtoB')
			# plot B->A translation
			summarize_performance(i, g_model_BtoA, trainB, 'BtoA')
		if (i+1) % (bat_per_epo * 5) == 0:
			# save the models
			save_models(i, g_model_AtoB, g_model_BtoA)


if __name__ == '__main__':
	# load image data
	dataset = load_real_samples('horse2zebra_256.npz')
	print('Loaded', dataset[0].shape, dataset[1].shape)
	# define input shape based on the loaded dataset
	image_shape = dataset[0].shape[1:]
	# generator: A -> B
	g_model_AtoB = define_generator(image_shape)
	# generator: B -> A
	g_model_BtoA = define_generator(image_shape)
	# discriminator: A -> [real/fake]
	d_model_A = define_discriminator(image_shape)
	# discriminator: B -> [real/fake]
	d_model_B = define_discriminator(image_shape)
	# composite: A -> B -> [real/fake, A]
	c_model_AtoB = define_composite_model(g_model_AtoB, d_model_B, g_model_BtoA, image_shape)
	# composite: B -> A -> [real/fake, B]
	c_model_BtoA = define_composite_model(g_model_BtoA, d_model_A, g_model_AtoB, image_shape)
	# train models
	train(d_model_A, d_model_B, g_model_AtoB, g_model_BtoA, c_model_AtoB, c_model_BtoA, dataset)

7. 如何使用CycleGAN生成器执行图像转换

# example of using saved cyclegan models for image translation
from keras.models import load_model
from numpy import load
from numpy import vstack
from matplotlib import pyplot
from numpy.random import randint
#import tensorflow_addons as tfa
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
# load and prepare training images
def load_real_samples(filename):
    # load the dataset
    data = load(filename)
    # unpack arrays
    X1, X2 = data['arr_0'], data['arr_1']
    # scale from [0,255] to [-1,1]
    X1 = (X1 - 127.5) / 127.5
    X2 = (X2 - 127.5) / 127.5
    return [X1, X2]


# select a random sample of images from the dataset
def select_sample(dataset, n_samples):
    # choose random instances
    ix = randint(0, dataset.shape[0], n_samples)
    # retrieve selected images
    X = dataset[ix]
    return X


# plot the image, the translation, and the reconstruction
def show_plot(imagesX, imagesY1, imagesY2):
    images = vstack((imagesX, imagesY1, imagesY2))
    titles = ['Real', 'Generated', 'Reconstructed']
    # scale from [-1,1] to [0,1]
    images = (images + 1) / 2.0
    # plot images row by row
    for i in range(len(images)):
        # define subplot
        pyplot.subplot(1, len(images), 1 + i)
        # turn off axis
        pyplot.axis('off')
        # plot raw pixel data
        pyplot.imshow(images[i])
        # title
        pyplot.title(titles[i])
    pyplot.show()


# load dataset
A_data, B_data = load_real_samples('horse2zebra_256.npz')
print('Loaded', A_data.shape, B_data.shape)
# load the models
cust = {'InstanceNormalization': InstanceNormalization}
model_AtoB = load_model('g_model_AtoB_089025.h5', cust)
model_BtoA = load_model('g_model_BtoA_089025.h5', cust)
# plot A->B->A
A_real = select_sample(A_data, 1)
B_generated = model_AtoB.predict(A_real)
A_reconstructed = model_BtoA.predict(B_generated)
show_plot(A_real, B_generated, A_reconstructed)
# plot B->A->B
B_real = select_sample(B_data, 1)
A_generated = model_BtoA.predict(B_real)
B_reconstructed = model_AtoB.predict(A_generated)
show_plot(B_real, A_generated, B_reconstructed)

6. 个人总结

两个生成器保存下来的参数大小大概都是137M左右,这个权重其实很大的,所以这也说明了为什么训练的很慢的原因,我在GPU上跑,大概花1个半小时跑9个Epoch.

在这里插入图片描述

评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

茫茫人海一粒沙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值