注意:inner policy的训练算法只是基本的PG,所以训练过程极不稳定。如有需要可以自己试试调参,或者把inner policy的训练算法改成更稳定的比如PPO等方法。
import numpy as np
import torch
import torch.nn as nn
import gym
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
class NN(nn.Module):
def __init__(self, state_size, action_size, hidden_size, num_options):
super().__init__()
self.actors = nn.ModuleList([
nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, action_size),
nn.Softmax(dim=-1)
) for _ in range(num_options)
])
self.terminations = nn.ModuleList([
nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
) for _ in range(num_options)
])
self.critics = nn.ModuleList([
nn.Sequential(
nn.Linear(state_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, action_size),
) for _ in range(num_options)
])
def select_option(self, state, epsilon):
if np.random.rand() >= epsilon:
max_value = - np.inf
option_id = -1
for i, (a, c) in enumerate(zip(self.actors, self.critics)):
q = c(state)
p = a(state)
v = (q * p).sum(-1).item()
if v >= max_value:
option_id = i
max_value = v
else:
option_id = np.random.randint(0, len(self.actors))
return self.actors[option_id], self.terminations[option_id], option_id
if __name__ == '__main__':
np.random.seed(0)
episodes = 5000
epsilon = 1.0
discount = 0.9
epsilon_decay = 0.995
epsilon_min = 0.05
training_epochs = 1
env = gym.make('CartPole-v1')
nn = NN(4, 2, 128, 6)
optimizer = torch.optim.Adam(nn.parameters(), lr=1e-2)
max_score = 0.0
trajectory = []
for e in range(1, episodes + 1):
if e % training_epochs == 0:
trajectory = []
score = 0.0
state, _ = env.reset()
option = nn.select_option(torch.tensor(state), epsilon)
while True:
policy = option[0](torch.tensor(state))
action = Categorical(policy).sample()
next_state, reward, done, _, _ = env.step(action.detach().numpy())
score += reward
beta = option[1](torch.tensor(next_state)).item()
if np.random.rand() > beta:
trajectory.append(
(state, action, reward, next_state, done, option[2], beta, False)
)
else:
trajectory.append(
(state, action, reward, next_state, done, option[2], beta, True)
)
option = nn.select_option(torch.tensor(next_state), epsilon)
state = next_state
if done: break
if e % training_epochs == 0:
optimizer.zero_grad()
q_targets = []
option_states = []
option_advs = []
option_next_states = []
for state, action, reward, next_state, done, option_id, beta, option_terminal in trajectory:
q = reward + (1 - done) * discount * (
(1 - beta) * (
nn.critics[option_id](torch.tensor(next_state)) *
nn.actors[option_id](torch.tensor(next_state))
).sum(-1).item() +
beta * max([
(
nn.critics[i](torch.tensor(next_state)) *
nn.actors[i](torch.tensor(next_state))
).sum(-1).item()
for i in range(len(nn.critics))
])
)
q_target = nn.critics[option_id](torch.tensor(state)).detach().numpy()
q_target[action] = q
q_targets.append(q_target)
option_states.append(state)
inner_next_value = (
nn.critics[option_id](torch.tensor(next_state)).detach().numpy() *
nn.actors[option_id](torch.tensor(next_state)).detach().numpy()
).sum(-1).item()
next_value = max([(
nn.critics[i](torch.tensor(next_state)).detach().numpy() *
nn.actors[i](torch.tensor(next_state)).detach().numpy()
).sum(-1).item() for i in range(len(nn.critics))])
option_adv = inner_next_value - next_value
option_advs.append(option_adv)
option_next_states.append(next_state)
if option_terminal:
option_states = torch.tensor(np.array(option_states))
q_targets = torch.tensor(np.array(q_targets))
option_advs = torch.tensor(np.array(option_advs)).view(-1, 1)
option_next_states = torch.tensor(np.array(option_next_states))
option_critic_loss = F.mse_loss(
nn.critics[option_id](option_states),
q_targets
)
actor_advs = q_targets - nn.critics[option_id](option_states).detach()
option_actor_loss = - (torch.log(nn.actors[option_id](option_states)) * actor_advs).mean()
option_terminal_loss = (nn.terminations[option_id](option_next_states) * option_advs).mean()
option_critic_loss.backward()
option_actor_loss.backward()
option_terminal_loss.backward()
q_targets = []
option_states = []
option_advs = []
option_next_states = []
optimizer.step()
if epsilon > epsilon_min:
epsilon *= epsilon_decay
if score > max_score:
max_score = score
torch.save(nn, 'NN.pt')
print("Episode: {}/{}, Epsilon: {}, Score: {}, Max score: {}".format(e, episodes, epsilon, score, max_score))