压缩近邻法的可视化Python实现

笔者博客链接 压缩近邻法的可视化Python实现

最近学习模式识别的时候看到压缩近邻法,然后对其进行了Python实现。这里分享一下代码。

压缩近邻法是为了降低近邻法的计算复杂度,其通过将数据集进行压缩,然后再进行近邻法的计算。这样可以大大降低计算复杂度。

本程序功能:

  1. 生成一个二维的数据集,以 y = s i n x y=sinx y=sinx 函数作非线性分割。
  2. 使用压缩近邻法构建比原数据集小的分类点集。并可视化了构建过程。
  3. 可以添加测试数据集并观察使用筛选后的分类点集在K-近邻法中对于不同K取值的分类效果。

构建过程如下:

首先设原数据集为 D D D,算法尝试构建分类集 S S S,使得 S S S中的点在最近邻算法下可以对 D D D进行分类。

S S S的构建过程如下:

  1. D D D中随机选取一个点加入 S S S
  2. 使用最近邻算法对 D D D中的点进行分类,将所有分类错误的点归为错误集 E E E
  3. E E E中随机选取一个点加入 S S S。(此步骤也可以使用其他策略,这里使用随机选择)
  4. 重复2-3步骤,直到 E E E为空。这时 S S S就是我们要找的分类集。

以下为代码实现:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Button, Slider
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import tkinter as tk
from tkinter import simpledialog
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['Microsoft YaHei'] # 指定默认字体:解决plot不能显示中文问题
mpl.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题

class ClassifierVisualizer:
    def __init__(self, n_points=100):
        self.n_points = n_points
        self.generate_data()
        self.S = set()
        self.setup_plot()
        self.test_mode = False
        self.k = 1

    def generate_data(self):
        self.X = np.random.uniform(-5, 5, (self.n_points, 2))
        self.y = (self.X[:, 1] > np.sin(self.X[:, 0])).astype(int)

    def setup_plot(self):
        self.fig = plt.figure(figsize=(15, 6))
        self.ax1 = self.fig.add_axes([0.05, 0.15, 0.4, 0.75])
        self.ax2 = self.fig.add_axes([0.55, 0.15, 0.4, 0.75])
        self.fig.suptitle('分类可视化')
        
        # 原始数据图
        self.scatter2 = self.ax2.scatter(self.X[:, 0], self.X[:, 1], c=self.y, cmap='coolwarm')
        self.ax2.set_title('原始数据')
        x = np.linspace(-6, 6, 100)
        self.sine2, = self.ax2.plot(x, np.sin(x), 'g-', lw=2)
        self.ax2.set_xlim(-6, 6)
        self.ax2.set_ylim(-6, 6)

        # 当前分类图
        self.scatter1 = self.ax1.scatter(self.X[:, 0], self.X[:, 1], c='gray')
        self.s_scatter1 = self.ax1.scatter([], [], c='orange', s=100)
        self.ax1.set_title('当前分类')
        self.ax1.set_xlim(-6, 6)
        self.ax1.set_ylim(-6, 6)
        self.sine1, = self.ax1.plot(x, np.sin(x), 'g-', lw=2)

        # 添加按钮
        self.ax_button = plt.axes([0.81, 0.02, 0.1, 0.075])
        self.button = Button(self.ax_button, '下一步')
        self.button.on_clicked(self.step)

        # 添加K值滑动条
        self.ax_slider = plt.axes([0.55, 0.05, 0.3, 0.03])
        self.k_slider = Slider(self.ax_slider, 'K值', 1, 10, valinit=1, valstep=1)
        self.k_slider.on_changed(self.update_k)

    def update_k(self, val):
        self.k = int(val)
        if self.test_mode:
            self.update_test_classification()

    def nearest_neighbors(self, point, k):
        if not self.S:
            return 0
        distances = np.sum((self.X[list(self.S)] - point)**2, axis=1)
        k = min(k, len(self.S))
        nearest_indices = np.argsort(distances)[:k]
        nearest_labels = [self.y[list(self.S)[i]] for i in nearest_indices]
        return np.mean(nearest_labels) > 0.5

    def step(self, event):
        if not self.test_mode:
            if not self.S:
                self.S.add(np.random.choice(self.n_points))
            else:
                W = [i for i in range(self.n_points) if i not in self.S and 
                     self.nearest_neighbors(self.X[i], self.k) != self.y[i]]
                if W:
                    self.S.add(np.random.choice(W))
                else:
                    self.prompt_test_mode()
                    return

            predictions = np.array([self.nearest_neighbors(x, self.k) for x in self.X])
            colors = np.where(predictions, 'red', 'blue')
            self.scatter1.set_facecolors(colors)
            
            s_points = self.X[list(self.S)]
            self.s_scatter1.set_offsets(s_points)
            
            self.ax1.set_title(f'当前分类 (S集合大小: {len(self.S)})')
        else:
            self.get_test_points()

        self.fig.canvas.draw_idle()

    def prompt_test_mode(self):
        self.test_mode = True
        self.ax2.clear()
        self.ax2.set_title('测试分类')
        self.ax2.set_xlim(-6, 6)
        self.ax2.set_ylim(-6, 6)
        x = np.linspace(-6, 6, 100)
        self.sine2, = self.ax2.plot(x, np.sin(x), 'g-', lw=2)
        self.scatter2 = self.ax2.scatter([], [])
        self.s_scatter2 = self.ax2.scatter([], [], c='orange', s=100)
        
        # 隐藏"下一步"按钮
        self.button.ax.set_visible(False)
        
        self.fig.canvas.draw_idle()
        self.get_test_points()

    def get_test_points(self):
        root = tk.Tk()
        root.withdraw()  # 隐藏主窗口

        n_test = simpledialog.askinteger("输入", "请输入测试点的数量:", parent=root, minvalue=1, maxvalue=1000)
        x_range = simpledialog.askstring("输入", "请输入x坐标范围 (min max):", parent=root)
        y_range = simpledialog.askstring("输入", "请输入y坐标范围 (min max):", parent=root)

        if n_test is None or x_range is None or y_range is None:
            return

        x_min, x_max = map(float, x_range.split())
        y_min, y_max = map(float, y_range.split())

        x_range = max(abs(x_min), abs(x_max), 6)
        y_range = max(abs(y_min), abs(y_max), 6)

        self.ax2.set_xlim(-x_range, x_range)
        self.ax2.set_ylim(-y_range, y_range)

        # 更新sin(x)函数的显示范围
        x = np.linspace(-x_range, x_range, 1000)
        self.sine2.set_data(x, np.sin(x))

        self.X_test = np.random.uniform(low=[x_min, y_min], high=[x_max, y_max], size=(n_test, 2))
        self.update_test_classification()

    def update_test_classification(self):
        # 显示原始数据
        self.scatter2.set_offsets(self.X)
        self.scatter2.set_facecolors(np.where(self.y, 'red', 'blue'))
        
        # 显示S集合的点
        s_points = self.X[list(self.S)]
        self.s_scatter2.set_offsets(s_points)
        
        # 显示并分类测试点
        predictions = np.array([self.nearest_neighbors(x, self.k) for x in self.X_test])
        colors = np.where(predictions, 'red', 'blue')
        test_scatter = self.ax2.scatter(self.X_test[:, 0], self.X_test[:, 1], c=colors, marker='s')
        
        self.ax2.set_title(f'测试分类 (K={self.k})')
        self.fig.canvas.draw_idle()

    def run(self):
        plt.show()

# 创建并运行可视化器
visualizer = ClassifierVisualizer()
visualizer.run()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值