一,介绍
算法主要步骤包括:初始化原型向量;迭代优化,更新原型向量。
流程如下:
具体来说,主要是:
1、对原型向量初始化,可以选择满足yj=tj,j∈{1,2,…,m}yj=tj,j∈{1,2,…,m}条件的某个样本 xj=(xj1,xj2,…,xjn)xj=(xj1,xj2,…,xjn)作为 qjqj的初始值;
2、从数据集DD 中任意选择一个样本 xjxj,找到与此样本距离最近的原型向量,假设为qiqi ;
3、如果xjxj的标记yjyj 与qiqi的标记titi相等,即 yj=ti,yj=ti,则令:
否则:
4、更新原型向量:
5、判断是否达到最大迭代次数或者原型向量更新幅度小于某个阈值。如果是,则停止迭代,输出原型向量;否则,转至步骤2。
其中步骤3和4的物理意义是:如果xjxj和最近的原型向量qiqi具有同样的类别标记,则令 qiqi向 xjxj的方向靠拢,且:
否则,qiqi 远离 xjxj,且
二,代码实现
import matplotlib.pyplot as plt import numpy as np import math import random def loadDataSet(filename): fr = open(filename) numberOfLines = len(fr.readlines()) returnMat = np.zeros((numberOfLines, 2)) classLabelVector = [] fr = open(filename) index = 0 for line in fr.readlines(): line = line.strip().split(',') returnMat[index, :] = line[0:2] classLabelVector.append(line[-1]) index += 1 return returnMat, classLabelVector # 欧几里得距离 def edistance(v1, v2): result=0.0 for i in range(len(v1)): result +=(v1[i]-v2[i])**2 return math.sqrt(result) # 学习向量量化算法 def lvq(dataMat, labelMat,alpha=0.1,times=500): classify = set(labelMat) randinfo = [random.randint(0,14),random.randint(15,30)] clusters = [dataMat[randinfo[i]] for i in range(len(randinfo))] # 随机选取k个值作为聚类中心 while times > 0: # 迭代次数 n = random.randint(0,29) d=np.array([edistance(clusters[i], dataMat[n]) for i in range(len(clusters))],dtype='float') # 获取和各个聚类中心距离 index = np.argmin(d) if(labelMat[n]==labelMat[randinfo[index]]): # 同类靠近 clusters[index]=clusters[index]+alpha*(dataMat[n]-clusters[index]) print("同类:",alpha*(dataMat[n]-clusters[index])) else: # 异类远离 clusters[index] = clusters[index] - alpha * (dataMat[n] - clusters[index]) print("异类:", alpha * (dataMat[n] - clusters[index])) times-=1 print("中心点:%s",(clusters)) return clusters def plot(dataMat, labelMat,clusters): xcord = [];ycord = [] sumx1 = 0.0;sumy1 = 0.0;sumx2 = 0.0;sumy2 = 0.0 midx = [];midy = [] for i in range(len(dataMat)): xcord.append(float(dataMat[i][0]));ycord.append(float(dataMat[i][1])) for i in range(len(labelMat)): if(labelMat[i]=="1"): plt.scatter(xcord[i], ycord[i], color='red') else: plt.scatter(xcord[i], ycord[i], color='black') for c in clusters: plt.scatter(c[0], c[1], marker='+', color='blue') for j in range(len(labelMat)): if (labelMat[j] == "1"): sumx1+=xcord[j] sumy1+=ycord[j] else: sumx2 += xcord[j] sumy2 += ycord[j] midx.append(sumx1 / 17) midx.append(sumx2 / 17) midy.append(sumy1 / 13) midy.append(sumy2 / 13) plt.scatter(midx, midy, marker='*',color='green') plt.show() if __name__=='__main__': dataMat, labelMat = loadDataSet('watermelon4.1.txt') clusters = lvq(dataMat, labelMat) plot(dataMat, labelMat,clusters)
结果如下: