对pytorch optimizer中state_dict、state、param_groups的简要理解

先说结论:

  • state_dict():一个dict,里面有两个key(stateparam_groups),

    • state这个key对应的value是各个权重对应的优化器状态。具体来说,一个model有很多权重,model.parameters()会打印出该模型的各层的权重,比如使用Adam,每层权重都有一个momentum和variance,形状与权重相同,还有该层当前更新到的步数。state_dict()['state']是一个dict,每个key-value item结构如下:
      该权重在model.parameters()中的位置 : {
      	'step': tensor, 
      	'exp_avg': tensor, # exp_avg: exponential moving average of gradient values
      	'exp_avg_sq: tensor # exp_avg_sq: exponential moving average of squared gradient values
      
    • param_groups这个key对应的value是一个list,其中每个元素都是超参数组成的一个dict,因为不同的权重可以使用不同的超参数,所以需要使用list来表示,而且dict中params表示该超参数配置作用于哪些权重。state_dict()['param_groups']是一个list,每个元素结构如下
      {'lr': 0.01, 'weight_decay': 0,  ...  , 'params', [该超参数配置作用于的权重的位置]}
      
  • state:是一个defaultdict,包含的信息类似于state_dict()['state']+model.parameters(),具体来说,每个key-value item结构如下:

    param_tensor :{
    	'step': tensor, 
    	'exp_avg': tensor, 
    	'exp_avg_sq': tensor,	
    }
    
  • param_groups:是一个list,包含的信息类似于state_dict()['param_groups']+model.parameters(),具体来说,每个元素结构如下:

    {
    	'params': [param1, param2, ...]
    	'lr': 0.01, 
    	'weight_decay': 0, 
    	...
    	# 注意相较于state_dict()['param_groups'],原来'params'这个key对应的是param的索引位置,现在直接就是tensor了
    }
    

示例代码:

import torch
from torch.nn import Module
from torch.optim import Adam


class MyModel(Module):
    def __init__(self, in_dim, hidden_dim):
        super(MyModel, self).__init__()
        self.linear = torch.nn.Linear(in_features=in_dim, out_features=hidden_dim, bias=True)
        self.linear2 = torch.nn.Linear(in_features=hidden_dim, out_features=in_dim, bias=False)
    def forward(self, x):
        y = self.linear(x)
        out = self.linear2(y)
        return out


in_dim = 5
hidden_dim = 2
model = MyModel(in_dim=in_dim, hidden_dim=hidden_dim)

optimier = Adam([
    {
   'params': model.linear.parameters(), 'lr': 0.05},
    {
   'params': model.linear2.parameters()}
], lr=0.01)


x = torch.randn((in_dim))
out = model(x)
loss = torch.sum(out, dim=-1)
optimier.zero_grad()
loss.backward()
optimier.step()

print('#' * 100)
print(optimier.state_dict())

print('#' * 100)
print(optimier.state)

print
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值