import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
from multiprocessing import Process, Pipe
import argparse
import gym
建立Actor和Critic网络
class ActorCritic(nn.Module):
''' A2C网络模型,包含一个Actor和Critic
'''
def __init__(self, input_dim, output_dim, hidden_dim):
super(ActorCritic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.actor = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
nn.Softmax(dim=1),
)
def forward(self, x):
value = self.critic(x)
probs = self.actor(x)
dist = Categorical(probs)
return dist, value
class A2C:
''' A2C算法
'''
def __init__(self,n_states,n_actions,cfg) -> None:
self.gamma = cfg.gamma
self.device = cfg.device
self.model = ActorCritic(n_states, n_actions, cfg.hidden_size).to(self.device)
self.optimizer = optim.Adam(self.model.parameters())
def compute_returns(self,next_value, rewards, masks):
R = next_value
returns = []
for step in reversed(range(len(rewards))):
R = rewards[step] + self.gamma * R * masks[step]
returns.insert(0, R)
return returns
def make_envs(env_name):
def _thunk():
env = gym.make(env_name)
env.seed(2)
return env
return _thunk
def test_env(env,model,vis=False):
state = env.reset()
if vis: env.render()
done = False
total_reward = 0
while not done:
state = torch.FloatTensor(state).unsqueeze(0).to(cfg.device)
dist, _ = model(state)
next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])
state = next_state
if vis: env.render()
total_reward += reward
return total_reward
def compute_returns(next_value, rewards, masks, gamma=0.99):
R = next_value
returns = []
for step in reversed(range(len(rewards))):
R = rewards[step] + gamma * R * masks[step]
returns.insert(0, R)
return returns
def train(cfg,envs):
print('Start training!')
print(f'Env:{cfg.env_name}, Algorithm:{cfg.algo_name}, Device:{cfg.device}')
env = gym.make(cfg.env_name) # a single env
env.seed(10)
n_states = envs.observation_space.shape[0]
n_actions = envs.action_space.n
model = ActorCritic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
optimizer = optim.Adam(model.parameters())
step_idx = 0
test_rewards = []
test_ma_rewards = []
state = envs.reset()
while step_idx < cfg.max_steps:
log_probs = []
values = []
rewards = []
masks = []
entropy = 0
# rollout trajectory
for _ in range(cfg.n_steps):
state = torch.FloatTensor(state).to(cfg.device)
dist, value = model(state)
action = dist.sample()
next_state, reward, done, _ = envs.step(action.cpu().numpy())
log_prob = dist.log_prob(action)
entropy += dist.entropy().mean()
log_probs.append(log_prob)
values.append(value)
rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(cfg.device))
masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(cfg.device))
state = next_state
step_idx += 1
if step_idx % 200 == 0:
test_reward = np.mean([test_env(env,model) for _ in range(10)])
print(f"step_idx:{step_idx}, test_reward:{test_reward}")
test_rewards.append(test_reward)
if test_ma_rewards:
test_ma_rewards.append(0.9*test_ma_rewards[-1]+0.1*test_reward)
else:
test_ma_rewards.append(test_reward)
# plot(step_idx, test_rewards)
next_state = torch.FloatTensor(next_state).to(cfg.device)
_, next_value = model(next_state)
returns = compute_returns(next_value, rewards, masks)
log_probs = torch.cat(log_probs)
returns = torch.cat(returns).detach()
values = torch.cat(values)
advantage = returns - values
actor_loss = -(log_probs * advantage.detach()).mean()
critic_loss = advantage.pow(2).mean()
loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Finish training!')
return test_rewards, test_ma_rewards
import matplotlib.pyplot as plt
import seaborn as sns
def plot_rewards(rewards, ma_rewards, cfg, tag='train'):
sns.set()
plt.figure() # 创建一个图形实例,方便同时多画几个图
plt.title("learning curve on {} of {} for {}".format(
cfg.device, cfg.algo_name, cfg.env_name))
plt.xlabel('epsiodes')
plt.plot(rewards, label='rewards')
plt.plot(ma_rewards, label='ma rewards')
plt.legend()
plt.show()
import easydict
from common.multiprocessing_env import SubprocVecEnv
cfg = easydict.EasyDict({
"algo_name": 'A2C',
"env_name": 'CartPole-v0',
"n_envs": 8,
"max_steps": 20000,
"n_steps":5,
"gamma":0.99,
"lr": 1e-3,
"hidden_dim": 256,
"device":torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
})
envs = [make_envs(cfg.env_name) for i in range(cfg.n_envs)]
envs = SubprocVecEnv(envs)
rewards,ma_rewards = train(cfg,envs)
plot_rewards(rewards, ma_rewards, cfg, tag="train") # 画出结果
Start training!
Env:CartPole-v0, Algorithm:A2C, Device:cuda
step_idx:200, test_reward:18.6
step_idx:400, test_reward:19.7
step_idx:600, test_reward:24.2
step_idx:800, test_reward:19.5
step_idx:1000, test_reward:33.9
step_idx:1200, test_reward:36.1
step_idx:1400, test_reward:32.6
step_idx:1600, test_reward:36.3
step_idx:1800, test_reward:38.9
step_idx:2000, test_reward:60.8
step_idx:2200, test_reward:41.9
step_idx:2400, test_reward:42.2
step_idx:2600, test_reward:71.6
step_idx:2800, test_reward:123.6
step_idx:3000, test_reward:57.5
step_idx:3200, test_reward:155.4
step_idx:3400, test_reward:111.4
step_idx:3600, test_reward:133.8
step_idx:3800, test_reward:133.8
step_idx:4000, test_reward:114.3
step_idx:4200, test_reward:165.5
step_idx:4400, test_reward:119.4
step_idx:4600, test_reward:173.4
step_idx:4800, test_reward:115.4
step_idx:5000, test_reward:159.7
step_idx:5200, test_reward:178.1
step_idx:5400, test_reward:137.8
step_idx:5600, test_reward:146.0
step_idx:5800, test_reward:187.4
step_idx:6000, test_reward:200.0
step_idx:6200, test_reward:169.2
step_idx:6400, test_reward:167.8
step_idx:6600, test_reward:184.3
step_idx:6800, test_reward:162.3
step_idx:7000, test_reward:125.4
step_idx:7200, test_reward:150.6
step_idx:7400, test_reward:152.6
step_idx:7600, test_reward:122.5
step_idx:7800, test_reward:136.3
step_idx:8000, test_reward:131.4
step_idx:8200, test_reward:174.6
step_idx:8400, test_reward:91.7
step_idx:8600, test_reward:170.1
step_idx:8800, test_reward:166.0
step_idx:9000, test_reward:150.2
step_idx:9200, test_reward:104.6
step_idx:9400, test_reward:147.2
step_idx:9600, test_reward:111.8
step_idx:9800, test_reward:118.7
step_idx:10000, test_reward:102.6
step_idx:10200, test_reward:99.0
step_idx:10400, test_reward:64.6
step_idx:10600, test_reward:133.7
step_idx:10800, test_reward:119.7
step_idx:11000, test_reward:112.6
step_idx:11200, test_reward:116.1
step_idx:11400, test_reward:116.3
step_idx:11600, test_reward:116.2
step_idx:11800, test_reward:115.3
step_idx:12000, test_reward:109.7
step_idx:12200, test_reward:110.3
step_idx:12400, test_reward:131.4
step_idx:12600, test_reward:128.3
step_idx:12800, test_reward:128.8
step_idx:13000, test_reward:119.8
step_idx:13200, test_reward:108.6
step_idx:13400, test_reward:128.4
step_idx:13600, test_reward:138.2
step_idx:13800, test_reward:119.1
step_idx:14000, test_reward:140.7
step_idx:14200, test_reward:145.3
step_idx:14400, test_reward:154.1
step_idx:14600, test_reward:165.2
step_idx:14800, test_reward:138.2
step_idx:15000, test_reward:143.5
step_idx:15200, test_reward:125.4
step_idx:15400, test_reward:137.1
step_idx:15600, test_reward:150.1
step_idx:15800, test_reward:132.9
step_idx:16000, test_reward:140.4
step_idx:16200, test_reward:141.3
step_idx:16400, test_reward:135.5
step_idx:16600, test_reward:135.5
step_idx:16800, test_reward:125.6
step_idx:17000, test_reward:126.8
step_idx:17200, test_reward:124.7
step_idx:17400, test_reward:129.6
step_idx:17600, test_reward:114.3
step_idx:17800, test_reward:57.3
step_idx:18000, test_reward:164.7
step_idx:18200, test_reward:165.8
step_idx:18400, test_reward:196.7
step_idx:18600, test_reward:198.8
step_idx:18800, test_reward:200.0
step_idx:19000, test_reward:199.6
step_idx:19200, test_reward:189.5
step_idx:19400, test_reward:177.9
step_idx:19600, test_reward:159.3
step_idx:19800, test_reward:127.7
step_idx:20000, test_reward:143.6
Finish training!