倒立摆_Q-Learning算法_边做边学深度强化学习:PyTorch程序设计实践(4)

倒立摆_Q-Learning算法_边做边学深度强化学习:PyTorch程序设计实践(4)

0、相关系列文章

迷宫_随机实验_边做边学深度强化学习:PyTorch程序设计实践(1)
迷宫_Sarsa算法_边做边学深度强化学习:PyTorch程序设计实践(2)
迷宫_Q-Learning算法_边做边学深度强化学习:PyTorch程序设计实践(3)
倒立摆_DQN算法_边做边学深度强化学习:PyTorch程序设计实践(5)

1、Agent.py

import numpy as np
import Brain

# 倒立摆小推车对象
class Agent:
    def __init__(self, num_states, num_actions):
        # 为智能体创建大脑以作出决策
        self.brain = Brain.Brain(num_states, num_actions)
    
    # 更新Q函数
    def update_Q_function(self, observation, action, reward, observation_next):
        self.brain.update_Q_table(observation, action, reward, observation_next)
    
    # 确定下一个动作
    def get_action(self, observation, step):
        action = self.brain.decide_action(observation, step)
        return action

    def print_Q(self):
        self.brain.print_Q()

2、Brain.py

import numpy as np
import Brain

# 倒立摆小推车对象
class Agent:
    def __init__(self, num_states, num_actions):
        # 为智能体创建大脑以作出决策
        self.brain = Brain.Brain(num_states, num_actions)
    
    # 更新Q函数
    def update_Q_function(self, observation, action, reward, observation_next):
        self.brain.update_Q_table(observation, action, reward, observation_next)
    
    # 确定下一个动作
    def get_action(self, observation, step):
        action = self.brain.decide_action(observation, step)
        return action

    def print_Q(self):
        self.brain.print_Q()

3、Environment.py

import numpy as np
import matplotlib.pyplot as plt
import datetime
import gym
import Agent
import Val


# 参考URL http://nbviewer.jupyter.org/github/patrickmineault/xcorr-notebooks/blob/master/Render%20OpenAI%20gym%20as%20GIF.ipynb
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display
from IPython.display import HTML


# 定义Environment类,如果连续10次站立195步或更多,则说明强化学习成功,然后再运行一次以保持成功后的动画
class Environment:
    def __init__(self):
        self.env = gym.make(Val.get_value('ENV')) # 设置要执行的任务
        num_states = self.env.observation_space.shape[0] # 获取任务状态的个数
        num_actions = self.env.action_space.n # 获取CartPole的动作数为2
        self.agent = Agent.Agent(num_states,num_actions) # 创建在环境中行动的Agent


    # 将运行状态保存为动画
    def display_frames_as_gif(self,frames):
        """
        Displays a list of frames as a gif, with controls
        """
        plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0),
                dpi=72)
        patch = plt.imshow(frames[0])
        plt.axis('off')

        def animate(i):
            patch.set_data(frames[i])

        anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames),
                                    interval=50)

        #anim.save('result/cartpole_QLearning.mp4')  # 保存动画
        anim.save('result/cartpole_QLearning'+ datetime.datetime.now().strftime('-%m-%d-%H-%M-%S') +'.gif',writer='pillow')
        #display(display_animation(anim, default_mode='loop'))


    '''
    observation, reward, done, info = env.step(action)是将游戏推进一步的指令  
    observation表示小车和杆的状态,包含小车位置、小车速度、杆的角度、杆的角速度  
    reward,是即时奖励  
    done, 在结束状态时为True  
    info,包含调试信息  
    '''
    def run(self):
        complete_episodes = 0 # 持续超过195步的实验次数
        is_episode_final = False # 最终试验的标志
        frames = [] # 用于存储视频图像的变量

        for episode in range(Val.get_value('NUM_EPISODES')): # 试验的最大重复次数
            observation = self.env.reset() # 环境初始化

            for step in range(Val.get_value('MAX_STEPS')): # 每个回合的循环
                if is_episode_final is True:
                    frames.append(self.env.render(mode='rgb_array'))

                # 求取动作
                action = self.agent.get_action(observation,episode)
                # 通过执行动作a_t 找到 s_{t+1},r_{t+1}
                observation_next,_,done,_ = self.env.step(action)

                # 给予奖励
                if  done:
                    # 如果步数超过200,或者如果倾斜超过某个角度,则done为True
                    if step < Val.get_value('NUM_KEEP_TIMES'):
                        reward = -1 # 如果半途摔倒,给予奖励 -1 作为惩罚
                        complete_episodes = 0 # 站立超过195步,重置试验次数
                    else:
                        reward = 1 # 一直站立到结束时给予奖励 1
                        complete_episodes +=1 # 更新连续记录
                else:
                    reward = 0 # 途中奖励为 0
                
                # 使用 step_1 的状态 observation_next 更新Q函数
                self.agent.update_Q_function(observation,action,reward,observation_next)

                observation = observation_next

                if done:
                    print('{0} Episode:Finished after {1} time steps'.format(episode,step + 1))
                    break

            # 在最后一次试验中保存并绘制动画
            if is_episode_final is True:
                self.agent.print_Q()
                self.display_frames_as_gif(frames)
                break

            if complete_episodes >= 10:
                print('10回合连续成功')
                is_episode_final = True

4、Val.py

实现全局变量

# _*_ coding:utf-8 _*_

'''
在main中,
import Val

#使用一下命令初始化
Val._init()

'''

def _init():
    global _global_dict
    _global_dict = {}

def set_value(key,value):
    _global_dict[key] = value

def get_value(key,defValue=None):
    try:
        return _global_dict[key]
    except KeyError:
        return -1

5、main.py

# 导入所使用的包

import Environment
import Val

if __name__ == '__main__':
    Val._init()

    # 定义常量
    Val.set_value('ENV','CartPole-v0')# 要使用的任务名称
    Val.set_value('NUM_DIZITIZED',6)# 将每个状态划分为离散值的个数
    Val.set_value('GAMMA',0.99) # 时间折扣率
    Val.set_value('ETA',0.5)# 学习系数
    Val.set_value('NUM_KEEP_TIMES',195)# 站立保持次数,超过即为成功
    Val.set_value('MAX_STEPS',200)# 一次试验的步数
    Val.set_value('NUM_EPISODES',1000)# 最大试验次数

    cartpole_env = Environment.Environment()
    cartpole_env.run()

6、最终结果

在这里插入图片描述

7、代码下载

跳转到下载地址

8、参考资料

[1]边做边学深度强化学习:PyTorch程序设计实践

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

sethnieTech

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值