一,优化器的基本知识
1.概念
- 导数:函数在指定坐标轴上的变化率
- 方向导数:指定方向上的变化率
- 梯度:一个向量,方向为方向导数取得最大值的方向
2.基本属性
- defaults:优化器超参数
- state:参数的缓存,如momentum
- param_groups:输入要优化的参数组
- _step_count:记录更新次数,学习率调整中使用
3.基本方法
- zero_grad():清空所管理的参数的梯度
- step():梯度一步更新
- add_param_group():添加参数组
4.用于断点续训练的方法
- state_dict():获取优化器当前状态信息字典
- load_state_dict():加载状态信息字典—》模型加载参数时使用
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
import torch
import torch.optim as optim
from tools.common_tools import set_seed
set_seed(1)
weight = torch.randn((2, 2), requires_grad=True)
weight.grad = torch.ones((2, 2))
optimizer = optim.SGD([weight], lr=0.1)
flag = 0
if flag:
print("weight before step:{}".format(weight.data))
optimizer.step()
print("weight after step:{}".format(weight.data))
flag = 0
if flag:
print("weight before step:{}".format(weight.data))
optimizer.step()
print("weight after step:{}".format(weight.data))
print("weight in optimizer:{}\nweight in weight:{}\n".format(id(optimizer.param_groups[0]['params'][0]), id(weight)))
print("weight.grad is {}\n".format(weight.grad))
optimizer.zero_grad()
print("after optimizer.zero_grad(), weight.grad is\n{}".format(weight.grad))
flag = 0
if flag:
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
w2 = torch.randn((3, 3), requires_grad=True)
optimizer.add_param_group({"params": w2, 'lr': 0.0001})
print("optimizer.param_groups is\n{}".format(optimizer.param_groups))
flag = 0
if flag:
optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
opt_state_dict = optimizer.state_dict()
print("state_dict before step:\n", opt_state_dict)
for i in range(10):
optimizer.step()
print("state_dict after step:\n", optimizer.state_dict())
torch.save(optimizer.state_dict(), os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
flag = 1
if flag:
optimizer = optim.SGD([weight], lr=0.1, momentum=0.9)
state_dict = torch.load(os.path.join(BASE_DIR, "optimizer_state_dict.pkl"))
print("state_dict before load state:\n", optimizer.state_dict())
optimizer.load_state_dict(state_dict)
print("state_dict after load state:\n", optimizer.state_dict())
二,动量(Momentam):结合当前梯度与上次更新信息,用于当前更新
- 类惯性定律,总向一个方向更新时,更新的步子会变大
- 值越接近1, 对先前数据的记忆周期越长
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/14bad39bd8d7d7e20505d91365ab11c0.png)
三,常用优化器
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/4caa4a443828925a81fdd851d3775396.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/99b53848a0f7e0cf37afc066d4b04cd3.png)
![在这里插入图片描述](https://i-blog.csdnimg.cn/blog_migrate/abbcab890061b4c845a23a46bb3c6765.png)