对抗生成网络学习(四)——WGAN+爬虫生成皮卡丘图像(tensorflow实现)

一、背景

WGAN的全称为Wasserstein GAN, 是Martin Arjovsky等人于17年1月份提出的一个模型,该文章可以参考[1]。WGAN针对GAN存在的问题进行了有针对性的改进,但WGAN几乎没有改变GAN的结构,只是改变了激活函数和loss函数,以及截取权重,却得到了非常好的效果[2]。且WGAN的方法同样适用于DCGAN。

本文以python爬虫爬取的皮卡丘(pikachu)数据集为例,利用WGAN生成皮卡丘图像。

[1]文章链接:https://arxiv.org/abs/1701.07875

[2]DCGAN、WGAN、WGAN-GP、LSGAN、BEGAN原理总结及对比

二、WGAN原理

网上对于WGAN的解读文章介绍的非常详细,这里给出两个详细介绍的链接:

[2]DCGAN、WGAN、WGAN-GP、LSGAN、BEGAN原理总结及对比

[3]令人拍案叫绝的Wasserstein GAN

在文章《Wasserstein GAN》中,作者从数学的角度进行了大量的公式推论,指出了GAN的问题所在,并提出了四点改进方法:

(1)判别器的最后一层去掉sigmoid

(2)生成器和判别器的loss不取log

(3)在每一轮梯度更新之后,对权值进行截取,将其值约束到一个范围之内(fixed box)。文章将其截取至[-0.01, 0.01],文章中作者也提到,尽管权值截取对于增强Lipschitz限制而言,是一种非常糟糕的方式(Weight clipping is a clearly terrible way to enforce a Lipschitz constraint),但是作者的数据集方差较小,被迫接受了这种方式,且取得了较好的结果,对于后续的研究,作者也希望有人能改进这一点(We experimented with simple variants with little difference, and we stuck with weight clipping due to its simplicity and already good performance.However, we do leave the topic of enforcing Lipschitz constraints in a neural network setting for further investigation, and we actively encourage interested researchers to improve on this method.)。

(4)作者推荐使用SGD,RMSprop等优化器,不要基于使用动量的优化算法,比如adam。

WGAN算法的伪代码如下:

当然,WGAN对GAN做了改进,其取得的效果也非常明显:

(1)WGAN理论上给出了GAN训练不稳定的原因,即交叉熵(JS散度)不适合衡量具有不相交部分的分布之间的距离,转而使用Wassertein距离去衡量生成数据分布和真实数据分布之间的距离,理论上解决了训练不稳定的问题,不再需要小心平衡生成器和判别器的训练程度。

(2)基本解决了模型崩溃(collapse mode)的问题,确保了生成样本的多样性。

(3)训练过程中终于有一个像交叉熵、准确率这样的数值来指示训练的进程,这个数值越小代表GAN训练得越好,代表生成器产生的图像质量越高。

(4)以上一切好处不需要精心设计的网络架构,最简单的多层全连接网络就可以做到。因为WGAN主要关注的是模型分布问题,而非模型结构。

本实验的网络结构主要参考了之前的DCGAN,和github上的代码,主要的参考代码如下:

[4]对抗神经网络学习(二)——DCGAN生成人脸图像(tensorflow实现)

[5]https://github.com/ConnorJL/WGAN-Tensorflow/blob/master/train_WGAN.py

[6]https://github.com/moxiegushi/pokeGAN/blob/master/pokeGAN.py

本实验结合上述的一些代码,并加上了自己的改进,以完成此次实验。

三、WGAN实现

1.数据准备

此次WGAN的实现采用自己的数据集,所有的数据集是从百度图片上爬取下来的。网上可以找到很多爬取百度图片的教程,这里就不再多说,根据关键词爬取皮卡丘图像的代码为:

# 导入需要的库
import requests
import os
import json

# 爬取百度图片,解析页面的函数
def getManyPages(keyword, pages):
    '''
    参数keyword:要下载的影像关键词
    参数pages:需要下载的页面数
    '''
    params = []

    for i in range(30, 30 * pages + 30, 30):
        params.append({
            'tn': 'resultjson_com',
            'ipn': 'rj',
            'ct': 201326592,
            'is': '',
            'fp': 'result',
            'queryWord': keyword,
            'cl': 2,
            'lm': -1,
            'ie': 'utf-8',
            'oe': 'utf-8',
            'adpicid': '',
            'st': -1,
            'z': '',
            'ic': 0,
            'word': keyword,
            's': '',
            'se': '',
            'tab': '',
            'width': '',
            'height': '',
            'face': 0,
            'istype': 2,
            'qc': '',
            'nc': 1,
            'fr': '',
            'pn': i,
            'rn': 30,
            'gsm': '1e',
            '1488942260214': ''
        })
    url = 'https://image.baidu.com/search/acjson'
    urls = []
    for i in params:
        try:
            urls.append(requests.get(url, params=i).json().get('data'))
        except json.decoder.JSONDecodeError:
            print("解析出错")
    return urls

# 下载图片并保存
def getImg(dataList, localPath):
    '''
    参数datallist:下载图片的地址集
    参数localPath:保存下载图片的路径
    '''
    if not os.path.exists(localPath):  # 判断是否存在保存路径,如果不存在就创建
        os.mkdir(localPath)
    x = 0
    for list in dataList:
        for i in list:
            if i.get('thumbURL') != None:
                print('正在下载:%s' % i.get('thumbURL'))
                ir = requests.get(i.get('thumbURL'))
                open(localPath + '%d.jpg' % x, 'wb').write(ir.content)
                x += 1
            else:
                print('图片链接不存在')

# 根据关键词皮卡丘来下载图片
if __name__ == '__main__':
    dataList = getManyPages('皮卡丘', 20)     # 参数1:关键字,参数2:要下载的页数
    getImg(dataList, './pikachu/')            # 参数2:指定保存的路径

直接运行上述代码,会在代码所在的根路径下,创建一个pikachu的文件夹,里面保存有下载的影像,下载图像的结果为:

一共爬取了400张影像,爬取的图像尺寸大小不一,且图像质量也存在很大偏差,有一些图像甚至不是皮卡丘,但也被下载了下来,需要对这部分影像手动删除处理,处理后的数据集共358张图片。

 

2.定义超参数 hyper parameters

超参数是指模型中需要用到的,而非模型生成的参数,这里定义的参数主要有:

from skimage import io, transform        # 用于读取影像
import tensorflow as tf                  # 构造网络
import numpy as np                
import matplotlib.pyplot as plt          # 绘制结果并保存
import os                                # 创建文件夹

image_width = 128      # 图像宽128像素
image_height = 128     # 图像高128像素
image_channel = 3      # 图像的通道数为3

input_dir = "./pikachu/"
output_dir = "result/"
batch_size = 64
z_dim = 128
lr_gen = 5e-5          # 生成器的学习率
lr_dis = 5e-5          # 判别器的学习率
epoch = 1000            

3.数据预处理

数据预处理部分包括读取数据,并对其resize;定义leaky_relu层。具体的代码为:

# 读取数据的函数,参照之间的DCGAN代码,这里做的改进在于读取数据的库使用的是skimage而非PIL
def process_data():
    '''
    函数功能:读取路径下的所有图像,返回读取的图像数据集train_set和图像个数image_len
    '''
    images = os.listdir(input_dir)
    image_len = len(images)

    data = np.empty((image_len, image_width, image_height, image_channel), dtype="float32")

    for i in range(image_len):
        # 利用skimage.io.image函数读取图像。如果用PIL.Image读取则会报错
        img = io.imread(input_dir + images[i])
        print(img.shape)
        # 将所有图像resize成128*128
        img = transform.resize(img, (image_width, image_height))
        arr = (np.asarray(img, dtype="float32"))
        # 这里暂时不要对图像进行归一化处理,否则结果全是噪声
        data[i, :, :, :] = arr

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        data = tf.reshape(data, [-1, image_width, image_height, image_channel])
        train_set = sess.run(data)

    return train_set, image_len
def leaky_relu(x, n, leak=0.2):
    return tf.maximum(x, leak * x, name=n)

4.定义生成器

定义生成器generator,代码如下:

def generator(input, random_dim, is_train, reuse=False):
    with tf.variable_scope('generator') as scope:
        if reuse:
            scope.reuse_variables()
        w1 = tf.get_variable('w1', shape=[random_dim, 4 * 4 * 512], dtype=tf.float32,
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b1 = tf.get_variable('b1', shape=[512 * 4 * 4], dtype=tf.float32,
                             initializer=tf.constant_initializer(0.0))
        flat_conv1 = tf.add(tf.matmul(input, w1), b1, name='flat_conv1')

        # 4*4*512
        conv1 = tf.reshape(flat_conv1, shape=[-1, 4, 4, 512], name='conv1')
        bn1 = tf.contrib.layers.batch_norm(conv1, is_training=is_train, epsilon=1e-5, decay=0.9,
                                           updates_collections=None, scope='bn1')
        act1 = tf.nn.relu(bn1, name='act1')

        # 8*8*256
        conv2 = tf.layers.conv2d_transpose(act1, 256, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                           kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                                           name='conv2')
        bn2 = tf.contrib.layers.batch_norm(conv2, is_training=is_train, epsilon=1e-5, decay=0.9,
                                           updates_collections=None, scope='bn2')
        act2 = tf.nn.relu(bn2, name='act2')

        # 16*16*128
        conv3 = tf.layers.conv2d_transpose(act2, 128, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                           kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                                           name='conv3')
        bn3 = tf.contrib.layers.batch_norm(conv3, is_training=is_train, epsilon=1e-5, decay=0.9,
                                           updates_collections=None, scope='bn3')
        act3 = tf.nn.relu(bn3, name='act3')

        # 32*32*64
        conv4 = tf.layers.conv2d_transpose(act3, 64, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                           kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                                           name='conv4')
        bn4 = tf.contrib.layers.batch_norm(conv4, is_training=is_train, epsilon=1e-5, decay=0.9,
                                           updates_collections=None, scope='bn4')
        act4 = tf.nn.relu(bn4, name='act4')

        # 64*64*32
        conv5 = tf.layers.conv2d_transpose(act4, 32, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                           kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                                           name='conv5')
        bn5 = tf.contrib.layers.batch_norm(conv5, is_training=is_train, epsilon=1e-5, decay=0.9,
                                           updates_collections=None, scope='bn5')
        act5 = tf.nn.relu(bn5, name='act5')

        # 128*128*3
        conv6 = tf.layers.conv2d_transpose(act5, image_channel, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                           kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                                           name='conv6')

        act6 = tf.nn.tanh(conv6, name='act6')

        return act6

5.定义判别器

定义判别器discriminator,代码如下:

def discriminator(input, is_train, reuse=False):
    with tf.variable_scope('discriminator') as scope:
        if reuse:
            scope.reuse_variables()

        # 64*64*64
        conv1 = tf.layers.conv2d(input, 64, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                 kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                                 name='conv1')
        act1 = leaky_relu(conv1, n='act1')

        # 32*32*128
        conv2 = tf.layers.conv2d(act1, 128, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                 kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                                 name='conv2')
        bn2 = tf.contrib.layers.batch_norm(conv2, is_training=is_train, epsilon=1e-5, decay=0.9,
                                           updates_collections=None, scope='bn2')
        act2 = leaky_relu(bn2, n='act2')

        # 16*16*256
        conv3 = tf.layers.conv2d(act2, 256, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                 kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                                 name='conv3')
        bn3 = tf.contrib.layers.batch_norm(conv3, is_training=is_train, epsilon=1e-5, decay=0.9,
                                           updates_collections=None, scope='bn3')
        act3 = leaky_relu(bn3, n='act3')

        # 8*8*512
        conv4 = tf.layers.conv2d(act3, 512, kernel_size=[5, 5], strides=[2, 2], padding="SAME",
                                 kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                                 name='conv4')
        bn4 = tf.contrib.layers.batch_norm(conv4, is_training=is_train, epsilon=1e-5, decay=0.9, updates_collections=None,
                                           scope='bn4')
        act4 = leaky_relu(bn4, n='act4')

        # start from act4
        dim = int(np.prod(act4.get_shape()[1:]))
        fc1 = tf.reshape(act4, shape=[-1, dim], name='fc1')
        w2 = tf.get_variable('w2', shape=[fc1.shape[-1], 1], dtype=tf.float32,
                             initializer=tf.truncated_normal_initializer(stddev=0.02))
        b2 = tf.get_variable('b2', shape=[1], dtype=tf.float32,
                             initializer=tf.constant_initializer(0.0))
        # wgan不适用sigmoid
        logits = tf.add(tf.matmul(fc1, w2), b2, name='logits')

        return logits

6.定义保存结果函数

保存结果的函数主要参照之间的DCGAN保存结果函数,其代码为:

def plot_and_save(num, images):
    batch_size = len(images)
    n = np.int(np.sqrt(batch_size))

    image_size = np.shape(images)[2]
    n_channel = np.shape(images)[3]
    images = np.reshape(images, [-1, image_size, image_size, n_channel])
    canvas = np.empty((n * image_size, n * image_size, image_channel))

    for i in range(n):
        for j in range(n):
            canvas[i * image_size:(i + 1) * image_size, j * image_size:(j + 1) * image_size, :] = images[
                n * i + j].reshape(128, 128, 3)

    plt.figure(figsize=(8, 8))
    plt.imshow(canvas, cmap="gray")
    label = "Epoch: {0}".format(num + 1)
    plt.xlabel(label)

    if type(num) is str:
        file_name = num
    else:
        file_name = "pikachu_gen" + str(num)

    plt.savefig(file_name)
    print(output_dir)
    print("Image saved in file: ", file_name)
    plt.close()

7.定义训练器

训练器是整个代码的关键部分,构建训练器的思路为:首先构建函数模型;然后读取数据;之后加载数据进行训练;最后输入随机数到模型中来生成皮卡丘。定义训练器的代码为:

def train():

    # 构建模型---------------------------------------------------------------------

    with tf.variable_scope('input'):
        # 模型中的输入数据部分
        real_image = tf.placeholder(tf.float32, shape=[None, image_height, image_width, image_channel], name='real_image')
        random_input = tf.placeholder(tf.float32, shape=[None, z_dim], name='rand_input')
        is_train = tf.placeholder(tf.bool, name='is_train')

    # 定义WGAN
    fake_image = generator(random_input, z_dim, is_train)
    real_result = discriminator(real_image, is_train)
    fake_result = discriminator(fake_image, is_train, reuse=True)

    # 定义损失函数,这是WGAN的改进所在
    d_loss = tf.reduce_mean(fake_result) - tf.reduce_mean(real_result)  # This optimizes the discriminator.
    g_loss = -tf.reduce_mean(fake_result)  # This optimizes the generator.

    # 定义方差
    t_vars = tf.trainable_variables()
    d_vars = [var for var in t_vars if 'discriminator' in var.name]
    g_vars = [var for var in t_vars if 'generator' in var.name]

    # 定义优化器,这里使用RMSProp
    trainer_d = tf.train.RMSPropOptimizer(learning_rate=0.0002).minimize(d_loss, var_list=d_vars)
    trainer_g = tf.train.RMSPropOptimizer(learning_rate=0.0002).minimize(g_loss, var_list=g_vars)

    # 权值裁剪至[-0.01, 0.01]
    d_clip = [v.assign(tf.clip_by_value(v, -0.01, 0.01)) for v in d_vars]
    # 模型构建完毕------------------------------------------------------------------

    # 读取数据 ---------------------------------------------------------------------
    image_batch, samples_num = process_data()
    # 数据读取完毕------------------------------------------------------------------

    batch_num = int(samples_num / batch_size)
    total_batch = 0

    # 创建会话并初始化
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())

    print('total training sample num:%d' % samples_num)
    print('batch size: %d, batch num per epoch: %d, epoch num: %d' % (batch_size, batch_num, epoch))
    print('start training...')

    # 逐个epoch进行训练
    for i in range(epoch):
        # 逐个batch进行训练
        for j in range(batch_num):
            # 每训练d_iters次判别器,训练g_iters次生成器
            d_iters = 5
            g_iters = 1
            # 随机噪声作为输入数据
            train_noise = np.random.uniform(-1.0, 1.0, size=[batch_size, z_dim]).astype(np.float32)
            # 每次训练判别器
            for k in range(d_iters):
                # 拿出batch_size张图像进行训练
                train_image = image_batch[j*batch_size:j*batch_size + batch_size]

                # 权值截断
                sess.run(d_clip)

                # 更新discriminator
                _, dLoss = sess.run([trainer_d, d_loss],
                                    feed_dict={random_input: train_noise, real_image: train_image, is_train: True})

            # 更新generator
            for k in range(g_iters):
                _, gLoss = sess.run([trainer_g, g_loss],
                                    feed_dict={random_input: train_noise, is_train: True})
            # 打印generator和discriminator的loss值
            print('train:[%d/%d],d_loss:%f,g_loss:%f' % (i, j, dLoss, gLoss))

        # 每训练10个epoch进行一次保存结果
        if i % 10 == 0:
            # 判断保存结果的文件夹是否存在,若不存在,则创建
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)

            # 随机生成噪声作为输入
            sample_noise = np.random.uniform(-1.0, 1.0, size=[batch_size, z_dim]).astype(np.float32)
            # 根据generator生成结果
            imgtest = sess.run(fake_image, feed_dict={random_input: sample_noise, is_train: False})

            # imgtest的格式转换
            imgtest.astype(np.uint8)

            # 保存结果
            plot_and_save(i, imgtest)
            print('train:[%d],d_loss:%f,g_loss:%f' % (i, dLoss, gLoss))

8.训练

编写完上述文件之后,对模型进行训练,代码非常简单:

if __name__ == '__main__':
    train()

四、实验结果

实验设置了1000个epoch,但是总的说来,实验的结果并不理想,可能是由于数据集的质量太差,不像人脸数据集那样,所有皮卡丘的姿势、形态、在图像中的位置、大小差异非常大,因此训练结果较差。

当分别训练1个epoch,10个epoch,以及50个epoch时,训练结果基本都为噪声,不过能够勉强看到一点颜色的变化:

当分别训练了100个epoch,200个epoch,400个epoch时,结果依然不明显,但是能看到WGAN学到了颜色信息:

当分别训练了600个epoch,800个epoch,1000个epoch时,依然没有生成皮卡丘图像,但是能明显看到WGAN似乎学到了皮卡丘的外形信息,epoch=800时似乎学到了皮卡丘的红脸蛋,epoch=1000时似乎学到了皮卡丘的眼睛:


第二次更新:

昨天训练了10000个epoch,想看一下训练效果,不过事实上效果还是很差。。。可能真的是因为数据集太难训练了 ??

当训练2000个epoch,4000个epoch,8000个epoch时的结果:

当训练9991个epoch的最终结果:

最后的训练结果简直是。。。。。。。画风抽象。。。


第三次更新:

由于皮卡丘的训练效果比较差,这次尝试着利用人像数据进行训练,数据集是特朗普的头像,总计376张,每张的大小都是256*256,在我的文章对抗神经网络学习(二)中有介绍这个数据集,这里直接拿来训练模型。

数据集的图像为:

总体而言,特朗普数据集训练的效果比皮卡丘数据集的训练效果稍好一些,但也并不是特别理想。实验暂时训练了400个epoch,从训练完成之后的生成结果来看,WGAN能够勉强生成五官,但生成图像非常模糊。。。。这至少说明了模型是可以运作的,且模型具有进一步提升的空间。这次训练391个epoch的结果为:


五、分析

1.WGAN没有针对网络结构做大的改进,其优化思路也可借鉴于其他GAN模型。

2.WGAN的作者也提出了关于权重截取的问题,应该有更好的方法能替代权值截取。

3.皮卡丘数据集的质量相对较差,在1000个epoch内难以训练好模型。理论上来说该模型应该是能够生成皮卡丘的,目前还在做进一步改进。

4.关于损失函数。原文中作者提到,对于一般的GAN来说,就是最大化Loss函数:

然而,这里的loss函数是使用的JS距离,这里就会导致一些问题:

This quantity clearly correlates poorly the sample quality. Note also that the JS estimate usually stays constant or goes up instead of going down. In fact it often remains very close to log^2 ≈ 0.69 which is the highest value taken by the JS distance. In other words, the JS distance saturates, the discriminator has zero loss, and the generated samples are in some cases meaningful (DCGAN generator, top right plot) and in other cases collapse to a single nonsensical image [4]. This last phenomenon has been theoretically explained in [1] and highlighted in [11]. 

简单的说,就是有时GAN会产生有意义的图片,但有时也会出现模型崩塌,产生单一的无意义图像。因此作者使用了-log变换。

When using the −logD trick [4], the discriminator loss and the generator loss are different. Figure 8 in Appendix E reports the same plots for GAN training, but using the generator loss instead of the discriminator loss. This does not change the conclusions.

为了使得模型稳定,作者构建了这样一个函数:

同时作者在后面进行了解释:

When F is the set of all measurable functions bounded between -1 and 1 (or all continuous functions between -1 and 1), we retrieve dF(Pr,Pθ) = δ(Pr,Pθ) the total variation distance [15]. This already tells us that going from 1-Lipschitz to 1-Bounded functions drastically changes the topology of the space, and the regularity of dF(Pr,Pθ) as a loss function (as by Theorems 1 and 2).

也就是是说,作者把dF(Pr,Pθ)的正则化作为了loss函数。

5.整个实验文件夹的结构为:

-- pikachu        (数据集文件夹)
        |------ image01.jpg
        |------ image02.jpg
        |------ ......

-- scratch.py     (抓取皮卡丘数据集的程序)
        {
        import ...

        def getManyPages(keyword, pages):...

        def getImg(dataList, localPath):...

        if __name__ == '__main__':...
        }


-- WGAN.py        (WGAN的实现代码)
        {
        from ...
        import...

        image_height=128
        ...
        
        def process_data():...

        def leaky_relu(x, n, leak=0.2):...

        def generator(input, random_dim, is_train, reuse=False):...

        def discriminator(input, is_train, reuse=False):...

        def plot_and_save(num, images):...

        def train():...

        if __name__ == '__main__':...
        }

 

  • 12
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 26
    评论
WGAN (Wasserstein生成对抗网络)是一种生成对抗网络,其原理图如下: WGAN的核心思想是通过定义和最大化Wasserstein距离来训练生成器和判别器模型。Wasserstein距离是用于衡量两个分布之间的差异的一种距离度量。 在WGAN中,生成器模型G接受一个随机噪声向量作为输入,并生成一个与真实数据分布相似的样本。判别器模型D接受生成生成的样本以及真实数据样本作为输入,并尝试区分出哪些是真实样本,哪些是生成的样本。 WGAN的训练过程分为两个阶段:判别器阶段和生成器阶段。在判别器阶段,我们固定生成器的参数,只更新判别器的参数,通过最小化Wasserstein距离来增强判别器的能力。Wasserstein距离的计算是通过将判别器输出对真实样本的评分减去对生成样本的评分,然后取这些差异的最大值。 在生成器阶段,我们固定判别器的参数,只更新生成器的参数,通过最大化Wasserstein距离来改进生成器的性能。在这个阶段,生成器努力生成样本,使得它们能够获得更高的Wasserstein距离评分。 通过交替进行这两个阶段的训练,WGAN可以逐渐提高生成器的生成能力,使其生成的样本与真实数据更加接近。另外,WGAN还引入了一些技巧以解决传统生成对抗网络训练中的一些不稳定性问题,例如使用权重剪切技术来约束判别器的参数。 总结起来,WGAN通过定义Wasserstein距离来衡量生成器和判别器之间的差异,并通过交替训练这两个模型来改进生成器的生成能力,从而使其生成的样本更接近于真实数据分布。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

全部梭哈迟早暴富

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值