.bin 文件通常是通过 torch.save(state_dict) 来保存的,因此需要使用 torch.load() 进行加载。这是因为 .bin 文件存储的是模型权重的 state_dict(即参数的字典),而不是完整的模型对象。你需要先加载 state_dict,然后将其加载到模型中,类似于 .pt 或 .pth 文件。
常见的权重文件格式(如 .bin、.safetensors、.pt 等)的保存和加载方法:
| 文件类型 | 保存方法 | 加载方法 | 描述 |
|---|---|---|---|
.bin | torch.save(model.state_dict(), 'model.bin') | state_dict = torch.load('model.bin', map_location='cpu') model.load_state_dict(state_dict) | .bin 文件通常存储 state_dict,仅保存模型的权重。需要先加载 state_dict,然后用 model.load_state_dict 加载到模型。 |
.pt / .pth | torch.save(model.state_dict(), 'model.pth') | state_dict = torch.load('model.pth', map_location='cpu') model.load_state_dict(state_dict) | .pt / .pth 文件是 PyTorch 中常见的保存权重格式,保存的是模型的 state_dict。加载方式和 .bin 文件相同。 |
.pt (完整模型) | torch.save(model, 'model.pth') | model = torch.load('model.pth', map_location='cpu') | 保存的是整个模型对象(包括结构和权重)。加载时会直接返回模型,无需再用 load_state_dict。 |
.safetensors | from safetensors import save_file save_file(tensors, 'model.safetensors') | from safetensors import load_file state_dict = load_file('model.safetensors') model.load_state_dict(state_dict) | .safetensors 是一种新型格式,专门用于安全和高效地保存模型权重。保存和加载都需要 safetensors 库。 |
Hugging Face .bin | model.save_pretrained('path_to_model') | from transformers import AutoModel model = AutoModel.from_pretrained('path_to_model') | Hugging Face 的 .bin 通常用于保存预训练模型的权重,带有额外的配置文件(如 config.json)。AutoModel 会自动处理加载。 |
.h5 (Keras/TensorFlow) | model.save_weights('model.h5') | model.load_weights('model.h5') | Keras 和 TensorFlow 使用 .h5 文件格式保存模型权重。适用于 TensorFlow/Keras 环境。 |
详细解释:
-
.bin文件.bin文件通常存储模型的state_dict,也就是模型参数的字典。你需要先用torch.load()读取这个字典,再用model.load_state_dict()将它加载到模型中。- 典型保存代码:
torch.save(model.state_dict(), 'model.bin') - 加载代码:
state_dict = torch.load('model.bin', map_location='cpu') model.load_state_dict(state_dict)
-
.pt/.pth文件-
.pt和.pth都是 PyTorch 官方推荐的扩展名,用于保存模型的state_dict或完整模型。对于state_dict,它的加载方式与.bin类似。 -
典型保存代码:
torch.save(model.state_dict(), 'model.pth') -
加载代码:
state_dict = torch.load('model.pth', map_location='cpu') model.load_state_dict(state_dict) -
保存完整模型对象的代码:
torch.save(model, 'model.pth') -
加载完整模型代码:
model = torch.load('model.pth', map_location='cpu')
-
-
.safetensors文件.safetensors是一种新型格式,设计用于比 PyTorch 的.pt文件更加安全、确定性强(防止任意代码执行的漏洞)且内存高效。它适合大型模型的存储。- 典型保存代码:
from safetensors import save_file save_file(tensors, 'model.safetensors') - 加载代码:
from safetensors import load_file state_dict = load_file('model.safetensors') model.load_state_dict(state_dict)
-
Hugging Face
.bin- Hugging Face 通常保存预训练模型的权重为
.bin文件,同时保存模型配置文件config.json,并提供save_pretrained和from_pretrained方法简化保存和加载过程。 - 典型保存代码:
model.save_pretrained('path_to_model') - 加载代码:
from transformers import AutoModel model = AutoModel.from_pretrained('path_to_model')
- Hugging Face 通常保存预训练模型的权重为
总结
.bin/.pt/.pth:这些格式大多数情况下保存的是state_dict,需要通过model.load_state_dict()将其加载到模型中。- 完整模型 (
.pt):保存的是整个模型对象,包含模型架构和权重,加载时直接得到模型实例。 .safetensors:专为安全和效率设计,需要safetensors库处理加载和保存。

4230

被折叠的 条评论
为什么被折叠?



