前面已经分别写过q-learning和Sarsa的强化学习代码,其实两者差异并不非常大,只是在更新方式上不同,只是一个更加注重探索,一个更加注重应用。
那么学习了这两种强化学习方法后,我们来进一步提升其性能,通过赋予强化学习智能体先验知识使其能够更快达到预想的效果,在前面的代码中,我们我们已经将智能体学习后的q表存储下来,现在我们可以在其启动时赋予智能体这些知识。
为了实现先验知识的赋予,需要的几个函数进行修改,首先是environment
import numpy as np
import time
import tkinter as tk
#定义一些常量
UNIT=40
WIDTH=4
HIGHT=4
#环境也写在一个类中
class Maze0509(tk.Tk,object):
def __init__(self):
super(Maze0509, self).__init__()
#动作空间
self.action_space=['u','d','l','r']
#self.n_action=len(self.action_space)
self.title('maze')
#建立画布
self.geometry('{0}x{1}'.format(HIGHT*UNIT,WIDTH*UNIT))
self.build_maze()
def build_maze(self):
self.canvas=tk.Canvas(self,bg='white',height=HIGHT*UNIT,width=WIDTH*UNIT)
#绘制线框
for i in range(0,WIDTH*UNIT,UNIT):
x0,y0,x1,y1=i,0,i,WIDTH*UNIT
self.canvas.create_line(x0,y0,x1,y1)
for j in range(0,HIGHT*UNIT,UNIT):
x0,y0,x1,y1=0,j,HIGHT*UNIT,j
self.canvas.create_line(x0,y0,x1,y1)
#创建迷宫中的地狱
hell_center1=np.array([100,20])
self.hell1=self.canvas.create_rectangle(hell_center1[0]-15,hell_center1[1]-15,hell_center1[0]+15,hell_center1[1]+15,fill='black')
hell_center2=np.array([20,100])
self.hell2=self.canvas.create_rectangle(hell_center2[0]-15,hell_center2[1]-15,hell_center2[0]+15,hell_center2[1]+15,fill='green')
#创建出口
out_center=np.array([100,100])
self.oval=self.canvas.create_oval(out_center[0]-15,out_center[1]-15,out_center[0]+15,out_center[1]+15,fill='yellow')
#智能体
origin=np.array([20,20])
self.finder=self.canvas.create_rectangle(origin[0]-15,origin[1]-15,origin[0]+15,origin[1]+15,fill='red')
self.canvas.pack()#一定不要忘记加括号
#智能体探索步
def step(self,action):
s=self.canvas.coords(self.finder)#获取智能体当前的位置
#由于移动的函数需要传递移动大小的参数,所以这里需要定义一个移动的基准距离
base_action=np.array([0,0])
#根据action来确定移动方向
if action=='u':
if s[1]>UNIT:
base_action[1]-=UNIT
elif action=='d':
if s[1]<HIGHT*UNIT:
base_action[1]+=UNIT
elif action=='l':
if s[0]>UNIT:
base_action[0]-=UNIT
elif action=='r':
if s[0]<WIDTH*UNIT:
base_action[0]+=UNIT
#移动
self.canvas.move(self.finder,base_action[0],base_action[1])
#移动后记录新位置指标
s_=self.canvas.coords(self.finder)
#反馈奖励,terminal不是自己赋予的,而是判断出来的
if s_==self.canvas.coords(self.oval):
reward=1
done=True
s_='terminal'#结束了
elif s_ in (self.canvas.coords(self.hell2),self.canvas.coords(self.hell1)):
reward=-1
done=True
s_='terminal'
else:
reward=0
done=False
#这个学习函数不但传入的参数多,返回的结果也多
return s_,reward,done
#渲染函数
def render(self):
time.sleep(0.1)
self.update()#这里的update应该是画布里的自动更新函数
#重置函数,当一轮走完后,需要重置画布到最初状态
def resets(self):
#其实,就是删掉原先的Finder,再重新定义新的Finder
self.update()
time.sleep(0.5)
#删掉搜索这
self.canvas.delete(self.finder)
#删掉后,再重新定义最初位置
origin=np.array([20,20])
self.finder=self.canvas.create_rectangle(origin[0]-15,origin[1]-15,origin[0]+15,origin[1]+15,fill='red')
return self.canvas.coords(self.finder)
这里将dataframe中的列索引类型改为字符类型,这样方便step函数中的动作对比,之前是用数字0,1,2,3来表示,但是在代码运行过程中会出现类型的错误,于是乎直接使用u,d,l,r这些动作字符表示吧,简单省事。整体environment里面没有改动太多
然后是agent
import numpy as np
import pandas as pd
import os
class Q_table:
def __init__(self,actions,learning_rate=0.01,reward_decay=0.9,e_greedy=0.9,filename=None):
self.actions=actions
self.learning_rate=learning_rate
self.reward_decay=reward_decay
self.e_greedy=e_greedy
if os.path.exists(filename):
print("使用先验知识赋予智能体认知")
self.q_table=pd.read_csv(filename,index_col=0)
print(self.q_table)
else:
print("直接自己探索学习吧")
self.q_table=pd.DataFrame(columns=actions,dtype=np.float64)
def choose_action(self,s,e):
self.check_state_exist(s)
if np.random.uniform()<e:
s_actions=self.q_table.loc[s,:]
action=np.random.choice(s_actions[s_actions==np.max(s_actions)].index)
else:
action=np.random.choice(self.actions)
return action
def learn(self,s,a,r,s_,a_):
self.check_state_exist(s_)
q_predict=self.q_table.loc[s,a]
if s_!='terminal':
q_target=r+self.reward_decay*self.q_table.loc[s_,a_]
else:
q_target=r
self.q_table.loc[s,a]+=self.learning_rate*(q_target-q_predict)
return self.q_table
def check_state_exist(self,state):
if state not in self.q_table.index:
self.q_table=self.q_table.append(pd.Series([0]*len(self.actions),index=self.q_table.columns,name=state))
agent中直接修改了构造方法,在初始化参数里添加了从外部读入文件的参数,也正是因为从csv文件中读取数据建造dataframe类型数据的时候,0,1,2,3被转换成了字符型,所以才会导致上面所说的数据类型不同。现在进行了修改就好了
最后是run
from environment0509 import Maze0509
from agent0509 import Q_table
def update():
epison=0.9
for i in range(10):
s=env.resets()
a=RL.choose_action(str(s),epison)
while True:
env.render()
s_,r,done=env.step(a)
a_=RL.choose_action(str(s_),epison)
q_table=RL.learn(str(s),a,r,str(s_),a_)
s=s_
a=a_
if done:
break
epison+=0.05
print("game over")
q_table.to_csv('output.csv')
env.destroy()
if __name__ == '__main__':
env=Maze0509()
RL=Q_table(actions=env.action_space,filename='output.csv')#这里面action传入进去的就是数字
env.after(10,update)
env.mainloop()
这里只需要修改main里面的代码即可,在初始化RL时加入存放先验知识的文件,让其读入数据。