基于MODnet无绿幕抠图

本文介绍了MODNet,一个由香港城市大学和商汤科技联合提出的轻量级、实时、无需额外背景输入的高精度单模型抠图算法。该模型在保持高准确度的同时,具有强大的泛化能力。通过ONNX推理代码实现抠图,并展示了丝发级别的抠图效果,对于清晰图片的抠图表现优异,未来可扩展到视频抠图应用。
摘要由CSDN通过智能技术生成

0.前言

MODNet由香港城市大学和商汤科技于2020年11月首次提出,用于实时抠图任务

MODNet特性:

  • 轻量级(light-weight )
  • 实时性高(real-time)
  • 预测时不需要额外的背景输入(trimap-free)
  • 准确度高(hight performance)
  • 单模型(single model instead of a complex pipeline)
  • 泛化能力强(better generalization ability)

论文地址 : https://arxiv.org/pdf/2011.11961.pdf
git地址: https://github.com/ZHKKKe/MODNet

1.复现代码

基于onnx推理代码

官方给出了基于torch和onnx推理代码,这里用的是关于onnx模型的推理代码.


import os
import cv2
import argparse
import numpy as np
from PIL import Image

import onnx
import onnxruntime


if __name__ == '__main__':
    # define cmd arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--image-path', default= 'test.jpeg',type=str, help='path of the input image (a file)')
    parser.add_argument('--output-path',default= 'result.png', type=str, help='paht for saving the predicted alpha matte (a file)')
    parser.add_argument('--model-path', default='hrnet.onnx', type=str, help='path of the ONNX model')
    args = parser.parse_args()

    # check input arguments
    if not os.path.exists(args.image_path):
        print('Cannot find the input image: {0}'.format(args.image_path))
        exit()
    if not os.path.exists(args.model_path):
        print('Cannot find the ONXX model: {0}'.format(args.model_path))
        exit()

    ref_size = 512

    # Get x_scale_factor & y_scale_factor to resize image
    def get_scale_factor(im_h, im_w, ref_size):

        if max(im_h, im_w) < ref_size or min(im_h, im_w) > ref_size:
            if im_w >= im_h:
                im_rh = ref_size
                im_rw = int(im_w / im_h * ref_size)
            elif im_w < im_h:
                im_rw = ref_size
                im_rh = int(im_h / im_w * ref_size)
        else:
            im_rh = im_h
            im_rw = im_w

        im_rw = im_rw - im_rw % 32
        im_rh = im_rh - im_rh % 32

        x_scale_factor = im_rw / im_w
        y_scale_factor = im_rh / im_h

        return x_scale_factor, y_scale_factor

    ##############################################
    #  Main Inference part
    ##############################################

    # read image
    im = cv2.imread(args.image_path)
    img = im.copy()
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

    # unify image channels to 3
    if len(im.shape) == 2:
        im = im[:, :, None]
    if im.shape[2] == 1:
        im = np.repeat(im, 3, axis=2)
    elif im.shape[2] == 4:
        im = im[:, :, 0:3]

    # normalize values to scale it between -1 to 1
    im = (im - 127.5) / 127.5   

    im_h, im_w, im_c = im.shape
    x, y = get_scale_factor(im_h, im_w, ref_size) 

    # resize image
    im = cv2.resize(im, None, fx = x, fy = y, interpolation = cv2.INTER_AREA)

    # prepare input shape
    im = np.transpose(im)
    im = np.swapaxes(im, 1, 2)
    im = np.expand_dims(im, axis = 0).astype('float32')

    # Initialize session and get prediction
    session = onnxruntime.InferenceSession(args.model_path, None)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    result = session.run([output_name], {input_name: im})

    # refine matte
    matte = (np.squeeze(result[0]) * 255).astype('uint8')
    matte = cv2.resize(matte, dsize=(im_w, im_h), interpolation = cv2.INTER_AREA)

    cv2.imwrite(args.output_path, matte)


    # 保存彩色图片
    # b,g,r = cv2.split(img)
    # rbga_img = cv2.merge((b, g, r, matte))
    rbga_img = cv2.merge((img, matte))
    cv2.imwrite('rbga_result.png',rbga_img)


代码比较简单,给出的是以长边512等比例缩放,最后我添加了一下保存成RGBA的彩色图片.

2.抠图效果

测试图片

在这里插入图片描述

测试结果

在这里插入图片描述
可以发现抠图已经达到了丝发级别,对于清晰的图片抠图还是很准确的.

后期可以补充一下对视频的抠图.

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

五小白

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值