生成对抗网络用于分类_生成对抗网络

本文介绍了如何将生成对抗网络(GANs)应用于分类任务,详细探讨了这一机器学习技术在该领域的应用。
摘要由CSDN通过智能技术生成

生成对抗网络用于分类

GANS系列 (GANS SERIES)

This article is a part of the Gans-Series published by me on TowardsDataScience Publication on Medium. If you do not know what GANs are or if you have an idea about it but wish to quickly go over it again, I highly recommend you read the previous article which is just a 7 minutes long read and provides a simple understanding of GANs for people who are new to this amazing domain of Deep Learning.

本文是我在TowardsDataScience出版的Medium上发布的Gans系列的一部分。 如果您不知道GAN是什么,或者您对GAN有一个想法,但希望快速浏览一下,我强烈建议您阅读上一篇文章 本书只有7分钟的阅读时间,它为那些在深度学习这个令人惊奇的领域中不熟悉的人们提供了对GAN的简单理解。

As you can tell from the gif shown above, this article is going to be all about learning how to create a Conditional GAN to predict colorful images from the given black and white sketch inputs without knowing the actual ground truth.

从上面显示的gif可以看出,本文将全部学习如何创建条件GAN来从给定的黑白草图输入中预测彩色图像,而又不知道实际的地面真实情况。

在进入编码模式之前,需要了解一些知识…… (A little bit of need-to-know stuff before bringing on the coding mode…)

Sketch to Color Image generation is an image-to-image translation model using Conditional Generative Adversarial Networks as described in the original paper by Phillip Isola, Jun-Yan Zhu, Tinghui Zhou, Alexei A. Efros 2016, Image-to-Image Translation with Conditional Adversarial Networks.

素描到彩色图像生成是使用条件生成对抗网络的图像到图像转换模型,如Phillip Isola,Jun-Yan Zhu,Tinghui Zhou,Alexei A.Efros 2016, 条件对抗网络的图像到图像翻译

When I first came across this paper, it was amazing to see such great results shown by the authors and the fundamental idea was amazing on its own too.

当我初次接触这篇论文时,看到作者展示如此出色的结果真是太神奇了,而这个基本思想本身也很令人惊奇。

应用领域 (APPLICATIONS)

There are a lot of application scenarios of Conditional GANs that are depicted in the original paper by the authors. Some of which are listed below.

作者在原始论文中描述了很多条件GAN的应用场景。 下面列出了其中一些。

  • Map to Aerial Photos and vice versa

    映射到航空照片,反之亦然

  • Cityscapes to Photos

    城市景观到照片

  • Building Facades Labels to Photos

    建筑立面标签的照片

  • Daylight Photos to Night

    白天照片到晚上

  • Photo Inpainting

    照片修补

  • Sketch to Color Images (the one which we are going to build in this article)

    草绘为彩色图像 (本文将要构建的 图像 )

We are going to build a Conditional Generative Adversarial Network which accepts a 256x256 px black and white sketch image and predicts the colored version of the image without knowing the ground truth. The model will be trained on the Anime Sketch-Colorization Pair Dataset available on Kaggle which contains 14.2k pairs of Sketch-Color Anime Images.

我们将建立一个条件生成对抗网络,该网络接受256x256 px的黑白素描图像,并在不知道地面真实性的情况下预测图像的彩色版本。 该模型将在Kaggle上可用的“ 动漫素描-着色对”数据集上进行训练,该数据集包含14.2k对素描色彩的动漫图像。

When I trained the model on my system, I ran the model for 150 epochs which took approximately 23 hours on a single GeForce GTX 1060 6GB Graphic Card and 16 GB RAM. After all that hard work and patience, the results were totally worth it!

当我在系统上训练模型时,我在单个GeForce GTX 1060 6GB图形卡和16 GB RAM上运行了150个历时,耗时约23个小时。 经过所有的努力和耐心,结果完全值得!

Image for post
Outputs of the Trained Generator Model (Image by Author)
训练有素的发电机模型的输出(作者提供的图像)

现在让我们进入有趣的部分…… (Let’s get to the interesting part now…)

To build this model I have used TensorFlow 2.x and most of the code is based on their awesome tutorial on Pix2Pix for CMP Facade Dataset which predicts building photos from facade labels. TensorFlow tutorials are a good way to understand the framework and work on some well-known projects. I highly recommend you to go through all the tutorials on the website — https://www.tensorflow.org/tutorials.

为了构建该模型,我使用了TensorFlow 2.x,并且大多数代码都基于他们在CMP Facade Dataset上关于Pix2Pix的出色教程 从立面标签预测建筑物的照片。 TensorFlow教程是了解框架并在一些知名项目上工作的好方法。 我强烈建议您阅读网站上的所有教程-https: //www.tensorflow.org/tutorials

要求 (REQUIREMENTS)

To build this model, there are some basic requirements that you need to install on your system in order for it to work properly.

要构建此模型,您需要在系统上安装一些基本要求,以使其正常工作。

If you are planning on using any cloud environments like Google Colab, you need to keep in mind that the training is going to take a lot of time as GANs are computationally quite heavy to run. Google Colab has an absolute timeout of 12 hours which means that the notebook kernel is reset so you’ll need to consider some points like mounting the Google Drive and saving checkpoints after regular intervals so that you can continue training from where it left off before the timeout.

如果您打算使用Google Colab之类的任何云环境,则需要记住,由于GAN在计算上非常繁琐,因此培训将花费大量时间。 Google Colab的绝对超时时间为12小时,这意味着将重置笔记本内核,因此您需要考虑一些注意事项,例如安装Google云端硬盘以及定期间隔后保存检查点,以便您可以从停止之前的位置继续进行培训。超时。

资料下载 (DOWNLOADING THE DATASET)

Download the Anime Sketch-Colorization Pair Dataset available on Kaggle and save it to a folder directory. The root folder will contain folders colorgram , train , and val . For everyone’s convenience let us call the path to the root folder as path/to/dataset/ .

下载Kaggle上可用的“ Animation Sketch-Colorization Pair”数据集并将其保存到文件夹目录中。 根文件夹将包含文件夹colorgramtrainval 。 为了大家的方便,让我们将根文件夹的path/to/dataset/称为path/to/dataset/

Once the basic requirements are checked and the dataset is downloaded to your machine, it’s time for you to get into coding your very own Conditional GAN.

一旦检查了基本要求并将数据集下载到您的机器上,您就该开始编写自己的条件GAN了。

Before we jump right into it, note that the code which I’m going to provide shouldn't just be copied and pasted from here if you wish to understand the basic working behind it. And do not hesitate to ask your queries because that’s how things are learned — by asking.

在我们继续之前,请注意,如果您希望了解其背后的基本工作,则不应仅从此处复制并粘贴我将提供的代码。 并且不要犹豫地询问您的问题,因为这就是通过询问来学习事物的方式。

最后,代码! (FINALLY, THE CODE!)

First, let’s initialize the parameters to configure the training of the model. As stated earlier, we will be using the TensorFlow framework so we’ll need to import it by using import tensorflow as tf .

首先,让我们初始化参数以配置模型的训练。 如前所述,我们将使用TensorFlow框架,因此我们需要通过使用import tensorflow as tf来导入它。

The os module is used to interact with the Operating System. We are going to use this for accessing and modifying the path variables to save checkpoints during training. The time module lets us display relative time and hence, we can check how much time each epoch took during the training.

os模块用于与操作系统进行交互。 我们将使用它来访问和修改路径变量,以在训练期间保存检查点。 time模块使我们可以显示相对时间,因此,我们可以检查每个时期在训练中花费了多少时间。

matplotlib is another cool python library which we will be using to plot and show images.

matplotlib是另一个很酷的python库,我们将使用它来绘制和显示图像。

BUFFER_SIZE is used when we shuffle the data samples while training. Higher the value of this more will be the degree of shuffling, and hence, higher will be the accuracy of the model. But with large data, it takes a lot of processing power to shuffle the images. For my system with Intel(R) Core(TM) i7–8750H CPU and 16 GB of RAM, it was possible to set it equal to the size of the train dataset samples i.e. 14,224.

当我们在训练期间对数据样本进行混洗时,将使用BUFFER_SIZE 。 此值越高,混洗程度越多,因此模型的精度也越高。 但是,对于大数据,需要大量处理能力才能对图像进行混洗。 对于带有Intel®Core™i7–8750H CPU和16 GB RAM的系统,可以将其设置为等于火车数据集样本的大小,即14,224。

Nshuffle() is when you set buffer_size equal to the size of data samples. This way it takes all the samples [in this case 14,224] in the primary memory and chooses a random one from those. If you set it to 10, it’ll take the 10 samples in the memory and choose a random one from those 10 samples and then repeat it for other remaining examples. So, check you machine capabilities and find out the sweet spot.

N shuffle()是将buffer_size设置为等于数据样本的大小时。 这样,它将所有样本[在这种情况下为14,224]都存储在主内存中,并从这些样本中选择一个随机样本。 如果将其设置为10,它将在内存中获取10个样本,并从这10个样本中选择一个随机样本,然后对其他剩余示例重复此步骤。 因此,请检查您的机器功能并找出最有效的方法。

BATCH_SIZE is used to divide the dataset into mini-batches for training. The higher this value is, the faster will be the process of training. But as you might have guessed already, higher batch size means a higher load on the machine.

BATCH_SIZE用于将数据集分为迷你批进行训练。 该值越高,训练过程就越快。 但是,正如您可能已经猜到的那样,较大的批处理大小意味着机器上的负载较高。

Now, if you take a look at the dataset, you have a single image of size 1024x512 px for one entry which has a colored image of size 512x512 px in the left and a black and white sketch image of size 512x512 px in the right.

现在,如果您查看数据集,则一个条目的单个图像尺寸为1024x512 px,左侧为彩色图像,尺寸为512x512 px,右侧为黑白草图图像,尺寸为512x512 px。

Image for post
Anime Sketch-Colorization Pair DatasetAnime Sketch-Colorization Pair数据集的图像

We will define a function load() that takes the image path as a parameter and returns an input_image which is the black and white sketch that we’ll give as an input to the model, and real_image which is the colored image that we want.

我们将定义一个函数load() ,该函数将图像路径作为参数,并返回一个input_image它是我们将作为模型输入的黑白草图)和real_image它是我们想要的彩色图像)。

前处理 (Preprocessing)

Now that we have the data loaded, we need to do some preprocessing in order to prepare the data for the model.

现在已经加载了数据,我们需要进行一些预处理,以便为模型准备数据。

Given below are a few easy functions used for this purpose.

下面给出了一些用于此目的的简单函数。

resize() function is used to return the images as 286x286 px. This is done in order to have a uniform image size if by chance there is a differently sized image in the dataset. And decreasing size from 512x512 px to half of it also helps in speeding up the model training as it is computationally less heavy.

resize()函数用于将图像返回为286x286 px。 如果偶然在数据集中有不同大小的图像,则这样做是为了获得统一的图像大小。 尺寸从512x512 px减小到一半也有助于加快模型训练的速度,因为它的计算量较小。

random_crop() function returns the cropped input and real images which have the desired size of 256x256 px.

random_crop()函数返回裁剪后的输入图像和真实图像,这些图像的期望大小为256x256 px。

normalize() function, as the name suggests, normalizes images to [-1, 1].

顾名思义, normalize()函数将图像标准化为[-1,1]。

In the random_jitter() function shown above, all the previous preprocessing functions are put together and random images are flipped horizontally. You can see what the preprocessing of data returns from images given below.

在上面显示的random_jitter()函数中,所有先前的预处理函数放在一起,并且随机翻转图像。 您可以从下面给出的图像中看到对数据的预处理返回的内容。

Image for post
Preprocessed Images (Image by Author)
预处理图像(作者提供的图像)

加载火车和测试数据 (Loading the Train & Test Data)

load_image_train() function is used to put together all the previously seen functions and output the final preprocessed image.

load_image_train()函数用于将所有先前看到的函数放在一起,并输出最终的预处理图像。

tf.data.Dataset.list_files() collects the path to all the png files available in the train/ folder of the dataset. Then the collection of these paths is mapped through and every path is sent individually as an argument to the load_image_train() function which returns the final preprocessed image and adds it to the train_dataset .

tf.data.Dataset.list_files()收集数据集的train/文件夹中所有可用png文件的路径。 然后映射这些路径的集合,并将每个路径作为参数单独发送到load_image_train()函数,该函数返回最终的预处理图像并将其添加到train_dataset

Finally, this train_dataset is shuffled using the BUFFER_SIZE and then divided into mini-batches as discussed earlier.

最后,使用BUFFER_SIZE对该train_dataset进行混洗,然后将其分成迷你批处理,如前所述。

To load the test dataset, we will use a similar process except for a small change. Here we will omit the random_crop() and random_jitter() functions as there is no need to do this for testing the results. Also, we can omit to shuffle the dataset for the same reason.

要加载测试数据集,我们将使用类似的过程,只是有一点点变化。 这里我们将省略random_crop()random_jitter()函数,因为不需要这样做来测试结果。 同样,出于相同的原因,我们可以省略对数据集进行混洗。

建立发电机模型 (Building the Generator Model)

Let us build the generator model now which takes an input black and white sketch image of 256x256 px and outputs an image that hopefully resembles the colored ground truth image in the training dataset.

现在,让我们构建生成器模型,该模型将输入256x256 px的黑白草图图像,并输出希望与训练数据集中的彩色地面真实图像相似的图像。

The Generator model is a UNet Architecture Model and has skip connections to other layers than the intermediate one. Take note that it becomes complex to design such an architecture as the output and input shapes need to match to the connected layers, so design this carefully.

生成器模型是UNet体系结构模型,具有到中间层以外的其他层的跳过连接。 请注意,由于需要将输出和输入形状与连接的层相匹配,因此设计这种体系结构会变得很复杂,因此请谨慎设计。

The downsampling stack of layers has Convolutional layers which result in a decrease in the size of the input image. And once the decreased image goes through the upsampling stack of layers which has kind of “reverse” Convolutional layers, the size is restored back to 256x256 px. Hence, the output of the Generator Model is a 256x256 px image with 3 output channels.

下采样堆栈层具有卷积层,这会导致输入图像的大小减小。 并且,一旦减小的图像经过具有“反向”卷积层的向上采样层堆栈,其大小将恢复为256x256 px。 因此,生成器模型的输出是具有3个输出通道的256x256 px图像。

You can take a look at the model summary which is given below.

您可以看一下下面给出的模型摘要。

Image for post
Model Summary for Generator Model (Image by Author)
生成器模型的模型摘要(作者提供的图像)

建立鉴别模型 (Building the Discriminator Model)

The primary purpose of the discriminator model is to find out which image is from the actual training dataset and which is an output from the generator model.

鉴别器模型的主要目的是找出实际训练数据集中的图像,以及生成器模型的输出。

You can take a look at the model summary of the Discriminator given below. This is not as complex as the Generator model as it’s fundamental task is just to classify real and fake images.

您可以查看下面给出的鉴别器的模型摘要。 这并不像Generator模型那样复杂,因为它的基本任务只是对真实和伪造图像进行分类。

Image for post
Discriminator Model Summary (Image by Author)
鉴别器模型摘要(作者提供)

模型的损失函数 (Loss Functions for the Models)

As we have two models with us, we are going to require two different loss functions to calculate their loss independently.

由于我们拥有两个模型,因此我们将需要两个不同的损失函数来独立计算其损失。

The loss for the generator is calculated by finding the sigmoid cross-entropy loss of the output of the generator and an array of ones. This means that we are training it to trick the discriminator in outputting the value as 1, which means that it is a real image. Also, for the output to be structurally similar to the target image, we take L1 loss along with it. The value of LAMBDA is suggested to be kept 100 by authors of the original paper.

通过找到发电机输出和一组S形信号的S型交叉熵损耗,可以计算出发电机的损耗。 这意味着我们正在训练它欺骗鉴别器,以将值输出为1,这意味着它是真实图像。 同样,为了使输出在结构上与目标图像相似,我们将L1损失与损失一起考虑。 原始论文的作者建议将LAMBDA的值保持100。

For discriminator loss, we take the same sigmoid cross-entropy loss of the real images and an array of ones and add it with the cross-entropy loss of the output images of the generator model and array of zeros.

对于鉴别器损耗,我们对真实图像和一个数组进行相同的S形交叉熵损耗,并将其与生成器模型和零数组的输出图像的交叉熵损耗相加。

优化器 (Optimizers)

Optimizers are algorithms or methods used to change the attributes of your neural network such as weights and learning rates in order to reduce the losses. Adam Optimizer is one of the best ones to use, in most of the use cases.

优化器是用于更改神经网络属性(例如权重和学习率)以减少损失的算法或方法。 在大多数用例中,Adam Optimizer是最适合使用的软件之一。

创建检查点 (Creating Checkpoints)

As discussed earlier, cloud environments have a specific timeout which can interrupt the training process. Also, if you are using your local system, there may arise some cases where the training might be interrupted due to some reasons.

如前所述,云环境具有特定的超时时间,可能会中断训练过程。 另外,如果您使用的是本地系统,则在某些情况下,由于某些原因,培训可能会中断。

GANs take a very long time to train and are computationally expensive. So, it is best to keep saving checkpoints at regular intervals so that you can restore to the latest checkpoint and continue from there without losing the previously done hard work by your machines.

GAN训练时间很长,而且计算量很大。 因此,最好保持定期保存检查点,以便您可以还原到最新的检查点并从那里继续操作,而不会丢失机器之前完成的辛苦工作。

显示输出图像 (Displaying Output Images)

The above-given block of code is a basic python function which uses the pyplot module from matplotlib library to display the predicted images by the generator model.

上面给出的代码块是一个基本的python函数,该函数使用matplotlib库中的pyplot模块通过生成器模型显示预测的图像。

Image for post
Displaying the predicted image by untrained Generator Model (Image by Author)
通过未经训练的生成器模型显示预测图像(作者提供的图像)

记录损失 (Logging the Losses)

You can log the important metrics like losses in a file so that you can analyze it as the training progresses on tools like Tensorboard.

您可以在文件中记录重要指标(例如损失),以便在培训过程中使用Tensorboard等工具进行分析。

训练步骤 (Train Step)

A basic train step will consist of the following processes:

基本训练步骤将包括以下过程:

  • The generator outputs a prediction

    生成器输出预测
  • The discriminator model is designed to have 2 inputs at a time. For the first time, it is given an input sketch image and the generated image. The next time it is given the real target image and the generated image.

    鉴别器模型设计为一次具有2个输入。 首次为它提供输入草图图像和生成的图像。 下次给它真实的目标图像和生成的图像。
  • Now the generator loss and discriminator loss are calculated.

    现在计算出发电机损耗和鉴别器损耗。
  • Then, the gradients are calculated from the losses and applied to the optimizers to help the generator produce a better image and also to help discriminator detect the real and generated image with better insights.

    然后,从损耗中计算出梯度并将其应用于优化器,以帮助生成器生成更好的图像,并帮助鉴别器更好地洞察真实和生成的图像。
  • All the losses are logged using summary_writer defined previously using tf.summary .

    使用先前使用tf.summary定义的summary_writer记录所有损失。

Model.fit() (Model.fit())

TensorFlow is an awesome, easy to use framework for training models. And one small command like model.fit() does the magic for us.

TensorFlow是一个很棒的,易于使用的训练模型框架。 一个像model.fit()这样的小命令对我们来说是神奇的。

Unfortunately, it will not directly work over here as we have created two models that work together. But it is pretty easy to do this too.

不幸的是,由于我们已经创建了两个可以协同工作的模型,因此无法直接在这里工作。 但是也很容易做到这一点。

Here we iterate over for every epoch and assign the relative time to start variable. Then we display an example of the generated image by the generator model. This example helps us visualize how the generator gets better at generating better-colored images with every epoch. Then we call the train_step function for the model to learn from the calculated losses and gradients. And finally, we check if the epoch number is divisible by 5 to save a checkpoint. This means that we are saving a checkpoint after every 5 epochs of training are completed. After this entire epoch is completed, the start time is subtracted from the final relative time to count the time taken for that particular epoch.

在这里,我们针对每个时期进行迭代,并分配start变量的相对时间。 然后,我们将显示由生成器模型生成的图像的示例。 该示例帮助我们直观地了解生成器如何在每个时期生成更好的彩色图像时变得更好。 然后,我们为模型调用train_step函数,以从计算出的损耗和梯度中学习。 最后,我们检查时期数是否可以被5整除以保存检查点。 这意味着每完成5个培训阶段,我们将保存一个检查点。 在整个时期完成之后,从最终相对时间中减去开始时间,以计算该特定时期所花费的时间。

啊! 最后,我们在这里... (Aah! Finally, here we are…)

All we have to do now is run this one line of code and wait for the Model to do its magic on its own. Well let’s not give the entire credit to the model, we have done a lot of hard work and it’s time to see the results.

现在,我们要做的就是运行这一行代码,然后等待模型自行发挥作用。 好吧,我们不要完全相信该模型,我们已经做了很多艰苦的工作,现在该看看结果了。

恢复最新的检查点 (Restoring the Latest Checkpoint)

Before moving forward, we must restore the latest checkpoint available in order to load the latest version of the trained model before testing it on the images.

在继续之前,我们必须还原可用的最新检查点,以便在对图像进行测试之前加载经过训练的模型的最新版本。

测试输出 (Testing Outputs)

This randomly selects 5 images from the test_dataset and inputs them individually to the Generator Model. Now the model is trained well enough and predicts near-perfect colored versions of the input sketch images.

这将从test_dataset随机选择5张图像,并将它们分别输入到Generator模型。 现在,该模型已经过充分训练,并可以预测输入草图图像的近乎完美的彩色版本。

Image for post
Outputs of Black and White Sketches after going through the trained Generator Model (Image by Author)
经过训练有素的生成器模型后,黑白草图的输出(图片由作者提供)

保存模型 (Saving the Model)

Let’s not kill the model right after doing so much work, right?

我们不要在完成大量工作后立即终止模型,对吧?

A model shouldn’t end its life in a Jupyter Notebook!- Rightly said by Daniel Bourke

模型不应该在Jupyter笔记本电脑中终结!- 丹尼尔·伯克 ( Daniel Bourke)正确地说

It takes only a line of code to save the entire model as a .H5 file which is supported by Keras models.

仅需一行代码即可将整个模型另存为Keras模型支持的.H5文件。

结论 (CONCLUSION)

So, that is it!

就是这样!

We have not only seen how a Conditional GAN works but also have successfully implemented it to predict colored images from the given black and white input sketch images.

我们不仅看到了条件GAN的工作原理,而且还成功地实现了它,以根据给定的黑白输入草图图像来预测彩色图像。

You can go through the entire code and download it to see how it works on your system from my GitHub Repository.

您可以浏览整个代码,然后从GitHub Repository下载该代码以查看其在系统上的工作方式

If you face any problems, want to suggest some enhancements, or just want to leave a quick feedback, do not hesitate in contacting me through any medium that best suits you.

如果您遇到任何问题,想提出一些改进建议,或者只是想留下快速的反馈意见,请随时通过任何最适合您的媒介与我联系。

我的联系方式: (My Contact Information:)

LinkedIn: https://www.linkedin.com/in/tejasmorkar/GitHub: https://github.com/tejasmorkarTwitter: https://twitter.com/TejasMorkar

LinkedInhttps : //www.linkedin.com/in/tejasmorkar/ GitHubhttps : //github.com/tejasmorkar Twitterhttps : //twitter.com/TejasMorkar

翻译自: https://towardsdatascience.com/generative-adversarial-networks-gans-89ef35a60b69

生成对抗网络用于分类

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值