基本概念
这几天在看深度学习这本书,正好看到梯度下降这里,想想好早之前看梯度下降一直不明白,这里就将其总结一下;
函数在某一点的梯度是这样一个向量,它的方向与取得最大方向导数的方向一致(即:变化最快的那个方向就是梯度的方向),而它的模为方向导数的最大值。
这里注意三点:
1)梯度是一个向量,即有方向有大小;
2)梯度的方向是最大方向导数的方向;
3)梯度的值是最大方向导数的值。
现有一函数f(x):
f
(
x
)
=
x
0
2
+
x
1
2
f(x) = x_0^2+x_1^2
f(x)=x02+x12
也可以写成
f
(
x
0
,
x
1
)
=
x
0
2
+
x
1
2
f(x_0,x_1) = x_0^2+x_1^2
f(x0,x1)=x02+x12
那么
(
∂
f
∂
x
0
,
∂
f
∂
x
1
)
(\frac{\partial f}{\partial x_0} ,\frac{\partial f}{\partial x_1} )
(∂x0∂f,∂x1∂f)即为函数f(x)的梯度;
梯度大小的计算
∂
f
∂
x
0
=
f
(
x
0
+
h
,
x
1
)
−
f
(
x
0
−
h
,
x
1
)
2
h
\frac{\partial f}{\partial x_0}=\frac{f(x_0+h,x_1)-f(x_0-h,x_1)}{2h}
∂x0∂f=2hf(x0+h,x1)−f(x0−h,x1)
∂
f
∂
x
1
=
f
(
x
0
,
x
1
+
h
)
−
f
(
x
0
,
x
1
−
h
)
2
h
\frac{\partial f}{\partial x_1}=\frac{f(x_0,x_1+h)-f(x_0,x_1-h)}{2h}
∂x1∂f=2hf(x0,x1+h)−f(x0,x1−h)
代码
import numpy as np
import matplotlib.pylab as plt
# def mean_squared_error(y,t):
# return 0.5*np.sum((y-t)**2)
#
# def cross_entropy_error(y,t):
# delta= 1e-7
# return -np.sum(t*np.log(y + delta))
#
# array1 = np.random.choice(6000,10)
# # print(array1)
#
#
# def f1(x):
# return 0.01*x**2 + 0.1*x
#
#
# x = np.arange(0.0,20.0,0.1)
# # print(x)
# y = f1(x)
# plt.xlabel("x")
# plt.ylabel("f(x)")
# plt.plot(x,y)
# plt.show()
# x = np.arange(-5.0,5.0,0.1) 这里是构造一个等差数列,起始值为-5,终值为-5;等差为0.1
# y = sigmoid(x)
# plt.plot(x,y) #折线图
# plt.ylim(-0.1,1.1) #设置y坐标轴范围
# plt.show() #画图,显示图像
def f2(x):
return x[0]**2 + x[1]**2
def _numerical_gradient_no_batch(f, x):
h = 1e-4 # 0.0001
grad = np.zeros_like(x) #[0,0],生成了和x形状相同的数组
for idx in range(x.size):
tmp_val = x[idx]
x[idx] = float(tmp_val) + h
print(x[idx])
fxh1 = f(x) # f(x+h)
print(fxh1)
x[idx] = tmp_val - h
print(x[idx])
fxh2 = f(x) # f(x-h)
print(fxh2)
grad[idx] = (fxh1 - fxh2) / (2 * h) #这是在求斜率
print(grad[idx])
x[idx] = tmp_val # 还原值
print(x[idx])
print(grad)
return grad
def gradient_descent(f,init_x,lr = 0.01,step_num = 100):
x = init_x
for i in range(step_num):
grad = _numerical_gradient_no_batch(f,x)
x -= lr * grad
print(x)
return x
arr = np.array([3,4],dtype = np.float) ##这里需要加上类型dtype = np.float,不然后面会遇到类型转换错误
# print(arr)
# _numerical_gradient_no_batch(f2,arr)
gradient_descent(f2,arr)
参考
https://blog.csdn.net/wireless_com/article/details/70596155
https://blog.csdn.net/walilk/article/details/50978864
《深度学习入门》