Metric Learning based Interactive Modulation for Real-World Super-Resolution
1. 思路介绍
关于blind sr, 常规的方法是 建立各种 退化方法数据集,然后监督训练得到 退化的表示向量,
在将 退化表征向量 和 图像 一起输入网络 建立blind sr, 如下图左
本文提出的是下图右,主要区别在于 作者建立 退化强度的表征不是依赖监督训练,而是依赖度量学习(metric learning)。通过图像比较的方法建立损失函数进行训练,如右图。
2. 整体结构
主要包含3部分.
- unsupervised degradation estimation module(UDEM) 通过metric learning 的方法学习到 degradation score.
- condition network 将 degradation score 转化为向量的形式作为 退化强度的表征,方便输入 base net
- base net是一个 RRDB组成的 超分网络,中间会引入 退化强度相关的向量。
因此最终的网络可以估计退化的强度,进而更好的进行超分辨率。
2.1 unsupervised degradation estimation module(UDEM) 无监督退化估计
两个退化估计网络分别 估计 blur score 和 noise score.
code如下:
分别是conv layer
body是一系列 residual block
avg_pool 池化
fc_degree:两个linear layer
num_degradation = 2 表示两个同样的网络分别估计 blur score和 noise score
每个网络输入的是 退化的图像,输出的是一个blur score 或 noise score
self.fc_degree.append(
nn.Sequential(
nn.Linear(num_feats[-1], 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1), # score
actv(),
))
def forward(self, x):
degrees = []
for i in range(self.num_degradation):
x_out = self.conv_first[i](x)
feat = self.body[i](x_out)
feat = self.avg_pool(feat)
feat = feat.squeeze(-1).squeeze(-1)
# for i in range(self.num_degradation):
degrees.append(self.fc_degree[i](feat).squeeze(-1))
return degrees
2.2 unsupervised degradation estimation module(UDEM) 无监督退化估计的损失函数
是该paper的创新点
一个 HR 图像经过blur 和 noise 两个大类的退化后,得到5个图像。
如下图,c1 比 c2 的 blur大,c3比 c2的 noise大,这样通过ranking 损失函数来建立约束,这样网络可以学习到 不同强度的 排名分数。
损失函数如下:si,sj是真实的degradation score, si_hat,sj_hat是网络输出的degradation score。这个损失函数是当 si的退化强度大于sj 时使si_hat > sj_hat, 反之亦然。
为了使 排名分数约束在一定范围,同时引入两个 ancher image. 一个使 没有退化,一个是 noise,blur都最大的退化。
并且是前者分数为0,后者分数为1.
损失函数如下:
下图,五个图像得到五组分数(每组包括blur score, noise score),通过上面建立损失函数建立一个无监督退化估计网络
其中 排名的损失函数是 margin ranking loss
class MarginRankingLoss(nn.Module):
def __init__(self, loss_weight=1.0, margin=0.0, reduction='mean'):
super(MarginRankingLoss, self).__init__()
if reduction not in ['none', 'mean', 'sum']:
raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
self.loss_weight = loss_weight
self.margin = margin
self.reduction = reduction
def forward(self, input1, input2, target):
return self.loss_weight * F.margin_ranking_loss(
input1, input2, target, margin=self.margin, reduction=self.reduction)
2.3 condition net
退化估计网络得到两个退化分数,condition net 输入 退化分数,得到 退化的向量表征。
就是通过一个全连接网络将 退化分数转化为向量表示.
退化向量 alpha(代码中是gamma) 和 beta再与feature进行fusion
class AffineModulate(nn.Module):
def __init__(self, degradation_dim=128, num_feat=64):
super(AffineModulate, self).__init__()
degradation_dim= 512 #256 #64
self.fc = nn.Sequential(
nn.Linear(degradation_dim, (degradation_dim + num_feat * 2) // 2),
nn.ReLU(True),
nn.Linear((degradation_dim + num_feat * 2) // 2, (degradation_dim + num_feat * 2) // 2),
nn.ReLU(True),
nn.Linear((degradation_dim + num_feat * 2) // 2, num_feat * 2),
)
default_init_weights([self.fc], 0.1)
def forward(self, x, d):
d = self.fc(d)
d = d.view(d.size(0), d.size(1), 1, 1)
gamma, beta = torch.chunk(d, chunks=2, dim=1)
return
2.4 base net
关键之处在与 每个 RRDB block得到的feature(代码中是x) 通过 (1 + gamma) * x + beta
引入 退化因素
关键代码:初始化多个 RRDB block 和 AffineModulate
for _ in range(num_block):
self.body.append(RRDB(num_feat, num_grow_ch=num_grow_ch))
self.am_list.append(AffineModulate(degradation_dim=512, num_feat=num_feat))
在 forward函数中依次连接
for i in range(self.num_block):
feat = self.body[i](feat)
feat = self.am_list[i](feat, d_embedding)
2.5 损失函数
除了 无监督退化估计的 度量损失,还有 GAN损失,感知损失 per,重建损失 L1
3. 结果
几个方法的比较:
关键就是介绍了一个无监督退化估计的度量方法,如标题。
官方代码是基于basic sr 来做的