import numpy as np #数值计算库
import matplotlib.pyplot as plt #绘图库
# 这里相当于是随机X维度X1,rand是随机[0,1)均匀分布 形状是100行1列
X = 2 * np.random.rand(100, 1)
# 人为的设置真实的Y一列,np.random.randn(100, 1)是设置error,randn是标准正态分布
y = 4 + 3 * X + np.random.randn(100, 1)
# 整合X0和X1构造出X矩阵 [100,2],这是由于权重参数w0恒为1即认为w0为偏置项
X_b = np.c_[np.ones((100, 1)), X]
# 常规等式求解theta
#np.linalg.inv()用于矩阵求逆
#a.T表示矩阵转置
#a.dot(b)为矩阵a,b进行点乘
theta_best = np.linalg.inv(X_b.T.dot(X_b)).dot(X_b.T).dot(y)
print(theta_best)
# 创建一个测试集,大小为两行一列
X_new = np.array([[0], [2]])
#把x0=1添加到测试矩阵中构造出两行两列的测试矩阵
X_new_b = np.c_[(np.ones((2, 1))), X_new]
#测试矩阵与权重矩阵相乘求出预测值
y_predict = X_new_b.dot(theta_best)
print(y_predict)
plt.plot(X_new, y_predict, 'r-')#绘制最优解的图线 r表示红色,-表示直线
plt.plot(X, y, 'b.') #绘制真实样本分布,b表示蓝色,.表示散点
plt.axis([0, 2, 0, 15]) #指定x,y轴的区域范围
plt.show()#显示图形
绘图结果如下:
可以看出我们利用解析解一步求出最优权重,该一次图线尽可能的穿过真实样本点。
解析解求解公式如下图: