文章目录
本文主要介绍visual prompt模型DINOv,该模型可输入八张目标示例图作为参考,告诉模型我要找的目标长这样,在新的图片上进行推理,实现实例分割的效果。
但一些复杂的场景,八张的示例图不能让模型完全的学习到目标的特征,因此扩展模型能力,让visual prompt数量不受限制,对实际场景应用是非常有必要的(附改造方法、改造代码)。
一、Closed-Set VS Open-set
Closed-Set模型只需要关注有限数量的已知类别,答案选项是预先定义的,这意味着模型的输出范围是有限的、固定的,并且只限于训练时已知的选项,例如YOLO;Open-Set模型可以识别不属于任何已知类别的样本,即其输出范围不是固定的,具备一定的泛化能力和鲁棒性,以应对这些未知的挑战,例如SAM。
在某些特定的应用场景中,仅仅依赖文本提示(text prompt)来描述目标对象,对于Open-Set大模型来说,可能并不足以实现精准识别。若能够额外提供示例图像(visual prompt),将有助于模型更准确地理解我们的意图,从而提升整体的识别效果。
下图是DINOv作者提供的demo界面,左上角输入油污推理图,左下角输入多张油污示例图,并用画笔进行mask,运行模型可得到右边的推理效果。
二、DINOv
2.1 论文和代码
论文名称:《Visual In-Context Prompting》
code:https://github.com/UX-Decoder/DINOv
demo:http://semantic-sam.xyzou.net:6099/
2.2 内容
上下文提示是一种利用少量示例任务来指导模型完成新任务的技术。在视觉任务中,这种技术可以通过提供一组带有标签的图像作为示例,来引导模型理解和解决新的视觉任务。
模型通过学习少量的带有标签的样本图像,提取出这些图像中的关键特征和模式,然后利用这些特征和模式来生成针对新图像的查询。这个查询可以引导模型在新图像中定位并分割出目标物体。具体来说,模型可能通过学习示例图像中的物体形状、颜色、纹理等特征,以及这些特征与标签之间的关系,来构造出查询。然后,模型将这个查询应用于其他图像,通过匹配和比较查询与图像中的特征,来定位并分割出目标物体。最终,模型会生成一个掩码,标记出分割出的物体区域。
以图片作为提示(visual prompt),在提示图上通过笔画、画mask等方法作为视觉prompt,可推理出侧视图中同类目标,达到zero-shot目标分割的效果。
说明:在降落伞进行mask标注,在新的降落伞场景可分割出降落伞,其他场景同理
2.3 安装部署
系统要求:gcc版本>=4.9
# 1、离线安装detectron2
# 下载https://github.com/MaureenZOU/detectron2-xyz.git
Unzip detectron2-xyz.zip # 解压
Cd detectron2-xyz
Pip install -e .
# 2、离线安装panopticapi
# 下载https://github.com/cocodataset/panopticapi.git
Unzip panopticapi.zip # 解压
Cd panopticapi
Pip install -e .
# 3、启动DINOv
# 下载DINOv,https://github.com/UX-Decoder/DINOv
Unzip DINOv.zip # 解压
cd DINOv
python -m pip install -r requirements.txt
python demo_openset.py --ckpt /path/to/swinL/ckpt
# 终端返回下图链接
注:在浏览器访问public URL,建议使用梯子,local URL直接用即可
2.4 使用效果
通过界面输入八张示例图,在一些大目标、规整目标(如矩形、圆形),效果较好,在复杂场景、小目标、不规则物体,无法达到预期效果,例如墙缝缺陷,无法分割裂缝。
三、多visual prompt 改造
使用八张图片作为示例图,可能无法完全学习到目标。在实际使用中,我们可能采集到一小部分图片,例如50张、100张等;如何让DINOv不受限制,可支持多张输入呢?
3.1 获取示例图mask
使用labelme标注工具,生成json标注文件,使用下面代码将json转化为标注mask图。
import json
import cv2
import os
import matplotlib.pyplot as plt
import numpy as np
def generate_mask(img_path, json_path, save_path):
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
mask = np.zeros_like(img)
with open(json_path, "r") as f:
tmp = f.read()
tmp = json.loads(tmp)
tmp_shapes = tmp["shapes"]
for shape in tmp_shapes:
points = shape["points"]
points = np.array(points, np.int32)
cv2.fillPoly(mask, [points], (255, 255, 255))
img_add = cv2.addWeighted(mask, 0.3,img,0.7,0)
cv2.imwrite(save_path, mask)
if __name__ == "__main__":
imgs_dir = "./imgs" # 图片目录
jsons_dir = "./jsons" # 标注的json文件存放目录
save_dir = "./masks" # 生成mask图保存目录
img_files = os.listdir(imgs_dir)
for img_name in img_files:
img_path = os.path.join(imgs_dir, img_name)
json_path = os.path.join(jsons_dir, img_name.split('.')[0]+'.json')
if os.path.exists(json_path):
save_path = os.path.join(save_dir, img_name)
generate_mask(img_path, json_path, save_path)
3.2 修改函数参数
修改文件路径:demo/openset_task.py
作用:将原8张图输入修改为列表不限制输入
# 原代码31-37行
def task_openset(model,generic_vp1, generic_vp2, generic_vp3, generic_vp4,
generic_vp5, generic_vp6, generic_vp7, generic_vp8, image_tgt=None, text_size=640,hole_scale=100,island_scale=100):
in_context_examples = [generic_vp1, generic_vp2, generic_vp3, generic_vp4,
generic_vp5, generic_vp6, generic_vp7, generic_vp8]
in_context_examples = [x for x in in_context_examples if x is not None]
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
# 替换代码
def task_openset(model,refer_img_list, image_tgt=None, text_size=640,hole_scale=100,island_scale=100):
# in_context_examples = [generic_vp1, generic_vp2, generic_vp3, generic_vp4,
# generic_vp5, generic_vp6, generic_vp7, generic_vp8]
in_context_examples = refer_img_list
in_context_examples = [x for x in in_context_examples if x is not None]
t = []
t.append(transforms.Resize(int(text_size), interpolation=Image.BICUBIC))
3.3 推理代码
自定义imgs_dir、mask_dir、tgt_dir,执行代码,可在save_dir中找到结果图
import torch
import argparse
from PIL import Image
import cv2
import os
from dinov.BaseModel import BaseModel
from dinov import build_model
from utils.arguments import load_opt_from_config_file
from demo.openset_task import task_openset
def parse_option():
parser = argparse.ArgumentParser('DINOv Demo', add_help=False)
parser.add_argument('--conf_files', default="configs/dinov_sam_coco_swinl_train.yaml", metavar="FILE", help='path to config file', )
parser.add_argument('--ckpt', default="model_swinL.pth", metavar="FILE", help='path to ckpt')
parser.add_argument('--port', default=6099, type=int, help='path to ckpt', )
args = parser.parse_args()
return args
'''
build args
'''
args = parse_option()
'''
build model
'''
sam_cfg=args.conf_files
opt = load_opt_from_config_file(sam_cfg)
model_sam = BaseModel(opt, build_model(opt)).from_pretrained(args.ckpt).eval().cuda()
@torch.no_grad()
def inference(refer_img_list, image2,*args, **kwargs):
with torch.autocast(device_type='cuda', dtype=torch.float16):
model=model_sam
a= task_openset(model, refer_img_list, image2, *args, **kwargs)
return a
"""
读取image和labelme标注的mask图
推理一整个目录的图片
"""
def inference_dir(imgs_dir, mask_dir, tgt_dir, save_dir):
files = os.listdir(tgt_dir)
result_img_list = []
for file in files:
print(f'==={file}==')
image_tgt_path = os.path.join(tgt_dir, file)
image_tgt = Image.open(image_tgt_path).convert('RGB')
refer_img_list = []
img_files = os.listdir(imgs_dir)
for img_name in img_files:
img_path = os.path.join(imgs_dir, img_name)
mask_path = os.path.join(mask_dir, img_name)
if os.path.exists(mask_path):
generic_vp= {"image":"", "mask":""}
generic_vp["image"] = Image.open(img_path).convert('RGB')
generic_vp["mask"] = Image.open(mask_path).convert('RGB')
refer_img_list.append(generic_vp)
# print(len(refer_img_list))
res = inference(refer_img_list, image_tgt)
res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(save_dir, os.path.basename(image_tgt_path)), res)
if __name__ == "__main__":
imgs_dir = "./test_img_2/group_50/refer/imgs" # 示例图目录
mask_dir = "./test_img_2/group_50/refer/masks" # 示例mask图目录
tgt_dir = "./test_img_2/tgt" # 推理图目录
save_dir = "results/group_50/" # 结果保存目录
inference_dir(imgs_dir, mask_dir, tgt_dir, save_dir)
3.4 效果的提升!
在验证多visual prompt对结果的影响,采用了对比实验。在光学镜头缺陷场景中,8张visual prompt和50张visual prompt进行对比,50张visual prompt得到的推理效果更优!
四、总结
如果文章对您有所帮助,记得点赞、收藏、评论探讨✌️