参考博客:https://blog.csdn.net/huahuazhu/article/details/73385362
参考的博客中的示例代码是直线,学习后将其改为曲线
#!/usr/bin/python
#coding=utf-8
import numpy as np
import matplotlib.pyplot as plt
# 构造训练数据
x = np.arange(0., 10., 0.2)
# 训练数据点数目
m = len(x)
#打印训练数据点数目
print(m)
#创建矩阵1*m,每个元素值为1.0
x0 = np.full(m, 1.0)
x1 = []
for i in x:
x1.append(i**2)
input_data = np.vstack([x0, x,x1]).T # 将偏置b作为权向量的第一个分量,权重w1作为第二个分量,权重w2作为第二个分量
target_data = x**2+2 * x + 5 + np.random.randn(m)
# 两种终止条件
loop_max = 10000 # 最大迭代次数(防止死循环)
epsilon = 1e-3
# 初始化权值
np.random.seed(0)
theta = np.random.randn(3)#初始化theta
alpha = 0.00001 # 步长(注意取值过大会导致振荡即不收敛,过小收敛速度变慢)
diff = 0.
#创建0矩阵
error = np.zeros(3)
count = 0 # 循环次数
finish = 0 # 终止标志
while count < loop_max:
count += 1
# 标准梯度下降是在权值更新前对所有样例汇总误差,而随机梯度下降的权值是通过考查某个训练样例来更新的
# 在标准梯度下降中,权值更新的每一步对多个样例求和,需要更多的计算
sum_m = np.zeros(3)
for i in range(m):
dif = (np.dot(theta, input_data[i]) - target_data[i]) * input_data[i]
sum_m = sum_m + dif # 当alpha取值过大时,sum_m会在迭代过程中会溢出
theta = theta - alpha * sum_m # 注意步长alpha的取值,过大会导致振荡
# 判断是否已收敛
if np.linalg.norm(theta - error) < epsilon:
finish = 1
break
else:
error = theta
print ('loop count = %d' % count, '\tw:',theta)
print ('loop count = %d' % count, '\tw:',theta)
#绘图
plt.plot(x, target_data, 'g.')
plt.plot(x, theta[2]*x**2+theta[1] * x + theta[0], 'r')
plt.show()
实验结果:
补充一些函数说明:
1.scipy.stats.linregress
scipy.stats.linregress(x, y=None)[source]
Calculate a linear least-squares regression for two sets of measurements.
计算两组测量的线性最小二乘回归。
Parameters:
x, y : array_like
Two sets of measurements. Both arrays should have the same length. If only x is given (and y=None), then it must be a two-dimensional array where one dimension has length 2. The two sets of measurements are then found by splitting the array along the length-2 dimension.Returns:
slope : float
slope of the regression lineintercept : float
intercept of the regression linervalue : float
correlation coefficientpvalue : float
two-sided p-value for a hypothesis test whose null hypothesis is that the slope is zero, using Wald Test with t-distribution of the test statistic.stderr : float
Standard error of the estimated gradient.
例如:
>>> import matplotlib.pyplot as plt
>>> from scipy import stats
>>>> np.random.seed(100)
>>> x = np.random.random(10)
>>> print(x)
[0.54340494 0.27836939 0.42451759 0.84477613 0.00471886 0.12156912
0.67074908 0.82585276 0.13670659 0.57509333]
>>> y = np.random.random(10)
>>> print(y)
[0.89132195 0.20920212 0.18532822 0.10837689 0.21969749 0.97862378
0.81168315 0.17194101 0.81622475 0.27407375]
>>> slope, intercept, r_value, p_value, std_err = stats.linregress(x, y)
>>> print('slope=%s, intercept=%s'%(slope, intercept))
slope=-0.3442654203664301, intercept=0.6190108485958797
>>> print('r_value=%s, p_value=%s,std_err=%s'%(p_value, p_value,std_err))
r_value=0.41832390506284134, p_value=0.41832390506284134,std_err=0.4034667235997838
>>> plt.plot(x, y, 'o', label='original data')
[<matplotlib.lines.Line2D object at 0x000001AC5B23B198>]
>>> plt.plot(x, intercept + slope*x, 'r', label='fitted line')
[<matplotlib.lines.Line2D object at 0x000001AC534240B8>]
>>> plt.legend()
<matplotlib.legend.Legend object at 0x000001AC5B23B8D0>
>>> plt.show()
2.np.random.seed(0)
作用:使得随机数据可预测。
即不设置随机种子,每次生成的随机数都一样
>>> import numpy as np
>>> np.random.random(5)
array([0.75967512, 0.981992 , 0.15152915, 0.49788026, 0.21685212])
>>> np.random.random(5)
array([0.14585632, 0.6247151 , 0.7466104 , 0.44686791, 0.19680322])
若设置随机种子,则每次生成的数据都一样
>>> np.random.seed(5)
>>> np.random.random(5)
array([0.22199317, 0.87073231, 0.20671916, 0.91861091, 0.48841119])
>>> np.random.seed(5)
>>> np.random.random(5)
array([0.22199317, 0.87073231, 0.20671916, 0.91861091, 0.48841119])
3.x = np.full((1,2), 1)
创建1*2的矩阵,每个元素为1
>>> x = np.full((1,2),1)
>>> print(x)
[[1 1]]
3.np.vstack([x,y]) np.stack([x,y]) np.hstack([x,y])
>>> x=(1,2,3)
>>> y=(4,5,6)
>>>> z = np.vstack([x,y])
>>> print(z)
[[1 2 3]
[4 5 6]]
>>> z = np.hstack([x,y])
>>> print(z)
[1 2 3 4 5 6]
>>>
>>> z = np.stack([x,y])
>>> print(z)
[[1 2 3]
[4 5 6]]
>>> z = np.stack([x,y],1)
>>> print(z)
[[1 4]
[2 5]
[3 6]]