1、meshgrid
1、问题描述
今天在看机器学习视频时,看到预测函数中有下面的代码
xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),
np.arange(x2_min, x2_max, resolution))
Z = classifier.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
不太懂其中的意思,索性就进行了下面的分析和测试。
2、分析
该函数可以在二维坐标轴上生成对应的交叉点,每个交叉点对应一个坐标,坐标的X轴取值为:[0. 0.2 0.4 0.6 0.8],Y轴取值为:[1 1.5 2],进行将X/Y轴取值进行组合后,可以拼接成相应的坐标,其测试代码如下:
import numpy as np
XX1, XX2 = np.meshgrid(np.arange(0, 1, 0.2), np.arange(1, 2.1, 0.5))
# [[0. 0.2 0.4 0.6 0.8]
# [0. 0.2 0.4 0.6 0.8]
# [0. 0.2 0.4 0.6 0.8]]
print(XX1)
# <class 'numpy.ndarray'>
print(type(XX1))
# (3, 5)
print(XX1.shape)
# [[1. 1. 1. 1. 1. ]
# [1.5 1.5 1.5 1.5 1.5]
# [2. 2. 2. 2. 2. ]]
print(XX2)
# 将数据进行扁平化
# [0. 0.2 0.4 0.6 0.8 0. 0.2 0.4 0.6 0.8 0. 0.2 0.4 0.6 0.8]
print(XX1.ravel())
# <class 'numpy.ndarray'>
print(type(XX1.ravel()))
# [1. 1. 1. 1. 1. 1.5 1.5 1.5 1.5 1.5 2. 2. 2. 2. 2. ]
print(XX2.ravel())
Z = np.array([XX1.ravel(), XX2.ravel()])
# [[0. 0.2 0.4 0.6 0.8 0. 0.2 0.4 0.6 0.8 0. 0.2 0.4 0.6 0.8]
# [1. 1. 1. 1. 1. 1.5 1.5 1.5 1.5 1.5 2. 2. 2. 2. 2. ]]
print(Z)
# [[0. 1. ]
# [0.2 1. ]
# [0.4 1. ]
print(Z.T)
2、mgrid
1、代码
# 5j个点,步长为复数表示点数,为实数表示步长
# np.mgrid[start:end:step]
x1, x2 = np.mgrid[-5:5:5j, -5:5:5j]
# 输出结果
[[-5. -5. -5. -5. -5. ]
[-2.5 -2.5 -2.5 -2.5 -2.5]
[ 0. 0. 0. 0. 0. ]
[ 2.5 2.5 2.5 2.5 2.5]
[ 5. 5. 5. 5. 5. ]]
[[-5. -2.5 0. 2.5 5. ]
[-5. -2.5 0. 2.5 5. ]
[-5. -2.5 0. 2.5 5. ]
[-5. -2.5 0. 2.5 5. ]
[-5. -2.5 0. 2.5 5. ]]
print(x1, x2)