均值漂移聚类matlab代码,如何使均值漂移聚类适用于5个以上的聚类?

我在mean-shift聚类上遇到了麻烦。当簇数很小时(2,3,4),它工作得很快,并且输出正确的结果,但是当簇数增加时,它失败了。在

例如,可以检测到3个簇:

f151cde193737fda7c69885cb9d370fb.png

但当数量增加时,它就失败了:

a1b482915a641c79a6fad374c5eac109.png

b40fd2f45726793367182d34509004c0.png

下面是完整的代码列表:#!/usr/bin/env python

import sys

import logging

import numpy as np

import matplotlib

matplotlib.use('Agg')

import matplotlib.pyplot as plot

from sklearn.cluster import estimate_bandwidth, MeanShift, get_bin_seeds

from sklearn.datasets.samples_generator import make_blobs

def test_mean_shift():

logging.debug('Generating mixture')

count = 5000

blocks = 7

std_error = 0.5

mixture, clusters = make_blobs(n_samples=count, centers=blocks, cluster_std=std_error)

logging.debug('Measuring bendwith')

bandwidth = estimate_bandwidth(mixture)

logging.debug('Bandwidth: %r' % bandwidth)

mean_shift = MeanShift(bandwidth=bandwidth)

logging.debug('Clustering')

mean_shift.fit(mixture)

shifted = mean_shift.cluster_centers_

guess = mean_shift.labels_

logging.debug('Centers: %r' % shifted)

def draw_mixture(mixture, clusters, output='mixture.png'):

plot.clf()

plot.scatter(mixture[:, 0], mixture[:, 1],

c=clusters,

cmap=plot.cm.coolwarm)

plot.savefig(output)

def draw_mixture_shifted(mixture, shifted, output='mixture_shifted.png'):

plot.clf()

plot.scatter(mixture[:, 0], mixture[:, 1], c='r')

plot.scatter(shifted[:, 0], shifted[:, 1], c='b')

plot.savefig(output)

logging.debug('Drawing')

draw_mixture_shifted(mixture, shifted)

draw_mixture(mixture, guess)

if __name__ == '__main__':

logging.basicConfig(level=logging.DEBUG)

test_mean_shift()

我做错什么了?在

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值