插值函数的缩放
主要问题是插值函数的标准差选择不当:stddev = 100
你的函数(它的峰)的大小约为1。所以,使用
^{pr2}$
X值的顺序
之所以出现红线混乱,是因为matplotlib中的plt按给定的顺序连接连续的数据点。由于X值是随机顺序的,这将导致混乱的左右移动。使用排序的X:X = np.sort(low_x + (high_x - low_x) * np.random.rand(N,1), axis=0)
效率问题
您的get_labels_improved方法效率低下,在X的元素上循环。请使用Y = f(X),将循环留给低级numy内部。在
另外,超定系统的最小二乘解的计算应该用lstsq来完成,而不是计算伪逆(计算量大)并乘以它。在
这是经过清理的代码;使用30个中心进行匹配。在
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
import matplotlib.pyplot as plt
N = 5000
low_x =-2*np.pi
high_x=2*np.pi
X = np.sort(low_x + (high_x - low_x) * np.random.rand(N,1), axis=0)
f = lambda x: 2*np.power( 2*np.power( np.cos(x) ,2) - 1, 2) - 1
Y = f(X)
K = 30 # number of centers for RBF
indices=np.random.choice(a=N,size=K) # choose numbers from 0 to D^(1)
subsampled_data_points=X[indices,:] # M_sub x D
stddev = 1
beta = 0.5*np.power(1.0/stddev,2)
Kern = np.exp(-beta*euclidean_distances(X=X, Y=subsampled_data_points,squared=True))
C = np.linalg.lstsq(Kern, Y)[0]
Y_pred = np.dot(Kern, C)
plt.plot(X, Y, 'o', label='Original data', markersize=1)
plt.plot(X, Y_pred, 'r', label='Fitted line', markersize=1)
plt.legend()
plt.show()