Sarsa
sarsa是强化学习中的一种Model-free的on-policy控制方法,更新公式如下:
agent.py
import numpy as np
class SarsaAgent(object):
def __init__(self,obs_n,act_n,learning_rate=0.01,gamma=0.9,e_greed=0.1):
self.act_n=act_n #动作维度,有几个动作选择
self.lr=learning_rate #学习率
self.gamma=gamma #reward的衰减率
self.epsilon=e_greed #按一定的概率随机选择动作
self.Q=np.zeros((obs_n,act_n)) #创建obs_nXact_n的Q表
# 根据输入观察值,采样输出的动作值,带探索
def sample(self,obs):
if np.random.uniform(0,1)<(1.0-self.epsilon):
action=self.predict(obs)
else:
action=np.random.choice(self.act_n)
return action
#根据输入观察值,预测输出的动作值
def predict(self,obs):
Q_list=self.Q[obs,:]
maxQ=np.max(Q_list)
action_list=np.where(Q_list==maxQ)[0]#maxQ可能对应多个action
action=np.random.choice(action_list)
return action
def learn(self,obs,action,reward,next_obs,next_action,done):
""" on-policy
obs: 交互前的obs, s_t
action: 本次交互选择的action, a_t
reward: 本次动作获得的奖励r
next_obs: 本次交互后的obs, s_t+1
next_action: 根据当前Q表格, 针对next_obs会选择的动作, a_t+1
done: episode是否结束
"""
predict_Q=self.Q[obs,action]
if done:
target_Q=reward
else:
target_Q=reward+self.gamma*self.Q[next_obs,next_action]#sarsa
self.Q[obs,action]+=self.lr*(target_Q-predict_Q) #修正q
# 保存Q表格数据到文件
def save(self):
npy_file = './q_table.npy'
np.save(npy_file, self.Q)
print(npy_file + ' saved.')
# 从文件中读取Q值到Q表格中
def restore(self, npy_file='./q_table.npy'):
self.Q = np.load(npy_file)
print(npy_file + ' loaded.')
train.py
import gym
from gridworld import CliffWalkingWapper,FrozenLakeWapper
from agent import SarsaAgent
import time
def run_episode(env,agent,render=False):
total_steps=0 #记录每个episode走了多少step
total_reward=0
obs=env.reset() #重置环境, 重新开一局(即开始新的一个episode)
action=agent.sample(obs) #根据算法选择一个动作
while True:
next_obs,reward,done,_=env.step(action) #与环境进行一个交互
next_action=agent.sample(next_obs) #根据算法选择一个动作
#训练Sarsa算法
agent.learn(obs,action,reward,next_obs,next_action,done)
action=next_action
obs=next_obs #储存上一个观察值
total_reward+=reward
total_steps+=1
if render:
env.render()
if done:
break
return total_reward,total_steps
def test_episode(env,agent):
total_reward=0
obs=env.reset()
while True:
action=agent.predict(obs) #greedy
next_obs,reward,done,_=env.step(action)
total_reward+=reward
obs=next_obs
time.sleep(0.5)
env.render()
if done:
print('test reward=%.1f'%(total_reward))
break
def main():
env=env = gym.make("CliffWalking-v0") # 0 up, 1 right, 2 down, 3 left
env=CliffWalkingWapper(env)
agent=SarsaAgent(obs_n=env.observation_space.n,act_n=env.action_space.n,
learning_rate=0.1,gamma=0.9,e_greed=0.1)
is_render=False
for episode in range(500):
ep_reward,ep_steps=run_episode(env,agent,is_render)
print('Episode %s: steps = %s , reward = %.1f' % (episode, ep_steps,ep_reward))
#每隔20个episode渲染一下看看效果
if episode%20==0:
is_render=True
else:
is_render=False
test_episode(env,agent)
if __name__=="__main__":
main()