一、 pytorch模型定义的方式
pytorch模型的定义主要包括初始化以及数据流向两个部分,并且均继承于nn.Module类。
pytorch的模型定义共可以分为三种方式:
- Sequential:传入有序字典即可,不用再写forword函数,但是缺乏灵活性;
- ModuleList:接收子模块的列表,需要通过forward函数定义顺序;
- ModuleDict:与ModuleList作用类似,但是可以方便对层添加名称。
二、利用模型块快速搭建复杂网络
关键点:通过简单层构建出具有特定功能的模型块,然后通过模型块则能够构建复杂的网络。
三、Pytorch修改模型
- 修改模型层
①参数修改:直接可以对某些层的参数进行参数的修改,使得输出的结果符合自身的大小
②网络结构的修改:可以对相应的层进行替换。 - 添加外部输入
可以自定义层进行外部输入 - 添加额外输出
可以输出不同层查看结果
四、Pytorch模型的保存与读取
- Pytorch的存储格式:pkl,pt,pth
- 存储内容:模型结构和权重
- 单卡和多卡模型存储的区别在于保存和加载时会存在一定问题
- 不同的硬件在存取时会发生一些问题