python 绘图实例2

29 篇文章 2 订阅

转载:https://moonbooks.org/Articles/How-to-calculate-a-root-mean-square-using-python-/
如何利用Python计算线性模型拟合结果的均方根:
完整代码:

import matplotlib.pyplot as plt
import numpy as np

X = 4 * np.random.rand(1000,1)
X_b = np.c_[np.ones((1000,1)), X]

Y = 2 + 3 * X + np.random.randn(1000,1)

plt.plot(X,Y,'.')

plt.xlim(0,4)
plt.ylim(0,15)

plt.xlabel(r'x',fontsize=8)
plt.ylabel(r'y',fontsize=8)

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.savefig("mean_squared_error_01.png", bbox_inches='tight')

#----- Let's take one random linear model

theta = np.array([[-1.4],[5.0]])

X_new = np.array([[0],[4]])
X_new_b = np.c_[np.ones((2,1)), X_new]

plt.plot(X_new, X_new_b.dot( theta ), '-')

plt.xlim(0,4)
plt.ylim(0,15)

plt.xlabel(r'x',fontsize=8)
plt.ylabel(r'y',fontsize=8)

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.savefig("mean_squared_error_02.png", bbox_inches='tight')

plt.close()

#----- using python

Y_predict = X_b.dot( theta )

print(Y_predict.shape, X_b.shape, theta.shape)

mse = np.sum( (Y_predict-Y)**2 ) / 1000.0

print('mse: ', mse)

#----- using sklearn

from sklearn.metrics import mean_squared_error

print('mse (sklearn): ', mean_squared_error(Y,Y_predict))

#----- Calculate the mse using a grid search

theta_0, theta_1 = np.meshgrid(np.arange(0, 10, 0.1), np.arange(0, 10, 0.1))

theta = np.vstack((theta_0.ravel(), theta_1.ravel()))

Y_predict = X_b @ theta

mse = np.sum( (Y_predict-Y)**2, axis=0 ) / 1000.0

mse = mse.reshape(100,100)

from matplotlib.colors import LogNorm
from pylab import figure, cm

plt.imshow(mse, origin='lower', norm=LogNorm(), extent=[0,10,0,10], cmap=cm.jet)

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.xlabel(r'$\theta_0$',fontsize=8)
plt.ylabel(r'$\theta_1$',fontsize=8)

plt.savefig("mean_squared_error_03.png", bbox_inches='tight')

#plt.show()

plt.close()

#----- plot theta_1 for a given theta_0

plt.plot(mse[:,20])

plt.title('How to caclulate the mean squared error in  python ?',fontsize=8)

plt.xlabel(r'$\theta_1$',fontsize=8)
plt.ylabel(r'mean square error',fontsize=8)

positions = [i*10 for i in range(10)]
labels = [i for i in range(10)]

plt.xticks(positions, labels)

plt.grid(linestyle='--')

plt.savefig("mean_squared_error_04.png", bbox_inches='tight')

#plt.show()
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值