MindSpore实现生成对抗网络(1)

本文介绍了如何使用MindSpore实现生成对抗网络(GAN),专注于生成一维高斯分布。内容涵盖GAN的基本概念,包括判别器和生成器的设计,损失函数的选择,以及WithLossCell和TrainOneStepCell的自定义。文章展示了从数据生成、网络构建到训练的完整过程,并通过可视化展示训练进展,为后续的DCGAN和CGAN实现奠定了基础。
摘要由CSDN通过智能技术生成

MindSpore实现生成对抗网络-GAN (1)

生成对抗网络(GAN)问世已经有好几年的时间了,其属于生成模型的一种,是现在比较热门的一个研究方向。GAN由两个部分组成——判别器和生成器。生成器用于生成样本,判别器判断其输入样本是真实的还是生成的。Mindspore是一个新生的AI框架,相关的资源较少,所以考虑用MindSpore实现一些简单的GAN来作为一个MindSpore相关的教程,算是给MindSpore社区贡献一点点代码。
暂时考虑写3篇相关的教程:
1.使用简单的GAN生成高斯分布
2.使用DCGAN生成MNIST手写数字
3.使用CGAN生成MNSIT手写数字

使用简单的GAN生成高斯分布

先完成一个简单的,用GAN生成服从均值为0,方差为1的一维高斯分布的数据。

关于GAN的基本理论,网上能够找到很多的讲解,所以,这里不再对原理做相关的介绍,而是侧重于怎么用Mindspore框架实现GAN的相关算法。总的来说,Gan是一种通过对抗的方式去学习训练数据分布的生成模型。首先,它包含两个网络,一个叫做判别器,另一个叫生成器。顾名思义,生成器的作用是生成我们想要的数据,判别器的作用是判断数据是生成的还是真实的。训练时,生成器尽可能生成逼真的样本去欺骗判别器,判别器网络则尽可能去判断输入的样本是真实样本还是生成样本。二者就在这样的相互对抗过程当中训练提升自己的性能,如下所示(latent code是生成器的输入,通常为随机噪声)。

在这里插入图片描述

Gan工作流程

各个模块的设计

所使用的mindspore版本为GPU-1.0。这里建议使用conda安装或者使用docker。

先导入会用到的包

from mindspore import nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore import context
import mindspore.ops.operations as P
import mindspore.ops.functional as F
import mindspore.ops.composite as C
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
                                       _get_parallel_mode)
from mindspore.context import ParallelMode
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
import os
import numpy as np
import matplotlib.pyplot as plt

1.判别器和生成器

这一步没有太多要说的,使用简单的全连接网络就可以获得不错的效果。判别器判断一个样本的真假,可以视作一个二分类任务,所以其输出只有一个单值,表示这个样本是真的概率。因为接下来定义的损失函数会对输出做sigmoid运算,所以判别器输出层不使用Sigmoid激活。

class Discriminator(nn.Cell):
    def __init__(self, hidden_dim, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell([
             nn.Dense(1, hidden_dim),
             nn.LeakyReLU(),
             nn.Dense(hidden_dim, 1)
        ])

    def construct(self, x):
        return self.model(x)
    
class Generator(nn.Cell):
    def __init__(self, input_dim, hidden_dim, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell([
            nn.Dense(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dense(hidden_dim, 1)
        ])

    def construct(self, x):
        return self.model(x)

2.损失函数

根据判别器的结构,我们使用Sigmoid交叉熵损失作为损失函数。但是mindspore的官方实现里没有这个Cell,所以可以自定义一个。

class SigmoidCrossEntropyWithLogits(nn.loss.loss._Loss):
    def __init__(self):
        super<
  • 13
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值