读源码学算法之MeanShift Clustring

最近学了一个新的算法:MeahShift,中文翻译为均值漂移,可用于目标跟踪,聚类,图像分割等任务。本文主要介绍如何通过MeahShift实现聚类。说到聚类,大家第一时间会想到Kmeans,但是它需要我们提前给定K,然后不停的迭代直至收敛到K个中心,然而很多时候我们并不知道K取多少比较合适。那么,基于MeanShift的聚类算法就不需要提前给定类别的数目,它可以自动决定有多少类。

一、MeanShift原理

我们以二维为例,如下图所示,给定二维空间的 n n n个点,随机的从任意点 p i p_i pi开始,以 r r r为半径做一个圆,将圆内的每个点 p j p_j pj p i p_i pi连接,每个连接构成一个向量 ( p i , p j ) (p_i,p_j) (pi,pj). 计算所有这些向量的均值,得到一个漂移量 m m m:
m = 1 n ∑ j = 1 n ( p j − p i ) m = \frac{1}{n}\sum_{j=1}^{n}(p_j-p_i) m=n1j=1n(pjpi)
其中, n n n表示落到以 p i p_i pi为圆心, r r r为半径的圆内的点数。以中学力学的观点看,将每个向量 ( p j − p i ) (p_j-p_i) (pjpi)看作一个力的话,那么 m m m可以认为是这些力的合力。然后,再将 p i p_i pi沿着合力方向平移到新的位置: p i 1 = p i + m p_i^1 = p_i + m pi1=pi+m。显然,合力将 p i p_i pi拉向更稠密的位置。实际上, p i 1 p_i^1 pi1也就是落在圆内那些点(除了 p i p_i pi)的几何中心。因为:
p i 1 = p i + m = p i + 1 n ∑ j = 1 n ( p j − p i ) = 1 n ∑ j = 1 n p j p_i^1 = p_i + m = p_i + \frac{1}{n}\sum_{j=1}^{n}(p_j-p_i) = \frac{1}{n}\sum_{j=1}^{n}p_j pi1=pi+m=pi+n1j=1n(pjpi)=n1j=1npj
以此类推,我们从新的中心点 p i 1 p_i^1 pi1继续出发到到 p i 2 p_i^2 pi2,到达 p i 3 p_i^3 pi3,最后从 p i 3 p_i^3 pi3继续算的时候,发现下一个中心跟 p i 3 p_i^3 pi3很接近了,那么就可以认为这一轮shift收敛了。如下图(4)所示,我们发现 p i p_i pi最终shift到 p i 3 p_i^3 pi3,到达了这些点最密集的区域。
在这里插入图片描述

二、MeanShift 聚类

那么,MeanShift如何与聚类关联起来呢?假设上图中从所有点出发进行MeanShift ,都到达 p i 3 p_i^3 pi3或者非常接近 p i 3 p_i^3 pi3我们可以认为这些都点都属于同一类。比如是下面这样的点集,很显然应该是两个类。我们从左边的任意点出发,meanshift会收敛到左侧某个点,而从右侧的任意点出发,meanshift会收敛到右侧的某个点,那么这两个点就是两个聚类中心。我们记录每个原始数据点最终收敛的位置,根据收敛结果对原始数据点进行分类。比如: p 1 , p 3 , p 5 p_1,p_3,p_5 p1,p3,p5最终收敛到 x x x, p 2 , p 4 , p 6 p_2,p_4,p_6 p2,p4,p6最终收敛到 y y y, 那么我们认为 p 1 , p 3 , p 5 p_1,p_3,p_5 p1,p3,p5属于以 x x x为中心的一类, p 2 , p 4 , p 6 p_2,p_4,p_6 p2,p4,p6属于以 y y y为中心的另一类。
在这里插入图片描述

三、 Talk is cheap,Show me the code

代码来自于:https://github.com/mattnedrich/MeanShift_cpp
代码稍微有改动,下面我写了详细的comments, 因此就不再解释了。PS:这份代码还有很大改进空间以降低时间复杂度。

MeanShift.h

struct Cluster {
	vector<double> mode;  //聚类中心坐标(x,y)
    vector<vector<double> > original_points; //当前类包含的原始顶点
    vector<vector<double> > shifted_points;  //当前类包含的shifted顶点
};
typedef vector<double> Point;

class MeanShift {
public:
	//对所有点进行meanshift
    vector<Point> meanshift(const vector<Point> & points, double radius,double EPSILON = 0.00001);
	//根据meanshift结果聚类
    vector<Cluster> cluster(const vector<Point> &, double);
private:
	//对单个点进行meanshift
    void shift_single_point(const Point&, const vector<Point> &, double, Point&);
    vector<Cluster> mean_shift_cluster(const vector<Point> &, const vector<Point> &);
};

MeanShift.cpp

#include <math.h>
#include "MeanShift.h"
using namespace std;
#define CLUSTER_EPSILON 0.5

//两点欧式距离
double euclidean_distance(const vector<double> &point_a, const vector<double> &point_b){
	double total = (point_a[0] - point_b[0])*(point_a[0] - point_b[0]);
	total += (point_a[1] - point_b[1])*(point_a[1] - point_b[1]);
    return sqrt(total);
}
//两点欧式距离平方
double euclidean_distance_sqr(const vector<double> &point_a, const vector<double> &point_b){
    double total = euclidean_distance(point_a, point_b);
	return total*total;
}
//高斯核函数,
double gaussian_kernel(double distance, double radius){
    double temp =  exp(-1.0/2.0 * (distance*distance) / (radius*radius));
    return temp;
}

//顶点point,meanshift一步的结果,这里略有不同,不是只根据半径为radius的圆内那些顶点计算新的中心。
//这里遍历了所有顶点,每个顶点根据到point的距离给了一个权值,如上所示的gaussian_kernel,可以认为离
//point较远的顶点几乎不产生影响,近一些的顶点影响较大。shifted_point = (\sum w_i*p_i)/(\sum_w_i)
void MeanShift::shift_single_point(const Point &point,const vector<Point> &points,double radius,Point &shifted_point) {
    shifted_point = vector<double>(2,0);
    double total_weight = 0;

    for(int i=0; i<points.size(); i++){
        const Point& temp_point = points[i];
        double distance = euclidean_distance(point, temp_point);
        double weight = gaussian_kernel(distance, radius);
        
        shifted_point[0] += temp_point[0] * weight;
		shifted_point[1] += temp_point[1] * weight;
        
        total_weight += weight;
    }
    shifted_point[0] /= total_weight;
	shifted_point[1] /= total_weight;
}

vector<Point> MeanShift::meanshift(const vector<Point> &points, double radius,double EPSILON){
    const double EPSILON_SQR = EPSILON*EPSILON;
    vector<bool> stop_moving(points.size(), false);
    vector<Point> shifted_points = points;
    double max_shift_distance;
    Point point_new;

	//多所有顶点进行多轮meanshift,每一轮对所有顶点只是meanshift一个step
    do {
        max_shift_distance = 0;
		//每次将只所有顶点meanshfit一步
        for(int i=0; i<points.size(); i++){
            if (!stop_moving[i]) {
                shift_single_point(shifted_points[i], points, radius, point_new);

				//meanshift距离,并记录这一轮(所有顶点移动)的最大距离
                double shift_distance_sqr = euclidean_distance_sqr(point_new, shifted_points[i]);
                if(shift_distance_sqr > max_shift_distance)
                    max_shift_distance = shift_distance_sqr;

				//如果meanshift距离太少,认为已经收敛,不必再移动
                if(shift_distance_sqr <= EPSILON_SQR) 
                    stop_moving[i] = true;

				//记录第i个顶点meanshift后的位置
                shifted_points[i] = point_new;
            }
        }
        printf("max_shift_distance: %f\n", sqrt(max_shift_distance));

	//如果这一轮所有顶点的最大移动距离已经很小,则全部收敛,停止循环
    } while (max_shift_distance > EPSILON_SQR); 
    return shifted_points;
}

vector<Cluster> MeanShift::mean_shift_cluster(const vector<Point> &points,const vector<Point> &shifted_points){
    vector<Cluster> clusters;
	//遍历所有meanshift后的顶点位置(聚类中心)
    for (int i = 0; i < shifted_points.size(); i++) {
		//计算当前聚类中心shifted_points[i]跟已确定的聚类中心clusters[c].mode的距离
		//如果shifted_points[i]跟clusters[c]足够接近,那么point[i]属于clusters[c]
		//如果shifted_points[i]跟所有clusters[c]距离都很大,那她应该是单独一类
        int c = 0;
        for (; c < clusters.size(); c++) {
            if (euclidean_distance(shifted_points[i], clusters[c].mode) <= CLUSTER_EPSILON) {
                break;
            }
        }

		//如果shifted_points[i]跟所有clusters[c]距离都很大,那她应该是单独一类
        if (c == clusters.size()) {
            Cluster clus;
            clus.mode = shifted_points[i];
            clusters.push_back(clus);
        }

		//将points[i]加入到clusters[c]
        clusters[c].original_points.push_back(points[i]);
        clusters[c].shifted_points.push_back(shifted_points[i]);
    }

    return clusters;
}

//程序入口,提供输入的点集以及半径
vector<Cluster> MeanShift::cluster(const vector<Point> &points, double radius){
	//对所有点集进行meanshift
    vector<Point> shifted_points = meanshift(points, radius);
	//根据meanshift结果聚类
    return mean_shift_cluster(points, shifted_points);
}
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Researcher-Du

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值