在实际的深度学习应用中,一个常见的问题是,一些类在训练集中的实例数量明显高于其他类。这种类不平衡数据集在不同的领域(如健康、银行、安全等)中很常见。对于这样的机器学习数据集,学习算法往往偏向于多数类,因此少数类实例的误分类率较高。
为了解决这一问题,需要采取过采样、过采样、两阶段训练和成本敏感学习等不同的策略。为少数类生成人工数据的方法构成了更通用的方法。这篇文章是关于使用深度卷积生成对抗网络(DC-GAN)来减少机器学习数据集中的这种不平衡,以提高分类性能。
在本文中,我们将讨论以下主题:
- GAN的一些提示和技巧。
- 如何定义GAN。
- GAN用例。
GAN的一些提示和技巧
开发用于生成图像的GAN需要使用判别器卷积神经网络模型来对给定图像是真实图像还是已生成的图像进行分类,并使用反卷积层将输入转换为像素值为二维的图像。
这些生成器和判别器模型在零和游戏中竞争。这意味着对一个模型的改进是以降低另一个模型的性能为代价的。结果是非常不稳定的训练过程,经常会导致失败。
一些技巧:
- 使用跨步卷积>>请勿使用最大池化层,而应使用卷积层中的stride在判别器模型中执行下采样。使用Conv2DTranspose和stride进行上采样。
- 删除全连接层>>判别器中不使用全连接层,而是将卷积层flattened并直接传递到输出层。
- 使用“批归一化” >>在鉴别器和生成器模型中,除了生成器的输出和鉴别器的输入外,都推荐使用Batch norm层。
- 使用ReLU, Leaky ReLU和Tanh >>ReLU只推荐用于生成器,但对于允许值小于零的ReLU的判别器变体,Leaky ReLU是首选。另外,生成器使用Tanh,判别器在输出层使用Sigmoid激活函数。
- 归一化输入>>归一化在-1到1之间的输入图像。为真实和伪造构造不同的mini-batches ,即每个mini-batch只需要包含所有真实图像或所有生成的图像。
- 学习率>>对判别器(1e-3)和生成器(1e-4)使用不同的学习率。两者都使用Adam优化器。
- 性能技巧>>训练判别器两次,生成器一次。在生成器中使用50%的dropout。
- 尽早跟踪故障>>判别器损失0.0是一种故障模式。如果生成器的损失稳步减少,则很可能用垃圾图像欺骗判别器。当训练顺利进行时,判别器损失的方差很小,并且随着时间的推移而下降。
如何定义GAN?
我们将使用DC-GAN为“Diabetic Retinopathy Detection(https://www.kaggle.com/c/diabetic-retinopathy-detection/overview)”机器学习数据集的第4类创建人工样本,该数据集有4类,其中类1有13000个样本,而类4只有600个样本。
导入所有必要的Python库。
import osimport tensorflow as tffrom keras.utils import plot_modelimport pydotimport graphvizimport numpy as np # linear algebrafrom sklearn.model_selection import train_test_splitimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)from tqdm import tqdmfrom numpy import expand_dims, zeros, ones, vstackfrom numpy.random import randn, randintfrom keras.optimizers import Adam, SGDfrom keras.models import Sequential from keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU, Dropout, BatchNormalization from matplotlib import pyplotfrom keras.preprocessing import imagefrom cv2 import cv2from PIL import Imageconfig = tf.ConfigProto()config.gpu_options.allow_growth = Truesess = tf.Session(config=config)
该机器学习数据集有几个压缩文件,我们需要将它们解压缩到包含相应图像的训练/测试文件夹中。训练图像的所有标签都在单独的csv文件中提供。
在下面的Python代码中,我们将读取一个包含标签和图像名称的csv文件。在继续前进之前,我们需要进行一些完整性检查(添加.jpeg扩展名,删除大小为0 KB的所有图像)。
在此机器学习数据集中,类3和4是少数类。我们将训练GAN为第4类生成图像。
#Read all the labels from the CSV filedf_csv = pd.read_csv('/storage/trainLabels.csv')df_csv['image'] = df_csv['image'].astype(str) + '.jpeg'## Delete all the images of size zero (0 KB), No need to do this step while rerunning the program #### There are multiple reasons for size zero dat