【分层强化学习】Option Critic 的 CartPole-v1 的简单实例

14 篇文章 0 订阅
12 篇文章 0 订阅

注意:inner policy的训练算法只是基本的PG,所以训练过程极不稳定。如有需要可以自己试试调参,或者把inner policy的训练算法改成更稳定的比如PPO等方法。

import numpy as np
import torch
import torch.nn as nn

import gym

import torch.nn.functional as F

from torch.distributions.categorical import Categorical

class NN(nn.Module):

    def __init__(self, state_size, action_size, hidden_size, num_options):
        super().__init__()

        self.actors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, action_size),
                nn.Softmax(dim=-1)
            ) for _ in range(num_options)
        ])

        self.terminations = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
                nn.Sigmoid()
            ) for _ in range(num_options)
        ])

        self.critics = nn.ModuleList([
            nn.Sequential(
                nn.Linear(state_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, action_size),
            ) for _ in range(num_options)
        ])

    def select_option(self, state, epsilon):
        # print("change option")
        if np.random.rand() >= epsilon:

            max_value = - np.inf
            option_id = -1

            for i, (a, c) in enumerate(zip(self.actors, self.critics)):

                q = c(state)
                p = a(state)

                v = (q * p).sum(-1).item()

                if v >= max_value:
                    option_id = i
                    max_value = v

        else:
            option_id = np.random.randint(0, len(self.actors))

        return self.actors[option_id], self.terminations[option_id], option_id

if __name__ == '__main__':

    np.random.seed(0)

    episodes = 5000

    epsilon = 1.0
    discount = 0.9
    epsilon_decay = 0.995
    epsilon_min = 0.05

    training_epochs = 1

    env = gym.make('CartPole-v1')

    nn = NN(4, 2, 128, 6)
    # nn = torch.load("NN.pt")
    optimizer = torch.optim.Adam(nn.parameters(), lr=1e-2)

    max_score = 0.0

    trajectory = []

    for e in range(1, episodes + 1):

        if e % training_epochs == 0:

            trajectory = []

        score = 0.0

        state, _ = env.reset()

        option = nn.select_option(torch.tensor(state), epsilon)

        while True:

            policy = option[0](torch.tensor(state))
            action = Categorical(policy).sample()

            next_state, reward, done, _, _ = env.step(action.detach().numpy())

            score += reward

            beta = option[1](torch.tensor(next_state)).item()

            if np.random.rand() > beta:

                trajectory.append(
                    (state, action, reward, next_state, done, option[2], beta, False)
                )

            else:

                trajectory.append(
                    (state, action, reward, next_state, done, option[2], beta, True)
                )

                option = nn.select_option(torch.tensor(next_state), epsilon)

            state = next_state

            if done: break

        # start training
        if e % training_epochs == 0:
            optimizer.zero_grad()

            q_targets = []
            option_states = []
            option_advs = []
            option_next_states = []

            for state, action, reward, next_state, done, option_id, beta, option_terminal in trajectory:

                q = reward + (1 - done) * discount * (
                        (1 - beta) * (
                        nn.critics[option_id](torch.tensor(next_state)) *
                        nn.actors[option_id](torch.tensor(next_state))
                ).sum(-1).item() +
                        beta * max([
                    (
                            nn.critics[i](torch.tensor(next_state)) *
                            nn.actors[i](torch.tensor(next_state))
                    ).sum(-1).item()
                    for i in range(len(nn.critics))
                ])
                )

                q_target = nn.critics[option_id](torch.tensor(state)).detach().numpy()
                q_target[action] = q

                q_targets.append(q_target)
                option_states.append(state)

                inner_next_value = (
                        nn.critics[option_id](torch.tensor(next_state)).detach().numpy() *
                        nn.actors[option_id](torch.tensor(next_state)).detach().numpy()
                ).sum(-1).item()

                next_value = max([(
                                          nn.critics[i](torch.tensor(next_state)).detach().numpy() *
                                          nn.actors[i](torch.tensor(next_state)).detach().numpy()
                                  ).sum(-1).item() for i in range(len(nn.critics))])

                option_adv = inner_next_value - next_value

                option_advs.append(option_adv)
                option_next_states.append(next_state)

                if option_terminal:

                    option_states = torch.tensor(np.array(option_states))
                    q_targets = torch.tensor(np.array(q_targets))
                    option_advs = torch.tensor(np.array(option_advs)).view(-1, 1)
                    option_next_states = torch.tensor(np.array(option_next_states))

                    option_critic_loss = F.mse_loss(
                        nn.critics[option_id](option_states),
                        q_targets
                    )

                    actor_advs = q_targets - nn.critics[option_id](option_states).detach()

                    option_actor_loss = - (torch.log(nn.actors[option_id](option_states)) * actor_advs).mean()

                    option_terminal_loss = (nn.terminations[option_id](option_next_states) * option_advs).mean()

                    option_critic_loss.backward()
                    option_actor_loss.backward()
                    option_terminal_loss.backward()

                    q_targets = []
                    option_states = []
                    option_advs = []
                    option_next_states = []

            optimizer.step()

        if epsilon > epsilon_min:

            epsilon *= epsilon_decay

        if score > max_score:

            max_score = score

            torch.save(nn, 'NN.pt')

        print("Episode: {}/{}, Epsilon: {}, Score: {}, Max score: {}".format(e, episodes, epsilon, score, max_score))

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值