模型配置项编写

在config中定义了一系列与数据集、模型参数和微调相关的数据类(dataclass),这些基础的数据类在后面用于配置和管理深度学习模型训练和推理过程中的各种选项。下面是对这些类的编写思路、编写目的和作用的详细分析:

编写思路:

  1. 数据类定义:使用dataclass装饰器定义易于管理和存储配置参数的类

    有用于配置训练和评估模型时的数据输入参数的数据类、初始化训练所需的数据集信息的数据类等

  1. 灵活性和参数验证:在每个数据类的__post_init__方法中,进行参数的初始化和验证,确保提供的配置是合法的。

  2. 支持JSON序列化:在关键的配置类中,提供了从JSON文件加载和保存到JSON文件的方法,便于配置的持久化和再利用,如

    @classmethod
    def load_from_json(cls, json_path: str):
        """
        从指定路径加载JSON文件,创建并返回一个新的类实例。
        Creates an instance from the content of `json_path`.
        """
        with open(json_path, "r", encoding="utf-8") as f:
            text = f.read()
        return cls(**json.loads(text))
  3. 参数帮助信息:每个字段都通过metadata提供了帮助信息,有助于其他组员理解组件功能

编写目的:

  • 标准化配置管理:通过定义标准的配置接口,简化模型训练和推理的配置过程。

  • 增强代码的可读性和可维护性:使用数据类来组织相关参数,使代码结构更清晰,更易于维护和扩展。

  • 提高用户体验:通过详细的参数帮助信息,帮助用户正确配置和使用模型。

  • 支持复杂的训练策略:通过详细的微调和生成参数设置,支持从简单的微调到复杂的LoRA和量化微调策略。

定义类:

  1. DatasetAttr:管理数据集属性,如加载来源、名称、校验码等,支持数据加载时的验证和配置。

    @dataclass
    class DatasetAttr:
    ​
        '''
        用于存储数据集属性,包含数据加载来源,名称,SHA1校验码等
        '''
        load_from: str
        dataset_name: Optional[str] = None
        dataset_sha1: Optional[str] = None
        source_prefix: Optional[str] = None
    ​
        def __repr__(self) -> str:
            return self.dataset_name
    ​
        def __post_init__(self):
            self.prompt_column = "instruction"
            self.query_column = "input"
            self.response_column = "output"
            self.history_column = None

  2. ModelArguments:配置模型相关参数,如模型路径、分词器选项、量化设置等,为模型训练和推理提供必要的配置支持

    @dataclass
    class ModelArguments:
        """
        Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
        配置模型路径和缓存目录:指定预训练模型的路径或标识符以及缓存目录。
        配置分词器:选择是否使用快速分词器。
        配置身份验证:选择是否使用身份验证令牌。
        配置模型版本:指定使用的模型版本。
        配置量化选项:选择量化位数、类型以及是否使用双重量化。
        配置检查点目录:指定保存检查点的目录。
        配置奖励模型路径:指定奖励模型的路径。
        配置微调选项:选择是否从上次的LoRA权重恢复训练。
        配置训练损失绘图:选择是否在微调后绘制训练损失图。
        """
        
        .........
        
        def __post_init__(self):
            if self.checkpoint_dir is not None: # support merging multiple lora weights
                self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
    ​
            if self.quantization_bit is not None:
                assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."

  3. DataTrainingArguments:设置训练和评估时的数据处理参数,如数据集、缓存策略、序列长度等,确保数据按预期被处理和使用。

    @dataclass
    class DataTrainingArguments:
        """
        Arguments pertaining to what data we are going to input our model for training and evaluation.
        用于配置训练和评估模型时的数据输入参数。
        数据集配置:指定使用的数据集名称、目录和数据拆分方式(如训练集或验证集)。
        缓存覆盖:是否覆盖缓存的训练和评估集。
        预处理:设置预处理的工作线程数量。
        序列长度:设置输入和输出序列的最大长度。
        调试:为了调试目的,截断每个数据集的样本数量。
        评估:设置用于评估的 beam search 数量。
        损失计算:是否在损失计算中忽略填充标签。
        源前缀:为每个源文本添加前缀。
        开发集比例:设置开发集的比例。
        提示模板:设置用于构建训练和推理提示的模板。
        """
    ​
        ......
        
        
        def init_for_training(self): # support mixing multiple datasets
            """
            方法用于初始化训练所需的数据集信息,支持混合多个数据集。
            """
            dataset_names = [ds.strip() for ds in self.dataset.split(",")]
            with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
                dataset_info = json.load(f)
    ​
            if self.source_prefix is not None:
                prefix_list = self.source_prefix.split("|")
                prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
                assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
            else:
                prefix_list = [None] * len(dataset_names)
    ​
            self.dataset_list: List[DatasetAttr] = []
            for i, name in enumerate(dataset_names):
                if name not in dataset_info:
                    raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
    ​
                if "hf_hub_url" in dataset_info[name]:
                    dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
                elif "script_url" in dataset_info[name]:
                    dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
                else:
                    dataset_attr = DatasetAttr(
                        "file",
                        dataset_name=dataset_info[name]["file_name"],
                        dataset_sha1=dataset_info[name].get("file_sha1", None)
                    )
    ​
                dataset_attr.source_prefix = prefix_list[i]
    ​
                if "columns" in dataset_info[name]:
                    dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
                    dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
                    dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
                    dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
    ​
                self.dataset_list.append(dataset_attr)

  4. FinetuningArguments:详细配置微调策略,支持不同的微调方法,如冻结某些层的训练或使用LoRA技术进行微调。

    class FinetuningArguments:
        """
        Arguments pertaining to which techniques we are going to fine-tuning with.
    ​
        微调方法:配置不同的微调方法,如 none, freeze, lora, full。
        隐藏层数:指定模型中解码器块的数量。
        可训练层数:指定在 freeze 微调方法中可训练的层数。
        LoRA参数:配置LoRA微调的相关参数,包括 rank, alpha, dropout 等。
        目标模块:指定应用LoRA的目标模块。
        """

  5. GeneratingArguments:详细配置生成文本的解码参数,如采样方式、温度、beam搜索等,以控制文本生成的质量和多样性。

    @dataclass
    class GeneratingArguments:
        """
        Arguments pertaining to specify the decoding parameters.
        配置文本生成时的解码参数
    ​
        采样:是否使用采样方法生成文本,否则使用贪婪解码。
        温度:调整下一个词的概率分布,影响生成的多样性。
        top-p:保留概率和不低于 top_p 的最可能词集。
        top-k:保留前 top_k 个最可能的词。
        beam search:使用 beam search 生成文本的 beam 数量。
        最大长度:生成的最大 token 长度。
        新 token 数:生成的最大新 token 数,不包括提示的 token。
        重复惩罚:调整重复生成的惩罚参数。
        长度惩罚:beam search 中对长度的指数惩罚。
        """
        do_sample: Optional[bool] = field(
            default=False,
            metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
        )
        temperature: Optional[float] = field(
            default=0.05,
            metadata={"help": "The value used to modulate the next token probabilities."}
        )
        top_p: Optional[float] = field(
            default=0.7,
            metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
        )
        top_k: Optional[int] = field(
            default=70,
            metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
        )
        num_beams: Optional[int] = field(
            default=4,
            metadata={"help": "Number of beams for beam search. 1 means no beam search."}
        )
        max_length: Optional[int] = field(
            default=4096,
            metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
        )
        max_new_tokens: Optional[int] = field(
            default=2048,
            metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
        )
        repetition_penalty: Optional[float] = field(
            default=1.5,
            metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
        )
        length_penalty: Optional[float] = field(
            default=1.5,
            metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
        )
    ​
        def to_dict(self) -> Dict[str, Any]:
            args = asdict(self)
            if args.get("max_new_tokens", None):
                args.pop("max_length", None)
            return args

上面的工作定义了一系列结构化的配置类,为后面的深度学习模型的训练和推理提供了标准化和灵活的配置管理方案。在多样化的应用场景中快速部署和优化模型。

  • 18
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
数字乡村和智慧农业的数字化转型是当前农业发展的新趋势,旨在通过应用数字技术,实现农业全流程的再造和全生命周期的管理服务。中国政府高度重视这一领域的发展,提出“数字中国”和“乡村振兴”战略,以提升国家治理能力,推动城乡融合发展。 数字乡村的建设面临乡村治理、基础设施、产业链条和公共服务等方面的问题,需要分阶段实施《数字乡村发展战略纲要》来解决。农业数字化转型的需求包括满足市民对优质农产品的需求、解决产销对接问题、形成优质优价机制、提高农业劳动力素质、打破信息孤岛、提高农业政策服务的精准度和有效性,以及解决农业融资难的问题。 数字乡村建设的关键在于构建“1+3+4+1”工程,即以新技术、新要素、新商业、新农民、新文化、新农村为核心,推进数据融合,强化农业大数据的汇集功能。数字农业大数据解决方案以农业数字底图和数据资源为基础,通过可视化监管,实现区域农业的全面数字化管理。 数字农业大数据架构基于大数据、区块链、GIS和物联网技术,构建农业大数据中心、农业物联网平台和农村综合服务指挥决策平台三大基础平台。农业大数据中心汇聚各类涉农信息资源和业务数据,支持大数据应用。信息采集系统覆盖市、县、乡、村多级,形成高效的农业大数据信息采集体系。 农业物联网平台包括环境监测系统、视频监控系统、预警预报系统和智能控制系统,通过收集和监测数据,实现对农业环境和生产过程的智能化管理。综合服务指挥决策平台利用数据分析和GIS技术,为农业决策提供支持。 数字乡村建设包括三大服务平台:治理服务平台、民生服务平台和产业服务平台。治理服务平台通过大数据和AI技术,实现乡村治理的数字化;民生服务平台利用互联网技术,提供各类民生服务;产业服务平台融合政企关系,支持农业产业发展。 数字乡村的应用场景广泛,包括农业生产过程、农产品流通、农业管理和农村社会服务。农业生产管理系统利用AIoT技术,实现农业生产的标准化和智能化。农产品智慧流通管理系统和溯源管理系统提高流通效率和产品追溯能力。智慧农业管理通过互联网+农业,提升农业管理的科学性和效率。农村社会服务则通过数字化手段,提高农村地区的公共服务水平。 总体而言,数字乡村和智慧农业的建设,不仅能够提升农业生产效率和管理水平,还能够促进农村地区的社会经济发展,实现城乡融合发展,是推动中国农业现代化的重要途径。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值