pytorch笔记:构建模型流程,nn.Module,state_dict,parameters,modules

本文介绍了PyTorch中Dataset的初始化,transform的调用,以及如何构建神经网络模型。重点讨论了model_state_dict和optimizer_state_dict在保存模型时的作用,提到了torchsummaryX用于查看模型参数。此外,解释了buffers与TensorFlow的槽变量相似性,以及如何注册和使用。还涵盖了model的参数管理,如train/eval模式切换,以及如何保存和加载模型状态。最后,建议使用Docker来创建训练环境以保持一致性。
摘要由CSDN通过智能技术生成

来自B站视频官网教程API查阅详细信息

  • Dataset 的__init__中定义 transform 一般通过 __getitem__来调用
  • BUILD THE NEURAL NETWORK 中是 pytorch 构建模型的简单流程,PYTORCH RECIPES 是更相关详细内容,训练是要保存 checkpoint,包括 model_state_dict 和 optimizer_state_dict 等
  • 类似 tf 中 summary 模型(查看参数数目和分布)的方法在官方 pytorch 中没有,可以通过torchsummaryX 来实现
  • 简单看模型的参数数目可以
sum(p.numel() for p in model.parameters())
  • module 和 API 介绍
  • pytorch 的 buffers 和 tf 中的(槽变量?)类似
  • 关于 buffer
[docs]    def register_buffer(self, name: str, tensor: Optional[Tensor], persistent: bool = True) -> None:
        r"""Adds a buffer to the module.

        This is typically used to register a buffer that should not to be
        considered a model parameter. For example, BatchNorm's ``running_mean``
        is not a parameter, but is part of the module's state. Buffers, by
        default, are persistent and will be saved alongside parameters. This
        behavior can be changed by setting :attr:`persistent` to ``False``. The
        only difference between a persistent buffer and a non-persistent buffer
        is that the latter will not be a part of this module's
        :attr:`state_dict`.

        Buffers can be accessed as attributes using given names.
  • module源码 介绍
  • parameter是tensor的一个子类,作为模型参数必须设为 parameter 类型,通过register_parameter 注册
  • 改变数据类型 model.to(torch.float32)
  • 类的基本属性有 model._module,model._parameters,model._buffers,返回当前 module 的内容,不会返回子module的内容,不是为了公有访问设计的
  • 类的基本方法有 model.parameters() 或 model.named_parameters(),返回的是迭代器,既包含当前 module 的内容,也包括子 module 的内容
  • model.train()和 model.eval() 改变 self.training,在 dropout 和 batchnorm 中会不同
  • 保存model
# Specify a path
PATH = "state_dict_model.pt"

# Save
torch.save(net.state_dict(), PATH)

# Load
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()
  • checkpoint 是训练中断后继续训练使用的,state_dict 是模型的参数,但是不包括架构,如果要保存架构,可以使用下面的示例。一般推荐按照
# Specify a path
PATH = "entire_model.pt"

# Save
torch.save(net, PATH)

# Load
model = torch.load(PATH)
model.eval()
  • 尽量使用 docker 创建训练环境请添加图片描述
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

_森罗万象

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值