代码思路
Policy iteration的思路:
- 初始化决策policy和policy的值函数V,代码中有16个格子,且每个格子有上下左右四个方向,所以policy中的值是0-3,四种。
- 通过初始化后的policy去计算这个policy的V(初始化的V没有什么用),因为格子内有信息,比如这个格子状态的奖励是多少,沿着方向能到达什么状态等。然后判断policy的V是否收敛,收敛则退出循环。
- 计算q-table,q-table里面的横轴代表状态,数轴代表行为,里面值代表对应状态采取对应行为能够获得的V值,然后从一列中取出最高V值的状态,用这个行为代替原有policy中的行为。
- 至此就拿到了旧的policy和新的policy,对比他们,如果一样则代表不用继续迭代了,已经收敛,然后用新得到的policy去训练。
Value iteration的思路:
- 直接初始化V,不用初始化policy。然后在环境中直接计算Q-table(为什么能计算出来呢?因为到达goal的状态都能得到1的奖励,也就是从goal状态的附近开始更新V。)结束条件是V收敛或者达到最大迭代次数。通过Q-table,拿到那一列最高值的V,作为V的更新。
- 用新的V来计算Q-table,用Q-table拿到最佳的Action,用来更新Policy。
- 用新的policy来运行环境,得到结果。
主要代码片段
Policy iteration:
def compute_policy_v(env, policy, gamma=1.0):
""" Iteratively evaluate the value-function under policy.
Alternatively, we could formulate a set of linear equations in iterms of v[s]
and solve them to find the value function.
"""
v = np.zeros(env.env.nS)
eps = 1e-10
while True:
prev_v = np.copy(v)
for s in range(env.env.nS):
# 拿到初始化策略的值,0-3可以看做采取的行为,对应不同的动作,每个动作做完后也不一定会到达一个确定的状态
# 比如在P[0]采取3的时候有0.66的概率到达状态0,有0.33的概率到达1
policy_a = policy[s]
# s_是该状态采取行动可到达的下一个状态,P是采取行动的概率,r是当前状态奖励
# v[s]的计算公式是笔记里面的(10)
v[s] = sum([p * (r + gamma * prev_v[s_]) for p, s_, r, _ in env.env.P[s][policy_a]])
if (np.sum((np.fabs(prev_v - v))) <= eps):
# value converged
break
return v
def policy_iteration(env, gamma = 1.0):
""" Policy-Iteration algorithm """
policy = np.random.choice(env.env.nA, size=(env.env.nS)) # initialize a random policy
max_iterations = 200000
gamma = 1.0
for i in range(max_iterations):
old_policy_v = compute_policy_v(env, policy, gamma)
new_policy = extract_policy(old_policy_v, gamma)
if (np.all(policy == new_policy)):
print ('Policy-Iteration converged at step %d.' %(i+1))
break
policy = new_policy
return policy
Value iteration:
def value_iteration(env, gamma = 1.0):
""" Value-iteration algorithm """
v = np.zeros(env.env.nS) # initialize value-function
max_iterations = 100000
eps = 1e-20
for i in range(max_iterations):
prev_v = np.copy(v)
for s in range(env.env.nS):
q_sa = [sum([p*(r + gamma * prev_v[s_]) for p, s_, r, _ in env.env.P[s][a]]) for a in range(env.env.nA)]
v[s] = max(q_sa)
if (np.sum(np.fabs(prev_v - v)) <= eps):
print ('Value-iteration converged at iteration# %d.' %(i+1))
break
return v
理解代码的前提需要知道环境,也就是参数中的env,它是gym包提供的用来运行RL的虚拟环境,这两个代码用的是名叫“FrozenLake-v0”的环境,有S\F\H\G四个字符,S代表起始状态、F代表可以走的冰面、H代表洞,走到这里代表掉入水中,中止行为,不得到奖励、G代表重点,走到这里代表到达目的地,能够拿到1的奖励。此外Policy数组代表在index+1状态的时候要采取的什么动作。比如Poliycy[4]=3代表状态5的时候采取3的动作。细节部分还需要自己探索,这样才能加深印象。
上图是代码中用来计算的公式,虽然字符比较多,但过一遍周博磊老师的第一二节视频应该就会懂了。图片出自周博磊老师的视频。
链接: 周博磊老师的视频链接.
链接: 代码下载地址.