基于python的dbscan算法手动实现

关键参数:

Eps:邻域范围

Minpts:邻域密度

步骤:

1、遍历所有数据点,若已被标记,跳过,否则,则进行邻域查询

2、若该点的邻近点个数>Minpts,说明是核心点,生成新的簇并以此为中心进行扩张查询,否则,则将其标记为噪声点

import numpy as np
from collections import deque

def dbscan(data, Eps, MinPts):
    #获取距离
    def get_dist(data):
        dist_matrix = np.linalg.norm(data[:,np.newaxis] - data, axis=2) #np.newaxis插入新的列,用来存储每个数据点与其它数据点的差值
        return dist_matrix    #axis=2,沿着第三个维度的方向求欧几里得距离,结果的第一维度代表第几个数据点,第二维度代表点与点的间距
    #获取邻域   
    def region_quary(point_idx, dist, Eps):
        neighbours = np.where(dist[point_idx] < Eps)[0] #获取索引
        return neighbours       #np.where返回一个元组,第一个为位置索引,第二个为符合条件的数组值
    
    def expand_cluster(point_idx, neighbours, cluster_id, labels, dist):
        labels[point_idx] = cluster_id
        queue = deque(neighbours)
        
        while queue:     #只要队列不空
            neighbours_idx = queue.popleft() #取出队头
            if labels[neighbours_idx] == -1:  #若该点在核心点检测中被置为噪声点,则此处将其置为边界点
                labels[neighbours_idx] = cluster_id
            elif labels[neighbours_idx] == 0:  #若该点尚未被访问
                labels[neighbours_idx] = cluster_id
                neighbour_neighbours = region_quary(neighbours_idx, dist, Eps)  #获取邻居点的邻域
                
                if len(neighbour_neighbours) >= MinPts:     #如果邻居点的邻域密度大于MinPts,则入队
                    queue.extend(neighbour_neighbours)
    
    labels = np.zeros(data.shape[0])   #每个数据点的分类标签
    dist = get_dist(data)
    cluster_id = 0
    
    for point_idx in range(data.shape[0]):
        if labels[point_idx] != 0:
            continue
        
        neighbours = region_quary(point_idx, dist, Eps)  #获取邻域
        
        if len(neighbours) < MinPts:
            labels[point_idx] = -1
        else:
            cluster_id += 1     #生成新的簇
            expand_cluster(point_idx, neighbours, cluster_id, labels, dist)
    
    return labels

if __name__ == "__main__":
    from sklearn.datasets import make_blobs, make_circles, make_moons
    #生成数据集
#     blobs,_ = make_blobs(n_samples=200, centers=2, cluster_std=0.5,random_state=66)
#     moon,_ = make_moons(n_samples=500, noise=0.1, random_state= 555)
    circles,_ = make_circles(n_samples=600, factor=0.2, noise=0.1, random_state=6)
    X = circles
#     X = np.concatenate((blobs, circles, moon), axis=0)
    #关键参数初始化
    Eps = 0.2
    MinPts = 5
    labels = dbscan(X, Eps, MinPts)
    
    #可视化结果
    n_cluster = set(labels)
    colors = [plt.cm.Spectral(i) for i in np.linspace(0, 1, len(n_cluster))]  #colormap,从冷色到暖色的颜色渐变
    
    for k, col in zip(n_cluster, colors):
        if k == -1:
            col = [0, 0, 0, 1]   #如果是噪声点,将其设置为黑色
        xy = X[labels == k]
        plt.scatter(xy[:, 0], xy[:, 1], color = col, edgecolor='k')  #color后不要加s
    
    plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值