MXNet网络模型(五)条件GAN神经网络

概述

上一篇的对抗神经网络有两个比较显著的问题:

  1. 输出效果不尽人意
  2. 输出数字的概率不均匀,数字0更多
  3. 不能控制数字

为了改善上面的问题,我改改良一下网络,用条件对抗神经网络。
在分类器和生成器的网络的输入层加10个参数,分别代表数值0~9。
其他都不用变。

代码

载入第三方库

In [1]: import time
   ...: import gzip
   ...: import numpy as np
   ...: import matplotlib.pyplot as plt
   ...: import mxnet as mx

定义两条工具函数

In [2]: def try_gpu():
   ...:     try:
   ...:         ctx = mx.gpu()
   ...:         _ = mx.ndarray.zeros((1,), ctx=ctx)
   ...:     except mx.base.MXNetError:
   ...:         ctx = mx.cpu()
   ...:     return ctx
   ...: 
   ...: def make_seed(batach, num, ctx=None):
   ...:     if ctx is None:
   ...:         ctx = mx.current_context()
   ...:     return mx.ndarray.normal(loc=0, scale=1, shape=(batach, num), ctx=ctx)

载入训练数据

In [3]: def load_dataset():
   ...:     transform = mx.gluon.data.vision.transforms.ToTensor()
   ...: 
   ...:     train_img = [ transform(img).asnumpy() for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=True )]
   ...:     train_lbl = [ np.array(lbl)            for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=True )]
   ...:     eval_img  = [ transform(img).asnumpy() for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=False)]
   ...:     eval_lbl  = [ np.array(lbl)            for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=False)]
   ...: 
   ...:     return train_img, train_lbl, eval_img, eval_lbl
   ...: 
   ...: train_img, train_lbl, eval_img, eval_lbl = load_dataset()

预览图片

In [4]: idxs = (25, 47, 74, 88, 92)
   ...: for i in range(5):
   ...:     plt.subplot(1, 5, i + 1)
   ...:     idx = idxs[i]
   ...:     plt.xticks([])
   ...:     plt.yticks([])
   ...:     img = train_img[idx][0].astype( np.float32 )
   ...:     plt.imshow(img, interpolation='none', cmap='Blues')
   ...: plt.show()

在这里插入图片描述

定义分类器

In [5]: class Discriminator():
   ...:     def __init__(self):     
   ...:         self.net = mx.gluon.nn.HybridSequential()
   ...:         # 第一层
   ...:         self.net.add(
   ...:             mx.gluon.nn.Dense(in_units=28*28+10, units=200),
   ...:             mx.gluon.nn.LeakyReLU(alpha=0.02),
   ...:             mx.gluon.nn.LayerNorm()
   ...:         )
   ...:         # 第二层
   ...:         self.net.add(
   ...:             mx.gluon.nn.Dense(units=1),
   ...:         )
   ...:             
   ...:         self.net.initialize(init=mx.init.Xavier(rnd_type='gaussian'))
   ...:         self.trainer = mx.gluon.Trainer(
   ...:             params=self.net.collect_params(),
   ...:             optimizer=mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True)
   ...:         )
   ...: 
   ...: Discriminator().net.summary(mx.ndarray.zeros(shape=(50, 28*28+10)))
Out[5]:
--------------------------------------------------------------------------------
        Layer (type)                                Output Shape         Param #
================================================================================
               Input                                   (50, 794)               0
             Dense-1                                   (50, 200)          159000
         LeakyReLU-2                                   (50, 200)               0
         LayerNorm-3                                   (50, 200)             400
             Dense-4                                     (50, 1)             201
================================================================================
Parameters in forward computation graph, duplicate included
   Total params: 159601
   Trainable params: 159601
   Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 159601
--------------------------------------------------------------------------------

定义生成器

In [6]: class Generator():
   ...:     def __init__(self):
   ...:         self.net = mx.gluon.nn.HybridSequential()
   ...:         # 第一层
   ...:         self.net.add(
   ...:             mx.gluon.nn.Dense(in_units=100+10, units=200),
   ...:             mx.gluon.nn.LeakyReLU(alpha=0.02),
   ...:             mx.gluon.nn.LayerNorm(),
   ...:         )
   ...:         # 第二层
   ...:         self.net.add(
   ...:             mx.gluon.nn.Dense(units=784),
   ...:             mx.gluon.nn.Activation(activation='sigmoid'),
   ...:             mx.gluon.nn.HybridLambda(lambda F, x: F.reshape(x, shape=(0, -1, 28, 28)))
   ...:         )
   ...:             
   ...:         self.net.initialize(init=mx.init.Xavier(rnd_type='gaussian'))
   ...:         self.trainer = mx.gluon.Trainer(
   ...:             params=self.net.collect_params(),
   ...:             optimizer=mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True)
   ...:         )
   ...: 
   ...: Generator().net.summary(mx.ndarray.zeros(shape=(50, 100+10)))
Out[6]:
--------------------------------------------------------------------------------
        Layer (type)                                Output Shape         Param #
================================================================================
               Input                                   (50, 110)               0
             Dense-1                                   (50, 200)           22200
         LeakyReLU-2                                   (50, 200)               0
         LayerNorm-3                                   (50, 200)             400
             Dense-4                                   (50, 784)          157584
        Activation-5                                   (50, 784)               0
      HybridLambda-6                             (50, 1, 28, 28)               0
================================================================================
Parameters in forward computation graph, duplicate included
   Total params: 180184
   Trainable params: 180184
   Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 180184
--------------------------------------------------------------------------------

定义训练函数

In [7]: def train_gan(data_set, batch_size, d_net, d_trainer, g_net, g_trainer, loss_fn, epoch):
   ...:     train_data = mx.gluon.data.DataLoader(data_set, batch_size=batch_size, shuffle=True)
   ...:     d_metric = mx.metric.Loss()
   ...:     g_metric = mx.metric.Loss()
   ...: 
   ...:     for epoch in range(1, epoch + 1):
   ...:         d_metric.reset(); g_metric.reset(); tic=time.time()
   ...:         for datas, labels in train_data:
   ...:             # 批大小
   ...:             actually_batch_size = datas.shape[0]
   ...:             # CPU 移到 ctx
   ...:             datas  = mx.gluon.utils.split_and_load(datas, [mx.current_context()])
   ...:             labels = mx.gluon.utils.split_and_load(labels, [mx.current_context()])
   ...:             # 生成种子
   ...:             seeds = [mx.ndarray.concat(make_seed(label.shape[0], 100), label.one_hot(10)) for label in labels]
   ...: 
   ...:             # 训练鉴别器
   ...:             for data, label, seed in zip(datas, labels, seeds):
   ...:                 lbl_real = mx.ndarray.ones(shape=(seed.shape[0],1))
   ...:                 lbl_fake = mx.ndarray.zeros(shape=(seed.shape[0],1))
   ...:                 data = mx.ndarray.concat(data.flatten(), label.one_hot(10))
   ...:                 img = mx.ndarray.concat(g_net(seed).flatten(), label.one_hot(10))
   ...: 
   ...:                 with mx.autograd.record():
   ...:                     alpha = loss_fn(d_net(data), lbl_real)
   ...:                     beta = loss_fn(d_net(img), lbl_fake)
   ...:                     d_loss = alpha + beta
   ...:                 d_loss.backward()
   ...:                 d_metric.update(_, preds=d_loss)
   ...:             d_trainer.step(actually_batch_size)
   ...: 
   ...:             # 训练生成器
   ...:             for label, seed in zip(labels, seeds):
   ...:                 lbl_real = mx.ndarray.ones(shape=(seed.shape[0],1))
   ...: 
   ...:                 with mx.autograd.record():
   ...:                     img = g_net(seed)
   ...:                     data = mx.ndarray.concat(img.flatten(), label.one_hot(10))
   ...:                     g_loss = loss_fn(d_net(data), lbl_real)
   ...:                 g_loss.backward()
   ...:                 g_metric.update(_, preds=g_loss)
   ...:             g_trainer.step(actually_batch_size)
   ...: 
   ...:         print("Epoch {:>2d}: cost:{:.1f}s d_loss:{:.3f} g_loss:{:.3f}".format(epoch, time.time()-tic, d_metric.get()[1], g_metric.get()[1]))
   ...: 
   ...:         # 展示结果
   ...:         label = mx.ndarray.arange(10)
   ...:         seed = mx.ndarray.concat(make_seed(len(label), 100), label.one_hot(10))
   ...:         output = g_net(seed)
   ...:         img = output[0][0]
   ...:         for i in range(1,10):
   ...:             img = mx.ndarray.concat(img, output[i][0])
   ...:         plt.xticks([]); plt.yticks([])
   ...:         plt.imshow(img.asnumpy(), interpolation='none', cmap='Blues')
   ...:         plt.show()

训练

In [8]: with mx.cpu(0):
   ...:     discriminator = Discriminator()
   ...:     generator = Generator()
   ...:     train_gan(
   ...:         data_set   = mx.gluon.data.ArrayDataset(train_img + eval_img, train_lbl + eval_lbl),
   ...:         batch_size = 10,
   ...:         d_net      = discriminator.net,
   ...:         d_trainer  = discriminator.trainer,
   ...:         g_net      = generator.net,
   ...:         g_trainer  = generator.trainer,
   ...:         loss_fn    = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss(),
   ...:         epoch      = 120,
   ...:     )
Out[8]:

在这里插入图片描述

保存参数

In [9]: discriminator.net.save_parameters("CondGan-dis-params")
   ...: generator.net.save_parameters("CondGan-gan-params")

输出图片

In [10]: n = Generator().net
    ...: n.load_parameters("CondGan-gan-params")
    ...: 
    ...: for i in range(5):
    ...:     label = mx.ndarray.arange(10)
    ...:     seed = make_seed(1, 100)
    ...:     seed = mx.ndarray.concat(seed, seed, seed, seed, seed, dim = 0)
    ...:     seed = mx.ndarray.concat(seed, seed, dim = 0)
    ...:     seed = mx.ndarray.concat(seed, label.one_hot(10))
    ...:     output = n(seed)
    ...:     img = output[0][0]
    ...:     for i in range(1,10):
    ...:         img = mx.ndarray.concat(img, output[i][0])
    ...:     plt.xticks([]); plt.yticks([])
    ...:     plt.imshow(img.asnumpy(), interpolation='none', cmap='Blues')
    ...:     plt.show()
Out[10]:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

评价

分类器和生成器的损失值如图。

趋势和无条件对抗网络差不多:

  1. 前期分类器和生成器在快速进步
  2. 中期分类器和生成器学习效果降低,但依然有效
  3. 后期分类器开始慢慢领先

但是,比无条件对抗网络改善了许多:

  1. d-loss 始终大于理想值 0.693
  2. g-loss 始终小于2

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值