安装环境 Ubuntu 16.04
python 2.7
TensorFlow 0.12
paper传送门 https://arxiv.org/abs/1612.03242
Github传送门 https://github.com/hanzhanggit/StackGAN
直接开始做笔记吧,写得比较乱,非常非常抱歉…
1. 数据准备
由于只准备了birds数据,所以只对该数据做处理
项目里边misc/preprcess_birds.py
该文件可以生成训练用的数据,我们训练只需要图片中有鸟的那一块区域,所以该文件主要将图片中标记好的bbox数据抠出来,并保存为各种尺寸的图片,原来只生成64x64与256x256的数据,由于实验需求,所以新增了32x32与128x128的数据.
下面代码为misc/preprcess_birds.py中保存数据的函数
Myself32_RETIO = 8
LR_HR_RETIO = 4
Myself128_RETIO = 2
IMSIZE = 256
LOAD_SIZE = int(IMSIZE * 76 / 64)
def save_data_list(inpath, outpath, filenames, filename_bbox):
Myself32_images = []
hr_images = []
Myself128_images = []
lr_images = []
Myself32_size = int(LOAD_SIZE/Myself32_RETIO)
lr_size = int(LOAD_SIZE / LR_HR_RETIO)
Myself128_size = int(LOAD_SIZE / Myself128_RETIO)
cnt = 0
for key in filenames:
bbox = filename_bbox[key]
f_name = '%s/CUB_200_2011/images/%s.jpg' % (inpath, key)
img = get_image(f_name, LOAD_SIZE, is_crop=True, bbox=bbox)
img = img.astype('uint8')
hr_images.append(img)
Myself128_img = scipy.misc.imresize(img, [Myself128_size, Myself128_size], 'bicubic')
Myself128_images.append(Myself128_img)
lr_img = scipy.misc.imresize(img, [lr_size, lr_size], 'bicubic')
lr_images.append(lr_img)
Myself32_img = scipy.misc.imresize(img, [Myself32_size, Myself32_size], 'bicubic')
Myself32_images.append(Myself32_img)
cnt += 1
if cnt % 100 == 0:
print('Load %d......' % cnt)
#
print('images', len(hr_images), hr_images[0].shape, lr_images[0].shape, Myself128_images[0].shape,Myself32_images[0].shape)
#
outfile = outpath + str(LOAD_SIZE) + 'images.pickle'
with open(outfile, 'wb') as f_out:
pickle.dump(hr_images, f_out)
print('save to: ', outfile)
#
outfile = outpath + str(Myself128_size) + 'images.pickle'
with open(outfile, 'wb') as f_out:
pickle.dump(Myself128_images, f_out)
print('save to: ', outfile)
#
outfile = outpath + str(lr_size) + 'images.pickle'
with open(outfile, 'wb') as f_out:
pickle.dump(lr_images, f_out)
print('save to: ', outfile)
#
outfile = outpath + str(Myself32_size) + 'images.pickle'
with open(outfile, 'wb') as f_out:
pickle.dump(Myself32_images, f_out)
print('save to: ', outfile)
misc/preprcess_birds.py运行完会生成
38images.pickle,
76images.pickle,
152images.pickle,
304images.pickle,这些都是训练要用到的数据,在misc/datasets.py中使用了transform函数,将数据转化为了32x32,64x64,128x128,256x256的数据.
该函数的作用为
例如,数据为76x76的图像,该函数在该图像中随机选取了64x64的图像保存了下来
def transform(self, images):
if self._aug_flag:
transformed_images =\
np.zeros([images.shape[0], self._imsize, self._imsize, 3])
ori_size = images.shape[1]
for i in range(images.shape[0]):
h1 = np.floor((ori_size - self._imsize) * np.random.random())
w1 = np.floor((ori_size - self._imsize) * np.random.random())
cropped_image =\
images[i][w1: w1 + self._imsize, h1: h1 + self._imsize, :]
if random.random() > 0.5:
transformed_images[i] = np.fliplr(cropped_image)
else:
transformed_images[i] = cropped_image
return transformed_images
else:
return images
2 网络结构
1.embedding的处理
StackGAN 没有直接将 embedding 作为 condition ,而是用 embedding 接了一个 FC 层从得到的独立的高斯分布中随机采样得到隐含变量。之所以这样做的原因是,embedding 通常比较高维,而相对这个维度来说, text 的数量其实很少,如果将 embedding 直接作为 condition,那么这个 latent variable 在 latent space 里就比较稀疏,这对训练不利。
StageI/model.py
def generate_condition(self, c_var):
conditions =\
(pt.wrap(c_var).
flatten().
custom_fully_connected(self.ef_dim * 2).
apply(leaky_rectify, leakiness=0.2))
mean = conditions[:, :self.ef_dim]
log_sigma = conditions[:, self.ef_dim:]
return [mean, log_sigma]
StageI/trainer.py
def sample_encoded_context(self, embeddings):
'''Helper function for init_opt'''
c_mean_logsigma = self.model.generate_condition(embeddings)
mean = c_mean_logsigma[0]
if cfg.TRAIN.COND_AUGMENTATION:
# epsilon = tf.random_normal(tf.shape(mean))
epsilon = tf.truncated_normal(tf.shape(mean))
stddev = tf.exp(c_mean_logsigma[1])
c = mean + stddev * epsilon
kl_loss = KL_loss(c_mean_logsigma[0], c_mean_logsigma[1])
else:
c = mean
kl_loss = 0
return c, cfg.TRAIN.COEFF.KL * kl_loss
上述代码出现了KL损失,目的是正则化:为了防止过拟合或者方差太大的情况,generator 的 loss 里面加入了对这个分布的正则化:
2.stageI 网络生成器
StageI/model.py
s为训练图像的尺寸,此处训练集为64x64图像,即 s=64, s2=s/2,s4=s/4,s8=s/8,s16=s/16
def generator(self, z_var):
node1_0 =\
(pt.wrap(z_var).
flatten().
custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8).
fc_batch_norm().
reshape([-1, self.s16, self.s16, self.gf_dim * 8]))
node1_1 = \
(node1_0.
custom_conv2d(self.gf_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
custom_conv2d(self.gf_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm())
node1 = \
(node1_0.
apply(tf.add, node1_1).
apply(tf.nn.relu))
node2_0 = \
(node1.
# custom_deconv2d([0, self.s8, self.s8, self.gf_dim * 4], k_h=4, k_w=4).
apply(tf.image.resize_nearest_neighbor, [self.s8, self.s8]).
custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm())
node2_1 = \
(node2_0.
custom_conv2d(self.gf_dim * 1, k_h=1, k_w=1, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
custom_conv2d(self.gf_dim * 1, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm())
node2 = \
(node2_0.
apply(tf.add, node2_1).
apply(tf.nn.relu))
output_tensor = \
(node2.
# custom_deconv2d([0, self.s4, self.s4, self.gf_dim * 2], k_h=4, k_w=4).
apply(tf.image.resize_nearest_neighbor, [self.s4, self.s4]).
custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
# custom_deconv2d([0, self.s2, self.s2, self.gf_dim], k_h=4, k_w=4).
apply(tf.image.resize_nearest_neighbor, [self.s2, self.s2]).
# custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).
custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
# custom_deconv2d([0] + list(self.image_shape), k_h=4, k_w=4).
apply(tf.image.resize_nearest_neighbor, [self.s, self.s]).
custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1).
apply(tf.nn.tanh))
return output_tensor
def generator_simple(self, z_var):
output_tensor =\
(pt.wrap(z_var).
flatten().
custom_fully_connected(self.s16 * self.s16 * self.gf_dim * 8).
reshape([-1, self.s16, self.s16, self.gf_dim * 8]).
conv_batch_norm().
apply(tf.nn.relu).
custom_deconv2d([0, self.s8, self.s8, self.gf_dim * 4], k_h=4, k_w=4).
# apply(tf.image.resize_nearest_neighbor, [self.s8, self.s8]).
# custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
custom_deconv2d([0, self.s4, self.s4, self.gf_dim * 2], k_h=4, k_w=4).
# apply(tf.image.resize_nearest_neighbor, [self.s4, self.s4]).
# custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
custom_deconv2d([0, self.s2, self.s2, self.gf_dim], k_h=4, k_w=4).
# apply(tf.image.resize_nearest_neighbor, [self.s2, self.s2]).
# custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
custom_deconv2d([0] + list(self.image_shape), k_h=4, k_w=4).
# apply(tf.image.resize_nearest_neighbor, [self.s, self.s]).
# custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1).
apply(tf.nn.tanh))
return output_tensor
def get_generator(self, z_var):
if cfg.GAN.NETWORK_TYPE == "default":
return self.generator(z_var)
elif cfg.GAN.NETWORK_TYPE == "simple":
return self.generator_simple(z_var)
else:
raise NotImplementedError
def sampler(self):
c, _ = self.sample_encoded_context(self.embeddings)
if cfg.TRAIN.FLAG:
z = tf.zeros([self.batch_size, cfg.Z_DIM]) # Expect similar BGs
else:
z = tf.random_normal([self.batch_size, cfg.Z_DIM])
self.fake_images = self.model.get_generator(tf.concat(1, [c, z]))
上述代码连接c和z,作用是为生成器提供输入
3.stageI网络判别器
首先embedding经过一个全连接层被压缩到128维,然后经过空间复制将其扩成一个4x4x128的张量。同时,图像会经过一系列的下采样到4x4。然后,图像过滤映射会连接图像和文本张量的通道。随后张量会经过一个1x1的卷积层去连接跨文本和图像学到的特征。最后,会通过只有一个节点的全连接层去产生图像真假的概率。
StageI/model.py
def context_embedding(self):
template = (pt.template("input").
custom_fully_connected(self.ef_dim).
apply(leaky_rectify, leakiness=0.2))
return template
def d_encode_image(self):
node1_0 = \
(pt.template("input").
custom_conv2d(self.df_dim, k_h=4, k_w=4).
apply(leaky_rectify, leakiness=0.2).
custom_conv2d(self.df_dim * 2, k_h=4, k_w=4).
conv_batch_norm().
apply(leaky_rectify, leakiness=0.2).
custom_conv2d(self.df_dim * 4, k_h=4, k_w=4).
conv_batch_norm().
custom_conv2d(self.df_dim * 8, k_h=4, k_w=4).
conv_batch_norm())
node1_1 = \
(node1_0.
custom_conv2d(self.df_dim * 2, k_h=1, k_w=1, d_h=1, d_w=1).
conv_batch_norm().
apply(leaky_rectify, leakiness=0.2).
custom_conv2d(self.df_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm().
apply(leaky_rectify, leakiness=0.2).
custom_conv2d(self.df_dim * 8, k_h=3, k_w=3, d_h=1, d_w=1).
# custom_conv2d(self.df_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm())
node1 = \
(node1_0.
apply(tf.add, node1_1).
apply(leaky_rectify, leakiness=0.2))
return node1
def d_encode_image_simple(self):
template = \
(pt.template("input").
custom_conv2d(self.df_dim, k_h=4, k_w=4).
apply(leaky_rectify, leakiness=0.2).
custom_conv2d(self.df_dim * 2, k_h=4, k_w=4).
conv_batch_norm().
apply(leaky_rectify, leakiness=0.2).
custom_conv2d(self.df_dim * 4, k_h=4, k_w=4).
conv_batch_norm().
apply(leaky_rectify, leakiness=0.2).
custom_conv2d(self.df_dim * 8, k_h=4, k_w=4).
conv_batch_norm().
apply(leaky_rectify, leakiness=0.2))
return template
def discriminator(self):
template = \
(pt.template("input"). # 128*9*4*4
custom_conv2d(self.df_dim * 8, k_h=1, k_w=1, d_h=1, d_w=1). # 128*8*4*4
conv_batch_norm().
apply(leaky_rectify, leakiness=0.2).
# custom_fully_connected(1))
custom_conv2d(1, k_h=self.s16, k_w=self.s16, d_h=self.s16, d_w=self.s16))
return template
def get_discriminator(self, x_var, c_var):
x_code = self.d_encode_img_template.construct(input=x_var)
c_code = self.d_context_template.construct(input=c_var)
c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
c_code = tf.tile(c_code, [1, self.s16, self.s16, 1])
x_c_code = tf.concat(3, [x_code, c_code])
return self.discriminator_template.construct(input=x_c_code)
经过了600个epoch的训练,stageI 网络得到的结果效果如下:
test598.txt
row 0
this small brown bird has a white speckled belly and a white eye brow.
row 1
this is medium sized bird with black feathers and a skinny body.
row 2
a small brown bird with a yellow belly and a medium sized beak.
row 3
this bird is black in color with green eyes and a black beak and black feet and tarsus and black wings.
test598.jpg
test599.txt
row 0
this bird is grey with yellow on its belly and brown on its tail.
row 1
this black bird has ruffled feathers and long reticles.
row 2
this bird is white with brown and has a long, pointy beak.
row 3
a medium sized black bird, with a white throat and a long skinny bill.
test599.jpg
4.stageII网络生成器
stageII的网络结构大部分与stageI类似,不过多了些下采样,
将stageI得到的64x64的图片下采样为16x16的图片,我们可以认为已经学习到了部分特征,再通过残差学习,最后经过生成器生成更高分辨率的图像.源码生成了256x256的图像.
stageII/model.py
def hr_g_encode_image(self, x_var):
output_tensor = \
(pt.wrap(x_var). # -->s * s * 3
custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1). # s * s * gf_dim
apply(tf.nn.relu).
custom_conv2d(self.gf_dim * 2, k_h=4, k_w=4). # s2 * s2 * gf_dim * 2
conv_batch_norm().
apply(tf.nn.relu).
custom_conv2d(self.gf_dim * 4, k_h=4, k_w=4). # s4 * s4 * gf_dim * 4
conv_batch_norm().
apply(tf.nn.relu))
return output_tensor
def hr_g_joint_img_text(self, x_c_code):
output_tensor = \
(pt.wrap(x_c_code). # -->s4 * s4 * (ef_dim+gf_dim*4)
custom_conv2d(self.gf_dim * 4, k_h=3, k_w=3, d_h=1, d_w=1). # s4 * s4 * gf_dim * 4
conv_batch_norm().
apply(tf.nn.relu))
return output_tensor
def hr_generator(self, x_c_code):
output_tensor = \
(pt.wrap(x_c_code). # -->s4 * s4 * gf_dim*4
# custom_deconv2d([0, self.s2, self.s2, self.gf_dim * 2], k_h=4, k_w=4). # -->s2 * s2 * gf_dim*2
apply(tf.image.resize_nearest_neighbor, [self.s2, self.s2]).
custom_conv2d(self.gf_dim * 2, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
# custom_deconv2d([0, self.s, self.s, self.gf_dim], k_h=4, k_w=4). # -->s * s * gf_dim
apply(tf.image.resize_nearest_neighbor, [self.s, self.s]).
custom_conv2d(self.gf_dim, k_h=3, k_w=3, d_h=1, d_w=1).
conv_batch_norm().
apply(tf.nn.relu).
# custom_deconv2d([0, self.s * 2, self.s * 2, self.gf_dim // 2], k_h=4, k_w=4). # -->2s * 2s * gf_dim/2
# apply(tf.image.resize_nearest_neighbor, [self.s * 2, self.s * 2]).
# custom_conv2d(self.gf_dim // 2, k_h=3, k_w=3, d_h=1, d_w=1).
# conv_batch_norm().
# apply(tf.nn.relu).
# # custom_deconv2d([0, self.s * 4, self.s * 4, self.gf_dim // 4], k_h=4, k_w=4). # -->4s * 4s * gf_dim//4
# apply(tf.image.resize_nearest_neighbor, [self.s * 4, self.s * 4]).
# custom_conv2d(self.gf_dim // 4, k_h=3, k_w=3, d_h=1, d_w=1).
# conv_batch_norm().
# apply(tf.nn.relu).
custom_conv2d(3, k_h=3, k_w=3, d_h=1, d_w=1). # -->4s * 4s * 3
apply(tf.nn.tanh))
return output_tensor
def hr_get_generator(self, x_var, c_code):
if cfg.GAN.NETWORK_TYPE == "default":
# image x_var: self.s * self.s *3
x_code = self.hr_g_encode_image(x_var) # -->s4 * s4 * gf_dim * 4
# text c_code: ef_dim
c_code = tf.expand_dims(tf.expand_dims(c_code, 1), 1)
c_code = tf.tile(c_code, [1, self.s4, self.s4, 1])
# combine both --> s4 * s4 * (ef_dim+gf_dim*4)
x_c_code = tf.concat(3, [x_code, c_code])
# Joint learning from text and image -->s4 * s4 * gf_dim * 4
node0 = self.hr_g_joint_img_text(x_c_code)
node1 = self.residual_block(node0)
node2 = self.residual_block(node1)
node3 = self.residual_block(node2)
node4 = self.residual_block(node3)
# Up-sampling
return self.hr_generator(node4) # -->4s * 4s * 3
else:
raise NotImplementedError
stageII判别器
与stageI一样,只不过由于输入尺寸变大,而为了得到4x4的图像块,加多了两层卷积层.
而最后由于电脑当机了,参数计算量过大,无法完成训练,所以最后只做了个假想实验,如下
论文中的实验证明,stageII网络可以提取到更多图像的细节,所以利用stageI生成大致的图像,再利用stageII精细图像,所以我们利用stageI生成64x64的图像,在通过stageII重新生成64x64的图像,迭代了200个epoch得到的结果如下:
test.txt
row 0
a small bird with a short bill and a yellowish crown
row 1
this bird has wings that are black with a bulk beak
row 2
a small brown bird with a very long straight tail, a fluffy head, and a medium sized beak.
row 3
a tall bird with long tarsi, a long black pointed bill, and some jet black wings.
stageI得到的图像 lr_fake_test.jpg
stageII得到的图像 hr_fake_test.jpg
由于迭代时间耗时长,迭代次数少,其实啥也看不出来
还有其他结果,
例如:
stageI输出32x32图像,
test483.txt
row 0
this bird is black with red and has a long, pointy beak.
row 1
gray crowned bird, with black, gray, and white spots scattered throughout the rest of his body.
row 2
this bird has wings that are brown and black with a red crown
row 3
this bird has wings that are black and has a white bill
test483.jpg 483个epoch生成的图像
test188.txt
row 0
this bird has wings that are brown and has a long neck
row 1
the bird has a tan breast, yellow torso and black back.
row 2
this bird is yellow with black on its neck and has a long, pointy beak.
row 3
this small bird has a black bill and crown with a white breast and dark retrices.
stageI输出128x128图像
test188.jpg 188个epoch生成的图像(耗时很长)
待续…