强化学习:A2C求解倒立摆问题代码

1.问题背景

倒立摆问题的问题背景就不再赘述了,在实现过程中用到了python的gym库。导入该环境的过程代码如下:

# 倒立摆网络
env = gym.make("CartPole-v0")
env.reset()
print("env_state:{}".format(env.state))
print("env_step(0):{}".format(env.step(0)))

在此之前需要导入的库为:

import numpy as np
import torch
# 导入torch的各种模块
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import Categorical
import gym
import warnings
warnings.filterwarnings('ignore')
import random
from collections import deque
import matplotlib.pyplot as plt

2.A2C训练原理

class policyNet(nn.Module):
    def __init__(self,state_num,action_num,hidden_num=40,lr=0.01):
        super(policyNet,self).__init__()
        self.fc1 = nn.Linear(state_num,hidden_num)
        self.action_fc = nn.Linear(hidden_num,action_num)
        self.state_fc = nn.Linear(hidden_num,1)
        self.optimzer = torch.optim.Adam(self.parameters(),lr=lr)
#         # 动作与回报的缓存
#         self.saved_states = []
#         self.saved_log_pi = []
#         self.saved_values = []
#         self.saved_actions = []
#         self.saved_rewards = []
    # 前馈
    def forward(self,state):
        hidden_output = self.fc1(state)
        x = F.relu(hidden_output)
        action_prob = self.action_fc(x)
        action_prob = F.softmax(action_prob,dim=-1)
        state_values = self.state_fc(x)
        return action_prob,state_values
      # 选择动作
    def select_action(self,state):
        # 输入state为numpy数据类型,要先进行数据转化
        # 输出lnpi(a|s),v(s),r
        state = torch.from_numpy(state).float()
        probs,state_value = self.forward(state)
        m = Categorical(probs)
        action = m.sample()
        # 存储数据
        log_pi_as = m.log_prob(action)
#         self.saved_states.append(state)
#         self.saved_log_pi.append(log_pi_as)
#         self.saved_values.append(state_value)
#         self.saved_actions.append(action)
        # 返回action
        return action.item(),log_pi_as
    # 对每个回合执行以下动作
    # 更新策略
    def update_policy(self,env,state,gamma=0.9):
        # 更新策略
        # state:numpy类型数据
        # next_state:numpy类型数据
        action,ln_pi_a_s = self.select_action(state)
        next_state,reward,done,_ = env.step(action)
#         self.saved_rewards.append(reward)
        _,next_state_values = self.forward(torch.from_numpy(next_state).float())
        _,state_values = self.forward(torch.from_numpy(state).float())
        U = reward + gamma*next_state_values*(1 - done)
        loss_theta = -(U - state_values)*ln_pi_a_s
        loss_w = 0.5*(U - state_values)**2
        self.optimzer.zero_grad()
        loss = loss_theta + loss_w
        loss.backward()
        self.optimzer.step()
#         # 更新完后重置缓存
#         del self.saved_actions[:]
#         del self.saved_log_pi[:]
#         del self.saved_rewards[:]
#         del self.saved_states[:]
#         del self.saved_values[:]
        return next_state,reward,done
    def train_network(self,env,epsiodes=100,gamma=0.9):
        epsiode_array = []
        mean_array = []
        for epsiode in range(epsiodes):
            epsiode_reward = 0.0
            state = env.reset()
            while True:
                next_state,reward,done = self.update_policy(env,state,gamma=gamma)
                epsiode_reward += reward
                if done:
                    break
                state = next_state
            epsiode_array.append(epsiode_reward)
            mean_array.append(np.mean(epsiode_array))
            print("第{}次迭代的奖励为{},平均奖励为{}".format(epsiode,epsiode_reward,int(mean_array[-1])))
        return epsiode_array,mean_array

3.主函数训练网络

policynet = policyNet(state_num=4,action_num=2)
epsiode_array,mean_array = policynet.train_network(env,epsiodes=3000,gamma=0.9)
plt.plot(epsiode_array)
plt.plot(mean_array)

结果为:
给出epsiodes=3000时的结果,可以看到训练到2000多次时网络在每个epsiode已经能达到最高的奖励效果:
在这里插入图片描述
作为对比,初期奖励如下:
在这里插入图片描述
收敛曲线如下:
在这里插入图片描述

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值