对抗神经网络Gan实现样本(mnist数据集)

Gan对抗性神经网络

概念:

1.判别器:判别生成物体的可识别度

2.生成器:生成物体

思路:将判别与生成放在一个网络体系中,判别网络与生成网络一起训练。

(固定生成器)判别器训练思路:真样本,真标签(1),生成样本,假标签(0) 促使其可以识别真假

(固定判别器)生成器训练思路:生成样本,真标签(1)

训练好后生成器生成样本接近0.5左右

注意事项:

不能分开训练(交替训练)

判别器最后一层去掉sigmod

不用信息丢失快的函数

更新后权重强制截断到一定范围比如[-0.01,0.01]

如果是条件gan需要控制正负样本,只有指定条件的是正样本,其他全是负样本

判别器

class  dnet:
    def __init__(self):
        # 用命名空间返回变量 逐个返回太慢
        with tf.variable_scope("d_param"):
            self.in_w = tf.Variable(tf.truncated_normal(shape=[784,512],stddev=0.01))
            self.in_b = tf.Variable(tf.zeros([512]))
            self.in_w1 = tf.Variable(tf.truncated_normal(shape=[512,256],stddev=0.01))
            self.in_b1 = tf.Variable(tf.zeros([256]))
            self.w=tf.Variable(tf.truncated_normal(shape=[256,1],stddev=0.01))
    def forward(self,x):
            # leake_relu不容易流失信息
            y=tf.nn.leaky_relu(tf.matmul(x,self.in_w)+self.in_b)
            y=tf.nn.leaky_relu(tf.matmul(y,self.in_w1)+self.in_b1)
            # x=y 不会饱和(但需要看情况)
            return  tf.matmul(y,self.w)
    def params(self):
        # 返回所有变量
        return  tf.get_collection(tf.GraphKeys.VARIABLES,scope="d_param")

 

生成器

class  gnet:
    def __init__(self):
        with tf.variable_scope("g_param"):
            # 用命名空间返回变量 逐个返回太慢
            self.in_x=tf.Variable(tf.truncated_normal(shape=[100,256],stddev=0.01))
            self.in_b=tf.Variable(tf.zeros([256]))
            self.in_x1=tf.Variable(tf.truncated_normal(shape=[256,512],stddev=0.01))
            self.in_b1=tf.Variable(tf.zeros([512]))
            self.ouput=tf.Variable(tf.truncated_normal(shape=[512,784],stddev=0.01))
    def forward(self,x):
        # leake_relu不容易流失信息
        y=tf.nn.leaky_relu(tf.matmul(x,self.in_x)+self.in_b)
        y=tf.nn.leaky_relu(tf.matmul(y,self.in_x1)+self.in_b1)
        # x=y 不会饱和(但需要看情况)
        return tf.matmul(y,self.ouput)
    def params(self):
        # 返回所有变量
        return tf.get_collection(tf.GraphKeys.VARIABLES,scope="g_param")

主网络 

class  net:
    def __init__(self):
        self.gnet=gnet()
        self.dnet=dnet()
        # 实际物体
        self.real = tf.placeholder(dtype=tf.float32,shape=[None,784])
        # 随意定的外部物体
        self.feature=tf.placeholder(dtype=tf.float32,shape=[None,100])
        # 占位符
        self.pt=tf.placeholder(dtype=tf.float32,shape=[None,1])
        self.nt=tf.placeholder(dtype=tf.float32,shape=[None,1])
        # 初始化前后项
        self.forward()
        self.backward()
    def forward(self):
        # 判别器
        # 判别图片
        self.real_output=self.dnet.forward(self.real)
        # 生成器
        # 生成图片
        self.feature_key=self.gnet.forward(self.feature)
        # 判别图片
        self.feature_out = self.dnet.forward(self.feature_key)
      def backward(self):
        # 判别器
        # 正样本loss

        # 负样本loss
     
        # 判别器训练
      
        # 生成器

        # 判别器生成器训练
        

训练

if __name__ == '__main__':
    net =net()
    init=tf.global_variables_initializer()
    with tf.Session() as  sess:
        sess.run(init)
        plt.ion()
        for epoch in range(1000000000):
            # 判别器
            # 样本
            real_xs,_=mnist.train.next_batch(512)
            # 正标签
            real_ys = np.ones(shape=[512,1],dtype=np.float32)
            # 生成
            # 随机生成0-1 实数
            gen_xs = np.random.uniform(0,1,size=[512,100])
            gen_ys = np.zeros(shape=[512,1],dtype=np.float32)
            _dloss,_=sess.run([net.loss_d,net.opt_d],feed_dict={
                net.real:real_xs,
                net.pt:real_ys,
                net.feature:gen_xs,
                net.nt:gen_ys
            })
            # 生成器
            gen_xs = np.random.uniform(0,1,size=[512,100])
            real_ys = np.ones(shape=[512,1],dtype=np.float32)
            _gloss,_=sess.run([net.loss_g,net.opt_g],feed_dict={
                net.feature:gen_xs,
                net.pt:real_ys
            })
            print(_dloss," ",_gloss)

gan部分代码实现仅供参考,如果网络结构复杂需要重新设计判别器和生成器中的网络

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值