Segment Anything(SAM)全图分割做mask

项目的源码和模型下载以及环境配置等可参考我的上一篇文章,这里不再赘述。 

文章链接:https://blog.csdn.net/m0_63604019/article/details/130221434

在项目中创建一个名为segment-everything.py的文件,文件中写入如下代码:

import torchvision
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))


image = cv2.imread('B.jpg')    #将B.jpg改为自己的输入图片的路径
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

sam_checkpoint = "sam_vit_h_4b8939.pth"     #改为已下载的模型的存放路径

device = "cuda"     #默认是cuda,如果是用cpu的话就改为cpu
model_type = "default"      #default默认代表的是vit_h模型,可将其改为自己下载的模型名称(vit_h/vit_l/vit_b)

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)
masks = mask_generator.generate(image)

plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.savefig('B_out.jpg')     #此处需填写输出结果的存放路径,B_out代表输出结果的文件名,.jpg表示将以jpg形式存放
plt.show()

然后右键点击【Run 'segment-everyting'】运行segment-everyting.py文件,运行过程可能需要几分钟,请耐心等待。

如果运行时出现如下报错:

OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5md.dll already initialized.
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause inc
orrect results. The best thing to do is to ensure that only a single OpenMP runtime is linked into the process, e.g. by avoiding static linking of the OpenMP runti
me in any library. As an unsafe, unsupported, undocumented workaround you can set the environment variable KMP_DUPLICATE_LIB_OK=TRUE to allow the program to contin
ue to execute, but that may cause crashes or silently produce incorrect results. For more information, please see http://www.intel.com/software/products/support/.

则在segment-everyting.py文件的顶部加入两行代码:

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

我的输入图片是我养的加菲猫(顺便夸一句我的小猪咪真可爱!):

分割后的输出结果:

 感觉效果不是我想要的那种,一只完整的小猫咪都被分得四分五裂了......可能还需要添加一些提示还是啥的,我还没搞懂。

 下面这个是在segment anything的demo(Segment Anything | Meta AI (segment-anything.com))中呈现的效果,这才是我想要的。

### 使用 Segment Anything 模型进行遥感图像分割 #### 准备工作 为了使用 Segment Anything Model (SAM) 进行遥感图像的分割,需先安装必要的库并加载预训练模型。这通常涉及配置环境变量和导入所需的Python包。 ```python import torch from segment_anything import sam_model_registry, SamPredictor import cv2 import numpy as np ``` #### 加载模型与初始化预测器 指定 `image_path` 和 `sam_checkpoint` 参数来指明输入图片的位置以及之前下载好的 SAM 权重文件位置;同时定义使用的设备(CPU 或 GPU)及所选模型版本[^2]。 ```python def load_sam(model_type="vit_b", device='cuda'): checkpoint_url = "path/to/sam_vit_b_01ec64.pth" model = sam_model_registry[model_type](checkpoint=checkpoint_url).to(device=device) predictor = SamPredictor(model) return predictor ``` #### 图像读取与处理 针对特定的应用场景——即地理遥感中的卫星影像,需要确保能够正确解析多光谱或高分辨率的数据集格式,比如GeoTIFF文件。这里假设已经有一个合适的函数可以从给定路径中读入此类图像,并将其转换成适合喂给神经网络的形式。 ```python def read_image(image_path): img = cv2.imread(image_path) img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) return img_rgb ``` #### 执行分割操作 利用上述准备好的工具链,在实际执行阶段只需调用简单的API接口即可完成目标区域的选择与掩码生成过程。对于遥感领域而言,这意味着可以根据兴趣点自动勾勒出边界轮廓,从而辅助后续的空间分析任务[^1]。 ```python predictor = load_sam() img = read_image('sentinel2.tif') predictor.set_image(img) input_point = np.array([[500, 375]]) # 用户交互式点击得到的兴趣点坐标 input_label = np.array([1]) # 对应标签(前景/背景),此处设为前景 masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, ) ``` #### 输出结果保存 最后一步是将获得的结果导出至外部存储介质以便进一步查看或与其他GIS软件集成。此部分涉及到创建Shapefile或其他矢量图形文件的操作。 ```python import shapely.geometry import geopandas as gpd mask = masks[np.argmax(scores)] # 获取最佳匹配度下的二值化蒙版图层 polygons = [] for i in range(mask.shape[-1]): contours, _ = cv2.findContours((mask[:, :, i]*255).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) polygon = [shapely.geometry.Polygon(contour.reshape(-1, 2)) for contour in contours if len(contour)>2] polygons.extend(polygon) gdf = gpd.GeoDataFrame({'geometry': polygons}) gdf.to_file(filename='output.shp', driver='ESRI Shapefile') ```
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AYu~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值