今天我们来讨论一个新的聚类算法-Affinity Propagation,(我喜欢叫它“亲和信息传播算法”,这是我的个人叫法,可能不专业也不准确)。顾名思义,亲和信息传播涉及到两个方面,一个是亲和信息,一个是传播,下边我们就围绕这两个方面来介绍这个算法,并且附上Java代码。
AP(Affinity Propagation)算法是一个基于图的聚类算法,图中的点就是待聚类的数据点,点与点之间的连线表示的是它们之间的关系(关系一般是指点之间的相似度,相似度呢又可以用欧氏距离等等来计算,这是我工作中用到的实际情况。但是在许多的其他情况中点之间的关系并不一定是相似度,也不一定是欧氏距离来计算,要根据具体的情况来决定)。跟其他聚类算法的不同之处是,AP聚类算法在开始时是不需要指定聚类的数目的,而是将图中所有节点都看成潜在的聚类中心,然后通过节点之间的通信(也就是我们上边的说的亲和信息传播),去找出最合适的聚类中心(通过亲和度找出),并将其他节点划分到这些中心下去。所以我们可以认为,AP算法所要做的事情就是去发现这些聚类中心,并且将其他节点划分到这些聚类中心中去。
上边我们也谈到了AP聚类算法的任务是寻找聚类中心,然后再把其他的点分配到相应的聚类中心去,那么问题来了,我们应该怎样去发现聚类中心呢?或者说聚类中心的判断标准是什么呢?要回答这个问题就会用到我们一开始提到的亲和信息。其实在AP聚类算法中,各数据点之间传播着两种信息:
吸引度信息(responsibility):r(i,k)描述了由i点发送到k点的信息,表示了k适合作为i点的聚类中心的程度。
归属度信息(availability):a(i,k)描述了k点发送到i点的信息,表示了i选择k作为聚类中心的合适程度。
由上面的两种信息我们可以看出来,一个点若想成为聚类中心,是需要一个双方认可的:一是该点本身相较于其他的点适合作为聚类中心(吸引度),另外一个是其他点选择该点作为聚类中心比较合适(归属度信息)。所以某个点是否是一个聚类中心是由r(i,k)+a(i,k)的值来决定的。
AP聚类算法的输入是一个节点间的相似度矩阵S,S(i,j)表示节点i和节点j之间的相似度,也表明了,j作为i的聚类中心的合适程度(相似度值取负值)。其中相似度矩阵主对角线上的相似度S(k,k)表示节点k成为聚类中心合适度,我们也成为参考度p,其值越大越能够成为聚类中心。在最开始时,这个点的值是使用者给定的,p的大小影响簇中心的数目,若认为每个数据对象都有可能作为簇中心,那么p就应该取相同的值(此时S对角线的值都为p),当然可以根据不同点成为簇中心的可能性大小,取不同的p值(此时S对角线上的值就会不同)。如果p等于S矩阵中所有元素的均值,那么得到的簇中心数目是中等的;如果取最小相似度,那么得到较少的聚类。
其实吸引度是一个相对的概念,相似度矩阵记录了k成为i的聚类中心的合适程度,那么这里我们只需要证明k比其他节点更合适了就可以了,其他节点是否合适这个如何进行衡量呢?是否合适其实就是看这两个节点是否相互认可,对于其他节点k'我们有s(i,k')表示节点k'作为节点i的聚类中心的合适度,再定义一个a(i,k')表示i对节点k'的认可程度(归属度),这两个值相加,a(i,k') + s(i,k'),就可以计算出节点k'作为节点i的聚类中心的合适程度了,这里,在所有其他节点k'中,找出最大的a(i,k') + s(i,k'),即max{a(i,k’)+s(i,k')},再使用s(i,k) - max{a(i,k’)+s(i,k')} 就可以得出k对i的吸引度了,也就是第一个公式:
r(i,k) = s(i,k) - max{a(i,k’)+s(i,k')} 其中k != k'
接下来就是要计算归属度a(i,k),在这里我们有一个假设:如果节点k作为其他节点i'的聚类中心的合适度很大,那么节点k作为节点i的聚类中心的合适度也可能会较大。由此就可以先计算节点k对其他节点的吸引度,r(i',k),然后做一个累加和表示节点k对其他节点的吸引度。然后再加上r(k,k),这里为什么要加上r(k,k),根据吸引度公式,我们可以看出,其实r(k,k),反应的是节点k有多不适合被划分到其他聚类中心下去,这里的公式中,将k有多适合成为其他节点的聚类中心加上它有多不适合被划分到其他聚类中心下去,我们得到公式:
a(i,k)=min{0,r(k,k)+∑max{0,r(i',k)}} 其中k != k'并且k != i
和
a(k,k)=∑max{0,r(i',k)}
注意:∑max{0,r(i',k)}}在r(i',k)跟0之间取一个大的原因是因为s(i',k)一般会初始化成负值,导致r(i',k)计算出来也有可能是负值,这样的好处是,最后可以方便找出合适的聚类中心在完成所有计算后。
AP算法为了避免震荡(即迭代过程中产生的聚类数目不断发生变化而导致不能收敛的情况),更新信息时引入了衰减系数λ 。每条信息被设置为它前次迭代更新值的 λ 倍加上本次信息更新值的1- λ 倍。其中,衰减系数 λ 是介于0到1之间的实数。即第t+1次 r(i,k) , a(i,k) 的迭代值:
整个AP算法的过程是先迭代R(i,k),利用迭代后的R再迭代A(i,k)。一次迭代包括R和A的迭代,每次迭代后,将R(k,k)+A(k,k)大于0的数据对象k作为簇中心。当迭代次数超过设置阈值时(如1000次)或者当聚类中心连续多少次迭代不发生改变时终止迭代(如50次)。
AP算法的迭代次数和聚类数目主要受到两个参数的影响。其中聚类数目主要受参考度p(该值为负值)的影响,该值越大,聚类数目越多。参数lamd称为阻尼系数,由公式可以看出,该值越小,那么R和A相比上一次迭代的R和A会发生较大的变化,迭代次数会减少。阻尼系数一般取值为(0,1)。
算法步骤:1)先计算数据对象之间的相似度,得到相似矩阵;2)不断迭代R和A,得到簇中心;3)根据簇中心划分数据对象。
优点:1)不需要像k-menas、k-medoids(如PAM、CLABA等)事先给出聚类数目;
缺点:1)AP算法需要事先计算每对数据对象之间的相似度,如果数据对象太多的话,内存放不下,若存在数据库,频繁访问数据库也需要时间。如果计算过程中实时计算相似度,那么计算量就上去了;2)AP算法的时间复杂度较高,一次迭代大概O(N3);3)聚类的好坏受到参考度和阻尼系数的影响。
好了,下边该上代码了:
package com.zc.cluster;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
/**
* 利用ap算法将数据进行聚类(csdn)
*
* @author zhouchao
*/
public class ApCluster {
public static String i_file = "E:/csdnworkplace/data/test_30.txt";// 输入数据
public static String clusetr_file = "E:/csdnworkplace/data/rel_30.txt";// 聚类结果文件
// public static String exemplar_file = "";// 聚类中心文件
private static int maxIterNum;// 最大迭代次数(自己设置)
private double lambda;// 衰减系数,主要是起收敛作用的
private int dataNum;// 数据个数的计数器
private static Point[] dataset;// 读入的数据集
private double similar[][];// 相似度矩阵,数据点i和点j的相似度记为s(i, j),是指点j作为点i的聚类中心的相似度
private double r[][];// 吸引信息矩阵,r(i,k)用来描述点k适合作为数据点i的聚类中心的程度
private double a[][];// 归属信息矩阵,a(i,k)用来描述点i选择点k作为其聚类中心的适合程度
private double oldr[][];// 前一次的吸引信息矩阵,r(i,k)用来描述点k适合作为数据点i的聚类中心的程度
private double olda[][];// 前一次的归属信息矩阵,a(i,k)用来描述点i选择点k作为其聚类中心的适合程度
private static List<Integer> exemplar;// 聚类中心
private List<Integer> oldExemplar;// 旧的聚类中心
private int changedCount;// 改变的次数
private int unchangeNum;// 未改变的次数
@SuppressWarnings("static-access")
public ApCluster(int maxIterNum, double lambda) {
this.maxIterNum = maxIterNum;
this.lambda = lambda;
}
public static void main(String[] args) {
// 初始化迭代次数和衰减系数
ApCluster ac = new ApCluster(1000, 0.8);
// 读入数据(输入地址,分隔符,是否含有标签)
ac.readData(i_file, " ", true);
// 初始化Availability、Responsibility和参考度p
ac.init();
// 开始聚类
ac.clustering();
System.out.println("迭代次数:" + maxIterNum);
System.out.println("聚类数目为:" + exemplar.size());
// 保存聚类结果
ac.writeCluster();
}
public void writeCluster() {
File Output = new File(clusetr_file);
if (Output.exists()) {
Output.delete();
}
FileOutputStream fos = null;
OutputStreamWriter osw = null;
BufferedWriter bw = null;
try {
fos = new FileOutputStream(Output);
osw = new OutputStreamWriter(fos, "utf-8");
bw = new BufferedWriter(osw);
Map<String, List<String>> relMap = new HashMap<String, List<String>>();
for (Point p : dataset) {
if (relMap.get(p.predictLabel) != null) {
relMap.get(p.predictLabel).add(p.label);
} else {
List<String> list = new LinkedList<String>();
list.add(p.label);
relMap.put(p.predictLabel, list);
}
}
for (String s : relMap.keySet()) {
bw.write(s);
bw.write("\r\n");
for (String str : relMap.get(s)) {
if (!str.equals(s)) {
bw.write(str);
bw.write("\r\n");
}
}
bw.write("----------");
bw.write("\r\n");
}
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
bw.flush();
osw.flush();
fos.flush();
fos.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
public void clustering() {
for (int i = 0; i < maxIterNum; i++) {
// 更新吸引度
updateResponsible();
// 更新归属度
updateAvailable();
oldExemplar.clear();
if (!exemplar.isEmpty()) {
for (Integer v : exemplar) {
oldExemplar.add(v);
}
}
exemplar.clear();
changedCount = 0;
// 获取聚类中心
for (int k = 0; k < dataNum; k++) {
if (r[k][k] + a[k][k] > 0) {
exemplar.add(k);
}
}
// 分配聚类中心
assignCluster();
if (changedCount == 0) {
unchangeNum++;
if (unchangeNum > 50) {
maxIterNum = i;
break;
}
} else {
unchangeNum = 0;
}
}
// 生成预测标签
setPredictLabel();
}
// 生成预测标签
private void setPredictLabel() {
Map<Integer, String> labelMap = new HashMap<Integer, String>();
for (int cid : exemplar) {
Map<String, Integer> tempMap = new HashMap<String, Integer>();
for (Point p : dataset) {
if (cid == p.cid) {
if (tempMap.get(p.label) == null) {
tempMap.put(p.label, 1);
} else {
tempMap.put(p.label, tempMap.get(p.label) + 1);
}
}
}
String maxLabel = null;
int temp = 0;
for (Entry<String, Integer> iter : tempMap.entrySet()) {
if (temp < iter.getValue()) {
temp = iter.getValue();
maxLabel = iter.getKey();
}
}
labelMap.put(cid, maxLabel);
}
for (Point p : dataset) {
p.predictLabel = labelMap.get(p.cid);
}
}
// 分配聚类中心
private void assignCluster() {
for (int i = 0; i < dataNum; i++) {
double max = -Double.MAX_VALUE;
int index = 0;
for (Integer k : exemplar) {
if (max < similar[i][k]) {
max = similar[i][k];
index = k;
}
}
if (dataset[i].cid != index) {
dataset[i].cid = index;
changedCount++;
}
}
}
// 更新吸引度
private void updateAvailable() {
for (int i = 0; i < dataNum; i++) {
for (int k = 0; k < dataNum; k++) {
olda[i][k] = a[i][k];
if (i == k) {
double sum = 0;
for (int j = 0; j < dataNum; j++) {
if (j != k) {
if (r[j][k] > 0) {
sum += r[j][k];
}
}
}
a[i][k] = sum;
} else {
double sum = 0;
for (int j = 0; j < dataNum; j++) {
if (j != i && j != k) {
if (r[j][k] > 0) {
sum += r[j][k];
}
}
}
if (r[k][k] + sum > 0) {
a[i][k] = 0;
} else {
a[i][k] = r[k][k] + sum;
}
}
a[i][k] = (1 - lambda) * a[i][k] + lambda * olda[i][k];
}
}
}
// 更新归属度
private void updateResponsible() {
for (int i = 0; i < dataNum; i++) {
for (int k = 0; k < dataNum; k++) {
oldr[i][k] = r[i][k];
double max = -Double.MAX_VALUE;
for (int j = 0; j < dataNum; j++) {
if (j != k) {
if (a[i][j] + similar[i][j] > max) {
max = a[i][j] + similar[i][j];
}
}
}
r[i][k] = similar[i][k] - max;
r[i][k] = (1 - lambda) * r[i][k] + lambda * oldr[i][k];
}
}
}
// 初始化聚类中心
public void init() {
oldExemplar = new ArrayList<Integer>();
exemplar = new ArrayList<Integer>();
similar = new double[dataNum][dataNum];
r = new double[dataNum][dataNum];
a = new double[dataNum][dataNum];
oldr = new double[dataNum][dataNum];
olda = new double[dataNum][dataNum];
// 利用欧氏距离计算每个点之间的相似度
for (int i = 0; i < dataset.length; i++) {
for (int j = i + 1; j < dataset.length; j++) {
similar[i][j] = -distance(dataset[i].dimensioin, dataset[j].dimensioin);
similar[j][i] = similar[i][j];
}
}
setPreference(1);// 设置参考度的取值,根据自己的需要设定
}
/**
* 获取数据点i的参考度<br>
* 称为p(i)或s(i,i) 是指点i作为聚类中心的参考度。一般取s相似度值的中值
*/
private void setPreference(int prefType) {
List<Double> list = new ArrayList<Double>();
// find the median
for (int i = 0; i < dataNum; i++) {
for (int j = i + 1; j < dataNum; j++) {
list.add(similar[i][j]);
}
}
Collections.sort(list);
double pref = 0;
if (prefType == 1) {// 取相似度的中位数作为参考度
if (list.size() % 2 == 0) {
pref = (list.get(list.size() / 2) + list.get(list.size() / 2 - 1)) / 2;
} else {
pref = list.get((list.size()) / 2);
}
} else if (prefType == 2) {// 取相似度的最小值作为参考度
pref = list.get(0);
} else if (prefType == 3) {// 取 0.5 * (min + max)作为参考度
pref = (list.get(list.size() - 1) + list.get(0)) * 0.5;
} else if (prefType == 4) {// 取相似度的最大值作为参考度
pref = list.get(list.size() - 1);
} else if (prefType == 5) {// 取相似度的平均值作为参考度
double temSum = 0;
for (double b : list) {
temSum += b;
}
pref = temSum / list.size();
} else {// 当没有选择合适的参考度设置类型,报错
System.out.println("prefType error");
System.exit(-1);
}
System.out.println(pref);
for (int i = 0; i < dataNum; i++) {
similar[i][i] = pref;
}
}
public double distance(double[] a, double[] b) {
if (a.length != b.length) {
throw new IllegalArgumentException("Arrry a not equal array b!");
}
double sum = 0;
for (int i = 0; i < a.length; i++) {
double dp = a[i] - b[i];
sum += dp * dp;
}
return (double) Math.sqrt(sum);
}
public void readData(String fileName, String split, boolean havelabel) {
List<Point> dataList = new ArrayList<Point>();
File file = new File(fileName);
if (!file.exists()) {
System.out.println("输入文件不存在!");
System.exit(1);
}
FileInputStream fis = null;
InputStreamReader isr = null;
BufferedReader br = null;
try {
fis = new FileInputStream(file);
isr = new InputStreamReader(fis, "utf-8");
br = new BufferedReader(isr);
String line = br.readLine();
String str[] = null;
double[] temp = null;
String label = "";
while (line != null) {
str = line.split(split);
if (havelabel) {
label = str[0];
temp = new double[str.length - 1];
for (int i = 0; i < str.length - 1; i++) {
temp[i] = Double.parseDouble(str[i + 1]);
}
} else {
label = line;
temp = new double[str.length];
for (int i = 0; i < str.length; i++) {
temp[i] = Double.parseDouble(str[i]);
}
}
dataList.add(new Point(temp, label));
dataNum++;
line = br.readLine();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
br.close();
isr.close();
fis.close();
} catch (IOException e) {
e.printStackTrace();
}
}
Collections.shuffle(dataList);// 打乱次序
dataset = new Point[dataList.size()];
dataList.toArray(dataset);
}
static class Point {
// 数据标签
private String label;
// 聚类预测的标签
private String predictLabel;
// 数据点所属簇id
private int cid;
// 数据点的维度
private double dimensioin[];
public Point(double dimensioin[], String label) {
this.label = label;
init(dimensioin);
}
public Point(double dimensioin[]) {
init(dimensioin);
}
public void init(double value[]) {
dimensioin = new double[value.length];
for (int i = 0; i < value.length; i++) {
dimensioin[i] = value[i];
}
}
@Override
public String toString() {
return "Point [label=" + label + ", predictLabel=" + predictLabel + ", cid=" + cid + ", dimensioin="
+ Arrays.toString(dimensioin) + "]";
}
}
}