import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import numpy as np
import matplotlib.pyplot as plt
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
class Net:
def __init__(self):
self.real_x = tf.placeholder(dtype=tf.float32,shape=[None,784])#训练样本
self.feature_x = tf.placeholder(dtype=tf.float32,shape=[None,128])#外部的特征输入
self.pos_y = tf.placeholder(dtype=tf.float32,shape=[None,1])#表示判别正确的标志
self.nage_y = tf.placeholder(dtype=tf.float32,shape=[None,1])#表示判别错误的标志
self.dnet = Dnet()#判别器
self.gnet = Gnet()#生成器
def forward(self):
#将真实的样本带进判别器网络进行训练得到输出,这个输出是一个二值量
self.real_d_out = self.dnet.forward(self.real_x)
# 将外部的特征输入传进生成器得到输出,也就是将随机生成的数带进去得到模仿的结果
self.feature_g_out = self.gnet.forward(self.feature_x)
self.g_d_out = self.dnet.forward(self.feature_g_out)#将生成的结果带进判别网络看看有没有生成对
def backward(self):
#将真实样本带入判别器去判断得到的输出与正标签求损失,即先训练判别器,让判别器知道这是真的
real_loss = tf.reduce_mean((self.real_x-self.pos_y)**2)
#告诉判别器生成器生成的东西是假的
g_d_loss = tf.reduce_mean((self.g_d_out-self.nage_y)**2)
#将两个损失进行叠加得到判别器的总损失
self.d_loss = real_loss+g_d_loss
#训练判别器网络
self.d_opt = tf.train.AdamOptimizer().minimize(self.d_loss,var_list=self.dnet.getParam())
#判别器训练完毕以后就要将生成器的输出与正标签
# 做比较,告诉判别器生成器生成的东西是对的
self.g_loss = tf.reduce_mean((self.g_d_out-self.pos_y)**2)
#训练生成器网络
self.g_opt = tf.train.AdamOptimizer().minimize(self.g_loss,var_list=self.gnet.getParam())
class Dnet:
def __init__(self):
with tf.variable_scope('Dnet'):#设置命名空间以便在训练网络的时候将所有的参数都一次性训练到
self.w1 = tf.Variable(tf.truncated_normal(shape=[784,512],stddev=0.1))
self.b1 = tf.Variable(tf.zeros([512]))
self.w2 = tf.Variable(tf.truncated_normal(shape=[512,256],stddev=0.1))
self.b2 = tf.Variable(tf.zeros([256]))
self.w = tf.Variable(tf.truncated_normal(shape=[256,1],stddev=0.1))
def forward(self,x):
y1 = tf.nn.leaky_relu(tf.matmul(x,self.w1)+self.b1)
y2 = tf.nn.leaky_relu(tf.matmul(y1,self.w2)+self.b2)
return tf.matmul(y2,self.w)
def getParam(self):
return tf.get_collection(tf.GraphKeys.VARIABLES,scope='Dnet')
class Gnet:
def __init__(self):
with tf.variable_scope('Gnet'):
self.w1 = tf.Variable(tf.truncated_normal(shape=[128,256],stddev=0.1))
self.b1 = tf.Variable(tf.zeros([256]))
self.w2 = tf.Variable(tf.truncated_normal(shape=[256,512],stddev=0.1))
self.b2 = tf.Variable(tf.zeros([512]))
self.w = tf.Variable(tf.truncated_normal(shape=[512,784],stddev=0.1))
def forward(self,x):
y1 = tf.nn.leaky_relu(tf.matmul(x,self.w1)+self.b1)
y2 = tf.nn.leaky_relu(tf.matmul(y1,self.w2)+self.b2)
return tf.matmul(y2,self.w)
def getParam(self):
return tf.get_collection(tf.GraphKeys.VARIABLES,scope='Gnet')
if __name__ == '__main__':
net = Net()
net.forward()
net.backward()
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
for i in range(10000):
x,_ = mnist.train.next_batch(100)
pos_y = np.ones([100,1])#正标签
nage_y = np.zeros([100,1])#负标签
feature_x = np.random.uniform(0,1,size=[100,128])#送进生成器的满足均匀分布的随机数
d_loss,_ = sess.run([net.d_loss,net.d_opt],feed_dict={net.real_x:x,net.pos_y:pos_y,net.nage_y:nage_y,net.feature_x:feature_x})
g_loss,_,out = sess.run([net.g_loss,net.g_opt,net.feature_g_out],feed_dict={net.feature_x:feature_x,net.pos_y:pos_y})
if i % 100 == 0:#每当训练100次就做一下测试看看训练结果
test_feature_x = np.random.uniform(0,1,size=[1,128])
test_data = sess.run([net.feature_g_out],feed_dict={net.feature_x:test_feature_x})
test_img = np.reshape(test_data,[28,28])
plt.imshow(test_img)
plt.pause(0.1)