Efficientsam验证代码详解

首先要确定模型输入的维度,不然就会发现性能不是很好。Efficientsam的输入维度是1024 * 1024,然后输出的时候要改为256 * 256.

from efficient_sam.build_efficient_sam import build_efficient_sam_vitt, build_efficient_sam_vits
from os.path import join, isfile, basename
from PIL import Image
from torchvision import transforms
import torch
import numpy as np
import zipfile
import matplotlib.pyplot as plt
import os
import argparse
from torchvision.transforms import ToTensor
import cv2
import glob
from os.path import join, basename
from collections import OrderedDict
from time import time
from datetime import datetime
import pandas as pd
from tqdm import tqdm
import random
np.random.seed(2023)
parser = argparse.ArgumentParser()
parser.add_argument(
    '-i',
    '--input_dir',
    type=str,
    default='./COVID-19-Dataset/Viral Pneumonia/images/',
    # required=True,
    help='root directory of the data',
)
parser.add_argument(
    '-o',
    '--output_dir',
    type=str,
    default='./data/Viral Pneumonia/dd_dd/new_efficientSAM_segs(samll)1/',
    help='directory to save the prediction',
)
parser.add_argument(
    '-png_save_dir',
    type=str,
    default='./overlay/new_efficientsam(samll)_overlay1',
    help='directory to save the overlay image'
)
parser.add_argument(
    '-gt_dir',
    type=str,
    default='./COVID-19-Dataset/Viral Pneumonia/masks/',
    help='directory to save the overlay image'
)
args = parser.parse_args()
data_root = args.input_dir
gt_root = args.gt_dir
pred_save_dir = args.output_dir
png_save_dir = args.png_save_dir
save_folder = "overlay/new_efficientsam(samll)_overlay1"
if not os.path.exists(save_folder):
    os.makedirs(save_folder)
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.8])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
def show_box(box,ax,color):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor='none', lw=2)
    )
def show_anns_ours(mask,ax,color):
    ax.set_autoscale_on(False)

    # Create a color mask image with the same dimensions as the mask and 4 color channels (RGBA)
    color_image = np.zeros((*mask.shape, 4))

    # Apply the mask color and alpha to the masked regions
    color_image[..., :3] = color  # Set color
    color_image[..., 3] = mask.astype(float) * 0.7  # Set alpha for masked areas

    ax.imshow(color_image)

random_color = [random.random() for _ in range(3)]

这个模块很重要,img_tensor的维度在输入进model的时候要是1024*1024,原来输入进去的图片是299 * 299,所以image_np用resize变成1024 * 1024

def run_ours_box_or_points(img_path, pts_sampled, pts_labels, model):
    image_np = np.array(Image.open(img_path).convert('L'))
     ***image_np = cv2.resize(
        image_np,
        (1024, 1024),
        interpolation=cv2.INTER_NEAREST
    )*** 
    img_tensor = ToTensor()(image_np)
    #将点样本转换为PyTorch张量,并重塑形状以符合模型的输入要求。
    pts_sampled = torch.reshape(torch.tensor(pts_sampled, dtype=torch.float32), [1, 1, -1, 2])  # 确保输入是float32
    #将点样本的标签转换为PyTorch张量
    pts_labels = torch.reshape(torch.tensor(pts_labels, dtype=torch.float32), [1, 1, -1])  # 确保标签也是float32
    #获取预测的逻辑值(logits)和预测的IoU值。
    predicted_logits, predicted_iou = model(
        img_tensor[None, ...],
        pts_sampled,
        pts_labels,
    )
    sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)
    predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)
    predicted_logits = torch.take_along_dim(
        predicted_logits, sorted_ids[..., None, None], dim=2
    )
    # 使用astype(int)将True/False转换为1/0
    # return (torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy())
    return (predicted_logits[0, 0, 0, :, :] > 0.5).cpu().detach().numpy()
def process_efficientsam_mask(boolean_array, original_size):
 # 初始化一个全零的数组,用于存储分割掩码
 efficientsam_mask = np.zeros(original_size, dtype=np.uint16)

 # 遍历布尔数组,根据True/False为分割掩码上色
 for idx, (row, col) in enumerate(np.argwhere(boolean_array)):
     efficientsam_mask[row, col] = idx + 1  # 将True的位置标记为对应的索引+1

 return efficientsam_mask
def pad_image(image, target_size=256):
    """
    Pad image to target_size
    Expects a numpy array with shape HxWxC in uint8 format.
    """
    # Pad
    h, w = image.shape[0], image.shape[1]
    padh = target_size - h
    padw = target_size - w
    if len(image.shape) == 3: ## Pad image
        image_padded = np.pad(image, ((0, padh), (0, padw), (0, 0)))
    else: ## Pad gt mask
        image_padded = np.pad(image, ((0, padh), (0, padw)))
    return image_padded
# efficient_sam_vitt_model = build_efficient_sam_vitt()
# efficient_sam_vitt_model.eval()
efficient_sam_vits_model= build_efficient_sam_vits()
efficient_sam_vits_model.eval()
# 使用glob查找所有.png图像和掩码
def process_image_2D(image_path, gt_path):
# for image_path, gt_path in zip(image_paths, gt_paths):
    gt_img = Image.open(gt_path).convert('L')
    gt = np.array(gt_img)
    # 转换为灰度图,确保结果是二维的
    gt = pad_image(gt)  # (256, 256)
    label_ids = np.unique(gt)[1:]
    import random
    gt2D = np.uint8(gt == random.choice(label_ids.tolist()))  # only one label, (256, 256)
    indices = np.where(gt2D > 0)
    if len(indices) == 2:  # 确保只有两个维度
        y_indices, x_indices = indices
    else:
        raise ValueError("gt2D不是一个二维数组,检查pad_image函数和gt2D的生成过程")
    # y_indices, x_indices = np.where(gt2D > 0)
    # 通过np.min和np.max函数计算目标的最小和最大x、y坐标,这些坐标定义了围绕目标的最小边界框
    x1, x2 = np.min(x_indices), np.max(x_indices)
    y1, y2 = np.min(y_indices), np.max(y_indices)
    # add perturbation to bounding box coordinates
    H, W = gt2D.shape
    bbox_shift = 5
    x1 = max(0, x1 - random.randint(0, bbox_shift))
    x2 = min(W, x2 + random.randint(0, bbox_shift))
    y1 = max(0, y1 - random.randint(0, bbox_shift))
    y2 = min(H, y2 + random.randint(0, bbox_shift))
    boxes = np.array([[x1, y1, x2, y2]])

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    input_point = np.array([[x1, y1], [x2, y2]])
    input_label = np.array([2, 3])
    img = Image.open(image_path).convert('L')
    #img_1024和img_256为了适应不同的输入和输出
    img_1024 = img.resize((1024, 1024), Image.NEAREST)
    img_256 = img.resize((256, 256), Image.NEAREST)
    image = np.array(img_256)
    show_points(input_point, input_label, ax[0])
    show_box([x1, y1, x2, y2], ax[1],random_color)

    segs = np.zeros(image.shape[:2], dtype=np.uint8)
    name = os.path.basename(image_path)
    if name.endswith('.png'):
        name = name[:-4]

    ax[1].imshow(image,cmap='gray')
    # mask_efficienet_sam_vitt = run_ours_box_or_points(image_path, input_point, input_label, efficient_sam_vitt_model)
    mask_efficient_sam_vits = run_ours_box_or_points(image_path, input_point, input_label, efficient_sam_vits_model)
    image_np = np.array(Image.open(image_path))
    # 获取原始图像的尺寸
    #original_size = image_np.shape[:2]
    original_size = image_1024.shape[:2]
    #process_efficientsam_mask要保证mask_efficient_sam_vits,和original_size的维度要一致
    efficientsam_mask = process_efficientsam_mask(mask_efficient_sam_vits, original_size)
    #因为最后的输出要是256*256,所以最后还要resize
    efficientsam_mask_resized = cv2.resize(efficientsam_mask, (256, 256), interpolation=cv2.INTER_NEAREST)
    segs[efficientsam_mask_resized > 0] = 1
    np.savez_compressed(
        join(pred_save_dir, name),
        segs=segs,
    )
    #这里注意输出的要是efficientsam_mask_resized
   # show_anns_ours(mask_efficient_sam_vits, ax[1],random_color)
   show_anns_ours(efficientsam_mask_resized.astype("uint8"), ax[1],random_color)
    ax[0].imshow(image,cmap='gray')
    ax[0].set_title("Image")
    ax[1].title.set_text("EfficientSAM (VIT-small)")
    ax[1].axis('off')
    ax[0].axis('off')
    plt.tight_layout()
    plt.savefig(join(png_save_dir, name.split(".")[0] + '.png'), dpi=300)
    # plt.show()
    plt.close()
if __name__ == '__main__':
    N=40
    # 使用glob函数搜索data_root和gt_root目录下所有的.png文件,并将它们排序
    img_png_files = sorted(glob.glob(join(data_root, '*.png')))
    gt_png_files = sorted(glob.glob(join(gt_root, '*.png')))
    # 选取前N个文件
    img_png_files = img_png_files[:N]
    gt_png_files = gt_png_files[:N]
    # assert len(img_png_files) == len(gt_png_files), "Images and masks count mismatch"

    # 初始化一个有序字典efficiency,用于存储case和time
    efficiency = OrderedDict([('case', []), ('time', [])])

    # 使用for循环遍历所有.png文件,tqdm用于显示进度条
    for img_png_file, gt_png_file in tqdm(zip(img_png_files, gt_png_files), total=len(img_png_files)):
        start_time = time()  # 开始时间

        # 处理2D图像及其对应的标签
        process_image_2D(img_png_file, gt_png_file)

        end_time = time()  # 结束时间

        # 记录文件处理信息
        efficiency['case'].append(basename(img_png_file))
        efficiency['time'].append(end_time - start_time)

        # 打印当前处理文件的信息
        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print(f"{current_time}, file name: {basename(img_png_file)}, time cost: {np.round(end_time - start_time, 4)}")
    # 将效率数据保存到CSV文件
    efficiency_df = pd.DataFrame(efficiency)
    efficiency_df.to_csv(join(pred_save_dir, 'efficiency.csv'), index=False)

  • 4
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

吾在学习路

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值