.bin
文件通常是通过 torch.save(state_dict)
来保存的,因此需要使用 torch.load()
进行加载。这是因为 .bin
文件存储的是模型权重的 state_dict
(即参数的字典),而不是完整的模型对象。你需要先加载 state_dict
,然后将其加载到模型中,类似于 .pt
或 .pth
文件。
常见的权重文件格式(如 .bin
、.safetensors
、.pt
等)的保存和加载方法:
文件类型 | 保存方法 | 加载方法 | 描述 |
---|---|---|---|
.bin | torch.save(model.state_dict(), 'model.bin') | state_dict = torch.load('model.bin', map_location='cpu') model.load_state_dict(state_dict) | .bin 文件通常存储 state_dict ,仅保存模型的权重。需要先加载 state_dict ,然后用 model.load_state_dict 加载到模型。 |
.pt / .pth | torch.save(model.state_dict(), 'model.pth') | state_dict = torch.load('model.pth', map_location='cpu') model.load_state_dict(state_dict) | .pt / .pth 文件是 PyTorch 中常见的保存权重格式,保存的是模型的 state_dict 。加载方式和 .bin 文件相同。 |
.pt (完整模型) | torch.save(model, 'model.pth') | model = torch.load('model.pth', map_location='cpu') | 保存的是整个模型对象(包括结构和权重)。加载时会直接返回模型,无需再用 load_state_dict 。 |
.safetensors | from safetensors import save_file save_file(tensors, 'model.safetensors') | from safetensors import load_file state_dict = load_file('model.safetensors') model.load_state_dict(state_dict) | .safetensors 是一种新型格式,专门用于安全和高效地保存模型权重。保存和加载都需要 safetensors 库。 |
Hugging Face .bin | model.save_pretrained('path_to_model') | from transformers import AutoModel model = AutoModel.from_pretrained('path_to_model') | Hugging Face 的 .bin 通常用于保存预训练模型的权重,带有额外的配置文件(如 config.json )。AutoModel 会自动处理加载。 |
.h5 (Keras/TensorFlow) | model.save_weights('model.h5') | model.load_weights('model.h5') | Keras 和 TensorFlow 使用 .h5 文件格式保存模型权重。适用于 TensorFlow/Keras 环境。 |
详细解释:
-
.bin
文件.bin
文件通常存储模型的state_dict
,也就是模型参数的字典。你需要先用torch.load()
读取这个字典,再用model.load_state_dict()
将它加载到模型中。- 典型保存代码:
torch.save(model.state_dict(), 'model.bin')
- 加载代码:
state_dict = torch.load('model.bin', map_location='cpu') model.load_state_dict(state_dict)
-
.pt
/.pth
文件-
.pt
和.pth
都是 PyTorch 官方推荐的扩展名,用于保存模型的state_dict
或完整模型。对于state_dict
,它的加载方式与.bin
类似。 -
典型保存代码:
torch.save(model.state_dict(), 'model.pth')
-
加载代码:
state_dict = torch.load('model.pth', map_location='cpu') model.load_state_dict(state_dict)
-
保存完整模型对象的代码:
torch.save(model, 'model.pth')
-
加载完整模型代码:
model = torch.load('model.pth', map_location='cpu')
-
-
.safetensors
文件.safetensors
是一种新型格式,设计用于比 PyTorch 的.pt
文件更加安全、确定性强(防止任意代码执行的漏洞)且内存高效。它适合大型模型的存储。- 典型保存代码:
from safetensors import save_file save_file(tensors, 'model.safetensors')
- 加载代码:
from safetensors import load_file state_dict = load_file('model.safetensors') model.load_state_dict(state_dict)
-
Hugging Face
.bin
- Hugging Face 通常保存预训练模型的权重为
.bin
文件,同时保存模型配置文件config.json
,并提供save_pretrained
和from_pretrained
方法简化保存和加载过程。 - 典型保存代码:
model.save_pretrained('path_to_model')
- 加载代码:
from transformers import AutoModel model = AutoModel.from_pretrained('path_to_model')
- Hugging Face 通常保存预训练模型的权重为
总结
.bin
/.pt
/.pth
:这些格式大多数情况下保存的是state_dict
,需要通过model.load_state_dict()
将其加载到模型中。- 完整模型 (
.pt
):保存的是整个模型对象,包含模型架构和权重,加载时直接得到模型实例。 .safetensors
:专为安全和效率设计,需要safetensors
库处理加载和保存。