既然是梯度下降法,那么梯度就显得很关键了,首先了解梯度的概念:全部变量的偏导数汇总而成的向量。其中偏导数就是说存在两个或两个以上变量的函数,这个时候需要指定求某个变量的导数,其余的变量看成是常量就好,叫做这个变量的偏导数,实质还是求导数!
比如函数 f(x0,x1)=x0²+x1²,分别求x0和x1在x0=4和x1=5位置的偏导数
使用中心差分求偏导数
求x0的偏导数
def numerical_diff(f,x):
'''
中心差分
'''
h=1e-4#0.0001,如果是太小的值也不可以,计算机舍入之后为0
return (f(x+h)-f(x-h))/(2*h)
def f1(x0):
return x0**2+5**2
numerical_diff(f1,4)
8.000000000016882
结果和公式求的真导数8,基本一样(允许误差)
同理求x1的偏导数
def f2(x1):
return 4**2+x1**2
numerical_diff(f2,5)
9.999999999976694
结果和公式求的真导数10,基本一样
上面是单个分别求偏导数,下面是一次性求出所有变量的偏导数(梯度)
import numpy as np
def numerical_gradient(f,x):
h=1e-4#0.0001
grad=np.zeros_like(x)#生成x.shape一样的数组,用来保存偏导数
for i in range(x.size):
tmp=x[i]
x[i]=float(tmp)+h
fxh1=f(x)#f(x+h)
x[i]=float(tmp)-h
fxh2=f(x)#f(x-h)
grad[i]=(fxh1-fxh2)/(2*h)#中心差分
x[i]=tmp#还原值
return grad
def func(x):
'''
x0²+x1²+x2²+...
'''
return np.sum(x**2)
numerical_gradient(func,np.array([4.0,5.0]))#必须是浮点数数组,很奇怪求导函数里面转化没有起作用?
[ 8. 10. ]
另外如果是输入的是矩阵,也就是大于一维,代码修改如下:
def _numerical_gradient_no_batch(f,x):
h=1e-4#0.0001
grad=np.zeros_like(x)#生成x.shape一样的数组,用来保存偏导数
for i in range(x.size):
tmp=x[i]
x[i]=float(tmp)+h
fxh1=f(x)#f(x+h)
x[i]=float(tmp)-h
fxh2=f(x)#f(x-h)
grad[i]=(fxh1-fxh2)/(2*h)#中心差分
x[i]=tmp#还原值
return grad
def numerical_gradient(f,X):
if X.ndim==1:
return _numerical_gradient_no_batch(f,X)
else:
grad=np.zeros_like(X)
for i,x in enumerate(X):
grad[i]=_numerical_gradient_no_batch(f,x)
return grad
numerical_gradient(func,np.array([[4.0,5.0],[1,3]]))
array([[ 8., 10.],
[ 2., 6.]])
我们将梯度画图,更直观的感受下,也更好的了解梯度的特点
import numpy as np
import matplotlib.pylab as plt
x0 = np.arange(-3, 3, 0.5)
x1 = np.arange(-3, 3, 0.5)
X, Y = np.meshgrid(x0, x1)
X=X.flatten()
Y=Y.flatten()
grad=numerical_gradient(func,np.array([X,Y]))
plt.figure()
plt.quiver(X,Y,-grad[0],-grad[1],angles='xy',color='#FF0000')
plt.grid()
plt.xlabel('x0')
plt.ylabel('x1')
plt.show()
我们通过quiver函数画出的箭头图形可以看出,f(x0,x1)=x0²+x1²的梯度呈现为有向向量(箭头),直观感受就是梯度像指南针一样,所有箭头都指向同一点,其次,我们发现离“最低处”越远,箭头越大。
附带再次说明下quiver画箭头的函数
import numpy as np
import matplotlib.pylab as plt
plt.grid()
#plt.quiver(3,4,3,5,angles='xy',color='#FF0000')#(3,4)坐标指向(3,5)坐标方向
plt.quiver((4,3,3),(5,5,7),(8,30,7),(8,6,7),angles='xy',color='#FF0000')
#(4,5)、(3,5)、(3,7)分别指向(8,8)、(30,6)、(7,7),从箭头又可以看出,离的远的箭头就比较大比较明显
plt.show()