概述
上一篇的对抗神经网络有两个比较显著的问题:
- 输出效果不尽人意
- 输出数字的概率不均匀,数字0更多
- 不能控制数字
为了改善上面的问题,我改改良一下网络,用条件对抗神经网络。
在分类器和生成器的网络的输入层加10个参数,分别代表数值0~9。
其他都不用变。
代码
载入第三方库
In [1]: import time
...: import gzip
...: import numpy as np
...: import matplotlib.pyplot as plt
...: import mxnet as mx
定义两条工具函数
In [2]: def try_gpu():
...: try:
...: ctx = mx.gpu()
...: _ = mx.ndarray.zeros((1,), ctx=ctx)
...: except mx.base.MXNetError:
...: ctx = mx.cpu()
...: return ctx
...:
...: def make_seed(batach, num, ctx=None):
...: if ctx is None:
...: ctx = mx.current_context()
...: return mx.ndarray.normal(loc=0, scale=1, shape=(batach, num), ctx=ctx)
载入训练数据
In [3]: def load_dataset():
...: transform = mx.gluon.data.vision.transforms.ToTensor()
...:
...: train_img = [ transform(img).asnumpy() for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=True )]
...: train_lbl = [ np.array(lbl) for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=True )]
...: eval_img = [ transform(img).asnumpy() for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=False)]
...: eval_lbl = [ np.array(lbl) for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=False)]
...:
...: return train_img, train_lbl, eval_img, eval_lbl
...:
...: train_img, train_lbl, eval_img, eval_lbl = load_dataset()
预览图片
In [4]: idxs = (25, 47, 74, 88, 92)
...: for i in range(5):
...: plt.subplot(1, 5, i + 1)
...: idx = idxs[i]
...: plt.xticks([])
...: plt.yticks([])
...: img = train_img[idx][0].astype( np.float32 )
...: plt.imshow(img, interpolation='none', cmap='Blues')
...: plt.show()
定义分类器
In [5]: class Discriminator():
...: def __init__(self):
...: self.net = mx.gluon.nn.HybridSequential()
...: # 第一层
...: self.net.add(
...: mx.gluon.nn.Dense(in_units=28*28+10, units=200),
...: mx.gluon.nn.LeakyReLU(alpha=0.02),
...: mx.gluon.nn.LayerNorm()
...: )
...: # 第二层
...: self.net.add(
...: mx.gluon.nn.Dense(units=1),
...: )
...:
...: self.net.initialize(init=mx.init.Xavier(rnd_type='gaussian'))
...: self.trainer = mx.gluon.Trainer(
...: params=self.net.collect_params(),
...: optimizer=mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True)
...: )
...:
...: Discriminator().net.summary(mx.ndarray.zeros(shape=(50, 28*28+10)))
Out[5]:
--------------------------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================================
Input (50, 794) 0
Dense-1 (50, 200) 159000
LeakyReLU-2 (50, 200) 0
LayerNorm-3 (50, 200) 400
Dense-4 (50, 1) 201
================================================================================
Parameters in forward computation graph, duplicate included
Total params: 159601
Trainable params: 159601
Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 159601
--------------------------------------------------------------------------------
定义生成器
In [6]: class Generator():
...: def __init__(self):
...: self.net = mx.gluon.nn.HybridSequential()
...: # 第一层
...: self.net.add(
...: mx.gluon.nn.Dense(in_units=100+10, units=200),
...: mx.gluon.nn.LeakyReLU(alpha=0.02),
...: mx.gluon.nn.LayerNorm(),
...: )
...: # 第二层
...: self.net.add(
...: mx.gluon.nn.Dense(units=784),
...: mx.gluon.nn.Activation(activation='sigmoid'),
...: mx.gluon.nn.HybridLambda(lambda F, x: F.reshape(x, shape=(0, -1, 28, 28)))
...: )
...:
...: self.net.initialize(init=mx.init.Xavier(rnd_type='gaussian'))
...: self.trainer = mx.gluon.Trainer(
...: params=self.net.collect_params(),
...: optimizer=mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True)
...: )
...:
...: Generator().net.summary(mx.ndarray.zeros(shape=(50, 100+10)))
Out[6]:
--------------------------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================================
Input (50, 110) 0
Dense-1 (50, 200) 22200
LeakyReLU-2 (50, 200) 0
LayerNorm-3 (50, 200) 400
Dense-4 (50, 784) 157584
Activation-5 (50, 784) 0
HybridLambda-6 (50, 1, 28, 28) 0
================================================================================
Parameters in forward computation graph, duplicate included
Total params: 180184
Trainable params: 180184
Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 180184
--------------------------------------------------------------------------------
定义训练函数
In [7]: def train_gan(data_set, batch_size, d_net, d_trainer, g_net, g_trainer, loss_fn, epoch):
...: train_data = mx.gluon.data.DataLoader(data_set, batch_size=batch_size, shuffle=True)
...: d_metric = mx.metric.Loss()
...: g_metric = mx.metric.Loss()
...:
...: for epoch in range(1, epoch + 1):
...: d_metric.reset(); g_metric.reset(); tic=time.time()
...: for datas, labels in train_data:
...: # 批大小
...: actually_batch_size = datas.shape[0]
...: # CPU 移到 ctx
...: datas = mx.gluon.utils.split_and_load(datas, [mx.current_context()])
...: labels = mx.gluon.utils.split_and_load(labels, [mx.current_context()])
...: # 生成种子
...: seeds = [mx.ndarray.concat(make_seed(label.shape[0], 100), label.one_hot(10)) for label in labels]
...:
...: # 训练鉴别器
...: for data, label, seed in zip(datas, labels, seeds):
...: lbl_real = mx.ndarray.ones(shape=(seed.shape[0],1))
...: lbl_fake = mx.ndarray.zeros(shape=(seed.shape[0],1))
...: data = mx.ndarray.concat(data.flatten(), label.one_hot(10))
...: img = mx.ndarray.concat(g_net(seed).flatten(), label.one_hot(10))
...:
...: with mx.autograd.record():
...: alpha = loss_fn(d_net(data), lbl_real)
...: beta = loss_fn(d_net(img), lbl_fake)
...: d_loss = alpha + beta
...: d_loss.backward()
...: d_metric.update(_, preds=d_loss)
...: d_trainer.step(actually_batch_size)
...:
...: # 训练生成器
...: for label, seed in zip(labels, seeds):
...: lbl_real = mx.ndarray.ones(shape=(seed.shape[0],1))
...:
...: with mx.autograd.record():
...: img = g_net(seed)
...: data = mx.ndarray.concat(img.flatten(), label.one_hot(10))
...: g_loss = loss_fn(d_net(data), lbl_real)
...: g_loss.backward()
...: g_metric.update(_, preds=g_loss)
...: g_trainer.step(actually_batch_size)
...:
...: print("Epoch {:>2d}: cost:{:.1f}s d_loss:{:.3f} g_loss:{:.3f}".format(epoch, time.time()-tic, d_metric.get()[1], g_metric.get()[1]))
...:
...: # 展示结果
...: label = mx.ndarray.arange(10)
...: seed = mx.ndarray.concat(make_seed(len(label), 100), label.one_hot(10))
...: output = g_net(seed)
...: img = output[0][0]
...: for i in range(1,10):
...: img = mx.ndarray.concat(img, output[i][0])
...: plt.xticks([]); plt.yticks([])
...: plt.imshow(img.asnumpy(), interpolation='none', cmap='Blues')
...: plt.show()
训练
In [8]: with mx.cpu(0):
...: discriminator = Discriminator()
...: generator = Generator()
...: train_gan(
...: data_set = mx.gluon.data.ArrayDataset(train_img + eval_img, train_lbl + eval_lbl),
...: batch_size = 10,
...: d_net = discriminator.net,
...: d_trainer = discriminator.trainer,
...: g_net = generator.net,
...: g_trainer = generator.trainer,
...: loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(),
...: epoch = 120,
...: )
Out[8]:
保存参数
In [9]: discriminator.net.save_parameters("CondGan-dis-params")
...: generator.net.save_parameters("CondGan-gan-params")
输出图片
In [10]: n = Generator().net
...: n.load_parameters("CondGan-gan-params")
...:
...: for i in range(5):
...: label = mx.ndarray.arange(10)
...: seed = make_seed(1, 100)
...: seed = mx.ndarray.concat(seed, seed, seed, seed, seed, dim = 0)
...: seed = mx.ndarray.concat(seed, seed, dim = 0)
...: seed = mx.ndarray.concat(seed, label.one_hot(10))
...: output = n(seed)
...: img = output[0][0]
...: for i in range(1,10):
...: img = mx.ndarray.concat(img, output[i][0])
...: plt.xticks([]); plt.yticks([])
...: plt.imshow(img.asnumpy(), interpolation='none', cmap='Blues')
...: plt.show()
Out[10]:
评价
分类器和生成器的损失值如图。
趋势和无条件对抗网络差不多:
- 前期分类器和生成器在快速进步
- 中期分类器和生成器学习效果降低,但依然有效
- 后期分类器开始慢慢领先
但是,比无条件对抗网络改善了许多:
- d-loss 始终大于理想值 0.693
- g-loss 始终小于2