使用GroundingDINO +SAM实现商品抠图

在电商场景下需要对主商品进行审核、替换商品图背景等场景需求。就需要对商家提供的商品图进行抠图、替换背景等操作。

当然有不少的第三方工具可供运营和产品使用,但都至少需要这样几个步骤,上传图片、选择抠图区域或边缘、等待抠图生成、下载抠图结果。

且先不管商品图片涉及泄漏的问题,单是抠图效果如何也不是很好保证的。比如有这样一个商品图,我们希望抠出手机这个主商品,去掉背景色和背景物。

期望得到的理想效果是这样:[这也是通过我们算法抠图后得到的一个效果图]

比如这个抠图软件remove.bg,它抠完后的效果是这样的。基本不能准确识别出商品,而且毛边也比较严重。需要人工再点恢复来修正它。这样一张图片的抠图成本不降反增了。

我们做这样一个项目的初衷就是为了给运营和产品同学降本提效的,所以,对算法抠图后的效果评价比较严格,不仅仅是识别商品的准确率、还要求抠图后的毛边、分辨率等不能太差。即要求商品识别率、抠图准确率高到90%,运营人工抽检和返工率不超过10%。这样才能大幅提高运营的人工效果,降低抠图的人工成本,真正达到降本增效的目的。

这个项目里主要用的就是 GroundDINO + SAM: segment Anything

https://github.com/IDEA-Research/Grounded-Segment-Anything

https://github.com/facebookresearch/segment-anything

实现步骤:

加载包和基本方法
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time: 2023/10/27 11:17 上午
# @Author: gao
# @Project : image_tagging
# @File: demo_sam

"""
抠图, 从EXCEL中获取商品URL和存放路径
"""
import pandas as pd
import numpy as np
import random
import os
import sys
import torch
from modelscope.utils.constant import Tasks
from modelscope.pipelines import pipeline
from modelscope.preprocessors.image import load_image

sys.path.append('/root/git/Grounded-Segment-Anything')
import argparse
import copy
import torch
# print(torch.__version__)
import urllib.request
from IPython.display import display
from PIL import Image, ImageDraw, ImageFont
from torchvision.ops import box_convert

# Grounding DINO
import GroundingDINO.groundingdino.datasets.transforms as T
from GroundingDINO.groundingdino.models import build_model
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.slconfig import SLConfig
from GroundingDINO.groundingdino.util.utils import clean_state_dict, \
    get_phrases_from_posmap
from GroundingDINO.groundingdino.util.inference import annotate, load_image, \
    predict

import supervision as sv

# segment anything
from segment_anything import build_sam, SamPredictor
import cv2
import matplotlib.pyplot as plt

# diffusers
import PIL
import requests
import torch
from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline

from huggingface_hub import hf_hub_download


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*',
               s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*',
               s=marker_size, edgecolor='white', linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0),
                      lw=2))


def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    # cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
    cache_config_file = ckpt_config_filename
    args = SLConfig.fromfile(cache_config_file)
    args.device = device
    model = build_model(args)

    # cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(filename, map_location=device)
    log = model.load_state_dict(clean_state_dict(checkpoint['model']),
                                strict=False)
    print("Model loaded from {} \n => {}".format(filename, log))
    _ = model.eval()
    return model


def detect(image, text_prompt, model, box_threshold=0.3, text_threshold=0.25):
    boxes, logits, phrases = predict(
        model=model,
        image=image,
        caption=text_prompt,
        box_threshold=box_threshold,
        text_threshold=text_threshold
    )

    annotated_frame = annotate(image_source=image_source, boxes=boxes,
                               logits=logits, phrases=phrases)
    annotated_frame = annotated_frame[..., ::-1]  # BGR to RGB
    return annotated_frame, boxes


def segment(image, sam_model, boxes):
    sam_model.set_image(image)
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

    transformed_boxes = sam_model.transform.apply_boxes_torch(
        boxes_xyxy.to(device), image.shape[:2])
    masks, _, _ = sam_model.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=False,
    )
    return masks.cpu()


def draw_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray(
        (mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))
对商品进行抠图实现
if __name__ == '__main__':
    keyword_lookup_dict = {"鞋": "shoes", "表": "watch", "手机": "cell phone", "床": "bed", "沙发": "sofa", "裤": "pants", "杯": "cup", "笔记本电脑": "laptop", "车": "car", "锅": "pot", "奶粉": "milk powder", "键盘": "keyboard", "面霜": "cream", "口红": "lipstick", "靴": "boots", "烧水壶": "Kettle", "板鞋": "sneakers", "香水": "perfume", "面膜": "Facial mask", "碗": "bowl", "内裤": "underwear", "吹风机": "hair dryer", "大衣": "coat", "桌子": "table", "剃须刀": "shaver", "音响": "Audio", "空调": "air conditioner", "吊坠": "pendant", "电饭煲": "rice cooker", "睡衣": "pajamas", "按摩仪": "Massager", "平板": "flat", "尿不湿": "Diapers", "沙发床": "sofa bed", "茶几": "coffee table", "清洗机": "washing machine", "投影仪": "projector", "拓展坞": "expansion dock", "无人机": "drone", "净饮机": "Drinking machine", "台灯": "desk lamp", "摄像头": "Camera", "抽水器": "water pump", "螺丝刀": "screwdriver", "烤箱": "oven", "牛奶": "milk", "筋膜枪": "fascia gun", "电暖器": "electric heater", "耳钉": "earrings", "压力锅": "pressure cooker", "足浴盆": "footbath", "泡脚桶": "Foot bath bucket"}
    images_download_dir = "/root/notebook/gao/sam_dino/images_source"
    result_save_dir = "/root/notebook/gao/sam_dino/images_result"
    if not os.path.exists(images_download_dir):
        os.mkdir(images_download_dir)
    if not os.path.exists(result_save_dir):
        os.mkdir(result_save_dir)
    # 1、从EXCEL中读取商品URL和标题信息 , 这里第一行是标题header=0
    sku_file = pd.read_excel(open('商品抠图列表.xlsx', 'rb'), sheet_name='Sheet1',
                             header=0, dtype={"sku_id": str})
    n, m = sku_file.shape

    # 3、sam_demo抠图
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = 'cpu'
    ckpt_repo_id = "ShilongLiu/GroundingDINO"
    ckpt_filenmae = "/root/git/Grounded-Segment-Anything/groundingdino_swinb_cogcoor.pth"
    ckpt_config_filename = "/root/git/Grounded-Segment-Anything/GroundingDINO_SwinB.cfg.py"

    groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae,
                                        ckpt_config_filename, device)

    for i in np.arange(n):
        line = sku_file.loc[i]
        sku_id = line['sku_id']
        sku_title = str(line['标题'])
        img_save_name = line['输出文件名']
        img_url = line['待抠图文件']

        if img_save_name.endswith("png"):
            "JPG转为PNG"

        # 图片下载存放
        img_download_path = os.path.join(images_download_dir,
                                         img_url.split("/")[-1])
        if not os.path.exists(img_download_path):
            try:
                urllib.request.urlretrieve(img_url, filename=img_download_path)
            except:
                ""
        else:
            ""
        if os.path.getsize(img_download_path) == 0:
            sys.stderr.write("图片URL下载失败")

        prompt_list = []
        for wrd in keyword_lookup_dict.keys():
            if sku_title.find(wrd) != -1:
                prompt_list.append(keyword_lookup_dict[wrd])
            else:
                continue

        if len(prompt_list) < 2:
            prompt_list.append('item.object.entity')
            prompt_text = '.'.join(prompt_list)
        else:
            prompt_text = '.'.join(prompt_list[:10])

        # 抠图图片存放
        result_save_path = os.path.join(result_save_dir, img_save_name)

        try:
            # 开始抠图
            image_source, image = load_image(img_download_path)
            annotated_frame, detected_boxes = detect(image,
                                                     text_prompt=prompt_text,
                                                     model=groundingdino_model)

            print("detected_boxes:\n", detected_boxes)

            # ! wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
            sam_checkpoint = 'sam_vit_h_4b8939.pth'
            sam_predictor = SamPredictor(
                build_sam(checkpoint=sam_checkpoint).to(device))

            segmented_frame_masks = segment(image_source, sam_predictor,
                                            boxes=detected_boxes)

            """
            1、抠图后图片与 原图片 色差,RGBA --> BGRA 
            2、边缘粗糙
            """

            # 抠图
            image = Image.open(img_download_path)  # 打开图片
            img = np.array(image.convert('RGBA'))  # 将图片转化为张量形式 RGBA
            print("原图shape: ", img.shape)

            new_img = np.zeros(img.shape, dtype='int32')
            print("抠图shape:", new_img.shape)
            new_img[:] = 255  # 抠图图片初始化为全白色

            # id = 0
            for t in np.arange(segmented_frame_masks.shape[0]):
                img1 = segmented_frame_masks[t][0]  # 掩码

                sq = img1.shape  # 掩码尺寸
                height = sq[0]
                width = sq[1]

                for p in range(0, height):
                    for q in range(0, width):
                        if img1[p, q].numpy():
                            new_img[p, q] = [img[p, q][2], img[p, q][1],
                                             img[p, q][0], img[p, q][
                                                 3]]  # #遍历图片,掩码部分赋值真实图像原始值
            # print("new img", new_img)
            # new_img =cv2.cvtColor(new_img, cv2.COLOR_BGR2RGB) #cv2读取的bgr通道转为rbg,否则图片偏蓝
            # cv2.imwrite(img_download_path.split("/")[-1].split(".")[0]+"_2_2.jpg", new_img[:,:,::-1]) # 同 new_img[...,::-1]
            # new_img1 = cv2.cvtColor(new_img, cv2.COLOR_RGBA2BGRA)
            cv2.imwrite(result_save_path, new_img)
            # print(id)
        except:
            ""
抠图效果展示:

商品提示词:手机

遗留问题:

当然算法虽然实现了大部分商品的抠图准确率,但毛边问题还是有的,需要加光滑函数来解决。商品识别的准确率也是依赖于提示词的,需要准备大量准确的提示词。

Done

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值