手写K-Means的python实现
K-Means 算法流程图
python代码实现
kmeans_tool.py
# -*- coding:utf-8 -*-
'''
description: kmeans tool kit
time: 2020/12/10
'''
import random
import math
class Cluster(object):
def __init__(self, samples):
if len(samples) == 0:
raise Exception("sample is empty")
self.n_dim = samples[0].n_dim
self.samples = samples
# calculate center sample
self.center = self.calculate_center()
def __repr__(self):
'''
print the cluser
__repr__ print the cluster info
'''
return str(self.samples)
def update(self, samples):
'''
update the samples in this clusers
'''
old_center = self.center
self.samples = samples
self.center = self.calculate_center()
return get_distance(old_center, self.center)
def calculate_center(self):
'''
calculate the center of the cluser
'''
coords = [sample.coords for sample in self.samples]
return Sample([ sum(item)/len(item) for item in zip(*coords)])
class Sample(object):
def __init__(self, coords):
self.coords = coords
self.n_dim = len(coords)
def __repr__(self):
return str(self.coords)
def get_distance(a, b):
'''
calculate the distance of two samples
'''
sum = 0.0
for i in range(a.n_dim):
sum += pow((a.coords[i] - b.coords[i]), 2)
return math.sqrt(sum)
def generate_sample(n_dim, low, high):
'''
generate a sample
'''
return Sample([random.uniform(low, high) for _ in range(n_dim)])
kmean_process.py
# -*- coding:utf-8 -*-
'''
description: kmeans progress
time: 2020/12/10
'''
import random
from kmeans_tool import Cluster, Sample, get_distance, generate_sample
import matplotlib.pyplot as plt
def kmeans(samples, k, cut_off):
# random select k samples as the init cluster
init_samples = random.sample(samples, k)
# init the cluser with center
clusters = [Cluster([sample]) for sample in init_samples]
loop = 0
while True:
# use to save the samples
cluser_samples = [[] for _ in range(k)]
loop += 1
for sample in samples:
# calculate the min distance with cluster center
min_distance = float('inf')
cluster_index = 0
for i in range(len(clusters)):
distance = get_distance(sample, clusters[i].center)
if distance < min_distance:
min_distance = distance
cluster_index = i
cluser_samples[cluster_index].append(sample)
#calculate the biggest_shift
biggest_shift = 0.0
for i in range(k):
shift = clusters[i].update(cluser_samples[i])
biggest_shift = max(biggest_shift, shift)
if biggest_shift < cut_off:
print("{} iteration ".format(loop))
break
return clusters
def main():
"""
main progress
"""
# number of samples
n_samples = 1000
# dimension
n_feat = 2
lower = 0
upper = 200
# number of clusters
n_cluster = 5
# generate the samples
samples = [generate_sample(n_feat, lower, upper) for _ in range(n_samples)]
# threshold
cutoff = 0.2
# execute kmeans
clusters = kmeans(samples, n_cluster, cutoff)
# print out enumerate generate (index, item) for iteration
for i, c in enumerate(clusters):
for sample in c.samples:
print('cluster--{},samples--{}'.format(i, sample))
if __name__ == '__main__':
main()