对着手写数字识别实例讲讲pytorch模型保存的格式。
首先讲讲保存模型或权重参数的后缀格式,权重参数和模型参数的后缀格式一样,pytorch中最常见的模型保存使用 .pt 或者是 .pth 作为模型文件扩展名。还有其他的保存数据的格式为.t7或者.pkl格式。t7文件是沿用torch7中读取模型权重的方式,而pth文件是python中存储文件的常用格式,而在keras中则是使用.h5文件 。
1、pytorch保存和加载模型以及权重参数(强烈推荐使用这种)
1.1 首先新建model.py模块
将模型单独新建一个模块
from
1.2 新建一个train.py模块
保存模型和权重参数的格式为:
torch
训练的实例如下所示:
import
1.3 新建一个test.py模块
加载模型权重的格式为:
the_model
测试实例如下所示:
import
2、pytorch保存和加载整个模型(不推荐)
2.1 model.py模块同以上1.1
2.2 新建一个train.py模块
保存整个模型的格式为:
torch.save(the_model, PATH)
训练实例如下所示:
import
2.3 新建一个test.py模块
加载整个模型的格式为:
model
测试实例如下所示:
import