StoppingCriteria代码分析

StoppingCriteria 基类

定义

StoppingCriteria 是一个抽象基类(Abstract Base Class),用于在文本生成过程中定义停止条件。所有自定义的停止条件类都应继承自该基类,并实现其 __call__ 方法。

主要方法

  • __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor

输入参数

  • input_idstorch.LongTensor,形状为 (batch_size, sequence_length)

    • 含义:当前生成的序列的 token IDs。
    • 获取方式:可以使用 AutoTokenizer 获取这些 ID,例如通过 PreTrainedTokenizer.encode 或直接调用 PreTrainedTokenizer.__call__ 方法。
  • scorestorch.FloatTensor,形状为 (batch_size, vocab_size)

    • 含义:模型在当前时间步的预测得分。
    • 注意:这些得分可以是在 SoftMax 之前或之后的分数。如果停止条件依赖于 scores,需要在 generate 方法中设置 return_dict_in_generate=Trueoutput_scores=True
  • **kwargs:其他可选的关键字参数。

    • 含义:特定停止条件所需的其他信息。

处理过程

  • 这是一个抽象方法,需要子类实现具体的逻辑。
  • 方法的主要目的是根据输入的 input_idsscores 等信息,判断是否满足停止条件。
  • 在实现过程中,可以利用 input_idsscores 进行自定义的判断逻辑。

返回值

  • 返回一个形状为 (batch_size,)torch.BoolTensor(布尔张量)。
    • 内容:每个元素对应一个样本,表示该样本是否应停止生成。
      • True:表示应停止生成。
      • False:表示应继续生成。

示例

下面是一个简单的自定义停止条件示例,假设我们想在生成的序列长度达到 10 时停止:

class MyStoppingCriteria(StoppingCriteria):
    def __init__(self, max_length):
        self.max_length = max_length

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
        current_length = input_ids.shape[1]  # 获取当前序列长度
        is_done = torch.full(
            (input_ids.shape[0],),  # batch_size
            current_length >= self.max_length,  # 判断是否达到最大长度
            dtype=torch.bool,
            device=input_ids.device
        )
        return is_done

在这个示例中:

  • 输入input_idsscores,以及可选的 **kwargs
  • 处理过程:计算当前序列长度,并判断是否达到预设的最大长度。
  • 返回值:一个形状为 (batch_size,) 的布尔张量,指示哪些样本应停止生成。

名称类型简要说明
StoppingCriteria所有可在生成过程中应用的停止准则的抽象基类。如果停止准则依赖于scores输入,需要确保在generate中传入return_dict_in_generate=True, output_scores=True
MaxLengthCriteria当生成的序列长度超过指定的max_length时停止生成。适用于仅解码类型的Transformer模型,这将包括初始提示的长度。
MaxTimeCriteria当生成过程超过指定的max_time(以秒为单位)时停止生成。可以指定initial_timestamp来覆盖生成开始计时的时间。
StopStringCriteria当生成的序列包含指定的停止字符串时停止生成。它预处理字符串和tokenizer的词汇表,以找到tokens可以有效完成停止字符串的位置。当生成的序列在末尾包含任何停止字符串时,停止生成。
EosTokenCriteria当生成的序列包含end-of-sequence(EOS)标记时停止生成。默认情况下,使用model.generation_config.eos_token_id
ConfidenceCriteria当模型对当前预测的置信度低于指定的阈值assistant_confidence_threshold时停止生成,即使尚未达到定义的token数量(num_assistant_tokens)。
StoppingCriteriaList停止准则的列表,用于在生成过程中同时应用多个停止条件。包含一个max_length属性,用于获取最大长度停止准则的max_length值。
validate_stopping_criteria函数验证并更新停止准则列表,确保包含最大长度停止准则。如果停止准则列表中没有MaxLengthCriteria,则添加一个新的MaxLengthCriteria

以下是 StoppingCriteria 的派生类的输入、返回值,以及它们对数据的处理方式的详细说明:


1. MaxLengthCriteria

输入

  • input_idstorch.LongTensor,形状为 (batch_size, sequence_length)。表示当前生成的序列的 token IDs。
  • scorestorch.FloatTensor,形状为 (batch_size, vocab_size)。模型在当前时间步的预测得分。
  • **kwargs:其他可选的关键字参数。

处理过程

  • 计算当前生成序列的长度 cur_len = input_ids.shape[-1]
  • 判断当前长度是否达到或超过了预设的最大长度 max_length
    • 如果 cur_len >= self.max_length,表示已经达到或超过最大长度,需要停止生成。
  • 如果提供了 max_position_embeddings(模型的最大位置嵌入大小),并且当前长度超过了它,则发出警告提醒可能会超过模型的最大长度。

返回值

  • 返回一个形状为 (batch_size,)torch.BoolTensor,其中每个元素表示对应样本是否应停止生成:
    • 如果当前序列长度达到或超过了 max_length,对应的位置为 True,表示应停止生成。
    • 否则为 False,表示继续生成。

2. MaxTimeCriteria

输入

  • input_idstorch.LongTensor,形状为 (batch_size, sequence_length)
  • scorestorch.FloatTensor,形状为 (batch_size, vocab_size)
  • **kwargs:其他可选的关键字参数。

处理过程

  • 计算从生成开始到当前的时间差 elapsed_time = time.time() - self.initial_timestamp
  • 判断是否超过了预设的最大生成时间 max_time
    • 如果 elapsed_time > self.max_time,表示已经超过了最大生成时间,需要停止生成。

返回值

  • 返回一个形状为 (batch_size,)torch.BoolTensor,其中每个元素表示对应样本是否应停止生成:
    • 如果已超过最大生成时间,对应的位置为 True,表示应停止生成。
    • 否则为 False,表示继续生成。

3. StopStringCriteria

输入

  • input_idstorch.LongTensor,形状为 (batch_size, sequence_length)
  • scorestorch.FloatTensor,形状为 (batch_size, vocab_size)
  • **kwargs:其他可选的关键字参数。

处理过程

  • 预处理(在初始化时完成):

    • 接收一个或多个停止字符串 stop_strings,以及对应的 tokenizer
    • 清理并获取 tokenizer 的词汇表,以得到实际的 token 字符串和对应的索引。
    • 预先计算每个 token 在停止字符串中的可能匹配位置,以及可能的结束重叠长度。这部分计算比较复杂,主要目的是为了在生成过程中高效地进行匹配检查。
  • 运行时处理

    • 在每次生成新的 token 后,获取最新的 input_ids,只关注序列末尾可能与停止字符串匹配的部分(取决于停止字符串的最大长度)。
    • 反转 input_ids,方便从序列末尾开始匹配。
    • 使用预先计算的嵌入向量 embedding_vec,通过张量操作检查序列末尾是否与任何一个停止字符串匹配。
      • 计算 cumulative sum(累积和)以跟踪匹配的字符数。
      • 使用掩码(mask)跳过不可能匹配的位置。
      • 检查是否有任何一个停止字符串完全匹配,如果是,则需要停止生成。

返回值

  • 返回一个形状为 (batch_size,)torch.BoolTensor,其中每个元素表示对应样本是否应停止生成:
    • 如果在序列末尾匹配到了任何一个停止字符串,对应的位置为 True,表示应停止生成。
    • 否则为 False,表示继续生成。

4. EosTokenCriteria

输入

  • input_idstorch.LongTensor,形状为 (batch_size, sequence_length)
  • scorestorch.FloatTensor,形状为 (batch_size, vocab_size)
  • **kwargs:其他可选的关键字参数。

处理过程

  • 检查 input_ids 中最后一个 token 是否为结束标记(EOS token)。
  • eos_token_id 可以是一个整数、整数列表或 torch.Tensor,表示一个或多个结束标记的 ID。
  • 使用 isin_mps_friendly 函数检查每个样本的最后一个 token 是否在 eos_token_id 中。

返回值

  • 返回一个形状为 (batch_size,)torch.BoolTensor,其中每个元素表示对应样本是否应停止生成:
    • 如果最后一个 token 是 EOS 标记,对应的位置为 True,表示应停止生成。
    • 否则为 False,表示继续生成。

5. ConfidenceCriteria

输入

  • input_idstorch.LongTensor,形状为 (batch_size, sequence_length)
  • scoresList[torch.FloatTensor],模型在每个时间步的预测得分列表。
  • **kwargs:其他可选的关键字参数。

处理过程

  • 获取当前时间步(最新)的预测得分 scores[-1],并对其应用 softmax 以获得概率分布 probs = scores[-1].softmax(-1)
  • 获取生成的最后一个 token 的概率 p = probs[0, input_ids[0, -1]].item()(假设 batch_size 为 1)。
  • 判断模型对当前预测的置信度是否低于预设的阈值 assistant_confidence_threshold
    • 如果 p < self.assistant_confidence_threshold,表示模型对当前预测的置信度不够高,需要停止生成。

返回值

  • 返回一个布尔值 TrueFalse,表示是否应停止生成:
    • True 表示置信度低于阈值,应停止生成。
    • False 表示置信度足够高,继续生成。

注意

  • 与其他派生类不同,ConfidenceCriteria__call__ 方法返回的是一个布尔值,而不是 torch.BoolTensor。在使用时需要确保与其他停止准则的返回值类型一致。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值