EDVR算法在CVPR 2019 Workshop NTIRE 2019 视频恢复比赛的四个赛道中获得了冠军,且大幅领先其它选手。paper code
算法框架:
PreDeblur module:
class Predeblur_ResNet_Pyramid(nn.Module):
def __init__(self, nf=128, HR_in=False):
'''
HR_in: True if the inputs are high spatial size
'''
super(Predeblur_ResNet_Pyramid, self).__init__()
self.HR_in = True if HR_in else False
if self.HR_in:
self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
else:
self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True)
basic_block = functools.partial(arch_util.ResidualBlock_noBN, nf=nf)
self.RB_L1_1 = basic_block()
self.RB_L1_2 = basic_block()
self.RB_L1_3 = basic_block()
self.RB_L1_4 = basic_block()
self.RB_L1_5 = basic_block()
self.RB_L2_1 = basic_block()
self.RB_L2_2 = basic_block()
self.RB_L3_1 = basic_block()
self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, x):
if self.HR_in:
L1_fea = self.lrelu(self.conv_first_1(x))
L1_fea = self.lrelu(self.conv_first_2(L1_fea))
L1_fea = self.lrelu(self.conv_first_3(L1_fea))
else:
L1_fea = self.lrelu(self.conv_first(x))
L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea))
L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea))
L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear',
align_corners=False)
L2_fea = self.RB_L2_1(L2_fea) + L3_fea
L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear',
align_corners=False)
L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea
out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea)))
return out
对齐模块:
PCD代码:
class PCD_Align(nn.Module):
''' Alignment module using Pyramid, Cascading and Deformable convolution
with 3 pyramid levels.
'''
def __init__(self, nf=64, groups=8):
super(PCD_Align, self).__init__()
# L3: level 3, 1/4 spatial size
self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L3_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
# L2: level 2, 1/2 spatial size
self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L2_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
# L1: level 1, original spatial size
self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset
self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.L1_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
# Cascading DCN
self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff
self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.cas_dcnpack = DCN(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups,
extra_offset_mask=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, nbr_fea_l, ref_fea_l):
'''align other neighboring frames to the reference frame in the feature level
nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features
'''
# L3
L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1)
L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset))
L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset))
L3_fea = self.lrelu(self.L3_dcnpack([nbr_fea_l[2], L3_offset]))
# L2
L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1)
L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset))
L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False)
L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1)))
L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset))
L2_fea = self.L2_dcnpack([nbr_fea_l[1], L2_offset])
L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False)
L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1)))
# L1
L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1)
L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset))
L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False)
L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1)))
L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset))
L1_fea = self.L1_dcnpack([nbr_fea_l[0], L1_offset])
L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False)
L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1))
# Cascading
offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1)
offset = self.lrelu(self.cas_offset_conv1(offset))
offset = self.lrelu(self.cas_offset_conv2(offset))
L1_fea = self.lrelu(self.cas_dcnpack([L1_fea, offset]))
return L1_fea
时空注意力模块:
TSA代码:
class TSA_Fusion(nn.Module):
''' Temporal Spatial Attention fusion module
Temporal: correlation;
Spatial: 3 pyramid levels.
'''
def __init__(self, nf=64, nframes=5, center=2):
super(TSA_Fusion, self).__init__()
self.center = center
# temporal attention (before fusion conv)
self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
# fusion conv: using 1x1 to save parameters and computation
self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
# spatial attention (after fusion conv)
self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True)
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
self.avgpool = nn.AvgPool2d(3, stride=2, padding=1)
self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True)
self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True)
self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
def forward(self, aligned_fea):
B, N, C, H, W = aligned_fea.size() # N video frames
#### temporal attention
emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone())
emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W]
cor_l = []
for i in range(N):
emb_nbr = emb[:, i, :, :, :]
cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W
cor_l.append(cor_tmp)
cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W
cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W)
aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob
#### fusion
fea = self.lrelu(self.fea_fusion(aligned_fea))
#### spatial attention
att = self.lrelu(self.sAtt_1(aligned_fea))
att_max = self.maxpool(att)
att_avg = self.avgpool(att)
att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1)))
# pyramid levels
att_L = self.lrelu(self.sAtt_L1(att))
att_max = self.maxpool(att_L)
att_avg = self.avgpool(att_L)
att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1)))
att_L = self.lrelu(self.sAtt_L3(att_L))
att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False)
att = self.lrelu(self.sAtt_3(att))
att = att + att_L
att = self.lrelu(self.sAtt_4(att))
att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False)
att = self.sAtt_5(att)
att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att)))
att = torch.sigmoid(att)
fea = fea * att * 2 + att_add
return fea
参考文献:
- https://zhuanlan.zhihu.com/p/86001304
- https://blog.csdn.net/liyu0611/article/details/90404328
- https://www.tfzx.net/index.php/article/662917.html
- http://www.tuan18.org/thread-19483-1-1.html