StoppingCriteria
基类
定义:
StoppingCriteria
是一个抽象基类(Abstract Base Class),用于在文本生成过程中定义停止条件。所有自定义的停止条件类都应继承自该基类,并实现其 __call__
方法。
主要方法:
__call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor
输入参数:
-
input_ids
:torch.LongTensor
,形状为(batch_size, sequence_length)
。- 含义:当前生成的序列的 token IDs。
- 获取方式:可以使用
AutoTokenizer
获取这些 ID,例如通过PreTrainedTokenizer.encode
或直接调用PreTrainedTokenizer.__call__
方法。
-
scores
:torch.FloatTensor
,形状为(batch_size, vocab_size)
。- 含义:模型在当前时间步的预测得分。
- 注意:这些得分可以是在 SoftMax 之前或之后的分数。如果停止条件依赖于
scores
,需要在generate
方法中设置return_dict_in_generate=True
和output_scores=True
。
-
**kwargs
:其他可选的关键字参数。- 含义:特定停止条件所需的其他信息。
处理过程:
- 这是一个抽象方法,需要子类实现具体的逻辑。
- 方法的主要目的是根据输入的
input_ids
、scores
等信息,判断是否满足停止条件。 - 在实现过程中,可以利用
input_ids
和scores
进行自定义的判断逻辑。
返回值:
- 返回一个形状为
(batch_size,)
的torch.BoolTensor
(布尔张量)。- 内容:每个元素对应一个样本,表示该样本是否应停止生成。
True
:表示应停止生成。False
:表示应继续生成。
- 内容:每个元素对应一个样本,表示该样本是否应停止生成。
示例:
下面是一个简单的自定义停止条件示例,假设我们想在生成的序列长度达到 10 时停止:
class MyStoppingCriteria(StoppingCriteria):
def __init__(self, max_length):
self.max_length = max_length
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
current_length = input_ids.shape[1] # 获取当前序列长度
is_done = torch.full(
(input_ids.shape[0],), # batch_size
current_length >= self.max_length, # 判断是否达到最大长度
dtype=torch.bool,
device=input_ids.device
)
return is_done
在这个示例中:
- 输入:
input_ids
和scores
,以及可选的**kwargs
。 - 处理过程:计算当前序列长度,并判断是否达到预设的最大长度。
- 返回值:一个形状为
(batch_size,)
的布尔张量,指示哪些样本应停止生成。
名称 | 类型 | 简要说明 |
---|---|---|
StoppingCriteria | 类 | 所有可在生成过程中应用的停止准则的抽象基类。如果停止准则依赖于scores 输入,需要确保在generate 中传入return_dict_in_generate=True, output_scores=True 。 |
MaxLengthCriteria | 类 | 当生成的序列长度超过指定的max_length 时停止生成。适用于仅解码类型的Transformer模型,这将包括初始提示的长度。 |
MaxTimeCriteria | 类 | 当生成过程超过指定的max_time (以秒为单位)时停止生成。可以指定initial_timestamp 来覆盖生成开始计时的时间。 |
StopStringCriteria | 类 | 当生成的序列包含指定的停止字符串时停止生成。它预处理字符串和tokenizer的词汇表,以找到tokens可以有效完成停止字符串的位置。当生成的序列在末尾包含任何停止字符串时,停止生成。 |
EosTokenCriteria | 类 | 当生成的序列包含end-of-sequence(EOS)标记时停止生成。默认情况下,使用model.generation_config.eos_token_id 。 |
ConfidenceCriteria | 类 | 当模型对当前预测的置信度低于指定的阈值assistant_confidence_threshold 时停止生成,即使尚未达到定义的token数量(num_assistant_tokens )。 |
StoppingCriteriaList | 类 | 停止准则的列表,用于在生成过程中同时应用多个停止条件。包含一个max_length 属性,用于获取最大长度停止准则的max_length 值。 |
validate_stopping_criteria | 函数 | 验证并更新停止准则列表,确保包含最大长度停止准则。如果停止准则列表中没有MaxLengthCriteria ,则添加一个新的MaxLengthCriteria 。 |
以下是 StoppingCriteria
的派生类的输入、返回值,以及它们对数据的处理方式的详细说明:
1. MaxLengthCriteria
输入:
input_ids
:torch.LongTensor
,形状为(batch_size, sequence_length)
。表示当前生成的序列的 token IDs。scores
:torch.FloatTensor
,形状为(batch_size, vocab_size)
。模型在当前时间步的预测得分。**kwargs
:其他可选的关键字参数。
处理过程:
- 计算当前生成序列的长度
cur_len = input_ids.shape[-1]
。 - 判断当前长度是否达到或超过了预设的最大长度
max_length
:- 如果
cur_len >= self.max_length
,表示已经达到或超过最大长度,需要停止生成。
- 如果
- 如果提供了
max_position_embeddings
(模型的最大位置嵌入大小),并且当前长度超过了它,则发出警告提醒可能会超过模型的最大长度。
返回值:
- 返回一个形状为
(batch_size,)
的torch.BoolTensor
,其中每个元素表示对应样本是否应停止生成:- 如果当前序列长度达到或超过了
max_length
,对应的位置为True
,表示应停止生成。 - 否则为
False
,表示继续生成。
- 如果当前序列长度达到或超过了
2. MaxTimeCriteria
输入:
input_ids
:torch.LongTensor
,形状为(batch_size, sequence_length)
。scores
:torch.FloatTensor
,形状为(batch_size, vocab_size)
。**kwargs
:其他可选的关键字参数。
处理过程:
- 计算从生成开始到当前的时间差
elapsed_time = time.time() - self.initial_timestamp
。 - 判断是否超过了预设的最大生成时间
max_time
:- 如果
elapsed_time > self.max_time
,表示已经超过了最大生成时间,需要停止生成。
- 如果
返回值:
- 返回一个形状为
(batch_size,)
的torch.BoolTensor
,其中每个元素表示对应样本是否应停止生成:- 如果已超过最大生成时间,对应的位置为
True
,表示应停止生成。 - 否则为
False
,表示继续生成。
- 如果已超过最大生成时间,对应的位置为
3. StopStringCriteria
输入:
input_ids
:torch.LongTensor
,形状为(batch_size, sequence_length)
。scores
:torch.FloatTensor
,形状为(batch_size, vocab_size)
。**kwargs
:其他可选的关键字参数。
处理过程:
-
预处理(在初始化时完成):
- 接收一个或多个停止字符串
stop_strings
,以及对应的tokenizer
。 - 清理并获取 tokenizer 的词汇表,以得到实际的 token 字符串和对应的索引。
- 预先计算每个 token 在停止字符串中的可能匹配位置,以及可能的结束重叠长度。这部分计算比较复杂,主要目的是为了在生成过程中高效地进行匹配检查。
- 接收一个或多个停止字符串
-
运行时处理:
- 在每次生成新的 token 后,获取最新的
input_ids
,只关注序列末尾可能与停止字符串匹配的部分(取决于停止字符串的最大长度)。 - 反转
input_ids
,方便从序列末尾开始匹配。 - 使用预先计算的嵌入向量
embedding_vec
,通过张量操作检查序列末尾是否与任何一个停止字符串匹配。- 计算 cumulative sum(累积和)以跟踪匹配的字符数。
- 使用掩码(mask)跳过不可能匹配的位置。
- 检查是否有任何一个停止字符串完全匹配,如果是,则需要停止生成。
- 在每次生成新的 token 后,获取最新的
返回值:
- 返回一个形状为
(batch_size,)
的torch.BoolTensor
,其中每个元素表示对应样本是否应停止生成:- 如果在序列末尾匹配到了任何一个停止字符串,对应的位置为
True
,表示应停止生成。 - 否则为
False
,表示继续生成。
- 如果在序列末尾匹配到了任何一个停止字符串,对应的位置为
4. EosTokenCriteria
输入:
input_ids
:torch.LongTensor
,形状为(batch_size, sequence_length)
。scores
:torch.FloatTensor
,形状为(batch_size, vocab_size)
。**kwargs
:其他可选的关键字参数。
处理过程:
- 检查
input_ids
中最后一个 token 是否为结束标记(EOS token)。 eos_token_id
可以是一个整数、整数列表或torch.Tensor
,表示一个或多个结束标记的 ID。- 使用
isin_mps_friendly
函数检查每个样本的最后一个 token 是否在eos_token_id
中。
返回值:
- 返回一个形状为
(batch_size,)
的torch.BoolTensor
,其中每个元素表示对应样本是否应停止生成:- 如果最后一个 token 是 EOS 标记,对应的位置为
True
,表示应停止生成。 - 否则为
False
,表示继续生成。
- 如果最后一个 token 是 EOS 标记,对应的位置为
5. ConfidenceCriteria
输入:
input_ids
:torch.LongTensor
,形状为(batch_size, sequence_length)
。scores
:List[torch.FloatTensor]
,模型在每个时间步的预测得分列表。**kwargs
:其他可选的关键字参数。
处理过程:
- 获取当前时间步(最新)的预测得分
scores[-1]
,并对其应用 softmax 以获得概率分布probs = scores[-1].softmax(-1)
。 - 获取生成的最后一个 token 的概率
p = probs[0, input_ids[0, -1]].item()
(假设 batch_size 为 1)。 - 判断模型对当前预测的置信度是否低于预设的阈值
assistant_confidence_threshold
:- 如果
p < self.assistant_confidence_threshold
,表示模型对当前预测的置信度不够高,需要停止生成。
- 如果
返回值:
- 返回一个布尔值
True
或False
,表示是否应停止生成:True
表示置信度低于阈值,应停止生成。False
表示置信度足够高,继续生成。
注意:
- 与其他派生类不同,
ConfidenceCriteria
的__call__
方法返回的是一个布尔值,而不是torch.BoolTensor
。在使用时需要确保与其他停止准则的返回值类型一致。