改Robust Video Matting为Robust Image Matting

Robust Video Matting是目前基于视频抠图的最优方案。它的网络和训练方法优势在于几点:

1. 充分利用现有二值分割数据提取语义信息,结合高质量Matting数据集做到发丝级分割

2. GRU提取帧间连续特征,稳定分割效果

3. 同时支持图像和视频数据

4. 支持任意分辨率输入

还有几个小的trick,比如最后一层输出直接用conv+clamp不做激活。视频有视频的优势,基于视频的算法依赖于连续帧间信息,运用在只有单张图片的抠图效果并不能达到最好,基于此,我们依然可以将RVM的1、3、4几个优势利用在图像Matting上。实验下来,最后再接个几个全卷积的SharpNet优化一下网络边缘,效果最好。网络结构代码如下

import torch
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from typing import Optional, List

import segmentation_models_pytorch as smp
from .lraspp import LRASPP
from .fast_guided_filter import FastGuidedFilterRefiner
from .deep_guided_filter import DeepGuidedFilterRefiner

from .onnx_helper import CustomOnnxResizeByFactorOp

class MattingNetwork(nn.Module):
    def __init__(self,
                 variant: str = 'densenet169',
                 refiner: str = 'deep_guided_filter',
                 ):
        super().__init__()
        assert refiner in ['fast_guided_filter', 'deep_guided_filter']

        decoder_channels = (256, 128, 64, 32, 16)

        self.t_net = nn.Sequential(
            #smp.Unet(variant, decoder_channels=decoder_channels, activation=None),
            smp.Unet(variant, decoder_channels=decoder_channels, activation=None),
            nn.Conv2d(decoder_channels[-1], 3, 3, 1, 1)
            )

        self.m_net = nn.Sequential(
            smp.Unet(variant, decoder_channels=decoder_channels, in_channels=6, activation=None),
            #nn.Conv2d(decoder_channels[-1], 4, 3, 1, 1)
            nn.Conv2d(decoder_channels[-1], 4, 7, 1, 4)
            )

        #self.initialize_module(self.t_net)
        #self.initialize_module(self.m_net)

        self.sharpnet = nn.Sequential(
            nn.Conv2d(4, 64, 3, 1, 1),
            nn.ReLU(),
            #nn.Sigmoid(),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.ReLU(),
            #nn.Sigmoid(),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.ReLU(),
            #nn.Sigmoid(),
            nn.Conv2d(64, 1, 3, 1, 1)
            )
        self.initialize_module(self.sharpnet)

        if refiner == 'deep_guided_filter':
            self.refiner = DeepGuidedFilterRefiner(decoder_channels[-1])
        else:
            self.refiner = FastGuidedFilterRefiner(decoder_channels[-1])

        self.initialize_module(self.refiner)

    def initialize_module(self, module):
        for m in module.modules():

            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
        
    def forward(self,
                src: Tensor,
                fullnet: bool = False,
                sharpnet: bool = False,
                downsample_ratio: float = 1,
                segmentation_pass: bool = False):
        _,_,H,W = src.shape       

        if torch.onnx.is_in_onnx_export():
            src_sm = CustomOnnxResizeByFactorOp.apply(src, downsample_ratio)
        elif downsample_ratio != 1:
            src_sm = self._interpolate(src, scale_factor=downsample_ratio)
        else:
            src_sm = src

        _,_,smH,smW = src_sm.shape

        if not segmentation_pass:
            seg = None

            if fullnet:
                seg = self.t_net(src_sm[:,:3,:,:])
                src_sm = torch.cat((src_sm[:,:3,:,:], seg.sigmoid()), 1)

            hid = self.m_net[0](src_sm)
            x = self.m_net[1](hid)
            #x = self.m_net(src_sm)

            fgr_residual, pha = x.split([3, 1], dim=-3)

            fgr_residual = fgr_residual[:,:,:smH,:smW]
            pha = pha[:,:,:smH,:smW]

            if downsample_ratio != 1:
                fgr_residual, pha = self.refiner(src[:,:3,:,:], src_sm[:,:3,:,:], fgr_residual, pha, hid)
                fgr_residual = fgr_residual[:,:,:H,:W]

            pha = pha.clamp(0., 1.)

            if sharpnet:
                edgepha = (pha>0.)*(pha<=0.99)*pha

                pha = pha*(pha>0.99) + \
                    self.sharpnet(torch.cat((src[:,:3,:,:],edgepha),1))*edgepha
                pha = pha[:,:,:,:]
                pha = pha.clamp(0., 1.)
 
            fgr = fgr_residual + src[:,:3,:,:]
            fgr = fgr.clamp(0., 1.)
            return [fgr, pha, seg]
        else:
            seg = self.t_net(src_sm)
            seg = seg[:,:,:H,:W]
            return [seg]

    def _interpolate(self, x: Tensor, scale_factor: float):
        x = F.interpolate(x, scale_factor=scale_factor,
            mode='bilinear', align_corners=False, recompute_scale_factor=False)
        return x

效果:

  • 2
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 13
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值