理论 模仿学习
修改了警告和报错
运行环境
Debian GNU/Linux 12
Python 3.9.19
torch 2.0.1
gym 0.26.2
运行代码
#!/usr/bin/env python
import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utils
class PolicyNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(PolicyNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
return F.softmax(self.fc2(x), dim=1)
class ValueNet(torch.nn.Module):
def __init__(self, state_dim, hidden_dim):
super(ValueNet, self).__init__()
self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, 1)
def forward(self, x):
x = F.relu(self.fc1(x))
return self.fc2(x)
class PPO:
''' PPO算法,采用截断方式 '''
def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
lmbda, epochs, eps, gamma, device):
self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
self.critic = ValueNet(state_dim, hidden_dim).to(device)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
lr=actor_lr)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
lr=critic_lr)
self.gamma = gamma
self.lmbda = lmbda
self.epochs = epochs # 一条序列的数据用于训练轮数
self.eps = eps # PPO中截断范围的参数
self.device = device
def take_action(self, state):
state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)
probs = self.actor(state)
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
return action.item()
def update(self, transition_dict):
states = torch.tensor(np.array(transition_dict['states']),
dtype=torch.float).to(self.device)
actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
self.device)
rewards = torch.tensor(transition_dict['rewards'],
dtype=torch.float).view(-1, 1).to(self.device)
next_states = torch.tensor(np.array(transition_dict['next_states']),
dtype=torch.float).to(self.device)
dones = torch.tensor(transition_dict['dones'],
dtype=torch.float).view(-1, 1).to(self.device)
td_target = rewards + self.gamma * self.critic(next_states) * (1 -
dones)
td_delta = td_target - self.critic(states)
advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,
td_delta.cpu()).to(self.device)
old_log_probs = torch.log(self.actor(states).gather(1,
actions)).detach()
for _ in range(self.epochs):
log_probs = torch.log(self.actor(states).gather(1, actions))
ratio = torch.exp(log_probs - old_log_probs)
surr1 = ratio * advantage
surr2 = torch.clamp(ratio, 1 - self.eps,
1 + self.eps) * advantage # 截断
actor_loss = torch.mean(-torch.min(surr1, surr2)) # PPO损失函数
critic_loss = torch.mean(
F.mse_loss(self.critic(states), td_target.detach()))
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
actor_loss.backward()
critic_loss.backward()
self.actor_optimizer.step()
self.critic_optimizer.step()
actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 250
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
env_name = 'CartPole-v1'
env = gym.make(env_name)
env.reset(seed=0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,
epochs, eps, gamma, device)
return_list = rl_utils.train_on_policy_agent(env, ppo_agent, num_episodes)
def sample_expert_data(n_episode):
states = []
actions = []
for episode in range(n_episode):
state = env.reset()[0]
done = False
while not done and len(states) < 10000:
action = ppo_agent.take_action(state)
states.append(state)
actions.append(action)
next_state, reward, done, _, __ = env.step(action)
state = next_state
return np.array(states), np.array(actions)
env.reset(seed=0)
torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)
n_samples = 30 # 采样30个数据
random_index = random.sample(range(expert_s.shape[0]), n_samples)
expert_s = expert_s[random_index]
expert_a = expert_a[random_index]
class BehaviorClone:
def __init__(self, state_dim, hidden_dim, action_dim, lr):
self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
def learn(self, states, actions):
states = torch.tensor(states, dtype=torch.float).to(device)
actions = torch.tensor(actions).view(-1, 1).to(device)
log_probs = torch.log(self.policy(states).gather(1, actions))
bc_loss = torch.mean(-log_probs) # 最大似然估计
self.optimizer.zero_grad()
bc_loss.backward()
self.optimizer.step()
def take_action(self, state):
state = torch.tensor(np.array([state]), dtype=torch.float).to(device)
probs = self.policy(state)
action_dist = torch.distributions.Categorical(probs)
action = action_dist.sample()
return action.item()
def test_agent(agent, env, n_episode):
return_list = []
for episode in range(n_episode):
episode_return = 0
state = env.reset()[0]
done = False
while not done:
action = agent.take_action(state)
next_state, reward, done, _, __ = env.step(action)
state = next_state
episode_return += reward
return_list.append(episode_return)
return np.mean(return_list)
env.reset(seed=0)
torch.manual_seed(0)
np.random.seed(0)
lr = 1e-3
bc_agent = BehaviorClone(state_dim, hidden_dim, action_dim, lr)
n_iterations = 1000
batch_size = 64
test_returns = []
with tqdm(total=n_iterations, desc="进度条") as pbar:
for i in range(n_iterations):
sample_indices = np.random.randint(low=0,
high=expert_s.shape[0],
size=batch_size)
bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])
current_return = test_agent(bc_agent, env, 5)
test_returns.append(current_return)
if (i + 1) % 10 == 0:
pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})
pbar.update(1)
iteration_list = list(range(len(test_returns)))
plt.plot(iteration_list, test_returns)
plt.xlabel('Iterations')
plt.ylabel('Returns')
plt.title('BC on {}'.format(env_name))
plt.show()
class Discriminator(nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
super(Discriminator, self).__init__()
self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
self.fc2 = torch.nn.Linear(hidden_dim, 1)
def forward(self, x, a):
cat = torch.cat([x, a], dim=1)
x = F.relu(self.fc1(cat))
return torch.sigmoid(self.fc2(x))
class GAIL:
def __init__(self, agent, state_dim, action_dim, hidden_dim, lr_d):
self.discriminator = Discriminator(state_dim, hidden_dim,
action_dim).to(device)
self.discriminator_optimizer = torch.optim.Adam(
self.discriminator.parameters(), lr=lr_d)
self.agent = agent
def learn(self, expert_s, expert_a, agent_s, agent_a, next_s, dones):
expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)
expert_actions = torch.tensor(expert_a).to(device)
agent_states = torch.tensor(np.array(agent_s), dtype=torch.float).to(device)
agent_actions = torch.tensor(agent_a).to(device)
expert_actions = F.one_hot(expert_actions, num_classes=2).float()
agent_actions = F.one_hot(agent_actions, num_classes=2).float()
expert_prob = self.discriminator(expert_states, expert_actions)
agent_prob = self.discriminator(agent_states, agent_actions)
discriminator_loss = nn.BCELoss()(
agent_prob, torch.ones_like(agent_prob)) + nn.BCELoss()(
expert_prob, torch.zeros_like(expert_prob))
self.discriminator_optimizer.zero_grad()
discriminator_loss.backward()
self.discriminator_optimizer.step()
rewards = -torch.log(agent_prob).detach().cpu().numpy()
transition_dict = {
'states': agent_s,
'actions': agent_a,
'rewards': rewards,
'next_states': next_s,
'dones': dones
}
self.agent.update(transition_dict)
env.reset(seed=0)
torch.manual_seed(0)
lr_d = 1e-3
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,
epochs, eps, gamma, device)
gail = GAIL(agent, state_dim, action_dim, hidden_dim, lr_d)
n_episode = 500
return_list = []
with tqdm(total=n_episode, desc="进度条") as pbar:
for i in range(n_episode):
episode_return = 0
state = env.reset()[0]
done = False
state_list = []
action_list = []
next_state_list = []
done_list = []
while not done and len(state_list) < 10000:
action = agent.take_action(state)
next_state, reward, done, _, __ = env.step(action)
state_list.append(state)
action_list.append(action)
next_state_list.append(next_state)
done_list.append(done)
state = next_state
episode_return += reward
return_list.append(episode_return)
gail.learn(expert_s, expert_a, state_list, action_list,
next_state_list, done_list)
if (i + 1) % 10 == 0:
pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})
pbar.update(1)
iteration_list = list(range(len(return_list)))
plt.plot(iteration_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('GAIL on {}'.format(env_name))
plt.show()
rl_utils.py