聚类算法之kmeans算法java版本

      聚类的意思很明确,物以类聚,把类似的事物放在一起。
      聚类算法是web智能中很重要的一步,可运用在社交,新闻,电商等各种应用中,我打算专门开个分类讲解聚类各种算法的java版实现。
     首先介绍kmeans算法。
     kmeans算法的速度很快,性能良好,几乎是应用最广泛的,它需要先指定聚类的个数k,然后根据k值来自动分出k个类别集合。
     举个例子,某某教练在得到全队的数据后,想把这些球员自动分成不同的组别,你得问教练需要分成几个组,他回答你k个,ok可以开始了,在解决这个问题之前有必要详细了解自己需要达到的目的:根据教练给出的k值,呈现出k个组,每个组的队员是相似的。
     首先,我们创建球员类。 

 

package kmeans;

   /**
    * 球员
     * 
    * @author 阿飞哥
    * 
    */
  public class Player {

 private int id;
 private String name;

 private int age;

 /* 得分 */
 @KmeanField
 private double goal;

 /* 助攻 */
 //@KmeanField
 private double assists;

 /* 篮板 */
 //@KmeanField
 private double backboard;

 /* 抢断 */
 //@KmeanField
 private double steals;

 public int getId() {
  return id;
 }

 public void setId(int id) {
  this.id = id;
 }

 public String getName() {
  return name;
 }

 public void setName(String name) {
  this.name = name;
 }

 public int getAge() {
  return age;
 }

 public void setAge(int age) {
  this.age = age;
 }

 public double getGoal() {
  return goal;
 }

 public void setGoal(double goal) {
  this.goal = goal;
 }

 public double getAssists() {
  return assists;
 }

 public void setAssists(double assists) {
  this.assists = assists;
 }

 public double getBackboard() {
  return backboard;
 }

 public void setBackboard(double backboard) {
  this.backboard = backboard;
 }

 public double getSteals() {
  return steals;
 }

 public void setSteals(double steals) {
  this.steals = steals;
 }

 
}

 

        

   @KmeanField这个注解是自定义的,用来标示这个属性是否是算法需要的维度。
代码如下 

package kmeans;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
 * 在对象的属性上标注此注释,
 * 表示纳入kmeans算法,仅支持数值类属性
 * @author 阿飞哥
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface KmeanField {
}

 

接下来看看最核心的kmeans算法,具体实现过程如下:
1,初始化k个聚类中心
2,计算出每个对象跟这k个中心的距离(相似度计算,这个下面会提到),假如x这个对象跟y这个中心的距离最小(相似度最大),那么x属于y这个中心。这一步就可以得到初步的k个聚类
3,在第二步得到的每个聚类分别计算出新的聚类中心,和旧的中心比对,假如不相同,则继续第2步,直到新旧两个中心相同,说明聚类不可变,已经成功

实现代码如下:

package kmeans;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

/**
 * 
 * @author 阿飞哥
 * 
 */
public class Kmeans<T> {

 /**
  * 所有数据列表
  */
 private List<T> players = new ArrayList<T>();

 /**
  * 数据类别
  */
 private Class<T> classT;

 /**
  * 初始化列表
  */
 private List<T> initPlayers;

 /**
  * 需要纳入kmeans算法的属性名称
  */
 private List<String> fieldNames = new ArrayList<String>();

 /**
  * 分类数
  */
 private int k = 1;

 public Kmeans() {

 }

 /**
  * 初始化列表
  * 
  * @param list
  * @param k
  */
 public Kmeans(List<T> list, int k) {
  this.players = list;
  this.k = k;
  T t = list.get(0);
  this.classT = (Class<T>) t.getClass();
  Field[] fields = this.classT.getDeclaredFields();
  for (int i = 0; i < fields.length; i++) {
   Annotation kmeansAnnotation = fields[i]
     .getAnnotation(KmeanField.class);
   if (kmeansAnnotation != null) {
    fieldNames.add(fields[i].getName());
   }

  }

  initPlayers = new ArrayList<T>();
  for (int i = 0; i < k; i++) {
   initPlayers.add(players.get(i));
  }
 }

 public List<T>[] comput() {
  List<T>[] results = new ArrayList[k];

  boolean centerchange = true;
  while (centerchange) {
   centerchange = false;
   for (int i = 0; i < k; i++) {
    results[i] = new ArrayList<T>();
   }
   for (int i = 0; i < players.size(); i++) {
    T p = players.get(i);
    double[] dists = new double[k];
    for (int j = 0; j < initPlayers.size(); j++) {
     T initP = initPlayers.get(j);
     /* 计算距离 */
     double dist = distance(initP, p);
     dists[j] = dist;
    }

    int dist_index = computOrder(dists);
    results[dist_index].add(p);
   }

   for (int i = 0; i < k; i++) {
    T player_new = findNewCenter(results[i]);
    T player_old = initPlayers.get(i);
    if (!IsPlayerEqual(player_new, player_old)) {
     centerchange = true;
     initPlayers.set(i, player_new);
    }

   }

  }

  return results;
 }

 /**
  * 比较是否两个对象是否属性一致
  * 
  * @param p1
  * @param p2
  * @return
  */
 public boolean IsPlayerEqual(T p1, T p2) {
  if (p1 == p2) {
   return true;
  }
  if (p1 == null || p2 == null) {
   return false;
  }

  

  boolean flag = true;
  try {
   for (int i = 0; i < fieldNames.size(); i++) {
    String fieldName=fieldNames.get(i);
    String getName = "get"
      + fieldName.substring(0, 1).toUpperCase()
      + fieldName.substring(1);    
    Object value1 = invokeMethod(p1,getName,null);
    Object value2 = invokeMethod(p2,getName,null);
    if (!value1.equals(value2)) {
     flag = false;
     break;
    }
   }
  } catch (Exception e) {
   e.printStackTrace();
   flag = false;
  }

  return flag;
 }

 /**
  * 得到新聚类中心对象
  * 
  * @param ps
  * @return
  */
 public T findNewCenter(List<T> ps) {
  try {
   T t = classT.newInstance();
   if (ps == null || ps.size() == 0) {
    return t;
   }

   double[] ds = new double[fieldNames.size()];
   for (T vo : ps) {
    for (int i = 0; i < fieldNames.size(); i++) {
     String fieldName=fieldNames.get(i);
     String getName = "get"
       + fieldName.substring(0, 1).toUpperCase()
       + fieldName.substring(1);
     Object obj=invokeMethod(vo,getName,null);
     Double fv=(obj==null?0:Double.parseDouble(obj+""));
     ds[i] += fv;
    }

   }

   for (int i = 0; i < fieldNames.size(); i++) {
    ds[i] = ds[i] / ps.size();
    String fieldName = fieldNames.get(i);
    
    /* 给对象设值 */
    String setName = "set"
      + fieldName.substring(0, 1).toUpperCase()
      + fieldName.substring(1);

    invokeMethod(t,setName,new Class[]{double.class},ds[i]);

   }

   return t;
  } catch (Exception ex) {
   ex.printStackTrace();
  }
  return null;

 }

 /**
  * 得到最短距离,并返回最短距离索引
  * 
  * @param dists
  * @return
  */
 public int computOrder(double[] dists) {
  double min = 0;
  int index = 0;
  for (int i = 0; i < dists.length - 1; i++) {
   double dist0 = dists[i];
   if (i == 0) {
    min = dist0;
    index = 0;
   }
   double dist1 = dists[i + 1];
   if (min > dist1) {
    min = dist1;
    index = i + 1;
   }
  }

  return index;
 }

 /**
  * 计算距离(相似性) 采用欧几里得算法
  * 
  * @param p0
  * @param p1
  * @return
  */
 public double distance(T p0, T p1) {
  double dis = 0;
  try {

   for (int i = 0; i < fieldNames.size(); i++) {
    String fieldName = fieldNames.get(i);
    String getName = "get"
      + fieldName.substring(0, 1).toUpperCase()
      + fieldName.substring(1);
    
    Double field0Value=Double.parseDouble(invokeMethod(p0,getName,null)+"");
    Double field1Value=Double.parseDouble(invokeMethod(p1,getName,null)+"");
    dis += Math.pow(field0Value - field1Value, 2); 
   }
  
  } catch (Exception ex) {
   ex.printStackTrace();
  }
  return Math.sqrt(dis);

 }
 
 /*------公共方法-----*/
 public Object invokeMethod(Object owner, String methodName,Class[] argsClass,
   Object... args) {
  Class ownerClass = owner.getClass();
  try {
   Method method=ownerClass.getDeclaredMethod(methodName,argsClass);
   return method.invoke(owner, args);
  } catch (SecurityException e) {
   e.printStackTrace();
  } catch (NoSuchMethodException e) {
   e.printStackTrace();
  } catch (Exception ex) {
   ex.printStackTrace();
  }

  return null;
 }

}

 

最后咱们测试一下:

package kmeans;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class TestMain {

 public static void main(String[] args) {
       List<Player> listPlayers=new ArrayList<Player>();
        
        for(int i=0;i<15;i++){
         
         Player p1=new Player();
         p1.setName("afei-"+i);
         p1.setAssists(i);
         p1.setBackboard(i);
         
         //p1.setGoal(new Random(100*i).nextDouble());
         p1.setGoal(i*10);
         p1.setSteals(i);
         //listPlayers.add(p1); 
        }
        
        Player p1=new Player();
        p1.setName("afei1");
        p1.setGoal(1);
        p1.setAssists(8);
        listPlayers.add(p1);
       
        Player p2=new Player();
        p2.setName("afei2");
        p2.setGoal(2);
        listPlayers.add(p2);
        
         Player p3=new Player();
        p3.setName("afei3");
        p3.setGoal(3);
        listPlayers.add(p3);
        
         Player p4=new Player();
        p4.setName("afei4");
        p4.setGoal(7);
        listPlayers.add(p4);
        
         Player p5=new Player();
        p5.setName("afei5");
        p5.setGoal(8);
        listPlayers.add(p5);
        
         Player p6=new Player();
        p6.setName("afei6");
        p6.setGoal(25);
        listPlayers.add(p6);
        
         Player p7=new Player();
        p7.setName("afei7");
        p7.setGoal(26);
        listPlayers.add(p7);
        
         Player p8=new Player();
        p8.setName("afei8");
        p8.setGoal(27);
        listPlayers.add(p8);
        
         Player p9=new Player();
        p9.setName("afei9");
        p9.setGoal(28);
        listPlayers.add(p9);
        
        
  Kmeans<Player> kmeans = new Kmeans<Player>(listPlayers,3);
  List<Player>[] results = kmeans.comput();
  for (int i = 0; i < results.length; i++) {
   System.out.println("===========类别" + (i + 1) + "================");
   List<Player> list = results[i];
   for (Player p : list) {
    System.out.println(p.getName() + "--->"
      + p.getGoal() + "," + p.getAssists() + ","
      + p.getSteals() + "," + p.getBackboard());
   }
  }
  
  
  
      
 }

}

 

结果如下

  这个里面涉及到相似度算法,事实证明欧几里得距离算法的实践效果是最优的。
  最后说说kmeans算法的不足:可以看到只能针对数字类型的属性(维),对于其他类型的除非选定合适的数值度量。

 

By 阿飞哥 转载请说明
 


       

 

转载于:https://my.oschina.net/duyunfei/blog/54755

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值