GRPO完整实验流程 swift

本文从较为简单的数学任务 Coundown Game 出发,从数据集定义、奖励函数定义和GRPO训练几个步骤介绍完整的GRPO训练流程。任务定义和训练参数等参考了 mini-deepseek-r1

任务与数据集定义

Coundown Game 的任务目标是根据给定的几个数字和加减乘除四种运算,得到目标数字,因此,我们定义数据集如下:

class CoundownTaskPreprocessor(ResponsePreprocessor):

    def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
        numbers = row['nums']
        target = row.pop('response', None)
        query = f"""
        Using the numbers {numbers}, create an equation that equals {target}.
        You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.
        Show your work in <think> </think> tags. And return the final equation and answer in <answer> </answer> tags,
        for example <answer> (1 + 2) / 3 * 4 = 4 </answer>.
        """
        row.update({'target': target, 'query': query})
        return super().preprocess(row)

register_dataset(
    DatasetMeta(
        ms_dataset_id='zouxuhong/Countdown-Tasks-3to4',
        subsets=['default'],
        preprocess_func=CoundownTaskPreprocessor(),
        tags=['math']))

通过 template, 使用 numbers 和 target 完成任务定义,并给到 query 字段供模型采样使用。同时,我们需要保留 nums 和 target 两个字段,用于后续的奖励函数计算。

奖励函数定义:

本任务使用的奖励函数有两个,一个是 Deepseek-R1 中提到的格式奖励函数,另一是 Coundown Game 的准确性奖励函数。前者已经在swift中内置,通过 --reward_funcs format 可以直接使用,而后者需要我们自己定义,在这里我们使用 external_plugin 的方式定义准确性奖励函数,将代码放在swift/examples/train/grpo/plugin/plugin.py中。

在这里,奖励函数的输入包括 completions、target 和 nums 三个字段,分别表示模型生成的文本、目标答案和可用的数字。每个都是list,支持多个 completion 同时计算。注意,在这里,除了 completions 之外的参数都是数据集中定义的字段透传而来,如果有任务上的变动,可以分别对数据集和奖励函数做对应的改变即可。

class CountdownORM(ORM):
    def __call__(self, completions, target, nums,
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值