medsam ,数入xml +img, 根据检测框,原图显示分割效果,加上点的减少处理

1、输入每张图片的多个检测框,得到这张图片的sam 分割结果

import numpy as np
import matplotlib.pyplot as plt
import os

join = os.path.join
import torch
from segment_anything import sam_model_registry
from skimage import io, transform
import torch.nn.functional as F
import argparse


@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :]  # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed,  # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
        multimask_output=False,
    )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg


# %% load model and image
parser = argparse.ArgumentParser(
    description="run inference on testing set based on MedSAM"
)
parser.add_argument(
    "-i",
    "--data_path",
    type=str,
    default="assets/img_demo.png",
    help="path to the data folder",
)
parser.add_argument(
    "-o",
    "--seg_path",
    type=str,
    default="assets/",
    help="path to the segmentation folder",
)
parser.add_argument(
    "--box",
    type=list,
    default=[95, 255, 190, 350],
    help="bounding box of the segmentation target",
)
parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument(
    "-chk",
    "--checkpoint",
    type=str,
    default="work_dir/MedSAM/medsam_vit_b.pth",
    # default="/home/syy/code/sam/MedSAM-LiteMedSAM/carotid_MedSAM-Lite-Box-20240508-1808/medsam_lite_best1.pth",
    help="path to the trained model",
)
args = parser.parse_args()

device = args.device
medsam_model = sam_model_registry["vit_b"](checkpoint=args.checkpoint)
medsam_model = medsam_model.to(device)
medsam_model.eval()
print("=====================================> 模型加载完毕")


import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
import os
import random 


import os
import xml.etree.ElementTree as ET
import cv2



def parse_xml(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    image_name = root.find('filename').text
 
    boxes = []
    labels = []

    for obj in root.findall('object'):
        label = obj.find('name').text
        bbox = obj.find('bndbox')
        x1 = int(bbox.find('xmin').text)
        y1 = int(bbox.find('ymin').text)
        x2 = int(bbox.find('xmax').text)
        y2 = int(bbox.find('ymax').text)
        boxes.append((x1, y1, x2, y2))
        labels.append(label)

    return image_name, boxes, labels

def process_xmls(xmls_dir):
    results = []
    xml_lists = os.listdir(xmls_dir)
    xml_lists.sort()
    for xml_file in xml_lists[0:200]:
        if xml_file.endswith('.xml'):
            xml_path = os.path.join(xmls_dir, xml_file)
            result = parse_xml(xml_path)
            results.append(result)

    return results



def show_mask(mask, ax, random_color=False):
    #  mask  模型预测的分割图 01  目标和背景
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.1]) #透明度0.3
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) #将掩码和颜色相乘,得到最终的带有颜色的掩码图像


    ax.imshow(mask_image) # 不显示mask区域

    #########################################
    # 找到掩码的轮廓
    contours, _ = cv2.findContours((mask * 255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    # 对最大的轮廓进行逼近处理,减少轮廓点的数量
    reduction_factor = 0.002 #0  #0.005
    if contours:  #没有会返回空
        areas = [cv2.contourArea(cnt) for cnt in contours]
        # 找到最大面积的轮廓的索引
        max_area_index = np.argmax(areas)
        # 获取最大面积的轮廓
        largest_contour = contours[max_area_index]           
        # 对每个轮廓进行逼近处理,减少轮廓

        if reduction_factor > 0.000001:
            epsilon = reduction_factor * cv2.arcLength(largest_contour, True)
            approx = cv2.approxPolyDP(largest_contour, epsilon, True)  # 最大轮廓的操作,平滑轮廓点
            # 绘制轮廓,减少的点,平滑的不是很好,换一个
            print("点有没有减少,len(approx),len(contours)",len(approx),len(largest_contour))
            ax.plot(approx[:, 0, 0], approx[:, 0, 1], color='red', linewidth=1)
        else:
            ax.plot(largest_contour[:, 0, 0], largest_contour[:, 0, 1], color='red', linewidth=0.3)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='yellow', facecolor=(0,0,0,0), lw=1))



def prompt_box_pred(xmls_dir,imgs_dir,save_dir):
    # 示例用法
    results = process_xmls(xmls_dir)
    for ind, res in enumerate(results):
        image_name, boxes, labels = res
        print(ind,': Image:', image_name)

        # 读取图片和xml 文件,获取坐标
        img_path = os.path.join(imgs_dir,image_name)
        # image = cv2.imread(img_path)
        # if image is None:
        #     print("=======================> 图片路径不存在",img_path)
        #     continue
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 
        # image_height, image_width = image.shape[:2]


        img_np = io.imread(img_path)
        if len(img_np.shape) == 2:
            img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
        else:
            img_3c = img_np
        H, W, _ = img_3c.shape
        # %% image preprocessing
        img_1024 = transform.resize(
            img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
        ).astype(np.uint8)
        img_1024 = (img_1024 - img_1024.min()) / np.clip(
            img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
        )  # normalize to [0, 1], (H, W, 3)
        # convert the shape to (3, H, W)
        img_1024_tensor = (
            torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
        )        


        plt.figure(figsize=(10, 10))  #画布的大小
        plt.imshow(img_3c)

        for box, label in zip(boxes, labels):
            x1, y1, x2, y2 = box
            print('  Label:', label)
            print('  Box:', x1, y1, x2, y2)

            input_box = np.array(box) 
            box_np = np.array([box]) 
            # transfer box_np t0 1024x1024 scale
            box_1024 = box_np / np.array([W, H, W, H]) * 1024
            #  预测图片的分割标签
            with torch.no_grad():
                image_embedding = medsam_model.image_encoder(img_1024_tensor)  # (1, 256, 64, 64)

            medsam_seg = medsam_inference(medsam_model, image_embedding, box_1024, H, W)  #分割最后输出原图大小
        
            # print(medsam_seg.shape) #(127, 212)
            # print(img_3c.shape) # (127, 212, 3)

            show_mask(medsam_seg, plt.gca())
            show_box(input_box, plt.gca())

        plt.axis('off')
        # plt.show()
        ###  bbox_inches='tight'表示将图像边缘紧贴画布边缘,pad_inches=0表示不添加额外的边距
        plt.savefig(save_dir + image_name,bbox_inches='tight', pad_inches=0) #) # 一张图保存多个框   
            
if __name__ == "__main__":
    xmls_dir = '/home/syy/data/甲乳/breast/image2/xmls'
    imgs_dir = '/home/syy/data/甲乳/breast/image2/images' 
    save_dir = "/home/syy/data/甲乳/breast/image2/medsam/"   
    
    os.makedirs(save_dir,exist_ok=True)
    prompt_box_pred(xmls_dir,imgs_dir,save_dir)    

在这里插入图片描述

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值