1 李宏毅生成对抗网络学习———GAN的直观了解

https://www.bilibili.com/video/av24011528/?p=1
ppt地址:http://speech.ee.ntu.edu.tw/~tlkagk/courses_MLDS17.html
csdn下载:
https://download.csdn.net/download/qq_35608277/10773478

1 定义

GAN 主要包括了两个部分,即生成器 generator 与判别器 discriminator。生成器主要用来学习真实图像分布从而让自身生成的图像更加真实,以骗过判别器。判别器则需要对接收的图片进行真假判别。

形象解释:

造假币的团伙相当于生成器,他们想通过伪造金钱来骗过银行,使得假币能够正常交易,而银行相当于判别器,需要判断进来的钱是真钱还是假币。因此假币团伙的目的是要造出银行识别不出的假币而骗过银行,银行则是要想办法准确地识别出假币。

love and peace 版本:
小学生学画画,从一年级到五年级,学生画画水平越来越高,老师要求也越来越高。
宿敌关系------佐助和鸣人,进藤光和塔矢亮…

总之是不断相互促进的效果。

回到GAN

对于给定的真实图片(real image),判别器要为其打上标签 1;
对于给定的生成图片(fake image),判别器要为其打上标签 0;(训练D的阶段,D的判别更强,对上一代G能够区分)

对于生成器传给辨别器的生成图片,生成器希望辨别器打上标签 1。(训练G的阶段,对上一代D判别器可以做到欺骗)

那么GAN是如何来做的呢?首先,我们有一个第一代的Generator,然后他产生一些图片,然后我们把这些图片和一些真实的图片丢到第一代的Discriminator里面去学习,让第一代的Discriminator能够真实的分辨生成的图片和真实的图片,然后我们又有了第二代的Generator,第二代的Generator产生的图片,能够骗过第一代的Discriminator,此时,我们在训练第二代的Discriminator,依次类推。

2 对应训练算法

在这里插入图片描述

3 问题及解答

  1. 为什么不能用生成器(NN)或者AE(自编码)直接生成,反而需要借助判别器?

生成器属于bottom-up架构,注重由部件到组成整体,缺乏整体的“大局观”。
以生成图片为例,相对独立的生成像素,对于相邻像素之间的关联性考虑不足。通过增加网络层数可以改善(增加关联),但同等规模的网络,GAN效果更好。

  1. 可不可以用判别器直接生成图像,而不使用生成器?(如某些名嘴,可以指点江山擅长批评,却无法提出创造性建议)

原则上是可以的。通过枚举的方式,找到打分最高的那张图,就是判别器生成的图。
判别器更擅长评估整体,是top-down的架构,比如生成图片中孤立的像素点更为敏感,能够指出其不合理。

实际上,问题的关键就在于如何产生真实的假图。如果只有数据集的真图,训练出来的判别器只会打高分。因为他遇到的所有图都是高分,所以形成所遇即高分的判断。

因此,通过生成器,来去逼近真实的假图分布,近似求解argmaxD(X)问题。

4 code

MNIST数据集比较好找。
结构
在这里插入图片描述

# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os

#读入数据
mnist = input_data.read_data_sets('./mnist', one_hot=True)#代码和数据集文件夹放在同一目录下

#从正态分布输出随机值
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)
	 

#判别模型的输入和参数初始化
X = tf.placeholder(tf.float32, shape=[None, 784])

D_W1 = tf.Variable(xavier_init([784, 128]))
D_b1 = tf.Variable(tf.zeros(shape=[128]))

D_W2 = tf.Variable(xavier_init([128, 1]))
D_b2 = tf.Variable(tf.zeros(shape=[1]))

theta_D = [D_W1, D_W2, D_b1, D_b2]

#生成模型的输入和参数初始化
Z = tf.placeholder(tf.float32, shape=[None, 100])

G_W1 = tf.Variable(xavier_init([100, 128]))
G_b1 = tf.Variable(tf.zeros(shape=[128]))

G_W2 = tf.Variable(xavier_init([128, 784]))
G_b2 = tf.Variable(tf.zeros(shape=[784]))

theta_G = [G_W1, G_W2, G_b1, G_b2]

#随机噪声采样函数
def sample_Z(m, n):
    return np.random.uniform(-1., 1., size=[m, n])

#生成模型
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob

#判别模型
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit

#画图函数
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.05, hspace=0.05)

    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig

#喂入数据
G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)

# 计算losses:
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_real, labels=tf.ones_like(D_logit_real))) 
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.zeros_like(D_logit_fake))) 
D_loss = D_loss_real + D_loss_fake

#label是真,意思是新的G可以骗过上一代D,最小化
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logit_fake, labels=tf.ones_like(D_logit_fake))) 

#为什么最小化,不应该是最大化吗???解释:使用的交叉熵做评价,就是比较labels和logits的相近程度,越小越好
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)

G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

mb_size = 128
Z_dim = 100


sess = tf.Session()
sess.run(tf.global_variables_initializer())

if not os.path.exists('out/'):
    os.makedirs('out/')

i = 0

#开始训练
axis_x=[]
axis_D_loss_curr=[]
axis_G_loss_curr=[]

for it in range(10000):#1000 000
    if it % 1000 == 0:
        samples = sess.run(G_sample, feed_dict={Z: sample_Z(16, Z_dim)})

        fig = plot(samples)
        plt.savefig('out/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)

    X_mb, _ = mnist.train.next_batch(mb_size)

    _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})##feed 真实图和噪声。
    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})###feed噪声,通过G产生生成图

    if it % 1000 == 0:
        axis_x.append(it)
        axis_D_loss_curr.append(D_loss_curr)
        axis_G_loss_curr.append(G_loss_curr)
        print('Iter: {}'.format(it))
        print('D loss: {:.4}'. format(D_loss_curr))
        print('G_loss: {:.4}'.format(G_loss_curr))
       

plt.plot(axis_x, axis_D_loss_curr, label='D_loss_curr')
plt.plot(axis_x, axis_G_loss_curr, label='G_loss_curr')

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

代码学习可以看:
https://study.163.com/course/courseLearn.htm?courseId=1005703030#/learn/video?lessonId=1052988014&courseId=1005703030

ref
https://www.jianshu.com/p/40feb1aa642a

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值