LSGAN(最小乘二GAN):具有WGAN同样效果的GAN

LSGAN通过使用平方差损失函数替代传统的Sigmod交叉熵,解决了GAN中梯度消失问题,使得训练更稳定且快速收敛。它对数据偏离给予与偏离距离成比例的惩罚,保持数据的稳定性。判别器和生成器的损失函数被简化,并引入了额外的损失项来优化分类和隐含信息的准确性。
摘要由CSDN通过智能技术生成

  前面已经介绍GAN是以对抗的形式逼近概率分布。但直接使用该方法,对随着判别器越来越好而生成器无法与其对抗,进行形成梯度消失的问题。WGAN和LSGAN都是试图使用不同的距离度量,从而构建一个不仅稳定,同时还能快速收敛的生成对抗网络。

1. LSGAN介绍

  WGAN使用的是Wasserstein理论来构建度量距离。而LSGAN使用另一种方法,使用更加平滑和非饱和梯度的损失函数——最小乘二来代替原来的Sigmod交叉熵。这是由于L2正则独有的特性,在数据偏离目标时会有一个与其偏离距离成比例的惩罚,再将其拉回来,从而使数据的偏离不会越来越远。
  LSGAN的loss简单很多,直接将传统GAN中的softmax变成平方差即可。

  • 判别器的loss:D_loss=tf.reduce_sum(tf.square(D(real_X)-1) + tf.square(D(random_Y))/2
  • 生成器的loss:G_loss=tf.reduce_sum(tf.square(D(random_Y)-1)/2
x = tf.placeholder(tf.float32, [None, n_input])
y = tf.placeholder(tf.int32, [None])

z_con = tf.random_normal((batch_size, con_dim))#2列
z_rand = tf.random_normal((batch_size, rand_dim))#38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth = classes_dim), z_con, z_rand])#50列
# 模拟数据gen
gen = generator(z)
genout= tf.squeeze(gen, -1)

# 判别器
disc_real, class_real, _ = discriminator(x)
disc_fake, class_fake, con_fake = discriminator(gen)
pred_class = tf.argmax(class_fake, dimension=1)

# 判别器loss:真实数据真,模拟数据假
loss_d = tf.reduce_sum(tf.square(disc_real-1) + tf.square(disc_fake))/2
# 生成器loss:模拟数据为假
loss_g = tf.reduce_sum(tf.square(disc_fake-1))/2

# 标签loss
loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y))#class ok 图片对不上
loss_cr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y))#生成的图片与class ok 与输入的class对不上
loss_c =(loss_cf + loss_cr) / 2
# 隐含信息loss
loss_con =tf.reduce_mean(tf.square(con_fake-z_con))

# 获得各个网络中各自的训练参数
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]


disc_global_step = tf.Variable(0, trainable=False)
gen_global_step = tf.Variable(0, trainable=False)

train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d + loss_c + loss_con, var_list = d_vars, global_step = disc_global_step)
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g + loss_c + loss_con, var_list = g_vars, global_step = gen_global_step)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值