保存没有压缩的原始模型和及其模型状态、保存压缩后的模型和及其模型状态、加载没有压缩的原始模型文件和及其模型状态、加载压缩后的模型和及其模型状态

日萌社

 

人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新)


保存没有压缩的原始模型和及其模型状态、保存压缩后的模型和及其模型状态、加载没有压缩的原始模型文件和及其模型状态、加载压缩后的模型和及其模型状态

Pytorch:模型保存与加载方式


"""
pytorch_model.bin模型状态(字典数据)文件、xx.pt、xx.pkl 的区别
1.xx.pt、xx.pkl均是torch.save(model.state_dict(), "./xx.pt" 或 "./xx.pkl" ) 保存的 模型状态(字典数据)文件
2.pytorch_model.bin实际也为模型状态(字典数据)文件,保存的权重数据(状态字典数据) 实际和 xx.pt/xx.pkl 这样的模型状态(字典数据)文件 的作用是一样的。
3.通过save_pretrained("保存路径")所保存的默认文件pytorch_model.bin 的文件大小 实际和 
  使用 torch.save(model.state_dict(), "./xx.pt" 或 "./xx.pkl" )所保存的模型状态(字典数据)文件的文件大小 是一样的,
  也就是 pytorch_model.bin模型状态(字典数据)文件 和 xx.pt/xx.pkl 等不同类型的模型状态(字典数据)文件 都可以保存模型的状态字典数据,
  并且不管使用哪种类型的模型状态(字典数据)文件,他们的文件大小都是一致的。 

1.BertTokenizer.from_pretrained:实际自动加载的是vocab.txt
  BertForSequenceClassification.from_pretrained:实际自动加载的是pytorch_model.bin模型状态文件

2.保存没有压缩的/压缩后的模型的模型状态(字典数据)、加载没有压缩的/压缩后的模型状态(字典数据)
    1.保存没有压缩的原始模型的模型状态(字典数据)
        torch.save(model.state_dict(), "./xx.pt" 或 "./xx.pkl")
        
    2.加载没有压缩的原始模型的模型状态(字典数据)
        model = 模型类Model()
        model.load_state_dict(torch.load("./xx.pt" 或 "./xx.pkl"))
        
    3.对原始模型应用动态量化技术进行模型压缩,使用save_pretrained把压缩后的模型的模型状态(字典数据)存储到本地默认的pytorch_model.bin文件
        #应用动态量化技术:
        #   使用torch.quantization.quantize_dynamic获得动态量化的模型
        #   量化的网络层为所有的nn.Linear的权重,使其成为int8
        #quantize_dynamic输入的为没有压缩的原始模型,quantize_dynamic输出的为压缩后的模型
        quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
        # 使用save_pretrained保存压缩后的模型的模型状态到本地文件,文件名默认为pytorch_model.bin
        quantized_model.save_pretrained("./bert_finetuning_test/quantized")
        
    4.加载压缩后的模型的模型状态(字典数据)文件
            # 实际加载的是"./bert_finetuning_test/quantized"目录下的vocab.txt
            tokenizer_load_model = BertTokenizer.from_pretrained("./bert_finetuning_test/quantized", do_lower_case=configs_load_model.do_lower_case)
            # 加载带有文本分类头的BERT模型:实际加载的是"./bert_finetuning_test/quantized"目录下的pytorch_model.bin模型状态(字典数据)文件
            quantized_load_model = BertForSequenceClassification.from_pretrained("./bert_finetuning_test/quantized")
            # 因为加载的为压缩后的模型的模型状态(字典数据)文件,因此此处需要使用动态量化技术
            quantized_load_model = torch.quantization.quantize_dynamic(quantized_load_model, {torch.nn.Linear}, dtype=torch.qint8)
            # 可以选择加载pytorch_model.bin模型状态(字典数据)文件,或者选择加载pt/pkl文件也可以
            state = torch.load("./pytorch_model.bin" 或 "./xx.pt或./xx.pkl")
            # 把模型状态(字典数据) 加载到模型中
            quantized_load_model.load_state_dict(state)
"""


##############
# 加载量化模型
##############
# 加载BERT预训练模型的数值映射器
# 实际加载的是"./bert_finetuning_test/quantized"目录下的vocab.txt
tokenizer_load_model = BertTokenizer.from_pretrained("./bert_finetuning_test/quantized", do_lower_case=configs.do_lower_case)
# 加载带有文本分类头的BERT模型
# 加载带有文本分类头的BERT模型:实际加载的是"./bert_finetuning_test/quantized"目录下的pytorch_model.bin模型状态(字典数据)文件
quantized_load_model = BertForSequenceClassification.from_pretrained("./bert_finetuning_test/quantized")
# 应用动态量化技术:
#   使用torch.quantization.quantize_dynamic获得动态量化的模型
#   量化的网络层为所有的nn.Linear的权重,使其成为int8
quantized_load_model = torch.quantization.quantize_dynamic(quantized_load_model, {torch.nn.Linear}, dtype=torch.qint8)
#实际加载的pytorch_model.bin该文件仍然是模型的状态字典的文件,并不是包含模型结构的文件
quantized_load_model.load_state_dict(torch.load("./bert_finetuning_test/quantized/pytorch_model.bin"))
# quantized_load_model.load_state_dict(torch.load("./quantized_model.pt"))
# print(quantized_load_model)
# 将模型传到制定设备上
quantized_load_model.to(configs_load_model.device)
#对比压缩后的模型的推理准确性和耗时
#获得模型评估结果和运行时间
time_model_evaluation(quantized_load_model, configs_load_model, tokenizer_load_model)

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

あずにゃん

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值