A3C
核心思想:Global NetWork和每个Worker网络结构是一样的。拿每个Worker网络去训练,训练完之后,拿自己的梯度去更新Global NetWork梯度。Global NetWork再将自己的梯度去更新每个Worker的梯度
代码实现
参考一个博主的github
a3c_main.py
import gym
import numpy as np
import matplotlib.pyplot as plt
from Actor_Critic.A3C.agent_a3c import A3C
def get_env_prop(env_name, continuous):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
if continuous:
action_dim = env.action_space.shape[0]
else:
action_dim = env.action_space.n
return env,state_dim, action_dim
def train_a3c(env_name,continuous):
env,state_size,action_size = get_env_prop(env_name,continuous)
agent = A3C(env,continuous,state_size,action_size)
scores = agent.train_worker()
return scores
def train_agent_for_env(env_name,continuous):
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
if continuous:
action_dim = env.action_space.shape[0]
else:
action_dim = env.action_space.n
agent = A3C(env, continuous,state_dim,action_dim)
scores = agent.train_worker()
return agent,scores
def plot_scores(scores,filename):
fig = plt.figure()
ax = fig.add_subplot(111)
plt.plot(np.arange(1, len(scores) + 1), scores)
plt.ylabel('Score')
plt.xlabel('Episode #')
plt.savefig(filename)
plt.show()
if __name__ == "__main__":
# env = gym.make("Pendulum-v0")
# train_scores = train_a3c(env,True)
# train A3C on discrete env : CartPole
scores_cartPole = train_agent_for_env("CartPole-v0",False)
plot_scores(scores_cartPole,"cartPole_trainPlot.png")
# train A3C on continuous env : continuous
# a3c_mCar = train_agent_for_env("MountainCarContinuous-v0", True)
agent_a3c.py
import random
import torch
import torch.optim as optim
import multiprocessing as mp
from multiprocessing import Process
from Actor_Critic.A3C.untils import ValueNetwork,ActorDiscrete,ActorContinous
from Actor_Critic.A3C.worker import Worker
GAMMA = 0.9
LR = 1e-4
GLOBAL_MAX_EPISODE = 5000
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class A3C():
def __init__(self,env,continuous,state_size,action_size):
self.max_episode=GLOBAL_MAX_EPISODE
#pytorch的多进程可以参考https://zhuanlan.zhihu.com/p/328271397
self.global_episode = mp.Value('i', 0) # 进程之间共享的变量
#用法例子如下
#self.global_episode .value += 1
self.global_epi_rew = mp.Value('d',0)
self.rew_queue = mp.Queue() #self.rew_queue是一个队列 可以往里面存储值和取出值 和普通队列的作用一样。但队列不能修改变量的值,只能存储值。而共享变量可以修改变量的值
self.worker_num = mp.cpu_count()
# define the global networks
self.global_valueNet= ValueNetwork(state_size,1).to(device)
# global 的网络参数放入 shared memory,以便复制给各个进程中的 worker网络
self.global_valueNet.share_memory()
if continuous:
self.global_policyNet = ActorContinous(state_size, action_size).to(device)
else:
self.global_policyNet = ActorDiscrete(state_size, action_size).to(device)
self.global_policyNet.share_memory()
# global optimizer
self.global_optimizer_policy = optim.Adam(self.global_policyNet.parameters(), lr=LR)
self.global_optimizer_value = optim.Adam(self.global_valueNet.parameters(),lr=LR)
# define the workers
self.workers=[Worker(env,continuous,state_size,action_size,i,
self.global_valueNet,self.global_optimizer_value,
self.global_policyNet,self.global_optimizer_policy,
self.global_episode,self.global_epi_rew,self.rew_queue,
self.max_episode,GAMMA)
for i in range(self.worker_num)]
def train_worker(self):
scores=[]
[w.start() for w in self.workers]
while True:
r = self.rew_queue.get()
if r is not None:
scores.append(r)
else:
break
[w.join() for w in self.workers]
return scores
def save_model(self):
torch.save(self.global_valueNet.state_dict(), "a3c_value_model.pth")
torch.save(self.global_policyNet.state_dict(), "a3c_policy_model.pth")
untils.py
import torch
from collections import namedtuple
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
from torch.distributions import Normal
class ValueNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super(ValueNetwork, self).__init__()
self.fc1 = nn.Linear(input_dim, 256)
self.fc2 = nn.Linear(256, output_dim)
def forward(self, state):
value = F.relu(self.fc1(state))
value = self.fc2(value)
return value
class ActorDiscrete(nn.Module):
"""
用于离散动作空间的策略网络
"""
def __init__(self,state_size,action_size):
super(ActorDiscrete, self).__init__()
self.seed = torch.manual_seed(0)
self.fc1 = nn.Linear(state_size, 128)
# self.fc2 = nn.Linear(64,128)
self.fc2= nn.Linear(128, action_size)
def forward(self, x):
"""
Build a network that maps state -> action probs.
"""
x=F.relu(self.fc1(x))
out = F.softmax(self.fc2(x),dim=1)
return out
def act(self,state):
"""
返回 action 和 action的概率
"""
# probs for each action (2d tensor)
probs = self.forward(state)
m = Categorical(probs)
action = m.sample()
# return action for current state, and the corresponding probability
return action.item(),probs[:,action.item()].item()
class ActorContinous(nn.Module):
"""
用于连续动作空间的策略网络
"""
def __init__(self,state_size,action_size):
super(ActorContinous, self).__init__()
self.fc1 = nn.Linear(state_size, 128)
self.fc2 = nn.Linear(128,128)
self.mu_head = nn.Linear(128, action_size)
self.sigma_head = nn.Linear(128, action_size)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
mu = 2.0 * torch.tanh(self.mu_head(x))
sigma = F.softplus(self.sigma_head(x))
return (mu, sigma)
def act(self,state):
"""
返回 action 和 action 的 log prob
"""
with torch.no_grad():
(mu, sigma) = self.policy(state) # 2d tensors
dist = Normal(mu, sigma)
action = dist.sample()
action_log_prob = dist.log_prob(action)
return action.numpy()[0], action_log_prob.numpy()[0]
worker.py
import math
import torch.multiprocessing as mp
import torch
import torch.nn.functional as F
from torch.distributions import Normal
from Actor_Critic.A3C.untils import ValueNetwork,ActorDiscrete,ActorContinous
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class Worker(mp.Process):
def __init__(self,env,continuous,state_size,action_size,id, global_valueNet,global_value_optimizer,
global_policyNet,global_policy_optimizer,
global_epi,global_epi_rew,rew_queue,
max_epi,gamma):
super(Worker, self).__init__()
# define env for individual worker
self.env = env
self.continuous = continuous
self.name = str(id)
self.env.seed(id)
self.state_size = state_size
self.action_size = action_size
self.memory=[]
# passing global settings to worker
self.global_valueNet,self.global_value_optimizer = global_valueNet,global_value_optimizer
self.global_policyNet,self.global_policy_optimizer = global_policyNet,global_policy_optimizer
self.global_epi,self.global_epi_rew = global_epi,global_epi_rew
self.rew_queue = rew_queue
self.max_epi = max_epi
# self.batch_size = batch_size
self.gamma = gamma
# define local net for individual worker
self.local_policyNet = ActorDiscrete(self.state_size,self.action_size).to(device)
if self.continuous:
self.local_policyNet = ActorContinous(self.state_size,self.action_size).to(device)
self.local_valueNet = ValueNetwork(self.state_size,1).to(device)
def sync_global(self):
self.local_valueNet.load_state_dict(self.global_valueNet.state_dict())
self.local_policyNet.load_state_dict(self.global_policyNet.state_dict())
def calculate_loss(self):
# get experiences from current trajectory
states = torch.tensor([t[0] for t in self.memory], dtype=torch.float)
#好像这里有点问题 因为self.memory.append([state,action,reward,next_state,done])
#所以t[1]是action?
#105行和107行代码
log_probs = torch.tensor([t[1] for t in self.memory], dtype=torch.float)
# -- calculate discount future rewards for every time step
rewards = [t[2] for t in self.memory]
fur_Rewards = []
for i in range(len(rewards)):
discount = [self.gamma ** i for i in range(len(rewards) - i)]
f_rewards = rewards[i:]
fur_Rewards.append(sum(d * f for d, f in zip(discount, f_rewards)))
fur_Rewards = torch.tensor(fur_Rewards, dtype=torch.float).view(-1, 1)
# calculate loss for critic
V = self.local_valueNet(states)
value_loss = F.mse_loss(fur_Rewards, V)
# compute entropy for policy loss
(mu, sigma) = self.local_policyNet(states)
dist = Normal(mu, sigma)
entropy = 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(dist.scale) # exploration
# calculate loss for actor
advantage = (fur_Rewards - V).detach()
policy_loss = -advantage * log_probs
policy_loss = (policy_loss - 0.005 * entropy).mean()
return value_loss,policy_loss
def update_global(self):
value_loss, policy_loss = self.calculate_loss()
self.global_value_optimizer.zero_grad()
value_loss.backward()
# propagate local gradients to global parameters
for local_params, global_params in zip(self.local_valueNet.parameters(), self.global_valueNet.parameters()):
global_params._grad = local_params._grad
self.global_value_optimizer.step()
self.global_policy_optimizer.zero_grad()
#求梯度
policy_loss.backward()
# propagate local gradients to global parameters
for local_params, global_params in zip(self.local_policyNet.parameters(), self.global_policyNet.parameters()):
#复制梯度
global_params._grad = local_params._grad
#更新参数w和b
self.global_policy_optimizer.step()
self.memory=[] # clear trajectory
def run(self):
while self.global_epi.value < self.max_epi:
state = self.env.reset()
total_reward=0
while True:
state = torch.from_numpy(state).float().unsqueeze(0).to(device)
action, prob = self.local_policyNet.act(state) # 离散空间取直接prob,连续空间取log prob
next_state, reward, done, _ = self.env.step(action)
self.memory.append([state,action,reward,next_state,done])
total_reward += reward
state = next_state
if done:
# recoding global episode and episode reward
with self.global_epi.get_lock():
self.global_epi.value += 1
with self.global_epi_rew.get_lock():
if self.global_epi_rew.value == 0.:
self.global_epi_rew.value = total_reward
else:
# Moving average reward
self.global_epi_rew.value = self.global_epi_rew.value * 0.99 + total_reward * 0.01
self.rew_queue.put(self.global_epi_rew.value)
print("w{} | episode: {}\t , episode reward:{:.4} \t "
.format(self.name,self.global_epi.value,self.global_epi_rew.value))
break
# update and sync with the global net when finishing an episode
self.update_global()
self.sync_global()
self.rew_queue.put(None)