python中shift_python mean-shift数据分析(1)

最近的一个项目要用到mean-shift[0]算法,显然,首先是选择一个包含mean-shift算法的机器学习工具包,而且最好是开源的,因为后续我们可以根据需要来修改一些东西。

这里我们选择了python实现的开源机器学习工具包Scikit-learn[1.5],其GitHub链接为[2]。

我们从官方提供的demo[3]开始,

首先从相应的包(package)中导入要用到的模块(module)

import numpy as np

from sklearn.cluster import MeanShift, estimate_bandwidth

from sklearn.datasets.samples_generator import make_blobs

顾名思义,sklearn.cluster包含一些聚类(cluster)算法,而sklearn.datasets.samples_generator用于生成数据样本。

生成数据样本

centers = [[1, 1], [-1, -1], [1, -1]]

X, _ = make_blobs(n_samples=10000, centers=centers, cluster_std=0.6)

然后对数据样本进行mean-shift分析

bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=500)

ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)

ms.fit(X)

labels = ms.labels_

cluster_centers = ms.cluster_centers_

labels_unique = np.unique(labels)

n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

estimate_bandwidth()用于生成mean-shift窗口的尺寸,其参数的意义为:从X中随机选取500个样本,计算每一对样本的距离,然后选取这些距离的0.2分位数作为返回值,显然当n_samples很大时,这个函数的计算量是很大的。

np.unique(labels)返回labels不同取值的个数,这里用于统计聚类后类别的个数。

MeanShift类的构造函数MeanShift()是重点,其原型为:

MeanShift(bandwidth=None, seeds=None, bin_seeding=False, min_bin_freq=1,cluster_all=True, n_jobs=1)

其参数的意义为:

bandwidth:float, Bandwidth used in the RBF(Radical Basis Function,径向基函数) kernel. If not given, the bandwidth is estimated using sklearn.cluster.estimate_bandwidth.

seeds:array, shape=[n_samples, n_features], Seeds used to initialize kernels. If not set, the seeds are calculated by clustering.get_bin_seeds with bandwidth as the grid size and default values for other parameters.

bin_seeding: boolean, If true, initial kernel locations are not locations of all points, but rather the location of the discretized version of points, where points are binned onto a grid whose coarseness(粒度) corresponds to the bandwidth. Setting this option to True will speed up the algorithm because fewer seeds will be initialized. Ignored if seeds argument is not None.

min_bin_freq: int, optional, To speed up the algorithm, accept only those bins with at least min_bin_freq points as seeds, default 1.

cluster_all: If true, then all points are clustered, even those orphans that are not within any kernel. Orphans are assigned to the nearest kernel. If false, then orphans are given cluster label -1.

n_jobs:The number of jobs to use for the computation. This works by computing each of the n_init runs in parallel. If -1 all CPUs are used. If 1 is given, no parallel computing code is used at all, which is useful for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one are used.

MeanShift类的其他常用函数以及属性:

**cluster_centers_ **: array, [n_clusters, n_features].Coordinates of cluster centers.

labels_ : Labels of each point.

fit(X):Perform clustering.

最后画出聚类的结果

# Plot result

import matplotlib.pyplot as plt

from itertools import cycle

plt.figure(1)

plt.clf()

colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')

for k, col in zip(range(n_clusters_), colors):

my_members = labels == k

cluster_center = cluster_centers[k]

plt.plot(X[my_members, 0], X[my_members, 1], col + '.')

plt.plot(cluster_center[0], cluster_center[1], 'o', markerfacecolor=col,

markeredgecolor='k', markersize=14)

plt.title('Estimated number of clusters: %d' % n_clusters_)

plt.show()

参考资料

[0]Mean shift: A robust approach toward feature space analysis. D. Comaniciu and P. Meer, IEEE Transactions on Pattern Analysis and Machine Intelligence (2002)

[1.5]Scikit-learn: Machine Learning in Python. Pedregosa et al., JMLR 12, pp. 2825-2830, 2011

[1]tutorial: http://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html

[2]source: https://github.com/scikit-learn/scikit-learn

[3]demo: http://scikit-learn.org/stable/auto_examples/cluster/plot_mean_shift.html#sphx-glr-auto-examples-cluster-plot-mean-shift-py

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值