Sentence Transformer
库升级到了V3,其中对模型训练部分做了优化,使得模型训练和微调更加简单了,跟着官方教程走了一遍,顺利完成向量模型的微调,以下是对官方教程的精炼和总结。
一 所需组件
使用Sentence Transformer
库进行向量模型的微调需要如下的组件:
- 数据数据: 用于训练和评估的数据。
- 损失函数 : 一个量化模型性能并指导优化过程的函数。
- 训练参数 (可选): 影响训练性能和跟踪/调试的参数。
- 评估器 (可选): 一个在训练前、中或后评估模型的工具。
- 训练器 : 将模型、数据集、损失函数和其他组件整合在一起进行训练。
二 数据集
大部分微调用到的数据都是本地的数据集,因此这里只提供本地数据的处理方法。如用其他在线数据可参考相对应的API。
1 数据类型
常见的数据类型为json、csv、parquet,可以使用load_dataset
进行加载:
from datasets import load_dataset
csv_dataset = load_dataset("csv", data_files="my_file.csv")
json_dataset = load_dataset("json", data_files="my_file.json")
parquet_dataset = load_dataset("parquet", data_files="my_file.parquet")
2 数据格式
数据格式需要与损失函数相匹配。如果损失函数需要计算三元组,则数据集的格式为['anchor', 'positive', 'negative']
,且顺序不能颠倒。如果损失函数计算的是句子对的相似度或者标签类别,则数据集中需要包含['label']
或者['score']
,其余列都会作为损失函数的输入。常见的数据格式和损失函数选择见表1。
三 损失函数
从链接整理了一些常见的数据格式和匹配的损失函数
Inputs | Labels | Appropriate Loss Functions |
---|---|---|
(sentence_A, sentence_B) pairs | class | SoftmaxLoss |
(anchor, positive) pairs | none | MultipleNegativesRankingLoss |
(anchor, positive/negative) pairs | 1 if positive, 0 if negative | ContrastiveLoss / OnlineContrastiveLoss |
(sentence_A, sentence_B) pairs | float similarity score | CoSENTLoss / AnglELoss / CosineSimilarityLoss |
(anchor, positive, negative) triplets | none | MultipleNegativesRankingLoss / TripletLoss |
表1 常见的数据格式和损失函数
四 训练参数
配置训练参数主要是用于提升模型的训练效果,同时可以显示训练过程的进度或者其他参数信息,方便调试。