[Pytorch] 参数保存剖析

一般Pytorch会将训练好的模型保存至 xxx.pth 文件中
常用命令:
torch.load()
torch.save()

详细解剖其内部:
.pth 文件实质上是一个简单的字典文件

module.features.0.weight 
module.features.0.bias
module.features.1.weight 
module.features.1.bias 
module.features.1.running_mean
module.features.1.running_var
module.features.3.weight
module.features.3.bias
module.features.4.weight 
module.features.4.bias 
module.features.4.running_mean 
module.features.4.running_var
....
module.classifier.weight
module.classifier.bias
DataParallel(
  (module): VGG(
    # 这里的 'features' 其实是自定的名称,下面的'classifier' 同理
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))  # 对应上面的 features.0.weight, bias 
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) # 对应上面的 *.1.*
      (2): ReLU(inplace)  # 激活曾没有参数所以直接跳过
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
      (5): ReLU(inplace)
      (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)  # 池化曾也没有参数
...
      (43): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)
      (44): AvgPool2d(kernel_size=1, stride=1, padding=0, ceil_mode=False, count_include_pad=True)
    )
    (classifier): Linear(in_features=512, out_features=10, bias=True)
  )
)

上面的DataParallel 是因为添加了
net = torch.nn.DataParallel(net)
这个操作使得网络可以在多GPU 上训练

VGG(
....

pytorch 的参数存储十分简单,如果你想自定义载入的话,直接修改net.state_dict()中的参数就可以了,和python的字典处理一样

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值