在多机多卡训练时,保存的文件无法读取,报错文件已经损坏

文章讲述了在多机多卡训练中,由于optimizer.pt文件被多个进程同时写入导致损坏的问题。通过简化测试和定位,发现问题出在自定义文件保存函数上,解决方案是在保存函数中增加进程判断,只允许全局rank为0的进程保存,避免了冲突。修复后文件能正常读取。
摘要由CSDN通过智能技术生成

问题描述:多机多卡训练保存了optimizer.pt文件,但是该文件在被读取时显示已经损坏。
原来的报错:

Traceback (most recent call last):
  File "/mnt/petrelfs/tongjingqi/train-moe/smoe/entrypoint/cpt_fpt.py", line 280, in <module>
    main()
  File "/mnt/petrelfs/tongjingqi/train-moe/smoe/utils/notification.py", line 146, in wrapper_sender
    raise ex
  File "/mnt/petrelfs/tongjingqi/train-moe/smoe/utils/notification.py", line 95, in wrapper_sender
    value = func(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/tongjingqi/train-moe/smoe/entrypoint/cpt_fpt.py", line 265, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/transformers/trainer.py", line 1539, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/transformers/trainer.py", line 1752, in _inner_training_loop
    self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/transformers/trainer_callback.py", line 353, in on_train_begin
    return self.call_event("on_train_begin", args, state, control)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/transformers/trainer_callback.py", line 397, in call_event
    result = getattr(callback, event)(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/tongjingqi/train-moe/smoe/callbacks/save_model.py", line 134, in on_train_begin
  File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/torch/serialization.py", line 798, in load
    with _open_zipfile_reader(opened_file) as opened_zipfile:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/torch/serialization.py", line 283, in __init__
    super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: PytorchStreamReader failed reading zip archive: invalid header or archive is corrupted

对bug进行最小化的复现:

>>> import torch
>>> torch.load("outputs/cpt-moe-fpt-test_lr_change-1811148/optimizer.pt")
outputs/cpt-moe-fpt-test_lr_change-1811148/optimizer.pt _is_zipfile: True
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/torch/serialization.py", line 798, in load
    with _open_zipfile_reader(opened_file) as opened_zipfile:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/petrelfs/tongjingqi/anaconda3/envs/moeenv/lib/python3.11/site-packages/torch/serialization.py", line 283, in __init__
    super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: PytorchStreamReader failed reading zip archive: invalid header or archive is corrupted

造成错误的保存代码:

    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        optimizer = kwargs['optimizer']
        optimizer_state = optimizer.state_dict()
        save_path = os.path.join(args.output_dir, OPTIMIZER_NAME)
        torch.save(optimizer_state, save_path)

原因:在多机多卡训练时,每个卡的进程都会保存这个文件,如果多个卡同时写文件,就会造成文件损坏。

找出bug的方法:首先尝试用最简单的方法复现出bug,当时报错是在自定义文件读取函数中发生的,为了排除是自定义文件读取函数的问题,尝试直接使用命令行使用torch.load库函数读取文件,果然复现了这个bug。之后为了排除是文件保存函数torch.save()的问题,做了torch.save()的最小化实现,发现可以正常保存和读取,说明不是库函数版本等环境问题。于是将问题范围缩小到自定义文件保存函数上面。
最终发现在保存文件中插入的print函数执行了16次,说明文件被反复的写入,推测可能是多进程写入的冲突导致文件损坏。

修改方法:
在函数开头加入判断,只有全局rank为0的进程才执行保存,这样只会保存一份,就不会出现多进程同时写入冲突导致文件损坏的问题了。

 def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if os.environ.get("RANK", "0") == "0":
            optimizer = kwargs['optimizer']
            optimizer_state = optimizer.state_dict()
            save_path = os.path.join(args.output_dir, OPTIMIZER_NAME)
            torch.save(optimizer_state, save_path)

解决问题后,文件果然可以正常读:
在这里插入图片描述

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值