GAN实战

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)

class Net:
    def __init__(self):
        self.real_x = tf.placeholder(dtype=tf.float32,shape=[None,784])#训练样本
        self.feature_x = tf.placeholder(dtype=tf.float32,shape=[None,128])#外部的特征输入
        self.pos_y = tf.placeholder(dtype=tf.float32,shape=[None,1])#表示判别正确的标志
        self.nage_y = tf.placeholder(dtype=tf.float32,shape=[None,1])#表示判别错误的标志
        self.dnet = Dnet()#判别器
        self.gnet = Gnet()#生成器
    def forward(self):
        #将真实的样本带进判别器网络进行训练得到输出,这个输出是一个二值量
        self.real_d_out = self.dnet.forward(self.real_x)
        # 将外部的特征输入传进生成器得到输出,也就是将随机生成的数带进去得到模仿的结果
        self.feature_g_out = self.gnet.forward(self.feature_x)
        self.g_d_out = self.dnet.forward(self.feature_g_out)#将生成的结果带进判别网络看看有没有生成对
    def backward(self):
        #将真实样本带入判别器去判断得到的输出与正标签求损失,即先训练判别器,让判别器知道这是真的
        real_loss = tf.reduce_mean((self.real_x-self.pos_y)**2)
        #告诉判别器生成器生成的东西是假的
        g_d_loss = tf.reduce_mean((self.g_d_out-self.nage_y)**2)
        #将两个损失进行叠加得到判别器的总损失
        self.d_loss = real_loss+g_d_loss
        #训练判别器网络
        self.d_opt = tf.train.AdamOptimizer().minimize(self.d_loss,var_list=self.dnet.getParam())

        #判别器训练完毕以后就要将生成器的输出与正标签
        # 做比较,告诉判别器生成器生成的东西是对的
        self.g_loss = tf.reduce_mean((self.g_d_out-self.pos_y)**2)
        #训练生成器网络
        self.g_opt = tf.train.AdamOptimizer().minimize(self.g_loss,var_list=self.gnet.getParam())
class Dnet:
    def __init__(self):
        with tf.variable_scope('Dnet'):#设置命名空间以便在训练网络的时候将所有的参数都一次性训练到
            self.w1 = tf.Variable(tf.truncated_normal(shape=[784,512],stddev=0.1))
            self.b1 = tf.Variable(tf.zeros([512]))
            self.w2 = tf.Variable(tf.truncated_normal(shape=[512,256],stddev=0.1))
            self.b2 = tf.Variable(tf.zeros([256]))
            self.w = tf.Variable(tf.truncated_normal(shape=[256,1],stddev=0.1))
    def forward(self,x):
        y1 = tf.nn.leaky_relu(tf.matmul(x,self.w1)+self.b1)
        y2 = tf.nn.leaky_relu(tf.matmul(y1,self.w2)+self.b2)
        return tf.matmul(y2,self.w)
    def getParam(self):
        return tf.get_collection(tf.GraphKeys.VARIABLES,scope='Dnet')
class Gnet:
    def __init__(self):
        with tf.variable_scope('Gnet'):
            self.w1 = tf.Variable(tf.truncated_normal(shape=[128,256],stddev=0.1))
            self.b1 = tf.Variable(tf.zeros([256]))
            self.w2 = tf.Variable(tf.truncated_normal(shape=[256,512],stddev=0.1))
            self.b2 = tf.Variable(tf.zeros([512]))
            self.w = tf.Variable(tf.truncated_normal(shape=[512,784],stddev=0.1))
    def forward(self,x):
        y1 = tf.nn.leaky_relu(tf.matmul(x,self.w1)+self.b1)
        y2 = tf.nn.leaky_relu(tf.matmul(y1,self.w2)+self.b2)
        return tf.matmul(y2,self.w)
    def getParam(self):
        return tf.get_collection(tf.GraphKeys.VARIABLES,scope='Gnet')
if __name__ == '__main__':
    net = Net()
    net.forward()
    net.backward()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        for i in range(10000):
            x,_ = mnist.train.next_batch(100)
            pos_y = np.ones([100,1])#正标签
            nage_y = np.zeros([100,1])#负标签
            feature_x = np.random.uniform(0,1,size=[100,128])#送进生成器的满足均匀分布的随机数
            d_loss,_ = sess.run([net.d_loss,net.d_opt],feed_dict={net.real_x:x,net.pos_y:pos_y,net.nage_y:nage_y,net.feature_x:feature_x})
            g_loss,_,out = sess.run([net.g_loss,net.g_opt,net.feature_g_out],feed_dict={net.feature_x:feature_x,net.pos_y:pos_y})
            if i % 100 == 0:#每当训练100次就做一下测试看看训练结果
                test_feature_x = np.random.uniform(0,1,size=[1,128])
                test_data = sess.run([net.feature_g_out],feed_dict={net.feature_x:test_feature_x})
                test_img = np.reshape(test_data,[28,28])
                plt.imshow(test_img)
                plt.pause(0.1)

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值