import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
data = [] #csv读入数据
cty_data = [] #将相同国家整合
cdata = [] #提取出的三维数据
def data_p(fliename):
fr = open(fliename)
tid = 0
cid = 0
for line in fr.readlines():
tline = line.strip().split(';')
data.append(tline)
cn = tline[0]
if tid==0:
cty_data.append([tid])
#print(cty_data)
#print(cty_data[cid][0])
elif cn == data[cty_data[cid][0]][0]:
cty_data[cid].append(tid)
else:
cid += 1
cty_data.append([tid])
tid += 1
data_p('C:\\Users\\imac\\Desktop\\2018\\bigdata\\py\\wealth1951.txt')
cty_data = np.asarray(cty_data)
#print(cty_data)
data = np.asarray(data)
#print(data[cty_data[:,:],0])
def getdata():
for i in range(len(cty_data)):
itm = data[cty_data[i,-1],3:6]
cdata.append(itm)
getdata()
cdata = np.asarray(cdata,dtype=np.float32)
#print(cdata)
def distEclud(vecA,vecB):
return np.sqrt(sum(np.power(vecA-vecB,2)))
def randCent(dataSet,k):#取中心
dim = dataSet.shape[1]
cent = np.zeros([k,dim],dtype=np.float32)
for i in range(dim):
minn = np.amin(dataSet[:,i])
maxn = np.amax(dataSet[:,i])
cent[:,i] = np.random.uniform(minn,maxn,k)
return cent
def kMeans(dataSet,k,distMeas=distEclud,createCent=randCent):
ctmp = np.zeros([dataSet.shape[0],1],dtype=np.int32) #点->中心
dim = dataSet.shape[1] #【1】坐标维数
cent = createCent(dataSet,k) #中心及坐标
change = True
while change:
change = False
for i in range(len(dataSet)):
mindis = float("inf")
mcent = -1
for j in range(k): #枚举最近中心
tdis = distEclud(dataSet[i],cent[j])
if tdis < mindis:
mindis = tdis
mcent = j
if mcent != ctmp[i]: #修改
ctmp[i] = mcent
change = True
for i in range(k): #枚举中心
for j in range(dim): #枚举维数
cent[i,j] = np.mean([dataSet[k,j] \
for k in range(len(dataSet)) if ctmp[k] == i])
#所有该类的点各维度均值(更新中心)
return ctmp, cent
ctmp,cent = kMeans(cdata,k=4)
#print(randCent(cdata,k=4))
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(cdata[:,0],cdata[:,1],cdata[:,2],c='r')
print(len(cent))
for i in range(len(cent)):
sett = [k for k in range(len(ctmp)) if ctmp[k] == i]
#print(sett)
ax.scatter(cent[i,0],cent[i,1],cent[i,2],s=50,marker='+')
if i==0:
ax.scatter(cdata[sett,0],cdata[sett,1],cdata[sett,2],s=30,marker='^',c='red')
elif i==1:
ax.scatter(cdata[sett,0],cdata[sett,1],cdata[sett,2],c='black')
elif i==2:
ax.scatter(cdata[sett,0],cdata[sett,1],cdata[sett,2],c='green')
elif i==3:
ax.scatter(cdata[sett,0],cdata[sett,1],cdata[sett,2],c='blue')
print(cent[i])
print(sett)
ax.set_zlabel('Z') # 坐标轴
ax.set_ylabel('Y')
ax.set_xlabel('X')
plt.show()