pip install safetensors
2.2.2 保存模型权重
使用safetensors保存模型权重,而不是直接使用PyTorch的.save()方法。
import torch
from safetensors.torch import save_file
# 假设model是你的模型实例
model_state_dict = model.state_dict()
# 保存模型到safetensors格式
save_file(model_state_dict, "model.safetensors")
对应的pytorch保存模型的方法
# 保存模型状态字典
torch.save(model.state_dict(), 'model.pth')
# 加载模型状态字典
model = YourModelClass() # 初始化模型实例
model.load_state_dict(torch.load('model.pth')) # 加载权重
model.eval() # 如果是预训练模型,通常设置为评估模式
2.2.3 加载模型权重
加载时,同样使用safetensors的专用函数。
from safetensors.torch import load_file
# 加载模型权重
loaded_state_dict = load_file("model.safetensors")
# 加载到模型中
model.load_state_dict(loaded_state_dict)
原文链接:https://blog.csdn.net/weixin_48007632/article/details/139992354