Momentum参照小球在碗中滚动的物理规则进行移动
AdaGrad为每个参数适当调整更新步伐
Adam是2015年提出的新方法,它是Momentum和AdaGrad的结合体,,融合了两种方法的优势
代码:
# Adam.py
import numpy as np
import matplotlib.pyplot as plt
class Adam:
def __init__(self, lr=0.01, beta1=0.9, beta2=0.999):
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.iter = 0
self.m = None
self.v = None
def update(self, params, grads):
if self.m is None:
self.m, self.v = {}, {}
for key, val in params.items():
self.m[key] = np.zeros_like(val)
self.v[key] = np.zeros_like(val)
self.iter += 1
lr_t = self.lr * np.sqrt(1.0 - self.beta2**self.iter) / (1.0 - self.beta1**self.iter)
for key in params.keys():
self.m[key] += (1 - self.beta1) * (grads[int(key)] - self.m[key])
self.v[key] += (1 - self.beta2) * (grads[int(key)]**2 - self.v[key])
params[key] -= lr_t * self.m[key] / (np.sqrt(self.v[key])+ 1e-7)
return params
def numerical_gradient(f, x):
h = 1e-4
x = np.array(list(init_x.values())) # 转换为ndarray
grad = np.zeros_like(x)
for idx in range(x.size):
temp = x[idx]
x[idx] = temp + h
fxh1 = f(x)
x[idx] = temp - h
fxh2 = f(x)
grad[idx] = (fxh1 - fxh2) / (2 * h)
x[idx] = temp
return grad
def func2(x):
return (x[0]**2) / 20 + x[1] ** 2
def adam_update(init_x, stepnum):
x = init_x
x_history = []
for i in range(stepnum):
x_history.append(np.array(list(x.copy().values())))
grad = numerical_gradient(func2, x)
x = m.update(x, grad)
return x, np.array(x_history)
init_x = {} # 起始点
init_x['0'] = -7.0
init_x['1'] = 2.0
learning_rate = 0.25
m = Adam(lr=learning_rate)
stepnum = 40
x, x_history = adam_update(init_x=init_x, stepnum=stepnum)
axis_range = 10
x = np.arange(-axis_range, axis_range, 0.05)
y = np.arange(-axis_range, axis_range, 0.05)
X, Y = np.meshgrid(x, y)
z = np.array([X, Y])
# 画等高线
plt.figure()
plt.contour(x, y, func2(z),np.arange(0,10,2), zdir='z', cmap='binary')
# 画所有由梯度下降找到的点
plt.plot(x_history[:, 0], x_history[:, 1], '+', color='blue')
# 画点间连线
for i in range(x_history.shape[0]-2):
tmp = x_history[i:i+2]
tmp = tmp.T
plt.plot(tmp[0], tmp[1], color='blue')
# 标注最小值位置
plt.plot(0, 0, 'o', color='r')
plt.xlabel('x')
plt.ylabel('y')
plt.title('Adam 0.05x^2 + y^2 ')
plt.show()