转载:https://moonbooks.org/Articles/How-to-calculate-a-root-mean-square-using-python-/
如何利用Python计算线性模型拟合结果的均方根:
完整代码:
import matplotlib.pyplot as plt
import numpy as np
X = 4 * np.random.rand(1000,1)
X_b = np.c_[np.ones((1000,1)), X]
Y = 2 + 3 * X + np.random.randn(1000,1)
plt.plot(X,Y,'.')
plt.xlim(0,4)
plt.ylim(0,15)
plt.xlabel(r'x',fontsize=8)
plt.ylabel(r'y',fontsize=8)
plt.title('How to caclulate the mean squared error in python ?',fontsize=8)
plt.savefig("mean_squared_error_01.png", bbox_inches='tight')
#----- Let's take one random linear model
theta = np.array([[-1.4],[5.0]])
X_new = np.array([[0],[4]])
X_new_b = np.c_[np.ones((2,1)), X_new]
plt.plot(X_new, X_new_b.dot( theta ), '-')
plt.xlim(0,4)
plt.ylim(0,15)
plt.xlabel(r'x',fontsize=8)
plt.ylabel(r'y',fontsize=8)
plt.title('How to caclulate the mean squared error in python ?',fontsize=8)
plt.savefig("mean_squared_error_02.png", bbox_inches='tight')
plt.close()
#----- using python
Y_predict = X_b.dot( theta )
print(Y_predict.shape, X_b.shape, theta.shape)
mse = np.sum( (Y_predict-Y)**2 ) / 1000.0
print('mse: ', mse)
#----- using sklearn
from sklearn.metrics import mean_squared_error
print('mse (sklearn): ', mean_squared_error(Y,Y_predict))
#----- Calculate the mse using a grid search
theta_0, theta_1 = np.meshgrid(np.arange(0, 10, 0.1), np.arange(0, 10, 0.1))
theta = np.vstack((theta_0.ravel(), theta_1.ravel()))
Y_predict = X_b @ theta
mse = np.sum( (Y_predict-Y)**2, axis=0 ) / 1000.0
mse = mse.reshape(100,100)
from matplotlib.colors import LogNorm
from pylab import figure, cm
plt.imshow(mse, origin='lower', norm=LogNorm(), extent=[0,10,0,10], cmap=cm.jet)
plt.title('How to caclulate the mean squared error in python ?',fontsize=8)
plt.xlabel(r'$\theta_0$',fontsize=8)
plt.ylabel(r'$\theta_1$',fontsize=8)
plt.savefig("mean_squared_error_03.png", bbox_inches='tight')
#plt.show()
plt.close()
#----- plot theta_1 for a given theta_0
plt.plot(mse[:,20])
plt.title('How to caclulate the mean squared error in python ?',fontsize=8)
plt.xlabel(r'$\theta_1$',fontsize=8)
plt.ylabel(r'mean square error',fontsize=8)
positions = [i*10 for i in range(10)]
labels = [i for i in range(10)]
plt.xticks(positions, labels)
plt.grid(linestyle='--')
plt.savefig("mean_squared_error_04.png", bbox_inches='tight')
#plt.show()