SAnet-任意风格迁移代码详解

这里是引用

一、解码器

第一个vgg编码器,看图因该是选取了vgg网络的前Reule4_1层,第二个编码器就是vgg的前Reule5_1。这块的代码和AdaIN的编码器几乎是一样的用的,一点不同就是比AdaIN的多了一层。

二、注意力机制融合模块

模块图如下:
在这里插入图片描述
网络的输出图如下:
在这里插入图片描述

逐步解释以上内容的实现,这是文章的一大主要的创新点

1、Fc(rule_4_1)、Fs(rule_4_1)经过style-Attentional变成了Fcs(rule_4_1)

这一步的输入是内容和风格图像经过编码器前四层所输出的特征图。

(1)Fc和Fs都进行了归一化处理
def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std

def mean_variance_norm(feat):
    size = feat.size()
    mean, std = calc_mean_std(feat)
    normalized_feat = (feat - mean.expand(size)) / std.expand(size)
    return normalized_feat
class SANet(nn.Module):
    def __init__(self, in_planes):
        super(SANet, self).__init__()
        self.f = nn.Conv2d(in_planes, in_planes, (1, 1))  # 1x1卷积层,用于提取content的特征
        self.g = nn.Conv2d(in_planes, in_planes, (1, 1))  # 1x1卷积层,用于提取style的特征
        self.h = nn.Conv2d(in_planes, in_planes, (1, 1))  # 1x1卷积层,用于生成style的特征
        self.sm = nn.Softmax(dim=-1)  # softmax层,用于将注意力图进行归一化
        self.out_conv = nn.Conv2d(in_planes, in_planes, (1, 1))  # 1x1卷积层,用于生成最终的输出特征图

    def forward(self, content, style):
        F = self.f(mean_variance_norm(content))  # 对content进行均值方差归一化后,通过f卷积层得到特征图F
        G = self.g(mean_variance_norm(style))  # 对style进行均值方差归一化后,通过g卷积层得到特征图G
        H = self.h(style)  # 通过h卷积层得到生成style特征的特征图H
        b, c, h, w = F.size()
        F = F.view(b, -1, w * h).permute(0, 2, 1)  # 调整F的形状,用于计算注意力图S
        b, c, h, w = G.size()
        G = G.view(b, -1, w * h)  # 调整G的形状,用于计算注意力图S
        S = torch.bmm(F, G)  # 计算注意力图S,S的形状为[b, h*w, h*w],F和G矩阵相乘
        S = self.sm(S)  # 对S进行归一化,使得每个位置的注意力权重之和为1
        b, c, h, w = H.size()
        H = H.view(b, -1, w * h)  # 调整H的形状,用于计算生成特征图O
        O = torch.bmm(H, S.permute(0, 2, 1))  # 计算生成特征图O,O的形状为[b, c, h*w]
        b, c, h, w = content.size()
        O = O.view(b, c, h, w)  # 调整O的形状,使其与content的形状相同
        O = self.out_conv(O)  # 通过out_conv卷积层生成最终的输出特征图
        O += content  # 将输出特征图与content相加
        return O

其过程如图所示:
在这里插入图片描述

这个风格融合模块和自适应归一化还是很不一样的,他并没有用到方差对齐、均值对齐的操作,但是达到了任意风格迁移的效果,这是为啥呢?
这个是通过优化模型,动态调整权重来实现风格风格融合的。

三、编码器

解码器的内容和AdaIN中的一摸一样

decoder = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 3, (3, 3)),
)

四、训练过程

1、损失函数

(1)内容损失函数
    def calc_content_loss(self, input, target, norm=False):
        if (norm == False):
            return self.mse_loss(input, target)
        else:
            return self.mse_loss(mean_variance_norm(input), mean_variance_norm(target))
(2)风格损失函数
    def calc_style_loss(self, input, target):
        input_mean, input_std = calc_mean_std(input)
        target_mean, target_std = calc_mean_std(target)
        return self.mse_loss(input_mean, target_mean) + \
            self.mse_loss(input_std, target_std)
(3)身份损失函数

l_identity1,优化的是decoder、transform。
l_identity2,优化的是编码器的前一层。

Icc = self.decoder(self.transform(content_feats[3], content_feats[3], content_feats[4], content_feats[4]))
Iss = self.decoder(self.transform(style_feats[3], style_feats[3], style_feats[4], style_feats[4]))
l_identity1 = self.calc_content_loss(Icc, content) + self.calc_content_loss(Iss, style)
Fcc = self.encode_with_intermediate(Icc)
Fss = self.encode_with_intermediate(Iss)
l_identity2 = self.calc_content_loss(Fcc[0], content_feats[0]) + self.calc_content_loss(Fss[0], style_feats[0])
        for i in range(1, 5):
            l_identity2 += self.calc_content_loss(Fcc[i], content_feats[i]) + self.calc_content_loss(Fss[i],                                                                                          style_feats[i])
(4)训练
for i in tqdm(range(args.start_iter, args.max_iter)):
    adjust_learning_rate(optimizer, iteration_count=i)
    content_images = next(content_iter).to(device)
    style_images = next(style_iter).to(device)
    loss_c, loss_s, l_identity1, l_identity2 = network(content_images, style_images)
    loss_c = args.content_weight * loss_c
    loss_s = args.style_weight * loss_s
    loss = loss_c + loss_s + l_identity1 * 50 + l_identity2 * 1

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
        state_dict = decoder.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict,
                   '{:s}/decoder_iter_{:d}.pth'.format(args.save_dir,
                                                       i + 1))
        state_dict = network.transform.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].to(torch.device('cpu'))
        torch.save(state_dict,
                   '{:s}/transformer_iter_{:d}.pth'.format(args.save_dir,
                                                           i + 1))
        state_dict = optimizer.state_dict()
        torch.save(state_dict,
                   '{:s}/optimizer_iter_{:d}.pth'.format(args.save_dir,
                                                         i + 1))

五、总结
论文详细解读,这篇论文就是一个基于AdaIN的论文,代码都差不太多,比较大的区别就是在本篇论文中,是通过训练权重比例来融合风格图像和内容图像的。并且提出了一个损失函数身份损失函数。这个身份损失函数和cyclegan。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值