自编写code实现
import math
from skimage import io, color
import numpy as np
from tqdm import trange
class Cluster(object):
cluster_index = 1
def __init__(self, h, w, l=0, a=0, b=0):
self.update(h, w, l, a, b)
self.pixels = []
self.no = self.cluster_index
Cluster.cluster_index += 1
def update(self, h, w, l, a, b):
self.h = h
self.w = w
self.l = l
self.a = a
self.b = b
def __str__(self): # 聚类的格式
return "{},{}:{} {} {} ".format(self.h, self.w, self.l, self.a, self.b)
def __repr__(self):
return self.__str__()
class SLICProcessor(object):
@staticmethod
def open_image(path):
"""
Return:
3D array, row col [LAB]
"""
rgb = io.imread(path)
lab_arr = color.rgb2lab(rgb)
print("lab_arr.shape:",lab_arr.shape) # (512, 512, 3)
return lab_arr
@staticmethod
def save_lab_image(path, lab_arr):
"""
Convert the array to RBG, then save the image
:param path:
:param lab_arr:
:return:
"""
rgb_arr = color.lab2rgb(lab_arr)
io.imsave(path, rgb_arr)
def make_cluster(self, h, w):
h = int(h)
w = int(w)
return Cluster(h, w,
self.data[h][w][0],
self.data[h][w][1],
self.data[h][w][2])
def __init__(self, filename, K, M):
self.K = K
self.M = M
# 先转化为LAB 再次转化为X,Y, LAB
self.data = self.open_image(filename) # (512, 512, 3) LAB
self.image_height = self.data.shape[0]
self.image_width = self.data.shape[1]
self.N = self.image_height * self.image_width
self.S = int(math.sqrt(self.N / self.K))
print("self.S",self.S)
self.clusters = []
self.label = {} # 存储像素--->标签
self.dis = np.full((self.image_height, self.image_width), np.inf)
def init_clusters(self):
h = self.S / 2
w = self.S / 2
while h < self.image_height:
while w < self.image_width:
self.clusters.append(self.make_cluster(h, w))
w += self.S
w = self.S / 2
h += self.S
print("self.clusters:",self.clusters, len(self.clusters), type(self.clusters))
print()
def get_gradient(self, h, w):
if w + 1 >= self.image_width:
w = self.image_width - 2
if h + 1 >= self.image_height:
h = self.image_height - 2
gradient = self.data[h + 1][w + 1][0] - self.data[h][w][0] + \
self.data[h + 1][w + 1][1] - self.data[h][w][1] + \
self.data[h + 1][w + 1][2] - self.data[h][w][2]
return gradient
# 找到3*3 范围内梯度最小的点,作为初始聚类中心
def move_clusters(self):
for cluster in self.clusters:
cluster_gradient = self.get_gradient(cluster.h, cluster.w)
for dh in range(-1, 2): # -1,0,1 3*3范围内
for dw in range(-1, 2):
_h = cluster.h + dh
_w = cluster.w + dw
new_gradient = self.get_gradient(_h, _w)
if new_gradient < cluster_gradient:
cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])
cluster_gradient = new_gradient
# 在聚类中心 2S*2S 范围内进行计算距离
def assignment(self):
for cluster in self.clusters:
for h in range(cluster.h - 2 * self.S, cluster.h + 2 * self.S):
if h < 0 or h >= self.image_height: continue
for w in range(cluster.w - 2 * self.S, cluster.w + 2 * self.S):
if w < 0 or w >= self.image_width: continue
L, A, B = self.data[h][w]
Dc = math.sqrt(
math.pow(L - cluster.l, 2) +
math.pow(A - cluster.a, 2) +
math.pow(B - cluster.b, 2))
Ds = math.sqrt(
math.pow(h - cluster.h, 2) +
math.pow(w - cluster.w, 2))
D = math.sqrt(math.pow(Dc / self.M, 2) + math.pow(Ds / self.S, 2))
if D < self.dis[h][w]:
if (h, w) not in self.label: # 新标签(该像素点没有分配超像素中心)
self.label[(h, w)] = cluster
cluster.pixels.append((h, w))# 该聚类中心包含的像素坐标
else:
self.label[(h, w)].pixels.remove((h, w))# 已经分配超像素中心,把原来的移除,添加新的超像素标签
self.label[(h, w)] = cluster
cluster.pixels.append((h, w))# 该聚类中心包含的像素坐标
self.dis[h][w] = D
def update_cluster(self):
for cluster in self.clusters:
sum_h = sum_w = number = 0
for p in cluster.pixels:
print("PPP:",p) # x,y
sum_h += p[0]
sum_w += p[1]
number += 1
_h = int(sum_h / number)
_w = int(sum_w / number)
cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])
# 保存最终的分类图像结果
def save_current_image(self, name):
image_arr = np.copy(self.data) #lab_arr # (512, 512, 3) LAB
for cluster in self.clusters:
for p in cluster.pixels: #(p[0],p[1]) 是图像的坐标
image_arr[p[0]][p[1]][0] = cluster.l
image_arr[p[0]][p[1]][1] = cluster.a
image_arr[p[0]][p[1]][2] = cluster.b
image_arr[cluster.h][cluster.w][0] = 0
image_arr[cluster.h][cluster.w][1] = 0
image_arr[cluster.h][cluster.w][2] = 0
# print("image_arr:", image_arr.shape) # (512, 512, 3)
self.save_lab_image(name, image_arr) # lab_arr 转化为RBG 再进行保存
def iterate_10times(self):
self.init_clusters() # 初始化一些参数
self.move_clusters() # 调节初始化的聚类中心
# 迭代开始
for i in range(1):
self.assignment()
self.update_cluster()
name = 'lenna_M{m}_K{k}_loop{loop}.png'.format(loop=i, m=self.M, k=self.K)
self.save_current_image(name)
if __name__ == '__main__':
p = SLICProcessor('Lenna.png', 200, 40)
p.iterate_10times()
# p = SLICProcessor('Lenna.png', 300, 40)
# p.iterate_10times()
# p = SLICProcessor('Lenna.png', 500, 40)
# p.iterate_10times()
# p = SLICProcessor('Lenna.png', 1000, 40)
# p.iterate_10times()
# p = SLICProcessor('Lenna.png', 200, 5)
# p.iterate_10times()
# p = SLICProcessor('Lenna.png', 300, 5)
# p.iterate_10times()
# p = SLICProcessor('Lenna.png', 500, 5)
# p.iterate_10times()
# p = SLICProcessor('Lenna.png', 1000, 5)
# p.iterate_10times()
参考:https://blog.csdn.net/haoji007/article/details/103432879
官方实现
from skimage import io
import matplotlib.pyplot as plt
from skimage.segmentation import slic
from skimage.util import img_as_float
from skimage.segmentation import mark_boundaries
x1 = "./rock.jpg"
image = img_as_float(io.imread(x1))
segments = slic(image, n_segments=20, sigma=5)
print("image:",image.shape) # image: (500, 560, 3)
print("segment:",segments.dtype, segments.shape) # segment: int64 (500, 560)
print("segment10:",segments[0:,:10]) # segment: (500, 560) 得到分类好的像素
# plt.axis("off")
plt.subplot(131)
plt.title('image',fontsize = 15)
plt.imshow(image)
plt.subplot(132)
plt.title('segments',fontsize = 15)
# plt.imshow(segments, cmap="gray")
plt.imshow(segments)
plt.subplot(133)
plt.title('image and segments',fontsize = 15)
plt.imshow(mark_boundaries(image, segments)) # 联通性标记
plt.show()