先从简单的q-learning代码练习起步,001的环境也比较简单,训练智能体从左走到右边找到treasure
话不多说,直接写代码最直接。
import numpy as np
import pandas as pd
import time
np.random.seed(2)
N_STATES=6 #状态
N_ACTIONS=['left','right'] #是个列表,里面包含了动作选项
EPISION=0.9 #贪婪策略
ALPHA=0.1 #学习率
GAMMA=0.9 #奖励的折扣率
MAX_EPISODES=13 #迭代的次数
#老师说的没错,要想真正的数量,必须要多练习
FRESH_TIME=0.3 #刷新的时间
#建立Q表
def build_q_table(n_states,actions):
table=pd.DataFrame(np.zeros((n_states,len(actions))),columns=actions) #填0
return table
#动作探索策略
def choose_action(state,q_table):#传入状态和Q表
state_actions=q_table.iloc[state,:]#找出该状态下对应的所有q值
#如果没有对应该状态记录,或者是不执行贪婪策略
if(np.random.uniform()>EPISION) or ((state_actions==0).all):#矩阵全为0也触发非贪心的随机策略
action_name=np.random.choice(N_ACTIONS)
else:
action_name=state_actions.idmax()
return action_name
#执行动作,以及环境反馈
def get_env_feedback(state,action):
if action=='right':#向右移动
if state==N_STATES-2:#结束了
s_next='terminal'
reward=10
else:
s_next=state+1
reward=2
else:
reward=-1
if state==0:
s_next=state
else:
s_next=state-1
return s_next,reward
#更新环境
def update_env(state,episode,step_counter): #可视化更新过程
env_list=['-']*(N_STATES-1)+['T']
if state=='terminal':
interaction='Episode: %s: total_steps=%s'%(episode+1,step_counter)
#\r 表示将光标的位置回退到本行的开头位置
print('\r{}'.format(interaction),end='')
time.sleep(2)
else:#state对应的是某一个状态的位置
env_list[state]='o'
interaction=''.join(env_list)
print('\r{}'.format(interaction),end='')
time.sleep(FRESH_TIME)
#前面都是准备工作,都是方法片段,下面才是训练学习的过程
def rl():#调用方法建立q表
q_table=build_q_table(N_STATES,N_ACTIONS)
#进入迭代
for i in range(MAX_EPISODES):
step_counter=0#这是用来统计每一轮学习迭代的步数
#初始化状态
s=0
is_terminated=False
#第一次利用初始状态更新环境,显示环境
update_env(s,i,step_counter)
#更新环境后判断是否结束
while not is_terminated:#没结束,已经有表了,通过q表选择动作
A=choose_action(s,q_table)
s_next,reward=get_env_feedback(s,A)#获取下一个状态和奖励
#基于当前状态的预测值
q_predict=q_table.loc[s,A]
if s_next!='terminal':#计算目标值
q_target=reward+GAMMA*q_table.loc[s_next,:].max()
else:#否则的话就直接赋予奖励值即可
q_target=reward
is_terminated=True
#更新q表
q_table.loc[s,A]+=ALPHA*(q_target-q_predict)
#环境进入下一个状态
s=s_next
#再次更新环境,记住,更新环境仅和状态有关,
update_env(s,i,step_counter+1)
step_counter+=1
#注意,学习完成后返回的是一个q表,里面包含了各个状态对应的最佳参数
return q_table
#和Java一样,它也是需要主函数的
if __name__ == '__main__':
q_table=rl()
print('\r\nq_table:\n')
print(q_table)