超分辨率:基于metric learning的无监督blind sr:Metric Learning based Interactive Modulation for Real-World Super

Metric Learning based Interactive Modulation for Real-World Super-Resolution

1. 思路介绍

关于blind sr, 常规的方法是 建立各种 退化方法数据集,然后监督训练得到 退化的表示向量,
在将 退化表征向量 和 图像 一起输入网络 建立blind sr, 如下图左

本文提出的是下图右,主要区别在于 作者建立 退化强度的表征不是依赖监督训练,而是依赖度量学习(metric learning)。通过图像比较的方法建立损失函数进行训练,如右图。

在这里插入图片描述

2. 整体结构

主要包含3部分.

  1. unsupervised degradation estimation module(UDEM) 通过metric learning 的方法学习到 degradation score.
  2. condition network 将 degradation score 转化为向量的形式作为 退化强度的表征,方便输入 base net
  3. base net是一个 RRDB组成的 超分网络,中间会引入 退化强度相关的向量。

因此最终的网络可以估计退化的强度,进而更好的进行超分辨率。
在这里插入图片描述

2.1 unsupervised degradation estimation module(UDEM) 无监督退化估计

在这里插入图片描述

两个退化估计网络分别 估计 blur score 和 noise score.

code如下:
分别是conv layer
body是一系列 residual block
avg_pool 池化
fc_degree:两个linear layer

num_degradation = 2 表示两个同样的网络分别估计 blur score和 noise score

每个网络输入的是 退化的图像,输出的是一个blur score 或 noise score

    self.fc_degree.append(
        nn.Sequential(
            nn.Linear(num_feats[-1], 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1),  # score
            actv(),
        ))
    def forward(self, x):
            degrees = []
            for i in range(self.num_degradation):
                x_out = self.conv_first[i](x)
                feat = self.body[i](x_out)
                feat = self.avg_pool(feat)
                feat = feat.squeeze(-1).squeeze(-1)
                # for i in range(self.num_degradation):
                degrees.append(self.fc_degree[i](feat).squeeze(-1))

            return degrees

2.2 unsupervised degradation estimation module(UDEM) 无监督退化估计的损失函数

是该paper的创新点

一个 HR 图像经过blur 和 noise 两个大类的退化后,得到5个图像。
如下图,c1 比 c2 的 blur大,c3比 c2的 noise大,这样通过ranking 损失函数来建立约束,这样网络可以学习到 不同强度的 排名分数。
损失函数如下:si,sj是真实的degradation score, si_hat,sj_hat是网络输出的degradation score。这个损失函数是当 si的退化强度大于sj 时使si_hat > sj_hat, 反之亦然。

在这里插入图片描述

为了使 排名分数约束在一定范围,同时引入两个 ancher image. 一个使 没有退化,一个是 noise,blur都最大的退化。
并且是前者分数为0,后者分数为1.
损失函数如下:

在这里插入图片描述

下图,五个图像得到五组分数(每组包括blur score, noise score),通过上面建立损失函数建立一个无监督退化估计网络
在这里插入图片描述

其中 排名的损失函数是 margin ranking loss

class MarginRankingLoss(nn.Module):

    def __init__(self, loss_weight=1.0, margin=0.0, reduction='mean'):
        super(MarginRankingLoss, self).__init__()
        if reduction not in ['none', 'mean', 'sum']:
            raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')

        self.loss_weight = loss_weight
        self.margin = margin
        self.reduction = reduction

    def forward(self, input1, input2, target):
        return self.loss_weight * F.margin_ranking_loss(
            input1, input2, target, margin=self.margin, reduction=self.reduction)

2.3 condition net

退化估计网络得到两个退化分数,condition net 输入 退化分数,得到 退化的向量表征。
就是通过一个全连接网络将 退化分数转化为向量表示.
在这里插入图片描述

退化向量 alpha(代码中是gamma) 和 beta再与feature进行fusion

class AffineModulate(nn.Module):

    def __init__(self, degradation_dim=128, num_feat=64):
        super(AffineModulate, self).__init__()
        degradation_dim= 512 #256 #64
        self.fc = nn.Sequential(
            nn.Linear(degradation_dim, (degradation_dim + num_feat * 2) // 2),
            nn.ReLU(True),
            nn.Linear((degradation_dim + num_feat * 2) // 2, (degradation_dim + num_feat * 2) // 2),
            nn.ReLU(True),
            nn.Linear((degradation_dim + num_feat * 2) // 2, num_feat * 2),
        )
        default_init_weights([self.fc], 0.1)

    def forward(self, x, d):
        d = self.fc(d)
        d = d.view(d.size(0), d.size(1), 1, 1)
        gamma, beta = torch.chunk(d, chunks=2, dim=1)

        return 

2.4 base net

关键之处在与 每个 RRDB block得到的feature(代码中是x) 通过 (1 + gamma) * x + beta
引入 退化因素

在这里插入图片描述

关键代码:初始化多个 RRDB block 和 AffineModulate

        for _ in range(num_block):
            self.body.append(RRDB(num_feat, num_grow_ch=num_grow_ch))
            self.am_list.append(AffineModulate(degradation_dim=512, num_feat=num_feat))

在 forward函数中依次连接

        for i in range(self.num_block):
            feat = self.body[i](feat)
            feat = self.am_list[i](feat, d_embedding)

2.5 损失函数

除了 无监督退化估计的 度量损失,还有 GAN损失,感知损失 per,重建损失 L1

在这里插入图片描述

3. 结果

几个方法的比较:
在这里插入图片描述

关键就是介绍了一个无监督退化估计的度量方法,如标题。
官方代码是基于basic sr 来做的

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值