强化学习原理python篇06(拓展)——DQN-FrozenLake


本章全篇参考赵世钰老师的教材 Mathmatical-Foundation-of-Reinforcement-Learning Deep Q-learning 章节,请各位结合阅读,本合集只专注于数学概念的代码实现。

概述

本章为DQN拓展篇,以FrozenLake游戏举例

rl-video-episode-0

FrozenLake是一个 4*4 的网络格子,每个格子可以是起始块,目标块、冻结块或者危险块。游戏规则如下,玩家从初始点位出发,当玩家掉入冰窟,则游戏结束,若玩家捡到宝箱则,游戏胜利。

准备工作

首先,建立一个网络,并对状态值进行独热编码。

class Net(nn.Module):
    def __init__(self, obs_size, hidden_size, q_table_size):
        super(Net, self).__init__()

        self.net = nn.Sequential(
            # 输入为状态,样本为(1*n)
            nn.Linear(obs_size, hidden_size),
            nn.ReLU(),
            # nn.Linear(hidden_size, hidden_size),
            # nn.ReLU(),
            nn.Linear(hidden_size, q_table_size),
        )

    def forward(self, state):
        return self.net(state)


class DiscreteOneHotWrapper(gym.ObservationWrapper):
    def __init__(self, env):
        super(DiscreteOneHotWrapper, self).__init__(env)
        assert isinstance(env.observation_space, gym.spaces.Discrete)
        shape = (env.observation_space.n,)
        self.observation_space = gym.spaces.Box(0.0, 1.0, shape, dtype=np.float32)

    def observation(self, observation):
        res = np.copy(self.observation_space.low)
        res[observation] = 1.0
        return res

ReplayBuffer

建立一个buffer 模拟以下步骤

4)将转移过程(s, a, r, s’)存储在回放缓冲区中。

5)从回放缓冲区中采样一个随机的小批量转移过程。

class ReplayBuffer:
    def __init__(self, queue_size, replay_time):
        self.queue = []
        self.queue_size = queue_size
        self.replay_time = replay_time

    def get_batch_queue(self, env, action_trigger, batch_size, epsilon):
        def insert_sample_to_queue(env):
            state, info = env.reset()
            stop = 0

            while True:
                if np.random.uniform(0, 1, 1) > epsilon:
                    action = env.action_space.sample()
                else:
                    action = action_trigger(state)

                next_state, reward, terminated, truncated, info = env.step(action)
                self.queue.append([state, action, next_state, reward, terminated])
                state = next_state
                if terminated:
                    state, info = env.reset()
                    stop += 1
                    continue
                if stop >= replay_time:
                    break

        def init_queue(env):
            while True:
                insert_sample_to_queue(env)
                if len(self.queue) >= self.queue_size:
                    break

        init_queue(env)
        insert_sample_to_queue(env)
        self.queue = self.queue[-self.queue_size :]

        return random.sample(self.queue, batch_size)

建立DQN网络

class DQN:
    def __init__(self, env, obs_size, hidden_size, q_table_size):
        self.env = env
        self.net = Net(obs_size, hidden_size, q_table_size)
        self.tgt_net = Net(obs_size, hidden_size, q_table_size)

    # 更新net参数
    def update_net_parameters(self, update=True):
        self.net.load_state_dict(self.tgt_net.state_dict())

    def get_action_trigger(self, state):
        state = torch.Tensor(state)
        action = int(torch.argmax(self.tgt_net(state).detach()))
        return action

    # 计算y_hat_and_y
    def calculate_y_hat_and_y(self, batch, gamma):
        y = []
        action_sapce = []
        state_sapce = []

        for state, action, next_state, reward, terminated in batch:
            q_table_net = self.net(torch.Tensor(next_state)).detach()
            y.append(reward + (1 - terminated) * gamma * float(torch.max(q_table_net)))
            action_sapce.append(action)
            state_sapce.append(state)
        y_hat = self.tgt_net(torch.Tensor(np.array(state_sapce)))
        y_hat = y_hat.gather(1, torch.LongTensor(action_sapce).reshape(-1, 1))
        return y_hat.reshape(-1), torch.tensor(y)

    def predict_reward(self):
        state, info = env.reset()
        step = 0
        reward_space = []

        while True:
            step += 1
            state = torch.Tensor(state)
            action = int(torch.argmax(self.net(state).detach()))
            next_state, reward, terminated, truncated, info = env.step(action)
            reward_space.append(reward)
            state = next_state
            if terminated:
                state, info = env.reset()
                continue
            if step >= 100:
                break
        return float(np.mean(reward_space))

开始训练

hidden_size = 64
queue_size = 500
replay_time = 50

## 初始化环境
env = frozen_lake.FrozenLakeEnv(is_slippery=False)
env.spec = gym.spec("FrozenLake-v1")
env = gym.wrappers.TimeLimit(env, max_episode_steps=100)
env = DiscreteOneHotWrapper(env)

## 初始化buffer
replay_buffer = ReplayBuffer(queue_size, replay_time)

## 初始化dqn
obs_size = env.observation_space.shape[0]
q_table_size = env.action_space.n
dqn = DQN(env, obs_size, hidden_size, q_table_size)

# 定义优化器
opt = optim.Adam(dqn.tgt_net.parameters(), lr=0.01)

# 定义损失函数
loss = nn.MSELoss()

batch_size = 256
epsilon = 0.8
epochs = 200
gamma = 0.9

for epoch in range(epochs):
    batch = replay_buffer.get_batch_queue(
        env, dqn.get_action_trigger, batch_size, epsilon
    )
    y_hat, y = dqn.calculate_y_hat_and_y(batch, gamma)
    l = loss(y_hat, y)

    # 反向传播
    opt.zero_grad()
    l.backward()
    opt.step()

    if epoch % 10 == 0 and epoch != 0:
        dqn.update_net_parameters()

    print(
        "epoch:{},  MSE: {}, epsilon: {}, 100 steps reward: {}".format(
            epoch, l, epsilon, dqn.predict_reward()
        )
    )

可视化预测

DQN_Q = dqn.net

env = frozen_lake.FrozenLakeEnv(is_slippery=False, render_mode="human")
env.spec = gym.spec("FrozenLake-v1")
# display_size = 512
# env.window_size = (display_size, display_size)
# env.cell_size = (
#     env.window_size[0] // env.ncol,
#     env.window_size[1] // env.nrow,
# )
env = gym.wrappers.RecordVideo(env, video_folder="video")

env = DiscreteOneHotWrapper(env)

state, info = env.reset()
total_rewards = 0

while True:
    action = int(torch.argmax(DQN_Q(torch.Tensor(state))))
    state, reward, terminated, truncted, info = env.step(action)
    print(action)
    if terminated:
        break
env.close()

Ref

Mathematical Foundations of Reinforcement Learning,Shiyu Zhao

  • 14
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值