医疗大模型实战——MedicalGPT项目记录

(logo由DALL·E生成)

        本文是一个医疗大模型项目具体实现的分享,代码使用的是由shibing624提供的开源项目(特别感谢) 。选取这个项目的原因在于其支持的模型结构较多,数据集清晰,大模型全生命周期完整,项目作者活跃度也很高,回答问题速度很快。另外,还有微信交流群以供交流讨论。

        本文选择阿里通义千问模型Qwen-7B-Chat,其被证实拥有不错的中文能力,同时还是比较典型的因果语言模型 (Casual LM),这一点主要区分质谱清言推出的ChatGLM系列。

       在数据集的选取来自1)中文医疗对话数据集Toyhom/Chinese-medical-dialogue-data的六个科室医疗问诊数据, 有79万条;2)在线医疗百科 huatuo_encyclopedia_qa ,有36万条;3)医疗知识图谱 huatuo_knowledge_graph_qa,有79万条。三部分合并,共195万条。本文随机抽取其中十分之一的数据量,即195k。

        实验开始前需要注意的是,本文的内容是想通过训练医疗模型作为背景来完成大模型全生命周期的实现。实际在应用场景,完整走过三个流程训练出的模型并不是第一选择。一方面原因是成本因素,另一方面原因是尤其在增量预训练阶段,往往需要将领域数据+通用数据以避免大模型的灾难遗忘现象。这对于数据集的质量要求很高,同时需要做到完全shuffle。这对于个人玩家甚至是普通企业玩家要求较高。另外一个重要的原因是经过增量预训练大模型的对齐会被破坏需要重新进行。这给模型的训练带来了更多不必要的成本,而领域知识的注入往往只需要通过SFT阶段即可注入,通过SFT可以避免这个现象,通常用企业发布的chat版本的模型进行微调即可,而不需要再完成偏好对齐的步骤。最后,大模型的增量预训需要拥有海量的领域数据,将这些海量数据和通用领域数据混合后的训练数据集极大,这对于普通大模型定制是没有必要的。

       之后计划公布通过SFT训练的真正可用的模型。下面让我们开始吧!


目录

0 硬件配置 (Hardware Configuration)

1 增量预训练 (Incremental Pre-traing)

1.1 实验准备

1.2 开始实验

1.3 参数合并

2 监督微调 (Supervised fine-tuning)

2.1 实验准备

2.2 开始实验

2.3 参数合并

3 偏好对齐 (Human Preferences Alignment)

3.1 实验准备

3.2 开始实验

3.2.1 基于人类反馈的强化学习 (RLHF)

3.2.2 直接偏好优化 (DPO) 

4 参考内容 (References)       


0 硬件配置 (Hardware Configuration)

        本文的硬件环境是一台8卡服务器,拥有8张 NVIDIA GeForce RTX 4090 (24G)显存显卡,由于是共用,故本文大多数情境下并没有完全利用上所有显卡,一般使用4到5块。需要注意的是,由于本文使用的是数据并行,并没有利用到模型并行的技术,故更多的显卡只是令训练更快,理论上,只用一张显卡仍然能够完成本文的所有实验。


1 增量预训练 (Incremental Pre-traing)

        首先需要明确的是,即使使用的模型是经过微调对齐后的模型,增量预训练过程仍会破坏这个“距离”,即经过增量预训练之后的模型将会成为“base model”。没有回答能力,而只有续写能力,没有部署的必要 (这也是为什么选择chat和base模型在这一步区别不大的原因,后续仍需要微调和强化学习来对齐) 。

        现今一个具有竞争力的大模型动辄上百亿的参数量让从头预训练和全参数微调对于普通玩家而言几乎是不可能完成的任务。但是我们如果可以在开源的大模型基础上添加医疗数据进行增量预训练以向模型注入医疗方面的垂直知识是一个不错的选择。shibing624也提出:

        可以看出,增量预训练并不是必选项,且对无监督预训练数据集的质量要求较高。如果不进行增量预训练而只是在STF (supervised fine-tuning) 阶段注入知识也可以让模型学习到相关的领域知识。再次,本文展现的是一个完整的大模型训练周期,故添加了这部分。从训练结果来看,这部分的训练是有效的。

1.1 实验准备

export CUDA_VISIBLE_DEVICES=0,3,4,5,6
accelerate launch pretraining.py \
    --model_type auto \
    --model_name_or_path /data2/dataset/wzx/MedicalGPT/model/Qwen-7B \
    --train_file_dir ./data/pretrain \
    --validation_file_dir ./data/pretrain \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --do_train \
    --do_eval \
    --use_peft True \
    --seed 42 \
    --max_train_samples 10000 \
    --max_eval_samples 10 \
    --num_train_epochs 2 \
    --learning_rate 2e-4 \
    --warmup_ratio 0.05 \
    --weight_decay 0.01 \
    --logging_strategy steps \
    --logging_steps 10 \
    --eval_steps 50 \
    --evaluation_strategy steps \
    --save_steps 500 \
    --save_strategy steps \
    --save_total_limit 13 \
    --gradient_accumulation_steps 4 \
    --preprocessing_num_workers 10 \
    --block_size 512 \
    --output_dir outputs-pt-qwen-7b \
    --overwrite_output_dir \
    --ddp_timeout 30000 \
    --logging_first_step True \
    --target_modules all \
    --lora_rank 8 \
    --lora_alpha 16 \
    --lora_dropout 0.05 \
    --torch_dtype bfloat16 \
    --bf16 \
    --device_map auto \
    --report_to tensorboard \
    --ddp_find_unused_parameters False \
    --gradient_checkpointing True \
    --cache_dir ./cache

        更改的地方包括gradient_accumulation_steps,num_train_epochs,per_device_eval_batch_size,per_device_train_batch_size从实验过程中的显存占用情况来看,设置是合理的。另外使用 RTX 3090 或 4000 系列 GPU 时,PyTorch Accelerate 库存在问题。这些 GPU 不支持通过 P2P (点对点) 或 IB (InfiniBand) 的更快通信宽带。故改用了accelerate launch方法,代码也相应更改。本次实验用到了5张显卡。

1.2 开始实验

        显存占用的比例很高都在20G左右,(其中cuda:6由于同时还有其他例程运行占用的显存较高一些) 。证明超参数的设置比较合理。训练过程大约25分钟,模型的loss从2.9下降到了2.2左右并震荡,训练后的权重文件:             

1.3 参数合并

        合并后的模型如图:

        至此,第一阶段增量预训练阶段结束。


2 监督微调 (Supervised fine-tuning)

2.1 实验准备

        数据集的格式是alpaca (Instruct-Input-Response) 格式,而代码所需的格式是sharegpt格式的。需要转换,同时,代码需要的是jsonl后缀的数据集文件,故还应该转换成jsonl形式,即数据集的每一行都是一个json文件。数据集包含约195K轮对话数据。

        处理后的文件格式:

2.2 开始实验

        首先需要注意的是,代码的模型template是vicuna。而我们选用的qwen模型的template是chatml,具体如图:   

        在代码中,不同template对应的special token处理不同,如果不正确设置的话,代码会报错。例如,如果在本实验中,如果不设置template而使用默认值的话,模型会将每个input_id后接上一个None值,使代码无法继续正常运行:

        Qwen模型的template如图:

        sh设置:

#!/bin/bash

# 设置CUDA可见设备
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6

# 使用accelerate启动训练
accelerate launch --main_process_port 28500 supervised_finetuning.py \
    --model_type auto \
    --model_name_or_path /data2/dataset/wzx/MedicalGPT/merged-gpt-chat/ \
    --train_file_dir ./data/finetune \
    --validation_file_dir ./data/finetune \
    --template_name chatml \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --do_train \
    --use_peft True \
    --model_max_length 1024 \
    --num_train_epochs 2 \
    --learning_rate 2e-5 \
    --warmup_ratio 0.05 \
    --weight_decay 0.05 \
    --logging_strategy steps \
    --logging_steps 10 \
    --save_steps 500 \
    --save_strategy steps \
    --save_total_limit 13 \
    --gradient_accumulation_steps 4 \
    --preprocessing_num_workers 4 \
    --output_dir outputs-sft-qwen-chat-v1 \
    --overwrite_output_dir \
    --ddp_timeout 30000 \
    --logging_first_step True \
    --target_modules all \
    --lora_rank 8 \
    --lora_alpha 16 \
    --lora_dropout 0.05 \
    --torch_dtype float16 \
    --fp16 \
    --device_map auto \
    --report_to tensorboard \
    --ddp_find_unused_parameters False \
    --gradient_checkpointing True \
    --cache_dir ./cache

        将merge后的模型的config设置为fp16设为True,bf16为False,开始训练。本文此阶段使用了7张显卡。具体的显存占用情况如图:

        显存占用情况还是很优秀的,代码运行了4epoch,用时约6小时,0~0.14epoch的loss变化如图:

        可以看出收敛速度还是很快的,代码在第二个epoch开始loss基本稳定在2左右震荡,这里跑了2个epoch:

        最后的loss是2.01左右。取最后一次的checkpoint进行合并。

2.3 参数合并

        这一部分的内容和1.3一致,不再赘述。至此,第二阶段监督微调结束。


3 偏好对齐 (Human Preferences Alignment)

3.1 实验准备

        本文测试了两种人类偏好对齐的方法,分别是基于人类反馈的强化学习 (RLHF) 和 直接偏好优化 (DPO):

        本文成功运行了DPO方法。此阶段的数据集包含两个output,分别是需要被拒绝的低质量回复和应该被接受的高质量回复。

       需要注意的是,在预训练 (Pretrain) 和微调 (Fine-tuning) 的过程中,采用 bfloat16 对模型的性能可能影响不大。这是因为在这些阶段,计算损失函数和梯度时通常涉及到大量的样本和较大批量,这有助于减少因采用低精度数值表示而带来的精确度损失。另一方面,在强化学习和某些模型训练场景中,计算奖励 (reward) 模型时往往只涉及较少的样本和较小的批量。在这些情况下,使用 bfloat16 可能会引起数值的不稳定和精度的损失,进而影响模型的性能。这可能是使用 bfloat16 时遇到损失降为零的原因。实际应用中,是否选择使用 bfloat16 取决于特定的任务和模型。有些任务和模型对数值精度的需求较高,在这种情况下,使用 bfloat16 可能会导致性能降低。然而,对于其他任务和模型,使用 bfloat16 可能不会对性能造成显著影响。因此,建议根据任务和模型的具体需求来确定使用哪种数值精度。总的来说,尽管在预训练和微调阶段可以使用 bfloat16,但在某些特定情况下,如在强化学习中计算奖励模型时,使用 bfloat16 可能会引起数值不稳定和精度损失。为了保证模型性能,建议根据任务和模型的需求选择合适的数值精度。

3.2 开始实验

3.2.1 基于人类反馈的强化学习 (RLHF)

        RLHF(Reinforcement Learning from Human Feedback)是一种结合了人类反馈的强化学习方法。由论文《Training language models to follow instructions with human feedback》引入大模型训练领域。李沐老师对这一方法有详细地讲解,本文不做过多介绍。论文给出的流程图如下:

        分两个阶段:奖励模型建模和强化学习训练。奖励模型建模,构造人类偏好排序数据集,训练奖励模型,以对齐人类偏好,主要是"HHH"原则,具体是"helpful, honest, harmless"  。

        需要注意的是,这个阶段的训练对显存的要求比之前任务更高。此阶段用了6张显卡,将--device_map设为auto开启模型并行,用时约7分钟,显存占用情况如图:

        结果如图:    

         第二阶段强化学习阶段由于模型需要在显卡上加载两次,对显卡显存要求极高,这里没有复现。

3.2.2 直接偏好优化 (DPO) 

        此阶段使用了6张显卡,用时约7分钟,显存占用情况如图:

       最终结果: 


4 参考内容 (References)

       MedicalGPT: Training Medical GPT Model

       通义千问——一个不断进化的AI大模型

        Lora: Low-rank adaptation of large language models

        240万条中文医疗数据集(包括预训练、指令微调和奖励数据集)

        Training language models to follow instructions with human feedback

        Chinese-LLaMA-Alpaca

        LLaMA-Factory

        

  • 45
    点赞
  • 50
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值