使用tensorflow实现conditional-gan简单小demo

                                                      图(1)Conditional-GAN结构图

conditional-gan 简称 CGAN,它是GAN的一种变形。它的结构可以由原始的GAN结构做简单的变换得到。

如上图(1)所示,CGAN与GAN的不同之处在于:

从生成器G的角度而言:

GAN的生成器G是直接通过随机噪声来拟合真实图片的分布从而产生伪图片;而CGAN添加上了条件(文字描述t),我们可以使用函数φ将其转化为φ(t),并将条件φ(t)与噪声Z相拼接从而得到生成器的输入,生成器G会通过反卷积的过程将其变为伪图片。

从鉴别器的角度而言:

GAN的鉴别器D是直接输入待鉴别图片并给出打分;而CGAN不仅输入待鉴别图片,还将待鉴别图片与条件相连接得到(x, φ(t)),从而通过鉴别器给出鉴别结果。

 

下面是使用tensorflow进行实现的小demo,该demo并未使用卷积与反卷积,而是通过全连接层来实现!

 

 

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

tf.set_random_seed(1)
np.random.seed(1)

BATCH_SIZE = 64
LR_G = 0.0001
LR_D = 0.0001
N_IDEAS = 5
ART_COMPONENTS = 15
PAINT_POINTS = np.vstack([np.linspace(-1,1,ART_COMPONENTS)for _ in range(BATCH_SIZE)])                                                #shape = (64,15)

plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2)+1,c = '#74BCFF',lw = 3,label='upper bound')
plt.plot(PAINT_POINTS[0],1*np.power(PAINT_POINTS[0],2)+0,c = '#FF9359',lw = 3,label='lower bound')
plt.legend(loc = 'upper right')
plt.show()


def artist_works():
    a = np.random.uniform(1,2,size=BATCH_SIZE)[:,np.newaxis]   #shape = (64,1)
    paintings = a*np.power(PAINT_POINTS,2)+(a-1)               #shape = (64,15)
    labels = (a-1)>0.5
    labels = labels.astype(np.float32)
    return paintings,labels

art_labels = tf.placeholder(tf.float32,[None,1])               #shape = (64,1)
with tf.variable_scope('Generator'):
    G_in = tf.placeholder(tf.float32,[None,N_IDEAS])           #shape = (64,5)
    G_art = tf.concat((G_in,art_labels),1)                    #shape = (64,6)
    G_l1 = tf.layers.dense(G_art,128,tf.nn.relu)                #shape = (64,128)
    G_out = tf.layers.dense(G_l1,ART_COMPONENTS)               #shape = (64,15)


with tf.variable_scope('Discriminator'):
    real_in = tf.placeholder(tf.float32,[None,ART_COMPONENTS],name='real_in')
    real_art = tf.concat((real_in,art_labels),1)               
    D_l0 = tf.layers.dense(real_art,128,tf.nn.relu,name='1')
    prob_artist0 = tf.layers.dense(D_l0,1,tf.nn.sigmoid,name='out')

    #fake art
    G_art = tf.concat((G_out,art_labels),1)
    D_l1 = tf.layers.dense(G_art,128,tf.nn.relu,name='1',reuse=True)
    prob_artist1 = tf.layers.dense(D_l1,1,tf.nn.sigmoid,name='out',reuse=True)

D_loss = -tf.reduce_mean(tf.log(prob_artist0)+tf.log(1-prob_artist1))
G_loss = tf.reduce_mean(tf.log(1-prob_artist1))

train_D = tf.train.AdamOptimizer(LR_D).minimize(
       D_loss,var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='Discriminator'))

train_G = tf.train.AdamOptimizer(LR_G).minimize(
       G_loss,var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='Generator'))

sess= tf.Session()
sess.run(tf.global_variables_initializer())

plt.ion()
for step in range(7000):
    artist_paintings,labels = artist_works()
    G_ideas = np.random.randn(BATCH_SIZE,N_IDEAS)
    G_paintings,pa0,D1 = sess.run([G_out,prob_artist0,D_loss,train_D,train_G],
                                  {G_in:G_ideas,real_in:artist_paintings,art_labels:labels})[:3]

    if step%50==0:
        plt.cla()
        plt.plot(PAINT_POINTS[0],G_paintings[0],c='#4AD631',lw=3,label='Generated painting')
        bound = [0,0.5] if labels[0,0] == 0 else [0.5,1]
        plt.plot(PAINT_POINTS[0],2*np.power(PAINT_POINTS[0],2)+bound[1],c='#74BCFF',lw=3,label='upper bound')
        plt.plot(PAINT_POINTS[0],1*np.power(PAINT_POINTS[0],2)+bound[0],c='#FF9359',lw=3,label='lower bound')
        plt.text(-.5,2.3,'D accuracy=%.2f (0.5 for D to converge)'%pa0.mean(),fontdict={'size':15})
        plt.text(-.5,2,'D score=%.2f (-1.38 for G to converge)'%-D1,fontdict={'size':15})
        plt.text(-.5,1.7,'Class = %i' % int(labels[0,0]),fontdict={'size':15})
        plt.ylim((0,3));plt.legend(loc='upper right',fontsize=12);plt.draw();plt.pause(0.01)

plt.ioff()
plt.show()


实验结果如下图所示:

 

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值