1. 前言
最近DeepSeek因为将接近O1性能的R1进行了开源,市场反应很火爆,各家机构纷纷接入DeepSeek,以接入DeepSeek为荣。在这样的热潮之下,甚至DeepSeek的老板梁文峰都已经被搅得心境不复从前,最近发表的一篇论文居然他亲自上传到arxiv,百忙之中抽时间上传论文。
但热潮归热潮,也不可否认DeepSeek-R1很优秀,但其实用过之后,会发现DeepSeek R1离真正服务业务还有一段距离,并不能很好地解决一些实际的问题,最近也有一些客户反馈DS并没有看起来那么聪明,有一定的GAP。
2. DeepSeek V3/R1微调
在垂直行业的应用上,可能还是需要一定的行业数据微调。最近ColossalAI上线了DS-R1的微调方案。主要支持DeepSeek V3/R1 671B LoRA 微调,完整强化学习工具链(PPO、GRPO、DPO、SimPO 等),以及适配 DeepSeek 蒸馏模型。虽然DeepSeek V3/R1 拥有 6710 亿参数,参数量比较大,但微调步骤和一般大模型微调类似。
2.1 数据准备
支持 JSONL 格式,每行就是一条样本,代表一组对话,如:
[{"role": "user", "content": "能给我讲一个寓意深刻的故事吗?"},
{"role": "assistant", "content": "好的,我将按照您的要求创作一个富有寓意的故事。为了让故事更完整,我会先列出一些基本的情节脉络。请您看看以下内容是否符合您的预期。如果您有其他想法,可以随时提出,我会进行调整。\n\n### 世界观设定\n\n- **时空背景**:近未来时空管理局内部,具有永恒循环属性的微观宇宙观测站。\n\n- **核心矛盾**:觉醒的机械观测员发现世界真相后,在永恒职责与自由意志之间的挣扎。\n\n- **特殊设定**:宇宙沙漏中的星尘具有生命重塑功能,观测站 ... 永恒的精神图腾。\n\n希望这个故事能满足您对\"寓意深刻\"的要求。如果需要调整某些隐喻的浓度或增加情节转折,请随时告诉我"}]
当然也是支持自定义系统提示词。
2.2 模型权重准备
建议使用 BF16(bfloat16)格式 进行微调。如果已经下载 FP8 格式的 DeepSeek V3/R1 权重,也可使用官方转换脚本转换为 BF16。
DeepSeek FP8 → BF16 转换脚本(需 GPU),下面代码很长,但主要关注其中的反量化函数weight_dequant_kernel。FP8 是一种低精度浮点数格式,用于存储量化后的权重或激活值。为了恢复其原始精度(FP16),通过反量化操作将其转换为更高精度的浮点数格式。反量化的核心公式为:y=x⋅s,其中x 是量化后的 FP8 数据,s 是缩放因子(scale factor),用于将 FP8 数据恢复到 FP16 范围,y 是反量化后的 FP16 数据。
import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import torch
from safetensors.torch import load_file, save_file
from kernel import weight_dequant
def main(fp8_path, bf16_path):
"""
Converts FP8 weights to BF16 and saves the converted weights.
This function reads FP8 weights from the specified directory, converts them to BF16,
and saves the converted weights to another specified directory. It also updates the
model index file to reflect the changes.
Args:
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
Raises:
KeyError: If a required scale_inv tensor is missing for a weight.
Notes:
- The function assumes that the FP8 weights are stored in safetensor files.
- The function caches loaded safetensor files to optimize memory usage.
- The function updates the model index file to remove references to scale_inv tensors.
"""
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
"""
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
Args:
tensor_name (str): The name of the tensor to retrieve.
Returns:
torch.Tensor: The retrieved tensor.
Raises:
KeyError: If the tensor does not exist in the safetensor file.
"""
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
Returns:
None
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M, N).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
]
2.3 启动微调
Colossal-AI 提供 LoRA 微调脚本,兼容 Hugging Face PEFT。 运行以下命令,即可进行微调:
colossalai run --hostfile path-to-host-file --nproc_per_node 8 \
lora_finetune.py --pretrained path-to-DeepSeek-R1-bf16 \
--dataset path-to-dataset.jsonl --plugin moe \
--lr 2e-5 --max_length 256 -g --ep 8 --pp 3 \
--batch_size 24 --lora_rank 8 --lora_alpha 16 \
--num_epochs 2 --warmup_steps 8 --tensorboard_dir logs \
--save_dir DeepSeek-R1-bf16-lora
这里给出lora_finetune.py脚本:
"""
Supervised fine-tuning of MoE models like Deepseek V3/R1 on a downstream task.
"""
import argparse
import json
import os
import resource
from contextlib import nullcontext
from types import MethodType
import torch
import torch.distributed as dist
from coati.dataset.loader import RawConversationDataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.booster import Booster
from colossalai.booster.plugin import (
GeminiPlugin,
HybridParallelPlugin,
LowLevelZeroPlugin,
MoeHybridParallelPlugin,
Plugin,
TorchDDPPlugin,
)
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor:
loss = loss.data
group = getattr(plugin, "dp_group", None)
dist.all_reduce(loss, group=group)
return loss / dist.get_world_size(group)
def train(args) -> None:
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch()
accelerator = get_accelerator()
coordinator = DistCoordinator()
# ==============================
# Initialize Booster
# ==============================
if args.plugin == "ddp":
plugin = TorchDDPPlugin(find_unused_parameters=True if args.use_grad_checkpoint is False else False)
elif args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
)
elif args.plugin == "moe":
plugin = MoeHybridParallelPlugin(
ep_size=args.ep,
tp_size=args.tp,
pp_size=args.pp,
zero_stage=args.zero_stage,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)
def is_master():
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
return coordinator.rank == coordinator.world_size - 1
return coordinator.is_master()
# ==============================
# Initialize Tensorboard and Save Config
# ==============================
if is_master():
if args.tensorboard_dir is not None:
from torch.utils.tensorboard import SummaryWriter
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
coordinator.print_on_master(
f"Training Info:\nConfig file: {args.config_file} \nTensorboard logs: {args.tensorboard_dir} \nModel checkpoint: {args.save_dir}"
)
coordinator.print_on_master(f"Load dataset: {args.dataset}")
dataset = RawConversationDataset(
tokenizer,
args.dataset,
args.max_length,
)
dataloader = plugin.prepare_dataloader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
)
coordinator.print_on_master(
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
else nullcontext()
)
attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
config = AutoConfig.from_pretrained(args.pretrained, trust_remote_code=True)
with init_ctx:
# from_pretrained is not compatible with LoRA, we load pretrained weights later.
# model = AutoModelForCausalLM.from_pretrained(
# args.pretrained,
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
# trust_remote_code=True,
# attn_implementation=attn_impl,
# )
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=True,
attn_implementation=attn_impl,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
)
if args.lora_rank > 0:
if model.__class__.__name__.startswith("DeepseekV3"):
lora_config = LoraConfig(
task_type="CAUSAL_LM",
r=args.lora_rank,
lora_alpha=args.lora_alpha,
target_modules=["gate_proj", "up_proj", "down_proj"],
)
else:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=args.lora_alpha)
model = booster.enable_lora(model, lora_config=lora_config)
# this is essential, otherwise the grad checkpoint will not work.
model.train()
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
if model.config.__class__.__name__.startswith("DeepseekV3"):
model.config.use_cache = False
model.eval()
# enable grad for moe layers
for m in model.modules():
if m.__class__.__name__ == "DeepseekV3MoE":
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
model_numel = sum(p.numel() for p in model.parameters())
coordinator.print_on_master(f"Model params: {model_numel / 1e9:.2f} B")
optimizer = HybridAdam(
model_params=model.parameters(),
lr=args.lr,
betas=(0.9, 0.95),
weight_decay=args.weight_decay,
adamw_mode=True,
)
if args.warmup_steps is None:
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
lr_scheduler = CosineAnnealingWarmupLR(
optimizer=optimizer,
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
warmup_steps=args.warmup_steps,
eta_min=0.1 * args.lr,
)
# Flash attention will be disabled because it does NOT support fp32.
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
dataloader=dataloader,
)
torch.set_default_dtype(torch.float)
booster.load_model(model, args.pretrained)
coordinator.print_on_master(
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
coordinator.print_on_master(
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
)
start_epoch = 0
start_step = 0
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch)
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
data_iter = iter(dataloader)
step_bar = tqdm(
range(len(dataloader)),
desc="Step",
disable=not is_master(),
)
for step in step_bar:
outputs = booster.execute_pipeline(
data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=True,
)
loss = outputs["loss"]
if booster.plugin.stage_manager.is_last_stage():
global_loss = all_reduce_mean(loss, plugin)
optimizer.step()
if booster.plugin.stage_manager.is_last_stage():
grad_norm = optimizer.get_grad_norm()
step_bar.set_postfix({"loss": global_loss.item(), "grad_norm": grad_norm})
if args.tensorboard_dir is not None and is_master():
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
writer.add_scalar(tag="Loss", scalar_value=global_loss.item(), global_step=global_step)
writer.add_scalar(
tag="Learning Rate",
scalar_value=lr_scheduler.get_last_lr()[0],
global_step=global_step,
)
writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
lr_scheduler.step()
optimizer.zero_grad()
else:
pbar = tqdm(
dataloader,
desc=f"Epoch {epoch}",
disable=not is_master(),
initial=start_step // args.accumulation_steps,
)
total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(pbar, start=start_step // args.accumulation_steps):
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
batch_output = model(**batch)
loss = batch_output.loss / args.accumulation_steps
total_loss.add_(loss.data)
booster.backward(loss=loss, optimizer=optimizer)
if (step + 1) % args.accumulation_steps == 0:
all_reduce_mean(total_loss, plugin)
optimizer.step()
grad_norm = optimizer.get_grad_norm()
pbar.set_postfix({"loss": total_loss.item(), "grad_norm": grad_norm})
if args.tensorboard_dir is not None and is_master():
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
writer.add_scalar(
tag="Learning Rate",
scalar_value=lr_scheduler.get_last_lr()[0],
global_step=global_step,
)
writer.add_scalar(tag="Grad Norm", scalar_value=grad_norm, global_step=global_step)
lr_scheduler.step()
optimizer.zero_grad()
total_loss.fill_(0.0)
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# Final save.
coordinator.print_on_master("Start saving final model checkpoint")
if args.lora_rank > 0:
booster.save_lora_as_pretrained(model, os.path.join(args.save_dir, "lora"))
else:
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Basic training information.
parser.add_argument(
"-m",
"--pretrained",
type=str,
required=True,
help="Address of the pre-trained model",
)
parser.add_argument("-d", "--dataset", type=str, required=True, help="Raw Jonl dataset for training.")
parser.add_argument(
"-p",
"--plugin",
type=str,
default="zero2",
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d", "ddp", "moe"],
help="Choose which plugin to use",
)
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
parser.add_argument("--tensorboard_dir", type=str, default=None, help="Tensorboard directory")
parser.add_argument("--config_file", type=str, default="training_config.json", help="Config file")
# Training parameters
parser.add_argument("-n", "--num_epochs", type=int, default=1, help="Number of training epochs")
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
parser.add_argument("--batch_size", type=int, default=2, help="Global Batch size of each process")
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
parser.add_argument(
"--mixed_precision",
type=str,
default="bf16",
choices=["fp16", "bf16"],
help="Mixed precision",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument(
"-g",
"--use_grad_checkpoint",
action="store_true",
default=False,
help="Use gradient checkpointing",
)
parser.add_argument(
"-f",
"--use_flash_attn",
action="store_true",
default=False,
help="Use flash-attention",
)
# Additional arguments for 3d plugin.
parser.add_argument("--tp", type=int, default=1, help="TP size, used for 3d plugin.")
parser.add_argument("--pp", type=int, default=1, help="PP size, used for 3d plugin.")
parser.add_argument("--sp", type=int, default=1, help="SP size, used for 3d plugin.")
parser.add_argument("--ep", type=int, default=1, help="EP size, used for moe plugin.")
parser.add_argument("--zero_stage", type=int, default=1, help="Zero stage, used for 3d plugin.", choices=[0, 1, 2])
parser.add_argument(
"--sp_mode",
type=str,
default="split_gather",
choices=["split_gather", "ring", "all_to_all"],
help="SP mode, used for 3d plugin.",
)
parser.add_argument(
"--enable_sequence_parallelism",
default=False,
action="store_true",
help="Whether to enable SP, used for 3d plugin.",
)
parser.add_argument(
"--zero_cpu_offload", default=False, action="store_true", help="Whether to use offloading, used for 3d plugin."
)
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
parser.add_argument("--lora_alpha", type=int, default=8, help="lora alpha when using lora to train.")
args = parser.parse_args()
if args.plugin in ["3d", "moe"] and args.pp > 1 and args.accumulation_steps > 1:
raise ValueError("Accumulation steps should be 1 when using PP. Please adjust batch size directly.")
train(args)
2.4 LoRA 优化硬件需求
利用 LoRA 优化后,微调 DeepSeek V3/R1 最低硬件需求:
需 24 张 H100 GPU(ep=8, pp=3) 即可运行
启用
--zero_cpu_offload
,可进一步降低硬件需求(但训练速度下降)
实验结果:DeepSeek V3/R1 671B 训练过程中,loss 逐步下降,收敛稳定。 对于低预算团队,可结合强化学习,高效构建类似 DeepSeek R1 的私有模型。
2.5 强化学习微调(RLHF)—— DeepSeek 蒸馏模型
框架已经实现 GRPO 算法 和 可验证奖励机制,在 Qwen-3B-Base 上完成实验。
奖励函数设计:
格式错误:奖励 =
0
格式正确但结果错误:奖励 =
1
格式和结果都正确:奖励 =
10
Qwen2.5-3B 对话模板
{
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"system_message": "You are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a math problem that involves reasoning. After thinking, when you finally reach a conclusion, clearly output the final answer without explanation within the <answer> </answer> tags, your final answer should be a integer without unit, currency mark, thousands separator or other text. i.e., <answer> 123 </answer>.\n",
"stop_ids": [
151643
],
"end_of_assistant": "<|endoftext|>",
"response_format_tags": {
"think_start": {
"text": "<think>",
"num_occur": 1
},
"think_end": {
"text": "</think>",
"num_occur": 1
},
"answer_start": {
"text": "<answer>",
"num_occur": 1
},
"answer_end": {
"text": "</answer>",
"num_occur": 1
}
}
}
GRPO 训练脚本:
"""
GRPO trainer
"""
import os
from typing import Dict, List, Optional, Union
import torch
import wandb
from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience, NaiveExperienceMaker
from coati.models import RewardModel, RLVRRewardModel
from coati.models.loss import GPTLMLoss, PolicyLoss
from coati.models.utils import calc_action_log_probs
from coati.trainer.callbacks import Callback
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import OLTrainer
from .utils import AnnealingScheduler, CycledDataLoader, is_rank_0, to_device
def _set_default_generate_kwargs(actor: PreTrainedModel) -> Dict:
"""
Set default keyword arguments for generation based on the actor model.
Args:
actor (PreTrainedModel): The actor model.
Returns:
Dict: A dictionary containing the default keyword arguments for generation.
"""
unwrapped_model = actor.unwrap()
new_kwargs = {}
# use huggingface models method directly
if hasattr(unwrapped_model, "prepare_inputs_for_generation"):
new_kwargs["prepare_inputs_fn"] = unwrapped_model.prepare_inputs_for_generation
if hasattr(unwrapped_model, "_update_model_kwargs_for_generation"):
new_kwargs["update_model_kwargs_fn"] = unwrapped_model._update_model_kwargs_for_generation
return new_kwargs
class GRPOTrainer(OLTrainer):
"""
Trainer for GRPO algorithm.
Args:
strategy (Booster): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
reward_model (RewardModel): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
train_batch_size (int, defaults to 8): the batch size to use for training
buffer_limit (int, defaults to 0): the max_size limitation of buffer
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def __init__(
self,
actor_booster: Booster,
actor: PreTrainedModel,
reward_model: Union[RewardModel, RLVRRewardModel],
initial_model: PreTrainedModel,
actor_optim: Optimizer,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
kl_coef: float = 0.1,
ptx_coef: float = 0.9,
train_batch_size: int = 8,
buffer_limit: int = 0,
buffer_cpu_offload: bool = True,
eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.2,
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
save_interval: int = 0,
save_dir: str = None,
use_tp: bool = False,
num_generation: int = 8,
inference_batch_size: int = None,
logits_forward_batch_size: int = None,
temperature_annealing_config: Optional[Dict] = None,
coordinator: DistCoordinator = None,
callbacks: List[Callback] = [],
**generate_kwargs,
) -> None:
if isinstance(actor_booster, GeminiPlugin):
assert not offload_inference_models, "GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer = NaiveExperienceBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
super().__init__(actor_booster, None, data_buffer, sample_buffer, dataloader_pin_memory, callbacks=callbacks)
self.generate_kwargs = _set_default_generate_kwargs(actor)
self.generate_kwargs.update(generate_kwargs)
self.actor = actor
self.actor_booster = actor_booster
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.experience_maker = NaiveExperienceMaker(
self.actor,
None,
reward_model,
initial_model,
self.tokenizer,
kl_coef,
use_grpo=True,
num_generation=num_generation,
inference_batch_size=inference_batch_size,
logits_forward_batch_size=logits_forward_batch_size,
)
if temperature_annealing_config:
# use annealing
self.temperature_annealing_scheduler = AnnealingScheduler(
temperature_annealing_config["start_temperature"],
temperature_annealing_config["end_temperature"],
temperature_annealing_config["annealing_warmup_steps"],
temperature_annealing_config["annealing_steps"],
)
else:
self.temperature_annealing_scheduler = None
self.train_batch_size = train_batch_size
self.actor_loss_fn = PolicyLoss(eps_clip)
self.vf_coef = vf_coef
self.ptx_loss_fn = GPTLMLoss()
self.ptx_coef = ptx_coef
self.actor_optim = actor_optim
self.save_interval = save_interval
self.apply_loss_mask = apply_loss_mask
self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor")
self.num_train_step = 0
self.accumulation_steps = accumulation_steps
self.use_tp = use_tp
self.accumulative_meter = AccumulativeMeanMeter()
self.offload_inference_models = offload_inference_models
self.device = get_current_device()
def _before_fit(
self,
prompt_dataloader: DataLoader,
pretrain_dataloader: Optional[DataLoader] = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.prompt_dataloader = CycledDataLoader(prompt_dataloader)
self.pretrain_dataloader = CycledDataLoader(pretrain_dataloader) if pretrain_dataloader is not None else None
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-grpo", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "grpo")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _setup_update_phrase_dataload(self):
"""
why not use distributed_dataloader?
if tp is used, input on each rank is the same and we use the same dataloader to feed same experience to all ranks
if tp is not used, input on each rank is different and we expect different experiences to be fed to each rank
"""
self.dataloader = DataLoader(
self.data_buffer,
batch_size=self.train_batch_size,
shuffle=True,
drop_last=True,
pin_memory=self.dataloader_pin_memory,
collate_fn=self.data_buffer.collate_fn,
)
def _make_experience(self, collect_step: int) -> Experience:
"""
Make experience
"""
prompts = self.prompt_dataloader.next()
if self.offload_inference_models:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self.experience_maker.initial_model.to(self.device)
self.experience_maker.reward_model.to(self.device)
if self.temperature_annealing_scheduler:
self.generate_kwargs["temperature"] = self.temperature_annealing_scheduler.get_temperature()
return self.experience_maker.make_experience(
input_ids=prompts["input_ids"].to(get_current_device()),
attention_mask=prompts["attention_mask"].to(get_current_device()),
gt_answer=prompts["gt_answer"],
**self.generate_kwargs,
)
def _training_step(self, experience: Experience):
"""
Args:
experience:
sequences: [batch_size, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>
"""
self.actor.train()
num_actions = experience.action_log_probs.size(1)
# policy loss
actor_logits = self.actor(input_ids=experience.sequences, attention_mask=experience.attention_mask)[
"logits"
] # [batch size, prompt_length + response_length]
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
action_log_probs,
experience.action_log_probs,
experience.advantages.unsqueeze(dim=-1).repeat_interleave(action_log_probs.size(-1), dim=-1),
action_mask=experience.action_mask if self.apply_loss_mask else None,
)
# sequence that is not end properly are not counted in token cost
token_cost = torch.sum(
(experience.sequences[:, -num_actions:] != self.tokenizer.pad_token_id).to(torch.float), axis=-1
).to(actor_logits.device)
end_properly = experience.sequences[:, -1] == self.tokenizer.pad_token_id
mean_token_cost = torch.sum(token_cost * end_properly) / torch.sum(end_properly)
actor_loss = (1 - self.ptx_coef) * actor_loss
if not to_skip:
self.actor_booster.backward(loss=actor_loss, optimizer=self.actor_optim)
# ptx loss
if self.ptx_coef != 0:
batch = self.pretrain_dataloader.next()
batch = to_device(batch, self.device)
outputs = self.actor(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
ptx_loss = outputs.loss
ptx_loss = self.ptx_coef * ptx_loss
self.actor_booster.backward(loss=ptx_loss, optimizer=self.actor_optim)
# sync
actor_loss_mean = all_reduce_mean(tensor=actor_loss)
max_ratio_mean = all_reduce_mean(tensor=max_ratio)
reward_mean = all_reduce_mean(tensor=experience.reward.mean())
advantages_mean = all_reduce_mean(tensor=experience.advantages.mean())
kl_mean = all_reduce_mean(tensor=experience.kl.mean())
mean_token_cost = all_reduce_mean(tensor=mean_token_cost)
if self.ptx_coef != 0:
ptx_loss_mean = all_reduce_mean(tensor=ptx_loss)
self.accumulative_meter.add("actor_loss", actor_loss_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("max_ratio", max_ratio_mean.to(torch.float16).item())
self.accumulative_meter.add("reward", reward_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("advantages", advantages_mean.to(torch.float16).item())
self.accumulative_meter.add("skip_ratio", 1.0 if to_skip else 0.0)
self.accumulative_meter.add("mean_token_cost", mean_token_cost.to(torch.float16).item())
self.accumulative_meter.add("kl", kl_mean.to(torch.float16).item())
if self.ptx_coef != 0:
self.accumulative_meter.add("ptx_loss", ptx_loss_mean.to(torch.float16).mean().item())
if self.num_train_step % self.accumulation_steps == self.accumulation_steps - 1:
self.actor_optim.step()
self.actor_optim.zero_grad()
self.actor_scheduler.step()
if self.temperature_annealing_scheduler:
self.temperature_annealing_scheduler.step_forward()
# preparing logging model output and corresponding rewards.
if self.num_train_step % 10 == 0:
response_text = self.experience_maker.tokenizer.batch_decode(
experience.sequences, skip_special_tokens=True
)
for i in range(len(response_text)):
response_text[i] = response_text[i] + f"\n\nReward: {experience.reward[i]}"
if self.writer and is_rank_0() and "wandb_run" in self.__dict__:
# log output to wandb
my_table = wandb.Table(
columns=[f"sample response {i}" for i in range(len(response_text))], data=[response_text]
)
try:
self.wandb_run.log({"sample_response": my_table})
except OSError as e:
self.coordinator.print_on_master(e)
elif self.writer and is_rank_0():
for line in response_text:
self.coordinator.print_on_master(line)
if self.writer and is_rank_0():
global_step = (self.num_train_step + 1) / self.accumulation_steps
self.writer.add_scalar("train/max_ratio", self.accumulative_meter.get("max_ratio"), global_step)
self.writer.add_scalar("train/skip_ratio", self.accumulative_meter.get("skip_ratio"), global_step)
self.writer.add_scalar("train/actor_loss", self.accumulative_meter.get("actor_loss"), global_step)
self.writer.add_scalar("train/lr_actor", self.actor_optim.param_groups[0]["lr"], global_step)
if self.ptx_coef != 0:
self.writer.add_scalar("train/ptx_loss", self.accumulative_meter.get("ptx_loss"), global_step)
self.writer.add_scalar("reward", self.accumulative_meter.get("reward"), global_step)
self.writer.add_scalar("token_cost", self.accumulative_meter.get("mean_token_cost"), global_step)
self.writer.add_scalar("approx_kl", self.accumulative_meter.get("kl"), global_step)
self.writer.add_scalar("advantages", self.accumulative_meter.get("advantages"), global_step)
self.accumulative_meter.reset()
self.num_train_step += 1
def _learn(self, update_step: int):
"""
Perform the learning step of the PPO algorithm.
Args:
update_step (int): The current update step.
Returns:
None
"""
if self.offload_inference_models:
self.experience_maker.initial_model.to("cpu")
self.experience_maker.reward_model.to("cpu")
# buffer may be empty at first, we should rebuild at each training
if self.sample_buffer:
experience = self.data_buffer.sample()
self._on_learn_batch_start()
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)
else:
if isinstance(self.dataloader.sampler, DistributedSampler):
self.dataloader.sampler.set_epoch(update_step)
pbar = tqdm(self.dataloader, desc=f"Train epoch [{update_step + 1}]", disable=not is_rank_0())
for experience in pbar:
self._on_learn_batch_start()
experience.to_device(self.device)
self._training_step(experience)
self._on_learn_batch_end(experience)
def _save_checkpoint(self, num_train_step: int = 0):
"""
Save the actor checkpoints with running states.
Args:
num_train_step (int): The current num_train_step number.
Returns:
None
"""
self.coordinator.print_on_master("\nStart saving actor checkpoint with running states")
save_checkpoint(
save_dir=self.actor_save_dir,
booster=self.actor_booster,
model=self.actor,
optimizer=self.actor_optim,
lr_scheduler=self.actor_scheduler,
epoch=0,
step=num_train_step + 1,
batch_size=self.train_batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved actor checkpoint at episode {(num_train_step + 1)} at folder {self.actor_save_dir}"
)
实验结果:即便是 3B 规模模型,随着训练迭代,平均奖励和回复长度都逐步增长。 训练过程中,模型甚至学会自我纠错,体现了强化学习的潜力。
讨论:虽然V3/R1的微调可行,但是对于算力资源还是有一定的要求的。很多组织可能并不具备这样的资源,所以后续会分享一种针对垂直领域更简单的微调方案。