Sarsa

Sarsa

sarsa是强化学习中的一种Model-free的on-policy控制方法,更新公式如下:
在这里插入图片描述

agent.py

import numpy as np

class SarsaAgent(object):
    def __init__(self,obs_n,act_n,learning_rate=0.01,gamma=0.9,e_greed=0.1):
        self.act_n=act_n #动作维度,有几个动作选择
        self.lr=learning_rate #学习率
        self.gamma=gamma #reward的衰减率
        self.epsilon=e_greed #按一定的概率随机选择动作
        self.Q=np.zeros((obs_n,act_n)) #创建obs_nXact_n的Q表
    # 根据输入观察值,采样输出的动作值,带探索
    def sample(self,obs):
        if np.random.uniform(0,1)<(1.0-self.epsilon):
            action=self.predict(obs)
        else:
            action=np.random.choice(self.act_n)
        return action
    #根据输入观察值,预测输出的动作值
    def predict(self,obs):
        Q_list=self.Q[obs,:]
        maxQ=np.max(Q_list)
        action_list=np.where(Q_list==maxQ)[0]#maxQ可能对应多个action
        action=np.random.choice(action_list)
        return action

    def learn(self,obs,action,reward,next_obs,next_action,done):
        """ on-policy
                   obs: 交互前的obs, s_t
                   action: 本次交互选择的action, a_t
                   reward: 本次动作获得的奖励r
                   next_obs: 本次交互后的obs, s_t+1
                   next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1
                   done: episode是否结束
        """
        predict_Q=self.Q[obs,action]
        if done:
            target_Q=reward
        else:
            target_Q=reward+self.gamma*self.Q[next_obs,next_action]#sarsa

        self.Q[obs,action]+=self.lr*(target_Q-predict_Q) #修正q

    # 保存Q表格数据到文件
    def save(self):
        npy_file = './q_table.npy'
        np.save(npy_file, self.Q)
        print(npy_file + ' saved.')

    # 从文件中读取Q值到Q表格中
    def restore(self, npy_file='./q_table.npy'):
        self.Q = np.load(npy_file)
        print(npy_file + ' loaded.')



train.py

import gym
from gridworld import CliffWalkingWapper,FrozenLakeWapper
from agent import SarsaAgent
import time

def run_episode(env,agent,render=False):
    total_steps=0 #记录每个episode走了多少step
    total_reward=0

    obs=env.reset() #重置环境, 重新开一局(即开始新的一个episode)
    action=agent.sample(obs) #根据算法选择一个动作

    while True:
       next_obs,reward,done,_=env.step(action) #与环境进行一个交互
       next_action=agent.sample(next_obs) #根据算法选择一个动作
       #训练Sarsa算法
       agent.learn(obs,action,reward,next_obs,next_action,done)

       action=next_action
       obs=next_obs #储存上一个观察值
       total_reward+=reward
       total_steps+=1
       if render:
           env.render()
       if done:
           break
    return total_reward,total_steps

def test_episode(env,agent):
    total_reward=0
    obs=env.reset()
    while True:
        action=agent.predict(obs) #greedy
        next_obs,reward,done,_=env.step(action)
        total_reward+=reward
        obs=next_obs
        time.sleep(0.5)
        env.render()
        if done:
            print('test reward=%.1f'%(total_reward))
            break

def main():

    env=env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left
    env=CliffWalkingWapper(env)

    agent=SarsaAgent(obs_n=env.observation_space.n,act_n=env.action_space.n,
                     learning_rate=0.1,gamma=0.9,e_greed=0.1)
    is_render=False

    for episode in range(500):
        ep_reward,ep_steps=run_episode(env,agent,is_render)
        print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps,ep_reward))

        #每隔20个episode渲染一下看看效果
        if episode%20==0:
            is_render=True
        else:
            is_render=False

    test_episode(env,agent)
if __name__=="__main__":
    main()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值