通常,用户想在磁盘上保存并加载经过训练的模型。这就是使用
AllenNLP
的配置文件非常有用的地方,因为加载模型所需的所有内容,包括权重、配置和词汇表,都可以存储在单个tar文件中。在本章中,将介绍三种对模型进行保存与加载的方式。
手动保存与加载
为了正确地保存和加载AllenNLP
模型,我们一般需要有如下文件:
- 模型配置(用于训练模型的规范)
- 模型权重(模型的训练参数)
- 词汇表
在AllenNLP
中,模型配置由Params
类管理,可以使用to_file()
方法保存到磁盘。用户可以使用model.state_dict()
检索模型权重,并使用PyTorch
的torch.save()
将其保存到磁盘中。 Vocabulary.save_to_files()
方法将Vocabulary
对象序列化到目录。
为了从文件加载模型,可以使用Model.load()
类方法。 它需要一个Params
对象,该对象包含模型配置以及模型权重和词汇序列化的目录路径。 该方法还加载和还原词汇表。
示例代码如下:
# 存储模型
serialization_dir = 'model'
config_file = os.path.join(serialization_dir, 'config.json')
vocabulary_dir = os.path.join(serialization_dir, 'vocabulary')
weights_file = os.path.join(serialization_dir, 'weights.th')
os.makedirs(serialization_dir, exist_ok=True)
params.to_file(config_file)
vocab.save_to_files(vocabulary_dir)
torch.save(model.state_dict(), weights_file)
# 加载模型
loaded_params = Params.from_file(config_file)
loaded_model = Model.load(loaded_params, serialization_dir, weights_file)
loaded_vocab = loaded_model.vocab # 在上一步已经加载进去了
archive保存与加载
因为每次需要保存、加载和移动模型时都要处理这三个文件,所以AllenNLP
提供了用于归档和取消归档模型文件的实用功能。用户可以使用archive_model()
方法将模型配置、权重和词汇表打包成一个tar.gz
文件,以及任何附加的补充文件。此方法假设用户使用 training loop
训练模型,并打包 training loop
运行时保存的文件。 training loop
也调用此函数,以便在训练结束时打包最佳模型权重,因此用户不太可能需要自己调用此方法。
另外,用户可以简单地使用load_archive()
从存档文件中还原模型。 这将返回一个Archive
对象,其中包含配置和模型。
# 建立archive文件
archive_model(serialization_dir, weights='weights.th')
# 加载archive文件
archive = load_archive(os.path.join(serialization_dir, 'model.tar.gz'))
AllenNLP
命令保存与加载
实际上,如果用户使用AllenNLP
命令(例如allennlp train
),就会自动处理模型的保存。 训练结束或中断训练后,该命令会自动将最佳模型保存到model.tar.gz文件中。 用户还可以从序列化目录恢复训练。训练命令样例如下:
# my_text_classifier.jsonnet指模型配置文件,-s 参数是保存模型的文件夹名,--include-package 是放自己写的自定义脚本所在的文件夹名
allennlp train my_text_classifier.jsonnet -s model --include-package my_text_classifier
Model.from_archive
除了load_archive()
之外,AllenNLP
还提供了一种便捷方法Model.from_archive()
。 这基本上只是在后台调用load_archive()
。 其主要目的是将其注册为from_archive
类型的Model
构造函数,以便用户可以从存档文件加载保存的模型,并继续使用allennlp train
命令对其进行训练。 为此,请将以下代码片段放入训练配置文件中:
{
...
"model": {
"type": "from_archive",
"archive_file": "/path/to/saved/archive/file.tar.gz"
}
...
}
参考资料