Robust Video Matting是目前基于视频抠图的最优方案。它的网络和训练方法优势在于几点:
1. 充分利用现有二值分割数据提取语义信息,结合高质量Matting数据集做到发丝级分割
2. GRU提取帧间连续特征,稳定分割效果
3. 同时支持图像和视频数据
4. 支持任意分辨率输入
还有几个小的trick,比如最后一层输出直接用conv+clamp不做激活。视频有视频的优势,基于视频的算法依赖于连续帧间信息,运用在只有单张图片的抠图效果并不能达到最好,基于此,我们依然可以将RVM的1、3、4几个优势利用在图像Matting上。实验下来,最后再接个几个全卷积的SharpNet优化一下网络边缘,效果最好。网络结构代码如下
import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from typing import Optional, List
import segmentation_models_pytorch as smp
from .lraspp import LRASPP
from .fast_guided_filter import FastGuidedFilterRefiner
from .deep_guided_filter import DeepGuidedFilterRefiner
from .onnx_helper import CustomOnnxResizeByFactorOp
class MattingNetwork(nn.Module):
def __init__(self,
variant: str = 'densenet169',
refiner: str = 'deep_guided_filter',
):
super().__init__()
assert refiner in ['fast_guided_filter', 'deep_guided_filter']
decoder_channels = (256, 128, 64, 32, 16)
self.t_net = nn.Sequential(
#smp.Unet(variant, decoder_channels=decoder_channels, activation=None),
smp.Unet(variant, decoder_channels=decoder_channels, activation=None),
nn.Conv2d(decoder_channels[-1], 3, 3, 1, 1)
)
self.m_net = nn.Sequential(
smp.Unet(variant, decoder_channels=decoder_channels, in_channels=6, activation=None),
#nn.Conv2d(decoder_channels[-1], 4, 3, 1, 1)
nn.Conv2d(decoder_channels[-1], 4, 7, 1, 4)
)
#self.initialize_module(self.t_net)
#self.initialize_module(self.m_net)
self.sharpnet = nn.Sequential(
nn.Conv2d(4, 64, 3, 1, 1),
nn.ReLU(),
#nn.Sigmoid(),
nn.Conv2d(64, 64, 3, 1, 1),
nn.ReLU(),
#nn.Sigmoid(),
nn.Conv2d(64, 64, 3, 1, 1),
nn.ReLU(),
#nn.Sigmoid(),
nn.Conv2d(64, 1, 3, 1, 1)
)
self.initialize_module(self.sharpnet)
if refiner == 'deep_guided_filter':
self.refiner = DeepGuidedFilterRefiner(decoder_channels[-1])
else:
self.refiner = FastGuidedFilterRefiner(decoder_channels[-1])
self.initialize_module(self.refiner)
def initialize_module(self, module):
for m in module.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self,
src: Tensor,
fullnet: bool = False,
sharpnet: bool = False,
downsample_ratio: float = 1,
segmentation_pass: bool = False):
_,_,H,W = src.shape
if torch.onnx.is_in_onnx_export():
src_sm = CustomOnnxResizeByFactorOp.apply(src, downsample_ratio)
elif downsample_ratio != 1:
src_sm = self._interpolate(src, scale_factor=downsample_ratio)
else:
src_sm = src
_,_,smH,smW = src_sm.shape
if not segmentation_pass:
seg = None
if fullnet:
seg = self.t_net(src_sm[:,:3,:,:])
src_sm = torch.cat((src_sm[:,:3,:,:], seg.sigmoid()), 1)
hid = self.m_net[0](src_sm)
x = self.m_net[1](hid)
#x = self.m_net(src_sm)
fgr_residual, pha = x.split([3, 1], dim=-3)
fgr_residual = fgr_residual[:,:,:smH,:smW]
pha = pha[:,:,:smH,:smW]
if downsample_ratio != 1:
fgr_residual, pha = self.refiner(src[:,:3,:,:], src_sm[:,:3,:,:], fgr_residual, pha, hid)
fgr_residual = fgr_residual[:,:,:H,:W]
pha = pha.clamp(0., 1.)
if sharpnet:
edgepha = (pha>0.)*(pha<=0.99)*pha
pha = pha*(pha>0.99) + \
self.sharpnet(torch.cat((src[:,:3,:,:],edgepha),1))*edgepha
pha = pha[:,:,:,:]
pha = pha.clamp(0., 1.)
fgr = fgr_residual + src[:,:3,:,:]
fgr = fgr.clamp(0., 1.)
return [fgr, pha, seg]
else:
seg = self.t_net(src_sm)
seg = seg[:,:,:H,:W]
return [seg]
def _interpolate(self, x: Tensor, scale_factor: float):
x = F.interpolate(x, scale_factor=scale_factor,
mode='bilinear', align_corners=False, recompute_scale_factor=False)
return x
效果: