11. Scipy Tutorial-多维插值griddata
scipy.interpolate模块下的griddata函数可以处理多元(维)函数的插值,以二元函数$f(x, y)$为例说明一下griddata的使用。与之前的一元函数插值interp1d相区别,interp1d是通过已知的点集$P = {(x_i, y_i)|x_i \in R, y_i \in R }$通过interp1d可以找到一个函数$f(x_i) = y_i$,那么任何一个$x_j$通过插值函数就能求得其$y_j = f(x_j)$,$y_i$即插值,这里的$x_j$可能是点集P里的一个数据,也可以不是,这是一元插值的思想。可以看出插值需要$f(x)$算出来,而griddata函数可以用于多元的插值,其返回值不是一个函数,而是插值本身,可以通过下面的代码验证一下这个说法。
下面的代码看上去很长,实际内容并不多,大致有三部分:第一部分从import numpy as np语句开始,到import matplotlib.pyplot as plt,这部分是本例子的核心,即求多元数据的插值,使用了griddata函数。第二部分是数据的可视化,从语句import matplotlib.pyplot as plt开始到第一个plt.show()即第一次数据可视化输出,这部分的作用是绘制已知点集和插值的数据的可视化。第三部分 从print "*" * 20语句开始一直到程序结束,这部分主要是验证griddata函数返回的是插值数据本身,无需像一元interp1d插值那样用点去计算插值了,返回值本身就是插值数据。
import numpy as np
def func(x, y):
return x*(1-x)*np.cos(4*np.pi*x) * np.sin(4*np.pi*y**2)**2
grid_x, grid_y = np.mgrid[0:1:100j, 0:1:200j]
points = np.random.rand(1000, 2)
values = func(points[:,0], points[:,1])
from scipy.interpolate import griddata
grid_z0 = griddata(points, values, (grid_x, grid_y), method='nearest')
grid_z1 = griddata(points, values, (grid_x, grid_y), method='linear')
grid_z2 = griddata(points, values, (grid_x, grid_y), method='cubic')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
plt.figure()
ax1 = plt.subplot2grid((2,2), (0,0), projection='3d')
ax1.plot_surface(grid_x, grid_y, grid_z0, color = "c")
ax1.set_xlim3d(0, 1)
ax1.set_ylim3d(0, 1)
ax1.set_zlim3d(-0.25, 0.25)
ax1.set_title('nearest')
ax2 = plt.subplot2grid((2,2), (0,1), projection='3d')
ax2.plot_surface(grid_x, grid_y, grid_z1, color = "c")
ax2.set_xlim3d(0, 1)
ax2.set_ylim3d(0, 1)
ax2.set_zlim3d(-0.25, 0.25)
ax2.set_title('linear')
ax3 = plt.subplot2grid((2,2), (1,0), projection='3d')
ax3.plot_surface(grid_x, grid_y, grid_z2, color = "r")
ax3.set_xlim3d(0, 1)
ax3.set_ylim3d(0, 1)
ax3.set_zlim3d(-0.25, 0.25)
ax3.set_title('cubic')
ax4 = plt.subplot2grid((2,2), (1,1), projection='3d')
ax4.scatter(points[:,0], points[:,1], values, c= "b")
ax4.set_xlim3d(0, 1)
ax4.set_ylim3d(0, 1)
ax4.set_zlim3d(-0.25, 0.25)
ax4.set_title('org_points')
plt.tight_layout()
plt.show()