LLaMA-Factory/scripts/length_cdf.py 源码解析

这段代码定义了一个函数 length_cdf,用来计算和打印数据集样本长度累积分布函数(CDF),并在脚本直接运行时通过 fire 库将该函数暴露为命令行接口。我们逐行解释这段代码:

python

复制

from llmtuner.data import get_dataset
from llmtuner.hparams import get_train_args
from llmtuner.model import load_tokenizer
  • 从 llmtuner 模块中导入 get_datasetget_train_args  load_tokenizer 函数

python

复制

def length_cdf(
    model_name_or_path: str,
    dataset: Optional[str] = "alpaca_en",
    dataset_dir: Optional[str] = "data",
    template: Optional[str] = "default",
    interval: Optional[int] = 1000,
):
  • 定义了一个名为 length_cdf 的函数,接受以下参数:
    • model_name_or_path: 模型的名称或路径(字符串类型)
    • dataset: 数据集的名称,默认为 "alpaca_en"(可选)。
    • dataset_dir: 数据集的目录,默认为 "data"(可选)。
    • template: 模板名称,默认为 "default"(可选)
    • interval: 计算长度分布的区间,默认为 1000(可选)。

python

复制

    model_args, data_args, training_args, _, _ = get_train_args(
        dict(
            stage="sft",
            model_name_or_path=model_name_or_path,
            dataset=dataset,
            dataset_dir=dataset_dir,
            template=template,
            cutoff_len=1_000_000,
            output_dir="dummy_dir",
            overwrite_cache=True,
        )
    )
  • 调用 get_train_args 函数获取模型参数、数据参数和训练参数。
  • 传递一个字典作为参数,其中包括:
    • stage: 训练阶段,这里是 "sft"(假设是某种微调)。
    • model_name_or_path: 模型的名称或路径。
    • dataset: 数据集名称。
    • dataset_dir: 数据集目录。
    • template: 模板名称。
    • cutoff_len: 截断长度,设置为 1,000,000
    • output_dir: 输出目录,这里是 "dummy_dir"
    • overwrite_cache: 是否覆盖缓存,设置为 True

python

复制

    tokenizer = load_tokenizer(model_args)
  • 调用 load_tokenizer 函数加载分词器,传入 model_args

python

复制

    trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
  • 调用 get_dataset 函数获取训练数据集,传入 tokenizermodel_argsdata_args 和 training_args,指定 stage 为 "sft"

python

复制

    total_num = len(trainset)
  • 计算训练数据集的总样本数,存储在 total_num 中。

python

复制

    length_dict = defaultdict(int)
  • 创建一个默认值为 0 的字典 length_dict,用来统计样本长度的分布。   
    for sample in tqdm(trainset["input_ids"]):
        length_dict[len(sample) // interval * interval] += 1
  • 遍历训练数据集中的每个样本(input_ids)。
  • 计算每个样本长度所在的区间(len(sample) // interval * interval),并在 length_dict 中进行计数。

python

复制

    length_tuples = list(length_dict.items())
    length_tuples.sort()
  • 将 length_dict 转换为一个包含长度区间和计数的元组列表 length_tuples
  • 对列表进行排序。

python

复制

    count_accu, prob_accu = 0, 0
  • 初始化累积计数 count_accu 和累积概率 prob_accu

python

复制

    for length, count in length_tuples:
        count_accu += count
        prob_accu += count / total_num * 100
        print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
  • 遍历 length_tuples 中的每个长度区间和计数。
  • 更新累积计数 count_accu 和累积概率 prob_accu

python

复制

        print("{:d} ({:.2f}%) samples have length < {}.".format(count_accu, prob_accu, length + interval))
  • 打印当前长度区间的累积计数和累积概率。输出格式为:count_accu (prob_accu%) samples have length < length + interval.

python

复制

if __name__ == "__main__":
    fire.Fire(length_cdf)
  • 检查脚本是否作为主程序运行。
  • 如果是,则使用 fire.Fire 将 length_cdf 函数暴露为命令行接口,使得可以通过命令行参数直接调用该函数。

总结

这段代码的主要功能是计算数据集样本长度的累积分布函数(CDF),并打印每个长度区间的累积计数和累积概率。具体步骤包括:

  1. 从命令行参数获取模型和数据集的相关参数。
  2. 加载分词器和数据集。
  3. 统计样本长度的分布。
  4. 计算并打印累积分布。

示例用法

假设你将这段代码保存为 length_cdf.py,可以通过以下命令运行并查看结果:

sh

复制

python length_cdf.py --model_name_or_path "path/to/model" --dataset "custom_dataset" --dataset_dir "path/to/data"

这将根据提供的模型和数据集路径,计算并打印数据集中样本长度的累积分布情况。

为什么要用fire库,有什么好处,用和不用有什么区别

GPT-4o

Poe

使用 fire 库的主要目的是简化命令行接口的创建,使得 Python 脚本可以更容易地通过命令行参数进行调用和控制。下面我们详细解释一下 fire 库的好处,以及使用和不使用 fire 库的区别。

fire 库的好处

  1. 自动生成命令行接口

    • fire 可以自动将 Python 对象(函数、类、模块等)转换为命令行接口。开发者无需手动解析命令行参数fire 会自动处理这些工作。
  2. 简化代码

    • 使用 fire 只需一行代码就能将函数或类暴露为命令行接口,极大地简化了代码量。
  3. 易用性

    • fire 提供了一致且直观的方式来调用 Python 脚本,降低了使用命令行参数的复杂性。
  4. 灵活性

    • fire 支持多种调用方式,包括函数、类和模块,使得开发者可以根据需要灵活选择如何组织代码。

使用和不使用 fire 的区别

使用 fire 库

使用 fire 库的代码示例如下:

python

复制

import fire

def greet(name="World"):
    print(f"Hello, {name}!")

if __name__ == "__main__":
    fire.Fire(greet)
  • 当你运行这个脚本时,可以通过命令行参数直接调用 greet 函数并传递参数:

sh

复制

python script.py --name John
  • 输出:

    复制

    Hello, John!
    
不使用 fire 库

如果不使用 fire 库,你需要手动解析命令行参数,例如使用 argparse 库

import argparse

def greet(name="World"):
    print(f"Hello, {name}!")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Greet someone.")
    parser.add_argument("--name", type=str, default="World", help="Name of the person to greet.")
    args = parser.parse_args()
    greet(args.name)
  • 当你运行这个脚本时,同样可以通过命令行参数传递参数:

sh

复制

python script.py --name John
  • 输出:

    复制

    Hello, John!
    

对比分析

  • 代码简洁性

    • 使用 fire 库,代码更加简洁,只需一行代码即可实现命令行接口。
    • 不使用 fire 库,需要更多的代码来设置和解析命令行参数。
  • 功能和灵活性

    • fire 提供了自动化和简化的功能,适合快速开发和测试。
    • argparse 等库提供了更多的自定义选项和复杂功能,适合复杂的命令行接口需求。

结论

  • 如果你需要快速创建一个简单的命令行接口,fire 是一个很好的选择,因为它可以极大地简化开发过程。
  • 如果你需要更多的控制和自定义选项,比如处理复杂的参数解析逻辑,argparse 或其他专门的命令行解析库可能更适合。

综上所述,选择是否使用 fire 取决于你的具体需求和项目的复杂性。

  • 29
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI生成曾小健

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值