机器学习算法与Python实践(13) - 均值漂移聚类(Mean-Shift Clustering)

机器学习算法与Python实践(13) - 均值漂移聚类 Mean-Shift Clustering

其实相信很多人多少都已经接触过这种聚类的方法,这篇文章也是参考别人的做的总结,也算是加深自己印象的一个笔记。

一、算法概述

Mean Shift算法,又称为均值漂移算法,Mean Shift的概念最早是由Fukunage在1975年提出的,在后来由Yizong Cheng对其进行扩充,主要提出了两点的改进:

  • 定义了核函数;
  • 增加了权重系数。

核函数的定义使得偏移值对偏移向量的贡献随之样本与被偏移点的距离的不同而不同。权重系数使得不同样本的权重不同。Mean Shift算法在聚类,图像平滑、分割以及视频跟踪等方面有广泛的应用。

二、算法核心原理

2.1 核函数

在Mean Shift算法中引入核函数的目的是使得随着样本与被偏移点的距离不同,其偏移量对均值偏移向量的贡献也不同。核函数是机器学习中常用的一种方式。核函数的定义如下所示:

X X X表示一个 d d d维的欧式空间, x x x是该空间中的一个点 x = x 1 , x 2 , x 3 ⋯ , x d x={x_1,x_2,x_3⋯,x_d} x=x1,x2,x3,xd其中, x x x的模 ‖ x ‖ 2 = x x T ‖x‖^2=xx^T x2=xxT R R R实数域,如果一个函数 K : X → R K:X→R K:XR存在一个剖面函数 k : [ 0 , ∞ ] → R k:[0,∞]→R k:[0,]R,即
K ( x ) = k ( ‖ x ‖ 2 ) K(x)=k(‖x‖^2) K(x)=k(x2)
并且满足:
(1) k k k是非负的
(2) k k k是非增的
(3) k k k是分段连续的
那么,函数 K ( x ) K(x) K(x)就称为核函数。

常用的核函数有高斯核函数。高斯核函数如下:

N ( x ) = 1 2 π h e − x 2 2 h 2 N(x)=\frac{1}{\sqrt{2\pi}h}e^{-\frac{x^2}{2h^2}} N(x)=2π h1e2h2x2

其中, h h h称为带宽(bandwidth),不同带宽的核函数如下图所示:

在这里插入图片描述

上图的画图脚本如下所示:

import matplotlib.pyplot as plt
import math

def cal_Gaussian(x, h=1):
    molecule = x * x
    denominator = 2 * h * h
    left = 1 / (math.sqrt(2 * math.pi) * h)
    return left * math.exp(-molecule / denominator)

x = []

for i in xrange(-40,40):
    x.append(i * 0.5);

score_1 = []
score_2 = []
score_3 = []
score_4 = []

for i in x:
    score_1.append(cal_Gaussian(i,1))
    score_2.append(cal_Gaussian(i,2))
    score_3.append(cal_Gaussian(i,3))
    score_4.append(cal_Gaussian(i,4))

plt.plot(x, score_1, 'b--', label="h=1")
plt.plot(x, score_2, 'k--', label="h=2")
plt.plot(x, score_3, 'g--', label="h=3")
plt.plot(x, score_4, 'r--', label="h=4")

plt.legend(loc="upper right")
plt.xlabel("x")
plt.ylabel("N")
plt.show()
2.2 Mean Shift 算法核心思想
2.21 基本原理

对于Mean Shift算法,是一个迭代的步骤,即先算出当前点的偏移均值,将该点移动到此偏移均值,然后以此为新的起始点,继续移动,直到满足最终的条件。此过程可由下图的过程进行说明(图片来自参考文献3):

  • 步骤1:在指定的区域内计算偏移均值(如下图的黄色的圈)

在这里插入图片描述

  • 步骤2:移动该点到偏移均值点处
    在这里插入图片描述

  • 步骤3: 重复上述的过程(计算新的偏移均值,移动)
    在这里插入图片描述


在这里插入图片描述


在这里插入图片描述


在这里插入图片描述

  • 步骤4:满足了最终的条件,即退出
    在这里插入图片描述

从上述过程可以看出,在Mean Shift算法中,最关键的就是计算每个点的偏移均值,然后根据新计算的偏移均值更新点的位置。

2.22 基本Mean Shift向量形式

对于给定的 d d d维空间 R d R^d Rd中的 n n n个样本点 x i x_i xi i = 1 , 2 , . . . , n i=1,2,...,n i=1,2,...,n,则对于 x x x点,其mean shift向量的基本形式为:

M h ( x ) = 1 k ∑ x i ∈ S h ( x i − x ) M_h(x)=\frac{1}{k}\sum_{x_i\in{S_h}}(x_i-x) Mh(x)=k1xiSh(xix)

其中 S h S_h Sh指的是一个半径为 h h h的高维球区域,如上图中的蓝色圆形区域。 S h S_h Sh定义为:

S h ( x ) = ( y ∣ ( y − x ) ( y − x ) T ⩽ h 2 ) S_h(x)=(y∣(y−x)(y−x)^T⩽h^2) Sh(x)=(y(yx)(yx)Th2)

这样的一种基本的Mean Shift形式存在一个问题:在 S h S_h Sh的区域内,每一个点对x的贡献是一样的。而实际上,这种贡献与x到每一个点之间的距离是相关的。同时,对于每一个样本,其重要程度也是不一样的。

官网上给的例子:

import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
from sklearn.datasets.samples_generator import make_blobs

# #############################################################################
# Generate sample data  造用于聚类的数据
centers = [[1, 1], [-1, -1], [1, -1]]
X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

# #############################################################################
# Compute clustering with MeanShift

# The following bandwidth can be automatically detected using
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(X)  # 训练模型
labels = ms.labels_  # 所有点的的labels
cluster_centers = ms.cluster_centers_  # 聚类得到的中心点

labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

# #############################################################################
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle

plt.figure(1)
plt.clf()

colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
for k, col in zip(range(n_clusters_), colors):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    print(cluster_center)
    plt.plot(X[my_members, 0], X[my_members, 1], col + '.')
    plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,
             markeredgecolor='k', markersize=14)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()

未完待续 …
更多详情请看:
https://blog.csdn.net/google19890102/article/details/51030884

  • 9
    点赞
  • 62
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值