每一个点和直线产生的代价为当x的取值相同时,点对应的y值减去直线对应的y值的平方,变成符号表示就是
(
y
−
h
)
2
(y - h) ^ 2
(y−h)2
而代价函数就是求得一个平均代价,考虑到在梯度下降法的时候需要求导,为了能够消除求导多出来的乘2,所以需要对原函数除2,因此最终代价函数的表示为
J
=
1
2
m
Σ
(
y
−
h
)
2
J = \frac{1}{2m}\Sigma(y - h) ^ 2
J=2m1Σ(y−h)2,其中m为样本个数
我们想要的就是让这个
J
J
J最小
梯度下降
梯度下降法就是用来解决让
J
J
J最小的问题的,什么是梯度下降呢?
假设给定一个一元二次函数
y
=
x
2
y = x ^2
y=x2,问
y
y
y最小的时候,
x
x
x的取值,那肯定是在
x
=
0
x = 0
x=0的时候,但是机器并不知道是
x
=
0
x = 0
x=0的时候,除非给出求解公式,不过一般情况下不存在求解公式,这时候就需要对函数进行求导,得到所谓的极值点,这个时候可能就是最小的时候。
而梯度下降法则是利用了梯度的特性,一开始随便取一个值
x
x
x,然后减去一个学习率乘以在这个时候导数的值,即
x
−
l
r
∗
2
x
x - lr * 2x
x−lr∗2x,最终总能够无限逼近于
x
=
0
x = 0
x=0,虽然最后得到的结果可能是
x
=
0.000001
x = 0.000001
x=0.000001,但是对我们来说足够了
对于代价函数来说,梯度下降法同样适用,由此可以求得一个
J
J
J的值无限逼近于最小值时的
w
w
w,
w
w
w是直线的权值,因为我们已经知道了
x
x
x的取值范围,所以只要求出直线的
k
k
k和
b
b
b就行了
具体分析
通过上述两条概念得出,为了能够得到我们想要的直线,其实只要初始化一个
b
b
b,一个
k
k
k,然后用梯度下降法对
b
b
b和
k
k
k不断更新就行了,即
b
=
b
−
l
r
∗
1
m
Σ
(
h
−
y
)
k
=
k
−
l
r
∗
1
m
Σ
(
h
−
y
)
∗
x
i
b = b - lr * \frac{1}{m}\Sigma(h - y)\\ k = k - lr * \frac{1}{m}\Sigma(h - y) * x_{i}
b=b−lr∗m1Σ(h−y)k=k−lr∗m1Σ(h−y)∗xi
为什么上述两个公式会不同,因为对
J
J
J求
b
b
b和
k
k
k的偏导就是这个结果
自定义python代码实现一元线性回归
import matplotlib.pyplot as plt
# 代价函数deflose_function(b, k, x_data, y_data):'''
求代价必须传入b,k,x_data,y_data
'''# 求代价,最后的除2是为了1/2m设计的
total_error =0# m是样本个数
m =len(x_data)for i inrange(m):
total_error +=(y_data[i]-(k * x_data[i]+ b))**2return total_error / m /2# 梯度下降函数defgradient_descent(b, k, x_data, y_data, epochs, lr):# m是样本个数
m =len(x_data)# epochs是迭代次数for i inrange(epochs):
b_grad =0
k_grad =0for j inrange(len(x_data)):
b_grad +=(k * x_data[j]+ b - y_data[j])/ m
k_grad +=(k * x_data[j]+ b - y_data[j])* x_data[j]/ m
b = b - lr * b_grad
k = k - lr * k_grad
# 记录迭代次数和此时的代价print("第{0}次迭代,lose={1}".format(i+1, lose_function(b, k, x_data, y_data)))return b, k
if __name__ =='__main__':
x_data =[1,2,3,4,5,6,7]
y_data =[3,4.2,5,5.8,7.5,8.3,9]
plt.figure()
plt.scatter(x_data, y_data)# 设置初始b和k的值
k =1
b =2# 通过梯度下降法不断迭代更新得到一个更小的代价
b, k = gradient_descent(b, k, x_data, y_data,50,0.0001)print(b, k)# 得到b和k之后,画出拟合直线
plt.plot(x_data,[k * x + b for x in x_data], c='red')
plt.show()