tensorflow-005-GAN01




import numpy as np 
import tensorflow as tf 
from scipy.stats import norm
import matplotlib.pyplot as plt

batch_size = 10
hidden_size = 4 

class DataDistribution():
    def __init__(self):
        self.mu = 4
        self.sigma = 0.5

    def sample(self, N):
        samples = np.random.normal(self.mu, self.sigma, N)
        samples.sort()
        return samples

class GeneratorDistribution():
    def __init__(self, rangex):
        self.range = rangex

    def sample(self, N):
        return np.linspace(-self.range, self.range, N) + np.random.random(N)*0.01

#def minibatch(inputx, num_kernel=5, kernel_dim=3):
#    x = linear(inputx, num_kernel*kernel_dim, scope='minibatch', stddev=0.02)
#    activation = tf.reshape(x, (-1, num_kernel, kernel_dim))
#    diffs = tf.expand_dims(activation, 3) - tf.expand_dims(tf.transpose(activation, [1,2,0]),0)
#    eps = tf.expand_dims( np.eye(batch_size, dtype=np.float32), 1)
#    abs_diffs = tf.reduce_sum(tf.abs(diffs), 2) + eps
#    minibatch_fea = tf.reduce_sum(tf.exp(-abs_diffs), 2)
#    return tf.concat(1, [inputx, minibatch_fea])

def linear(inputx, output_dim, scope=None, stddev=1.0):
    norm = tf.random_normal_initializer(stddev=stddev)
    const = tf.constant_initializer(0.0)
    with tf.variable_scope(scope or 'linear'):
        w = tf.get_variable('w', [inputx.get_shape()[1], output_dim], initializer=norm)
        b = tf.get_variable('b', [output_dim], initializer=const)
        return tf.matmul(inputx, w)+b

def generator(inputx, hidden_size):
    h0 = tf.nn.softplus(linear(inputx, hidden_size, 'g0'))
    h1 = linear(h0, 1, 'g1')
    return h1

def discriminator(inputx, hidden_size):
    h0 = tf.tanh(linear(inputx, hidden_size*2, 'd0'))
    h1 = tf.tanh(linear(h0, hidden_size*2, 'd1'))
    h2 = tf.tanh(linear(h1, hidden_size*2, 'd2'))
    #h2 = minibatch(h1)
    h3 = tf.sigmoid(linear(h2, 1, 'd3'))
    return h3

def optimizer(loss, var_list):
    initial_learning_rate = 0.005
    decay = 0.95
    num_decay_step = 200
    batch = tf.Variable(0, trainable=False)
    learning_rate = tf.train.exponential_decay(initial_learning_rate, batch, num_decay_step, decay, staircase=True)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=batch, var_list=var_list)
    return optimizer


with tf.variable_scope('D_pre'):
    pre_input = tf.placeholder(tf.float32, shape=[None, 1])
    pre_labels = tf.placeholder(tf.float32, shape=[None, 1])
    D_pre = discriminator(pre_input, hidden_size)
    pre_loss = tf.reduce_mean(tf.square(D_pre-pre_labels))
    #pre_opt = optimizer(pre_loss, None)

with tf.variable_scope('G'):
    z = tf.placeholder(tf.float32, shape=[None, 1])
    G = generator(z, hidden_size)

with tf.variable_scope('D') as scope:
    x = tf.placeholder(tf.float32, shape=[None, 1])
    D1 = discriminator(x, hidden_size)
    scope.reuse_variables()
    D2 = discriminator(G, hidden_size)

loss_d = tf.reduce_mean(-tf.log(D1)-tf.log(1-D2))
loss_g = tf.reduce_mean(-tf.log(D2))

var = tf.trainable_variables()
d_pre_params = [v for v in var if v.name.startswith('D_pre/')]
d_params = [v for v in var if v.name.startswith('D/')]
g_params = [v for v in var if v.name.startswith('G/')]

opt_d = optimizer(loss_d, d_params)
opt_g = optimizer(loss_g, g_params)
opt_d_pre = optimizer(pre_loss, d_pre_params)

with tf.Session() as sess:

    sess.run(tf.initialize_all_variables())
    data = DataDistribution()
    gen = GeneratorDistribution(8)

    anim_frames = []
    
    num_pretrain_steps = 1000
    for step in xrange(num_pretrain_steps):
        d = (np.random.random(batch_size)-0.5)*10.0
        labels = norm.pdf(d, data.mu, data.sigma)
        feed_dict = { pre_input: np.reshape(d, [batch_size,1]), \
                        pre_labels: np.reshape(labels, [batch_size, 1])}

        pretrain_loss, _ = sess.run([pre_loss, opt_d_pre], feed_dict=feed_dict)

        if step%100 ==0:
            print step, pretrain_loss

    weightsD = sess.run(d_pre_params)
    print len(weightsD)
    print weightsD[0]

    for i, v in enumerate(d_params):
        sess.run(v.assign(weightsD[i]))

    print 'done'

    for step in xrange(2500):
        
        xx = data.sample(batch_size)
        zz = gen.sample(batch_size)
        feed_dictx = { x: np.reshape(xx, [batch_size,1]), \
                        z: np.reshape(zz, [batch_size,1])}
        lossd, _ = sess.run([loss_d, opt_d], feed_dict=feed_dictx)

        zz = gen.sample(batch_size)
        lossg, _ = sess.run([loss_g, opt_g], feed_dict={z: np.reshape(zz, [batch_size,1])})

        if step % 100 ==0:
            print step, lossd, lossg

    # PLOT
    num_points = 10000
    num_bins = 1000
    xs = np.linspace(-8, 8, num_points)
    bins = np.linspace(-8, 8, num_bins)

    db = np.zeros([num_points, 1])
    for i in range(num_points // batch_size):
        feed_dict = {x: np.reshape(xs[batch_size*i: batch_size*(i+1)], [batch_size,1])}
        db[batch_size*i: batch_size*(i+1)] = sess.run(D1, feed_dict= feed_dict)

    d = data.sample(num_points)
    pd, _ = np.histogram(d, bins=bins, density=True)

    zs = np.linspace(-8, 8, num_points)
    g = np.zeros((num_points, 1))
    for i in range(num_points // batch_size):
        feed_dict = {z: np.reshape( zs[batch_size*i: batch_size*(i+1)], [batch_size,1] )}
        g[batch_size*i: batch_size*(i+1)] = sess.run(G, feed_dict= feed_dict)

    pg, _ = np.histogram(g, bins=bins, density=True)

    db_x = np.linspace(-8, 8, len(db))
    p_x = np.linspace(-8, 8, len(pd))

    anim_frames.append(([db_x, db], [p_x, pd], [p_x, pg]))
    
    plt.plot(db_x, db, label='decition boundary')
    plt.plot(p_x, pd, label='real data')
    plt.plot(p_x, pg, label='generated data')
    plt.title('GAN')
    plt.xlabel('data value')
    plt.ylabel('prob density')
    plt.legend()
    plt.show()


  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值