SAM验证代码详解

SAM验证代码详解

from os import listdir, makedirs
from os.path import join, isfile, basename
from glob import glob
from tqdm import tqdm
from time import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from segment_anything.modeling import MaskDecoder, PromptEncoder, TwoWayTransformer
from tiny_vit_sam import TinyViT
from matplotlib import pyplot as plt
import cv2
import argparse
from collections import OrderedDict
import pandas as pd
from datetime import datetime
import matplotlib.pyplot as plt
import os
# 这里把os.path.join赋值给join变量
join = os.path.join
from segment_anything import sam_model_registry
from skimage import io, transform
import argparse
# 把npz的图片 进行一个分割得到segs 并把分割后的png图片保存在./overlay下

# %% set seeds
torch.set_float32_matmul_precision('high')
torch.manual_seed(2024)
torch.cuda.manual_seed(2024)
np.random.seed(2024)

argparse.ArgumentParser()用于构建命令行参数和选项

parser = argparse.ArgumentParser()
parser.add_argument(
    '-i',
    '--input_dir',
    type=str,
    default='./data/Viral Pneumonia/dd_dd/npz/',
    # required=True,
    help='root directory of the data',
)
parser.add_argument(
    '-o',
    '--output_dir',
    type=str,
    default='./data/Viral Pneumonia/dd_dd/new_sam_segs/',
    help='directory to save the prediction',
)
parser.add_argument(
    "-chk",
    "--checkpoint",
    type=str,
    default="./train/work_dir/medsam_model_best.pth",
    # default = "./train/work_dir/MedSAM-ViT-B-20240410-1922/medsam_model_best.pth",
    help="path to the trained model",
)
parser.add_argument(
    '-lite_medsam_checkpoint_path',
    type=str,
    default="work_dir/LiteMedSAM/lite_medsam.pth",
    help='path to the checkpoint of MedSAM-Lite',
)
parser.add_argument(
    '-device',
    type=str,
    default="cpu",
    help='device to run the inference',
)
parser.add_argument(
    '-num_workers',
    type=int,
    default=4,
    help='number of workers for inference with multiprocessing',
)
parser.add_argument(
    '--save_overlay',
    default=True,
    action='store_true',
    help='whether to save the overlay image'
)
parser.add_argument(
    '-png_save_dir',
    type=str,
    default='./overlay/new_sam_overlay',
    help='directory to save the overlay image'
)

配置程序,从args中取参数的值,并将其赋值给变量。

args = parser.parse_args()

data_root = args.input_dir
pred_save_dir = args.output_dir
save_overlay = args.save_overlay
num_workers = args.num_workers
if save_overlay:
    assert args.png_save_dir is not None, "Please specify the directory to save the overlay image"
    png_save_dir = args.png_save_dir
    makedirs(png_save_dir, exist_ok=True)

# medsam_checkpoint_path = args.medsam_checkpoint_path
makedirs(pred_save_dir, exist_ok=True)
device = torch.device(args.device)
image_size = 256

将给定的图像调整大小,以使图像的最长边等于目标长度,同时保持图像的宽高比不变。

def resize_longest_side(image, target_length=256):
    oldh, oldw = image.shape[0], image.shape[1]
    scale = target_length * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    neww, newh = int(neww + 0.5), int(newh + 0.5)
    target_size = (neww, newh)
    return cv2.resize(image, target_size, interpolation=cv2.INTER_AREA)

用于将图像或标签补齐到特定的尺寸

def pad_image(image, target_size=256):
    """
    Pad image to target_size
    Expects a numpy array with shape HxWxC in uint8 format.
    """
    # Pad
    #获取了图像的原始高度(h)和宽度(w)。
    h, w = image.shape[0], image.shape[1]
    #这两行代码计算了在高度和宽度方向上需要填充的像素数量。
    padh = target_size - h
    padw = target_size - w
    if len(image.shape) == 3:  ## Pad image
        image_padded = np.pad(image, ((0, padh), (0, padw), (0, 0)))
    else:  ## Pad gt mask
        image_padded = np.pad(image, ((0, padh), (0, padw)))

    return image_padded

用于处理图像和相关信息,以生成低分辨率的掩码

class MedSAM_Lite(nn.Module):
    def __init__(
            self,
            image_encoder,
            mask_decoder,
            prompt_encoder
    ):
        super().__init__()
        self.image_encoder = image_encoder
        self.mask_decoder = mask_decoder
        self.prompt_encoder = prompt_encoder

    def forward(self, image, box_np):
    	#使用图像编码器对输入图像进行编码,生成图像嵌入(特征表示)
        image_embedding = self.image_encoder(image)  # (B, 256, 64, 64)
        # do not compute gradients for prompt encoder
        with torch.no_grad():
   		   #将边界框信息从NumPy数组转换为PyTorch张量
            box_torch = torch.as_tensor(box_np, dtype=torch.float32, device=image.device)
            #是2维的,.增加一个维度
            if len(box_torch.shape) == 2:
                box_torch = box_torch[:, None, :]  # (B, 1, 4)
                
        sparse_embeddings, dense_embeddings = self.prompt_encoder(
            points=None,
            boxes=box_np,
            masks=None,
        )
        low_res_masks, iou_predictions = self.mask_decoder(
            image_embeddings=image_embedding,  # (B, 256, 64, 64)
            image_pe=self.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
            multimask_output=False,
        )  # (B, 1, 256, 256)

        return low_res_masks

将模型预测出的掩模调整到原始图像的大小
首先裁剪掩模以匹配一个中间大小,然后使用双线性插值将掩模调整(上采样)回原始大小

    @torch.no_grad()
    def postprocess_masks(self, masks, new_size, original_size):
        # Crop
        masks = masks[..., :new_size[0], :new_size[1]]
        # Resize
        masks = F.interpolate(
            masks,
            size=(original_size[0], original_size[1]),
            mode="bilinear",
            align_corners=False,
        )
        return masks

展示mask

def show_mask(mask, ax, mask_color=None, alpha=0.5):
    if mask_color is not None:
        color = np.concatenate([mask_color, np.array([alpha])], axis=0)
    else:
        color = np.array([251 / 255, 252 / 255, 30 / 255, alpha])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

展示box

def show_box(box, ax, edgecolor='blue'):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=edgecolor, facecolor=(0, 0, 0, 0), lw=2))

分析掩码图像来计算包含所有非零像素的最小边界框,并且允许在边界框坐标上添加一个小的扰动

def get_bbox256(mask_256, bbox_shift=3):
    y_indices, x_indices = np.where(mask_256 > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    # add perturbation to bounding box coordinates and test the robustness
    # this can be removed if you do not want to test the robustness
    H, W = mask_256.shape
    x_min = max(0, x_min - bbox_shift)
    x_max = min(W, x_max + bbox_shift)
    y_min = max(0, y_min - bbox_shift)
    y_max = min(H, y_max + bbox_shift)
    bboxes256 = np.array([x_min, y_min, x_max, y_max])
    return bboxes256

从原始图像中得到的边界框坐标重新缩放到256x256尺寸的图像中对应的坐标。

def resize_box_to_256(box, original_size):
	#输入box形状相同且元素全为零的数组new_box
    new_box = np.zeros_like(box)
    #256除以原始图像尺寸中的最大值。这个比例用于保持图像的宽高比不变
    ratio = 256 / max(original_size)
    for i in range(len(box)):
        new_box[i] = int(box[i] * ratio)
    return new_box

它接受MedSAM模型、图像的嵌入表示、边界框坐标、以及目标图像的高度和宽度作为输入,返回分割掩码和原始的预测概率

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
	#将box_1024转换为PyTorch张量box_torch
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    #如果二维的,添加一个维度
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :]  # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed,  # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(),  # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings,  # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings,  # (B, 256, 64, 64)
        multimask_output=False,
    )
	#通过torch.sigmoid函数将逻辑斯特回归输出转换为预测概率。
    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)
	#预测概率的分辨率调整到与原始图像相匹配的尺寸
    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    #预测概率转换为NumPy数组
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    #通过阈值操作(阈值为0.5)生成二值分割掩码medsam_seg。
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg, low_res_pred
device = args.device
medsam_model = sam_model_registry["vit_b"](checkpoint=args.checkpoint)
medsam_model = medsam_model.to(device)
medsam_model.eval()

MedSAM_infer_npz_2DHA函数封装了很多内容,拆开方便理解

def MedSAM_infer_npz_2D(img_npz_file):

处理了图像数据的加载和调整大小的流程

    npz_name = basename(img_npz_file)
    npz_data = np.load(img_npz_file, 'r', allow_pickle=True)  # (H, W, 3)
    #从加载的数据中提取名为'imgs'的数组
    img_3c = npz_data['imgs']  # (H, W, 3)
    H, W = 256, 256
    img_256 = cv2.resize(
        img_3c,
        (H, W),  # LiteMedSAM (256*256)
        interpolation=cv2.INTER_NEAREST
    ).astype(np.uint8)
    img_1024 = cv2.resize(
        img_3c,
        (1024, 1024), # LiteMedSAM (256*256)
        interpolation=cv2.INTER_NEAREST
    ).astype(np.uint8)
    assert np.max(img_256) < 256, f'input data should be in range [0, 255], but got {np.unique(img_3c)}'

创建box

	#从加载的.npz文件数据中提取地面真实标签(gts)
    gt = npz_data['gts']
    #检查gt的形状是否已经是目标形状(256x256像素)
    if gt.shape != (H, W):
        gt = cv2.resize(
            gt.astype(np.uint8), (W, H),
            interpolation=cv2.INTER_NEAREST
        ).astype(np.uint8)
    gt = pad_image(gt)  # LiteMedSAM (256*256)
    label_ids = np.unique(gt)[1:]
    #从label_ids中随机选择一个标签,并生成一个新的二维数组gt2D
    import random
    gt2D = np.uint8(gt == random.choice(label_ids.tolist()))  # only one label, (256, 256)
    gt2D = np.uint8(gt2D > 0)
    y_indices, x_indices = np.where(gt2D > 0)
    # 通过np.min和np.max函数计算目标的最小和最大x、y坐标,这些坐标定义了围绕目标的最小边界框
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    # add perturbation to bounding box coordinates
    # H, W = gt2D.shape
    bbox_shift = 5
    x_min = max(0, x_min - random.randint(0, bbox_shift))
    x_max = min(W, x_max + random.randint(0, bbox_shift))
    y_min = max(0, y_min - random.randint(0, bbox_shift))
    y_max = min(H, y_max + random.randint(0, bbox_shift))
    boxes = np.array([[x_min, y_min, x_max, y_max]])

创建一个与img_3c相同高度和宽度的全零数组segs,用于存储每个像素的分割结果。

   segs = np.zeros((H,W), dtype=np.uint8)
   ## preprocessing
   # 它调整图像img_3c的最长边到1024像素
   img_1024 = resize_longest_side(img_1024, 1024) # LiteMedSAM (256)
   # 获取调整尺寸后的图像img_1024的新高度和宽度。
   newh, neww = img_1024.shape[:2]
   # !对图像img_1024进行最小-最大归一化,使所有像素值位于[0, 1]范围内
   img_1024_norm = (img_1024 - img_1024.min()) / np.clip(
       img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
   )
   img_1024_padded = pad_image(img_1024_norm, 1024) # LiteMedSAM (256)
   img_1024_tensor = torch.tensor(img_1024_padded).float().permute(2, 0, 1).unsqueeze(0).to(device)
   with torch.no_grad():
       # 使用medsammodel的image_encoder部分对预处理后的图像进行编码
       #!可以看出medsam模型的图片输入要1024*1024
       image_embedding = medsam_model.image_encoder(img_1024_tensor)  
   for idx, box in enumerate(boxes, start=1):
       # 它根据原图像尺寸(H, W)调整边界框box的尺寸,使其适应1024*1024的图像尺寸
       box_1024 = box / np.array([W, H, W, H]) * 1024
       box_1024 = box_1024[None, ...]
       medsam_mask, iou_pred = medsam_inference(medsam_model, image_embedding, box_1024, H, W)
       segs[medsam_mask > 0] = idx

将segs数组以压缩格式保存到指定路径的.npz文件中

 np.savez_compressed(
     join(pred_save_dir, npz_name),
     segs=segs,
 )

可视化

    # visualize image, mask and bounding box
    if save_overlay:
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].imshow(img_256)
        ax[1].imshow(img_256)
        ax[0].set_title("Image")
        ax[1].set_title("SAM Segmentation")
        ax[0].axis('off')
        ax[1].axis('off')
		#遍历并显示所有边界框和掩码
        for i, box in enumerate(boxes):
            color = np.random.rand(3)
            box_viz = box
            show_box(box_viz, ax[1], edgecolor=color)
            show_mask((segs == i + 1).astype(np.uint8), ax[1], mask_color=color)
        plt.tight_layout()
        plt.savefig(join(png_save_dir, npz_name.split(".")[0] + '.png'), dpi=300)
        plt.close()

它根据文件是3D图像还是2D图像来调用不同的推理函数,并记录每个文件处理的时间。

if __name__ == '__main__':
    # 使用glob函数搜索data_root目录下所有的.npz文件,并将它们排序。这样可以确保处理文件的顺序是一致的。
    img_npz_files = sorted(glob(join(data_root, '*.npz'), recursive=True))
    # 初始化一个有序字典efficiency,用于存储case和time。
    efficiency = OrderedDict()
    efficiency['case'] = []
    efficiency['time'] = []
    # 使用for循环遍历前20个.npz文件,tqdm用于显示进度条
    for img_npz_file in tqdm(img_npz_files[:-1]):
        # 这个时间值将用于后续计算处理特定文件所需的时间。
        start_time = time()
        # 看开头是不是‘3D'开头的
        if basename(img_npz_file).startswith('3D'):
            MedSAM_infer_npz_3D(img_npz_file)
        else:
            MedSAM_infer_npz_2D(img_npz_file)
        end_time = time()
        #
        efficiency['case'].append(basename(img_npz_file))
        efficiency['time'].append(end_time - start_time)
        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(current_time, 'file name:', basename(img_npz_file), 'time cost:', np.round(end_time - start_time, 4))
    efficiency_df = pd.DataFrame(efficiency)
    efficiency_df.to_csv(join(pred_save_dir, 'efficiency.csv'), index=False)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

吾在学习路

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值