camvid数据集使用方法_使用GAN来处理类不平衡机器学习数据集问题

本文探讨了使用深度卷积生成对抗网络(DC-GAN)解决类不平衡机器学习数据集问题,特别是在camvid数据集上的应用。通过调整GAN的训练技巧,如避免最大池化层,使用批量归一化,调整学习率等,可以生成少数类的高质量图像,从而改善分类性能。此外,还介绍了如何定义和训练GAN模型,并展示了在糖尿病视网膜病变检测数据集上的应用示例。
摘要由CSDN通过智能技术生成
3e3844d2ab20bcad3305ba123140b03d.png

在实际的深度学习应用中,一个常见的问题是,一些类在训练集中的实例数量明显高于其他类。这种类不平衡数据集在不同的领域(如健康、银行、安全等)中很常见。对于这样的机器学习数据集,学习算法往往偏向于多数类,因此少数类实例的误分类率较高。

为了解决这一问题,需要采取过采样、过采样、两阶段训练和成本敏感学习等不同的策略。为少数类生成人工数据的方法构成了更通用的方法。这篇文章是关于使用深度卷积生成对抗网络(DC-GAN)来减少机器学习数据集中的这种不平衡,以提高分类性能。

在本文中,我们将讨论以下主题:

  1. GAN的一些提示和技巧。
  2. 如何定义GAN。
  3. GAN用例。

GAN的一些提示和技巧

开发用于生成图像的GAN需要使用判别器卷积神经网络模型来对给定图像是真实图像还是已生成的图像进行分类,并使用反卷积层将输入转换为像素值为二维的图像。

这些生成器和判别器模型在零和游戏中竞争。这意味着对一个模型的改进是以降低另一个模型的性能为代价的。结果是非常不稳定的训练过程,经常会导致失败。

一些技巧:

  1. 使用跨步卷积>>请勿使用最大池化层,而应使用卷积层中的stride在判别器模型中执行下采样。使用Conv2DTranspose和stride进行上采样。
  2. 删除全连接层>>判别器中不使用全连接层,而是将卷积层flattened并直接传递到输出层。
  3. 使用“批归一化” >>在鉴别器和生成器模型中,除了生成器的输出和鉴别器的输入外,都推荐使用Batch norm层。
  4. 使用ReLU, Leaky ReLU和Tanh >>ReLU只推荐用于生成器,但对于允许值小于零的ReLU的判别器变体,Leaky ReLU是首选。另外,生成器使用Tanh,判别器在输出层使用Sigmoid激活函数。
  5. 归一化输入>>归一化在-1到1之间的输入图像。为真实和伪造构造不同的mini-batches ,即每个mini-batch只需要包含所有真实图像或所有生成的图像。
  6. 学习率>>对判别器(1e-3)和生成器(1e-4)使用不同的学习率。两者都使用Adam优化器。
  7. 性能技巧>>训练判别器两次,生成器一次。在生成器中使用50%的dropout。
  8. 尽早跟踪故障>>判别器损失0.0是一种故障模式。如果生成器的损失稳步减少,则很可能用垃圾图像欺骗判别器。当训练顺利进行时,判别器损失的方差很小,并且随着时间的推移而下降。

如何定义GAN?

我们将使用DC-GAN为“Diabetic Retinopathy Detection(https://www.kaggle.com/c/diabetic-retinopathy-detection/overview)”机器学习数据集的第4类创建人工样本,该数据集有4类,其中类1有13000个样本,而类4只有600个样本。

导入所有必要的Python库。

import osimport tensorflow as tffrom keras.utils import plot_modelimport pydotimport graphvizimport numpy as np # linear algebrafrom sklearn.model_selection import train_test_splitimport pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)from tqdm import tqdmfrom numpy import expand_dims, zeros, ones, vstackfrom numpy.random import randn, randintfrom keras.optimizers import Adam, SGDfrom keras.models import Sequential from keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU, Dropout, BatchNormalization from matplotlib import pyplotfrom keras.preprocessing import imagefrom cv2 import cv2from PIL import Imageconfig = tf.ConfigProto()config.gpu_options.allow_growth = Truesess = tf.Session(config=config)
bcef381863685b23e0c99913f0007031.png

该机器学习数据集有几个压缩文件,我们需要将它们解压缩到包含相应图像的训练/测试文件夹中。训练图像的所有标签都在单独的csv文件中提供。

在下面的Python代码中,我们将读取一个包含标签和图像名称的csv文件。在继续前进之前,我们需要进行一些完整性检查(添加.jpeg扩展名,删除大小为0 KB的所有图像)。

在此机器学习数据集中,类3和4是少数类。我们将训练GAN为第4类生成图像。

#Read all the labels from the CSV filedf_csv = pd.read_csv('/storage/trainLabels.csv')df_csv['image'] = df_csv['image'].astype(str) + '.jpeg'## Delete all the images of size zero (0 KB), No need to do this step while rerunning the program #### There are multiple reasons for size zero dat
GAN(生成对抗网络)算法是一种用于生成数据的深度学习模型。它是由Generator(生成器)和Discriminator(判别器)两个子网络构成的。 当处理序列多平衡数据集时,GAN算法可以通过生成新的样本来平衡数据集。首先,生成器接收来自原始数据集中不平衡的样本作为输入。生成器学习生成新的样本,这些样本与原始数据集中的样本别相同,但具有更多的多样性。在这个过程中,生成器试图模仿原始数据集中的数据分布。 然后,判别器负责区分生成器生成的样本和原始数据集中的真实样本。判别器通过与生成器互动和学习来提高自己的性能。生成器和判别器通过不断迭代的对抗训练来提高彼此的能力。 在处理平衡数据集时,GAN算法可以生成更多的少数别样本,从而增加数据集中各别的数量平衡。通过生成样本,原始数据集的数量不再是严重不平衡的,这有助于提高分模型的性能和泛化能力。 然而,GAN算法也存在一些挑战。例如,生成的样本可能与真实样本之间存在明显的差距。此外,生成样本的质量和多样性可能受到生成器和判别器之间的平衡问题的限制。此外,GAN算法的训练可能需要更长的时间和更大的计算资源。 总而言之,GAN算法可以用于处理序列多平衡数据集。通过生成新的样本,它可以帮助平衡数据集中各别的数量,并提高分模型的性能。然而,这个算法仍然需要进一步的研究和改进,以解决其存在的挑战和限制。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值