非参数估计Parzen窗估计-python实现
# -*- coding: utf-8 -*-
"""
@Time : 2022/3/14 9:50
@Auth : yusuen
@File :main.py
@IDE :PyCharm
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from mpl_toolkits.mplot3d import Axes3D
def generate_data(num_samples):
# 生成样本
datax = np.hstack([np.random.randn(num_samples) * 2 - 5, np.random.randn(num_samples) * 2 + 3, np.random.randn(num_samples) * 3 - 5])
datay = np.hstack([np.random.randn(num_samples) * 6 + 5, np.random.randn(num_samples) * 4 - 6, np.random.randn(num_samples) * 4 + 3])
# print("x", datax)
# print("y", datay)
x_, y_ = datax, datay
pos_ = np.vstack([datax, datay])
return x_, y_, pos_
def parzen_window(x, h):
u = (pos - x.reshape(-1,1))/ h
ix, iy = pos[:, (abs(u) <= 0.5).all(axis=0)]
k = len(ix)
return k / ((h**2) * n)
def rotate(angle):
ax.view_init(azim=angle)
if __name__ == '__main__':
windowsize = [0.2, 4, 10]
n = 20000
xv, yv, pos = generate_data(n)
# 散点图
plt.figure(1)
plot_pos = 131
xi = np.array([1,4])
for h in windowsize:
plt.subplot(plot_pos)
plot_pos += 1
u = (pos - xi.reshape(-1, 1)) / h
ix, iy = pos[:, (abs(u) <= 0.5).all(axis=0)]
plt.title("h=" + str(h))
plt.scatter(xv, yv, s=0.01)
plt.scatter(ix, iy)
plt.scatter(xi[0], xi[1], c='r')
plt.show()
# 三维
w = 50
gx = np.linspace(np.min(xv), np.max(xv), w)
gy = np.linspace(np.min(yv), np.max(yv), w)
# 获得网格坐标矩阵
gxv, gyv = np.meshgrid(gx, gy)
fgxv = gxv.ravel()
fgyv = gyv.ravel()
for i, h in enumerate(windowsize):
fpx = np.array([parzen_window(x, h) for x in np.vstack([fgxv,fgyv]).T])
fpx = fpx.reshape(w, w)
fig = plt.figure(num = i+1)
ax = Axes3D(fig)
surf = ax.plot_surface(gxv,gyv,fpx, rstride=1, cstride=1, cmap='GnBu_r')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_title('h='+str(h))
# 添加等高线
ax.contourf(gxv, gyv, fpx, zdir='z', offset= fpx.max(), cmap='GnBu_r')
ax.set_zlim3d(0, fpx.max())
# 添加图例bar
fig.colorbar(surf, shrink=0.5, aspect=5)
rot_animation = animation.FuncAnimation(fig, rotate, frames=np.arange(0, 362, 2), interval=100)
rot_animation.save('./rotation_{}.gif'.format('h='+str(h)), dpi=80, writer='pillow')
plt.show()
可视化结果
散点图
三维概率密度图
h=0.2
h=4
h=10