pix2pix gan
There are times that we want to to transform an image into another style. Let’s say we have a fine collection of sketches. Our daily work is to color these black and white images.
有时候,我们希望将图像转换为另一种样式。 假设我们有一组草图。 我们的日常工作是为这些黑白图像着色。
It might be interesting if the number of tasks is small, but when it comes to hundreds of sketches a day, hmmm… maybe we need some help. This is where GAN comes to rescue. Generative Adversarial Network, or GAN, is a machine learning framework that aims to generate new data with the same distribution as the one in the training dataset. In this article, we will build a pix2pix GAN that takes an image as input, and later outputs another image.
如果任务数量很少,可能会很有趣,但是当涉及到每天数百个草图时,嗯……也许我们需要一些帮助。 这就是GAN救援的地方。 生成对抗网络(GAN)是一种机器学习框架,旨在生成与训练数据集中的分布相同的新数据。 在本文中,我们将构建一个pix2pix GAN,它将图像作为输入,然后输出另一个图像。
To break things down, we will go through these steps:
为了分解,我们将执行以下步骤:
- Prepare our data 准备我们的数据
- Build the network 建立网络
- Train the network 训练网络
- Test and see the results 测试并查看结果
准备我们的数据 (Prepare our data)
In image transformation, we need to have an original image and its expected transformed result. It is recommended to have more than thousands of this kind of before-after-pairs. (Yes, GAN needs a lot of image 😅) In this post, we will use data from this kaggle dataset.
在图像转换中,我们需要原始图像及其预期的转换结果。 建议拥有成千上万的此类前后配对。 (是的,GAN需要很多图像😅 )在本文中,我们将使用来自kaggle数据集的数据 。
The image pairs can be saved as a merged one like those in our dataset. They can also be separated in two folders, just make sure the order matches later when we process them 😉
图像对可以像我们数据集中的图像对一样保存为合并的图像对。 它们也可以分成两个文件夹,只要稍后处理它们时确保顺序匹配😉
Since the image pairs are merged in a single one, we first need to split them into sketch images and colored pictures:
由于图像对合并为一个图像对,因此我们首先需要将其分为草图图像和彩色图片:
from os import listdir
from numpy import asarray
from numpy import vstack
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from numpy import savez_compressed
def load_images(path, size=(256,512)):
src_list = list()
tar_list = list()
for filename in listdir(path):
# load and resize the image
pixels = load_img(path + filename, target_size=size) # images are in PIL formate
# convert to numpy array
pixels = img_to_array(pixels)
# split into colored and sketch. 256 comes from 512/2. The first part is colored while the rest is sketch
color_img, bw_img = pixels[:, :256], pixels[:, 256:]
src_list.append(bw_img)
tar_list.append(color_img)
return [asarray(src_list), asarray(tar_list)]
Having our splitting function, we can process the training dataset with the following code:
有了拆分功能,我们可以使用以下代码处理训练数据集:
path = "data/train/"
# load dataset
[src_images, tar_images] = load_images(path)
print('Loaded: ', src_images.shape, tar_images.shape)
# save as compressed numpy array
filename = 'gan_img_train.npz'
savez_compressed(filename, src_images, tar_images)
print('Saved dataset: ', filename)
As my machine is not powerful enough, I only kept 1508 images in the training set 🤫 Here is the output:
由于我的机器功能不够强大,我只在训练集中保留了1508张图像🤫这是输出:
Loaded: (1508, 256, 256, 3) (1508, 256, 256, 3)
Saved dataset: gan_img_train.npz
There should be a new file named “gan_img_train.npz” in your current directory. We can check our data before moving on 🔍
当前目录中应该有一个名为“ gan_img_train.npz”的新文件。 我们可以在继续🔍之前检查我们的数据
from numpy import load
from matplotlib import pyplot
# load the dataset
data = load(filename)
src_images, tar_images = data['arr_0'], data['arr_1']
print('Loaded: ', src_images.shape, tar_images.shape)
# plot source images
n_samples = 3
for i in range(n_samples):
pyplot.subplot(2, n_samples, 1 + i)
pyplot.axis('off')
pyplot.imshow(src_images[i].astype('uint8'))
# plot target image
for i in range(n_samples):
pyplot.subplot(2, n_samples, 1 + n_samples + i)
pyplot.axis('off')
pyplot.imshow(tar_images[i].astype('uint8'))
pyplot.show()
Nice! The source is our sketches, and the target is those colored ones. Ready to rock ’n’ roll 👊
真好! 源是我们的草图,目标是那些彩色的草图。 准备摇滚乐rock
建立网络 (Build the network)
A GAN network is composed of a generator and a discriminator. A generator is like a student, trying to mimic a masterpiece as his homework. The discriminator then serves as a teacher, giving feedbacks like Good ✅ (‘It looks real!’) or Bad ❌ (‘Nah, it’s so fake.’) to the student’s work. The student does his homework again and again, while the teacher tells him whether he is doing well each time. Once the teacher cannot distinguish the work of the student and the actual masterpiece, we then consider the student is able to create images that are good enough (Poor student 😿)
GAN网络由生成器和鉴别器组成。 发电机就像一个学生,试图模仿杰作作为他的家庭作业。 然后,判别器充当老师,向学生的工作提供诸如“好”(“看起来很真实!”)或“坏”(“不,这真是假”)的反馈。 学生一次又一次地做作业,而老师每次都告诉他自己是否做得很好。 一旦老师无法区分学生的作品和真正的杰作,我们便认为该学生能够创建足够好的图像(可怜的学生😿)
Q. If discriminator knows so well about how an image is good or bad, why don't it generate the image on its own?A. To be honest, yes it can generate images. However, since the discriminator excels at seeing a big picture, creating them from pixels becomes an arduous work for it. To generate an image using a discriminator, we have to solve the function argmax D(x), which aims to maximize the score of classifying real and generated images. This function turns out to be too complicated to address if we do not define any limitation. Since this limitation will also restrict the model's capacity, people find it easier to replace solving the argmax D(x) function with a separate generator.
👨🏼🏫 Here is a lecture about it.
1.定义鉴别符 (1. Define discriminator)
def define_discriminator(image_shape):
# weight initialization
init = RandomNormal(stddev=0.02)
# source image input
in_src_image = Input(shape=image_shape)
# target image input
in_target_image = Input(shape=image_shape)
# concatenate images channel-wise
merged = Concatenate()([in_src_image, in_target_image])
# C64
d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
d = LeakyReLU(alpha=0.2)(d)
# C128
d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# C256
d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# C512
d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# second last output layer
d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# patch output
d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
patch_out = Activation('sigmoid')(d)
# define model
model = Model([in_src_image, in_target_image], patch_out)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
return model
If you are interested in the difference between ReLU and Leaky ReLU, here is a well-explained article 🌈 In short, Leaky ReLU changes the slope of ReLU where x < 0, causing a leak and extending the range of ReLU.
From the code above, it’s not hard to tell that a discriminator is basically a classifier. It takes an input image then categorizes it into “Real” or “Fake”.
从上面的代码中,不难区分出一个鉴别器基本上就是一个分类器。 它拍摄输入图像,然后将其分类为“真实”或“伪造”。
2.定义生成器 (2. Define Generator)
As for generator, its structure is more complicated 😣 Its implementation combines a decoder and an encoder (Its a U-Net structure here). The encoder tries to break the input image down into smaller pieces. From these pieces, the decoder later tries to scale it up and generate a new image in the end.
对于生成器,其结构更加复杂😣其实现结合了解码器和编码器(此处为U-Net结构)。 编码器尝试将输入图像分解为较小的部分。 从这些片段中,解码器随后尝试将其放大并最终生成新图像。
def define_encoder_block(layer_in, n_filters, batchnorm=True):
# weight initialization
init = RandomNormal(stddev=0.02)
# add downsampling layer
g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
# conditionally add batch normalization
if batchnorm:
g = BatchNormalization()(g, training=True)
# leaky relu activation
g = LeakyReLU(alpha=0.2)(g)
return g
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
# weight initialization
init = RandomNormal(stddev=0.02)
# add upsampling layer
g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
# add batch normalization
g = BatchNormalization()(g, training=True)
# conditionally add dropout
if dropout:
g = Dropout(0.5)(g, training=True)
# merge with skip connection
g = Concatenate()([g, skip_in])
# relu activation
g = Activation('relu')(g)
return g
def define_generator(image_shape=(256,256,3)):
# weight initialization
init = RandomNormal(stddev=0.02)
# image input
in_image = Input(shape=image_shape)
# encoder model
e1 = define_encoder_block(in_image, 64, batchnorm=False)
e2 = define_encoder_block(e1, 128)
e3 = define_encoder_block(e2, 256)
e4 = define_encoder_block(e3, 512)
e5 = define_encoder_block(e4, 512)
e6 = define_encoder_block(e5, 512)
e7 = define_encoder_block(e6, 512)
# bottleneck, no batch norm and relu
b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
b = Activation('relu')(b)
# decoder model
d1 = decoder_block(b, e7, 512)
d2 = decoder_block(d1, e6, 512)
d3 = decoder_block(d2, e5, 512)
d4 = decoder_block(d3, e4, 512, dropout=False)
d5 = decoder_block(d4, e3, 256, dropout=False)
d6 = decoder_block(d5, e2, 128, dropout=False)
d7 = decoder_block(d6, e1, 64, dropout=False)
# output
g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
out_image = Activation('tanh')(g)
# define model
model = Model(in_image, out_image)
return model
You may also see UpSampling2D instead of Conv2DTranspose in other generators. The key difference between these two functions is that, UpSampling2D is just a simple scaling up of the image by using nearest neighbour or bilinear upsampling. But Conv2DTranspose will not only upsample its input but also learn what is the best upsampling kernel. (from this Stack Overflow answer)
3.定义GAN (3. Define GAN)
Putting discriminator and generator together, now we have our GAN. An interesting note here is that, we have to set discriminator as “not trainable”. Because if it is trainable, the generator will adjust discriminator’s weights and make it easier to fool it 😱 No, we don’t want this 🤪🤪🤪
将鉴别器和生成器放在一起,现在我们有了GAN。 有趣的是,我们必须将鉴别器设置为“不可训练”。 因为如果它是可训练的,则生成器将调整鉴别器的权重并使其更容易被欺骗😱不,我们不希望这样做🤪🤪🤪
def define_gan(g_model, d_model, image_shape):
# make weights in the discriminator not trainable
d_model.trainable = False
# define the source image
in_src = Input(shape=image_shape)
# connect the source image to the generator input
gen_out = g_model(in_src)
# connect the source input and generator output to the discriminator input
dis_out = d_model([in_src, gen_out])
# src image as input, generated image and classification output
model = Model(in_src, [dis_out, gen_out])
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
return model
训练网络 (Train the network)
We first need to load and process our training images:
我们首先需要加载和处理我们的训练图像:
def load_real_samples(filename):
# load compressed arrays
data = load(filename)
# unpack arrays. arr_0 is source array, while arr_1 is target array
X1, X2 = data['arr_0'], data['arr_1']
# scale from [0,255] to [-1,1]
X1 = (X1 - 127.5) / 127.5
X2 = (X2 - 127.5) / 127.5
return [X1, X2]
To train the discriminator, we need a lot of real and fake images 😬 In the following functions, generate_real_samples will give us some random samples with its expected transformed result. On the other hand, generate_fake_samples utilizes our generator network to create fake images based on its input. We label the expected result as 1, and 0 for those predicted by the generator to show discriminator that these images are fake.
为了训练鉴别器,我们需要大量真实和伪造的图像😬在以下函数中, generate_real_samples将为我们提供一些具有预期转换结果的随机样本。 另一方面, generate_fake_samples利用我们的生成器网络根据其输入来创建伪图像。 对于生成器预测的结果,我们将预期结果标记为1,对于生成器预测的结果,标记为0,以表明辨别器这些图像是伪造的。
# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
# unpack dataset (source, target)
trainA, trainB = dataset
# choose random instances
ix = randint(0, trainA.shape[0], n_samples)
# retrieve selected images
X1, X2 = trainA[ix], trainB[ix]
# generate 'real' class labels (1)
y = ones((n_samples, patch_shape, patch_shape, 1))
return [X1, X2], y
# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):
# generate fake instance
X = g_model.predict(samples)
# create 'fake' class labels (0)
y = zeros((len(X), patch_shape, patch_shape, 1))
return X, y
Almost there! It’s always helpful to know how our network performs during the training process. So here we write a function that saves the model’s weight into a h5 file. It also creates a plot comparing the real input, output and the generated image:
差不多了! 了解我们的网络在培训过程中的表现总是有帮助的。 因此,在这里我们编写了一个将模型的权重保存到h5文件中的函数。 它还创建一个比较实际输入,输出和生成的图像的图:
def summarize_performance(step, g_model, dataset, n_samples=3):
# select a sample of input images
[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)
# generate a batch of fake samples
X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)
# scale all pixels from [-1,1] to [0,1]
X_realA = (X_realA + 1) / 2.0
X_realB = (X_realB + 1) / 2.0
X_fakeB = (X_fakeB + 1) / 2.0
# plot real source images
for i in range(n_samples):
pyplot.subplot(3, n_samples, 1 + i)
pyplot.axis('off')
pyplot.imshow(X_realA[i])
# plot generated target image
for i in range(n_samples):
pyplot.subplot(3, n_samples, 1 + n_samples + i)
pyplot.axis('off')
pyplot.imshow(X_fakeB[i])
# plot real target image
for i in range(n_samples):
pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)
pyplot.axis('off')
pyplot.imshow(X_realB[i])
# save plot to file
plot_name = 'plot_%06d.png' % (step+1)
pyplot.savefig(plot_name)
pyplot.close()
# save the generator model
model_name = 'model_%06d.h5' % (step+1)
g_model.save(model_name)
print('>Saved: %s and %s' % (plot_name, model_name))
And yay, let’s train it 🏋️♀️
是的,让我们训练它🏋️♀️
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):
# determine the output square shape of the discriminator
n_patch = d_model.output_shape[1]
# unpack dataset
trainA, trainB = dataset
# calculate the number of batches per training epoch
bat_per_epo = int(len(trainA) / n_batch)
# calculate the number of training iterations
n_steps = bat_per_epo * n_epochs
# manually enumerate epochs
for i in range(n_steps):
# select a batch of real samples
[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
# generate a batch of fake samples
X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
# update discriminator for real samples
d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
# update discriminator for generated samples
d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
# update the generator
g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
# summarize performance
print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))
# summarize model performance
if (i+1) % (bat_per_epo * 10) == 0:
summarize_performance(i, g_model, dataset)
In the end, there will be 10 model files and 10 plots indicating how the training performs. From our last plot, the GAN networks seems to function well 🥰
最后,将有10个模型文件和10个图,指示训练的执行方式。 从我们的上一个情节来看,GAN网络似乎运行良好🥰
测试并查看结果 (Test and see the results)
Having a trained model, let’s test it with some images it has never learned before. We’ll also use the ones from this kaggle dataset so the preprocessing functions defined above can be reused. (In this case, the testing images are processed and saved in “gan_img_test.npz”) The function below creates plots to help us compare the results with the expected output.
拥有训练有素的模型,让我们用一些以前从未学过的图像对其进行测试。 我们还将使用kaggle数据集中的数据,以便可以重复使用上面定义的预处理功能。 (在这种情况下,测试图像将被处理并保存在“ gan_img_test.npz ”中。)下面的函数创建图表以帮助我们将结果与预期输出进行比较。
from keras.models import load_model
from numpy import load
from numpy import vstack
from matplotlib import pyplot
from numpy.random import randint
def plot_images(src_img, gen_img, tar_img):
images = vstack((src_img, gen_img, tar_img))
# scale from [-1,1] to [0,1]
images = (images + 1) / 2.0
titles = ['Source', 'Generated', 'Expected']
# plot images row by row
for i in range(len(images)):
# define subplot
pyplot.subplot(1, 3, 1 + i)
# turn off axis
pyplot.axis('off')
# plot raw pixel data
pyplot.imshow(images[i])
# show title
pyplot.title(titles[i])
pyplot.show()
Testing gogogo 🤸♀️ ⛹️♀️
测试gogogo ♀♀️ ⛹️♀️
# load dataset
[X1, X2] = load_real_samples("gan_img_test.npz")
print('Loaded', X1.shape, X2.shape)
# load model, put the name of the last trained model
model = load_model('model_150800.h5')
for i in range(len(X1)):
src_image, tar_image = X1[[i]], X2[[i]]
# generate image from source
gen_image = model.predict(src_image)
# plot all three images
plot_images(src_image, gen_image, tar_image)
haha, some of them look quite messy 🤪 But since I only use 1508 images here, putting more images in the training dataset will for sure generate a more promising result.
哈哈,其中一些看起来很凌乱。但是由于我在这里只使用了1508张图像,因此将更多图像放到训练数据集中肯定会产生更好的结果。
I mostly followed this post to reproduce the implementations above, so feel free to go back the the original work for a more detailed explanation ☘️
我主要是按照这篇文章来重现上面的实现,所以请随时返回原始工作以获取更详细的解释☘️
翻译自: https://medium.com/@wendeehsu/build-a-pix2pix-gan-with-python-6db841b302c7
pix2pix gan
这篇博客介绍了如何利用Python构建Pix2Pix生成对抗网络(GAN),详细翻译自一篇Medium文章,展示了从理论到实践的完整过程。

1718

被折叠的 条评论
为什么被折叠?



