本篇文章着重于chatGPT训练流程的复现
来自:无数据不智能
进NLP群—>加入NLP交流群
环境安装
虚拟环境创建
conda create -n chatgpt python=3.10
conda activate chatgpt
依赖包安装
git clone https://github.com/LAION-AI/Open-Assistant.git
cd Open-Assistat/model
pip install -r model_training/requirements.txt
pip install -r reward/instructor/requirements.txt
安装trlx
git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
pip install -e .
在Open-Assistant目录下,安装oasst-shared
cd oasst-shared/
pip install -e .
SFT
以翻译为例,prompt:
"zh": [
"翻译成中文: {}",
"{} 这句中文翻译怎麽写?",
"我需要这句话的中文翻译: {}",
]
数据样例
[
"<human>+随机选择一个prompt.format(原句)+<bot>",
"译句"
]
训练脚本
mkdir cache
mkdir sft_model
python trainer_sft.py --configs defaults pythia --cache_dir ./cache --output_dir ./sft_model
配置文件
defaults:
learning_rate: 1e-5
gradient_checkpointing: false
gradient_accumulation_steps: 32
per_device_train_batch_size: 2
per_device_eval_batch_size: 2
weight_decay: 0.00
warmup_steps: 600
eval_steps: 500
save_steps: 500
max_length: 512
num_train_epochs: 3
logging_steps: 10
max_grad_norm: 2.0
save_total_limit: 4
fp16: false
eval_accumulation_steps:
freeze_layer:
datasets:
- webgpt
- squad_v2
cache_dir: .cache
loss_fn: CrossEntropyLoss
eval_size:
log_dir: "base"
quantization: false
seq2seqmodel: false
poly_eps: 1.0
fuse_gelu: true
log_wandb: true
samples_mixing: false # uses collator that mixes samples in the batch to create a single sample with possible multiple tasks within
verbose: false
output_dir: saved_model
pythia:
learning_rate: 8e-6
model_name: EleutherAI/pythia-70m-deduped
weight_decay: 0.01
max_length: 520
warmup_steps: 1000
gradient_checkpointing: false
gradient_accumulation_steps: 9
per_device_train_batch_size: 2
per_device_eval_batch_size: 4
output_dir: pythia_model
RM
数据样例
{
"question full text":["答案1","答案2"] # 跟据分数排名
}
训练脚本
cd ../reward/instructor
mkdir model
python trainer.py configs/deberta-v3-base.yml --output_dir ./reward_model
配置文件
model_name: microsoft/deberta-v3-base
learning_rate: 1e-5
scheduler: cosine
gradient_checkpointing: false
gradient_accumulation_steps: 16
per_device_train_batch_size: 2
warmup_steps: 600
eval_steps: 200
save_steps: 500
max_length: 512
num_train_epochs: 2
datasets:
- webgpt
- hfsummary
RL
数据样例
"<human>+随机选择一个prompt.format(原句)+<bot>"
训练脚本
cd ../../model_training
python trainer_rl.py --configs defaults_rlhf --cache_dir ./cache --rank_model ../reward/instructor/reward_model --sft_model ../model_training/sft_model --output_dir ./rl_model
配置文件
defaults_rlhf:
dataset:
rank_model: TODO
sft_model: TODO
eval_prompts:
batch_size: 64
epochs: 10
datasets:
- oa_private:
split: rl
val_split: 0.0
fraction: 1
file: 2023-02-10_oasst_prod.jsonl
cache_dir: .cache
quantization: false
seq2seqmodel: false
output_dir: output
reward_model_batch_size: 32
debug_rlhf:
rank_model: /local/home/sanagnos/general/Open-Assistant/model/reward/instructor/facebook/galactica-125m-finetuned/checkpoint-500/
sft_model: /local/home/sanagnos/general/Open-Assistant/model/model_training/EleutherAI/pythia-70m-deduped-base-finetuned/checkpoint-20/
batch_size: 2
相关链接
CarperAI/trlx: A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) (github.com)
microsoft/DeepSpeed: DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective. (github.com)
TimDettmers/bitsandbytes: 8-bit CUDA functions for PyTorch (github.com)
huggingface/evaluate: 🤗 Evaluate: A library for easily evaluating machine learning models and datasets. (github.com)
wkentaro/gdown: Download a large file from Google Drive (curl/wget fails because of the security notice). (github.com)
wandb/wandb: 🔥 A tool for visualizing and tracking your machine learning experiments. This repo contains the CLI and Python API. (github.com)
huggingface/transformers: 🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX. (github.com)
pytorch/pytorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration (github.com)
进NLP群—>加入NLP交流群(备注nips/emnlp/nlpcc进入对应投稿群)
持续发布自然语言处理NLP每日优质论文解读、相关一手资料、AI算法岗位等最新信息。
加入星球,你将获得:
1. 每日更新3-5篇最新最优质的的论文速读。用几秒钟就可掌握论文大致内容,包含论文一句话总结、大致内容、研究方向以及pdf下载。
2. 最新入门和进阶学习资料。包含机器学习、深度学习、NLP等领域。
3. 具体细分NLP方向包括不限于:情感分析、关系抽取、知识图谱、句法分析、语义分析、机器翻译、人机对话、文本生成、命名实体识别、指代消解、大语言模型、零样本学习、小样本学习、代码生成、多模态、知识蒸馏、模型压缩、AIGC、PyTorch、TensorFlow等细方向。
4. 每日1-3个NLP、搜广推、CV等AI岗位招聘信息。可安排模拟面试。