模式识别课作业-Parzen窗估计实现

非参数估计Parzen窗估计-python实现

# -*- coding: utf-8 -*-
"""
@Time : 2022/3/14 9:50
@Auth : yusuen
@File :main.py
@IDE :PyCharm
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from mpl_toolkits.mplot3d import Axes3D

def generate_data(num_samples):
    
    # 生成样本
    datax = np.hstack([np.random.randn(num_samples) * 2 - 5, np.random.randn(num_samples) * 2 + 3, np.random.randn(num_samples) * 3 - 5])
    datay = np.hstack([np.random.randn(num_samples) * 6 + 5, np.random.randn(num_samples) * 4 - 6, np.random.randn(num_samples) * 4 + 3])

    # print("x", datax)
    # print("y", datay)

    x_, y_ = datax, datay
    pos_ = np.vstack([datax, datay])
    return x_, y_, pos_

def parzen_window(x, h):

    u = (pos - x.reshape(-1,1))/ h
    ix, iy = pos[:, (abs(u) <= 0.5).all(axis=0)]
    k = len(ix)

    return k / ((h**2) * n)


def rotate(angle):
    ax.view_init(azim=angle)

if __name__ == '__main__':

    windowsize = [0.2, 4, 10]
    n = 20000
    xv, yv, pos = generate_data(n)
    # 散点图
    plt.figure(1)
    plot_pos = 131
    xi = np.array([1,4])
    for h in windowsize:
        plt.subplot(plot_pos)
        plot_pos += 1
        u = (pos - xi.reshape(-1, 1)) / h
        ix, iy = pos[:, (abs(u) <= 0.5).all(axis=0)]
        plt.title("h=" + str(h))
        plt.scatter(xv, yv, s=0.01)
        plt.scatter(ix, iy)
        plt.scatter(xi[0], xi[1], c='r')
    plt.show()

    # 三维
    w = 50
    gx = np.linspace(np.min(xv), np.max(xv), w)
    gy = np.linspace(np.min(yv), np.max(yv), w)
    # 获得网格坐标矩阵
    gxv, gyv = np.meshgrid(gx, gy)
    fgxv = gxv.ravel()
    fgyv = gyv.ravel()

    for i, h in enumerate(windowsize):
        fpx = np.array([parzen_window(x, h) for x in np.vstack([fgxv,fgyv]).T])
        fpx = fpx.reshape(w, w)
        fig = plt.figure(num = i+1)
        ax = Axes3D(fig)
        surf = ax.plot_surface(gxv,gyv,fpx, rstride=1, cstride=1, cmap='GnBu_r')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_zlabel('z')
        ax.set_title('h='+str(h))

        # 添加等高线
        ax.contourf(gxv, gyv, fpx, zdir='z', offset= fpx.max(), cmap='GnBu_r')
        ax.set_zlim3d(0, fpx.max())
        # 添加图例bar
        fig.colorbar(surf, shrink=0.5, aspect=5)
        rot_animation = animation.FuncAnimation(fig, rotate, frames=np.arange(0, 362, 2), interval=100)
        rot_animation.save('./rotation_{}.gif'.format('h='+str(h)), dpi=80, writer='pillow')
        plt.show()

可视化结果

散点图
在这里插入图片描述
三维概率密度图

h=0.2
在这里插入图片描述

h=4
在这里插入图片描述
h=10
在这里插入图片描述

  • 1
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值