元强化学习 PEARL 项目解读

PEARL 项目解读

因为自己的工作需要,需要跑一下元强化学习的 PEARL (Efficient Off-policy Meta-learning via Probabilistic Context Variables) 代码做一些对比实验。

最原始的代码(https://github.com/katerakelly/oyster)。但是这个代码下载下来,环境配置比较困难。

搜索 Github 上看到这个项目(链接:https://github.com/dongminlee94/meta-learning-for-everyone),里面介绍了 MAML、RL2 和 PEARL。阅读项目介绍,感觉比较科普,因此就以这个项目做解读。理解里面的代码细节,方便做复现~

1. 文件总概览

进入到 PEARL 文件夹下,输入指令 tree 即可大致了解里面的文件包含情况~

.
├── algorithm
│   ├── buffers.py
│   ├── meta_learner.py
│   ├── networks.py
│   ├── sac.py
│   └── sampler.py
├── configs
│   ├── dir_target_config.yaml
│   ├── experiment_config.yaml
│   └── vel_target_config.yaml
└── pearl_trainer.py

algorithm 文件夹下面是基本的组建:经验池 buffers.py,元学习外环框架 meta_learner.py,模型网络 networks.py,内环框架 sac.py,以及最后的批采样器 sampler.py

configs 下面是一些实验的配置文件:experiment_config.yaml 总的配置文件,dir_target_config.yaml 控制方向实验配置文件,vel_target_config.yaml 控制速度配置文件。

pearl_trainer.py 是整个训练脚本,运行时候直接运行这个文件即可~

2. 代码流程 pearl_trainer.py

导入 experiment_config.yaml 文件。

with open(os.path.join("configs", "experiment_config.yaml"), "r", encoding="utf-8") as file:
    experiment_config: Dict[str, Any] = yaml.load(file, Loader=yaml.FullLoader)

导入具体某个实验的配置数据,experiment_config.yaml 文件中:env_name: "dir",因此导入的是方向实验文件,也就是 dir_target_config.yaml

with open(
    os.path.join("configs", experiment_config["env_name"] + "_target_config.yaml"),
    "r",
    encoding="utf-8",
) as file:
    env_target_config: Dict[str, Any] = yaml.load(file, Loader=yaml.FullLoader)

下面是 dir_target_config.yaml 的具体细节。具体变量含义稍后在提及~

train_tasks: 2
test_tasks: 2
latent_dim: 5
hidden_dim: 300

pearl_params:
    num_iterations: 1000
    num_sample_tasks: 5
    num_init_samples: 2000
    num_prior_samples: 1000
    num_posterior_samples: 1000
    num_meta_grads: 1500
    meta_batch_size: 4
    batch_size: 256
    max_step: 200
    max_buffer_size: 1000000
    num_stop_conditions: 3
    stop_goal: 1900
    
sac_params:
    gamma: 0.99
    kl_lambda: 0.1
    batch_size: 256
    qf_lr: 0.0003
    encoder_lr: 0.0003
    policy_lr: 0.0003

env 实例化环境,里面初始化了元训练任务和元测试任务的目标。["cheetah-" + experiment_config["env_name"]] 里面承载的是一个字符串,用于索引环境的种类。后面 () 里面的 num_tasks 承接着训练任务和测试任务的数量,这里用加号说明是实例化两者的总和。

tasks 返回的是一个列表,get_all_task_idx() 输出的是刚才实例化的任务的编号。

env: HalfCheetahEnv = ENVS["cheetah-" + experiment_config["env_name"]](
    num_tasks=env_target_config["train_tasks"] + env_target_config["test_tasks"],
)
tasks: List[int] = env.get_all_task_idx()

读取配置文件的随机数信息,给必要的库设置随机数。

env.reset(seed=experiment_config["seed"])
np.random.seed(experiment_config["seed"])
torch.manual_seed(experiment_config["seed"])

根据实例化的环境 env,指出环境中的状态空间、动作空间和隐藏层维度。

隐藏层维度指的是生成上下文变量的神经网络的中间层神经元的数量,也就是 dir_target_config.yaml 文件的 hidden_dim 变量,数值是300~

observ_dim: int = env.observation_space.shape[0]
action_dim: int = env.action_space.shape[0]
hidden_dim: int = env_target_config["hidden_dim"]

根据配置文件指定显卡。

device: torch.device = (
        torch.device("cuda", index=experiment_config["gpu_index"])
        if torch.cuda.is_available()
        else torch.device("cpu"))

实例化内环智能体 SAC。实例化参数里面,根据配置信息读取数据。上下文变量的维度:latent_dim 是5,隐藏层神经元的维度:hidden_dim 是300。编码器的输入是状态转移信息,观测维度 + 动作维度 + 奖励值(奖励值维度是1)。==输出是两个上下文变量的维度,实际是一个均值变量,一个是方差变量,生成一个正态分布。==最后再从配置文件中传入SAC的其他参数~

agent = SAC(
        observ_dim=observ_dim,
        action_dim=action_dim,
        latent_dim=env_target_config["latent_dim"],
        hidden_dim=hidden_dim,
        encoder_input_dim=observ_dim + action_dim + 1,
        encoder_output_dim=env_target_config["latent_dim"] * 2,
        device=device,
        **env_target_config["sac_params"],
    )

实例化外环训练器 MetaLearner。输入的是元训练集 train_tasks 的任务和元测试集 test_tasks 的任务。还有就是保存和载入的模型文件的断点这些~

meta_learner = MetaLearner(
        env=env,
        env_name=experiment_config["env_name"],
        agent=agent,
        observ_dim=observ_dim,
        action_dim=action_dim,
        train_tasks=tasks[: env_target_config["train_tasks"]],
        test_tasks=tasks[-env_target_config["test_tasks"] :],
        save_exp_name=experiment_config["save_exp_name"],
        save_file_name=experiment_config["save_file_name"],
        load_exp_name=experiment_config["load_exp_name"],
        load_file_name=experiment_config["load_file_name"],
        load_ckpt_num=experiment_config["load_ckpt_num"],
        device=device,
        **env_target_config["pearl_params"],
    )

最后开始进行元训练~

meta_learner.meta_train()

3. 代码流程 meta_learner.meta_train()

元训练的主要代码就在这里了。

total_start_timestart_time 记录的是总的训练时间和每次迭代开始的时间。

self.num_iterations 就是配置文件的 num_iterations ,含义是元训练过程的迭代次数,数值是1000。

条件块 if iteration == 0 :代码在第0次迭代搜集用于训练和验证的状态转移数据,代码在所有的训练集任务上做了数据的收集。用循环变量 index 代表训练集任务 self.train_tasks 的下标,用循环变量 index 指引每个任务,对每个任务做初始化操作,然后开始搜集状态转移数据。self.train_tasks 表示训练集内部的总任务数量,对于方向实验就只有2个(两个反方向),对于速度实验有300个。

self.collect_train_data() 的作用是为每个任务收集若干完整的轨迹作为样本,加入到强化学习经验池子编码器经验池子当中,这里不进行后验推断,这里只是纯粹的搜集。

搜集的方法 self.collect_train_data 在下一节~

代码首先遍历 self.num_sample_tasks 次,从训练任务集合中抽取一个任务环境,重置并清空里面的编码器经验池子。如果 self.num_prior_samples 大于0,则用先验分布(也就是标准正态分布)抽样得到的隐藏层变量然后获得一批 self.num_prior_samples 大小的数据,放入经验池子,此时不进行后验推断。随后,如果 self.num_posterior_samples 大于0,用后验分布抽样得到的隐藏层变量然后获得一批大小为 self.num_posterior_samples 的数据,不放入经验池子,此时进行后验推断。

之后开始进行元梯度更新。在元梯度迭代次数 self.num_meta_grads 内,在 self.train_tasks 内采样self.meta_batch_size 大小的下标,用于指代这么多的任务

我的理解是:比如说,我的 self.train_tasks 是0~299,那么我就抽样 self.meta_batch_size 大小体量的下标作为numpy数组。

清除这些任务的隐藏层变量z,赋值标准正态分布的采样,随后采样上下文 context_batch 和状态转移 transition_batch 数据。

随后执行 self.agent.train_model() 方法进行元训练,然后再调用 self.meta_test() 做元测试。对于早结束的情况做了一些异常的提示。

元训练过程 self.agent.train_model()7. 代码流程 def train_model()

元测试过程 self.meta_test()8. 代码流程 def meta_test()

def meta_train(self) -> None:
    total_start_time: float = time.time()
    
    for iteration in range(self.num_iterations):
        start_time: float = time.time()

        if iteration == 0:
            print("Collecting initial pool of data for train and eval")
            for index in tqdm(self.train_tasks):
                self.env.reset_task(index)
                self.collect_train_data(
                    task_index=index,
                    max_samples=self.num_init_samples,
                    update_posterior=False,
                    add_to_enc_buffer=True,
                )

        print(f"\n=============== Iteration {iteration} ===============")
        for i in range(self.num_sample_tasks):
            index = np.random.randint(len(self.train_tasks))
            self.env.reset_task(index)
            self.encoder_replay_buffer.task_buffers[index].clear()

            if self.num_prior_samples > 0:
                print(f"[{i + 1}/{self.num_sample_tasks}] collecting samples with prior")
                self.collect_train_data(
                    task_index=index,
                    max_samples=self.num_prior_samples,
                    update_posterior=False,
                    add_to_enc_buffer=True,
                )

             if self.num_posterior_samples > 0:
                print(f"[{i + 1}/{self.num_sample_tasks}] collecting samples with posterior")
                self.collect_train_data(
                    task_index=index,
                    max_samples=self.num_posterior_samples,
                    update_posterior=True,
                    add_to_enc_buffer=False,
                )

             print(f"Start meta-gradient updates of iteration {iteration}")
             for i in range(self.num_meta_grads):
                 indices: np.ndarray = np.random.choice(self.train_tasks, self.meta_batch_size)

                self.agent.encoder.clear_z(num_tasks=len(indices))
                context_batch: torch.Tensor = self.sample_context(indices)
                transition_batch: List[torch.Tensor] = self.sample_transition(indices)

                log_values: Dict[str, float] = self.agent.train_model(
                    meta_batch_size=self.meta_batch_size,
                    batch_size=self.batch_size,
                    context_batch=context_batch,
                    transition_batch=transition_batch,
                )

                self.agent.encoder.task_z.detach()
                self.meta_test(iteration, total_start_time, start_time, log_values)

                if self.is_early_stopping:
                    print(
                        f"\n==================================================\n"
                        f"The last {self.num_stop_conditions} meta-testing results are {self.dq}.\n"
                        f"And early stopping condition is {self.is_early_stopping}.\n"
                        f"Therefore, meta-training is terminated.",
                    )
                    break

4. 代码流程 def collect_train_data()

这个函数是用来搜集状态转移数据,正如上一节提到的。

输入变量

  • task_index :任务索引下标,指带具体任务的编号;
  • max_samples :采样的最大状态转移数量;
  • update_posterior:是否进行后验推断;
  • add_to_enc_buffer:是否放进经验池子;

代码里 self.agent.encoder.clear_z() 先把隐藏层z变量赋值标准正态分布的抽样信息。

编码器类在下一节 5. 编码器类 class MLPEncoder(FlattenMLP) 介绍~

self.agent.policy.is_deterministic = False 表示这部分代码的决策,是通过计算得到分布并基于分布采样得到的,而不是直接输出固定确切值。

采样器用于采样 max_samples 数量的若干条完整的轨迹。

采样器类在 6. 采样器类 class Sampler 介绍~

获得完整的轨迹后,将数据存储在第 task_index 任务的经验池子当中。

如果 add_to_enc_buffer 标记为真,那么就将轨迹放置到编码器经验池子 encoder_replay_buffer 当中;

如果 update_posterior 标记为真,那么就从编码器经验池子 encoder_replay_buffer 当中采样数据进行后验推断。

def collect_train_data(
    self,
    task_index: int,
    max_samples: int,
    update_posterior: bool,
    add_to_enc_buffer: bool,
) -> None:
    self.agent.encoder.clear_z()
    self.agent.policy.is_deterministic = False

    cur_samples = 0
    while cur_samples < max_samples:
        trajs, num_samples = self.sampler.obtain_samples(
            max_samples=max_samples - cur_samples,
            update_posterior=update_posterior,
            accum_context=False,
        )
        cur_samples += num_samples

        self.rl_replay_buffer.add_trajs(task_index, trajs)
        if add_to_enc_buffer:
            self.encoder_replay_buffer.add_trajs(task_index, trajs)

            if update_posterior:
                context_batch = self.sample_context(np.array([task_index]))
                self.agent.encoder.infer_posterior(context_batch)

因为 self.sample_context 这个方法和 self.collect_train_data 都在同一个类下,因此就在这里展开~

代码先初始化了批上下文列表:context_batch,然后对抽样得到的任务下标,依次对下标对应的任务抽取 self.batch_size 大小的历史,然后用 np.concatenate() 方法结合起来,最后返回~

def sample_context(self, indices: np.ndarray) -> torch.Tensor:
    context_batch = []
    for index in indices:
        batch = self.encoder_replay_buffer.sample_batch(task=index, batch_size=self.batch_size)
        context_batch.append(
            np.concatenate((batch["cur_obs"], batch["actions"], batch["rewards"]), axis=-1),
        )
        return torch.Tensor(context_batch).to(self.device)

接下来介绍采样状态转移的方法~

实际就是随机采样方法,在此基础上做一些分类归纳,并输出出来~

def sample_transition(self, indices: np.ndarray) -> List[torch.Tensor]:
    cur_obs, actions, rewards, next_obs, dones = [], [], [], [], []
    for index in indices:
        batch = self.rl_replay_buffer.sample_batch(task=index, batch_size=self.batch_size)
        cur_obs.append(batch["cur_obs"])
        actions.append(batch["actions"])
        rewards.append(batch["rewards"])
        next_obs.append(batch["next_obs"])
        dones.append(batch["dones"])

        cur_obs = torch.Tensor(cur_obs).view(len(indices), self.batch_size, -1).to(self.device)
        actions = torch.Tensor(actions).view(len(indices), self.batch_size, -1).to(self.device)
        rewards = torch.Tensor(rewards).view(len(indices), self.batch_size, -1).to(self.device)
        next_obs = torch.Tensor(next_obs).view(len(indices), self.batch_size, -1).to(self.device)
        dones = torch.Tensor(dones).view(len(indices), self.batch_size, -1).to(self.device)
        return [cur_obs, actions, rewards, next_obs, dones]

5. 编码器类 class MLPEncoder(FlattenMLP)

这个是编码器类,在内环训练过程中,需要将状态转移信息用编码器得到隐藏层变量z,用的就是这个类的实例~

初始化输入信息维度 input_dim、输出信息的维度 output_dim、隐藏层变量维度 latent_dim 和计算隐藏层变量的中间层神经元数量 hidden_dim,最后再声明一下显卡配置 device

这个类继承了 FlattenMLP 这个类,在计算时候会调用 FlattenMLP 这个类的 forward 方法,实际上就是加了一层 torch.cat() 操作,把几个张量合并在一起。而 FlattenMLP 这个类继承了 MLP 这个类。 MLP 这个类定义了神经网路,通过python的继承使用语法,网络的输入层维度就是初始化输入信息维度 input_dim、输出层维度就是输出信息的维度 output_dim,中间层的神经元数就是中间层维度 hidden_dim,默认3个中间层。使用 torch.functional.F.relu 激活函数。 MLP 这个类还定义了基本的前向计算 forward(),与我们一般的神经网络差不多。

回到 MLPEncoder 这个类。初始化了隐藏层变量的均值 self.z_mean 、方差 self.z_varself.task_z 均为 None。执行了一下 self.clear_z() 操作。

接下来细说 self.clear_z() 操作:先初始化先验分布为标准正态分布,标准正态分布的维度 = 上下文变量的维度 × 任务数(默认是1),为每个任务抽样上下文变量做准备。

接下来执行 self.sample_z() 操作:首先采用 torch.unbind() 方法对均值张量和方差张量对第0维度做了切片,然后用 zip 捆绑起来,并给 meanvar 赋值,也就是分别为每个任务附加先验的标准正态分布。依次构建标准正态分布并填充于一个列表 dists 。再对列表里面的每个任务的标准正态分布重采样上下文变量,并堆叠起来成为一个张量。

class MLPEncoder(FlattenMLP):
    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        latent_dim: int,
        hidden_dim: int,
        device: torch.device,
    ) -> None:
        super().__init__(input_dim=input_dim, output_dim=output_dim, hidden_dim=hidden_dim)

        self.output_dim = output_dim
        self.latent_dim = latent_dim
        self.device = device

        self.z_mean = None
        self.z_var = None
        self.task_z = None
        self.clear_z()

    def clear_z(self, num_tasks: int = 1) -> None:
        self.z_mean = torch.zeros(num_tasks, self.latent_dim).to(self.device)
        self.z_var = torch.ones(num_tasks, self.latent_dim).to(self.device)

        self.sample_z()

        self.context = None

    def sample_z(self) -> None:
        dists = []
        for mean, var in zip(torch.unbind(self.z_mean), torch.unbind(self.z_var)):
            dist = torch.distributions.Normal(mean, torch.sqrt(var))
            dists.append(dist)
        sampled_z = [dist.rsample() for dist in dists]
        self.task_z = torch.stack(sampled_z).to(self.device)

介绍 infer_posterior() 方法。顾名思义就是进行后验推断。这个方法首先进行了 self.forward(context) 函数,这个函数调用的是父类 FlattenMLP 的方法,也就是根据上下文计算出一组隐藏层变量 params 并转成 “特定行 x 输出维度10”的形式。将一组隐藏层变量 params 拆解出分布均值 z_mean 和分布方差 z_var(之前提到神经网络的输出是隐藏层变量维度的2倍,用在这里了~),并执行 self.product_of_gaussians() 计算出由原本多个均值和方差融合的整个分布的参数 z_params。最后再拆解参数 z_params 得到一组全是均值 self.z_mean 和全是方差 self.z_var 的变量。有了均值 self.z_mean 和方差 self.z_var 变量即可生成分布,因此最后代码执行了 self.sample_z() 抽样隐藏层变量的分布得到隐藏层变量。

def infer_posterior(self, context: torch.Tensor) -> None:
    params = self.forward(context)
    params = params.view(context.size(0), -1, self.output_dim).to(self.device)

    z_mean = torch.unbind(params[..., : self.latent_dim])
    z_var = torch.unbind(F.softplus(params[..., self.latent_dim :]))
    z_params = [self.product_of_gaussians(mu, var) for mu, var in zip(z_mean, z_var)]

    self.z_mean = torch.stack([z_param[0] for z_param in z_params]).to(self.device)
    self.z_var = torch.stack([z_param[1] for z_param in z_params]).to(self.device)
    self.sample_z()

接下来介绍 def product_of_gaussians() 方法。这个函数的意义是将两个高斯分布通过乘积的方式合成一个高斯分布。具体的过程如链接所示:https://blog.csdn.net/qq_41035283/article/details/121015712 。

看完上面的文章,下面代码也就很清晰了。首先对方差变量 vartorch.clamp 方法确保它是正值。然后求出合成的方差 pog_var 。最后合成的均值 pog_mean 就是合成的方差 pog_var 再乘以原本的均值~

@classmethod
def product_of_gaussians(
    cls,
    mean: torch.Tensor,
    var: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    var = torch.clamp(var, min=1e-7)
    pog_var = 1.0 / torch.sum(torch.reciprocal(var), dim=0)
    pog_mean = pog_var * torch.sum(mean / var, dim=0)
    return pog_mean, pog_var

最后介绍 def compute_kl_div() 方法。顾名思义就是计算 KL 散度。

KL散度的详细计算在这里~链接:https://blog.csdn.net/weixin_50752408/article/details/129562144

代码构造了隐藏层变量相同维度的标准正态分布 prior。通过一个列表 posteriors 记录许多后验分布。采用 torch.unbind() 方法对均值张量 self.z_mean 和方差张量 self.z_var 解构,然后用 zip 捆绑起来,并给 meanvar 赋值,也就是分别为每个任务附加先验的后验正态分布。然后依次对后验正态分布先验标准正态分布计算KL散度数值,最后求和并返回。

def compute_kl_div(self) -> torch.Tensor:
    prior = torch.distributions.Normal(
        torch.zeros(self.latent_dim).to(self.device),
        torch.ones(self.latent_dim).to(self.device),
    )

    posteriors = []
    for mean, var in zip(torch.unbind(self.z_mean), torch.unbind(self.z_var)):
        dist = torch.distributions.Normal(mean, torch.sqrt(var))
        posteriors.append(dist)

        kl_div = [torch.distributions.kl.kl_divergence(posterior, prior) for posterior in posteriors]
        kl_div = torch.stack(kl_div).sum().to(self.device)
        return kl_div

6. 采样器类 class Sampler

这个类是用来采样轨迹数据的。

首先来看 def rollout 方法,这个方法的作用是获得一条智能体在环境中运动的数据(only one)。初始化了一系列列表,用于记录智能体与环境交互的各个数据。obs 获得环境的状态信息,done 是是否完成的标志,默认是 False。智能体根据状态信息,获得动作信息;然后作用在环境中获得下一个状态、奖励和是否完成的标志。如果 accum_contextTrue,那么就实时更新上下文信息,否则就没有。然后将状态转移数据记录到一个个列表中,更新 obs 进入下一状态。最后返回的是一个字典,记录了这些信息。

接下来看 def update_context 方法。先把状态信息 obs 、动作信息 action 和奖励 reward 改成高维度的浮点型GPU张量, 然后对这些信息做连接得到 transition 上下文变量。如果编码器的上下文信息 self.agent.encoder.context 是空 None 的话,那么得到 transition 上下文变量赋值给self.agent.encoder.context ;否则就是在原来 self.agent.encoder.context 基础上再增加当前状态转移的上下文。

最后看 def obtain_samples 方法。输入的 max_samples 表示最大采样数量、update_posterior 表示更新后验标记和 accum_context 表示是否累积上下文。这个方法的作用是:获得大小为 max_samples 的状态转移数据,这些数据由若干完整的轨迹变量 trajs 组成。然后还采样了隐藏层变量z,最后进行输出。

class Sampler:
    def __init__(
        self,
        env: HalfCheetahEnv,
        agent: SAC,
        max_step: int,
        device: torch.device,
    ) -> None:

        self.env = env
        self.agent = agent
        self.max_step = max_step
        self.device = device

    def obtain_samples(
        self,
        max_samples: int,
        update_posterior: bool,
        accum_context: bool = True,
    ) -> Tuple[List[Dict[str, np.ndarray]], int]:
        trajs = []
        cur_samples = 0

        while cur_samples < max_samples:
            traj = self.rollout(accum_context=accum_context)
            trajs.append(traj)
            cur_samples += len(traj["cur_obs"])
            self.agent.encoder.sample_z()

            if update_posterior:
                break
        return trajs, cur_samples

    def rollout(self, accum_context: bool = True) -> Dict[str, np.ndarray]:
        _cur_obs = []
        _actions = []
        _rewards = []
        _next_obs = []
        _dones = []
        _infos = []

        obs = self.env.reset()
        done = False
        cur_step = 0

        while not (done or cur_step == self.max_step):
            action = self.agent.get_action(obs)
            next_obs, reward, done, info = self.env.step(action)

            if accum_context:
                self.update_context(obs=obs, action=action, reward=np.array([reward]))

            _cur_obs.append(obs)
            _actions.append(action)
            _rewards.append(reward)
            _next_obs.append(next_obs)
            _dones.append(done)
            _infos.append(info["run_cost"])

            cur_step += 1
            obs = next_obs
        return dict(
            cur_obs=np.array(_cur_obs),
            actions=np.array(_actions),
            rewards=np.array(_rewards).reshape(-1, 1),
            next_obs=np.array(_next_obs),
            dones=np.array(_dones).reshape(-1, 1),
            infos=np.array(_infos),
        )

    def update_context(self, obs: np.ndarray, action: np.ndarray, reward: np.ndarray) -> None:
        obs = obs.reshape((1, 1, *obs.shape))
        action = action.reshape((1, 1, *action.shape))
        reward = reward.reshape((1, 1, *reward.shape))

        obs = torch.from_numpy(obs).float().to(self.device)
        action = torch.from_numpy(action).float().to(self.device)
        reward = torch.from_numpy(reward).float().to(self.device)
        transition = torch.cat([obs, action, reward], dim=-1).to(self.device)

        if self.agent.encoder.context is None:
            self.agent.encoder.context = transition
        else:
            self.agent.encoder.context = torch.cat([self.agent.encoder.context, transition], dim=1).to(
                self.device,
            )
           

7. 代码流程 def train_model()

这个代码就是承接 3. 代码流程 meta_learner.meta_train()train_model() 方法,比较长的方法了~

首先将状态转移变量 transition_batch 拆解出来,得到强化学习的组件信息。然后对这些信息按照一定规格整理以便后续运算。随后进行了后验推断 self.encoder.infer_posterior(context_batch) 得到后验隐藏层变量,由于存在多个任务因此用 torch.cat() 整合成一个张量。

执行 self.encoder.compute_kl_div() 计算所有分布和标准正态分布之间的KL散度,然后再乘以 self.kl_lambda 得到编码器损失 encoder_loss 。最后通过优化器的 zero_grad()backward() 方法进行优化预备。

def train_model(
    self,
    meta_batch_size: int,
    batch_size: int,
    context_batch: torch.Tensor,
    transition_batch: List[torch.Tensor],
) -> Dict[str, float]:
    cur_obs, actions, rewards, next_obs, dones = transition_batch

    cur_obs = cur_obs.view(meta_batch_size * batch_size, -1)
    actions = actions.view(meta_batch_size * batch_size, -1)
    rewards = rewards.view(meta_batch_size * batch_size, -1)
    next_obs = next_obs.view(meta_batch_size * batch_size, -1)
    dones = dones.view(meta_batch_size * batch_size, -1)

    self.encoder.infer_posterior(context_batch)
    task_z = self.encoder.task_z

    task_z = [z.repeat(batch_size, 1) for z in task_z]
    task_z = torch.cat(task_z, dim=0)

    kl_div = self.encoder.compute_kl_div()
    encoder_loss = self.kl_lambda * kl_div
    self.encoder_optimizer.zero_grad()
    encoder_loss.backward(retain_graph=True)

做完预备工作后,代码接下来对Critic网络做了优化。因为内环结构本身是SAC风格的,因此很多部分有相似之处~输入下一观测信息 next_obs 和隐藏层变量 task_z 进入策略中,输出动作 next_policy 和动作的对数概率分布 next_log_policymin_target_q 取动作状态对 (next_obs, next_policy, task_z) 的评估Q值的最小值。计算当前状态的状态价值 target_v ,这与经典SAC代码是一样的。最后计算目标Q值 target_q 。做后通过 loss -> zero_grad() -> backward() -> step() 经典过程优化两个Q网络 self.qf1()self.qf2()

到最后对编码器和Q值网络都做了参数更新~

with torch.no_grad():
    next_inputs = torch.cat([next_obs, task_z], dim=-1)
    next_policy, next_log_policy = self.policy(next_inputs)
    min_target_q = torch.min(
        self.target_qf1(next_obs, next_policy, task_z),
        self.target_qf2(next_obs, next_policy, task_z),
    )
    target_v = min_target_q - self.alpha * next_log_policy
    target_q = rewards + self.gamma * (1 - dones) * target_v

    pred_q1 = self.qf1(cur_obs, actions, task_z)
    pred_q2 = self.qf2(cur_obs, actions, task_z)
    qf1_loss = F.mse_loss(pred_q1, target_q)
    qf2_loss = F.mse_loss(pred_q2, target_q)
    qf_loss = qf1_loss + qf2_loss
    self.qf_optimizer.zero_grad()
    qf_loss.backward()

    self.qf_optimizer.step()
    self.encoder_optimizer.step()

最后一部分主要是对策略 policy 和熵超参数 alpha 做损失。与SAC都是类似的,只是在原来的基础之上增加了后验分布采样 task_z.detach() 。同样,最后都要执行软更新。这个函数的最后部分输出了更新后的策略;Q值网络、编码器网络和熵超参数的损失。以及后验分布的均值和方差。

inputs = torch.cat([cur_obs, task_z.detach()], dim=-1)
policy, log_policy = self.policy(inputs)
min_q = torch.min(
    self.qf1(cur_obs, policy, task_z.detach()),
    self.qf2(cur_obs, policy, task_z.detach()),
)
policy_loss = (self.alpha * log_policy - min_q).mean()
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()

alpha_loss = -(self.log_alpha * (log_policy + self.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
self.alpha = self.log_alpha.exp()

self.soft_target_update(self.qf1, self.target_qf1)
self.soft_target_update(self.qf2, self.target_qf2)
return dict(
    policy_loss=policy_loss.item(),
    qf1_loss=qf1_loss.item(),
    qf2_loss=qf2_loss.item(),
    encoder_loss=encoder_loss.item(),
    alpha_loss=alpha_loss.item(),
    alpha=self.alpha.item(),
    z_mean=self.encoder.z_mean.detach().cpu().numpy().mean().item(),
    z_var=self.encoder.z_var.detach().cpu().numpy().mean().item(),
)

8. 代码流程 def meta_test()

这部分承接了 3. 代码流程 meta_learner.meta_train() ,对已经元训练好的模型进行元测试~

def meta_test(
    self,
    iteration: int,
    total_start_time: float,
    start_time: float,
    log_values: Dict[str, float],
) -> None:

    test_results = {}
    return_before_infer = 0
    return_after_infer = 0
    run_cost_before_infer = np.zeros(self.max_step)
    run_cost_after_infer = np.zeros(self.max_step)

对于元测试集里面的任务,依次获得下标 index 。对 index 索引的任务重置,并获得最大步骤 self.max_step 两倍数量左右的完整轨迹,并设置 update_posteriorTrue 表示进行后验推断。获得推断前和推断后的累计奖励:return_before_inferreturn_after_infer

for index in self.test_tasks:
    self.env.reset_task(index)
    trajs: List[List[Dict[str, np.ndarray]]] = self.collect_test_data(
    max_samples=self.max_step * 2,
    update_posterior=True,
    )

    return_before_infer += np.sum(trajs[0][0]["rewards"])
    return_after_infer += np.sum(trajs[1][0]["rewards"])

承接上一部分的循环体。如果环境名称是速度 "vel" 的话,那么对获得的完整轨迹计算累计的运行损失。接着再对累计奖励做平均值回报。

如果环境名称是速度 "vel" 的话,记录以下信息。

itemrun_cost_before_inferrun_cost_after_infersum_run_cost_before_infersum_run_cost_after_infer
含义推断前的运行损失(info)推断后的运行损失(info)推断前的运行损失(info)的累计总和推断后的运行损失(info)的累计总和
itempolicy_lossqf1_loss,qf2_lossencoder_lossalpha_loss
含义策略损失两个Q值网路的损失编码器的损失熵系数变化的损失
itemz_meanz_vartotal_timetime_per_iter
含义隐藏层变量的均值隐藏层变量的方差总体运行时间每次迭代的运行时间
    if self.env_name == "vel":
        for i in range(self.max_step):
            run_cost_before_infer[i] += trajs[0][0]["infos"][i]
            run_cost_after_infer[i] += trajs[1][0]["infos"][i]

test_results["return_before_infer"] = return_before_infer / len(self.test_tasks)
test_results["return_after_infer"] = return_after_infer / len(self.test_tasks)

if self.env_name == "vel":
    test_results["run_cost_before_infer"] = run_cost_before_infer / len(self.test_tasks)
    test_results["run_cost_after_infer"] = run_cost_after_infer / len(self.test_tasks)
    test_results["sum_run_cost_before_infer"] = sum(
        abs(run_cost_before_infer / len(self.test_tasks)),
    )
    test_results["sum_run_cost_after_infer"] = sum(
        abs(run_cost_after_infer / len(self.test_tasks)),
    )
    test_results["policy_loss"] = log_values["policy_loss"]
    test_results["qf1_loss"] = log_values["qf1_loss"]
    test_results["qf2_loss"] = log_values["qf2_loss"]
    test_results["encoder_loss"] = log_values["encoder_loss"]
    test_results["alpha_loss"] = log_values["alpha_loss"]
    test_results["alpha"] = log_values["alpha"]
    test_results["z_mean"] = log_values["z_mean"]
    test_results["z_var"] = log_values["z_var"]
    test_results["total_time"] = time.time() - total_start_time
    test_results["time_per_iter"] = time.time() - start_time

    self.visualize_within_tensorboard(test_results, iteration)

    if self.env_name == "dir":
        self.dq.append(test_results["return_after_infer"])
        if all(list(map((lambda x: x >= self.stop_goal), self.dq))):
        self.is_early_stopping = True
    elif self.env_name == "vel":
        self.dq.append(test_results["sum_run_cost_after_infer"])
        if all(list(map((lambda x: x <= self.stop_goal), self.dq))):
        self.is_early_stopping = True

最后一部分跟上面的一样,提前退出元测试阶段就报出断点~

if self.is_early_stopping:
    ckpt_path = os.path.join(self.result_path, "checkpoint_" + str(iteration) + ".pt")
    torch.save(
        {
            "policy": self.agent.policy.state_dict(),
            "encoder": self.agent.encoder.state_dict(),
            "qf1": self.agent.qf1.state_dict(),
            "qf2": self.agent.qf2.state_dict(),
            "target_qf1": self.agent.target_qf1.state_dict(),
            "target_qf2": self.agent.target_qf2.state_dict(),
            "log_alpha": self.agent.log_alpha,
            "alpha": self.agent.alpha,
            "rl_replay_buffer": self.rl_replay_buffer,
            "encoder_replay_buffer": self.encoder_replay_buffer,
        },
        ckpt_path,
    )

项目分析

  1. 代码上计算KL的散度,是计算得到的后验分布和标准正态分布的KL散度,而不是这一次迭代的分布和上一次迭代的分布的KL散度。代码将KL散度乘以一个具体常数值作为策略损失。
  2. 作者在元训练时用的是随机抽样,也就是具体抽样的是状态转移历史,而不是轨迹(有顺序的状态转移)
  3. 不理解的问题是:为什么在元训练期间要经历两次相似的采样?翻译了一下代码的韩文,他的意思是强化学习要在隐藏层变量由标准正态分布采样隐藏层变量由后验分布采样的共同轨迹(融合成经验池)上进行学习~所以现在标准正态分布采样下搜集数据,然后在后验分布上采样得到隐藏层变量再搜集数据。

进展节点

2023-03-22-23-00:第一次编辑(1-6小节)

2023-03-23-15-27:第二次编辑(7-8小节)

2023-03-27-15-12:第三次编辑(修改表述、高亮内容和项目分析)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ctrl+Alt+L

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值