8.2、Prioritized sweeping 算法实现

算法伪代码:

在这里插入图片描述
 


使用maze环境:maze_env 的代码见这里

import numpy as np
import pandas as pd
from maze_env import Maze
import queue


class Q(object):
    def __init__(self, action_space):
        self.nA = action_space
        self.actions = list(range(action_space))

        self.q_table = pd.DataFrame(columns=self.actions)
        self.init_Q()

    def init_Q(self):
        for x in range(5, 165, 40):
            for y in range(5, 165, 40):
                if x == 45 and y == 85:
                    s = 'terminal'
                elif x == 85 and y == 45:
                    s = 'terminal'
                elif x == 85 and y == 85:
                    s = 'terminal'
                else:
                    s = [x+0.0, y+0.0, x + 30.0, y + 30.0]
                    s = str(s)
                if s not in self.q_table.index:
                    self.q_table = self.q_table.append(
                        pd.Series([0] * len(self.actions),
                                  index=self.q_table.columns,
                                  name=s)
                    )

    def target_policy(self, s):
        # target_policy is the greedy policy
        # self.check_state_exist(s)
        A = self.target_policy_probs(s)
        return np.random.choice(range(self.nA), p=A)

    def target_policy_probs(self, s, epsilon=.3):
        A = np.ones(self.nA, dtype=float) * epsilon / self.nA
        best_action = np.argmax(self.q_table.loc[s, :])
        A[best_action] += (1.0 - epsilon)
        return A


class Model(object):
    def __init__(self):
        self.model = dict()

    def store(self, s, a, r, s_):
        self.model[s, a] = [r, s_]


if __name__ == '__main__':
    env = Maze()
    action_space = env.n_actions
    RL = Q(action_space)
    model = Model()
    PQueue = queue.Queue()

    gamma = 0.9
    alpha = 0.01
    theta = 0.5

    for episode in range(100):
        state = env.reset()

        while True:
            env.render()
            action = RL.target_policy(str(state))
            state_, reward, done = env.step(action)

            model.store(str(state), action, reward, str(state_))
            P = abs(reward + gamma * np.max(RL.q_table.loc[str(state_), :])
                    - RL.q_table.loc[str(state), action])

            if P > theta:
                PQueue.put([str(state), action])

            while not PQueue.empty():
                S_A = PQueue.get()
                S = S_A[0]
                A = S_A[1]
                R, S_ = model.model[S, A]
                G = R + gamma * np.max(RL.q_table.loc[S_, :])
                RL.q_table.loc[S, A] += alpha * (G - RL.q_table.loc[S, A])

                values = model.model.values()

                # S = ‘terminal'时 已经结束 这里不需要考虑
                if [0, S] in model.model.values():
                    for keys in range(0, len(model.model)):
                        if list(model.model.values())[keys] == [0, S]:
                            _s_a = list(model.model.keys())[keys]
                            _s_a = list(_s_a)
                            _s = _s_a[0]
                            _a = _s_a[1]
                            _r = 0
                            P = abs(_r + gamma * np.max(RL.q_table.loc[str(S), :])
                                    - RL.q_table.loc[_s, _a])
                            if P > theta:
                                PQueue.put([_s, _a])

            if done:
                break

            state = state_

    print('game over')
    env.destroy()

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值