Intro
Meanshift的使用案例~
数据引入
from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
from itertools import cycle
import numpy as np
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
from sklearn.datasets import load_iris
import pandas as pd
pd.set_option('display.max_rows', 500) # 打印最大行数
pd.set_option('display.max_columns', 500) # 打印最大列数
# 检查是否是array格式,如果不是,转换成array
from sklearn.utils import check_array
from sklearn.utils import check_random_state
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import gen_batches
iris_df = pd.DataFrame(
load_iris()["data"],
columns=["sepal_length", "sepal_width", "petal_length", "petal_width"])
iris_df["target"] = load_iris()["target"]
iris_df.head()
sepal_length | sepal_width | petal_length | petal_width | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
iris_df.groupby(by="target").describe()
sepal_length | sepal_width | petal_length | petal_width | |||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | mean | std | min | 25% | 50% | 75% | max | count | mean | std | min | 25% | 50% | 75% | max | count | mean | std | min | 25% | 50% | 75% | max | count | mean | std | min | 25% | 50% | 75% | max | |
target | ||||||||||||||||||||||||||||||||
0 | 50.0 | 5.006 | 0.352490 | 4.3 | 4.800 | 5.0 | 5.2 | 5.8 | 50.0 | 3.428 | 0.379064 | 2.3 | 3.200 | 3.4 | 3.675 | 4.4 | 50.0 | 1.462 | 0.173664 | 1.0 | 1.4 | 1.50 | 1.575 | 1.9 | 50.0 | 0.246 | 0.105386 | 0.1 | 0.2 | 0.2 | 0.3 | 0.6 |
1 | 50.0 | 5.936 | 0.516171 | 4.9 | 5.600 | 5.9 | 6.3 | 7.0 | 50.0 | 2.770 | 0.313798 | 2.0 | 2.525 | 2.8 | 3.000 | 3.4 | 50.0 | 4.260 | 0.469911 | 3.0 | 4.0 | 4.35 | 4.600 | 5.1 | 50.0 | 1.326 | 0.197753 | 1.0 | 1.2 | 1.3 | 1.5 | 1.8 |
2 | 50.0 | 6.588 | 0.635880 | 4.9 | 6.225 | 6.5 | 6.9 | 7.9 | 50.0 | 2.974 | 0.322497 | 2.2 | 2.800 | 3.0 | 3.175 | 3.8 | 50.0 | 5.552 | 0.551895 | 4.5 | 5.1 | 5.55 | 5.875 | 6.9 | 50.0 | 2.026 | 0.274650 | 1.4 | 1.8 | 2.0 | 2.3 | 2.5 |
从数据上看,三个种类之间,petal_length和petal_width的差异比较大,用它来画图。
# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors =["red","yellow","blue"]
marker = ["o","*","+"]
for k, col,mark in zip(range(3), colors,marker):
sub_data = iris_df.query("target==%s"%k)
plt.plot(sub_data.petal_length, sub_data.petal_width,"o", markerfacecolor=col,
markeredgecolor='k', markersize=5)
plt.show()
可以看到红色点和其余点相差很多,蓝色和黄色有部分点交错在一起
默认参数进行聚类
# ms = MeanShift( bin_seeding=True,cluster_all=False)
bandwidth = 0.726
ms = MeanShift(bandwidth=bandwidth, bin_seeding=True)
ms.fit(iris_df[["petal_length", "petal_width"]])
labels = ms.labels_
cluster_centers = ms.cluster_centers_
labels_unique = np.unique(labels)
n_clusters_ = len(labels_unique)
print("number of estimated clusters : %d" % n_clusters_)
# #############################################################################
# Plot result
import matplotlib.pyplot as plt
from itertools import cycle
plt.figure(1)
plt.clf()
# colors = cycle('bgrcmykbgrcmykbgrcmykbgrcmyk')
colors = ["yellow", "red", "blue"]
marker = ["o", "*", "+"]
for k, col, mark in zip(range(n_clusters_), colors, marker):
my_members = labels == k
cluster_center = cluster_centers[k]
plt.plot(iris_df[my_members].petal_length,
iris_df[my_members].petal_width,
".",
markerfacecolor=col,
markeredgecolor='k',
markersize=6)
plt.plot(cluster_center[0],
cluster_center[1],
'o',
markerfacecolor=col,
markeredgecolor='k',
markersize=14)
circle = plt.Circle((cluster_center[0], cluster_center[1]),
bandwidth,
color='black',
fill=False)
plt.gcf().gca().add_artist(circle)
plt.title('Estimated number of clusters: %d' % n_clusters_)
plt.show()
number of estimated clusters : 3
从图上看,红色部分自成一派,聚类效果就好,蓝黄两类互有交叉,以最靠近的类别中心来打label.
estimate_bandwidth方法
根据聚类的原始数据,生成建议的bandwidth,基础逻辑:
- 先抽样,获取部分样本
- 计算这样样本和所有点的最大距离
- 对距离求平均
从逻辑上看,更像是找一个较大的距离,使得能涵盖更多的点
estimate_bandwidth(iris_df[["petal_length", "petal_width"]])
0.7266371274126329
计算距离,check下
from sklearn.neighbors import NearestNeighbors
nbrs = NearestNeighbors(n_neighbors=len(iris_df), n_jobs=-1)
nbrs.fit(iris_df.iloc[:,[2,3]])
NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
metric_params=None, n_jobs=-1, n_neighbors=150, p=2,
radius=1.0)
d, index = nbrs.kneighbors(iris_df.iloc[:,[2,3]],return_distance=True)
from functools import reduce #python 3
total_distance = reduce(lambda x,y: x+y,np.array(pd.DataFrame(d).iloc[:,1:150]).tolist())
from scipy import stats
stats.describe(total_distance)
DescribeResult(nobs=22350, minmax=(0.0, 6.262587324740471), mean=2.185682454621745, variance=2.6174775533104904, skewness=0.3422940721262964, kurtosis=-1.1637573960810108)
pd.DataFrame({"total_distance":total_distance}).describe()
total_distance | |
---|---|
count | 22350.000000 |
mean | 2.185682 |
std | 1.617862 |
min | 0.000000 |
25% | 0.640312 |
50% | 1.941649 |
75% | 3.544009 |
max | 6.262587 |
从数据上看,有点接近25%分位数。
meanshift的简单介绍到此为止,有些业务场景下,这个算法还是很好用的。需要具体问题具体分析。
2021-03-31 于南京市江宁区九龙湖