【peft】用peft调大模型,加载checkpoint失败,报ValueError: Can‘t find a valid checkpoint at

接上文,用peft的lora训练bloomz,训练1轮之后,发现可能不太充分,打算加载检查点继续训练,代码如下:

trainer.train(resume_from_checkpoint = 'checkpoint目录')

然后报错:raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
ValueError: Can't find a valid checkpoint at checkpoint目录

参考Peft Model not resuming from Checkpoint · Issue #24252 · huggingface/transformers · GitHub

就是_load_from_checkpoint有点问题

解决,新建一个Trainer子类,子类里重写了加载检查点的函数,调用时用这个子类来创建trainer对象

from transformers import Trainer
import os
from peft import PeftModel
from transformers.utils import (
    ADAPTER_SAFE_WEIGHTS_NAME,
    ADAPTER_WEIGHTS_NAME,
    is_sagemaker_mp_enabled,
    is_peft_available,
    logging,
)

logger = logging.get_logger(__name__)

class PeftTrainer(Trainer):
 
    def _load_from_peft_checkpoint(self, resume_from_checkpoint, model):
        adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME)
        adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)

        if not any(
            os.path.isfile(f) for f in [adapter_weights_file, adapter_safe_weights_file]
        ):
            raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")

        logger.info(f"Loading model from {resume_from_checkpoint}.")
        # Load adapters following PR # 24096 
        if is_peft_available() and isinstance(model, PeftModel):
            # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
            if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
                if os.path.exists(resume_from_checkpoint) or os.path.exists(resume_from_checkpoint):
                    model.load_adapter(resume_from_checkpoint, model.active_adapter)
                    # Load_adapter has no return value present, modify it when appropriate.
                    from torch.nn.modules.module import _IncompatibleKeys

                    load_result = _IncompatibleKeys([], [])
                else:
                    logger.warning(
                        "The intermediate checkpoints of PEFT may not be saved correctly, "
                        f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
                        "here are some examples https://github.com/huggingface/peft/issues/96"
                    )
            else:
                logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")

    def _load_from_checkpoint(self, resume_from_checkpoint, model=None):

        if model is None:
            model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
        if is_peft_available() and isinstance(model, PeftModel):
            # Try to load adapters before trying to load a torch model
            try:
                return self._load_from_peft_checkpoint(resume_from_checkpoint, model=model)
            except:
                return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
            # If it is not a PeftModel, use the original _load_from_checkpoint
        else:
            return super()._load_from_checkpoint(resume_from_checkpoint, model=model)
     

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值