【蘑菇书】【A2C】

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
from multiprocessing import Process, Pipe
import argparse
import gym

建立Actor和Critic网络

class ActorCritic(nn.Module):
    ''' A2C网络模型,包含一个Actor和Critic
    '''
    def __init__(self, input_dim, output_dim, hidden_dim):
        super(ActorCritic, self).__init__()
        self.critic = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        
        self.actor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Softmax(dim=1),
        )
        
    def forward(self, x):
        value = self.critic(x)
        probs = self.actor(x)
        dist  = Categorical(probs)
        return dist, value
class A2C:
    ''' A2C算法
    '''
    def __init__(self,n_states,n_actions,cfg) -> None:
        self.gamma = cfg.gamma
        self.device = cfg.device
        self.model = ActorCritic(n_states, n_actions, cfg.hidden_size).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters())

    def compute_returns(self,next_value, rewards, masks):
        R = next_value
        returns = []
        for step in reversed(range(len(rewards))):
            R = rewards[step] + self.gamma * R * masks[step]
            returns.insert(0, R)
        return returns
def make_envs(env_name):
    def _thunk():
        env = gym.make(env_name)
        env.seed(2)
        return env
    return _thunk
def test_env(env,model,vis=False):
    state = env.reset()
    if vis: env.render()
    done = False
    total_reward = 0
    while not done:
        state = torch.FloatTensor(state).unsqueeze(0).to(cfg.device)
        dist, _ = model(state)
        next_state, reward, done, _ = env.step(dist.sample().cpu().numpy()[0])
        state = next_state
        if vis: env.render()
        total_reward += reward
    return total_reward

def compute_returns(next_value, rewards, masks, gamma=0.99):
    R = next_value
    returns = []
    for step in reversed(range(len(rewards))):
        R = rewards[step] + gamma * R * masks[step]
        returns.insert(0, R)
    return returns


def train(cfg,envs):
    print('Start training!')
    print(f'Env:{cfg.env_name}, Algorithm:{cfg.algo_name}, Device:{cfg.device}')
    env = gym.make(cfg.env_name) # a single env
    env.seed(10)
    n_states  = envs.observation_space.shape[0]
    n_actions = envs.action_space.n
    model = ActorCritic(n_states, n_actions, cfg.hidden_dim).to(cfg.device)
    optimizer = optim.Adam(model.parameters())
    step_idx    = 0
    test_rewards = []
    test_ma_rewards = []
    state = envs.reset()
    while step_idx < cfg.max_steps:
        log_probs = []
        values    = []
        rewards   = []
        masks     = []
        entropy = 0
        # rollout trajectory
        for _ in range(cfg.n_steps):
            state = torch.FloatTensor(state).to(cfg.device)
            dist, value = model(state)
            action = dist.sample()
            next_state, reward, done, _ = envs.step(action.cpu().numpy())
            log_prob = dist.log_prob(action)
            entropy += dist.entropy().mean()
            log_probs.append(log_prob)
            values.append(value)
            rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(cfg.device))
            masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(cfg.device))
            state = next_state
            step_idx += 1
            if step_idx % 200 == 0:
                test_reward = np.mean([test_env(env,model) for _ in range(10)])
                print(f"step_idx:{step_idx}, test_reward:{test_reward}")
                test_rewards.append(test_reward)
                if test_ma_rewards:
                    test_ma_rewards.append(0.9*test_ma_rewards[-1]+0.1*test_reward)
                else:
                    test_ma_rewards.append(test_reward) 
                # plot(step_idx, test_rewards)   
        next_state = torch.FloatTensor(next_state).to(cfg.device)
        _, next_value = model(next_state)
        returns = compute_returns(next_value, rewards, masks)
        log_probs = torch.cat(log_probs)
        returns   = torch.cat(returns).detach()
        values    = torch.cat(values)
        advantage = returns - values
        actor_loss  = -(log_probs * advantage.detach()).mean()
        critic_loss = advantage.pow(2).mean()
        loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print('Finish training!')
    return test_rewards, test_ma_rewards
import matplotlib.pyplot as plt
import seaborn as sns 
def plot_rewards(rewards, ma_rewards, cfg, tag='train'):
    sns.set()
    plt.figure()  # 创建一个图形实例,方便同时多画几个图
    plt.title("learning curve on {} of {} for {}".format(
        cfg.device, cfg.algo_name, cfg.env_name))
    plt.xlabel('epsiodes')
    plt.plot(rewards, label='rewards')
    plt.plot(ma_rewards, label='ma rewards')
    plt.legend()
    plt.show()
import easydict
from common.multiprocessing_env import SubprocVecEnv
cfg = easydict.EasyDict({
        "algo_name": 'A2C',
        "env_name": 'CartPole-v0',
        "n_envs": 8,
        "max_steps": 20000,
        "n_steps":5,
        "gamma":0.99,
        "lr": 1e-3,
        "hidden_dim": 256,
        "device":torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
})
envs = [make_envs(cfg.env_name) for i in range(cfg.n_envs)]
envs = SubprocVecEnv(envs) 
rewards,ma_rewards = train(cfg,envs)
plot_rewards(rewards, ma_rewards, cfg, tag="train") # 画出结果
Start training!
Env:CartPole-v0, Algorithm:A2C, Device:cuda
step_idx:200, test_reward:18.6
step_idx:400, test_reward:19.7
step_idx:600, test_reward:24.2
step_idx:800, test_reward:19.5
step_idx:1000, test_reward:33.9
step_idx:1200, test_reward:36.1
step_idx:1400, test_reward:32.6
step_idx:1600, test_reward:36.3
step_idx:1800, test_reward:38.9
step_idx:2000, test_reward:60.8
step_idx:2200, test_reward:41.9
step_idx:2400, test_reward:42.2
step_idx:2600, test_reward:71.6
step_idx:2800, test_reward:123.6
step_idx:3000, test_reward:57.5
step_idx:3200, test_reward:155.4
step_idx:3400, test_reward:111.4
step_idx:3600, test_reward:133.8
step_idx:3800, test_reward:133.8
step_idx:4000, test_reward:114.3
step_idx:4200, test_reward:165.5
step_idx:4400, test_reward:119.4
step_idx:4600, test_reward:173.4
step_idx:4800, test_reward:115.4
step_idx:5000, test_reward:159.7
step_idx:5200, test_reward:178.1
step_idx:5400, test_reward:137.8
step_idx:5600, test_reward:146.0
step_idx:5800, test_reward:187.4
step_idx:6000, test_reward:200.0
step_idx:6200, test_reward:169.2
step_idx:6400, test_reward:167.8
step_idx:6600, test_reward:184.3
step_idx:6800, test_reward:162.3
step_idx:7000, test_reward:125.4
step_idx:7200, test_reward:150.6
step_idx:7400, test_reward:152.6
step_idx:7600, test_reward:122.5
step_idx:7800, test_reward:136.3
step_idx:8000, test_reward:131.4
step_idx:8200, test_reward:174.6
step_idx:8400, test_reward:91.7
step_idx:8600, test_reward:170.1
step_idx:8800, test_reward:166.0
step_idx:9000, test_reward:150.2
step_idx:9200, test_reward:104.6
step_idx:9400, test_reward:147.2
step_idx:9600, test_reward:111.8
step_idx:9800, test_reward:118.7
step_idx:10000, test_reward:102.6
step_idx:10200, test_reward:99.0
step_idx:10400, test_reward:64.6
step_idx:10600, test_reward:133.7
step_idx:10800, test_reward:119.7
step_idx:11000, test_reward:112.6
step_idx:11200, test_reward:116.1
step_idx:11400, test_reward:116.3
step_idx:11600, test_reward:116.2
step_idx:11800, test_reward:115.3
step_idx:12000, test_reward:109.7
step_idx:12200, test_reward:110.3
step_idx:12400, test_reward:131.4
step_idx:12600, test_reward:128.3
step_idx:12800, test_reward:128.8
step_idx:13000, test_reward:119.8
step_idx:13200, test_reward:108.6
step_idx:13400, test_reward:128.4
step_idx:13600, test_reward:138.2
step_idx:13800, test_reward:119.1
step_idx:14000, test_reward:140.7
step_idx:14200, test_reward:145.3
step_idx:14400, test_reward:154.1
step_idx:14600, test_reward:165.2
step_idx:14800, test_reward:138.2
step_idx:15000, test_reward:143.5
step_idx:15200, test_reward:125.4
step_idx:15400, test_reward:137.1
step_idx:15600, test_reward:150.1
step_idx:15800, test_reward:132.9
step_idx:16000, test_reward:140.4
step_idx:16200, test_reward:141.3
step_idx:16400, test_reward:135.5
step_idx:16600, test_reward:135.5
step_idx:16800, test_reward:125.6
step_idx:17000, test_reward:126.8
step_idx:17200, test_reward:124.7
step_idx:17400, test_reward:129.6
step_idx:17600, test_reward:114.3
step_idx:17800, test_reward:57.3
step_idx:18000, test_reward:164.7
step_idx:18200, test_reward:165.8
step_idx:18400, test_reward:196.7
step_idx:18600, test_reward:198.8
step_idx:18800, test_reward:200.0
step_idx:19000, test_reward:199.6
step_idx:19200, test_reward:189.5
step_idx:19400, test_reward:177.9
step_idx:19600, test_reward:159.3
step_idx:19800, test_reward:127.7
step_idx:20000, test_reward:143.6
Finish training!

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

资源存储库

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值