强化学习时序差分算法之Sarsa算法——以悬崖漫步环境为例

1.导入必要的库环境,代码如下所示。

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

2.本悬崖漫步环境中无需提供奖励函数以及状态转移函数,而需提供一个与智能体进行交互的step()函数,该函数输入为智能体当前状态下的动作,输出为当前状态下的奖励以及智能体的下一状态,代码如下所示。

class CliffWalkingEnv:
    def __init__(self,ncol,nrow,step_reward,cliff_reward):
        self.ncol=ncol
        self.nrow=nrow
        self. x=0#记录当前智能体位置的横坐标
        self.y=self.nrow-1#记录当前智能体位置的纵坐标
        self.step_reward=step_reward
        self.cliff_reward=cliff_reward
    def step(self,action):#外部调用此函数改变当前位置
        change=[[0,-1],[0,1],[-1,0],[1,0]]#定义四个动作,change[0]:上;change[1]:下;change[2]:左;change[3]:右;坐标系原点(0,0)定义在左上角
        self.x=np.clip(self.x+change[action][0],0,self.ncol-1)#也可采用self.x=min(self.ncol-1,max(0,self.x+change[action][0]))
        self.y=np.clip(self.y+change[action][1],0,self.nrow-1)#也可采用self.y=min(self.nrow-1,max(0,self.y+change[action][1]))
        next_state=self.ncol*self.y+self.x#计算下一个状态
        reward=self.step_reward
        done=False
        if self.y==self.nrow-1 and self.x>0:#如果当前位置在悬崖或者终点
            done=True
            if self.x!=self.ncol-1:#如果在悬崖
                reward=self.cliff_reward
        return next_state,reward,done
    def reset(self):#环境重置
        self.x=0
        self.y=self.nrow-1
        return self.y*self.ncol+self.x

3.实现Sarsa算法,维护一个Q_table()表格,本表格主要存储当前策略下所有状态动作对的价值,即所有状态下各个动作的动作价值函数,Sarsa算法与环境进行交互,采用\varepsilon-贪婪策略进行采样,使用时序差分算法进行Sarsa算法更新。本程序默认终止状态时所有动作的价值为0,即终止状态的动作价值函数均为0,其在初始化为0后不会在与环境交互过程中进行更新。具体代码如下所示。

class Sarsa:
    """ Sarsa算法 """
    def __init__(self,ncol,nrow,epsilon,alpha,gamma,n_action=4):
        self.ncol=ncol
        self.nrow=nrow
        self.epsilon=epsilon
        self.alpha=alpha
        self.gamma=gamma
        self.n_action=n_action
        self.Q_table=np.zeros([self.ncol*self.nrow,self.n_action])#初始化Q(s,a)表格
    def take_action(self,state):
        if np.random.random()<self.epsilon:#如果小于epsilon,则随机选择动作
            action=np.random.randint(self.n_action)
        else:
            action=np.argmax(self.Q_table[state])
        return action
    def best_action(self,state):
        Q_max=np.max(self.Q_table[state])
        a=[0]*self.n_action#or a=[0 for _ in range(self.n_action)]
        for i in range(self.n_action):#若两动作价值一样为最大,则会记录下来
            if self.Q_table[state,i]==Q_max:
                a[i]=1
        return a
    def update(self,s0,a0,r,s1,a1):#Sarsa算法核心部分,采用时序差分算法估计动作价值函数Q,Q(s0,a0)~Q(s0,a0)+alpha(r0+gamma*Q(s1,a1)-Q(s0,a0))
        TD_error=r+self.gamma*self.Q_table[s1,a1]-self.Q_table[s0,a0]
        self.Q_table[s0,a0]+=self.alpha*TD_error

4.本案例悬崖漫步环境中相关参数设置如下所示,可以自行修改。

ncol=12#悬崖漫步环境中的网格环境列数
nrow=4#悬崖漫步环境中的网格环境行数
step_reward=-1#每步的即时奖励
cliff_reward=-100#悬崖的即时奖励
epsilon=0.1#epsilon-贪婪算法的探索因子
alpha=0.1 #价值估计更新的步长
gamma=0.9 #回报计算的折扣衰减因子
n_action=4 #动作个数

5.主程序实现部分如下所示。

env=CliffWalkingEnv(ncol,nrow,step_reward,cliff_reward)
agent=Sarsa(ncol,nrow,epsilon,alpha,gamma,n_action)
num_episode=500#智能体在环境中运行的序列的数量
episode_Gt_list=[]#记录每个序列的回报
pbar_num=10#进度条的数量
for i in range(pbar_num):#显示每个进度条
    with tqdm(total=num_episode/pbar_num,desc='Episode %d'%i) as pbar:
        for episode in range(int(num_episode/pbar_num)):#每个进度条的序列数量
            number=1#记录每个序列的步数
            episode_Gt=0
            state=env.reset()
            action=agent.take_action(state)
            done=False
            while not done:
                next_state,reward,done=env.step(action)
                next_action=agent.take_action(next_state)
                episode_Gt+=reward #回报的计算不进行折扣因子衰减,考虑远期最优。
                agent.update(state,action,reward,next_state,next_action)
                state=next_state
                action=next_action
                number+=1
            episode_Gt_list.append(episode_Gt)
            if (episode+1)%10==0:#每10条序列打印一下这十条序列的平均回报
                pbar.set_postfix({'episode':'%d' %((num_episode/pbar_num)*i+episode+1),'return':'%.3f'%np.mean(episode_Gt_list[-10:],axis=None)})
            pbar.update(1)
    print('\n')
episodes_list=list(range(len(episode_Gt_list)))
plt.plot(episodes_list,episode_Gt_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Sarsa on {}'.format('Cliff Walking'))#or "plt.title('Sarsa on %s'%'Cliff Walking') "
plt.show()

6.结果如图所示。

Episode 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 276.35it/s, episode=50, return=-119.100]


Episode 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 345.86it/s, episode=100, return=-73.600]


Episode 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 421.77it/s, episode=150, return=-33.200]


Episode 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 586.43it/s, episode=200, return=-38.700]


Episode 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 668.64it/s, episode=250, return=-27.600]


Episode 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 787.75it/s, episode=300, return=-19.800]


Episode 6: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 891.04it/s, episode=350, return=-19.200] 


Episode 7: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 875.73it/s, episode=400, return=-21.100] 


Episode 8: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 844.25it/s, episode=450, return=-29.200] 


Episode 9: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50.0 [00:00<00:00, 904.60it/s, episode=500, return=-17.800] 
 

 我们可以发现随着训练的进行,Sarsa算法获得的回报越来越高,在进行500条序列的学习后,可以获得-20左右的回报,此时非常接近最优策略了。

7.如下所示的程序是实现查看Sarsa算法得到的策略在各状态下会使智能体采取何种动作的功能。

def print_agent(agent,env,action_meaning,disaster=[],end=[]):
    for i in range(env.nrow):
        for j in range(env.ncol):
            if (i*env.ncol+j) in disaster:
                print('****',end=' ')
            elif (i*env.ncol+j) in end:
                print('EEEE',end=' ')
            else:
                a=agent.best_action(i*env.ncol+j)
                pi_str=''
                for k in range(len(action_meaning)):
                    pi_str+=action_meaning[k] if a[k]>0 else 'o'
                print(pi_str,end=' ')
        print()
action_meaning=['^','v','<','>']
print('Sarsa算法最终收敛得到的策略为:')
print_agent(agent,env,action_meaning,disaster=[range(37,47)],end=[47])

8.结果如下所示:

ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo 
^ooo ^ooo ooo> ooo> ooo> ooo> ooo> ooo> ooo> ovoo ooo> ovoo
^ooo oo<o oo<o ^ooo ^ooo oo<o ^ooo ooo> oo<o ooo> ooo> ovoo
^ooo **** **** **** **** **** **** **** **** **** **** EEEE

可以发现Sarsa算法会采取比较远离悬崖的策略来抵达目标。 

  • 10
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值