数据科学【七】:聚类(三)
本文数据为CIFAR-10 dataset
加载数据集
打开“batch_1”, 并随机显示一个图像:
import pickle
import numpy as np
import random
with open("data_batch_1", 'rb') as f_batch:
data_dict = pickle.load(f_batch, encoding='bytes')
image_datas = data_dict[b'data']
random_idx = random.randint(0, len(image_datas)-1)
image_array = np.reshape(image_datas[random_idx], (3, 32, 32)).transpose(1, 2, 0)
plt.imshow(image_array)
plt.show()
这是🐎。
基于聚类的图像压缩
对于每一个像素,可将其视为一个数据。我们将以上图片的所有像素进行聚类,得到四个聚类,然后将每一像素用其所属聚类中心替换:
pixels = np.reshape(image_array, (32*32, 3))
kmeans_d = KMeans(n_clusters = 4)
kmeans_d.fit(pixels)
pixel_labels = kmeans_d.labels_
center_colors = kmeans_d.cluster_centers_
pixel_labels_2d = np.reshape(pixel_labels, (32, 32))
image_avg_array = image_array
for i in range(32):
for j in range(32):
image_avg_array[i][j] = center_colors[pixel_labels_2d[i][j]]
plt.imshow(image_avg_array)
plt.show()
批量处理
我们可以实现一个函数,对整个数据集应用这种变换:
def trans(idx, K):
image_array = np.reshape(image_datas[idx], (3, 32, 32)).transpose(1, 2, 0)
pixels = np.reshape(image_array, (32*32, 3))
kmeans_e = KMeans(n_clusters = K)
kmeans_e.fit(pixels)
pixel_labels = kmeans_e.labels_
center_colors = kmeans_e.cluster_centers_
pixel_labels_2d = np.reshape(pixel_labels, (32, 32))
image_avg_array = image_array
for i in range(32):
for j in range(32):
image_avg_array[i][j] = center_colors[pixel_labels_2d[i][j]]
plt.imshow(image_avg_array)
plt.show()
def all_trans(K):
for i in range(len(image_datas)):
trans(i, K)