python代码实现的特效,最小python代码实现Kmeans算法并图形化展示

Kmeans算法简介:

Kmeans算法基本思想是初始随机给定K个簇中心,按照距离最近的原则把待分类的样本点分到各个簇,然后根据平均值计算新的簇的质心。一直迭代直到两次簇心之间的迭代距离小于要求的值。

基本步骤

未知簇心的数据集

初始化簇心

随机选取K个数据做为簇心

计算各个点到簇心的距离,并聚类到离该点最近的簇心上去

计算每一个簇类距离平均,并将这个平均值做为新的簇心

重复4

重复5

数据:

数据来源是CSIE,里面的adult数据,主要分类有:

age: continuous.

workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.

fnlwgt: continuous.

education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.

education-num: continuous.

marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.

occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.

relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.

race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.

sex: Female, Male.

capital-gain: continuous.

capital-loss: continuous.

hours-per-week: continuous.

native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.

需要对数据做mapping,这样才能计算距离:

convert_work= {

b'workclass':0.0,

b'Private':1.0,

b'Self-emp-not-inc':2.0,

b'Self-emp-inc':3.0,

b'Federal-gov':4.0,

b'Local-gov':5.0,

b'State-gov':6.0,

b'Without-pay':7.0,

b'Never-worked':8.0,

}

marry_convert = {

b"marital-status":1.0,

b"Married-civ-spouse":2.0,

b"Divorced":3.0,

b"Never-married":4.0,

b"Separated":5.0,

b"Widowed":6.0,

b"Married-spouse-absent":7.0,

b"Married-AF-spouse":8.0,

}

occupation_convert={

b"Tech-support":1.0,

b"Craft-repair":2.0,

b"Other-service":3.0,

b"Sales":4.0,

b"Exec-managerial":5.0,

b"Prof-specialty":6.0,

b"Handlers-cleaners":7.0,

b"Machine-op-inspct":8.0,

b"Adm-clerical":9.0,

b"Farming-fishing":10.0,

b"Transport-moving":11.0,

b"Priv-house-serv":12.0,

b"Protective-serv":13.0,

b"Armed-Forces":14.0,

}

relationship_convert= {

b"Wife":1.0,

b"Own-child":2.0,

b"Husband":3.0,

b"Not-in-family":4.0,

b"Other-relative":5.0,

b"Unmarried" :6.0,

}

race_convert = {

b"White":1.0,

b"Asian-Pac-Islander":2.0,

b"Amer-Indian-Eskimo":3.0,

b"Other":4.0,

b"Black":5.0,

}

sex_convert = {

b"Female":0,

b"Male" :1,

}

country_convert = {

b"United-States":1.0,

b"Cambodia":2.0,

b"England":3.0,

b"Puerto-Rico":4.0,

b"Canada":5.0,

b"Germany":6.0,

b"Outlying-US(Guam-USVI-etc)":7.0,

b"India":8.0,

b"Japan":9.0,

b"Greece":10.0,

b"South":11.0,

b"China":12.0,

b"Cuba":13.0,

b"Iran":14.0,

b"Honduras":15.0,

b"Philippines":16.0,

b"Italy":17.0,

b"Poland":18.0,

b"Jamaica":19.0,

b"Vietnam":20.0,

b"Mexico":21.0,

b"Portugal":22.0,

b"Ireland":23.0,

b"France":24.0,

b"Dominican-Republic":25.0,

b"Laos":26.0,

b"Ecuador":27.0,

b"Taiwan":28.0,

b"Haiti":29.0,

b"Columbia":30.0,

b"Hungary":31.0,

b"Guatemala":32.0,

b"Nicaragua":33.0,

b"Scotland":34.0,

b"Thailand":35.0,

b"Yugoslavia":36.0,

b"El-Salvador":37.0,

b"Trinadad&Tobago":38.0,

b"Peru":39.0,

b"Hong":40.0,

b"Holand-Netherlands.":41.0,

b"":1.0

}

class_convert = {

b">50K":0.0,

b"<=50K":1.0,

b"":1.0

}

convert_education={

b"Bachelors":0.0,

b"Some-college":1.0,

b"11th":2.0,

b"HS-grad":3.0,

b"Prof-school":4.0,

b"Assoc-acdm":5.0,

b"Assoc-voc":6.0,

b"9th":7.0,

b"7th-8th":8.0,

b"12th":9.0,

b"Masters":10.0,

b"1st-4th":11.0,

b"10th":12.0,

b"Doctorate":13.0,

b"5th-6th":14.0,

b"Preschool":15.0,

b"":1.0

}

读取文件:

def read_file(file,one_hot=False):

np_file = np.genfromtxt(fname=file,delimiter=',',\

replace_space='',

converters={

1:lambda x:convert_work.get(x,1.0), \

3:lambda x:convert_education.get(x,7.0), \

5:lambda x:marry_convert.get(x,4.0), \

6:lambda x:occupation_convert.get(x,3.0), \

7:lambda x:relationship_convert.get(x,5.0), \

8:lambda x:race_convert.get(x,4.0), \

9:lambda x:sex_convert.get(x,1.0), \

13:lambda x:country_convert.get(x,1.0), \

14:lambda x:class_convert.get(x,1.0)}

)

if one_hot:

pass

else:

return np_file

算法代码如下所示:

def computeDis(x,y):

sub_ = x -y

dist = np.sqrt(np.sum(np.power(sub_,2),axis=1))

return dist

def initClusterPoint(data,k):

##随机选取K个点做为初始聚类点

centers = data[np.random.randint(0,data.shape[0],k)]

return centers

def getClusterPoint(data,k,centers):

rows = data.shape[0]

all_distance = np.empty([rows,k])

for index,center in enumerate(centers):

distanc = computeDis(data,center)

all_distance[:,index] = distanc

print(centers)

small_dis = all_distance.min(axis=1)

small_dis_index = all_distance.argmin(axis=1)

##更新centers

new_centers = np.empty(centers.shape)

for index,center in enumerate(centers):

index_num = (small_dis_index == index).sum()

data_index_small = data[small_dis_index==index]

new_centers[index] = np.sum(data_index_small,axis=0) / index_num

return small_dis_index,new_centers

if __name__ == "__main__":

data = read_file('./adult.txt')

# data = read_small_file('./sample_kmeans_data.txt')

k= 5

tempDist = 1.0

convergeDist = 0.01

init_centor = initClusterPoint(data=data,k=k)

data_cluster = np.empty(data.shape[0])

while tempDist > convergeDist:

data_cluster,centers = getClusterPoint(data=data,k=k,centers=init_centor)

# tempDist = np.sqrt(np.dot(np.power(init_centor-centers,2)))

tempDist = computeDis(init_centor,centers).max()

init_centor = centers

draw(Data=data,centers=init_centor,index=data_cluster)

绘图:

def draw(Data,centers,index=None):

"""

特征数是15维时,可以降维展示分类效果,

不代表实际数据的分布

"""

plt.title('Kmeans classifier of Adult')

fig,axes = plt.subplots(1,2)

pca = decomposition.PCA(n_components=2)

new_data = pca.fit_transform(Data)

print (new_data)

axes[0].scatter(new_data[:,0], new_data[:,1], marker='o',alpha=0.5)

for center_index,center in enumerate(centers):

data_index = new_data[index==center_index]

axes[1].scatter(data_index[:,0], data_index[:,1], marker='o',color=next(palette),alpha=0.5)

plt.show()

效果如下:

rMB3Ub.png kmeans.png

完整代码在:

作者:marvinxu

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值