在config
中定义了一系列与数据集、模型参数和微调相关的数据类(dataclass),这些基础的数据类在后面用于配置和管理深度学习模型训练和推理过程中的各种选项。下面是对这些类的编写思路、编写目的和作用的详细分析:
编写思路:
-
数据类定义:使用
dataclass
装饰器定义易于管理和存储配置参数的类有用于配置训练和评估模型时的数据输入参数的数据类、初始化训练所需的数据集信息的数据类等
-
灵活性和参数验证:在每个数据类的
__post_init__
方法中,进行参数的初始化和验证,确保提供的配置是合法的。 -
支持JSON序列化:在关键的配置类中,提供了从JSON文件加载和保存到JSON文件的方法,便于配置的持久化和再利用,如
@classmethod def load_from_json(cls, json_path: str): """ 从指定路径加载JSON文件,创建并返回一个新的类实例。 Creates an instance from the content of `json_path`. """ with open(json_path, "r", encoding="utf-8") as f: text = f.read() return cls(**json.loads(text))
-
参数帮助信息:每个字段都通过
metadata
提供了帮助信息,有助于其他组员理解组件功能
编写目的:
-
标准化配置管理:通过定义标准的配置接口,简化模型训练和推理的配置过程。
-
增强代码的可读性和可维护性:使用数据类来组织相关参数,使代码结构更清晰,更易于维护和扩展。
-
提高用户体验:通过详细的参数帮助信息,帮助用户正确配置和使用模型。
-
支持复杂的训练策略:通过详细的微调和生成参数设置,支持从简单的微调到复杂的LoRA和量化微调策略。
定义类:
-
DatasetAttr
类:管理数据集属性,如加载来源、名称、校验码等,支持数据加载时的验证和配置。@dataclass class DatasetAttr: ''' 用于存储数据集属性,包含数据加载来源,名称,SHA1校验码等 ''' load_from: str dataset_name: Optional[str] = None dataset_sha1: Optional[str] = None source_prefix: Optional[str] = None def __repr__(self) -> str: return self.dataset_name def __post_init__(self): self.prompt_column = "instruction" self.query_column = "input" self.response_column = "output" self.history_column = None
-
ModelArguments
类:配置模型相关参数,如模型路径、分词器选项、量化设置等,为模型训练和推理提供必要的配置支持@dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 配置模型路径和缓存目录:指定预训练模型的路径或标识符以及缓存目录。 配置分词器:选择是否使用快速分词器。 配置身份验证:选择是否使用身份验证令牌。 配置模型版本:指定使用的模型版本。 配置量化选项:选择量化位数、类型以及是否使用双重量化。 配置检查点目录:指定保存检查点的目录。 配置奖励模型路径:指定奖励模型的路径。 配置微调选项:选择是否从上次的LoRA权重恢复训练。 配置训练损失绘图:选择是否在微调后绘制训练损失图。 """ ......... def __post_init__(self): if self.checkpoint_dir is not None: # support merging multiple lora weights self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] if self.quantization_bit is not None: assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
-
DataTrainingArguments
类:设置训练和评估时的数据处理参数,如数据集、缓存策略、序列长度等,确保数据按预期被处理和使用。@dataclass class DataTrainingArguments: """ Arguments pertaining to what data we are going to input our model for training and evaluation. 用于配置训练和评估模型时的数据输入参数。 数据集配置:指定使用的数据集名称、目录和数据拆分方式(如训练集或验证集)。 缓存覆盖:是否覆盖缓存的训练和评估集。 预处理:设置预处理的工作线程数量。 序列长度:设置输入和输出序列的最大长度。 调试:为了调试目的,截断每个数据集的样本数量。 评估:设置用于评估的 beam search 数量。 损失计算:是否在损失计算中忽略填充标签。 源前缀:为每个源文本添加前缀。 开发集比例:设置开发集的比例。 提示模板:设置用于构建训练和推理提示的模板。 """ ...... def init_for_training(self): # support mixing multiple datasets """ 方法用于初始化训练所需的数据集信息,支持混合多个数据集。 """ dataset_names = [ds.strip() for ds in self.dataset.split(",")] with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: dataset_info = json.load(f) if self.source_prefix is not None: prefix_list = self.source_prefix.split("|") prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1." else: prefix_list = [None] * len(dataset_names) self.dataset_list: List[DatasetAttr] = [] for i, name in enumerate(dataset_names): if name not in dataset_info: raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) if "hf_hub_url" in dataset_info[name]: dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) elif "script_url" in dataset_info[name]: dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) else: dataset_attr = DatasetAttr( "file", dataset_name=dataset_info[name]["file_name"], dataset_sha1=dataset_info[name].get("file_sha1", None) ) dataset_attr.source_prefix = prefix_list[i] if "columns" in dataset_info[name]: dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) self.dataset_list.append(dataset_attr)
-
FinetuningArguments
类:详细配置微调策略,支持不同的微调方法,如冻结某些层的训练或使用LoRA技术进行微调。class FinetuningArguments: """ Arguments pertaining to which techniques we are going to fine-tuning with. 微调方法:配置不同的微调方法,如 none, freeze, lora, full。 隐藏层数:指定模型中解码器块的数量。 可训练层数:指定在 freeze 微调方法中可训练的层数。 LoRA参数:配置LoRA微调的相关参数,包括 rank, alpha, dropout 等。 目标模块:指定应用LoRA的目标模块。 """
-
GeneratingArguments
类:详细配置生成文本的解码参数,如采样方式、温度、beam搜索等,以控制文本生成的质量和多样性。@dataclass class GeneratingArguments: """ Arguments pertaining to specify the decoding parameters. 配置文本生成时的解码参数 采样:是否使用采样方法生成文本,否则使用贪婪解码。 温度:调整下一个词的概率分布,影响生成的多样性。 top-p:保留概率和不低于 top_p 的最可能词集。 top-k:保留前 top_k 个最可能的词。 beam search:使用 beam search 生成文本的 beam 数量。 最大长度:生成的最大 token 长度。 新 token 数:生成的最大新 token 数,不包括提示的 token。 重复惩罚:调整重复生成的惩罚参数。 长度惩罚:beam search 中对长度的指数惩罚。 """ do_sample: Optional[bool] = field( default=False, metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} ) temperature: Optional[float] = field( default=0.05, metadata={"help": "The value used to modulate the next token probabilities."} ) top_p: Optional[float] = field( default=0.7, metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} ) top_k: Optional[int] = field( default=70, metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} ) num_beams: Optional[int] = field( default=4, metadata={"help": "Number of beams for beam search. 1 means no beam search."} ) max_length: Optional[int] = field( default=4096, metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} ) max_new_tokens: Optional[int] = field( default=2048, metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} ) repetition_penalty: Optional[float] = field( default=1.5, metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} ) length_penalty: Optional[float] = field( default=1.5, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) if args.get("max_new_tokens", None): args.pop("max_length", None) return args
上面的工作定义了一系列结构化的配置类,为后面的深度学习模型的训练和推理提供了标准化和灵活的配置管理方案。在多样化的应用场景中快速部署和优化模型。