【大数据分析常用算法】6.K均值

简介

1、K-均值距离函数

1.1、欧式距离

欧式距离的计算公式 $$ d(x,y) = \sqrt{(x_1 - y_1)^2 + (x_2 - y_2)^2 + ... + (x_n - y_n)^2} $$

其中,x,y分别代表两个点,同时,两个点具有相同的维度:n。$x_1,x_2,...,x_n$代表点x的每个维度的值,$y_1,y_2,...,y_n$代表点y的各个维度的值。

1.2、欧氏距离的性质

假设有$p_1,p_2,p_{k}$3个点。

  • $d(p_1,p_2) \ge 0$

  • $d(p_i,p_i) = 0$

  • $d(p_i,p_j) = d(p_j,p_i)$

  • $d(p_i,p_j) \le d(p_i,p_k) + d(p_k,p_j)$

最后一个性质也说明了一个很常见的现象:两点间的距离,线段最短。

1.3、源码实现

import java.util.List;
/**
 * 欧式距离计算
 */
public class EuclideanDistance {
    public static double caculate(List<Double> p1, List<Double> p2){
        double sum = 0.0;
        int length = p1.size();
        for (int i = 0; i < length; i++) {
            sum += Math.pow(p1.get(i) - p2.get(i),2.0);
        }
        return Math.sqrt(sum);
    }
}

2、形式化描述

K-均值算法是一个完成聚类分析的简单学习算法。K-均值聚类算法的目标是找出n项的最佳划分,也就是将n个对象划分到K个组中,是的一个组中的成员语气相应的质心(表示这个组)之间的总距离最小。采用形式化表示,目标就是将n项划分到K个集合$$ {S_i,i=1,2,...,K} $$ 中,使得簇内平方和或组内平方和(within-cluster sum of squares,WCSS)最小,WCSS定义为 $$ \min \sum_{j=1}^k \sum_{i=1}^n ||x_{i}^j - c_j|| $$

这里的$||x_i^j - c_j||$表示实体点质心之间的距离。

3、MapReduce实现

3.1、数据集

如下所示,我们选用的二位数据集。

1.0,2.0
1.0,3.0
1.0,4.0
2.0,5.0
2.0,3.0
2.0,7.0
2.0,8.0
3.0,100.0
3.0,101.0
3.0,102.0
3.0,103.0
3.0,104.0

3.2、Mapper

package mapreduce;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Mapper;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class KMeansMapper extends Mapper<LongWritable, Text, IntWritable, Text> {

    private List<List<Double>> centers = null;

    // K
    private int k = 0;

    /**
     * map 开始时调用一次。
     * @param context
     * @throws IOException
     * @throws InterruptedException
     */
    @Override
    protected void setup(Context context) throws IOException, InterruptedException {
        // config
        String centerPath = context.getConfiguration().get("centerPath");
        // 读取质心点信息
        this.centers = KMeansUtil.getCenterFromFileSystem(centerPath);
        // 获取K值(中心点个数)
        k = centers.size();
        System.out.println("当前的质心数据为:" + centers);
    }

    /**
     * 1.每次读取一条要分类的条记录与中心做对比,归类到对应的中心
     * 2.以中心ID为key,中心包含的记录为value输出(例如: 1 0.2---->1为聚类中心的ID,0.2为靠近聚类中心的某个值)
     */
    @Override
    protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException {
        // 读取一行数据
        List<Double> fields = KMeansUtil.textToList(value);
        // 点维度
        int dimension = fields.size();

        double minDistance = Double.MAX_VALUE;

        int centerIndex = 0;

        // 依次取出K个中心点与当前读取的记录做计算
        for (int i = 0; i < k; i++) {
            double currentDistance  = 0.0;
            // 之所以跳过0,是因为1代表的是该点的ID,不纳入计算的范畴
            for (int j = 1; j < dimension; j++) {
                // 获取中心点
                double centerPoint = Math.abs(centers.get(i).get(j));
                // 当前需要计算的点
                double field = Math.abs(fields.get(j));
                // 计算欧氏距离
                currentDistance += Math.pow((centerPoint - field) / (centerPoint + field), 2);
            }

            // 找出距离该记录最近的中心点的ID,记录最小值、该点的索引
            if(currentDistance < minDistance){
                minDistance = currentDistance;
                centerIndex = i;
            }
        }

        // 以中心点为key,原样输出,这样以该中心点为key的点都会作为一个簇在reducer端汇聚
        context.write(new IntWritable(centerIndex),value);
    }
}

3.3、Reuder

package mapreduce;

import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Reducer;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * 利用reduce归并功能以中心为key将记录归并在一起
 */
public class KMeansReducer extends Reducer<IntWritable, Text, NullWritable, Text>{
    /**
     * 1.K-V: Key为聚类中心的ID;value为该中心的记录集合;
     * 2.计数所有记录元素的平均值,求出新的中心;KMeans算法的最终结果选取的质心点一般不是原数据集中的点
     */
    @Override
    protected void reduce(IntWritable key, Iterable<Text> values, Context context) throws IOException, InterruptedException {
        List<List<Double>> result = new ArrayList<List<Double>>();
        // 依次读取记录集,每行转化为一个List<Double>
        for (Text value : values) {
            result.add(KMeansUtil.textToList(value));
        }

        // 计算新的质心点:通过各个维的平均值
        int dimension = result.get(0).size();
        double[] averages = new double[dimension];

        for (int i = 0; i < dimension; i++) {
            double sum = 0.0;
            int size = result.size();

            for (int j = 0; j < size; j++) {
                sum += result.get(j).get(i);
            }

            averages[i] = sum / size;
        }
        context.write(NullWritable.get(),new Text(Arrays.toString(averages).replace("[","").replace("]","")));
    }
}

3.4、Driver

package mapreduce;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;

import java.io.IOException;
import java.util.List;

public class KMeansDriver {

    public static void main(String[] args) throws Exception{
        String dfs = "hdfs://192.168.35.128:9000";
        // 存放中心点坐标值
        String centerPath = dfs + "/kmeans/center/";
        // 存放待处理数据
        String dataPath = dfs + "/kmeans/kmeans_input_file.txt";
        // 新中心点存放目录
        String newCenterPath = dfs + "/kmeans/newCenter/";
        // delta
        double delta = 0.1D;

        int count = 0;

        final int K = 3;

        // 选取初始的K个质心点
        List<List<Double>> pick = KMeansUtil.pick(K, dfs + "/kmeans/kmeans_input_file.txt");

        // 存储到结果集
        KMeansUtil.writeCurrentKClusterToCenter(centerPath + "center.data",pick);

        while(true){
            ++ count;
            System.out.println(" 第 " + count + " 次计算 ");
            run(dataPath, centerPath, newCenterPath);
            System.out.println("计算迭代变化值");
            // 比较新旧质点变化幅度
            if(KMeansUtil.compareCenters(centerPath, newCenterPath,delta)){
                System.out.println("迭代结束");
                break;
            }
        }
        /**
         * 第 1 次计算
         * 当前的质心数据为:[[1.0, 1.0], [1.0, 2.0], [1.0, 3.0]]
         * task running status is : 1
         * 计算迭代变化值
         * 当前的质心点迭代变化值: 2125.9917355371904
         *  第 2 次计算
         * 当前的质心数据为:[[1.0, 1.0], [1.0, 2.0], [2.272727272727273, 49.09090909090909]]
         * task running status is : 1
         * 计算迭代变化值
         * 当前的质心点迭代变化值: 2806.839601956485
         *  第 3 次计算
         * 当前的质心数据为:[[1.0, 1.0], [1.5714285714285714, 4.571428571428571], [3.0, 102.0]]
         * task running status is : 1
         * 计算迭代变化值
         * 当前的质心点迭代变化值: 0.44274376417233585
         *  第 4 次计算
         * 当前的质心数据为:[[1.0, 1.5], [1.6666666666666667, 5.0], [3.0, 102.0]]
         * task running status is : 1
         * 计算迭代变化值
         * 当前的质心点迭代变化值: 0.0
         * 迭代结束
         */
    }

    public static void run(String dataPath, String centerPath, String newCenterPath) throws IOException, ClassNotFoundException, InterruptedException {
        Configuration configuration = new Configuration();
        configuration.set("centerPath", centerPath);

        Job job = Job.getInstance(configuration);

        job.setJarByClass(KMeansDriver.class);
        job.setMapperClass(KMeansMapper.class);
        job.setMapOutputKeyClass(IntWritable.class);
        job.setMapOutputValueClass(Text.class);

        job.setReducerClass(KMeansReducer.class);
        job.setOutputKeyClass(NullWritable.class);
        job.setOutputValueClass(Text.class);


        FileInputFormat.setInputPaths(job,new Path(dataPath));
        FileOutputFormat.setOutputPath(job,new Path(newCenterPath) );
        System.out.println("task running status is : " + (job.waitForCompletion(true)? 1:0));
    }
}

我们还可以写一个Combiner优化网络传输的流量,不过此处由于测试的缘故,就不写不是本章节主题的代码了。

另外,这几个类还使用了一个辅助工具类

package mapreduce;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.*;
import org.apache.hadoop.io.IOUtils;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.util.LineReader;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;

/**
 * KMeans工具
 */
public class KMeansUtil {

    public static FileSystem getFileSystem() throws URISyntaxException, IOException, InterruptedException {
        // 获取一个具体的文件系统对象
        return FileSystem.get(new URI("hdfs://192.168.35.128:9000"),new Configuration(),"root");
    }


    /**
     * 在数据集中选取前K个点作为质心
     * @param k
     * @param filePath
     * @return
     */
    public static List<List<Double>> pick(int k, String filePath) throws Exception {
        List<List<Double>> result = new ArrayList<List<Double>>();
        Path path = new Path(filePath);
        FileSystem fileSystem = getFileSystem();
        FSDataInputStream open = fileSystem.open(path);
        LineReader lineReader = new LineReader(open);
        Text line = new Text();
        // 读取每一行信息
        while(lineReader.readLine(line) > 0 && k > 0){
            List<Double> doubles = textToList(line);
            result.add(doubles);
            k = k - 1;
        }
        lineReader.close();
        return result;
    }

    /**
     * 将当前的结果写入数据中心
     */
    public static void writeCurrentKClusterToCenter(String centerPath,List<List<Double>> data) throws Exception {
        FSDataOutputStream out = getFileSystem().create(new Path(centerPath));

        for (List<Double> d : data) {
            String str = d.toString();
            out.write(str.replace("[","").replace("]","\n").getBytes());
        }
        out.close();
    }


    /**
     * 从数据中心获取质心点数据
     * @param filePath 路径
     * @return 质心数据
     */
    public static List<List<Double>> getCenterFromFileSystem(String filePath) throws IOException {

        List<List<Double>> result = new ArrayList<List<Double>>();

        Path path = new Path(filePath);
        Configuration configuration = new Configuration();
        FileSystem fileSystem = null;
        try {
            fileSystem = getFileSystem();
        } catch (Exception e) {
            e.printStackTrace();
        }

        FileStatus[] listFiles = fileSystem.listStatus(path);
        for (FileStatus file : listFiles) {
            FSDataInputStream open = fileSystem.open(file.getPath());
            LineReader lineReader = new LineReader(open, configuration);
            Text line = new Text();
            // 读取每一行信息
            while(lineReader.readLine(line) > 0){
                List<Double> doubles = textToList(line);
                result.add(doubles);
            }
        }
        return result;
    }

    /**
     * 将Text转化为数组
     * @param text
     * @return
     */
    public static List<Double> textToList(Text text){
        List<Double> list = new ArrayList<Double>();

        String[] split = text.toString().split(",");

        for (int i = 0; i < split.length; i++) {
            list.add(Double.parseDouble(split[i]));
        }
        return list;
    }

    /**
     * 比较新旧数据点的变化情况
     * @return
     * @throws Exception
     */
    public static boolean compareCenters(String center, String newCenter, double delta) throws Exception{
        List<List<Double>> oldCenters = getCenterFromFileSystem(center);
        List<List<Double>> newCenters = getCenterFromFileSystem(newCenter);

        // 质心点数
        int size = oldCenters.size();
        // 维度
        int fieldSize = oldCenters.get(0).size();

        double distance = 0.0;

        for (int i = 0; i < size; i++) {
            for (int j = 0; j < fieldSize; j++) {
                double p1 = Math.abs(oldCenters.get(i).get(j));
                double p2 = Math.abs(newCenters.get(i).get(j));
                // this is used euclidean distance.
                distance += Math.pow(p1 - p2, 2);
            }
        }

        System.out.println("当前的质心点迭代变化值: " + distance);
        // 在区间内
        if(distance <= delta){
            return true;
        }else{
            Path centerPath = new Path(center);
            Path newCenterPath = new Path(newCenter);
            FileSystem fs = getFileSystem();

            // 删除当前质点文件
            fs.delete(centerPath,true );

            // 将新质点文件结果移动到当前质点文件
            fs.rename(newCenterPath,centerPath);
        }
        return false;
    }
}

可以看到,我们的K=3,并且选择的是数据集中的前三个点作为初始迭代的质心点。当然,更好的算法应该是从数据集中随机选取3个点或者以贴合业务的选取方式选取初始点,从算法中我们可以了解到,初始点的选择在一定迭代次数内是对结果有很大的影响的。

3.5、绘图

最终,我们得到的结果如下,其中的红点即为质心点

转载于:https://my.oschina.net/u/3091870/blog/3023599

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值