机器学习 - DBSCAN算法

DBSCAN算法是一种基于密度的聚类算法,它的核心思想是定义一个密度,然后把密度相近的且靠的较近的点(密度可达)逐渐合并成一个簇;也可以认为是使用较低密度的区域来分割密度较高的区域。


1. 算法概述

1)找出数据中所有的核心对象C
2)找出每一个核心对象所有的密度可达点;注意在为当前的一个核心点Ci寻找密度可达点时,多半都会找到核心对象也是密度可达点,此时C将不断缩小。因此最终簇数一般都会小于C中元素的个数。

详细算法介绍在这里:http://download.csdn.net/download/zk_j1994/9932804


2. 算法实现

DBSCAN核心算法类:

class DBSCAN:
    def __init__(self, radius, minPoints):
        """
        radius:     邻域半径
        minPoints:  邻域内最少点数目
        """
        self.radius = radius
        self.minPoints = minPoints
        
    def _all_coreObject(self, train_x):
        """ 寻找数据集中所有的核心对象 
        train_x:    全部数据, 一行代表样本, 列代表特征
        coreObject: 数据集中所有的核心对象, 其索引构成列表
        pointCnt:   邻域内点数统计
        """
        coreObject = []
        for i in range(len(train_x)):
            pointsCnt = 0
            for j in range(len(train_x)):
                distance = np.sqrt(sum((train_x[i, :] - train_x[j, :])**2))
                if distance <= self.radius:
                    pointsCnt += 1
            if pointsCnt >= self.minPoints:
                coreObject.append(i)
        return coreObject
    
    def _cal_neighborPoints(self, train_x, currentPointIndex):
        """ 计算currentPoint的邻域点, 不包含自身 
        train_x:            全部数据, 一行代表样本, 列代表特征
        currentPointIndex:  currentPoint在train_x所在的行
        currentPoint:       当前点
        distance:           当前点与train_x中所有点的距离
        index:              currentPoint的邻域点的索引
        """
        currentPoint = train_x[currentPointIndex, :]
        distance = np.sqrt(np.sum((train_x - currentPoint) ** 2, axis=1))
        index = set(np.argwhere(distance <= self.radius).flatten())
        index.remove(currentPointIndex)
        return index
    
    def clustering(self, train_x):
        """
        train_x:
            全部数据, 一行代表样本, 列代表特征
        k:
            聚类簇数
        noAccess_sample:
            未访问样本列表
        coreObject:     
            所有核心对象
        noAccess_sample_old:
            未访问样本列表的备份
        Q:
            队列
        neighborPoints:
            当前样本的邻域样本
        clusters:
            簇, 存放的是在train_x的索引
        """
        k = 0
        clusters = []
        noAccess_sample = list(range(len(train_x)))
        coreObject = self._all_coreObject(train_x)
        
        while len(coreObject) != 0:
            # 未访问样本备份
            noAccess_sample_old = noAccess_sample.copy()
            
            # 随机取出一核心对象放入Q队列
            Q = []
            Q.append(coreObject[np.random.randint(0, len(coreObject))])
            
            # 未访问样本除去随机取出的核心对象
            noAccess_sample.remove(Q[0])

            # 找出随机核心对象Q[0]的所有可达样本
            while len(Q) != 0:
                if Q[0] in coreObject:
                    # 计算Q[0]的邻域
                    neighborPoints = self._cal_neighborPoints(train_x, Q[0])
                    
                    # 邻域中未访问过的样本
                    neighbor_noAccess = list(set(noAccess_sample) & neighborPoints)
                    Q.extend(neighbor_noAccess)
                    
                    # 更新未访问样本列表
                    noAccess_sample = [ i for i in noAccess_sample if i not in neighbor_noAccess ]
                # Q[0]出队列
                Q.remove(Q[0])
            
            # 增加簇
            clusters.append(list(set(noAccess_sample_old) - set(noAccess_sample)))
            
            # 更新核心对象列表
            coreObject = list(set(coreObject) - set(clusters[k]))
            k += 1
        return clusters
    
    def find_noise(self):
        """ 寻找噪音点, 即train_x中存在, clusters中不存在的点 """
        pass


工具函数:

# -*- coding: utf-8 -*-
""" 基于密度的聚类算法DBSCAN """
import numpy as np
import matplotlib.pyplot as plt

def load_data():
    with open("../DBSCAN/data/Spiral.txt", "r") as f:
        data = []
        for line in f.readlines():
            data.append(line.strip().split(","))
        data = np.array(data, np.float)[:, 0:2]
    return data

def draw_result(train_x, clusters):
    plt.figure("DBSCAN algorithm")
    for i in range(len(clusters)):
        plt.scatter(train_x[clusters[i], 0], train_x[clusters[i], 1], s=20)
    plt.show()

测试程序:

if __name__ == "__main__":
    train_x = load_data()
    
    """
    Spiral:
        radius = 2,     minPoints = 5; 
    D31:
        radius = 0.4,   minPoints = 6; 
    """
    obj = DBSCAN(radius=2, minPoints=6)
    
    clusters = obj.clustering(train_x)

    draw_result(train_x, clusters)



Kmeans跑出来的结果:


数据集下载:http://download.csdn.net/download/zk_j1994/9932797


4. 算法优缺点
1)不需要提前确定聚类的簇数;
2)密度依赖于半径,半径需要调节;
3)基于密度定义,能找出噪音,能处理任意形状和大小的簇;
4)簇密度变化太大时,比较麻烦,选取半径radius和minPoints比较困难;
5)时间复杂度O(N2),使用高级数据结构KD树等,可以达到O(NlogN);


参考文献

https://wenku.baidu.com/view/43d179f202d276a201292e50.html

http://blog.csdn.net/itplus/article/details/10088625

http://blog.csdn.net/google19890102/article/details/37656733


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值