# -*- coding: utf-8 -*-
"""
Created on Tue Jul 26 15:16:42 2016
@author: brian
"""
import numpy as np
import matplotlib.pyplot as plt
def loadData(file = "testSet.txt"):
dataMat = np.loadtxt(file)
return dataMat
def showData(data):
plt.plot(data[:,0],data[:,1],"bo")
def classify(dataSet , centers):
"""
dataSet Numbers * dims; centers: numbers * dims
"""
k = np.shape(centers)[0] # k 类
n = np.shape(dataSet)[0] # 样本数
dim =np.shape(dataSet)[1]# 样本维度
diss = np.zeros((n,k))
for i in range(k):
dis_row = dataSet - centers[i,:]
seq = dis_row**2
dis = (seq[:,0] + seq[:,1])**(0.5)
diss[:,i] = dis
return np.argmin(diss,axis = 1)
def randCent(dataSet, k):
"""
随机选取K个点, 保证K处于样本数据之间。
"""
n = np.shape(dataSet)[1]
centroids = np.mat(np.zeros((k,n)))#create centroid mat
for j in range(n):
#create random cluster centers, within bounds of each dimension
minJ = min(dataSet[:,j])
rangeJ = float(max(dataSet[:,j]) - minJ)
centroids[:,j] = np.mat(minJ + rangeJ * np.random.rand(k,1))
return np.asarray(centroids)
def updataCents(datas):
return datas.mean(axis = 0)
def Kmeans(data,k=4,iteration = 10):
cents = randCent(data, k)
# in iterator for j in range(10):
for j in range(3):
clasIndex = classify(data, cents)
plt.figure(j)
for i in range(k):
if i == 0:
data_i = data[np.where(clasIndex==i)]
plt.plot(data_i[:,0],data_i[:,1],"bo")
plt.plot(cents[i,0],cents[i,1],"bx" , markersize=12 ,markeredgewidth =3)
elif i == 1:
data_i = data[np.where(clasIndex==i)]
plt.plot(data_i[:,0],data_i[:,1],"go")
plt.plot(cents[i,0],cents[i,1],"gx" ,markersize=12 , markeredgewidth=3)
elif i == 2:
data_i = data[np.where(clasIndex==i)]
plt.plot(data_i[:,0],data_i[:,1],"ro")
plt.plot(cents[i,0],cents[i,1],"rx" , markersize=12 ,markeredgewidth =3)
elif i == 3:
data_i = data[np.where(clasIndex==i)]
plt.plot(data_i[:,0],data_i[:,1],"co")
plt.plot(cents[i,0],cents[i,1],"cx" , markersize=12 ,markeredgewidth =3)
cents[i,:] = updataCents(data[np.where(clasIndex==i)])
plt.show()
k=4
data = loadData()
Kmeans(data)