一、解码器
第一个vgg编码器,看图因该是选取了vgg网络的前Reule4_1层,第二个编码器就是vgg的前Reule5_1。这块的代码和AdaIN的编码器几乎是一样的用的,一点不同就是比AdaIN的多了一层。
二、注意力机制融合模块
模块图如下:
网络的输出图如下:
逐步解释以上内容的实现,这是文章的一大主要的创新点
1、Fc(rule_4_1)、Fs(rule_4_1)经过style-Attentional变成了Fcs(rule_4_1)
这一步的输入是内容和风格图像经过编码器前四层所输出的特征图。
(1)Fc和Fs都进行了归一化处理
def calc_mean_std(feat, eps=1e-5):
# eps is a small value added to the variance to avoid divide-by-zero.
size = feat.size()
assert (len(size) == 4)
N, C = size[:2]
feat_var = feat.view(N, C, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(N, C, 1, 1)
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
return feat_mean, feat_std
def mean_variance_norm(feat):
size = feat.size()
mean, std = calc_mean_std(feat)
normalized_feat = (feat - mean.expand(size)) / std.expand(size)
return normalized_feat
class SANet(nn.Module):
def __init__(self, in_planes):
super(SANet, self).__init__()
self.f = nn.Conv2d(in_planes, in_planes, (1, 1)) # 1x1卷积层,用于提取content的特征
self.g = nn.Conv2d(in_planes, in_planes, (1, 1)) # 1x1卷积层,用于提取style的特征
self.h = nn.Conv2d(in_planes, in_planes, (1, 1)) # 1x1卷积层,用于生成style的特征
self.sm = nn.Softmax(dim=-1) # softmax层,用于将注意力图进行归一化
self.out_conv = nn.Conv2d(in_planes, in_planes, (1, 1)) # 1x1卷积层,用于生成最终的输出特征图
def forward(self, content, style):
F = self.f(mean_variance_norm(content)) # 对content进行均值方差归一化后,通过f卷积层得到特征图F
G = self.g(mean_variance_norm(style)) # 对style进行均值方差归一化后,通过g卷积层得到特征图G
H = self.h(style) # 通过h卷积层得到生成style特征的特征图H
b, c, h, w = F.size()
F = F.view(b, -1, w * h).permute(0, 2, 1) # 调整F的形状,用于计算注意力图S
b, c, h, w = G.size()
G = G.view(b, -1, w * h) # 调整G的形状,用于计算注意力图S
S = torch.bmm(F, G) # 计算注意力图S,S的形状为[b, h*w, h*w],F和G矩阵相乘
S = self.sm(S) # 对S进行归一化,使得每个位置的注意力权重之和为1
b, c, h, w = H.size()
H = H.view(b, -1, w * h) # 调整H的形状,用于计算生成特征图O
O = torch.bmm(H, S.permute(0, 2, 1)) # 计算生成特征图O,O的形状为[b, c, h*w]
b, c, h, w = content.size()
O = O.view(b, c, h, w) # 调整O的形状,使其与content的形状相同
O = self.out_conv(O) # 通过out_conv卷积层生成最终的输出特征图
O += content # 将输出特征图与content相加
return O
其过程如图所示:
这个风格融合模块和自适应归一化还是很不一样的,他并没有用到方差对齐、均值对齐的操作,但是达到了任意风格迁移的效果,这是为啥呢?
这个是通过优化模型,动态调整权重来实现风格风格融合的。
三、编码器
解码器的内容和AdaIN中的一摸一样
decoder = nn.Sequential(
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 256, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 128, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 64, (3, 3)),
nn.ReLU(),
nn.Upsample(scale_factor=2, mode='nearest'),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 3, (3, 3)),
)
四、训练过程
1、损失函数
(1)内容损失函数
def calc_content_loss(self, input, target, norm=False):
if (norm == False):
return self.mse_loss(input, target)
else:
return self.mse_loss(mean_variance_norm(input), mean_variance_norm(target))
(2)风格损失函数
def calc_style_loss(self, input, target):
input_mean, input_std = calc_mean_std(input)
target_mean, target_std = calc_mean_std(target)
return self.mse_loss(input_mean, target_mean) + \
self.mse_loss(input_std, target_std)
(3)身份损失函数
l_identity1,优化的是decoder、transform。
l_identity2,优化的是编码器的前一层。
Icc = self.decoder(self.transform(content_feats[3], content_feats[3], content_feats[4], content_feats[4]))
Iss = self.decoder(self.transform(style_feats[3], style_feats[3], style_feats[4], style_feats[4]))
l_identity1 = self.calc_content_loss(Icc, content) + self.calc_content_loss(Iss, style)
Fcc = self.encode_with_intermediate(Icc)
Fss = self.encode_with_intermediate(Iss)
l_identity2 = self.calc_content_loss(Fcc[0], content_feats[0]) + self.calc_content_loss(Fss[0], style_feats[0])
for i in range(1, 5):
l_identity2 += self.calc_content_loss(Fcc[i], content_feats[i]) + self.calc_content_loss(Fss[i], style_feats[i])
(4)训练
for i in tqdm(range(args.start_iter, args.max_iter)):
adjust_learning_rate(optimizer, iteration_count=i)
content_images = next(content_iter).to(device)
style_images = next(style_iter).to(device)
loss_c, loss_s, l_identity1, l_identity2 = network(content_images, style_images)
loss_c = args.content_weight * loss_c
loss_s = args.style_weight * loss_s
loss = loss_c + loss_s + l_identity1 * 50 + l_identity2 * 1
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
state_dict = decoder.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(torch.device('cpu'))
torch.save(state_dict,
'{:s}/decoder_iter_{:d}.pth'.format(args.save_dir,
i + 1))
state_dict = network.transform.state_dict()
for key in state_dict.keys():
state_dict[key] = state_dict[key].to(torch.device('cpu'))
torch.save(state_dict,
'{:s}/transformer_iter_{:d}.pth'.format(args.save_dir,
i + 1))
state_dict = optimizer.state_dict()
torch.save(state_dict,
'{:s}/optimizer_iter_{:d}.pth'.format(args.save_dir,
i + 1))
五、总结
论文详细解读,这篇论文就是一个基于AdaIN的论文,代码都差不太多,比较大的区别就是在本篇论文中,是通过训练权重比例来融合风格图像和内容图像的。并且提出了一个损失函数身份损失函数。这个身份损失函数和cyclegan。