StackGAN学习笔记

安装环境 Ubuntu 16.04

python 2.7

TensorFlow 0.12

paper传送门 https://arxiv.org/abs/1612.03242
Github传送门 https://github.com/hanzhanggit/StackGAN
直接开始做笔记吧,写得比较乱,非常非常抱歉…

1. 数据准备

由于只准备了birds数据,所以只对该数据做处理
项目里边misc/preprcess_birds.py
该文件可以生成训练用的数据,我们训练只需要图片中有鸟的那一块区域,所以该文件主要将图片中标记好的bbox数据抠出来,并保存为各种尺寸的图片,原来只生成64x64与256x256的数据,由于实验需求,所以新增了32x32与128x128的数据.

下面代码为misc/preprcess_birds.py中保存数据的函数

Myself32_RETIO = 8
LR_HR_RETIO = 4
Myself128_RETIO = 2
IMSIZE = 256
LOAD_SIZE = int(IMSIZE * 76 / 64)

def save_data_list(inpath, outpath, filenames, filename_bbox):
    Myself32_images = []
    hr_images = []
    Myself128_images = []
    lr_images = []
    Myself32_size = int(LOAD_SIZE/Myself32_RETIO)
    lr_size = int(LOAD_SIZE / LR_HR_RETIO)
    Myself128_size = int(LOAD_SIZE / Myself128_RETIO)
    cnt = 0
    for key in filenames:
        bbox = filename_bbox[key]
        f_name = '%s/CUB_200_2011/images/%s.jpg' % (inpath, key)
        img = get_image(f_name, LOAD_SIZE, is_crop=True, bbox=bbox)
        img = img.astype('uint8')
        hr_images.append(img)
        Myself128_img = scipy.misc.imresize(img, [Myself128_size, Myself128_size], 'bicubic')
        Myself128_images.append(Myself128_img)
        lr_img = scipy.misc.imresize(img, [lr_size, lr_size], 'bicubic')
        lr_images.append(lr_img)
        Myself32_img = scipy.misc.imresize(img, [Myself32_size, Myself32_size], 'bicubic')
        Myself32_images.append(Myself32_img)
        cnt += 1
        if cnt % 100 == 0:
            print('Load %d......' % cnt)
    #
    print('images', len(hr_images), hr_images[0].shape, lr_images[0].shape, Myself128_images[0].shape,Myself32_images[0].shape)
    #
    outfile = outpath + str(LOAD_SIZE) + 'images.pickle'
    with open(outfile, 'wb') as f_out:
        pickle.dump(hr_images, f_out)
        print('save to: ', outfile)
    #
    outfile = outpath + str(Myself128_size) + 'images.pickle'
    with open(outfile, 'wb') as f_out:
        pickle.dump(Myself128_images, f_out)
        print('save to: ', outfile)
    #
    outfile = outpath + str(lr_size) + 'images.pickle'
    with open(outfile, 'wb') as f_out:
        pickle.dump(lr_images, f_out)
        print('save to: ', outfile)
    #
    outfile = outpath + str(Myself32_size) + 'images.pickle'
    with open(outfile, 'wb') as f_out:
        pickle.dump(Myself32_images, f_out)
        print('save to: ', outfile)

misc/preprcess_birds.py运行完会生成
38images.pickle,
76images.pickle,
152images.pickle,
304images.pickle,这些都是训练要用到的数据,在misc/datasets.py中使用了transform函数,将数据转化为了32x32,64x64,128x128,256x256的数据.

该函数的作用为
例如,数据为76x76的图像,该函数在该图像中随机选取了64x64的图像保存了下来

    def transform(self, images):
        if self._aug_flag:
            transformed_images =\
                np.zeros([images.shape[0], self._imsize, self._imsize, 3])
            ori_size = images.shape[1]
            for i in range(images.shape[0]):
                h1 = np.floor((ori_size - self._imsize) * np.random.random())
                w1 = np.floor((ori_size - self._imsize) * np.random.random())
                cropped_image =\
                    images[i][w1: w1 + self._imsize, h1: h1 + self._imsize, :]
                if random.random() > 0.5:
                    transformed_images[i] = np.fliplr(cropped_image)
                else:
                    transformed_images[i] = cropped_image
            return transformed_images
        else:
            return images

2 网络结构
这里写图片描述
1.embedding的处理
StackGAN 没有直接将 embedding 作为 condition ,而是用 embedding 接了一个 FC 层从得到的独立的高斯分布中随机采样得到隐含变量。之所以这样做的原因是,embedding 通常比较高维,而相对这个维度来说, text 的数量其实很少,如果将 embedding 直接作为 condition,那么这个 latent variable 在 latent space 里就比较稀疏,这对训练不利。

StageI/model.py

    def generate_condition(self, c_var):
        conditions =\
            (pt.wrap(c_var).
             flatten().
             custom_fully_connected(self.ef_dim * 2).
             apply(leaky_rectify, leakiness=0.2))
        mean = conditions[:, :self.ef_dim]
        log_sigma = conditions[:, self.ef_dim:]
        return [mean, log_sigma]

StageI/trainer.py

    def sample_encoded_context(self, embeddings):
        '''Helper function for init_opt'''
        c_mean_logsigma = self.model.generate_condition(embeddings)
        mean = c_mean_logsigma[0]
        if cfg.TRAIN.COND_AUGMENTATION:
            # epsilon = tf.random_normal(tf.shape(mean))
            epsilon = tf.truncated_normal(tf.shape(mean))
            stddev = tf.exp(c_mean_logsigma[1])
            c = mean + stddev * epsilon

            kl_loss = KL_loss(c_mean_logsigma[0], c_mean_logsigma[1])
        else:
            c = mean
            kl_loss = 0

        return c, cfg.TRAIN.COEFF.KL * kl_loss

上述代码出现了KL损失,目的是正则化:为了防止过拟合或者方差太大的情况,generator 的 loss 里面加入了对这个分布的正则化:这里写图片描述

2.stageI 网络生成器
StageI/model.py
s为训练图像的尺寸,此处训练集为64x64图像,即 s=64, s2=s/2,s4=s/4,s8=s/8,s16=s/16

    def generator(self, z_var):
        node1_0 =\
            (pt.wrap(z_var).
             flatten().
             custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8).
             fc_batch_norm().
             reshape([-1, self.s16, self.s16, self.gf_dim * 8]))
        node1_1 = \
            (node1_0.
             custom_conv2d(self.gf_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_conv2d(self.gf_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm())
        node1 = \
            (node1_0.
             apply(tf.add, node1_1).
             apply(tf.nn.relu))

        node2_0 = \
            (node1.
             # custom_deconv2d([0, self.s8, self.s8, self.gf_dim * 4], k_h=4, k_w=4).
             apply(tf.image.resize_nearest_neighbor, [self.s8, self.s8]).
             custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm())
        node2_1 = \
            (node2_0.
             custom_conv2d(self.gf_dim * 1, k_h=1, k_w=1, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_conv2d(self.gf_dim * 1, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm())
        node2 = \
            (node2_0.
             apply(tf.add, node2_1).
             apply(tf.nn.relu))

        output_tensor = \
            (node2.
             # custom_deconv2d([0, self.s4, self.s4, self.gf_dim * 2], k_h=4, k_w=4).
             apply(tf.image.resize_nearest_neighbor, [self.s4, self.s4]).
             custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             # custom_deconv2d([0, self.s2, self.s2, self.gf_dim], k_h=4, k_w=4).
             apply(tf.image.resize_nearest_neighbor, [self.s2, self.s2]).
             # custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).
             custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             # custom_deconv2d([0] + list(self.image_shape), k_h=4, k_w=4).
             apply(tf.image.resize_nearest_neighbor, [self.s, self.s]).
             custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1).
             apply(tf.nn.tanh))
        return output_tensor

    def generator_simple(self, z_var):
        output_tensor =\
            (pt.wrap(z_var).
             flatten().
             custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8).
             reshape([-1, self.s16, self.s16, self.gf_dim * 8]).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_deconv2d([0, self.s8, self.s8, self.gf_dim * 4], k_h=4, k_w=4).
             # apply(tf.image.resize_nearest_neighbor, [self.s8, self.s8]).
             # custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_deconv2d([0, self.s4, self.s4, self.gf_dim * 2], k_h=4, k_w=4).
             # apply(tf.image.resize_nearest_neighbor, [self.s4, self.s4]).
             # custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_deconv2d([0, self.s2, self.s2, self.gf_dim], k_h=4, k_w=4).
             # apply(tf.image.resize_nearest_neighbor, [self.s2, self.s2]).
             # custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_deconv2d([0] + list(self.image_shape), k_h=4, k_w=4).
             # apply(tf.image.resize_nearest_neighbor, [self.s, self.s]).
             # custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1).
             apply(tf.nn.tanh))
        return output_tensor

    def get_generator(self, z_var):
        if cfg.GAN.NETWORK_TYPE == "default":
            return self.generator(z_var)
        elif cfg.GAN.NETWORK_TYPE == "simple":
            return self.generator_simple(z_var)
        else:
            raise NotImplementedError
    def sampler(self):
        c, _ = self.sample_encoded_context(self.embeddings)
        if cfg.TRAIN.FLAG:
            z = tf.zeros([self.batch_size, cfg.Z_DIM])  # Expect similar BGs
        else:
            z = tf.random_normal([self.batch_size, cfg.Z_DIM])
        self.fake_images = self.model.get_generator(tf.concat(1, [c, z]))

上述代码连接c和z,作用是为生成器提供输入

3.stageI网络判别器
首先embedding经过一个全连接层被压缩到128维,然后经过空间复制将其扩成一个4x4x128的张量。同时,图像会经过一系列的下采样到4x4。然后,图像过滤映射会连接图像和文本张量的通道。随后张量会经过一个1x1的卷积层去连接跨文本和图像学到的特征。最后,会通过只有一个节点的全连接层去产生图像真假的概率。

StageI/model.py

    def context_embedding(self):
        template = (pt.template("input").
                    custom_fully_connected(self.ef_dim).
                    apply(leaky_rectify, leakiness=0.2))
        return template

    def d_encode_image(self):
        node1_0 = \
            (pt.template("input").
             custom_conv2d(self.df_dim, k_h=4, k_w=4).
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 2, k_h=4, k_w=4).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 4, k_h=4, k_w=4).
             conv_batch_norm().
             custom_conv2d(self.df_dim * 8, k_h=4, k_w=4).
             conv_batch_norm())
        node1_1 = \
            (node1_0.
             custom_conv2d(self.df_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
             # custom_conv2d(self.df_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm())
        node1 = \
            (node1_0.
             apply(tf.add, node1_1).
             apply(leaky_rectify, leakiness=0.2))

        return node1

    def d_encode_image_simple(self):
        template = \
            (pt.template("input").
             custom_conv2d(self.df_dim, k_h=4, k_w=4).
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 2, k_h=4, k_w=4).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 4, k_h=4, k_w=4).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             custom_conv2d(self.df_dim * 8, k_h=4, k_w=4).
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2))

        return template

    def discriminator(self):
        template = \
            (pt.template("input").  # 128*9*4*4
             custom_conv2d(self.df_dim * 8, k_h=1, k_w=1, d_h=1, d_w=1).  # 128*8*4*4
             conv_batch_norm().
             apply(leaky_rectify, leakiness=0.2).
             # custom_fully_connected(1))
             custom_conv2d(1, k_h=self.s16, k_w=self.s16, d_h=self.s16, d_w=self.s16))

        return template

    def get_discriminator(self, x_var, c_var):
        x_code = self.d_encode_img_template.construct(input=x_var)

        c_code = self.d_context_template.construct(input=c_var)
        c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
        c_code = tf.tile(c_code, [1, self.s16, self.s16, 1])

        x_c_code = tf.concat(3, [x_code, c_code])
        return self.discriminator_template.construct(input=x_c_code)

经过了600个epoch的训练,stageI 网络得到的结果效果如下:
test598.txt
row 0
this small brown bird has a white speckled belly and a white eye brow.
row 1
this is medium sized bird with black feathers and a skinny body.
row 2
a small brown bird with a yellow belly and a medium sized beak.
row 3
this bird is black in color with green eyes and a black beak and black feet and tarsus and black wings.
test598.jpg
这里写图片描述

test599.txt
row 0
this bird is grey with yellow on its belly and brown on its tail.
row 1
this black bird has ruffled feathers and long reticles.
row 2
this bird is white with brown and has a long, pointy beak.
row 3
a medium sized black bird, with a white throat and a long skinny bill.
test599.jpg
这里写图片描述

4.stageII网络生成器
stageII的网络结构大部分与stageI类似,不过多了些下采样,
将stageI得到的64x64的图片下采样为16x16的图片,我们可以认为已经学习到了部分特征,再通过残差学习,最后经过生成器生成更高分辨率的图像.源码生成了256x256的图像.
stageII/model.py

    def hr_g_encode_image(self, x_var):
        output_tensor = \
            (pt.wrap(x_var).  # -->s * s * 3
             custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).  # s * s * gf_dim
             apply(tf.nn.relu).
             custom_conv2d(self.gf_dim * 2, k_h=4, k_w=4).  # s2 * s2 * gf_dim * 2
             conv_batch_norm().
             apply(tf.nn.relu).
             custom_conv2d(self.gf_dim * 4, k_h=4, k_w=4).  # s4 * s4 * gf_dim * 4
             conv_batch_norm().
             apply(tf.nn.relu))
        return output_tensor
    def hr_g_joint_img_text(self, x_c_code):
        output_tensor = \
            (pt.wrap(x_c_code).  # -->s4 * s4 * (ef_dim+gf_dim*4)
             custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).  # s4 * s4 * gf_dim * 4
             conv_batch_norm().
             apply(tf.nn.relu))
        return output_tensor

    def hr_generator(self, x_c_code):
        output_tensor = \
            (pt.wrap(x_c_code).  # -->s4 * s4 * gf_dim*4
             # custom_deconv2d([0, self.s2, self.s2, self.gf_dim * 2], k_h=4, k_w=4).  # -->s2 * s2 * gf_dim*2
             apply(tf.image.resize_nearest_neighbor, [self.s2, self.s2]).
             custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             # custom_deconv2d([0, self.s, self.s, self.gf_dim], k_h=4, k_w=4).  # -->s * s * gf_dim
             apply(tf.image.resize_nearest_neighbor, [self.s, self.s]).
             custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).
             conv_batch_norm().
             apply(tf.nn.relu).
             # custom_deconv2d([0, self.s * 2, self.s * 2, self.gf_dim // 2], k_h=4, k_w=4).  # -->2s * 2s * gf_dim/2
             # apply(tf.image.resize_nearest_neighbor, [self.s * 2, self.s * 2]).
             # custom_conv2d(self.gf_dim // 2, k_h=3, k_w=3, d_h=1, d_w=1).
             # conv_batch_norm().
             # apply(tf.nn.relu).
             # # custom_deconv2d([0, self.s * 4, self.s * 4, self.gf_dim // 4], k_h=4, k_w=4).  # -->4s * 4s * gf_dim//4
             # apply(tf.image.resize_nearest_neighbor, [self.s * 4, self.s * 4]).
             # custom_conv2d(self.gf_dim // 4, k_h=3, k_w=3, d_h=1, d_w=1).
             # conv_batch_norm().
             # apply(tf.nn.relu).
             custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1).  # -->4s * 4s * 3
             apply(tf.nn.tanh))
        return output_tensor

    def hr_get_generator(self, x_var, c_code):
        if cfg.GAN.NETWORK_TYPE == "default":
            # image x_var: self.s * self.s *3
            x_code = self.hr_g_encode_image(x_var)  # -->s4 * s4 * gf_dim * 4

            # text c_code: ef_dim
            c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
            c_code = tf.tile(c_code, [1, self.s4, self.s4, 1])

            # combine both --> s4 * s4 * (ef_dim+gf_dim*4)
            x_c_code = tf.concat(3, [x_code, c_code])

            # Joint learning from text and image -->s4 * s4 * gf_dim * 4
            node0 = self.hr_g_joint_img_text(x_c_code)
            node1 = self.residual_block(node0)
            node2 = self.residual_block(node1)
            node3 = self.residual_block(node2)
            node4 = self.residual_block(node3)

            # Up-sampling
            return self.hr_generator(node4)  # -->4s * 4s * 3
        else:
            raise NotImplementedError

stageII判别器
与stageI一样,只不过由于输入尺寸变大,而为了得到4x4的图像块,加多了两层卷积层.

而最后由于电脑当机了,参数计算量过大,无法完成训练,所以最后只做了个假想实验,如下
论文中的实验证明,stageII网络可以提取到更多图像的细节,所以利用stageI生成大致的图像,再利用stageII精细图像,所以我们利用stageI生成64x64的图像,在通过stageII重新生成64x64的图像,迭代了200个epoch得到的结果如下:

test.txt
row 0
a small bird with a short bill and a yellowish crown
row 1
this bird has wings that are black with a bulk beak
row 2
a small brown bird with a very long straight tail, a fluffy head, and a medium sized beak.
row 3
a tall bird with long tarsi, a long black pointed bill, and some jet black wings.
stageI得到的图像 lr_fake_test.jpg
这里写图片描述

stageII得到的图像 hr_fake_test.jpg
这里写图片描述

由于迭代时间耗时长,迭代次数少,其实啥也看不出来

还有其他结果,
例如:
stageI输出32x32图像,
test483.txt
row 0
this bird is black with red and has a long, pointy beak.
row 1
gray crowned bird, with black, gray, and white spots scattered throughout the rest of his body.
row 2
this bird has wings that are brown and black with a red crown
row 3
this bird has wings that are black and has a white bill
test483.jpg 483个epoch生成的图像
这里写图片描述

test188.txt
row 0
this bird has wings that are brown and has a long neck
row 1
the bird has a tan breast, yellow torso and black back.
row 2
this bird is yellow with black on its neck and has a long, pointy beak.
row 3
this small bird has a black bill and crown with a white breast and dark retrices.
stageI输出128x128图像
test188.jpg 188个epoch生成的图像(耗时很长)
这里写图片描述

待续…

阅读更多
想对作者说点什么?

博主推荐

换一批

没有更多推荐了,返回首页