介绍
CycleGAN网络具有很强大的风格迁移功能。能够实现非常深层次的风格转换。比如男性图片女性化或者女性图片男性化。
先上效果图:
下面简单谈一谈实现原理。
网络结构
网络结构如图所示,通过两个循环使用的生成器来进行风格迁移。由此实现了非常神奇的效果。
下面结合代码来详细解释一下网络结构。训练生成对抗网络的深度学习框架为Pytorch。
1. 残差模块定义
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
# 残差模块不改变shape
conv_block = [ nn.ReflectionPad2d(1), # 构建残差模块的时候使用映射填充的形式
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features), # 不使用BatchNorm而是使用InstanceNorm
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
nn.InstanceNorm2d(in_features) ]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
残差模块的定义没有太多需要说明的地方,就是有一点需要注意的是。我们在风格迁移中,不再使用BatchNorm而是使用InstanceNorm。
BN是将每一个batch的每一个通道的每一组图片求mean和var, IN是将单独一个图片的一个通道的数据求mean和var。 区别就是一个是对batch求,一个是对一个图片求。风格迁移中,为了保证风格,通常都对每一个图片单独处理。 CycleGAN网络中,每一个batch只有一张 图片,所以使用InstanceNorm。
2. 定义生成器