finetune llama2 error1

transformers版本更新,没有top_k_top_p_filtering()函数了,可以安装合适版本的transformers(<=4.38.2)
我这里选择偷懒,直接将相关函数添加到/***/***/lib/python3.10/site-packages/trl/core.py中

class LogitsWarper:
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        raise NotImplementedError(
            f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
        )

class TopPLogitsWarper(LogitsWarper):
    def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        top_p = float(top_p)
        if top_p < 0 or top_p > 1.0:
            raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
        if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
            raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")

        self.top_p = top_p
        self.filter_value = filter_value
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        sorted_logits, sorted_indices = torch.sort(scores, descending=False)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
        # Keep at least min_tokens_to_keep
        sorted_indices_to_remove[..., -self.min_tokens_to_keep :] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores

class TopKLogitsWarper(LogitsWarper):
    def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
        if not isinstance(top_k, int) or top_k <= 0:
            raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")

        self.top_k = max(top_k, min_tokens_to_keep)
        self.filter_value = filter_value

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        top_k = min(self.top_k, scores.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = scores < torch.topk(scores, top_k)[0][..., -1, None]
        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores

def top_k_top_p_filtering(
    logits: torch.FloatTensor,
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
) -> torch.FloatTensor:
    if top_k > 0:
        logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )

    if 0 <= top_p <= 1.0:
        logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
            None, logits
        )

    return logits

  • 4
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
为了将Llama模型进行fine-tuning以理解中文,你需要按照以下步骤进行操作: 1. 准备数据集:首先,你需要准备一个中文的对话数据集,其中包含了问题和对应的回答。数据集应该是已经清洗和预处理过的,并且按照一定的格式进行组织,比如每行一个问题和对应的回答。 2. 安装依赖:确保你已经安装了相关的依赖库,包括transformers和torch。 3. 加载预训练模型:使用transformers库加载Llama模型的预训练权重。你可以从Hugging Face的模型库中下载预训练权重。 4. 数据处理:将你准备好的中文对话数据集转换为适合fine-tuning的格式。你需要将文本转换为token IDs,并根据需要进行截断或填充。 5. 创建模型:根据预训练模型的配置和你的任务需求,创建一个新的Llama模型。你可以根据自己的需求添加额外的层或修改模型结构。 6. 设置训练参数:根据你的需求,设置fine-tuning的训练参数,例如学习率、batch size等。 7. 训练模型:使用准备好的数据集和设置好的训练参数,开始训练模型。在每个epoch结束后,评估模型的性能并保存最佳的模型权重。 8. 测试模型:使用测试集对训练好的模型进行评估,并根据需要进行调整和改进。 9. 部署模型:将训练好的模型部署到你的应用程序或服务中,以便进行中文对话理解的任务。 需要注意的是,fine-tuning Llama模型可能需要大量的计算资源和时间。确保你有足够的计算资源和合理的时间规划来完成训练过程。另外,还可以尝试使用更大的数据集或其他技术(如数据增强)来提高模型性能。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值