在RL训练中,我们会遇到各种各样的batch, 眼花缭乱。在这里,我们细致梳理一下这些batch 代表了什么,和它们之间的关系。这个blog里面,主要涉及如下文件中的代码:
verl/verl/trainer/ppo/ray_trainer.py
verl/verl/workers/fsdp_workers.py
1. 背景
采用FSDP 进行GRPO 训练。GRPO是DeepSeek提出的PPO的高效变体。相较于PPO:
- GRPO省略了Reward Model,直接用Rule-based 方式计算reward。
- GRPO省略了Critic Model(评论家模型), 不再额外计算 V i V_{i} Vi
- 计算advantage直接基于Reward R i R_{i} Ri, 而原始PPO(等算法)基于 V i V_{i} Vi. 这表明 GRPO直接用rollout example的reward 度量其Value (价值).
2. batch size 相关参数
在verl/verl/trainer/config/ppo_trainer.yaml
中,我们会遇到若干与batch size相关的参数,让我们一一拆解分析。
Warning: verl/verl/trainer/config/ppo_trainer.yaml
中的配置参数可能会被运行脚本中的配置参数覆盖掉。
2.1 General
data.train_batch_size=60
trainer.n_gpus_per_node=6
trainer.nnodes=1
这里我们有1个节点,共6张卡。
注意:
- data.train_batch_size 必须可以整除 trainer.n_gpus_per_node, 否则报错。
2.2 actor_rollout_ref
## actor
actor_rollout_ref.actor.ppo_mini_batch_size=60
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 # 这个似乎没什么用
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1
### actor.fsdp_config
actor_rollout_ref.actor.fsdp_config.param_offload=False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
## fsdp_config
actor_rollout_ref.fsdp_config.fsdp_size=-1
## rollout
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8
actor_rollout_ref.rollout.n=12
actor_rollout_ref.rollout.tensor_model_parallel_size=2
## ref
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8
我们可以在ray_trainer.py 的 fit() 函数中粗略理解一下batch
with _timer('step', timing_raw): # 记录整个步骤时间
# 整个步骤 = 计算old policy + 计算ref policy + 计算adv(包括计算value + 计算reward)
# + 更新critic model + 更新actor model + sometimes validation + sometimes ckpt saving
with _timer('gen', timing_raw): # 记录生成序列时间
# 使用actor模型生成文本序列
# gen_batch:包含input_ids等生成所需数据
# gen_batch_output:生成的序列及其概率等信息
print('gen_batch shape: ', gen_batch.batch['input_ids'].shape)
'''actor 进行rollout之前,
gen_batch shape: torch.Size([60, 8192]),
data.train_batch_size = 60
-> gen_batch.batch['input_ids'].shape[0] = 60
'''
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
print("gen_batch_output.batch['prompt_token_ids'].shape: ", gen_batch_output.batch['prompts'].shape)
'''actor 进行rollout之后,
gen_batch_output.batch['prompt_token_ids'].shape: torch.Size([720, 8192])
data.train_batch_size = 60, actor_rollout_ref.rollout.n=12
-> gen_batch_output.batch['prompt_token_ids'].shape.shape[0] = 60 * 12 =720
'''
TL,DR:
在 verl/verl/workers/fsdp_workers.py
的class ActorRolloutRefWorker(Worker)
中,我们会与这些参数打上交道
# 对于Actor
actor_rollout_ref.actor.ppo_mini_batch_size 和 GPU个数将共同决定
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu
# 在rollout过程中计算 rollout sample 的 log_prob, 每个GPU的处理样例数直接由如下配置决定
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8
# 计算reference model的 log_prob, 每个GPU的处理样例数直接由如下配置决定
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8
3. fsdp_workers 与上述配置
3.1 ActorRolloutRefWorker 的配置函数
class ActorRolloutRefWorker(Worker):
"""
This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
or a hybrid engine based on the config.rollout
"""
def __init__(self, config: DictConfig, role: str):
'''
DictConfig: actor_rollout_ref 相关的配置
role: 该worker的角色,一般情况下 role = actor_rollout
'''
super().__init__()
self.config = config
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group()
# build device mesh for FSDP
world_size = torch.distributed.get_world_size()
# TODO(sgm): support FSDP hybrid shard for larger model
'''
trainer.n_gpus_per_node=6 -> world_size = 6
actor_rollout_ref.fsdp_config.fsdp_size=-1 -> fsdp_size = -1
'''
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size)
# build device mesh for Ulysses Sequence Parallel
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.actor.get('ulysses_sequence_parallel_size', 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
# 创建一个二维 GPU 网格,同时支持数据并行(DP)和序列并行(SP)。
print('self.ulysses_sequence_parallel_size: ', self.ulysses_sequence_parallel_size)
self.ulysses_device_mesh = init_device_mesh('cuda',
mesh_shape=(dp, self.ulysses_sequence_parallel_size),
mesh_dim_names=['dp', 'sp'])
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.role = role
'''
正如前面的注释, self.role = actor_rollout
'''
assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']
self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref']
self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref']
self._is_ref = self.role in ['ref', 'actor_rollout_ref']
'''
那么
self._is_actor = True,
self._is_rollout = True,
self._is_ref = Flase
'''
self._is_offload_param = False
self._is_offload_optimizer = False
if self._is_actor:
self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False)
self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False)
'''
根据actor_rollout_ref.actor的配置信息, actor使用FSDP时,不进行 param 和 optimizer的offloading
actor_rollout_ref.actor.fsdp_config.param_offload=False
-> self._is_offload_param = False
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False
-> self._is_offload_optimizer = False
'''
elif self._is_ref:
# TODO: it seems that manual offload is slowly than FSDP offload
self._is_offload_param = self.config.ref.fsdp_config.get('param_offload', False)
# normalize config
if self._is_actor:
'''
从配置读入:
actor_rollout_ref.actor.ppo_mini_batch_size=60
-> self.config.actor.ppo_mini_batch_size = 60
actor_rollout_ref.rollout.n=12
-> self.config.rollout.n = 12
更新 self.config.actor.ppo_mini_batch_size (60->720):
-> 60 * 12 = 720
'''
self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
'''
根据print出的结果,self.device_mesh.size() 似乎对应 GPU个数。目前是单机, 在多机上不知道是不是这样。
因此,self.device_mesh.size() = 6
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1
-> self.ulysses_sequence_parallel_size = 1
TO-BE-Verify:
self.ulysses_sequence_parallel_size = n 意味着n张卡协作处理一个sequence?
n=1,这大概意味着暂时不启用sequence parallelism?
'''
self.config.actor.ppo_mini_batch_size //= (self.device_mesh.size() // self.ulysses_sequence_parallel_size)
'''
再次更新 self.config.actor.ppo_mini_batch_size (720 -> 120)
self.config.actor.ppo_mini_batch_size = 720
首先获得分片个数,
shard = self.device_mesh.size() // self.ulysses_sequence_parallel_size
= 6 // 1
= 6
这意味着 self.config.actor.ppo_mini_batch_size (720)的并发量,并发为6?
那么在每一个进程上,被分配 ppo_mini_batch_size 为
self.config.actor.ppo_mini_batch_size = 720 // 6 = 120
'''
#更新self.config.actor.ppo_mini_batch_size 的过程被称为normalization, 更新后到 self.config.actor.ppo_mini_batch_size 必须大于0
assert self.config.actor.ppo_mini_batch_size > 0, f'ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization'
# micro bsz
# self.config.actor.ppo_micro_batch_size 已被弃用,不用管这里
if self.config.actor.ppo_micro_batch_size is not None:
self.config.actor.ppo_micro_batch_size //= (self.device_mesh.size() //
self.ulysses_sequence_parallel_size)
self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size
assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, \
f'normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}'
assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, \
f'normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}'
# normalize rollout config,
# 在 self.config.rollout.log_prob_micro_batch_size 不为None的时候,才进行如下的操作,但是,大部分情况都是none (即不执行如下)
if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
'''
从配置中读入:
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8
-> self.config.rollout.log_prob_micro_batch_size = 8
shard = self.device_mesh.size() // self.ulysses_sequence_parallel_size
= 6 //1
=6
每个shard (此时每个shard对应一个GPU)上分了 self.config.rollout.log_prob_micro_batch_size = 8//6 = 1 个 sequence
self.config.rollout.log_prob_micro_batch_size_per_gpu = 1
'''
self.config.rollout.log_prob_micro_batch_size //= (self.device_mesh.size() //
self.ulysses_sequence_parallel_size)
self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size
# normalize ref config
# 处理逻辑跟self.config.rollout.log_prob_micro_batch_size_per_gpu的计算差不多。
# 在 self.config.ref.log_prob_micro_batch_size 不为None的时候,才进行如下的操作,但是,大部分情况都是none (即不执行如下)
if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
self.config.ref.log_prob_micro_batch_size //= (self.device_mesh.size() //
self.ulysses_sequence_parallel_size)
self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size
3.2 ActorRolloutRefWorker._build_rollout
def _build_rollout(self):
from torch.distributed.device_mesh import init_device_mesh
# TODO(sgm): support FSDP hybrid shard for larger model
infer_tp = self.config.rollout.tensor_model_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}'
'''
dp决定了 整个raw batch 会分几片
因为 self.world_size = 6 (6个GPU), actor_rollout_ref.rollout.tensor_model_parallel_size=2,
那么 dp = 6//2 = 3, 即 raw batch 中60条数据(data.train_batch_size=60), 可以分到三个shard上。每个shard上有两个GPU,处理 60//3 = 20 条数据。
'''
rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp'])
print('rollout_device_mesh in ActorRolloutRefWorker._build_rollout: ', rollout_device_mesh)
'''
rollout_device_mesh中存储了GPU分片信息,将传入 vLLMRollout 当中。
DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5]], mesh_dim_names=('dp', 'infer_tp'))
'''
rollout_name = self.config.rollout.name
if rollout_name == 'hf':
from verl.workers.rollout import HFRollout
from verl.workers.sharding_manager import BaseShardingManager
rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
rollout_sharding_manager = BaseShardingManager()
# TODO: a sharding manager that do nothing?
elif 'vllm' in rollout_name or rollout_name == 'vllm':
from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode
from verl.workers.sharding_manager import FSDPVLLMShardingManager
log_gpu_memory_usage(f'Before building {rollout_name} rollout', logger=None)
local_path = copy_to_local(self.config.model.path)
if vllm_mode == 'customized':
rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config)
# 这里有些我个人DIY的痕迹,和verl源码有点不同
elif vllm_mode == 'spmd':
if '...' in rollout_name:
...
else:
rollout_cls = vLLMRollout
rollout = rollout_cls(model_path=local_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config,
device_mesh=rollout_device_mesh)
else:
raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
log_gpu_memory_usage(f'After building {rollout_name} rollout', logger=None)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = 'dummy_hf'
rollout_sharding_manager = FSDPVLLMShardingManager(module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params='hf' in self.config.rollout.load_format,
device_mesh=rollout_device_mesh)
log_gpu_memory_usage('After building sharding manager', logger=None)
elif rollout_name == 'sglang':
from verl.workers.rollout.sglang_rollout import SGLangRollout
# NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.
# However, due to veRL's setting, the main process of ray can not find any CUDA device, which would potentially lead to:
# "RuntimeError: No CUDA GPUs are available".
# For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path.
# check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
log_gpu_memory_usage(f'Before building {rollout_name} rollout', logger=None)
rollout = SGLangRollout(actor_module=self.config.model.path,
config=self.config.rollout,
tokenizer=self.tokenizer,
model_hf_config=self.actor_model_config)
log_gpu_memory_usage(f'After building {rollout_name} rollout', logger=None)
if torch.distributed.get_world_size() == 1:
self.config.rollout.load_format = 'dummy_hf'
rollout_sharding_manager = FSDPSGLangShardingManager(module=self.actor_module_fsdp,
inference_engine=rollout.inference_engine,
model_config=self.actor_model_config,
full_params='hf' in self.config.rollout.load_format,
device_mesh=rollout_device_mesh)
log_gpu_memory_usage('After building sharding manager', logger=None)
return rollout, rollout_sharding_manager
3.3 ActorRolloutRefWorker.generate_sequences
generate_sequences
由装饰器@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
装饰,
负责汇总所有不同GPU分片上处理的数据。上面的_build_rollout
将原始batch中60条数据,分到三个shard上。每个shard上的vllm处理 20条数据,执行n=12的 rollout操作,最终得到20*12=240 条数据。
generate_sequences会对3个shard得到的rollout results进行汇总,最终得到 3*240 = 720 条数据。这也呼应了 ray_trainer.py
中 执行
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
后能得到 720 条数据。
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
# Support all hardwares
prompts = prompts.to(torch.cuda.current_device())
assert self._is_rollout
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
meta_info = {
'eos_token_id':
self.generation_config.eos_token_id
if self.generation_config is not None else self.tokenizer.eos_token_id,
'pad_token_id':
self.generation_config.pad_token_id
if self.generation_config is not None else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
with self.rollout_sharding_manager:
# after parameters sync with rollout, offload actor model to CPU
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage('After entering rollout sharding manager', logger=logger)
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
log_gpu_memory_usage('After rollout generation', logger=logger)
output = self.rollout_sharding_manager.postprocess_data(output)
output = output.to('cpu')
# clear kv cache
log_gpu_memory_usage('After recompute log prob', logger=logger)
return output
4. rollout 与上述配置
如上,我们有和actor rollout 相关的参数
data.train_batch_size=60
trainer.n_gpus_per_node=6
trainer.nnodes=1
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8
actor_rollout_ref.rollout.n=12
actor_rollout_ref.rollout.tensor_model_parallel_size=2
Q: rollout 时怎么分workers?每个worker负责多少条sequence的rollout?
- data.train_batch_size=60 意味着每一步处理60条来自training data 的sequence数据。
- trainer.n_gpus_per_node=6, trainer.nnodes=1 意味着共6块GPU
- actor_rollout_ref.rollout.tensor_model_parallel_size=2 意味着每两张GPU作为一个actor worker。其使用vllm负责一小批 batch。于是,拢共有 6/2=3 个actor workers,
那么,为了负载均衡,我们希望每一个worker负责等量的sequence。于是,每个worker会负责20条sequence的rollout, 计算方式如下。
data.train-batch-size worker-num = data.train-batch-size × ( actor-rollout-ref.tensor-model-parallel-size ) trainer.n-gpus-per-node × trainer.nnodes = 60 × 2 6 × 1 = 20 \frac{\text{data.train-batch-size}}{\text{worker-num}}=\frac{\text{data.train-batch-size}\times(\text{actor-rollout-ref.tensor-model-parallel-size})}{\text{trainer.n-gpus-per-node}\times\text{trainer.nnodes}}=\frac{60\times 2}{6\times 1}=20 worker-numdata.train-batch-size=trainer.n-gpus-per-node×trainer.nnodesdata.train-batch-size×(actor-rollout-ref.tensor-model-parallel-size)=6×160×2=20
- actor_rollout_ref.rollout.n=12,意味着对于每一条sequence,actor会基于当前的state进行12次采样。那么对于每一个actor worker的vllm engine而言, 需要推理的数据量为
20 × actor-rollout-ref.rollout.n = 20 × 12 = 240 20 \times \text{actor-rollout-ref.rollout.n}=20 \times 12 = 240 20×actor-rollout-ref.rollout.n=20×12=240
因此,在所有workers 完成rollout之后,总共有
240
×
worker-num
=
240
×
3
=
720
240 \times \text{worker-num}=240 \times 3 =720
240×worker-num=240×3=720 条sequences, 进入到verl/verl/trainer/ppo/ray_trainer.py
的流水线中:
- 计算 π o l d ( τ i , ( t ) ∣ τ i , < t ) \pi_{old} \left( \tau_{i,\left( t \right)} |\tau_{i,<t} \right) πold(τi,(t)∣τi,<t)
# 使用 actor_rollout_wg计算每一个rollout sample中每个token的 old policy log_prob
with _timer('old_log_prob', timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
- 计算 π r e f ( τ i , ( t ) ∣ τ i , < t ) \pi_{ref} \left( \tau_{i,\left( t \right)} |\tau_{i,<t} \right) πref(τi,(t)∣τi,<t)
# 使用 ref_policy_wg 计算每一个rollout sample的 ref policy log_prob
if self.use_reference_policy: #
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
- 计算 advantage A i A_{i} Ai
# 使用 critic_wg 计算每一个rollout sample的 价值。
# 每个rollout sample的Advantage的计算基于Value
# 然而,GRPO省略了critic model, 每个rollout sample 获得的Reward便是其Value!
if self.use_critic:
# self.use_critic = False in GRPO
with _timer('values', timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with _timer('adv', timing_raw):
# 在PPO中,需要使用一个单独的reward model(rm_wg)来打分
if self.use_rm:
# self.use_rm = False in GRPO
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
# GRPO中,直接使用了基于规则的reward。
reward_tensor = self.reward_fn(batch)
# 在这里可以算出,每个rollout sample中, 每个token的得分
batch.batch['token_level_scores'] = reward_tensor
# compute rewards. apply_kl_penalty if available KL散度惩罚
# 只有 actor_rollout_ref.actor.use_kl_loss=True 为 True 才会计算这里
if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False):
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
# 为每个rollout sample计算 advantage
batch = compute_advantage(
batch, # batch.batch['token_level_rewards'] 每个时间步的即时奖励(通常已包含KL散度惩罚)
adv_estimator=self.config.algorithm.adv_estimator, # 由self.critic_wg.compute_values()生成,状态或状态-动作对的预测价值
gamma=self.config.algorithm.gamma, # 折扣因子 (Gamma, γ) 控制未来奖励的衰减程度
lam=self.config.algorithm.lam, # GAE参数 (Lambda, λ),平衡偏差-方差权衡(越高越依赖长期估计)
num_repeat=self.config.actor_rollout_ref.rollout.n
) #
依次计算完 π o l d ( τ i , ( t ) ∣ τ i , < t ) \pi_{old} \left( \tau_{i,\left( t \right)} |\tau_{i,<t} \right) πold(τi,(t)∣τi,<t), π r e f ( τ i , ( t ) ∣ τ i , < t ) \pi_{ref} \left( \tau_{i,\left( t \right)} |\tau_{i,<t} \right) πref(τi,(t)∣τi,<t), A i A_{i} Ai 之后,开始对于 各个model进行优化更新:
- critic 模型更新 (GRPO不需要)
if self.use_critic:
with _timer('update_critic', timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)
- actor 模型更新
# trainer.critic_warmup 一般设计为0。
# global_steps: 处理到train_dataloader中第几个batch了?
# 当 self.global_steps < self.config.trainer.critic_warmup 不对于 actor model进行更新。
if self.config.trainer.critic_warmup <= self.global_steps:
# 在critic model 预热结束之后,对于 actor model 进行更新
with _timer('update_actor', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
# 更新actor 相关的metric
metrics.update(actor_output_metrics)
- 在指定条件下进行validation
# 使用 self._validate 进行validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
(is_last_step or self.global_steps % self.config.trainer.test_freq == 0):
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
- 在指定条件下进行ckpt saving
# 使用self._save_checkpoint() 保存检查点
if self.config.trainer.save_freq > 0 and ( is_last_step or \
self.global_steps % self.config.trainer.save_freq == 0):
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()