什么是GAN网络?
GAN(Generative Adversarial Networks)的初衷就是生成不存在于真实世界的数据,类似于使得 AI具有创造力或者想象力。应用场景如下:
- AI作家,AI画家等需要创造力的AI体;
- 将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有所谓的“想象力”,能脑补情节;
- 进行数据增强,根据已有数据生成更多新数据供以feed,可以减缓模型过拟合现象。
GAN网络有很多变形,下面主要介绍其中一种常用变形,ACGAN
那什么是ACGAN呢?
在计算机视觉里面,我们用的比较多的就是分类了,而训练分类的前提是收集足够多的各种分类的数据用来训练,这也是我们比较头疼的一个步骤,数据来源少无法训练怎么办?
ACGAN的一个用途就是用来生成多分类增强数据,只要你有每种分类数据大概2000张以上就能进行训练并生成指定分类的数据,下面是它的一个原理图:
如上图所示,ACGAN相对于GAN不同点在于:
- GAN只有Z即噪声作为输入变量,ACGAN多了一个分类变量
- GAN输出只有该图片真假判断,而ACGAN除了真假外增加了类别判断
所以表面上看还是挺简单的,不过深究到细节就有几点注意了:
1、针对输入多出的类别要怎样和噪声进行融合呢?我们首先想到的应该就是把类别和噪声进行连接成为新数组对吧,这种自己试过之后效果并不是特别好,因为该情况下类别无法深入影响到每个噪声变量。看一下下面这种方式:
def build_generator(self):
# generator负责生成图片,所以卷积过程是从小到大
model = Sequential()
row_shape = int(self.img_cols / 8)
col_shape = int(self.img_rows / 8)
# model.add(Dense(self.latent_dim * 8 * row_shape * col_shape, activation="relu", input_dim=self.latent_dim))
# model.add(Reshape((row_shape, col_shape, self.latent_dim * 8)))
# model.add(LeakyReLU(alpha=0.2))
# model.add(Conv2DTranspose(self.latent_dim * 4, 3, strides=2, padding='same'))
# model.add(LeakyReLU(alpha=0.2))
# model.add(Conv2DTranspose(self.latent_dim * 2, 3, strides=2, padding='same'))
# model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256 * row_shape * col_shape, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((row_shape, col_shape, 256)))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2DTranspose(64, kernel_size=3, strides=2, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2DTranspose(self.channels, kernel_size=3, strides=2, padding='same'))
model.add(Activation("tanh"))
model.summary()
noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.num_classes, 100)(label))
model_input = multiply([noise, label_embedding])
img = model(model_input)
return Model([noise, label], img)
上面这种方式是通过keras提供的Embedding层进行融合,Embedding层可以看做和one-hot类似的形式,但是它比one-hot优势体现在Embedding生成的变量并不是指定位置为1,其他为0的形式,而是每个位置的值都是一个浮点数,简单来说Embedding其实相当于一个神经网络层,把输入映射到多维空间,这样做的好处是空间特征更加丰富
回到上面的例子,利用Embedding输出和噪声进行相乘能更好将类别信息融合到噪声里面,经过测试相同的代码和数据,采用Embedding结构的网络生成的类别会更准确一些。
2、针对输出增加了类别的判断,网络结构上面也改变了一些,首先是D网络的输出:
def build_discriminator(self):
model = Sequential()
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0, 1), (0, 1))))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.summary()
img = Input(shape=self.img_shape)
# Extract feature representation
features = model(img)
# Determine validity and label of the image
validity = Dense(1, activation="sigmoid")(features)
label = Dense(self.num_classes, activation="softmax")(features)
return Model(img, [validity, label])
可以看到原先GAN网络输出只有validity,现在多了一个label,不过也不是特别复杂,只是将原先的最后一层网络分别映射成真假和类别两个输出。
3、增加了输出,损失函数的更改更为重要:
optimizer = Adam(0.001, 0.5)
losses = ['binary_crossentropy', 'sparse_categorical_crossentropy']
# Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss=losses, optimizer=optimizer, metrics=['accuracy'])
原先GAN网络只要判断真假,所以选择binary这种多标签损失函数就行了,ACGAN增加了类别,所以还得加入categorial这种多类别损失函数,两个损失函数分别对应之前的两个输出,两个加起来的结果就是总的损失函数
介绍完ACGAN基本结构,我们来看一下它实际的效果吧,如下所示,我们训练下面四种枪支:
- AKM(2500张训练数据):
- Kar98K(1800张训练数据):
- AWM(600张训练数据):
- M416(5000张训练数据):
经过一番数据收集,准备好了上面标注的数据量,使用前面提到的ACGAN网络进行训练,下面是训练结果:
1)epoch0:
2) epoch2:
3) epoch23:
可以看到逐步有效果出来了,注意训练数据的图片宽高比尽量保持一致,这样输出的比较符合真实比例,还有一点就是没必要增加太多的数据增强,增加一点训练数据效果还来的更好一些,再来看最后一张,Epoch45:
不知道大家注意到了没有,不管训练多少回,第二种类别即Kar98K这一列偶尔还是会出现不对的类别,第三种即AWM更糟糕,基本很少是对的类别,但是第一种和最后一种却非常稳定,几乎没有出现错的类别,这是什么原因呢?
带着疑问大家可以再回顾一下训练前准备的训练数据情况,这时候应该就能发现有个规律,训练数据量越多的生成的效果也就越准确,而且有个分界点,就是数据量低于2000的明显效果不太好,高于2000的会正常一些。
正如刚开始介绍ACGAN作用的时候也说过,ACGAN的一个作用就是弥补数据的不足,所以正常情况下也不会有太大的数据作为ACGAN的训练数据,不然就失去意义了,但是一个底线是起码每个分类的训练数据不低于2000张,否则
- 没有足够的数据D网络无法捕获更多的特征
- D网络没有足够特征反馈给G网络,G网络训练时针对该种类别的权重也就比较小了,所以可以看到第二第三种类别出错的时候基本都是第四种枪支的样子
总结
总的来说,ACGAN还是相当有意思的一种网络,它让计算机拥有了能够模仿现有数据从而生成独特数据的能力,用途也相当广泛,如果用来做数据增强的话,建议每种类别训练数据量尽量和上面的AKM保持一致,这样训练效果才不会太差,而且2000多张数据正常还是可以收集到的,当然个人比较看好的是它的“想象力”,希望后面能开发出变形以及用途。