Python 实现MeanShift算法

原理

大家自行百度吧,我懒得码字了

推荐一下原理原理https://blog.csdn.net/jinshengtao/article/details/30258833

代码

直接上代码了,看不懂,就参照一下原理

# author: wdq
# contact: 1920132572@qq.com
# datetime:2022/3/15 17:40
# software: PyCharm
import random
from collections import Counter
from typing import List

import numpy as np
from numpy import ndarray


class MeanShift:

    def __init__(self, nums: ndarray, band_width: float):
        """

        :param nums: 要划分的ndarray
        :param band_width: 窗口大小
        """
        # 要划分的ndarray
        self.__nums = nums
        # 窗口大小
        self.__band_width = band_width

        # 停止步长
        self.__stop_band_width = 10 ** -4 * self.__band_width
        # 访问数组
        self.__is_visited = [False] * self.__nums.shape[0]
        # 聚类中心
        self.__cluster_centers = []
        # 聚类
        self.__cluster = []

    def mean_shift(self) -> List[List[List[int]]]:
        # 判断是否所有点都被访问过
        while not self.__is_all_visited():
            my_member = []
            # 在没被访问的点随机选一个点
            start_point = random.choice([i for i in range(self.__nums.shape[0]) if not self.__is_visited[i]])
            my_mean = self.__nums[start_point]
            while True:
                # 得到到各点的距离,以及权重
                distance, gaussian = self.__get_shift(my_mean)
                # 找到在窗口的点
                in_the_area = self.__find__points(distance)
                # 保留当前的位置
                old_mean = my_mean.copy()
                # 得到新的位置
                my_mean = self.__get_new_mean(gaussian, in_the_area)
                # 将范围的点划到当次的聚类
                my_member.extend(in_the_area)
                # 更新当前的访问数组
                self.__update_visited(in_the_area)
                # 判断是否小于停止步长
                if self.__get_distance(old_mean, my_mean) < self.__stop_band_width:

                    merge_width = None
                    # 遍历当前聚类
                    for i in range(len(self.__cluster_centers)):
                        # 判断中心点离得太近
                        if self.__get_distance(my_mean, self.__cluster_centers[i]) < self.__band_width / 2:
                            merge_width = i
                            break
                    # 如果太近了就合并这2个聚类
                    if merge_width is not None:
                        # 合并中心点
                        self.__cluster_centers[merge_width] = self.__get_new_center(my_mean,
                                                                                    self.__cluster_centers[merge_width])
                        # 合并聚类中的点
                        self.__cluster[merge_width].extend(my_member)

                    # 否则就添加一个聚类
                    else:
                        self.__cluster_centers.append(my_mean.tolist())
                        self.__cluster.append(my_member)
                    break
        # 返回分好类的结果
        return self.__get_result()

    def __is_all_visited(self) -> bool:
        """

        :return: 是否全部访问
        """
        # 遍历访问数组
        for i in self.__is_visited:
            if not i:
                return False
        return True

    def __get_distance(self, start: any, end: any) -> float:
        """

        :param start: 起始点
        :param end: 终点
        :return: 两点之间的距离
        """
        # 类型转换
        if type(start) != ndarray:
            start = np.array(start)
        if type(end) != ndarray:
            end = np.array(end)
        # 返回欧式距离
        return np.linalg.norm(start - end)

    def __get_shift(self, start: ndarray) -> (ndarray, ndarray):
        """

        :param start: 开始的点
        :return: 计算滑动的距离
        """

        # 距离
        distance = np.zeros((self.__nums.shape[0], 1))
        # 权重
        gaussian = np.zeros((self.__nums.shape[0], 1))

        for i in range(distance.shape[0]):
            temp = self.__get_distance(start, self.__nums[i])
            gaussian[i] = self.__gaussian_kernel(temp, self.__band_width)
            distance[i] = temp
        return distance, gaussian

    def __gaussian_kernel(self, distance: float, bandwidth: float) -> float:
        """
        高斯核函数
        :param distance: 距离
        :param bandwidth: 窗口大小
        :return: 权重
        """
        return (1 / (bandwidth * np.sqrt(2 * np.pi))) * np.exp(-0.5 * (distance / bandwidth) ** 2)

    def __get_new_mean(self, gaussian: ndarray, in_the_area: List[int]) -> ndarray:
        """

        :param gaussian: 权重
        :param in_the_area: 在区域的点
        :return:
        """
        # 权重
        weight = 0.
        # 在范围的点
        new_mean = np.array([self.__nums[i].tolist() for i in in_the_area])

        for i in range(len(in_the_area)):
            new_mean[i] = new_mean[i] * gaussian[in_the_area[i]]
            weight += gaussian[in_the_area[i]]
        # 对范围的点进行加权,并算出漂移到的点
        return np.sum(new_mean, axis=0) / weight if weight != 0 else np.sum(new_mean, axis=0)

    def __find__points(self, distance: ndarray) -> List[int]:
        """

        :param distance: 距离ndarray
        :return: 在窗口大小内的点
        """
        return [i for i, j in enumerate(distance) if j < self.__band_width ** 2]

    def __update_visited(self, in_the_area: List[int]) -> None:
        """
        更新访问过的点
        :param in_the_area: 在窗口大小内的点
        :return:
        """
        for i in in_the_area:
            self.__is_visited[i] = True

    def __get_new_center(self, mymean: ndarray, old_center: List[int]) -> List[int]:
        """
        合并中心点
        :param mymean: 现在的中心点
        :param old_center: 以前的中心点
        :return:
        """
        return [(i + j) / 2 for i, j in zip(mymean.tolist(), old_center)]

    def __get_result(self) -> List[List[List[int]]]:
        """
        将结果分好类并返回
        这段代码比较丑陋,将就看看,不看也行,我自己都不想看
        大致意思就是找这些点应该分到那个类
        :return:
        """
        count = []
        result = [[] for i in range(len(self.__cluster))]
        # 计数,计出每个点到每个聚类的次数
        for i in self.__cluster:
            count.append(dict(Counter(i)))
        belong = []

        # 遍历找出每个点到到那个聚类的最大值,那我们就可以认为它在那个聚类
        for num in range(len(self.__nums)):
            # 最大次数的索引
            index = 0
            for i in range(1, len(count)):
                if count[i].get(num, 0) > count[index].get(num, 0):
                    index = i
            belong.append(index)

        # 分类
        for i in range(len(self.__nums)):
            result[belong[i]].append(self.__nums[i].tolist())

        # 把空的聚类移除
        return [i for i in result if i]

测试代码

import matplotlib
from matplotlib import pyplot as plt
from sklearn import datasets

from MeanShift import MeanShift

matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False
iris = datasets.load_iris()  # 引入数据集
# 分的类不好就重新分,多试一哈
mean_shift = MeanShift(nums=iris.data, band_width=1.34)  # 对于iris,窗口大小为1.34,别问为什么,别问,问就是好用
colors = ['red', 'green', 'blue', 'black', 'yellow']
a = mean_shift.mean_shift()
for i in range(len(a)):
    for j in a[i]:
        plt.scatter(j[0], j[1], c=colors[i])
plt.title("Mean-Shift")
plt.xlabel('萼片长度')
plt.ylabel('萼片宽度')
plt.show()

"""

  ___________.__                   __               _____                 _____  .__      ._.
  \__    ___/|  |__ _____    ____ |  | __  ______ _/ ____\___________    /     \ |__| ____| |
    |    |   |  |  \\__  \  /    \|  |/ / /  ___/ \   __\/  _ \_  __ \  /  \ /  \|  |/    \ |
    |    |   |   Y  \/ __ \|   |  \    <  \___ \   |  | (  <_> )  | \/ /    Y    \  |   |  \|
    |____|   |___|  (____  /___|  /__|_ \/____  >  |__|  \____/|__|    \____|__  /__|___|  /_
                  \/     \/     \/     \/     \/                               \/        \/\/      


"""

运行结果

标准答案

MeanShift算法

只用来学习,借鉴,错的话,欢迎批评和指导!

邮箱:cse.dqwu19@gzu.edu.cn

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值