本次的分享包含2个基于tf2.0的项目,一个是微信聊天机器人,一个是GAN。中间还发了件TF文化衫,大写的赞?。
wechat_bot:
why wechat:
Why Tf serving:
代码:
https://github.com/lemoz/tensorflow-serving_sidecar
protoc --python_out=./ tensorflow-serving_sidecar/object_detection/protos/string_int_label_map.proto
TF 2.0 实现对抗神经网络:
- 判别(Discriminative)模型
- 学习类别之间的差异/分界
- 参数数量较少,易于训练
- 只能用于分类,不能用于数据生成
- 生成(Generative)模型
- 学习类别特征的概率分布
- 学习到特征的概率分布,可以用于生成
- 参数数量很多,需要大量的数据样本
- 生成器(Generator):接受一个随机向量噪音 x 作为输入,生成一个张量 G(x)。
- 判别器(Discriminator):接受一个张量作为输入,输出其真实性。
以图像生成为列,整个 GAN 的训练过程如下:
1. 定义生成器模型。生成器接受随机输入,输出一个生成图像张量。
2. 定义判别器模型。判别器接受一张图像输入,输出一个代表图像真伪的张量。
3. 定义一个对抗模型。对抗模型接受随机输入,输出一个代表图像真伪的张量。对抗模型的网络层由生成器模型层和判别器模型层组成,其中判别器的层需要冻结。
4. 将一批随机噪音输入到生成器模型,生成一批图像。
5. 使用生成的图像和真实图像训练判别器。
6. 再使用新的随机输入训练对抗模型中的生成器,使其造假越来越逼真。
7. 重复步骤 4-7。
环境依赖:
pip install tensorflow==2.0.0a0
程序:
import tqdm
import numpy as np
from tensorflow.python import keras
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.optimizers import Adam
L = keras.layers
LATENT_DIM = 100 # 潜在空间维度
IMAGE_SHAPE = (28, 28, 1) # 输出图像尺寸
generator_net = [
L.Input(shape=(LATENT_DIM, )),
L.Dense(256),
L.LeakyReLU(alpha=0.2),
L.BatchNormalization(momentum=0.8),
L.Dense(512),
L.LeakyReLU(alpha=0.2),
L.BatchNormalization(momentum=0.8),
L.Dense(1024),
L.LeakyReLU(alpha=0.2),
L.BatchNormalization(momentum=0.8),
L.Dense(np.prod(IMAGE_SHAPE), activation='tanh'),
L.Reshape(IMAGE_SHAPE),
]
generator = keras.models.Sequential(generator_net)
generator.summary()
discriminator_net = [
L.Input(shape=IMAGE_SHAPE),
L.Flatten(),
L.Dense(512),
L.LeakyReLU(alpha=0.2),
L.Dense(256),
L.LeakyReLU(alpha=0.2),
L.Dense(1, activation='sigmoid'),
]
optimizer = Adam(0.0002, 0.5)
discriminator = Sequential(discriminator_net)
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
discriminator.summary()
# 对抗模型使用生成器模型层和判别器模型层,它们共享权重
adversarial_net = generator_net + discriminator_net
# 冻结判别器的权重
# trainable 属性只有编译后生效,所以之前的判别器模型同样的层还是可以训练的
for layer in discriminator_net:
layer.trainable = False
adversarial = Sequential(adversarial_net)
# 编译对抗模型
optimizer = Adam(0.0002, 0.5)
adversarial.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
adversarial.summary()
BATCH_COUNT = 30000
BATCH_SIZE = 32
# 读取数据集,我们只需要图像数据,不需要标签和测试数据。
(image_set, _), (_, _) = keras.datasets.mnist.load_data()
# 数据归一化
image_set = image_set / 127.5 - 1.
# 数据格式转换 [count, 28, 28] -> [count, 28, 28, 1]
image_set = image_set.reshape(len(image_set), 28, 28, 1)
# 准备 BATCH_SIZE 大小的真假数据标签
valid = np.ones((BATCH_SIZE))
fake = np.zeros((BATCH_SIZE))
import matplotlib.pyplot as plt
from IPython.display import clear_output, Image
def monitoring(sample_count=20):
plt.figure(figsize=(15,4))
images = image_set[np.random.randint(0, len(image_set), size=20)]
for i, image in enumerate(images):
image = np.reshape(image, [28, 28])
plt.subplot(2, 20/2, i+1)
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
plt.figure(figsize=(15,4))
noice_data = np.random.uniform(-1.0, 1.0, size=[sample_count, 100])
images = generator.predict(noice_data)
for i, image in enumerate(images):
image = np.reshape(image, [28, 28])
plt.subplot(2, sample_count/2, i+1)
plt.imshow(image, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.show()
for batch in tqdm.trange(BATCH_COUNT):
# ------ 生成器生成图像 ------
# 随机选择 BATCH_SIZE 数量的数据作为训练数据
idx = np.random.randint(0, image_set.shape[0], BATCH_SIZE)
imgs = image_set[idx]
# 生成噪音数据作为生成器输入
noise = np.random.normal(0, 1, (BATCH_SIZE, LATENT_DIM))
# 使用生成器生成生成图像
gen_imgs = generator.predict(noise)
# ------ 训练判别器 ------
# 使用真实图像和生成图像训练判别器,真实图像标签全部为 1,生成图像标签全部为 0
d_state_real = discriminator.train_on_batch(imgs, valid)
d_state_fake = discriminator.train_on_batch(gen_imgs, fake)
d_state = 0.5 * np.add(d_state_real, d_state_fake)
# ------ 训练生成器 ------
# 训练对抗模型,目标是生成判别器认为真实图像的图像,所以标签为 1
# 由于对抗模型中的判别器的层都冻结了,所以实际上在训练生成器,不断生成更加逼真的图像
adv_state = adversarial.train_on_batch(noise, valid)
# 更新进度条后缀,用于输出训练进度
if batch % 100 == 0:
clear_output(wait=True)
state = f"[D loss: {d_state[0]:.4f} acc: {d_state[1]:.4f}] " \
f"[A loss: {adv_state[0]:.4f} acc: {adv_state[1]:.4f}]"
print(f'batch {batch}')
print(state)
monitoring()
实验结果:
条件生成对抗网络(Conditional GANs):
深度卷积对抗生成网络(DCGAN):
深度卷积条件对抗生成网络(DCCGAN)
References: