Huggingface Trainer:Removed shared tensor while saving问题的解决

bug概述:【踩坑记录📝】Removed shared tensor while saving.
简单来说,这个bug的危害是trainer.save()无法正确存储权重。这篇博文的作者也给出了两种处理方法,但要么要改transformers版本,要么要包裹Trainer类,太麻烦了。

我使用transformers版本4.38.0,deepspeed zero 0/1/2复现了这篇博文所述bug,研究得知bug的具体原因后,给出以下解决思路:

第一步:获得正确的state_dict

trainer调用_save方法时,默认入参state_dict=None,然后_save方法通过self.model.state_dict()获得state_dict,导致报错。原因是此时self.model是通过accelerator加载的,state_dict需要用self.accelerator.get_state_dict(self.model)的方式获得。

解决思路是提前在调用_save之前提前通过self.accelerator.get_state_dict(self.model)把state_dict取出,然后直接带state_dict入参来调用_save方法,实现权重存储。

我在这里借鉴了QWen-VL仓库finetune.py的写法,参考代码如下:

def safe_save_model_for_hf_trainer(trainer: Trainer, output_dir: str, bias="none"):
    """Collects the state dict and dump to disk."""
    # check if zero3 mode enabled
    if trainer.args.hf_deepspeed_config.config['zero_optimization']['stage'] == 3:
        state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict()
    else:
        state_dict = trainer.accelerator.get_state_dict(trainer.model)
    if trainer.args.should_save and trainer.args.local_rank == 0:
        trainer._save(output_dir, state_dict=state_dict)

第二步:顺利把state_dict存在文件里

trainer在初始化时,args.save_safetensors要设置为false,来让_save方法内部调用model.save_pretrained时不存model.safetensors文件,避免这处调用报错。

第三步:顺利加载权重文件

以上两步做好后,可以保证正确的state_dict在pytorch_model.bin里存下来,而不至于丢失。
再次使用这份权重时,为了从pytorch_model.bin而不是model.safetensors加载权重,可以在from_pretrained的时候要设置use_safetensors=False;也可以把加载目录的model.safetensors直接删掉。

做好以上几步就可以完美解决这个bug.

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值