前言
风格迁移(Style Transfer)是一个很有意思的任务,通过风格迁移可以使一张图片保持本身内容大致不变的情况下呈现出另外一张图片的风格。风格迁移三步曲将绍以下三种风格迁移方式以及对应的代码实现
固定风格固定内容的普通风格迁移(A Neural Algorithm of Artistic Style)
固定风格任意内容的快速风格迁移(Perceptual Losses for Real-Time Style Transfer and Super-Resolution)
任意风格任意内容的极速风格迁移(Meta Networks for Neural Style Transfer)
本文会介绍固定风格固定内容的普通风格迁移(A Neural Algorithm of Artistic Style)方式以及对应的代码实现。本文所使用的环境是 pytorch 0.4.0,如果你使用了其他的版本,稍作修改即可正确运行。
1 固定风格固定内容的普通风格迁移
最早的风格迁移就是在固定风格、固定内容的情况下做的风格迁移,这是最慢的方法,也是最经典的方法。
最原始的风格迁移的思路很简单,把图片当做可以训练的变量,通过优化图片来降低与内容图片的内容差异以及降低与风格图片的风格差异,迭代训练多次以后,生成的图片就会与内容图片的内容一致,同时也会与风格图片的风格一致。
VGG16
VGG16 是一个很经典的模型,它通过堆叠 3x3 的卷积层和池化层,在 ImageNet 上获得了不错的成绩。我们使用在 ImageNet 上经过预训练的 VGG16 模型可以对图像提取出有用的特征,这些特征可以帮助我们去衡量两个图像的内容差异和风格差异。
在进行风格迁移任务时,我们只需要提取其中几个比较重要的层,所以我们对 pytorch 自带的预训练 VGG16 模型稍作了一些修改:
1class VGG(nn.Module):
2
3 def __init__(self, features):
4 super(VGG, self).__init__()
5 self.features = features
6 self.layer_name_mapping = {
7 '3': "relu1_2",
8 '8': "relu2_2",
9 '15': "relu3_3",
10 '22': "relu4_3"
11 }
12 for p in self.parameters():
13 p.requires_grad = False
14
15 def forward(self, x):
16 outs = []
17 for name, module in self.features._modules.items():
18 x = module(x)
19 if name in self.layer_name_mapping:
20 outs.append(x)
21 return outs
22
23vgg16 = models.vgg16(pretrained=True)
24vgg16 = VGG(vgg16.features[:23]).to(device).eval()
经过修改的 VGG16 可以输出 relu1_2、relu2_2、relu3_3、relu4_3 这几个特定层的特征图。下面这两句代码就是它的用法:
1features = vgg16(input_img)
2content_features = vgg16(content_img)
举个例子,当我们使用 vgg16 对 input_img 计算特征时,它会返回四个矩阵给 features,假设 input_img 的尺寸是 [1, 3, 512, 512](四个维度分别代表 batch, channels, height, width),那么它返回的四个矩阵的尺寸就是这样的:
relu1_2 [1, 64, 512, 512]
relu2_2 [1, 128, 256, 256]
relu3_3 [1, 256, 128, 128]
relu4_3 [1, 512, 64, 64]
内容
我们进行风格迁移的时候,必须保证生成的图像与内容图像的内容一致性,不然风格迁移就变成艺术创作了。那么如何衡量两张图片的内容差异呢?很简单,通过 VGG16 输出的特征图来衡量图片的内容差异。
提示:在本方法中没有 Image Transform Net,为了表述方便,我们使用了第二篇论文中的图。
这里使用的损失函数是:
其中:
根据生成图像和内容图像在 relu3_3 输出的特征图的均方误差(MeanSquaredError)来优化生成的图像与内容图像之间的内容一致性。
那么写成代码就是这样的:
1content_loss = F.mse_loss(features[2], content_features[2]) * content_weight
因为我们这里使用的是经过在 ImageNet 预训练过的 VGG16 提取的特征图,所以它能提取出图像的高级特征,通过优化生成图像和内容图像特征图的 mse,可以迫使生成图像的内容与内容图像在 VGG16 的 relu3_3 上输出相似的结果,因此生成图像和内容图像在内容上是一致的。
风格
Gram 矩阵
那么如何衡量输入图像与风格图像之间的内容差异呢?这里就需要提出一个新的公式,Gram 矩阵:
其中:
具体到代码,我们可以写出下面的函数:
1def gram_matrix(y):
2 (b, ch, h, w) = y.size()
3 features = y.view(b, ch, w * h)
4 features_t = features.transpose(1, 2)
5 gram = features.bmm(features_t) / (ch * h * w)
6 return gram
参考链接:
假设我们输入了一个 [1, 3, 512, 512] 的图像,下面就是各个矩阵的尺寸:
relu1_2 [1, 64, 512, 512],gram [1, 64, 64]
relu2_2 [1, 128, 256, 256],gram [1, 128, 128]
relu3_3 [1, 256, 128, 128],gram [1, 256, 256]
relu4_3 [1, 512, 64, 64],gram [1, 512, 512]
风格损失
根据生成图像和风格图像在relu1_2、relu2_2、relu3_3、relu4_3 输出的特征图的 Gram 矩阵之间的均方误差(MeanSquaredError)来优化生成的图像与风格图像之间的风格差异:
其中:
那么写成代码就是下面这样:
1style_grams = [gram_matrix(x) for x in style_features]
2
3style_loss = 0
4grams = [gram_matrix(x) for x in features]
5for a, b in zip(grams, style_grams):
6 style_loss += F.mse_loss(a, b) * style_weight
训练
那么风格迁移的目标就很简单了,直接将两个 loss 按权值加起来,然后对图片优化 loss,即可优化出既有内容图像的内容,也有风格图像的风格的图片。代码如下:
1input_img = content_img.clone()
2optimizer = optim.LBFGS([input_img.requires_grad_()])
3style_weight = 1e6
4content_weight = 1
5
6run = [0]
7while run[0] <= 300:
8 def f():
9 optimizer.zero_grad()
10 features = vgg16(input_img)
11
12 content_loss = F.mse_loss(features[2], content_features[2]) * content_weight
13 style_loss = 0
14 grams = [gram_matrix(x) for x in features]
15 for a, b in zip(grams, style_grams):
16 style_loss += F.mse_loss(a, b) * style_weight
17
18 loss = style_loss + content_loss
19
20 if run[0] % 50 == 0:
21 print('Step {}: Style Loss: {:4f} Content Loss: {:4f}'.format(
22 run[0], style_loss.item(), content_loss.item()))
23 run[0] += 1
24
25 loss.backward()
26 return loss
27
28 optimizer.step(f)
此处使用了 LBFGS,所以 loss 需要包装在一个函数里,代码参考了:
效果
最终效果如图所示:
可以看到生成的图像既有风格图像的风格,也有内容图像的内容,很完美。不过生成一幅256x256 的图像在 1080ti 上需要18.6s,这个时间挺长的,谈不上实时性。
预告
下一篇风格迁移三部曲(二)会介绍固定风格任意内容的快速风格迁移(Perceptual Losses for Real-Time Style Transfer and Super-Resolution)方式以及对应的代码实现。敬请期待