1、导入所用的库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin
# 对两个序列中的点进行距离匹配的函数
from sklearn.datasets import load_sample_image
# 导入图片数据所用的类
from sklearn.utils import shuffle # 打乱一个有序的序列的类
# 2、导入数据,探索数据
china = load_sample_image("china.jpg")
china
china = load_sample_image("china.jpg")
china
china.dtypechina.shape
china[0][0]
newimage = china.reshape((427 * 640,3))import pandas as pd
pd.DataFrame(newimage).drop_duplicates().shape
# 图像可视化
plt.figure(figsize=(15,15))
# 导入3维数组形成的图片
plt.imshow(china)
flower = load_sample_image("flower.jpg")
plt.figure(figsize=(15,15))
plt.imshow(flower)
# 3、决定超参数,数据预处理
n_clusters = 64
china = np.array(china, dtype=np.float64) / china.max()
w, h, d = original_shape = tuple(china.shape)
assert d == 3
image_array = np.reshape(china, (w * h, d))# plt.imshow在浮点数上表现非常优异,在这里我们把china中的数据,转换为浮点数,压缩到[0,1]之间
china = np.array(china, dtype=np.float64) / china.max()
w, h, d = original_shape = tuple(china.shape)
assert d == 3
d_ = 5
assert d_ == 3, "一个格子中的特征数目不等于3种"image_array = np.reshape(china, (w * h, d)) # 改变结构
image_array
Out[32]:
array([[0.68235294, 0.78823529, 0.90588235],
[0.68235294, 0.78823529, 0.90588235],
[0.68235294, 0.78823529, 0.90588235],
...,
[0.16862745, 0.19215686, 0.15294118],
[0.05098039, 0.08235294, 0.02352941],
[0.05882353, 0.09411765, 0.02745098]])
image_array.shape
Out[33]: (273280, 3)
a = np.random.random((2,4))
a.reshape((4,2))
Out[36]:
array([[0.85653672, 0.60094485],
[0.01856518, 0.30548777],
[0.51642109, 0.9919073 ],
[0.97541273, 0.70377145]])
np.reshape(a,(2,2,2))
np.reshape(a,(2,2,2))
Out[37]:
array([[[0.85653672, 0.60094485],
[0.01856518, 0.30548777]],
[[0.51642109, 0.9919073 ],
[0.97541273, 0.70377145]]])
# 4、对数据进行K-Means的矢量量化
# 首先,先使用1000个数据来找出质心
image_array_sample = shuffle(image_array, random_state=0)[:1000]
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(image_array_sample)
kmeans.cluster_centers_
Out[38]:
array([[0.11798806, 0.11884058, 0.07007673],
[0.80874811, 0.82262443, 0.85671192],
[0.47614379, 0.46895425, 0.27124183],
[0.92831097, 0.95803234, 0.99566563],
[0.52 , 0.5254902 , 0.39529412],
[0.61895425, 0.67712418, 0.70816993],
[0.31198257, 0.34030501, 0.18954248],
[0.82923351, 0.90641711, 0.98743316],
[0.80392157, 0.53006536, 0.3751634 ],
[0.25202614, 0.23764706, 0.20104575],
[0.03328773, 0.02836297, 0.01732786],
[0.3454902 , 0.1854902 , 0.12470588],
[0.52156863, 0.49150327, 0.52592593],
[0.74457516, 0.83934641, 0.95045752],
[0.72941176, 0.35764706, 0.23137255],
[0.41470588, 0.44656863, 0.40980392],
[0.96176471, 0.77058824, 0.63039216],
[0.69019608, 0.74705882, 0.7605042 ],
[0.57019608, 0.41098039, 0.34588235],
[0.93630422, 0.93594771, 0.94913844],
[0.3875817 , 0.45620915, 0.10588235],
[0.56684492, 0.62388592, 0.64171123],
[0.30795848, 0.29434833, 0.24844291],
[0.23157895, 0.16594427, 0.1376677 ],
[0.28403361, 0.4140056 , 0.37591036],
[0.60566449, 0.59389978, 0.53943355],
[0.32009804, 0.32230392, 0.11397059],
[0.48948307, 0.32442068, 0.28163993],
[0.89411765, 0.63764706, 0.43529412],
[0.12296919, 0.04033613, 0.032493 ],
[0.71973856, 0.73777778, 0.68444444],
[0.15496138, 0.15995247, 0.12477718],
[0.05568627, 0.16 , 0.19843137],
[0.37254902, 0.36302521, 0.30308123],
[0.83529412, 0.86349206, 0.89505135],
[0.07029412, 0.07941176, 0.04921569],
[0.20588235, 0.3379085 , 0.33202614],
[0.71328976, 0.41960784, 0.31851852],
[0.97279933, 0.97688778, 0.99382562],
[0.7827451 , 0.81098039, 0.80366013],
[0.19839572, 0.07950089, 0.0798574 ],
[0.88225239, 0.93584716, 0.99034691],
[0.36862745, 0.11960784, 0.03986928],
[0.59176471, 0.55215686, 0.43137255],
[0.15462185, 0.24201681, 0.24481793],
[0.40962567, 0.40427807, 0.1486631 ],
[0.22941176, 0.28169935, 0.1372549 ],
[0.24615385, 0.24012066, 0.07812971],
[0.77571644, 0.86998492, 0.96515837],
[0.51693405, 0.55080214, 0.49340463],
[0.89847495, 0.90544662, 0.91503268],
[0.5827451 , 0.55098039, 0.32078431],
[0.40751634, 0.41143791, 0.23366013],
[0.51204482, 0.32661064, 0.17983193],
[0.69313725, 0.60490196, 0.47156863],
[0.33893557, 0.47507003, 0.45490196],
[0.38126362, 0.28235294, 0.23616558],
[0.46928105, 0.43093682, 0.36949891],
[0.73333333, 0.7822376 , 0.7928489 ],
[0.43431373, 0.19313725, 0.16568627],
[0.51421569, 0.50735294, 0.23872549],
[0.47745098, 0.63235294, 0.59117647],
[0.83529412, 0.83235294, 0.74411765],
[0.73557423, 0.8162465 , 0.91134454]])
# 找到质心之后,按照已存在的质心对数据进行聚类
labels = kmeans.predict(image_array)
labels.shape
Out[39]: (273280,)
# 使用质心来替换所有的样本
image_kmeans = image_array.copy()
image_kmeans # 27w个样本点,9w多种不同的颜色(像素点)
Out[40]:
array([[0.68235294, 0.78823529, 0.90588235],
[0.68235294, 0.78823529, 0.90588235],
[0.68235294, 0.78823529, 0.90588235],
...,
[0.16862745, 0.19215686, 0.15294118],
[0.05098039, 0.08235294, 0.02352941],
[0.05882353, 0.09411765, 0.02745098]])
# labels这27w个样本点所对应的蔟的质心的索引
# kmeans.cluster_centers_[labels[1]]
for i in range(w*h):
image_kmeans[i] = kmeans.cluster_centers_[labels[i]]
image_kmeans
Out[41]:
array([[0.73557423, 0.8162465 , 0.91134454],
[0.73557423, 0.8162465 , 0.91134454],
[0.73557423, 0.8162465 , 0.91134454],
...,
[0.15496138, 0.15995247, 0.12477718],
[0.07029412, 0.07941176, 0.04921569],
[0.07029412, 0.07941176, 0.04921569]])
pd.DataFrame(image_kmeans).drop_duplicates().shape
Out[42]: (64, 3)
# 恢复图片的结构
image_kmeans = image_kmeans.reshape(w,h,d)
image_kmeans.shape
Out[43]: (427, 640, 3)
# 5. 对数据进行随机的矢量量化
centroid_random = shuffle(image_array, random_state=0)[:n_clusters]# 函数pairwise_distances_argmin(x1,x2,axis)
# 用来计算X2种的每个样本到X1中的每个样本的距离,并返回和x2相同形状的,x1中对应的最近的样本点的索引。labels_random = pairwise_distances_argmin(centroid_random,image_array,axis=0)
labels_random.shape
len(set(labels_random))# 使用随机质心来替换所有的样本
image_random = image_array.copy()
for i in range(w*h):
image_random[i] = centroid_random[labels_random[i]]# 恢复图片的结构
image_random = image_random.reshape(w,h,d)
image_random.shape
Out[44]: (427, 640, 3)
# 6、将原图,按Kmeans矢量量化和随机矢量量化的图像绘制出来
plt.figure(figsize=(10,10))
plt.axis('off') # 不显示坐标轴
plt.title('Original image (96,615 colors)')
plt.imshow(china)
plt.figure(figsize=(10,10))
plt.axis('off')
plt.title('Quantized image (64 colors, K-Means)')
plt.imshow(image_kmeans)
plt.figure(figsize=(10,10))
plt.axis('off')
plt.title('Quantized image (64 colors, Random)')
plt.imshow(image_random)
plt.show()
通过对比可以看出第一个Kmeans聚类,取的64种颜色的图片和原来的图区别不大;而使用随机的方法,取64种颜色的图片和原来的图片差别还是挺大的。因此,可以发现kmeans的优势还是很大的。