在cyclegan之前,对于两个域的图像进行转化,比如图像风格转换,它们的训练集图像都是成对的.而cyclegan则解决了训练图像必须成对的问题。使生成器的学习过程比image2image更像是两个图像域之间图像“翻译”。
- 下图分别是成对图像训练集与非成对图像训练集例子,成对图像训练时需要一一对应。
- cyclegan
cyclegan的网络设计思想本身不复杂。其中包含两个生成器,一个由图像域A生成图像域B,另一个由图像域B生成图像域A。两个判别器分别针对一个图像域。训练时,随机选一张域A的图像,由域A生成域B的图像,再将生成的图像由域B转换回域A,得到重构的输入图像,形成一个cycle。输入域B的图像过程与上相同。对每个生成器而言,误差即为重构误差(不采用identity loss的情况下)。
cyclegan损失函数由对抗损失与循环一致性损失构成(作者另外加上了identity loss)
对抗损失(判别器):
循环一致性损失(生成器重构误差):
总损失:
- 网络实现。对于生成器,采用在imagenet上预训练的vgg16作为基础。
class CycleGan:
def __init__(self, height, weight, channels=3):
self.height = height
self.weight = weight
self.channels = channels
self.img_shape = (self.height, self.weight, self.channels)
def build_generator(self):
# U-net like based on vgg16
input_img = Input(name='input_img',
shape=(self.height,
self.weight,
self.channels),
dtype='float32')
vgg16 = VGG16(input_tensor=input_img,
weights='imagenet',
include_top=False)
vgg_pools = [vgg16.get_layer('block%d_pool' % i).output
for i in range(1, 6)]
def decoder(layer_input, skip_input, channel, last_block=False):
if not last_block:
concat = Concatenate(axis=-1)([layer_input, skip_input])
bn1 = InstanceNormalization()(concat)
else:
bn1 = InstanceNormalization()(layer_input)
conv_1 = Conv2D(channel, 1,
activation='relu', padding='same')(bn1)
bn2 = InstanceNormalization()(conv_1)
conv_2 = Conv2D(channel, 3,
activation='relu', padding='same')(bn2)
return conv_2
d1 = decoder(UpSampling2D((2, 2))(vgg_pools[4]), vgg_pools[3], 256)
d2 = decoder(UpSampling2D((2, 2))(d1), vgg_pools[2], 128)
d3 = decoder(UpSampling2D((2, 2))(d2), vgg_pools[1], 64)
d4 = decoder(UpSampling2D((2, 2))(d3), vgg_pools[0], 32)
d5 = decoder(UpSampling2D((2, 2))(d4), None, 32, True)
output = Conv2D(3, 3, activation='tanh', padding='same')(d5)
model = Model(inputs=input_img, outputs=output)
# model.summary()
return model
def build_discriminator(self):
def d_layer(layer_input, filters, f_size=4, normalization=True):
d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
d = LeakyReLU(alpha=0.2)(d)
if normalization:
d = InstanceNormalization()(d)
return d
image = Input(shape=self.img_shape)
d1 = d_layer(image, 64, normalization=False)
d2 = d_layer(d1, 128)
d3 = d_layer(d2, 256)
d4 = d_layer(d3, 512)
patch_out = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
discriminator = Model(image, patch_out)
optimizer = Adam(0.0002, 0.5)
discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
# discriminator.summary()
return discriminator
def cycle_gan(self, gen_a2b, gen_b2a, dis_a, dis_b):
image_a = Input(shape=self.img_shape)
image_b = Input(shape=self.img_shape)
fake_b = gen_a2b(image_a)
fake_a = gen_b2a(image_b)
reconstr_a = gen_b2a(fake_b)
reconstr_b = gen_a2b(fake_a)
img_a_identity = gen_b2a(image_a)
img_b_identity = gen_a2b(image_b)
dis_a.trainable = False
dis_b.trainable = False
patch_out_a = dis_a(fake_a)
patch_out_b = dis_b(fake_b)
cycle_model = Model(inputs=[image_a, image_b],
outputs=[patch_out_a, patch_out_b,
reconstr_a, reconstr_b,
img_a_identity, img_b_identity])
optimizer = Adam(0.0002, 0.5)
lambda_cycle = 10.0 # Cycle-consistency loss
lambda_id = 0.1 * lambda_cycle # Identity loss
cycle_model.compile(loss=['mse', 'mse',
'mae', 'mae',
'mae', 'mae'],
loss_weights=[1, 1,
lambda_cycle, lambda_cycle,
lambda_id, lambda_id],
optimizer=optimizer)
# cycle_model.summary()
return cycle_model
- 复现结果,只复现了maps与monet2photo