class PolicyIteration:
""" 策略迭代算法 """
def __init__(self,env,theta,gamma):
self.env=env
self.theta=theta#策略评估收敛阈值
self.gamma=gamma#折扣因子
self.v=[0]*(self.env.ncol*self.env.nrow)#初始化价值为0
self.pi=[[0.25,0.25,0.25,0.25]for i in range(self.env.ncol*self.env.nrow)]
def policy_evaluation(self):
count=1
while 1:
max_diff=0
new_v=[0]*self.env.ncol*self.env.nrow
for s in range(self.env.ncol*self.env.nrow):
Qsa_list=[]#开始计算状态s下面的所有Q(s,a)价值
for a in range(4):
Qsa=0
for res in self.env.P[s][a]:
p,next_state,reward,done=res
Qsa+=p*(reward+self.gamma*self.v[next_state]*(1-done))#本环境特殊,奖励和下一个状态有关,所以需要和状态转移概率相乘
Qsa_list.append(self.pi[s][a]*Qsa)
new_v[s]=sum(Qsa_list)#状态价值函数与动作价值函数之间的关系
max_diff=max(max_diff,abs(new_v[s]-self.v[s]))
self.v=new_v
if max_diff<self.theta:break#满足收敛条件,退出评估迭代
count+=1
print("策略评估进行%d轮后完成"%count)
def policy_improvement(self):#策略提升
for s in range(self.env.ncol*self.env.nrow):
Qsa_list=[]
for a in range(4):
Qsa=0
for res in self.env.P[s][a]:
p,next_state,reward,done=res
Qsa+=p*(reward+self.gamma*self.v[next_state]*(1-done))
Qsa_list.append(Qsa)
max_Qsa=max(Qsa_list)
count_max_Qsa=Qsa_list.count(max_Qsa)#计算有几个动作得到最大的Q值
self.pi[s]=[1/count_max_Qsa if p==max_Qsa else 0 for p in Qsa_list]#让这些动作均分概率
print("策略提升完成")
return self.pi
def policy_iteration(self):#策略迭代
while 1:
self.policy_evaluation()
old_pi=self.pi.copy()#将列表进行深拷贝,方便接下来进行比较
new_pi=self.policy_improvement()
if new_pi==old_pi:break
#打印策略函数,打印当前策略在每一个状态下的价值以及智能体会采取的动作。对于打印出来的动作,用o↓o→表示等概率采取向上和向右两种动作,ooo→表示在当前状态下仅仅采取向右动作。
def print_agent(agent,action_meaning,disater=[],end=[]):
print("状态价值:")
for i in range(agent.env.nrow):
for j in range(agent.env.ncol):
print('%6.6s' % ('%.3f' % agent.v[i*agent.env.ncol+j]),end=' ')
print()
print("策略:")
for i in range(agent.env.nrow):
for j in range(agent.env.ncol):
#一些特殊的状态,例如悬崖漫步中的悬崖
if (i*agent.env.ncol+j) in disater:
print('****',end=' ')
elif (i*agent.env.ncol+j) in end:#目标状态
print('EEEE',end=' ')
else:
a=agent.pi[i*agent.env.ncol+j]
pi_str=''
for k in range(len(action_meaning)):
pi_str+=action_meaning[k] if a[k]>0 else 'o'
print(pi_str,end=' ')
print()#换行
import gym
import cv2
env=gym.make("FrozenLake-v1",map_name="4x4",render_mode='human')#创建环境 render_mode='human'/'rgb_array'/'ansi'
env.reset()
env=env.unwrapped#解封装才能访问状态转移矩阵P
action_meaning=['↑','↓','←','→']
theta=1e-5
gamma=0.9
agent=PolicyIteration(env,theta,gamma)
agent.policy_iteration()
print_agent(agent,action_meaning,[5,7,11,12],[15])
07-27
159