Implementing K-means
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
#找最近中心
def find_closest_centroids(X, centroids):
m = X.shape[0]
k = centroids.shape[0]
idx = np.zeros(m)
for i in range(m):
min_dist = 1000000
for j in range(k):
dist = np.sum((X[i,:] - centroids[j,:]) ** 2)
if dist < min_dist:
min_dist = dist
idx[i] = j
return idx
#计算中心均值
def compute_centroids(X, idx, k):
m, n = X.shape
centroids = np.zeros((k, n))
for i in range(k):
indices = np.where(idx == i)
centroids[i,:] = (np.sum(X[indices,:], axis=1) / len(indices[0])).ravel()
return centroids
#初始化和运行kmenas
def run_k_means(X, initial_centroids, max_iters):
m, n = X.shape
k = initial_centroids.shape[0]
idx = np.zeros(m)
centroids = initial_centroids
for i in range(max_iters):
idx = find_closest_centroids(X, centroids)
centroids = compute_centroids(X, idx, k)
return idx, centroids
def init_centroids(X, k):
m, n = X.shape
centroids = np.zeros((k, n))
idx = np.random.randint(0, m, k)
for i in range(k):
centroids[i,:] = X[idx[i],:]
return centroids
K-means on example dataset
from scipy.io import loadmat
data=loadmat('ex7data2.mat')
data
{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Wed Nov 16 00:48:22 2011',
'__version__': '1.0',
'__globals__': [],
'X': array([[ 1.84207953, 4.6075716 ],
[ 5.65858312, 4.79996405],
[ 6.35257892, 3.2908545 ],
[ 2.90401653, 4.61220411],
[ 3.23197916, 4.93989405],
[ 1.24792268, 4.93267846],
[ 1.97619886, 4.43489674],
[ 2.23454135, 5.05547168],
[ 2.98340757, 4.84046406],
[ 2.97970391, 4.80671081],
[ 2.11496411, 5.37373587],
[ 2.12169543, 5.20854212],
[ 1.5143529 , 4.77003303],
[ 2.16979227, 5.27435354],
[ 0.41852373, 4.88312522],
[ 2.47053695, 4.80418944],
[ 4.06069132, 4.99503862],
[ 3.00708934, 4.67897758],
[ 0.66632346, 4.87187949],
[ 3.1621865 , 4.83658301],
[ 0.51155258, 4.91052923],
[ 3.1342801 , 4.96178114],
[ 2.04974595, 5.6241395 ],
[ 0.66582785, 5.24399257],
[ 1.01732013, 4.84473647],
[ 2.17893568, 5.29758701],
[ 2.85962615, 5.26041997],
[ 1.30882588, 5.30158701],
[ 0.99253246, 5.01567424],
[ 1.40372638, 4.57527684],
[ 2.66046572, 5.19623848],
[ 2.79995882, 5.11526323],
[ 2.06995345, 4.6846713 ],
[ 3.29765181, 5.59205535],
[ 1.8929766 , 4.89043209],
[ 2.55983064, 5.26397756],
[ 1.15354031, 4.67866717],
[ 2.25150754, 5.4450031 ],
[ 2.20960296, 4.91469264],
[ 1.59141937, 4.83212573],
[ 1.67838038, 5.26903822],
[ 2.59148642, 4.92593394],
[ 2.80996442, 5.53849899],
[ 0.95311627, 5.58037108],
[ 1.51775276, 5.03836638],
[ 3.23114248, 5.78429665],
[ 2.54180011, 4.81098738],
[ 3.81422865, 4.73526796],
[ 1.68495829, 4.59643553],
[ 2.17777173, 4.86154019],
[ 1.8173328 , 5.13333907],
[ 1.85776553, 4.86962414],
[ 3.03084301, 5.24057582],
[ 2.92658295, 5.09667923],
[ 3.43493543, 5.34080741],
[ 3.20367116, 4.85924759],
[ 0.10511804, 4.72916344],
[ 1.40597916, 5.06636822],
[ 2.24185052, 4.9244617 ],
[ 1.36678395, 5.26161095],
[ 1.70725482, 4.04231479],
[ 1.91909566, 5.57848447],
[ 1.60156731, 4.64453012],
[ 0.37963437, 5.26194729],
[ 2.02134502, 4.41267445],
[ 1.12036737, 5.20880747],
[ 2.26901428, 4.61818883],
[-0.24512713, 5.74019237],
[ 2.12857843, 5.01149793],
[ 1.84419981, 5.03153948],
[ 2.32558253, 4.74867962],
[ 1.52334113, 4.87916159],
[ 1.02285128, 5.0105065 ],
[ 1.85382737, 5.00752482],
[ 2.20321658, 4.94516379],
[ 1.20099981, 4.57829763],
[ 1.02062703, 4.62991119],
[ 1.60493227, 5.13663139],
[ 0.47647355, 5.13535977],
[ 0.3639172 , 4.73332823],
[ 0.31319845, 5.54694644],
[ 2.28664839, 5.0076699 ],
[ 2.15460139, 5.46282959],
[ 2.05288518, 4.77958559],
[ 4.88804332, 5.50670795],
[ 2.40304747, 5.08147326],
[ 2.56869453, 5.20687886],
[ 1.82975993, 4.59657288],
[ 0.54845223, 5.0267298 ],
[ 3.17109619, 5.5946452 ],
[ 3.04202069, 5.00758373],
[ 2.40427775, 5.0258707 ],
[ 0.17783466, 5.29765032],
[ 2.61428678, 5.22287414],
[ 2.30097798, 4.97235844],
[ 3.90779317, 5.09464676],
[ 2.05670542, 5.23391326],
[ 1.38133497, 5.00194962],
[ 1.16074178, 4.67727927],
[ 1.72818199, 5.36028437],
[ 3.20360621, 0.7222149 ],
[ 3.06192918, 1.5719211 ],
[ 4.01714917, 1.16070647],
[ 1.40260822, 1.08726536],
[ 4.08164951, 0.87200343],
[ 3.15273081, 0.98155871],
[ 3.45186351, 0.42784083],
[ 3.85384314, 0.7920479 ],
[ 1.57449255, 1.34811126],
[ 4.72372078, 0.62044136],
[ 2.87961084, 0.75413741],
[ 0.96791348, 1.16166819],
[ 1.53178107, 1.10054852],
[ 4.13835915, 1.24780979],
[ 3.16109021, 1.29422893],
[ 2.95177039, 0.89583143],
[ 3.27844295, 1.75043926],
[ 2.1270185 , 0.95672042],
[ 3.32648885, 1.28019066],
[ 2.54371489, 0.95732716],
[ 3.233947 , 1.08202324],
[ 4.43152976, 0.54041 ],
[ 3.56478625, 1.11764714],
[ 4.25588482, 0.90643957],
[ 4.05386581, 0.53291862],
[ 3.08970176, 1.08814448],
[ 2.84734459, 0.26759253],
[ 3.63586049, 1.12160194],
[ 1.95538864, 1.32156857],
[ 2.88384005, 0.80454506],
[ 3.48444387, 1.13551448],
[ 3.49798412, 1.10046402],
[ 2.45575934, 0.78904654],
[ 3.2038001 , 1.02728075],
[ 3.00677254, 0.62519128],
[ 1.96547974, 1.2173076 ],
[ 2.17989333, 1.30879831],
[ 2.61207029, 0.99076856],
[ 3.95549912, 0.83269299],
[ 3.64846482, 1.62849697],
[ 4.18450011, 0.45356203],
[ 3.7875723 , 1.45442904],
[ 3.30063655, 1.28107588],
[ 3.02836363, 1.35635189],
[ 3.18412176, 1.41410799],
[ 4.16911897, 0.20581038],
[ 3.24024211, 1.14876237],
[ 3.91596068, 1.01225774],
[ 2.96979716, 1.01210306],
[ 1.12993856, 0.77085284],
[ 2.71730799, 0.48697555],
[ 3.1189017 , 0.69438336],
[ 2.4051802 , 1.11778123],
[ 2.95818429, 1.01887096],
[ 1.65456309, 1.18631175],
[ 2.39775807, 1.24721387],
[ 2.28409305, 0.64865469],
[ 2.79588724, 0.99526664],
[ 3.41156277, 1.1596363 ],
[ 3.50663521, 0.73878104],
[ 3.93616029, 1.46202934],
[ 3.90206657, 1.27778751],
[ 2.61036396, 0.88027602],
[ 4.37271861, 1.02914092],
[ 3.08349136, 1.19632644],
[ 2.1159935 , 0.7930365 ],
[ 2.15653404, 0.40358861],
[ 2.14491101, 1.13582399],
[ 1.84935524, 1.02232644],
[ 4.1590816 , 0.61720733],
[ 2.76494499, 1.43148951],
[ 3.90561153, 1.16575315],
[ 2.54071672, 0.98392516],
[ 4.27783068, 1.1801368 ],
[ 3.31058167, 1.03124461],
[ 2.15520661, 0.80696562],
[ 3.71363659, 0.45813208],
[ 3.54010186, 0.86446135],
[ 1.60519991, 1.1098053 ],
[ 1.75164337, 0.68853536],
[ 3.12405123, 0.67821757],
[ 2.37198785, 1.42789607],
[ 2.53446019, 1.21562081],
[ 3.6834465 , 1.22834538],
[ 3.2670134 , 0.32056676],
[ 3.94159139, 0.82577438],
[ 3.2645514 , 1.3836869 ],
[ 4.30471138, 1.10725995],
[ 2.68499376, 0.35344943],
[ 3.12635184, 1.2806893 ],
[ 2.94294356, 1.02825076],
[ 3.11876541, 1.33285459],
[ 2.02358978, 0.44771614],
[ 3.62202931, 1.28643763],
[ 2.42865879, 0.86499285],
[ 2.09517296, 1.14010491],
[ 5.29239452, 0.36873298],
[ 2.07291709, 1.16763851],
[ 0.94623208, 0.24522253],
[ 2.73911908, 1.10072284],
[ 6.00506534, 2.72784171],
[ 6.05696411, 2.94970433],
[ 6.77012767, 3.21411422],
[ 5.64034678, 2.69385282],
[ 5.63325403, 2.99002339],
[ 6.17443157, 3.29026488],
[ 7.24694794, 2.96877424],
[ 5.58162906, 3.33510375],
[ 5.3627205 , 3.14681192],
[ 4.70775773, 2.78710869],
[ 7.42892098, 3.4667949 ],
[ 6.64107248, 3.05998738],
[ 6.37473652, 2.56253059],
[ 7.28780324, 2.75179885],
[ 6.20295231, 2.67856179],
[ 5.38736041, 2.26737346],
[ 5.6673103 , 2.96477867],
[ 6.59702155, 3.07082376],
[ 7.75660559, 3.15604465],
[ 6.63262745, 3.14799183],
[ 5.76634959, 3.14271707],
[ 5.99423154, 2.75707858],
[ 6.37870407, 2.65022321],
[ 5.74036233, 3.10391306],
[ 4.61652442, 2.79320715],
[ 5.33533999, 3.03928694],
[ 5.37293912, 2.81684776],
[ 5.03611162, 2.92486087],
[ 5.52908677, 3.33681576],
[ 6.05086942, 2.80702594],
[ 5.132009 , 2.19812195],
[ 5.73284945, 2.87738132],
[ 6.78110732, 3.05676866],
[ 6.44834449, 3.35299225],
[ 6.39941482, 2.89756948],
[ 5.86067925, 2.99577129],
[ 6.44765183, 3.16560945],
[ 5.36708111, 3.19502552],
[ 5.88735565, 3.34615566],
[ 3.96162465, 2.72025046],
[ 6.28438193, 3.17360643],
[ 4.20584789, 2.81647368],
[ 5.32615581, 3.03314047],
[ 7.17135204, 3.4122727 ],
[ 7.4949275 , 2.84018754],
[ 7.39807241, 3.48487031],
[ 5.02432984, 2.98683179],
[ 5.31712478, 2.81741356],
[ 5.87655237, 3.21661109],
[ 6.03762833, 2.68303512],
[ 5.91280273, 2.85631938],
[ 6.69451358, 2.89056083],
[ 6.01017978, 2.72401338],
[ 6.92721968, 3.19960026],
[ 6.33559522, 3.30864291],
[ 6.24257071, 2.79179269],
[ 5.57812294, 3.24766016],
[ 6.40773863, 2.67554951],
[ 6.80029526, 3.17579578],
[ 7.21684033, 2.72896575],
[ 6.5110074 , 2.72731907],
[ 4.60630534, 3.329458 ],
[ 7.65503226, 2.87095628],
[ 5.50295759, 2.62924634],
[ 6.63060699, 3.01502301],
[ 3.45928006, 2.68478445],
[ 8.20339815, 2.41693495],
[ 4.95679428, 2.89776297],
[ 5.37052667, 2.44954813],
[ 5.69797866, 2.94977132],
[ 6.27376271, 2.24256036],
[ 5.05274526, 2.75692163],
[ 6.88575584, 2.88845269],
[ 4.1877442 , 2.89283463],
[ 5.97510328, 3.0259191 ],
[ 6.09457129, 2.61867975],
[ 5.72395697, 3.04454219],
[ 4.37249767, 3.05488217],
[ 6.29206262, 2.77573856],
[ 5.14533035, 4.13225692],
[ 6.5870565 , 3.37508345],
[ 5.78769095, 3.29255127],
[ 6.72798098, 3.0043983 ],
[ 6.64078939, 2.41068839],
[ 6.23228878, 2.72850902],
[ 6.21772724, 2.80994633],
[ 5.78116301, 3.07987787],
[ 6.62447253, 2.74453743],
[ 5.19590823, 3.06972937],
[ 5.87177181, 3.2551773 ],
[ 5.89562099, 2.89843977],
[ 5.6175432 , 2.5975071 ],
[ 5.63176103, 3.04758747],
[ 5.50258659, 3.11869075],
[ 6.48212628, 2.5508514 ],
[ 7.30278708, 3.38015979],
[ 6.99198434, 2.98706729],
[ 4.8255341 , 2.77961664],
[ 6.11768055, 2.85475655],
[ 0.94048944, 5.71556802]])}
X=data['X']
plt.scatter(X[:,0],X[:,1])
plt.show()
idx,centroids=run_k_means(X,init_centroids(X,3),10)
#获得类别行索引
cluster1=X[np.where(idx==0)[0],:]
cluster2=X[np.where(idx==1)[0],:]
cluster3=X[np.where(idx==2)[0],:]
fig,ax=plt.subplots(figsize=(12,8))
ax.scatter(cluster1[:,0], cluster1[:,1], s=30, color='r', label='Cluster 1')
ax.scatter(cluster2[:,0], cluster2[:,1], s=30, color='g', label='Cluster 2')
ax.scatter(cluster3[:,0], cluster3[:,1], s=30, color='b', label='Cluster 3')
ax.legend()
plt.show()
Image compression with K-means
from IPython.display import Image
Image(filename='bird_small.png')
image=loadmat('bird_small.mat')
image
{'__header__': b'MATLAB 5.0 MAT-file, Platform: GLNXA64, Created on: Tue Jun 5 04:06:24 2012',
'__version__': '1.0',
'__globals__': [],
'A': array([[[219, 180, 103],
[230, 185, 116],
[226, 186, 110],
...,
[ 14, 15, 13],
[ 13, 15, 12],
[ 12, 14, 12]],
[[230, 193, 119],
[224, 192, 120],
[226, 192, 124],
...,
[ 16, 16, 13],
[ 14, 15, 10],
[ 11, 14, 9]],
[[228, 191, 123],
[228, 191, 121],
[220, 185, 118],
...,
[ 14, 16, 13],
[ 13, 13, 11],
[ 11, 15, 10]],
...,
[[ 15, 18, 16],
[ 18, 21, 18],
[ 18, 19, 16],
...,
[ 81, 45, 45],
[ 70, 43, 35],
[ 72, 51, 43]],
[[ 16, 17, 17],
[ 17, 18, 19],
[ 20, 19, 20],
...,
[ 80, 38, 40],
[ 68, 39, 40],
[ 59, 43, 42]],
[[ 15, 19, 19],
[ 20, 20, 18],
[ 18, 19, 17],
...,
[ 65, 43, 39],
[ 58, 37, 38],
[ 52, 39, 34]]], dtype=uint8)}
A=image['A']
A.shape
(128, 128, 3)
#对图像数据预处理
A=A/255
#reshape the array
X=np.reshape(A,(A.shape[0]*A.shape[1],A.shape[2]))
X.shape
(16384, 3)
#初始化聚类中心
initial_centroids = init_centroids(X, 16)
#运行kmeans
idx, centroids = run_k_means(X, initial_centroids, 10)
#找到最近的聚类中心
idx=find_closest_centroids(X,centroids)
#将每个像素映射到中心值
X_recovered = centroids[idx.astype(int),:]
X_recovered.shape
(16384, 3)
#变为最初维度
X_recovered = np.reshape(X_recovered, (A.shape[0], A.shape[1], A.shape[2]))
X_recovered.shape
(128, 128, 3)
plt.imshow(X_recovered)
plt.show()
principal component analysis
def pca(X):
#标准化特征
X=(X-X.mean())/X.std()
#计算协方差矩阵
X=np.matrix(X)
cov=(X.T*X)/X.shape[0]
#执行SVD
U, S, V = np.linalg.svd(cov)
return U,S,V
将主成分U投影到较低位空间
def project_data(X,U,k):
U_reduced=U[:,:k]
return np.dot(X,U_reduced)
data=loadmat('ex7data1.mat')
data
{'__header__': b'MATLAB 5.0 MAT-file, Platform: PCWIN64, Created on: Mon Nov 14 22:41:44 2011',
'__version__': '1.0',
'__globals__': [],
'X': array([[3.38156267, 3.38911268],
[4.52787538, 5.8541781 ],
[2.65568187, 4.41199472],
[2.76523467, 3.71541365],
[2.84656011, 4.17550645],
[3.89067196, 6.48838087],
[3.47580524, 3.63284876],
[5.91129845, 6.68076853],
[3.92889397, 5.09844661],
[4.56183537, 5.62329929],
[4.57407171, 5.39765069],
[4.37173356, 5.46116549],
[4.19169388, 4.95469359],
[5.24408518, 4.66148767],
[2.8358402 , 3.76801716],
[5.63526969, 6.31211438],
[4.68632968, 5.6652411 ],
[2.85051337, 4.62645627],
[5.1101573 , 7.36319662],
[5.18256377, 4.64650909],
[5.70732809, 6.68103995],
[3.57968458, 4.80278074],
[5.63937773, 6.12043594],
[4.26346851, 4.68942896],
[2.53651693, 3.88449078],
[3.22382902, 4.94255585],
[4.92948801, 5.95501971],
[5.79295774, 5.10839305],
[2.81684824, 4.81895769],
[3.88882414, 5.10036564],
[3.34323419, 5.89301345],
[5.87973414, 5.52141664],
[3.10391912, 3.85710242],
[5.33150572, 4.68074235],
[3.37542687, 4.56537852],
[4.77667888, 6.25435039],
[2.6757463 , 3.73096988],
[5.50027665, 5.67948113],
[1.79709714, 3.24753885],
[4.3225147 , 5.11110472],
[4.42100445, 6.02563978],
[3.17929886, 4.43686032],
[3.03354125, 3.97879278],
[4.6093482 , 5.879792 ],
[2.96378859, 3.30024835],
[3.97176248, 5.40773735],
[1.18023321, 2.87869409],
[1.91895045, 5.07107848],
[3.95524687, 4.5053271 ],
[5.11795499, 6.08507386]])}
X=data['X']
plt.scatter(X[:,0],X[:,1])
plt.show()
U,S,V=pca(X)
U
matrix([[-0.79241747, -0.60997914],
[-0.60997914, 0.79241747]])
Z = project_data(X, U, k=1)
Z
matrix([[-4.74689738],
[-7.15889408],
[-4.79563345],
[-4.45754509],
[-4.80263579],
[-7.04081342],
[-4.97025076],
[-8.75934561],
[-6.2232703 ],
[-7.04497331],
[-6.91702866],
[-6.79543508],
[-6.3438312 ],
[-6.99891495],
[-4.54558119],
[-8.31574426],
[-7.16920841],
[-5.08083842],
[-8.54077427],
[-6.94102769],
[-8.5978815 ],
[-5.76620067],
[-8.2020797 ],
[-6.23890078],
[-4.37943868],
[-5.56947441],
[-7.53865023],
[-7.70645413],
[-5.17158343],
[-6.19268884],
[-6.24385246],
[-8.02715303],
[-4.81235176],
[-7.07993347],
[-5.45953289],
[-7.60014707],
[-4.39612191],
[-7.82288033],
[-3.40498213],
[-6.54290343],
[-7.17879573],
[-5.22572421],
[-4.83081168],
[-7.23907851],
[-4.36164051],
[-6.44590096],
[-2.69118076],
[-4.61386195],
[-5.88236227],
[-7.76732508]])
#恢复数据
def recover_data(Z, U, k):
U_reduced = U[:,:k]
return np.dot(Z, U_reduced.T)
X_recovered = recover_data(Z, U, 1)
X_recovered
matrix([[3.76152442, 2.89550838],
[5.67283275, 4.36677606],
[3.80014373, 2.92523637],
[3.53223661, 2.71900952],
[3.80569251, 2.92950765],
[5.57926356, 4.29474931],
[3.93851354, 3.03174929],
[6.94105849, 5.3430181 ],
[4.93142811, 3.79606507],
[5.58255993, 4.29728676],
[5.48117436, 4.21924319],
[5.38482148, 4.14507365],
[5.02696267, 3.8696047 ],
[5.54606249, 4.26919213],
[3.60199795, 2.77270971],
[6.58954104, 5.07243054],
[5.681006 , 4.37306758],
[4.02614513, 3.09920545],
[6.76785875, 5.20969415],
[5.50019161, 4.2338821 ],
[6.81311151, 5.24452836],
[4.56923815, 3.51726213],
[6.49947125, 5.00309752],
[4.94381398, 3.80559934],
[3.47034372, 2.67136624],
[4.41334883, 3.39726321],
[5.97375815, 4.59841938],
[6.10672889, 4.70077626],
[4.09805306, 3.15455801],
[4.90719483, 3.77741101],
[4.94773778, 3.80861976],
[6.36085631, 4.8963959 ],
[3.81339161, 2.93543419],
[5.61026298, 4.31861173],
[4.32622924, 3.33020118],
[6.02248932, 4.63593118],
[3.48356381, 2.68154267],
[6.19898705, 4.77179382],
[2.69816733, 2.07696807],
[5.18471099, 3.99103461],
[5.68860316, 4.37891565],
[4.14095516, 3.18758276],
[3.82801958, 2.94669436],
[5.73637229, 4.41568689],
[3.45624014, 2.66050973],
[5.10784454, 3.93186513],
[2.13253865, 1.64156413],
[3.65610482, 2.81435955],
[4.66128664, 3.58811828],
[6.1549641 , 4.73790627]])
plt.scatter(np.array(X_recovered[:, 0]),np.array(X_recovered[:, 1]))
plt.show()