机器学习翻译任务中的constrain decoding 实现流程方法详解

前言

在翻译任务中(从一种格式的语言翻译到另一种格式的语言,并不局限于国家自然语言的互相转换,从自然语言翻译到程序代码也算一种,比如Text2SQL任务),通常会涉及到格式化语言的输出,这时候需要constrain decoding算法来避免模型输出词汇/语法错误的sequence,按照我们给出的一系列准则,使模型输出我们希望的结果。

本文主要介绍项目中本人调研并使用过的constrain decoding算法,介绍其实现流程以及优缺点,为后来者提供参考和借鉴。

一、(Vectorized) Lexically constrained decoding with dynamic beam allocation

官方Readme文档参考链接戳这里
这一方法目前已经集成到了fairseq源码中,可以直接调用,它的使用原理很简单,给定一组候选词汇list, 模型generate时必将输出包含list里所有内容的sequence。如果传入ordered参数,意味着模型在最后输出的sequence里出现的候选词的顺序和list给出的顺序是一致的,在Readme里体现为:
在这里插入图片描述
但是MD文档里只介绍了fairseq-interactive的使用方法,并没有给出python代码的调用接口,在此给出相关实现方式供参考:

from fairseq.models.bart import BARTModel
from fairseq.token_generation_constraints import pack_constraints

checkpoint_dir="xxxx/xxxx/xxxx" #模型和词典的存放目录
model = BARTModel.from_pretrained(model_name_or_path=checkpoint_dir, checkpoint_file="checkpoint_best.pt").eval()
model.cuda()

constrain_words = ["aaa","bbb","ccc",...]

result = model.translate(
        sentences=sequences,
        beam=5,
        constraints = "unordered",
        inference_step_args = {
            "constraints":pack_constraints([
                    [model.encode("{}".format(w))[1:-1] for w in constrain_words] #这里的[1:-1]主要针对BART预加载模型的处理,去掉<s>和</s>起止符号,也就是只有token本身需要encode进去,其他的预加载模型需要根据其特点灵活处理
                ])
        }
    )

二、GENRE constrain decoding

官方文档链接戳这里
⚠️ 首先需要注意fairseq版本需要用他给出的:
https://github.com/facebookresearch/GENRE/tree/main/examples_genre

git clone --branch fixing_prefix_allowed_tokens_fn https://github.com/nicola-decao/fairseq
cd fairseq
pip install --editable ./

他在constrain decoding时的主要思路是,在模型做generate的每一个step时,根据需求传入可选的token list,当前step可选的token list主要取决于之前生成的prev output tokens,也就是下面的参数sent,根据我们提供的候选词list,构造前缀字典树Trie。
每一次step选择合法的候选词的关键函数是prefix_allowed_tokens_fn(batch_id, sent),我们需要根据具体的需求修改内部逻辑。
其主要在apply_mask中调用:
在这里插入图片描述
我们需要根据提供的候选词构造出它所需要的前缀树Trie,参考代码:

from genre.trie import Trie
from genre.entity_linking import get_end_to_end_prefix_allowed_tokens_fn_fairseq
from genre.fairseq_model import GENRE

model = GENRE.from_pretrained(model_name_or_path=checkpoint_dir, checkpoint_file=ckpt_file).eval()
model.cuda(f'cuda:{0}')

sentences = ['Hello world']
constrain_words = ["aaa","bbb","ccc",...]

example_trie=Trie([
        model.encode("{}".format(w))[:].tolist() for w in constrain_words
    ])

prefix_allowed_tokens_fn = get_end_to_end_prefix_allowed_tokens_fn_fairseq(
    model,
    sentences,
    mention_trie=example_trie,
)

results = model.translate(
        sentences=sentences,
        beam=5,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn
)

如有疑问可在公众号后台留言,期待您的关注!回复【关键词】可免费领取本人上传至CSDN的所有资料~

在这里插入图片描述


<think>好的,我现在要处理用户关于Arduino中constrain用法的询问。首先,我需要回忆一下自己对constrain函数的理解,确保信息准确。记得constrain是用来限制变量的范围,防止其超出设定的最小值和最大值。不过,我需要确认它的具体参数和返回值,避免记忆错误。 接下来,用户可能想知道函数的基本语法,应该先给出函数原型,比如constrain(x, a, b),然后解释每个参数的意义。这里要注意参数的数据类型,通常x、a、b都是整数或浮点数,但可能需要说明是否支持其他类型。另外,要明确返回值是怎样的,比如当x在a和b之间时返回x,否则返回a或b中的一个。需要特别指出a和b的大小顺序是否影响结果,比如不管a和b谁大谁小,函数都能正确处理,这可能是一个容易误解的地方。 然后,用户可能想了解具体的应用场景,比如在传感器读数时限制数值范围,防止执行器过载。这时候举几个例子会比较有帮助。比如,当传感器读数超过预期时,用constrain将其限制在有效范围内。另外,PWM输出也是一个常见用例,确保输出值在0-255之间。 还需要考虑用户可能的常见错误,比如错误地使用a和b的顺序,或者误解函数的行为。这时候需要强调函数会自动处理a和b的大小,无论输入顺序如何,都会正确限制到[min, max]的范围。同时,提醒用户如果a等于b的情况,函数会如何处理,比如返回a(或b)的值。 另外,可能需要对比其他方法,比如手动使用if-else语句来实现相同的功能,说明使用constrain的好处是代码简洁,减少出错的可能性。同时,提到性能方面,虽然constrain可能内部也是简单的比较,但使用函数可以提高代码可读性。 用户还可能想知道这个函数属于哪个库,是否需要包含头文件。通常,constrain是Arduino核心库的一部分,不需要额外包含,这点需要说明清楚。 在结构上,按照用户的要求,回答要分步骤,结构清晰。可能需要分点说明函数的作用、语法、参数、返回值、示例以及注意事项。每个部分要简明扼要,用例子辅助解释,确保用户容易理解。 最后,检查是否有遗漏的信息,比如是否适用于不同的数据类型(int, float),或者是否有其他相关函数,比如map函数,但可能不需要深入,除非用户提到。同时,确保所有LaTeX公式正确,但根据用户的问题,可能不需要太多数学表达式,除非在示例中涉及计算,这时候要注意使用正确的格式。 总结一下,我需要分步骤解释constrain的用法,包括语法、参数、返回值、示例和注意事项,确保信息准确且易于理解。同时,注意回答的中文表达和结构清晰,符合用户的要求。</think>### Arduino中`constrain()`函数的用法详解 #### 1. **作用** `constrain()`函数用于将变量的值限制在指定的最小值和最大值之间。 - 若输入值在范围内,则返回原值; - 若超出范围,则返回边界值(最小值或最大值)。 #### 2. **语法** ```cpp constrain(value, min, max); ``` - **参数说明**: - `value`:待限制的变量(支持`int`、`float`等数值类型)。 - `min`:允许的最小值。 - `max`:允许的最大值。 - **返回值**: - 当 `value` $\in [\text{min}, \text{max}]$ 时,返回 `value`; - 当 `value < min` 时,返回 `min`; - 当 `value > max` 时,返回 `max`。 #### 3. **关键特性** - **自动处理顺序**:无论 `min` 和 `max` 的大小顺序如何,函数会自动识别并约束到正确区间。 例如:`constrain(5, 10, 0)` 等效于 `constrain(5, 0, 10)`,最终返回 `5`。 - **数据类型兼容性**:适用于整数、浮点数等。 #### 4. **示例** **示例1:限制传感器读数范围** ```cpp int sensorValue = analogRead(A0); int safeValue = constrain(sensorValue, 0, 1023); // 确保数值在0~1023之间 ``` **示例2:限制PWM输出范围** ```cpp int pwmOutput = map(analogRead(A1), 0, 1023, 0, 300); pwmOutput = constrain(pwmOutput, 0, 255); // 确保PWM值不超过255 analogWrite(9, pwmOutput); ``` #### 5. **手动实现对比** 若不使用`constrain()`,需通过条件语句实现: ```cpp int manualConstrain(int value, int min, int max) { if (value < min) return min; else if (value > max) return max; else return value; } ``` 使用内置函数可简化代码并提升可读性。 #### 6. **注意事项** - **边界相等情况**:若 `min == max`,函数始终返回该值。 - **性能优化**:内部实现高效,适合实时控制场景。 #### 7. **适用场景** - 传感器数据校准(如过滤异常值)。 - 执行器保护(如电机速度限制)。 - 数据归一化处理(如映射后范围约束)。 通过合理使用`constrain()`,可增强代码的健壮性,避免硬件或逻辑错误。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值