import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import imageio
def Kmeans(center):
for i in range(3):
distance = np.sqrt(((x_train - center[i, :])**2).sum(axis=1))
columns_name = '类{}'.format(i)
data[columns_name] = distance
data['custer'] = data.loc[:, ['类0', '类1', '类2']].idxmin(axis=1)
new_center = data.loc[:, ['平均消费周期(天)', '平均每次消费金额', 'custer']].groupby(by='custer').mean().values
return new_center
def plot_image(times, data):
plt.figure()
for i in ['类0', '类1', '类2']:
x = data.loc[data['custer']==i, '平均消费周期(天)']
y = data.loc[data['custer']==i, '平均每次消费金额']
plt.scatter(x, y)
image_name = '第{}次聚类结果.png'.format(times)
plt.savefig(image_name)
plt.close('all')
return image_name
if __name__ == '__main__':
data = pd.read_csv('company.csv', encoding='gbk')
image_list = []
x_train = data.loc[:, ['平均消费周期(天)', '平均每次消费金额']]
center = np.array([[10, 100], [20, 200], [30, 300]])
new_center = Kmeans(center)
print(data)
times = 1
image_name=plot_image(times, data)
image_list.append(image_name)
while True:
if (center==new_center).all():
print(times)
break
times+=1
center = new_center.copy()
new_center = Kmeans(center)
image_name = plot_image(times, data)
image_list.append(image_name)
frame_list = []
for image_name in image_list:
im = imageio.imread(image_name)
frame_list.append(im)
duration= 0.7
imageio.mimsave('聚类结果.gif', frame_list, 'GIF', duration=round(duration, 2))
