超像素(superpixels)分割 SLIC算法

自编写code实现

import math
from skimage import io, color
import numpy as np
from tqdm import trange

class Cluster(object):
    cluster_index = 1

    def __init__(self, h, w, l=0, a=0, b=0):
        self.update(h, w, l, a, b)
        self.pixels = []
        self.no = self.cluster_index
        Cluster.cluster_index += 1

    def update(self, h, w, l, a, b):
        self.h = h
        self.w = w
        self.l = l
        self.a = a
        self.b = b

    def __str__(self): # 聚类的格式
        return "{},{}:{} {} {} ".format(self.h, self.w, self.l, self.a, self.b)

    def __repr__(self):
        return self.__str__()


class SLICProcessor(object):
    @staticmethod
    def open_image(path):
        """
        Return:
            3D array, row col [LAB]
        """
        rgb = io.imread(path)
        lab_arr = color.rgb2lab(rgb)
        print("lab_arr.shape:",lab_arr.shape) # (512, 512, 3)
        return lab_arr

    @staticmethod
    def save_lab_image(path, lab_arr):
        """
        Convert the array to RBG, then save the image
        :param path:
        :param lab_arr:
        :return:
        """
        rgb_arr = color.lab2rgb(lab_arr)
        io.imsave(path, rgb_arr)


    def make_cluster(self, h, w):
        h = int(h)
        w = int(w)
        return Cluster(h, w,
                       self.data[h][w][0],
                       self.data[h][w][1],
                       self.data[h][w][2])

    def __init__(self, filename, K, M):
        self.K = K
        self.M = M

        # 先转化为LAB 再次转化为X,Y, LAB
        self.data = self.open_image(filename)  # (512, 512, 3) LAB
        self.image_height = self.data.shape[0]
        self.image_width = self.data.shape[1]
        self.N = self.image_height * self.image_width
        self.S = int(math.sqrt(self.N / self.K))
        print("self.S",self.S)

        self.clusters = []
        self.label = {} # 存储像素--->标签
        self.dis = np.full((self.image_height, self.image_width), np.inf)

    def init_clusters(self):
        h = self.S / 2
        w = self.S / 2
        while h < self.image_height:
            while w < self.image_width:
                self.clusters.append(self.make_cluster(h, w))
                w += self.S
            w = self.S / 2
            h += self.S
        print("self.clusters:",self.clusters, len(self.clusters), type(self.clusters))
        print()

    def get_gradient(self, h, w):
        if w + 1 >= self.image_width:
            w = self.image_width - 2
        if h + 1 >= self.image_height:
            h = self.image_height - 2

        gradient = self.data[h + 1][w + 1][0] - self.data[h][w][0] + \
                   self.data[h + 1][w + 1][1] - self.data[h][w][1] + \
                   self.data[h + 1][w + 1][2] - self.data[h][w][2]
        return gradient

# 找到3*3 范围内梯度最小的点,作为初始聚类中心
    def move_clusters(self):
        for cluster in self.clusters:
            cluster_gradient = self.get_gradient(cluster.h, cluster.w)
            for dh in range(-1, 2): # -1,0,1  3*3范围内
                for dw in range(-1, 2):
                    _h = cluster.h + dh
                    _w = cluster.w + dw
                    new_gradient = self.get_gradient(_h, _w)
                    if new_gradient < cluster_gradient:
                        cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])
                        cluster_gradient = new_gradient

    # 在聚类中心 2S*2S 范围内进行计算距离
    def assignment(self):
        for cluster in self.clusters:
            for h in range(cluster.h - 2 * self.S, cluster.h + 2 * self.S):
                if h < 0 or h >= self.image_height: continue
                for w in range(cluster.w - 2 * self.S, cluster.w + 2 * self.S):
                    if w < 0 or w >= self.image_width: continue
                    L, A, B = self.data[h][w]
                    Dc = math.sqrt(
                        math.pow(L - cluster.l, 2) +
                        math.pow(A - cluster.a, 2) +
                        math.pow(B - cluster.b, 2))
                    Ds = math.sqrt(
                        math.pow(h - cluster.h, 2) +
                        math.pow(w - cluster.w, 2))
                    D = math.sqrt(math.pow(Dc / self.M, 2) + math.pow(Ds / self.S, 2))
                    if D < self.dis[h][w]:
                        if (h, w) not in self.label:  # 新标签(该像素点没有分配超像素中心)
                            self.label[(h, w)] = cluster
                            cluster.pixels.append((h, w))# 该聚类中心包含的像素坐标
                        else:
                            self.label[(h, w)].pixels.remove((h, w))# 已经分配超像素中心,把原来的移除,添加新的超像素标签
                            self.label[(h, w)] = cluster
                            cluster.pixels.append((h, w))# 该聚类中心包含的像素坐标
                        self.dis[h][w] = D

    def update_cluster(self):
        for cluster in self.clusters:
            sum_h = sum_w = number = 0
            for p in cluster.pixels:
                print("PPP:",p) # x,y
                sum_h += p[0]
                sum_w += p[1]
                number += 1
            _h = int(sum_h / number)
            _w = int(sum_w / number)
            cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])


    # 保存最终的分类图像结果
    def save_current_image(self, name):
        image_arr = np.copy(self.data) #lab_arr # (512, 512, 3) LAB
        for cluster in self.clusters:
            for p in cluster.pixels: #(p[0],p[1]) 是图像的坐标
                image_arr[p[0]][p[1]][0] = cluster.l
                image_arr[p[0]][p[1]][1] = cluster.a
                image_arr[p[0]][p[1]][2] = cluster.b
            image_arr[cluster.h][cluster.w][0] = 0
            image_arr[cluster.h][cluster.w][1] = 0
            image_arr[cluster.h][cluster.w][2] = 0
        # print("image_arr:", image_arr.shape)  # (512, 512, 3)
        self.save_lab_image(name, image_arr)  # lab_arr 转化为RBG 再进行保存

    def iterate_10times(self):
        self.init_clusters() # 初始化一些参数
        self.move_clusters() # 调节初始化的聚类中心
        # 迭代开始
        for i in range(1):
            self.assignment()
            self.update_cluster()
            name = 'lenna_M{m}_K{k}_loop{loop}.png'.format(loop=i, m=self.M, k=self.K)
            self.save_current_image(name)


if __name__ == '__main__':
    p = SLICProcessor('Lenna.png', 200, 40)
    p.iterate_10times()

    # p = SLICProcessor('Lenna.png', 300, 40)
    # p.iterate_10times()
    # p = SLICProcessor('Lenna.png', 500, 40)
    # p.iterate_10times()
    # p = SLICProcessor('Lenna.png', 1000, 40)
    # p.iterate_10times()
    # p = SLICProcessor('Lenna.png', 200, 5)
    # p.iterate_10times()
    # p = SLICProcessor('Lenna.png', 300, 5)
    # p.iterate_10times()
    # p = SLICProcessor('Lenna.png', 500, 5)
    # p.iterate_10times()
    # p = SLICProcessor('Lenna.png', 1000, 5)
    # p.iterate_10times()

请添加图片描述

参考:https://blog.csdn.net/haoji007/article/details/103432879

官方实现

from skimage import io
import matplotlib.pyplot as plt
from skimage.segmentation import slic
from skimage.util import img_as_float
from skimage.segmentation import mark_boundaries

x1 = "./rock.jpg"
image = img_as_float(io.imread(x1))
segments = slic(image, n_segments=20, sigma=5)
print("image:",image.shape) # image: (500, 560, 3)
print("segment:",segments.dtype, segments.shape) # segment: int64 (500, 560)
print("segment10:",segments[0:,:10]) # segment: (500, 560) 得到分类好的像素

# plt.axis("off")
plt.subplot(131)
plt.title('image',fontsize = 15)
plt.imshow(image)
plt.subplot(132)
plt.title('segments',fontsize = 15)
# plt.imshow(segments, cmap="gray")
plt.imshow(segments)

plt.subplot(133)
plt.title('image and segments',fontsize = 15)
plt.imshow(mark_boundaries(image, segments)) # 联通性标记
plt.show()
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是SLIC像素分割的MATLAB代码: ``` %% SLIC像素分割 clc; clear all; close all; % 读取图像 img = imread('lena.bmp'); figure; imshow(img); title('原始图像'); % 设置参数 num_superpixels = 1000; % 像素数量 compactness = 10; % 像素紧密度,越大则像素更规则 % 计算步长 [h, w, ~] = size(img); step = sqrt(h*w/num_superpixels); % 初始化像素分割结果 labels = zeros(h, w); % 初始化像素中心 centers = step/2:step:w; centers = repmat(centers, [ceil(h/step), 1]); centers = centers(1:h, :); % 迭代优化 for i = 1:10 % 计算像素中心所在的网格位置 gridx = floor(centers(:)/step)+1; gridy = floor((1:h)'/step)+1; % 扩展图像边界 img_ext = padarray(img, [step, step], 'symmetric', 'both'); gridx_ext = padarray(gridx, [step, step], 'symmetric', 'both'); gridy_ext = padarray(gridy, [step, step], 'symmetric', 'both'); labels_ext = padarray(labels, [step, step], 'symmetric', 'both'); % 计算每个像素中心附近的像素点 for j = 1:num_superpixels % 确定像素中心的位置 cx = centers(j); cy = find(gridy_ext(:, cx) == j, 1, 'first'); cy = cy - step; % 计算像素中心周围的像素点 x1 = max(cx-step, 1); x2 = min(cx+step, w)+step; y1 = max(cy-step, 1); y2 = min(cy+step, h)+step; pixels = img_ext(y1:y2, x1:x2, :); labels_pixels = labels_ext(y1:y2, x1:x2); [yy, xx] = find(labels_pixels == j); pixels = pixels(min(yy):max(yy), min(xx):max(xx), :); labels_pixels = labels_pixels(min(yy):max(yy), min(xx):max(xx)); % 计算每个像素点与像素中心的距离 [h_p, w_p, ~] = size(pixels); dists = zeros(h_p, w_p); for k = 1:h_p for l = 1:w_p dists(k, l) = sqrt((k-yy(1))^2 + (l-xx(1))^2) + sqrt((pixels(k, l, 1)-pixels(yy(1), xx(1), 1))^2 + (pixels(k, l, 2)-pixels(yy(1), xx(1), 2))^2 + (pixels(k, l, 3)-pixels(yy(1), xx(1), 3))^2)/compactness; end end % 更新像素点的标签 labels_pixels_new = labels_pixels; [~, ind] = sort(dists(:)); ind = ind(1:numel(yy)); for k = 1:numel(yy) [y, x] = ind2sub([h_p, w_p], ind(k)); labels_pixels_new(yy(k), xx(k)) = labels_pixels(y, x); end % 更新像素标签 labels_ext(y1:y2, x1:x2) = labels_pixels_new; end % 缩小图像边界 labels = labels_ext(step+1:h+step, step+1:w+step); % 更新像素中心 for j = 1:num_superpixels [yy, xx] = find(labels == j); centers(j, :) = [mean(xx), mean(yy)]; end end % 显示像素分割结果 figure; imshow(labels, []); title('像素分割结果'); ``` 该代码实现了SLIC像素分割算法,包括计算像素中心、计算每个像素周围的像素点、计算像素点与像素中心的距离、更新像素点的标签和更新像素中心等步骤。其中,使用了MATLAB自带的`padarray`函数对图像边界进行了扩展和缩小操作。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值