GAN生成对抗网络合集(五):LSGan-最小二乘GAN(附代码)

本文介绍了LSGAN的基本原理,它通过使用最小二乘损失函数解决传统GAN中梯度消失的问题,提供了一种更稳定且快速收敛的解决方案。相比WGAN,LSGAN的损失函数更为简单,将真实样本与1、模拟样本与0的差距用平方差表示。文章还给出了如何将现有GAN代码修改为LSGAN的步骤,并提供了生成MNIST数据集的完整代码示例。
摘要由CSDN通过智能技术生成

首先分享一个讲得还不错的博客:

经典论文复现 | LSGAN:最小二乘生成对抗网络

1. LSGan原理

       GAN是以对抗的方式逼近概率分布。但是直接使用该方法,会随着判别器越来越好而生成器无法与其对抗,进而形成梯度消失的问题。所以不论是WGAN,还是本节中的LSGAN,都是试图使用不同的距离度量(loss值),从而构建一个不仅稳定,同时还收敛迅速的生成对抗网络。
       WGAN使用的是Wasserstein理论来构建度量距离。而LSGAN使用了另一个方法,即使用了更加平滑和非饱和梯度的损失函数——最小二乘来代替原来的Sigmoid交叉熵。这是由于L2正则独有的特性,在数据偏离目标时会有一个与其偏离距离成比例的惩罚,再将其拉回来,从而使数据的偏离不会越来越远
       相对于WGAN而言,LSGAN的loss简单很多。直接将传统的GAN中的loss变为平方差即可。

即LSGan的核心就是损失函数变为D的loss是真实样本和1作差的平方+模拟样本和0作差的平方;G的loss是模拟样本和1作差的平方(L2正则化)
而WGan的核心是损失函数变为真实值和虚拟值的差(L1正则化)
原始GAN的损失函数是D的loss都是真实样本和1作交叉熵,模拟样本和0作交叉熵;G的loss是模拟样本和1作交叉熵。(交叉熵)

在这里插入图片描述
在这里插入图片描述
       为什么要除以2?和以前的原理一样,在对平方求导时会得到一个系数2,与事先的1/2运算正好等于1,使公式更加完整。

2 代码

直接修改我们之前的 GAN生成对抗网络合集(三):InfoGAN和ACGAN-指定类别生成模拟样本的GAN(附代码) 代码,将其改成LSGan。

需要改三个地方:

  1. 修改判别器D
    将判别器的最后一层输出disc改成使用Sigmoid的激活函数。代码如下:
        ... ...
        ... ...
        # 生成的数据可以分别连接不同的输出层产生不同的结果
        # 1维的输出层产生判别结果1或是0
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
        disc = tf.squeeze(disc, -1)
        # print ("disc",disc.get_shape()) # 0 or 1

        # 10维的输出层产生分类结果 (样本标签)
        recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)

        # 2维输出层产生重构造的隐含维度信息
        recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
    return disc, recog_cat, recog_cont
		
  1. 修改loss值
    将原有的loss_d与loss_g改成平方差形式,原有的y_real与y_fake不再需要了,可以删掉,其他代码不用变动。
##################################################################
#  4.定义损失函数和优化器
##################################################################
# 判别器 loss
# loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real))  # 1
# loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake))  # 0
loss_d = tf.reduce_sum(tf.square(disc_real-1) + tf.square(disc_fake-0)) / 2
# print ('loss_d', loss_d.get_shape())

# generator loss
loss_g = tf.reduce_sum(tf.square(disc_fake-1)) / 2

# categorical factor loss 分类因素损失
loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y))
loss_cr = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y))
loss_c = (loss_cf + loss_cr) / 2

# continuous factor loss 隐含信息变量的损失
loss_con = tf.reduce_mean(tf.square(con_fake - z_con))

-----------------------------------------------------------------------------------------------------------------------------------------

这里附上生成MNIST完整代码:

# !/usr/bin/env python3
# -*- coding: utf-8 -*-
__author__ = '黎明'

##################################################################
#  1.引入头文件并加载mnist数据
##################################################################
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow.contrib.slim as slim
import time
from timer import Timer
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/media/S318080208/py_pictures/minist/")  # ,one_hot=True)

tf.reset_default_graph()  # 用于清除默认图形堆栈并重置全局默认图形


##################################################################
#  2.定义生成器与判别器
##################################################################
def generator(x):  # 生成器函数 : 两个全连接+两个反卷积模拟样本的生成,每一层都有BN(批量归一化)处理
    reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) >
  • 6
    点赞
  • 49
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值