DeepMind提图像生成的基于RNN的DRAW

原文地址:https://github.com/shugert/DRAW

这里简单基于python3改了一版,可以作为参考。

文章看了两篇,简单说下对这个算法的理解。基于几个问题我们收缩一下关注点,不然大家看起来比较费劲。

数据集是什么?

文章的训练集和测试集均是基于MNIST的,但是也可以基于其他的数据集做训练,比如人脸、图片风格生成。

算法是什么?

Draw神经网络提供了一种新的空间注意力机制,采用一种顺序变化的编码框架,使其对于复杂的图像进行迭代构造。核心一对RNN神经网络:一个是编码器压缩用于训练的真实图像,一个是解码器在接收到编码后解码生成图像。采用随机梯度下降的端到端训练,损失函数是数据对数似然函数的变分最大值。

传统的变分编码器与DRAW的区别:

640?wx_fmt=png

DRAW中的编码器和解码器都是RNN,解码器的输出直接被用于生成数据,而不是一步步的生成数据。动态更新的注意力机制用于限制由编码器负责的输入区域和由解码器更新的输出区域。就是说这个网络在每个时间步都能够决定“读到哪里”、“写到哪里”,“些什么”。

网络的总损失是重建损失和潜在损失之和的期望值:L=(L^x+L^z)_{(z\sim Q)}

为了确定图像的哪一部分最重要,需要做些观察,并根据这些观察做出决定。在 DRAW中,使用前一个时间步的解码器隐藏状态,通过使用一个简单的全连接的图层,我们可以将隐藏状态映射到三个决定方形裁剪的参数:中心 X、中心 Y 和比例。

虽然可以直观地将注意力机制描述为一种裁剪,但实践中使用了一种不同的方法。在上面描述的模型结构仍然精确的前提下,使用了 gaussian filters 矩阵,没有利用裁剪的方式。我们在 DRAW 中取了一组每个 filter 的中心间距都均匀的 gaussian filters 矩阵 。

算法能解决什么问题?

DRAW神经网络的主旨是通过不停的改善图像一步步把图像优化。

算法和其他类似算法的区别?

在 DRAW 模型中, P(Ct|Ct−1) 是所有 t 的同一分布,因此我们可以将其表示为以下递归关系(如果不是,那么就是 Markov Chain 而不是递归网络了)。如果我们把变分的自动编码器(VAE)中的潜在代码看作是表示整个图像的矢量,那么绘图中的潜在代码就可以看作是表示笔画的矢量。最后,这些向量的序列实现了原始图像的再现。

 

# import libraries
import tensorflow as tf
from tensorflow.examples.tutorials import mnist
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
from PIL import Image
import os

# hack for duplicated library on MacOS
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# reset tf package
tf.pack = tf.stack
tf.select = tf.where
tf.batch_matmul = tf.matmul

# define fully-connected layer
def dense(x, inputFeatures, outputFeatures, scope=None, with_w=False):
    with tf.variable_scope(scope or "Linear"):
        matrix = tf.get_variable("Matrix", [inputFeatures, outputFeatures], tf.float32, tf.random_normal_initializer(stddev=0.02))
        bias = tf.get_variable("bias", [outputFeatures], initializer=tf.constant_initializer(0.0))
        if with_w:
            return tf.matmul(x, matrix) + bias, matrix, bias
        else:
            return tf.matmul(x, matrix) + bias

# merge image
def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    img = np.zeros((h * size[0], w * size[1]))
    for idx, image in enumerate(images):
        i = int(idx % size[1])
        j = int(idx / size[1])
        img[j*h:j*h+h, i*w:i*w+w] = image
    return img

# save image
def save(filename, image_array):
    Image.fromarray((image_array*255).astype('uint8'), mode='L').convert('RGB').save(filename)

# DRAW implementation
class draw_model():
    def __init__(self):
        # 1. download MNIST dataset to our folder
        self.mnist = input_data.read_data_sets("data/", one_hot=True)
        self.n_samples = self.mnist.train.num_examples

        # 2. set model parameters
        # image width and height
        self.image_size = 28
        # read glimpse grid width/height
        self.attention_n = 5
        # number of hidden units / output size in LSTM
        self.n_hidden = 256
        # QSampler output size
        self.n_output = 10
        # MNIST generation sequence length
        self.sequence_length = 10
        # training minibatch size
        self.batch_size = 64
        # workaround for variable_scope(reuse=True)
        self.share_parameters = False

        # 3. build model
        # input(batch_size * image_size)
        self.images = tf.placeholder(tf.float32, [None, 784])
        # Qsample noise
        self.noise = tf.random_normal((self.batch_size, self.n_output), mean=0, stddev=1)
        # encoder 0p
        self.lstm_enc = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True)
        # decoder 0p
        self.lstm_dec = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True)

        # 4. define state variables
        # sequence of canvases
        self.seq_canvas = [0] * self.sequence_length
        self.mu, self.logsigma, self.sigma = [0] * self.sequence_length, [0] * self.sequence_length, [0] * self.sequence_length

        # 5. initalize states
        h_dec_prev = tf.zeros((self.batch_size, self.n_hidden))
        enc_state = self.lstm_enc.zero_state(self.batch_size, tf.float32)
        dec_state = self.lstm_dec.zero_state(self.batch_size, tf.float32)

        # 6. construct the unrolled computational graph
        x = self.images
        for t in range(self.sequence_length):
            # error image + original image
            c_prev = tf.zeros((self.batch_size, self.image_size**2)) if t == 0 else self.seq_canvas[t-1]
            x_hat = x - tf.sigmoid(c_prev)
            # read the image
            r = self.read_attention(x, x_hat, h_dec_prev)
            #  sanity check
            print(r.get_shape())
            # encode to guassian distribution
            self.mu[t], self.logsigma[t], self.sigma[t], enc_state = self.encode(enc_state, tf.concat([r, h_dec_prev], 1))
            # sample from the distribution to get z
            z = self.sampleQ(self.mu[t], self.sigma[t])
            # sanity check
            print(z.get_shape())
            # retrieve the hidden layer of RNN
            h_dec, dec_state = self.decode_layer(dec_state, z)
            # sanity check
            print(h_dec.get_shape())
            # map from hidden layer
            self.seq_canvas[t] = c_prev + self.write_attention(h_dec)
            h_dec_prev = h_dec
            # from now no, share variables
            self.share_parameters = True

        # 7. loss function
        self.generated_images = tf.nn.sigmoid(self.seq_canvas[-1])
        self.generation_loss = tf.reduce_mean(-tf.reduce_sum(self.images * tf.log(1e-10 + self.generated_images) + (1 - self.images) * tf.log(1e-10 + 1 - self.generated_images), 1))

        kl_terms = [0] * self.sequence_length
        for t in range(self.sequence_length):
            mu2 = tf.square(self.mu[t])
            sigma2 = tf.square(self.sigma[t])
            logsigma = self.logsigma[t] ###########
            # each kl term is (1 * minibatch)
            kl_terms[t] = 0.5 * tf.reduce_sum(mu2 + sigma2 - 2 * logsigma, 1) - self.sequence_length * 0.5
        self.latent_loss = tf.reduce_mean(tf.add_n(kl_terms))
        self.cost = self.generation_loss + self.latent_loss

        # 8. optimization
        optimizer = tf.train.AdamOptimizer(1e-3, beta1=0.5)
        grads = optimizer.compute_gradients(self.cost)
        for i, (g, v) in enumerate(grads):
            if g is not None:
                grads[i] = (tf.clip_by_norm(g, 5), v)
        self.train_op = optimizer.apply_gradients(grads)

        # 9. intialize session
        self.sess = tf.Session()
        self.sess.run(tf.initialize_all_variables())

    #training function
    def train(self):
        for i in range(20000):
            xtrain, _  = self.mnist.train.next_batch(self.batch_size)
            seq_canvas, gen_loss, lat_loss, _ = self.sess.run([self.seq_canvas, self.generation_loss, self.latent_loss, self.train_op], feed_dict={self.images: xtrain})
            print("iter: %d genloss: %f latloss: %f" % (i, gen_loss, lat_loss))
            if i % 500 == 0:
                seq_canvas = 1.0 / (1.0 + np.exp(-np.array(seq_canvas)))
                for sc_iter in range(10):
                    results = seq_canvas[sc_iter]
                    results_square = np.reshape(results, [-1, 28, 28])
                    print(results_square.shape)
                    save("result/" + str(i) + "-step--" + str(sc_iter) + ".jpg", merge(results_square, [8,8]))

    # Eric Jang's main functions
    # ------------------------------
    # locate where to put attention filters on hidden layers
    def attn_window(self, scope, h_dec):
        with tf.variable_scope(scope, reuse=self.share_parameters):
            parameters = dense(h_dec, self.n_hidden, 5)
        # center of 2d gaussian on a scale of -1 to 1
        gx_, gy_, log_sigma2, log_delta, log_gamma = tf.split(parameters, np.int(5), np.int(1))
        # move gx/gy to be a scale of -imagesize to +imagesize
        gx = (self.image_size + 1) / 2 * (gx_ + 1)
        gy = (self.image_size + 1) / 2 * (gy_ + 1)

        sigma2 = tf.exp(log_sigma2)
        # distance between patches
        delta = (self.image_size - 1) / ((self.attention_n - 1) * tf.exp(log_delta))
        # returns [Fx, Fy, gamma]
        return self.filterbank(gx, gy, sigma2, delta) + (tf.exp(log_gamma),)

    # construct patches of gaussian filters
    def filterbank(self, gx, gy, sigma2, delta):
        # 1 * N, look like [[0, 1, 2, 3, 4]]
        grid_i = tf.reshape(tf.cast(tf.range(self.attention_n), tf.float32), [1, -1])
        # individual patches centers
        mu_x = gx + (grid_i - self.attention_n / 2 - 0.5) * delta
        mu_y = gy + (grid_i - self.attention_n / 2 - 0.5) * delta
        mu_x = tf.reshape(mu_x, [-1, self.attention_n, 1])
        mu_y = tf.reshape(mu_y, [-1, self.attention_n, 1])
        # 1 * 1 * imagesize, looks like [[[0, 1, 2, 3, 4, ..., 27]]]
        im = tf.reshape(tf.cast(tf.range(self.image_size), tf.float32), [1, 1, -1])
        # list of gaussian curves for x and y
        sigma2 = tf.reshape(sigma2, [-1, 1, 1])
        Fx = tf.exp(-tf.square((im - mu_x) / (2 * sigma2)))
        Fy = tf.exp(-tf.square((im - mu_x) / (2 * sigma2)))
        # normalize area-under-curve
        Fx = Fx / tf.maximum(tf.reduce_sum(Fx, 2, keep_dims=True), 1e-8)
        Fy = Fy / tf.maximum(tf.reduce_sum(Fy, 2, keep_dims=True), 1e-8)
        return Fx, Fy

    # read operation without attention
    def read_basic(self, x, x_hat, h_dec_prev):
        return tf.concat([x_hat, x], 1)

    # read operation with attention
    def read_attention(self, x, x_hat, h_dec_prev):
        Fx, Fy, gamma = self.attn_window("read", h_dec_prev)
        # apply parameters for patch of gaussian filters
        def filter_img(img, Fx, Fy, gamma):
            Fxt = tf.transpose(Fx, perm=[0, 2, 1])
            img = tf.reshape(img, [-1, self.image_size, self.image_size])
            # apply the gaussian patches
            glimpse = tf.batch_matmul(Fy, tf.batch_matmul(img, Fxt))
            glimpse = tf.reshape(glimpse, [-1, self.attention_n**2])
            # scale using the gamma parameter
            return glimpse * tf.reshape(gamma, [-1, 1])

        x = filter_img(x, Fx, Fy, gamma)
        x_hat = filter_img(x_hat, Fx, Fy, gamma)
        return tf.concat([x, x_hat], 1)

    # encoder function for attention patch
    def encode(self, prev_state, image):
        # update the RNN with our image
        with tf.variable_scope("encoder", reuse=self.share_parameters):
            hidden_layer, next_state = self.lstm_enc(image, prev_state)

        # map the RNN hidden state to latent variables
        with tf.variable_scope("mu", reuse=self.share_parameters):
            mu = dense(hidden_layer, self.n_hidden, self.n_output)
        with tf.variable_scope("sigma", reuse=self.share_parameters):
            logsigma = dense(hidden_layer, self.n_hidden, self.n_output)
            sigma = tf.exp(logsigma)
        return mu, logsigma, sigma, next_state

    # 
    def sampleQ(self, mu, sigma):
        return mu + sigma*self.noise

    # decoder function
    def decode_layer(self, prev_state, latent):
        # update decoder RNN using our latent variable
        with tf.variable_scope("decoder", reuse=self.share_parameters):
            hidden_layer, next_state = self.lstm_dec(latent, prev_state)
        return hidden_layer, next_state
    
    # write operation without attention
    def write_basic(self, hidden_layer):
        # map RNN hidden state to image
        with tf.variable_scope("write", reuse=self.share_parameters):
            decoded_image_portion = dense(hidden_layer, self.n_hidden, self.image_size**2)
        return decoded_image_portion

    # write operation with attention
    def write_attention(self, hidden_layer):
        with tf.variable_scope("writefW", reuse=self.share_parameters):
            w = dense(hidden_layer, self.n_hidden, self.attention_n**2)

        w = tf.reshape(w, [self.batch_size, self.attention_n, self.attention_n])
        Fx, Fy, gamma = self.attn_window("write", hidden_layer)
        Fyt = tf.transpose(Fy, perm=[0, 2, 1])
        wr = tf.batch_matmul(Fyt, tf.batch_matmul(w, Fx))
        wr = tf.reshape(wr, [self.batch_size, self.image_size**2])
        return wr * tf.reshape(1.0/gamma, [-1, 1])
    
model = draw_model()
model.train()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值