verl - verl核心代码详解(与令人纠结的batch size)

在RL训练中,我们会遇到各种各样的batch, 眼花缭乱。在这里,我们细致梳理一下这些batch 代表了什么,和它们之间的关系。这个blog里面,主要涉及如下文件中的代码:
verl/verl/trainer/ppo/ray_trainer.py
verl/verl/workers/fsdp_workers.py

1. 背景

采用FSDP 进行GRPO 训练。GRPO是DeepSeek提出的PPO的高效变体。相较于PPO:

  • GRPO省略了Reward Model,直接用Rule-based 方式计算reward。
  • GRPO省略了Critic Model(评论家模型), 不再额外计算 V i V_{i} Vi
  • 计算advantage直接基于Reward R i R_{i} Ri, 而原始PPO(等算法)基于 V i V_{i} Vi. 这表明 GRPO直接用rollout example的reward 度量其Value (价值).

2. batch size 相关参数

verl/verl/trainer/config/ppo_trainer.yaml中,我们会遇到若干与batch size相关的参数,让我们一一拆解分析。
Warning: verl/verl/trainer/config/ppo_trainer.yaml中的配置参数可能会被运行脚本中的配置参数覆盖掉。

2.1 General

data.train_batch_size=60
trainer.n_gpus_per_node=6 
trainer.nnodes=1 

这里我们有1个节点,共6张卡。
注意:

  • data.train_batch_size 必须可以整除 trainer.n_gpus_per_node, 否则报错。

2.2 actor_rollout_ref

## actor
actor_rollout_ref.actor.ppo_mini_batch_size=60
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8  # 这个似乎没什么用
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 
### actor.fsdp_config
actor_rollout_ref.actor.fsdp_config.param_offload=False 
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False 
## fsdp_config
actor_rollout_ref.fsdp_config.fsdp_size=-1 
## rollout
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 
actor_rollout_ref.rollout.n=12
actor_rollout_ref.rollout.tensor_model_parallel_size=2
## ref
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 

我们可以在ray_trainer.py 的 fit() 函数中粗略理解一下batch

with _timer('step', timing_raw): # 记录整个步骤时间
	# 整个步骤 = 计算old policy + 计算ref policy +  计算adv(包括计算value + 计算reward) 
	#			+ 更新critic model + 更新actor model + sometimes validation + sometimes ckpt saving
     with _timer('gen', timing_raw): # 记录生成序列时间
         # 使用actor模型生成文本序列
         # gen_batch:包含input_ids等生成所需数据
         # gen_batch_output:生成的序列及其概率等信息
         print('gen_batch shape: ', gen_batch.batch['input_ids'].shape)
         '''actor 进行rollout之前,
         gen_batch shape:  torch.Size([60, 8192]), 
         data.train_batch_size = 60
         	->  gen_batch.batch['input_ids'].shape[0] = 60
         '''
         gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
         print("gen_batch_output.batch['prompt_token_ids'].shape: ", gen_batch_output.batch['prompts'].shape)
         '''actor 进行rollout之后,
         gen_batch_output.batch['prompt_token_ids'].shape:  torch.Size([720, 8192])
         data.train_batch_size = 60, actor_rollout_ref.rollout.n=12
         	->  gen_batch_output.batch['prompt_token_ids'].shape.shape[0] = 60 * 12 =720
         '''

TL,DR:
verl/verl/workers/fsdp_workers.pyclass ActorRolloutRefWorker(Worker)中,我们会与这些参数打上交道

# 对于Actor
actor_rollout_ref.actor.ppo_mini_batch_size 和 GPU个数将共同决定 
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu

# 在rollout过程中计算 rollout sample 的 log_prob, 每个GPU的处理样例数直接由如下配置决定
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 

# 计算reference model的 log_prob, 每个GPU的处理样例数直接由如下配置决定
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 

3. fsdp_workers 与上述配置

3.1 ActorRolloutRefWorker 的配置函数

class ActorRolloutRefWorker(Worker):
    """
    This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy
    or a hybrid engine based on the config.rollout
    """

    def __init__(self, config: DictConfig, role: str):
    	'''
    		DictConfig: actor_rollout_ref 相关的配置
    		role: 该worker的角色,一般情况下 role = actor_rollout
    	'''
        super().__init__()
        self.config = config
        import torch.distributed
        if not torch.distributed.is_initialized():
            torch.distributed.init_process_group()

        # build device mesh for FSDP
        world_size = torch.distributed.get_world_size()
        # TODO(sgm): support FSDP hybrid shard for larger model
        '''
        	trainer.n_gpus_per_node=6  -> world_size = 6
        	actor_rollout_ref.fsdp_config.fsdp_size=-1 -> fsdp_size = -1
        '''

        self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size)

        # build device mesh for Ulysses Sequence Parallel
        self.ulysses_device_mesh = None
        self.ulysses_sequence_parallel_size = self.config.actor.get('ulysses_sequence_parallel_size', 1)
        dp = world_size // self.ulysses_sequence_parallel_size
        if self.ulysses_sequence_parallel_size > 1:
            # 创建一个二维 GPU 网格,同时支持数据并行(DP)和序列并行(SP)。
            print('self.ulysses_sequence_parallel_size: ', self.ulysses_sequence_parallel_size)
            self.ulysses_device_mesh = init_device_mesh('cuda',
                                                        mesh_shape=(dp, self.ulysses_sequence_parallel_size),
                                                        mesh_dim_names=['dp', 'sp'])

        self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)

        self.role = role
        '''
        	正如前面的注释, self.role = actor_rollout
        '''
        assert self.role in ['actor', 'rollout', 'ref', 'actor_rollout', 'actor_rollout_ref']

        self._is_actor = self.role in ['actor', 'actor_rollout', 'actor_rollout_ref']
        self._is_rollout = self.role in ['rollout', 'actor_rollout', 'actor_rollout_ref']
        self._is_ref = self.role in ['ref', 'actor_rollout_ref']
		'''
			那么
			self._is_actor = True,
			self._is_rollout = True,  
			self._is_ref = Flase
		'''
        self._is_offload_param = False
        self._is_offload_optimizer = False
        if self._is_actor:
            self._is_offload_param = self.config.actor.fsdp_config.get('param_offload', False)
            self._is_offload_optimizer = self.config.actor.fsdp_config.get('optimizer_offload', False)
        '''
        	根据actor_rollout_ref.actor的配置信息, actor使用FSDP时,不进行 param 和 optimizer的offloading
        	actor_rollout_ref.actor.fsdp_config.param_offload=False 
        		-> self._is_offload_param = False
			actor_rollout_ref.actor.fsdp_config.optimizer_offload=False 
				-> self._is_offload_optimizer = False
        '''
        elif self._is_ref:
            # TODO: it seems that manual offload is slowly than FSDP offload
            self._is_offload_param = self.config.ref.fsdp_config.get('param_offload', False)

        # normalize config
        if self._is_actor:
            '''
            	从配置读入:
            	actor_rollout_ref.actor.ppo_mini_batch_size=60
            		-> self.config.actor.ppo_mini_batch_size = 60
				actor_rollout_ref.rollout.n=12
					-> self.config.rollout.n = 12
				更新 self.config.actor.ppo_mini_batch_size (60->720):
					-> 60 * 12 = 720
            '''
            self.config.actor.ppo_mini_batch_size *= self.config.rollout.n
            '''
            	根据print出的结果,self.device_mesh.size() 似乎对应 GPU个数。目前是单机, 在多机上不知道是不是这样。
            	因此,self.device_mesh.size() = 6
            	actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 
            	 -> self.ulysses_sequence_parallel_size = 1
            	TO-BE-Verify:
            		self.ulysses_sequence_parallel_size = n 意味着n张卡协作处理一个sequence?
            		n=1,这大概意味着暂时不启用sequence parallelism?
            '''
            self.config.actor.ppo_mini_batch_size //= (self.device_mesh.size() // self.ulysses_sequence_parallel_size)
            '''
            	再次更新 self.config.actor.ppo_mini_batch_size (720 -> 120)
            	self.config.actor.ppo_mini_batch_size = 720
            	首先获得分片个数,
            		shard = self.device_mesh.size() // self.ulysses_sequence_parallel_size 
            			  = 6 // 1
            			  = 6
            		这意味着 self.config.actor.ppo_mini_batch_size (720)的并发量,并发为6?
            	那么在每一个进程上,被分配 ppo_mini_batch_size 为
            		self.config.actor.ppo_mini_batch_size = 720 // 6 = 120
            								
            '''
            #更新self.config.actor.ppo_mini_batch_size 的过程被称为normalization, 更新后到  self.config.actor.ppo_mini_batch_size 必须大于0
            assert self.config.actor.ppo_mini_batch_size > 0, f'ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than 0 after normalization'
            # micro bsz
			# self.config.actor.ppo_micro_batch_size 已被弃用,不用管这里
            if self.config.actor.ppo_micro_batch_size is not None:
                self.config.actor.ppo_micro_batch_size //= (self.device_mesh.size() //
                                                            self.ulysses_sequence_parallel_size)
                self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size
                assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0, \
                    f'normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be divisible by ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}'
                assert self.config.actor.ppo_mini_batch_size // self.config.actor.ppo_micro_batch_size_per_gpu > 0, \
                    f'normalized ppo_mini_batch_size {self.config.actor.ppo_mini_batch_size} should be larger than ppo_micro_batch_size_per_gpu {self.config.actor.ppo_micro_batch_size_per_gpu}'

        # normalize rollout config,
        # 在 self.config.rollout.log_prob_micro_batch_size 不为None的时候,才进行如下的操作,但是,大部分情况都是none (即不执行如下)
        if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
            '''
            	从配置中读入:
            	actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 
            		-> self.config.rollout.log_prob_micro_batch_size = 8
            	shard = self.device_mesh.size() // self.ulysses_sequence_parallel_size 
            	      = 6 //1 
            	      =6
            	每个shard (此时每个shard对应一个GPU)上分了 self.config.rollout.log_prob_micro_batch_size  = 8//6 = 1 个 sequence
            	self.config.rollout.log_prob_micro_batch_size_per_gpu = 1
            	
            '''
            self.config.rollout.log_prob_micro_batch_size //= (self.device_mesh.size() //
                                                               self.ulysses_sequence_parallel_size)
            self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size
        # normalize ref config
        # 处理逻辑跟self.config.rollout.log_prob_micro_batch_size_per_gpu的计算差不多。
        #  在 self.config.ref.log_prob_micro_batch_size 不为None的时候,才进行如下的操作,但是,大部分情况都是none (即不执行如下)
        if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None:
            self.config.ref.log_prob_micro_batch_size //= (self.device_mesh.size() //
                                                           self.ulysses_sequence_parallel_size)
            self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size

3.2 ActorRolloutRefWorker._build_rollout

    def _build_rollout(self):
        from torch.distributed.device_mesh import init_device_mesh
        # TODO(sgm): support FSDP hybrid shard for larger model
        infer_tp = self.config.rollout.tensor_model_parallel_size
        dp = self.world_size // infer_tp 
        assert self.world_size % infer_tp == 0, f'rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}'
        '''
          dp决定了 整个raw batch 会分几片 
          因为 self.world_size = 6 (6个GPU), actor_rollout_ref.rollout.tensor_model_parallel_size=2,
          那么 dp = 6//2 = 3, 即 raw batch 中60条数据(data.train_batch_size=60), 可以分到三个shard上。每个shard上有两个GPU,处理 60//3 = 20 条数据。
        '''
        rollout_device_mesh = init_device_mesh('cuda', mesh_shape=(dp, infer_tp), mesh_dim_names=['dp', 'infer_tp'])
        print('rollout_device_mesh in ActorRolloutRefWorker._build_rollout: ', rollout_device_mesh)
        '''
        	rollout_device_mesh中存储了GPU分片信息,将传入 vLLMRollout 当中。
        	DeviceMesh('cuda', [[0, 1], [2, 3], [4, 5]], mesh_dim_names=('dp', 'infer_tp'))
        '''
        rollout_name = self.config.rollout.name
        if rollout_name == 'hf':
            from verl.workers.rollout import HFRollout
            from verl.workers.sharding_manager import BaseShardingManager
            rollout = HFRollout(module=self.actor_module_fsdp, config=self.config.rollout)
            rollout_sharding_manager = BaseShardingManager()
            # TODO: a sharding manager that do nothing?

        elif 'vllm' in rollout_name or rollout_name == 'vllm':
            from verl.workers.rollout.vllm_rollout import vLLMRollout, vllm_mode
            from verl.workers.sharding_manager import FSDPVLLMShardingManager
            log_gpu_memory_usage(f'Before building {rollout_name} rollout', logger=None)
            local_path = copy_to_local(self.config.model.path)
            if vllm_mode == 'customized':
                rollout = vLLMRollout(actor_module=self.actor_module_fsdp,
                                      config=self.config.rollout,
                                      tokenizer=self.tokenizer,
                                      model_hf_config=self.actor_model_config)
            # 这里有些我个人DIY的痕迹,和verl源码有点不同
            elif vllm_mode == 'spmd':
                if '...' in rollout_name:
                    ...
                else:
                    rollout_cls = vLLMRollout

                rollout = rollout_cls(model_path=local_path,
                                      config=self.config.rollout,
                                      tokenizer=self.tokenizer,
                                      model_hf_config=self.actor_model_config,
                                      device_mesh=rollout_device_mesh)
            else:
                raise NotImplementedError("vllm_mode must be 'customized' or 'spmd'")
            log_gpu_memory_usage(f'After building {rollout_name} rollout', logger=None)
            if torch.distributed.get_world_size() == 1:
                self.config.rollout.load_format = 'dummy_hf'
            rollout_sharding_manager = FSDPVLLMShardingManager(module=self.actor_module_fsdp,
                                                               inference_engine=rollout.inference_engine,
                                                               model_config=self.actor_model_config,
                                                               full_params='hf' in self.config.rollout.load_format,
                                                               device_mesh=rollout_device_mesh)
            log_gpu_memory_usage('After building sharding manager', logger=None)

        elif rollout_name == 'sglang':
            from verl.workers.rollout.sglang_rollout import SGLangRollout
            # NOTE(linjunrong): Due to recent fp8 support in SGLang. Now importing any symbol relate to SGLang's model_runner would check CUDA device capability.
            # However, due to veRL's setting, the main process of ray can not find any CUDA device, which would potentially lead to:
            # "RuntimeError: No CUDA GPUs are available".
            # For this reason, sharding_manager.__init__ should not import FSDPSGLangShardingManager and we import it here use the abs path.
            # check: https://github.com/sgl-project/sglang/blob/00f42707eaddfc2c0528e5b1e0094025c640b7a0/python/sglang/srt/layers/quantization/fp8_utils.py#L76
            from verl.workers.sharding_manager.fsdp_sglang import FSDPSGLangShardingManager
            log_gpu_memory_usage(f'Before building {rollout_name} rollout', logger=None)
            rollout = SGLangRollout(actor_module=self.config.model.path,
                                    config=self.config.rollout,
                                    tokenizer=self.tokenizer,
                                    model_hf_config=self.actor_model_config)
            log_gpu_memory_usage(f'After building {rollout_name} rollout', logger=None)

            if torch.distributed.get_world_size() == 1:
                self.config.rollout.load_format = 'dummy_hf'
            rollout_sharding_manager = FSDPSGLangShardingManager(module=self.actor_module_fsdp,
                                                                 inference_engine=rollout.inference_engine,
                                                                 model_config=self.actor_model_config,
                                                                 full_params='hf' in self.config.rollout.load_format,
                                                                 device_mesh=rollout_device_mesh)
            log_gpu_memory_usage('After building sharding manager', logger=None)

        return rollout, rollout_sharding_manager

3.3 ActorRolloutRefWorker.generate_sequences

generate_sequences由装饰器@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)装饰,
负责汇总所有不同GPU分片上处理的数据。上面的_build_rollout将原始batch中60条数据,分到三个shard上。每个shard上的vllm处理 20条数据,执行n=12的 rollout操作,最终得到20*12=240 条数据。

generate_sequences会对3个shard得到的rollout results进行汇总,最终得到 3*240 = 720 条数据。这也呼应了 ray_trainer.py中 执行

gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

后能得到 720 条数据。

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def generate_sequences(self, prompts: DataProto):
     # Support all hardwares
     prompts = prompts.to(torch.cuda.current_device())

     assert self._is_rollout
     if self._is_offload_param:
         load_fsdp_model_to_gpu(self.actor_module_fsdp)

     meta_info = {
         'eos_token_id':
             self.generation_config.eos_token_id
             if self.generation_config is not None else self.tokenizer.eos_token_id,
         'pad_token_id':
             self.generation_config.pad_token_id
             if self.generation_config is not None else self.tokenizer.pad_token_id,
     }
     prompts.meta_info.update(meta_info)
     with self.rollout_sharding_manager:

         # after parameters sync with rollout, offload actor model to CPU
         if self._is_offload_param:
             offload_fsdp_model_to_cpu(self.actor_module_fsdp)
         if self._is_offload_optimizer:
             offload_fsdp_optimizer(optimizer=self.actor_optimizer)

         log_gpu_memory_usage('After entering rollout sharding manager', logger=logger)

         prompts = self.rollout_sharding_manager.preprocess_data(prompts)
         output = self.rollout.generate_sequences(prompts=prompts)
         log_gpu_memory_usage('After rollout generation', logger=logger)

         output = self.rollout_sharding_manager.postprocess_data(output)

     output = output.to('cpu')

     # clear kv cache
     log_gpu_memory_usage('After recompute log prob', logger=logger)
     return output

4. rollout 与上述配置

如上,我们有和actor rollout 相关的参数

data.train_batch_size=60
trainer.n_gpus_per_node=6 
trainer.nnodes=1 
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 
actor_rollout_ref.rollout.n=12
actor_rollout_ref.rollout.tensor_model_parallel_size=2

Q: rollout 时怎么分workers?每个worker负责多少条sequence的rollout?

  • data.train_batch_size=60 意味着每一步处理60条来自training data 的sequence数据。
  • trainer.n_gpus_per_node=6, trainer.nnodes=1 意味着共6块GPU
  • actor_rollout_ref.rollout.tensor_model_parallel_size=2 意味着每两张GPU作为一个actor worker。其使用vllm负责一小批 batch。于是,拢共有 6/2=3 个actor workers,

那么,为了负载均衡,我们希望每一个worker负责等量的sequence。于是,每个worker会负责20条sequence的rollout, 计算方式如下。

data.train-batch-size worker-num = data.train-batch-size × ( actor-rollout-ref.tensor-model-parallel-size ) trainer.n-gpus-per-node × trainer.nnodes = 60 × 2 6 × 1 = 20 \frac{\text{data.train-batch-size}}{\text{worker-num}}=\frac{\text{data.train-batch-size}\times(\text{actor-rollout-ref.tensor-model-parallel-size})}{\text{trainer.n-gpus-per-node}\times\text{trainer.nnodes}}=\frac{60\times 2}{6\times 1}=20 worker-numdata.train-batch-size=trainer.n-gpus-per-node×trainer.nnodesdata.train-batch-size×(actor-rollout-ref.tensor-model-parallel-size)=6×160×2=20

  • actor_rollout_ref.rollout.n=12,意味着对于每一条sequence,actor会基于当前的state进行12次采样。那么对于每一个actor worker的vllm engine而言, 需要推理的数据量为
    20 × actor-rollout-ref.rollout.n = 20 × 12 = 240 20 \times \text{actor-rollout-ref.rollout.n}=20 \times 12 = 240 20×actor-rollout-ref.rollout.n=20×12=240

因此,在所有workers 完成rollout之后,总共有 240 × worker-num = 240 × 3 = 720 240 \times \text{worker-num}=240 \times 3 =720 240×worker-num=240×3=720 条sequences, 进入到verl/verl/trainer/ppo/ray_trainer.py的流水线中:

在这里插入图片描述

  • 计算 π o l d ( τ i , ( t ) ∣ τ i , < t ) \pi_{old} \left( \tau_{i,\left( t \right)} |\tau_{i,<t} \right) πold(τi,(t)τi,<t)
# 使用 actor_rollout_wg计算每一个rollout sample中每个token的 old policy log_prob
with _timer('old_log_prob', timing_raw):
  old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) 
  batch = batch.union(old_log_prob)
  • 计算 π r e f ( τ i , ( t ) ∣ τ i , < t ) \pi_{ref} \left( \tau_{i,\left( t \right)} |\tau_{i,<t} \right) πref(τi,(t)τi,<t)
# 使用 ref_policy_wg 计算每一个rollout sample的 ref policy log_prob
if self.use_reference_policy: #
    # compute reference log_prob
	with _timer('ref', timing_raw):
	   ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
	   batch = batch.union(ref_log_prob)
  • 计算 advantage A i A_{i} Ai
# 使用 critic_wg 计算每一个rollout sample的 价值。 
# 每个rollout sample的Advantage的计算基于Value
# 然而,GRPO省略了critic model, 每个rollout sample 获得的Reward便是其Value!
if self.use_critic:
	# self.use_critic = False in GRPO
   with _timer('values', timing_raw):
        values = self.critic_wg.compute_values(batch) 
        batch = batch.union(values)


 with _timer('adv', timing_raw):
 	 # 在PPO中,需要使用一个单独的reward model(rm_wg)来打分
     if self.use_rm: 
     	# self.use_rm = False in GRPO
         reward_tensor = self.rm_wg.compute_rm_score(batch) 
         batch = batch.union(reward_tensor)

     # GRPO中,直接使用了基于规则的reward。
     reward_tensor = self.reward_fn(batch)
     # 在这里可以算出,每个rollout sample中, 每个token的得分
     batch.batch['token_level_scores'] = reward_tensor

     # compute rewards. apply_kl_penalty if available KL散度惩罚
     # 只有 actor_rollout_ref.actor.use_kl_loss=True 为 True 才会计算这里
     if not self.config.actor_rollout_ref.actor.get('use_kl_loss', False):
         batch, kl_metrics = apply_kl_penalty(batch,
                                              kl_ctrl=self.kl_ctrl,
                                              kl_penalty=self.config.algorithm.kl_penalty)
         metrics.update(kl_metrics)
     else:
         batch.batch['token_level_rewards'] = batch.batch['token_level_scores']

     # 为每个rollout sample计算 advantage
     batch = compute_advantage(
     	   batch, # batch.batch['token_level_rewards'] 每个时间步的即时奖励(通常已包含KL散度惩罚)
           adv_estimator=self.config.algorithm.adv_estimator, # 由self.critic_wg.compute_values()生成,状态或状态-动作对的预测价值 
           gamma=self.config.algorithm.gamma, # 折扣因子 (Gamma, γ) 控制未来奖励的衰减程度
           lam=self.config.algorithm.lam, # GAE参数 (Lambda, λ),平衡偏差-方差权衡(越高越依赖长期估计)
           num_repeat=self.config.actor_rollout_ref.rollout.n
         ) # 

依次计算完 π o l d ( τ i , ( t ) ∣ τ i , < t ) \pi_{old} \left( \tau_{i,\left( t \right)} |\tau_{i,<t} \right) πold(τi,(t)τi,<t) π r e f ( τ i , ( t ) ∣ τ i , < t ) \pi_{ref} \left( \tau_{i,\left( t \right)} |\tau_{i,<t} \right) πref(τi,(t)τi,<t), A i A_{i} Ai 之后,开始对于 各个model进行优化更新:

  • critic 模型更新 (GRPO不需要)
if self.use_critic:
	 with _timer('update_critic', timing_raw):
	   critic_output = self.critic_wg.update_critic(batch)
	   critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
	   metrics.update(critic_output_metrics)
  • actor 模型更新
# trainer.critic_warmup 一般设计为0。
# global_steps: 处理到train_dataloader中第几个batch了?
# 当 self.global_steps < self.config.trainer.critic_warmup 不对于 actor model进行更新。
if self.config.trainer.critic_warmup <= self.global_steps:
    # 在critic model 预热结束之后,对于 actor model 进行更新
    with _timer('update_actor', timing_raw):
        actor_output = self.actor_rollout_wg.update_actor(batch)
    actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
    # 更新actor 相关的metric
    metrics.update(actor_output_metrics)
  • 在指定条件下进行validation
# 使用 self._validate 进行validate
if self.val_reward_fn is not None and self.config.trainer.test_freq > 0 and \
       (is_last_step or  self.global_steps % self.config.trainer.test_freq == 0):
       with _timer('testing', timing_raw):
           val_metrics: dict = self._validate()
           if is_last_step:
               last_val_metrics = val_metrics
       metrics.update(val_metrics)
  • 在指定条件下进行ckpt saving
# 使用self._save_checkpoint() 保存检查点
 if self.config.trainer.save_freq > 0 and ( is_last_step or \
         self.global_steps % self.config.trainer.save_freq == 0):
     with _timer('save_checkpoint', timing_raw):
         self._save_checkpoint()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值