浅谈StarGAN
前言
今天我们来聊下StarGAN。
github
https://github.com/yanjingke/StarGAN-Keras
数据集
链接:https://pan.baidu.com/s/1mDKY52cWgvLIk5YrEKtF5w
提取码:aef0
复制这段内容后打开百度网盘手机App,操作更方便哦–来自百度网盘超级会员V3的分享
论文
StarGAN
StarGAN的引入是为了解决多领域间的转换问题的,之前的CycleGAN等只能解决两个领域之间的转换,那么对于含有C个领域转换而言,需要学习C*(C-1)个模型,但StarGAN仅需要学习一个,而且效果很棒,如下:
整个网络的处理流程如下:
对D,G训练
(a)D对真假图片进行判别,真图片判真,假图片判假,真图片被分类到相应域
对G训练
(b)真图片+目标标签c进入G生成假图片
(c)假图片+原始标签c又进入G生成重构图片
(d)D对假图片进行判别,判假图片为真
StarGAN优势
StarGAN优势
1.单个网络
生成器G接受图像和域信息作为输入。其中,使用二进制或一位有效编码(one-hot vector)的标签来代表域信息。
训练时,我们随机地生成目标域标签并训练模型能够将输入图像转换成具有目标域特征的图像。
2.跨数据集的多域训练
通过在域标签里增加掩码向量来支持不同数据集的域之间的联合训练。忽略未知标签并专注于特定的标签。如官方图的右边,使用RaFD的标签来合成CelebA的图像
举例来说,celebA和RaFD数据集,前者有发色和性别信息,后者有面部表情信息,我能将celebA中的人物改变一下面部表情吗?
一个很简单的想法是如果我原来的域标注信息是5位的onehot编码,现在变长为8位不就可以了。但是这存在一个问题就是celebA中的人其实也有表情,只是没有标注,RaFD其实也有性别区别,但对于网络来说没标记就是未知的。简单扩充域标记信息位是肯定不行的。我们希望网络只关注它有明确信息的那一部分标注。
因此,作者加了一个mask。在联合多个数据集训练时把mask向量也输入到生成器。
以上的ci代表第i个数据集的标签,已知标签ci如果是二进制属性则可以表示为二进制向量,如果为类别属性表示一个onehot。剩下的n-1个则指定为0。m则是一个长度为n的onehot编码。这样网络就会只关注已给定的标签。
StarGan loss
对抗损失(adv)
对抗损失
生成器G以输入图像x和目标域标签c为条件生成图像G(x, c),而鉴别器D尝试区分真实图像x和伪造图像y。
Dsrc(x)为D在给到D的源概率分布。
生成器G尝试最小化此目标,而鉴别器D尝试最大化它。
域分类损失(cls)
目标是令输出图像y符合目标域标签c的分类。
为了达到这个条件,我们在鉴别器D上添加了一个辅助分类器并通过优化D和G的损失来加强域分类。
即域分类损失分为两部分:
用于优化D的针对真实图像的域分类损失
Dcls(c’|x)为通过D计算的在域标签概率分布。
通过最小化这个目标,D学习将真实图像x分类为对应的原始领域c’。
我们假设输入图像和标签对(x,c’)是由训练数据给出。
用于优化G的针对伪造图像的域分类损失
G尝试最小化此目标以生成可归类为目标域c的图像。
重建损失(rec)
通过最小化上面的对抗损失和分类损失并不能保证仅更改输入图像中与域相关的部分而保留输入图像的内容。
为了解决这个问题,我们在生成器引入一个周期一致性损失
其中G接受伪造图像G(x,c)和原始域标签c’作为输入,尝试重建原始图像x。
我们采用L1规范作为重建损失。
注意我们将两次使用生成器G,第一次是首先将原始图像转换为目标域中的图像,第二次是从翻译后的图像中重建原始图像。
完整的损失
将目标函数写为优化D和G的形式
与对抗性损失相比,需要参数 λcls and λrec分别控制域分类损失和重建损失的相对比例。
StarGAN具体实现
generator生成器
StarGAN的生成器网络由两个卷积层组成,其步长为2,用于下采样;六个残差块;两个转置卷积层,步长为2,用于上采样,值得注意的是生成器如何将输入图片与目标域c一起结合作为输入的。
def build_generator(self):
"""Generator network."""
# Input tensors
#?,5
inp_c = Input(shape = (self.c_dim, ))
#?,128,128,3
inp_img = Input(shape = (self.image_size, self.image_size, 3))
#重复叠加?,16384,5
# Replicate spatially and concatenate domain information
c = Lambda(lambda x: K.repeat(x, self.image_size**2))(inp_c)
#?,128,128,5
c = Reshape((self.image_size, self.image_size, self.c_dim))(c)
#?,128,128,8
x = Concatenate()([inp_img, c])
#?,128,128,64
# First Conv2D
x = Conv2D(filters = self.g_conv_dim, kernel_size = 7, strides = 1, padding = 'same', use_bias = False)(x)
x = InstanceNormalization(axis = -1)(x)
x = Activation('relu')(x)
# Down-sampling layers
curr_dim = self.g_conv_dim
#?,64,64,128
#?32,32,256
for i in range(2):
x = ZeroPadding2D(padding = 1)(x)
x = Conv2D(filters = curr_dim*2, kernel_size = 4, strides = 2, padding = 'valid', use_bias = False)(x)
x = InstanceNormalization(axis = -1)(x)
x = Activation('relu')(x)
curr_dim = curr_dim * 2
# Bottleneck layers.
#32, 32, 256
for i in range(self.g_repeat_num):
x = self.ResidualBlock(x, curr_dim)
#128,128,64
# Up-sampling layers
for i in range(2):
x = UpSampling2D(size = 2)(x)
x = Conv2D(filters = curr_dim // 2, kernel_size = 4, strides = 1, padding = 'same', use_bias = False)(x)
x = InstanceNormalization(axis = -1)(x)
x = Activation('relu')(x)
curr_dim = curr_dim // 2
#128,128,3
# Last Conv2D
x = ZeroPadding2D(padding = 3)(x)
out = Conv2D(filters = 3, kernel_size = 7, strides = 1, padding = 'valid', activation = 'tanh', use_bias = False)(x)
return Model(inputs = [inp_img, inp_c], outputs = out)
discriminator鉴别器
对于鉴别器,使用 out_cls的输出代表域的预测概率,out_src 的输出代表图片是否为真的判断。这两个的关系是并行的。
def build_discriminator(self):
"""Discriminator network with PatchGAN."""
#128,128,3
inp_img = Input(shape = (self.image_size, self.image_size, 3))
x = ZeroPadding2D(padding = 1)(inp_img)
#64,64,64
x = Conv2D(filters = self.d_conv_dim, kernel_size = 4, strides = 2, padding = 'valid', use_bias = False)(x)
x = LeakyReLU(0.01)(x)
#32,32,128
#2,2,2048
curr_dim = self.d_conv_dim
for i in range(1, self.d_repeat_num):
x = ZeroPadding2D(padding = 1)(x)
x = Conv2D(filters = curr_dim*2, kernel_size = 4, strides = 2, padding = 'valid')(x)
x = LeakyReLU(0.01)(x)
curr_dim = curr_dim * 2
kernel_size = int(self.image_size / np.power(2, self.d_repeat_num))
#2,2,1
out_src = ZeroPadding2D(padding = 1)(x)
out_src = Conv2D(filters = 1, kernel_size = 3, strides = 1, padding = 'valid', use_bias = False)(out_src)
#1,1,5
out_cls = Conv2D(filters = self.c_dim, kernel_size = kernel_size, strides = 1, padding = 'valid', use_bias = False)(x)
out_cls = Reshape((self.c_dim, ))(out_cls)
return Model(inp_img, [out_src, out_cls])
loss计算
def classification_loss(self, Y_true, Y_pred) :
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=Y_true, logits=Y_pred))
def wasserstein_loss(self, Y_true, Y_pred):
return K.mean(Y_true*Y_pred)
def reconstruction_loss(self, Y_true, Y_pred):
return K.mean(K.abs(Y_true - Y_pred))
def gradient_penalty_loss(self, y_true, y_pred, averaged_samples):
"""
Computes gradient penalty based on prediction and weighted real / fake samples
求梯度
"""
gradients = K.gradients(y_pred, averaged_samples)[0]
# compute the euclidean norm by squaring ...
gradients_sqr = K.square(gradients)
# ... summing over the rows ...
gradients_sqr_sum = K.sum(gradients_sqr, axis=np.arange(1, len(gradients_sqr.shape)))
# ... and sqrt
gradient_l2_norm = K.sqrt(gradients_sqr_sum)
# compute lambda * (1 - ||grad||)^2 still for each single sample
gradient_penalty = K.square(1 - gradient_l2_norm)
# return the mean as loss over all the batch samples
return K.mean(gradient_penalty)
训练
def build_model(self):
self.G = self.build_generator()
self.D = self.build_discriminator()
# First don't update weights of Generator block
self.G.trainable = False
# Compute output with real images.
x_real = Input(shape = (self.image_size, self.image_size, 3))
#2,2,1 ?,5
out_src_real, out_cls_real = self.D(x_real)
# Compute output with fake images.
label_trg = Input(shape = (self.c_dim, ))
#128,128,3
x_fake = self.G([x_real, label_trg])
# 2,2,1 ?,5
out_src_fake, out_cls_fake = self.D(x_fake)
# Compute output for gradient penalty.
rd_avg = RandomWeightedAverage()
rd_avg.define_batch_size(self.batch_size)
x_hat = rd_avg([x_real, x_fake])
#8,2,2,1
out_src, _ = self.D(x_hat)
# Use Python partial to provide loss function with additional 'averaged_samples' argument
partial_gp_loss = partial(self.gradient_penalty_loss, averaged_samples=x_hat)
partial_gp_loss.__name__ = 'gradient_penalty' # Keras requires function names
# Define training model D
self.train_D = Model([x_real, label_trg], [out_src_real, out_cls_real, out_src_fake, out_src])
# Setup loss for train_D
self.train_D.compile(loss = [self.wasserstein_loss, self.classification_loss, self.wasserstein_loss, partial_gp_loss],
optimizer=Adam(lr = self.d_lr, beta_1 = self.beta1, beta_2 = self.beta2), loss_weights = [1, self.lambda_cls, 1, self.lambda_gp])
# Update G and not update D
self.G.trainable = True
self.D.trainable = False
# All inputs
real_x = Input(shape = (self.image_size, self.image_size, 3))
org_label = Input(shape = (self.c_dim, ))
trg_label = Input(shape = (self.c_dim, ))
# Compute output of fake image
fake_x = self.G([real_x, trg_label])
fake_out_src, fake_out_cls = self.D(fake_x)
# Target-to-original domain.
x_reconst = self.G([fake_x, org_label])
# Define traning model G
self.train_G = Model([real_x, org_label, trg_label], [fake_out_src, fake_out_cls, x_reconst])
# Setup loss for train_G
self.train_G.compile(loss = [self.wasserstein_loss, self.classification_loss, self.reconstruction_loss],
optimizer=Adam(lr = self.g_lr, beta_1 = self.beta1, beta_2 = self.beta2), loss_weights = [1, self.lambda_cls, self.lambda_rec])
""" Input Image"""
self.Image_data_class = ImageData(data_dir=self.data_dir, selected_attrs=self.selected_attrs)
self.Image_data_class.preprocess()
def train(self):
data_iter = get_loader(self.Image_data_class.train_dataset, self.Image_data_class.train_dataset_label, self.Image_data_class.train_dataset_fix_label,
image_size=self.image_size, batch_size=self.batch_size, mode=self.mode)
# Training
valid = -np.ones((self.batch_size, 2, 2, 1))
fake = np.ones((self.batch_size, 2, 2, 1))
dummy = np.zeros((self.batch_size, 2, 2, 1)) # Dummy gt for gradient penalty
for epoch in range(self.num_iters):
imgs, orig_labels, target_labels, fix_labels, _ = next(data_iter)
# Setting learning rate (linear decay)
if epoch > (self.num_iters - self.num_iters_decay):
K.set_value(self.train_D.optimizer.lr, self.d_lr*(self.num_iters - epoch)/(self.num_iters - self.num_iters_decay))
K.set_value(self.train_G.optimizer.lr, self.g_lr*(self.num_iters - epoch)/(self.num_iters - self.num_iters_decay))
# Training Discriminators
D_loss = self.train_D.train_on_batch(x = [imgs, target_labels], y = [valid, orig_labels, fake, dummy])
# Training Generators
if (epoch + 1) % self.n_critic == 0:
G_loss = self.train_G.train_on_batch(x = [imgs, orig_labels, target_labels], y = [valid, target_labels, imgs])
if (epoch + 1) % self.log_step == 0:
print(f"Iteration: [{epoch + 1}/{self.num_iters}]")
print(f"\tD/loss_real = [{D_loss[1]:.4f}], D/loss_fake = [{D_loss[3]:.4f}], D/loss_cls = [{D_loss[2]:.4f}], D/loss_gp = [{D_loss[4]:.4f}]")
print(f"\tG/loss_fake = [{G_loss[1]:.4f}], G/loss_rec = [{G_loss[3]:.4f}], G/loss_cls = [{G_loss[2]:.4f}]")
if (epoch + 1) % self.model_save_step == 0:
self.G.save_weights(os.path.join(self.model_save_dir, 'G_weights.hdf5'))
self.D.save_weights(os.path.join(self.model_save_dir, 'D_weights.hdf5'))
self.train_D.save_weights(os.path.join(self.model_save_dir, 'train_D_weights.hdf5'))
self.train_G.save_weights(os.path.join(self.model_save_dir, 'train_G_weights.hdf5'))
训练预测方式
训练:python main.py --mode=train --batch_size=8 --data_dir=D:\YouKu\data\celeba
预测 python main.py --mode custom --custom_image_name <000001.jpg… name of your image> --custom_image_label <1 0 0 1 1… 5 original labels of your image>