MeanShift聚类-02python案例

Intro

  Meanshift的使用案例~

数据引入

from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
from itertools import cycle
import numpy as np
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
from sklearn.datasets import load_iris
import pandas as pd
pd.set_option('display.max_rows', 500) # 打印最大行数
pd.set_option('display.max_columns', 500) # 打印最大列数
# 检查是否是array格式,如果不是,转换成array
from sklearn.utils import check_array
from sklearn.utils import check_random_state
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import gen_batches
iris_df = pd.DataFrame(
    load_iris()["data"],
    columns=["sepal_length", "sepal_width", "petal_length", "petal_width"])
iris_df["target"] = load_iris()["target"]
iris_df.head()
sepal_lengthsepal_widthpetal_lengthpetal_widthtarget
05.13.51.40.20
14.93.01.40.20
24.73.21.30.20
34.63.11.50.20
45.03.61.40.20
iris_df.groupby(by="target").describe()
sepal_lengthsepal_widthpetal_lengthpetal_width
countmeanstdmin25%50%75%maxcountmeanstdmin25%50%75%maxcountmeanstdmin25%50%75%maxcountmeanstdmin25%50%75%max
target
050.05.0060.3524904.34.8005.05.25.850.03.4280.3790642.33.2003.43.6754.450.01.4620.1736641.01.41.501.5751.950.00.2460.1053860.10.20.20.30.6
150.05.9360.5161714.95.6005.96.37.050.02.7700.3137982.02.5252.83.0003.450.04.2600.4699113.04.04.354.6005.150.01.3260.1977531.01.21.31.51.8
250.06.5880.6358804.96.2256.56.97.950.02.9740.3224972.22.8003.03.1753.850.05.5520.5518954.55.15.555.8756.950.02.0260.2746501.41.82.02.32.5

从数据上看,三个种类之间,petal_length和petal_width的差异比较大,用它来画图。

# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors =["red","yellow","blue"]
marker = ["o","*","+"]
for k, col,mark in zip(range(3), colors,marker):
    sub_data = iris_df.query("target==%s"%k)
    plt.plot(sub_data.petal_length, sub_data.petal_width,"o", markerfacecolor=col,
             markeredgecolor='k', markersize=5)
plt.show()

在这里插入图片描述

可以看到红色点和其余点相差很多,蓝色和黄色有部分点交错在一起

默认参数进行聚类

# ms = MeanShift( bin_seeding=True,cluster_all=False)
bandwidth = 0.726
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(iris_df[["petal_length", "petal_width"]])
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)

print("number of estimated clusters : %d" % n_clusters_)

# #############################################################################
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle

plt.figure(1)
plt.clf()

# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors = ["yellow", "red", "blue"]
marker = ["o", "*", "+"]
for k, col, mark in zip(range(n_clusters_), colors, marker):
    my_members = labels == k
    cluster_center = cluster_centers[k]
    plt.plot(iris_df[my_members].petal_length,
             iris_df[my_members].petal_width,
             ".",
             markerfacecolor=col,
             markeredgecolor='k',
             markersize=6)
    plt.plot(cluster_center[0],
             cluster_center[1],
             'o',
             markerfacecolor=col,
             markeredgecolor='k',
             markersize=14)
    circle = plt.Circle((cluster_center[0], cluster_center[1]),
                        bandwidth,
                        color='black',
                        fill=False)
    plt.gcf().gca().add_artist(circle)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()
number of estimated clusters : 3

在这里插入图片描述

从图上看,红色部分自成一派,聚类效果就好,蓝黄两类互有交叉,以最靠近的类别中心来打label.

estimate_bandwidth方法

根据聚类的原始数据,生成建议的bandwidth,基础逻辑:

  • 先抽样,获取部分样本
  • 计算这样样本和所有点的最大距离
  • 对距离求平均

从逻辑上看,更像是找一个较大的距离,使得能涵盖更多的点

estimate_bandwidth(iris_df[["petal_length", "petal_width"]])
0.7266371274126329

计算距离,check下

from sklearn.neighbors import NearestNeighbors
nbrs = NearestNeighbors(n_neighbors=len(iris_df), n_jobs=-1)
nbrs.fit(iris_df.iloc[:,[2,3]])
NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
                 metric_params=None, n_jobs=-1, n_neighbors=150, p=2,
                 radius=1.0)
d, index = nbrs.kneighbors(iris_df.iloc[:,[2,3]],return_distance=True)
from functools import reduce #python 3
total_distance = reduce(lambda x,y: x+y,np.array(pd.DataFrame(d).iloc[:,1:150]).tolist())
from scipy import stats
stats.describe(total_distance)
DescribeResult(nobs=22350, minmax=(0.0, 6.262587324740471), mean=2.185682454621745, variance=2.6174775533104904, skewness=0.3422940721262964, kurtosis=-1.1637573960810108)
pd.DataFrame({"total_distance":total_distance}).describe()
total_distance
count22350.000000
mean2.185682
std1.617862
min0.000000
25%0.640312
50%1.941649
75%3.544009
max6.262587

从数据上看,有点接近25%分位数。

meanshift的简单介绍到此为止,有些业务场景下,这个算法还是很好用的。需要具体问题具体分析。

                                2021-03-31 于南京市江宁区九龙湖

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值