SLIC学习笔记

超像素分割——SLIC学习

最新看论文的时候发现“超像素分割”概念被多次提及,作为图像预处理的一部分,“超像素分割”可以在保持图像特征不变的情况下,减少后续图像处理的计算量。
这里,将简单介绍一下SLIC(Simple linear iterative clustering)算法,先贴出相关论文和源代码供大家参考。

SLIC算法描述

算法流程:
SLIC
对照上述算法流程图,SLIC算法可以分为下面几步:
1. 颜色空间的转换。作者在论文中提到使用CIELAB空间,而不是RGB空间。原因在于在CIELAB空间分割效果要比RGB空间规则很多。两空之间的转换可以参考http://www.cnblogs.com/Imageshop/archive/2013/02/02/2889897.html
2. 初始化聚类中心Ck,步长为S,Ck为包含五个数据的向量:l,a,b为像素的LAB颜色空间值,x,y为像素的坐标。作者为了防止聚类中心落在图像边缘位置,在聚类中心3X3的邻域内求各像素梯度值,将聚类中心移动到梯度值最小的像素点处。
3. 最重要的一步:迭代寻聚类中心。这里有点类似K-means算法,对于每一个初始化的聚类中心,在其2SX2S的邻域内计算各像素和聚类中心的距离D,当D小于d(i)时,更新像素i的标签,遍历完聚类中心后,重新计算新的聚类中心。这里,因为SLIC的聚类数据来自LAB颜色空间和XY位置坐标,两者的取值范围不一样,作者提出了如下的距离计算公式:
这里写图片描述
这里写图片描述
参数S表示XY空间的最大的可能值,这个通过用户输入的超像素大小可以自动计算出来(详见论文),而参数M为LAB空间的距离可能最大值,论文提出其可取的范围建议为[1,40]。

python代码实现

作者源代码给的是C++版本的,关于C++版本的优化可以参考这篇博客:
https://www.cnblogs.com/Imageshop/p/6193433.html
python版本的代码,可以参考这里:
https://www.kawabangga.com/posts/1923
不过,直接从Github上clone下来的代码运行后,会出现如下的错误:
only integers, slices (:), ellipsis (…), numpy.newaxis (None) and integer or boolean arrays are valid indices
错误原因是程序中的有些h和w不是int型数据,所以我稍微修改了一下,强制转换成int型。
代码如下:

import math
from skimage import io, color
import numpy as np
from tqdm import trange

class Cluster(object):
    cluster_index = 1

    def __init__(self, h, w, l=0, a=0, b=0):
        self.update(h, w, l, a, b)
        self.pixels = []
        self.no = self.cluster_index
        Cluster.cluster_index += 1

    def update(self, h, w, l, a, b):
        self.h = h
        self.w = w
        self.l = l
        self.a = a
        self.b = b

    def __str__(self):
        return "{},{}:{} {} {} ".format(self.h, self.w, self.l, self.a, self.b)

    def __repr__(self):
        return self.__str__()


class SLICProcessor(object):
    @staticmethod
    def open_image(path):
        """
        Return:
            3D array, row col [LAB]
        """
        rgb = io.imread(path)
        lab_arr = color.rgb2lab(rgb)
        return lab_arr

    @staticmethod
    def save_lab_image(path, lab_arr):
        """
        Convert the array to RBG, then save the image
        :param path:
        :param lab_arr:
        :return:
        """
        rgb_arr = color.lab2rgb(lab_arr)
        io.imsave(path, rgb_arr)

    def make_cluster(self, h, w):
        return Cluster(h, w,
                       self.data[h][w][0],
                       self.data[h][w][1],
                       self.data[h][w][2])

    def __init__(self, filename, K, M):
        self.K = K
        self.M = M

        self.data = self.open_image(filename)
        self.image_height = self.data.shape[0]
        self.image_width = self.data.shape[1]
        self.N = self.image_height * self.image_width
        self.S = int(math.sqrt(self.N / self.K))

        self.clusters = []
        self.label = {}
        self.dis = np.full((self.image_height, self.image_width), np.inf)

    def init_clusters(self):
        h = int(self.S / 2)
        w = int(self.S / 2)
        while h < self.image_height:
            while w < self.image_width:
                self.clusters.append(self.make_cluster(h, w))
                w += int(self.S)
            w = int(self.S / 2)
            h += int(self.S)

    def get_gradient(self, h, w):
        if w + 1 >= self.image_width:
            w = self.image_width - 2
        if h + 1 >= self.image_height:
            h = self.image_height - 2

        gradient = self.data[w + 1][h + 1][0] - self.data[w][h][0] + \
                   self.data[w + 1][h + 1][1] - self.data[w][h][1] + \
                   self.data[w + 1][h + 1][2] - self.data[w][h][2]
        return gradient

    def move_clusters(self):
        for cluster in self.clusters:
            cluster_gradient = self.get_gradient(cluster.h, cluster.w)
            for dh in range(-1, 2):
                for dw in range(-1, 2):
                    _h = cluster.h + dh
                    _w = cluster.w + dw
                    new_gradient = self.get_gradient(_h, _w)
                    if new_gradient < cluster_gradient:
                        cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])
                        cluster_gradient = new_gradient

    def assignment(self):
        for cluster in self.clusters:
            for h in range(cluster.h - 2 * self.S, cluster.h + 2 * self.S):
                if h < 0 or h >= self.image_height: continue
                for w in range(cluster.w - 2 * self.S, cluster.w + 2 * self.S):
                    if w < 0 or w >= self.image_width: continue
                    L, A, B = self.data[h][w]
                    Dc = math.sqrt(
                        math.pow(L - cluster.l, 2) +
                        math.pow(A - cluster.a, 2) +
                        math.pow(B - cluster.b, 2))
                    Ds = math.sqrt(
                        math.pow(h - cluster.h, 2) +
                        math.pow(w - cluster.w, 2))
                    D = math.sqrt(math.pow(Dc / self.M, 2) + math.pow(Ds / self.S, 2))
                    if D < self.dis[h][w]:
                        if (h, w) not in self.label:
                            self.label[(h, w)] = cluster
                            cluster.pixels.append((h, w))
                        else:
                            self.label[(h, w)].pixels.remove((h, w))
                            self.label[(h, w)] = cluster
                            cluster.pixels.append((h, w))
                        self.dis[h][w] = D

    def update_cluster(self):
        for cluster in self.clusters:
            sum_h = sum_w = number = 0
            for p in cluster.pixels:
                sum_h += p[0]
                sum_w += p[1]
                number += 1
                _h = int(sum_h / number)
                _w = int(sum_w / number)
                cluster.update(_h, _w, self.data[_h][_w][0], self.data[_h][_w][1], self.data[_h][_w][2])

    def save_current_image(self, name):
        image_arr = np.copy(self.data)
        for cluster in self.clusters:
            for p in cluster.pixels:
                image_arr[p[0]][p[1]][0] = cluster.l
                image_arr[p[0]][p[1]][1] = cluster.a
                image_arr[p[0]][p[1]][2] = cluster.b
            image_arr[cluster.h][cluster.w][0] = 0
            image_arr[cluster.h][cluster.w][1] = 0
            image_arr[cluster.h][cluster.w][2] = 0
        self.save_lab_image(name, image_arr)

    def iterate_10times(self):
        self.init_clusters()
        self.move_clusters()
        for i in trange(10):
            self.assignment()
            self.update_cluster()
            name = 'lenna_M{m}_K{k}_loop{loop}.png'.format(loop=i, m=self.M, k=self.K)
            self.save_current_image(name)


if __name__ == '__main__':
    p = SLICProcessor('Lenna.png', 200, 40)
    p.iterate_10times()

代码运行比较慢,花了我大概一分钟左右,如果是用于工程项目中的话,还需要好好优化一下。

  • 3
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值