摘要
本文使用纯 Python 和 PyTorch 对比实现SGD, Momentum, RMSprop, Adam梯度下降算法.
相关
原理和详细解释, 请参考: :
常用梯度下降算法SGD, Momentum, RMSprop, Adam详解
文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
正文
1. 算法类
文件目录 : vanilla_nn/optim.py
import numpy as np
class SGD:
def __init__(self, lr=0.01):
self.lr = lr
def __call__(self, params, grads):
params -= self.lr * grads
class Momentum:
def __init__(self, lr=0.01, momentum=0.9):
self.lr = lr
self.momentum = momentum
self.v = None
def __call__(self, params, grads):
if self.v is None:
self.v = np.zeros_like(params)
self.v = self.momentum * self.v + grads
params -= self.lr * self.v
class RMSProp:
def __init__(self, lr=0.01, alpha=0.9, eps=1e-08):
self.lr = lr
self.alpha = alpha
self.eps = eps
self.v = None
def __call__(self, params, grads):
if self.v is<