Gan对抗性神经网络
概念:
1.判别器:判别生成物体的可识别度
2.生成器:生成物体
思路:将判别与生成放在一个网络体系中,判别网络与生成网络一起训练。
(固定生成器)判别器训练思路:真样本,真标签(1),生成样本,假标签(0) 促使其可以识别真假
(固定判别器)生成器训练思路:生成样本,真标签(1)
训练好后生成器生成样本接近0.5左右
注意事项:
不能分开训练(交替训练)
判别器最后一层去掉sigmod
不用信息丢失快的函数
更新后权重强制截断到一定范围比如[-0.01,0.01]
如果是条件gan需要控制正负样本,只有指定条件的是正样本,其他全是负样本
判别器
class dnet:
def __init__(self):
# 用命名空间返回变量 逐个返回太慢
with tf.variable_scope("d_param"):
self.in_w = tf.Variable(tf.truncated_normal(shape=[784,512],stddev=0.01))
self.in_b = tf.Variable(tf.zeros([512]))
self.in_w1 = tf.Variable(tf.truncated_normal(shape=[512,256],stddev=0.01))
self.in_b1 = tf.Variable(tf.zeros([256]))
self.w=tf.Variable(tf.truncated_normal(shape=[256,1],stddev=0.01))
def forward(self,x):
# leake_relu不容易流失信息
y=tf.nn.leaky_relu(tf.matmul(x,self.in_w)+self.in_b)
y=tf.nn.leaky_relu(tf.matmul(y,self.in_w1)+self.in_b1)
# x=y 不会饱和(但需要看情况)
return tf.matmul(y,self.w)
def params(self):
# 返回所有变量
return tf.get_collection(tf.GraphKeys.VARIABLES,scope="d_param")
生成器
class gnet:
def __init__(self):
with tf.variable_scope("g_param"):
# 用命名空间返回变量 逐个返回太慢
self.in_x=tf.Variable(tf.truncated_normal(shape=[100,256],stddev=0.01))
self.in_b=tf.Variable(tf.zeros([256]))
self.in_x1=tf.Variable(tf.truncated_normal(shape=[256,512],stddev=0.01))
self.in_b1=tf.Variable(tf.zeros([512]))
self.ouput=tf.Variable(tf.truncated_normal(shape=[512,784],stddev=0.01))
def forward(self,x):
# leake_relu不容易流失信息
y=tf.nn.leaky_relu(tf.matmul(x,self.in_x)+self.in_b)
y=tf.nn.leaky_relu(tf.matmul(y,self.in_x1)+self.in_b1)
# x=y 不会饱和(但需要看情况)
return tf.matmul(y,self.ouput)
def params(self):
# 返回所有变量
return tf.get_collection(tf.GraphKeys.VARIABLES,scope="g_param")
主网络
class net:
def __init__(self):
self.gnet=gnet()
self.dnet=dnet()
# 实际物体
self.real = tf.placeholder(dtype=tf.float32,shape=[None,784])
# 随意定的外部物体
self.feature=tf.placeholder(dtype=tf.float32,shape=[None,100])
# 占位符
self.pt=tf.placeholder(dtype=tf.float32,shape=[None,1])
self.nt=tf.placeholder(dtype=tf.float32,shape=[None,1])
# 初始化前后项
self.forward()
self.backward()
def forward(self):
# 判别器
# 判别图片
self.real_output=self.dnet.forward(self.real)
# 生成器
# 生成图片
self.feature_key=self.gnet.forward(self.feature)
# 判别图片
self.feature_out = self.dnet.forward(self.feature_key)
def backward(self):
# 判别器
# 正样本loss
# 负样本loss
# 判别器训练
# 生成器
# 判别器生成器训练
训练
if __name__ == '__main__':
net =net()
init=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
plt.ion()
for epoch in range(1000000000):
# 判别器
# 样本
real_xs,_=mnist.train.next_batch(512)
# 正标签
real_ys = np.ones(shape=[512,1],dtype=np.float32)
# 生成
# 随机生成0-1 实数
gen_xs = np.random.uniform(0,1,size=[512,100])
gen_ys = np.zeros(shape=[512,1],dtype=np.float32)
_dloss,_=sess.run([net.loss_d,net.opt_d],feed_dict={
net.real:real_xs,
net.pt:real_ys,
net.feature:gen_xs,
net.nt:gen_ys
})
# 生成器
gen_xs = np.random.uniform(0,1,size=[512,100])
real_ys = np.ones(shape=[512,1],dtype=np.float32)
_gloss,_=sess.run([net.loss_g,net.opt_g],feed_dict={
net.feature:gen_xs,
net.pt:real_ys
})
print(_dloss," ",_gloss)
gan部分代码实现仅供参考,如果网络结构复杂需要重新设计判别器和生成器中的网络