前言: 想要实现照片风格转和油画风格互转吗?想要实现斑马野马互转吗?想要实现苹果橘子互转等这些任务吗?没错,CycleGAN网络就能够帮你满足这一目标!
CycleGAN详解
我们先来用一些简单的话语描述GAN。
所谓GAN,按照我的理解就是:
1.一个具有生成器和判别器的网络结构
2.生成器主要负责从随机样本空间高斯采样随机点,然后生成假图片,类似于造假货
3.而判别器是一个二分类的神经网络,主要负责判别喂给它的图像是来自真实世界的图像还是生成器生成的假图像,类似于博学的专家。
直到生成器生成出的图像,连博学的专家都分辨不出来了,那么就达到了所谓的纳什均衡了!
但是呢,GAN虽然很好,要是我如果想让普通照片和油画互转,阁下该怎么应对呢?
那么不得不提到我们今天的主角,CycleGAN
2.1 模型结构
如上图(a)所示,CycleGAN的结构是两个GAN组成的,也就是说,它有两副GAN的结构,即两个生成器和两个判别器,我们尝试做如下定义:
X 代表的是莫奈的油画 , Y 代表的是普通照片 D x 代表的是 Y → X 的判别器 , D y 代表的是 X → Y 的判别器 G 代表的是莫奈转普通照片的生成器,而 F 代表普通照片转莫奈油画 X代表的是莫奈的油画 , Y 代表的是普通照片 \\ D_x代表的是Y \to X的判别器,D_y代表的是X \to Y的判别器 \\ G代表的是莫奈转普通照片的生成器,而F代表普通照片转莫奈油画 X代表的是莫奈的油画,Y代表的是普通照片Dx代表的是Y→X的判别器,Dy代表的是X→Y的判别器G代表的是莫奈转普通照片的生成器,而F代表普通照片转莫奈油画
清楚了上面的定义,我们参照原始代码中生成器的结构:
在这里生成器是一个UNet形状的结构,输入一个256x256的图像,然后通过一系列块:
- 其中CLI是由卷积、InstanceNorm和Leaky RELU组成的。
- ReflectionPad(*)是一种图像增强方式,使得图像沿着边缘上下左右进行对称,增大图像分辨率的方式。
- Residual block模块负责将数据进行恢复增强。
那么我们由如上的定义,可以写出如下的代码了!
# 分别定义两个判别器和两个生成器
# 生成器的定义
class GeneratorResNet(nn.Module):
def __init__(self, input_shape, num_residual_blocks):
super(GeneratorResNet, self).__init__()
channels = input_shape[0]
# Initial convolution block
out_features = 64
model = [
nn.ReflectionPad2d(channels),
nn.Conv2d(channels, out_features, 7),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Downsampling
for _ in range(2):
out_features *= 2
model += [
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Residual blocks
for _ in range(num_residual_blocks):
model += [ResidualBlock(out_features)]
# Upsampling
for _ in range(2):
out_features //= 2
model += [
nn.Upsample(scale_factor=2),
nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
nn.InstanceNorm2d(out_features),
nn.ReLU(inplace=True),
]
in_features = out_features
# Output layer
model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
判别器特别简单,可以归纳为一个线性结构,直上直下,最后展平成一个(b,1)的维度。
所以我们再定义判别器:
# 判别器的定义
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
channels, height, width = input_shape
# Calculate output shape of image discriminator (PatchGAN)
self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
def discriminator_block(in_filters, out_filters, normalize=True):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*discriminator_block(channels, 64, normalize=False),
*discriminator_block(64, 128),
*discriminator_block(128, 256),
*discriminator_block(256, 512),
nn.ZeroPad2d((1, 0, 1, 0)),
nn.Conv2d(512, 1, 4, padding=1)
)
def forward(self, img):
return self.model(img)
声明好结构之后,我们来定义生成器和判别器:
G_AB = GeneratorResNet(input_shape, opt.n_residual_blocks)
G_BA = GeneratorResNet(input_shape, opt.n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)
2.2 一致性损失
难道这就是全部了吗?当然不是!因为你想想,如果这种方式直接训练,那么会出现这种问题:
莫奈油画转普通照片,只是让生成的图像很像普通照片的风格,但是一点儿莫奈油画里画的元素都没有!
也就是说,莫奈油画里画了一只狗(的油画),但是我生成的照片是一只其他的小动物(的照片),牛头不对马嘴了~
那该怎么办呢?所以作者提出了一致性损失!
简而言之,就是A -> B之后,B 再次转回到 A时,生成的图片要和A初始的图像对得上(也就是最小化损失,保持图片的一致性)
公式如下:
上面的损失什么意思呢?就是:
x
→
G
(
x
)
:
x
被随机采样,然后过
G
(
∗
)
生成器,得到
G
(
x
)
G
(
x
)
→
F
(
G
(
x
)
)
:
G
(
x
)
送到
F
(
∗
)
生成器
,
也就是上面说的再生成回来
最后得到结果和原始
x
的
L
1
损失
x \to G(x) : x被随机采样,然后过G(*)生成器,得到G(x) \\ G(x) \to F(G(x)) : G(x) 送到 F(*)生成器,也就是上面说的再生成回来 \\ 最后得到结果和原始x的L1 损失
x→G(x):x被随机采样,然后过G(∗)生成器,得到G(x)G(x)→F(G(x)):G(x)送到F(∗)生成器,也就是上面说的再生成回来最后得到结果和原始x的L1损失
那么总损失可以概括为:
因此可以定义出文中的所有损失:
# GAN损失
criterion_GAN = torch.nn.MSELoss()
# A to B 循环损失
criterion_cycle = torch.nn.L1Loss()
# B to A 循环损失
criterion_identity = torch.nn.L1Loss()
最后是训练的代码:
for epoch in range(opt.epoch, opt.n_epochs):
for i, batch in enumerate(dataloader):
# Set model input
# A 、B都是真实数据
real_A = Variable(batch["A"].type(Tensor))
real_B = Variable(batch["B"].type(Tensor))
# Adversarial ground truths
# 定义标签
valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)
# ------------------
# Train Generators
# ------------------
G_AB.train()
G_BA.train()
optimizer_G.zero_grad()
# Identity loss
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle loss
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss
loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
loss_G.backward()
optimizer_G.step()
# -----------------------
# Train Discriminator A
# -----------------------
optimizer_D_A.zero_grad()
# Real loss
loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
loss_D_A.backward()
optimizer_D_A.step()
# -----------------------
# Train Discriminator B
# -----------------------
optimizer_D_B.zero_grad()
# Real loss
loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
loss_D_B.backward()
optimizer_D_B.step()
loss_D = (loss_D_A + loss_D_B) / 2
# --------------
# Log Progress
# --------------
# Determine approximate time left
batches_done = epoch * len(dataloader) + i
batches_left = opt.n_epochs * len(dataloader) - batches_done
time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
prev_time = time.time()
# Print log
sys.stdout.write(
"\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, adv: %f, cycle: %f, identity: %f] ETA: %s"
% (
epoch,
opt.n_epochs,
i,
len(dataloader),
loss_D.item(),
loss_G.item(),
loss_GAN.item(),
loss_cycle.item(),
loss_identity.item(),
time_left,
)
)
# If at sample interval save image
if batches_done % opt.sample_interval == 0:
sample_images(batches_done)
# Update learning rates
lr_scheduler_G.step()
lr_scheduler_D_A.step()
lr_scheduler_D_B.step()
The End