【权重小技巧(2)】模型权重文件总结: .bin、.safetensors、.pt的保存、加载方法一览

.bin 文件通常是通过 torch.save(state_dict) 来保存的,因此需要使用 torch.load() 进行加载。这是因为 .bin 文件存储的是模型权重的 state_dict(即参数的字典),而不是完整的模型对象。你需要先加载 state_dict,然后将其加载到模型中,类似于 .pt.pth 文件。

常见的权重文件格式(如 .bin.safetensors.pt 等)的保存和加载方法:

文件类型保存方法加载方法描述
.bintorch.save(model.state_dict(), 'model.bin')state_dict = torch.load('model.bin', map_location='cpu')
model.load_state_dict(state_dict)
.bin 文件通常存储 state_dict,仅保存模型的权重。需要先加载 state_dict,然后用 model.load_state_dict 加载到模型。
.pt / .pthtorch.save(model.state_dict(), 'model.pth')state_dict = torch.load('model.pth', map_location='cpu')
model.load_state_dict(state_dict)
.pt / .pth 文件是 PyTorch 中常见的保存权重格式,保存的是模型的 state_dict。加载方式和 .bin 文件相同。
.pt (完整模型)torch.save(model, 'model.pth')model = torch.load('model.pth', map_location='cpu')保存的是整个模型对象(包括结构和权重)。加载时会直接返回模型,无需再用 load_state_dict
.safetensorsfrom safetensors import save_file
save_file(tensors, 'model.safetensors')
from safetensors import load_file
state_dict = load_file('model.safetensors')
model.load_state_dict(state_dict)
.safetensors 是一种新型格式,专门用于安全和高效地保存模型权重。保存和加载都需要 safetensors 库。
Hugging Face .binmodel.save_pretrained('path_to_model')from transformers import AutoModel
model = AutoModel.from_pretrained('path_to_model')
Hugging Face 的 .bin 通常用于保存预训练模型的权重,带有额外的配置文件(如 config.json)。AutoModel 会自动处理加载。
.h5 (Keras/TensorFlow)model.save_weights('model.h5')model.load_weights('model.h5')Keras 和 TensorFlow 使用 .h5 文件格式保存模型权重。适用于 TensorFlow/Keras 环境。

详细解释:

  1. .bin 文件

    • .bin 文件通常存储模型的 state_dict,也就是模型参数的字典。你需要先用 torch.load() 读取这个字典,再用 model.load_state_dict() 将它加载到模型中。
    • 典型保存代码:torch.save(model.state_dict(), 'model.bin')
    • 加载代码:
      state_dict = torch.load('model.bin', map_location='cpu')
      model.load_state_dict(state_dict)
      
  2. .pt / .pth 文件

    • .pt.pth 都是 PyTorch 官方推荐的扩展名,用于保存模型的 state_dict 或完整模型。对于 state_dict,它的加载方式与 .bin 类似。

    • 典型保存代码:torch.save(model.state_dict(), 'model.pth')

    • 加载代码:

      state_dict = torch.load('model.pth', map_location='cpu')
      model.load_state_dict(state_dict)
      
    • 保存完整模型对象的代码:torch.save(model, 'model.pth')

    • 加载完整模型代码:

      model = torch.load('model.pth', map_location='cpu')
      
  3. .safetensors 文件

    • .safetensors 是一种新型格式,设计用于比 PyTorch 的 .pt 文件更加安全、确定性强(防止任意代码执行的漏洞)且内存高效。它适合大型模型的存储。
    • 典型保存代码:
      from safetensors import save_file
      save_file(tensors, 'model.safetensors')
      
    • 加载代码:
      from safetensors import load_file
      state_dict = load_file('model.safetensors')
      model.load_state_dict(state_dict)
      
  4. Hugging Face .bin

    • Hugging Face 通常保存预训练模型的权重为 .bin 文件,同时保存模型配置文件 config.json,并提供 save_pretrainedfrom_pretrained 方法简化保存和加载过程。
    • 典型保存代码:model.save_pretrained('path_to_model')
    • 加载代码:
      from transformers import AutoModel
      model = AutoModel.from_pretrained('path_to_model')
      

总结

  • .bin / .pt / .pth:这些格式大多数情况下保存的是 state_dict,需要通过 model.load_state_dict() 将其加载到模型中。
  • 完整模型 (.pt):保存的是整个模型对象,包含模型架构和权重,加载时直接得到模型实例。
  • .safetensors:专为安全和效率设计,需要 safetensors 库处理加载和保存。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值