目录
1. 定义算法
DQN 经验回放的算法中的均匀采样算法,可能会忘记一些重要的、以后使用的经验数据。针对这样的问题,PER_DQN 提出了优先级经验回放(prioritized experience reolay)的技术来解决,这种方法应用到 DQN 获得了更好的效果。PER_DQN 成功的原因有:1. 提出了sum tree这样复杂度为O(logn)的高效数据结构。 2. 正确估计了 weighted importance sampling.
1.1、 定义模型
这里的 PER_DQN 的模型和 DQN 中类似,也是用的三层的MLP。
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, n_states,n_actions,hidden_dim=128):
""" 初始化q网络,为全连接网络
"""
super(MLP, self).__init__()
self.fc1 = nn.Linear(n_states, hidden_dim) # 输入层
self.fc2 = nn.Linear(hidden_dim,hidden_dim) # 隐藏层
self.fc3 = nn.Linear(hidden_dim, n_actions) # 输出层
def forward(self, x):
# 各层对应的激活函数
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
1.2、定义经验回放
这里的经验回放是和DQN中最大的不同。它使用了sum tree的数据结构,它是一种特殊的二叉树,其父亲节点的值等于子节点的和。节点上的值,定义为每个样本的优先度,这里就用TDerror来衡量。叶子上的数值就是样本优先度。
sum tree 采样过程:根据根节点的priority和采样样本数,划分采样的区间,然后在这些区间中均应采样得到所要选取的样本的优先度。从根节点开始,逐层将样本的优先度和节点的优先度进行对比,最终可以得到所要采样的叶子样本。
import numpy as np
import random
class SumTree:
'''SumTree for the per(Prioritized Experience Replay) DQN.
This SumTree code is a modified version and the original code is from:
https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/5.2_Prioritized_Replay_DQN/RL_brain.py
'''
def __init__(self, capacity: int):
self.capacity = capacity
self.data_pointer = 0
self.n_entries = 0
self.tree = np.zeros(2 * capacity - 1)
self.data = np.zeros(capacity, dtype = object)
def update(self, tree_idx, p):
'''Update the sampling weight
'''
change = p - self.tree[tree_idx]
self.tree[tree_idx] = p
while tree_idx != 0:
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change
def add(self, p, data):
'''Adding new data to the sumTree
'''
tree_idx = self.data_pointer + self.capacity - 1
self.data[self.data_pointer] = data
# print ("tree_idx=", tree_idx)
# print ("nonzero = ", np.count_nonzero(self.tree))
self.update(tree_idx, p)
self.data_pointer += 1
if self.data_pointer >= self.capacity:
self.data_pointer = 0
if self.n_entries < self.capacity:
self.n_entries += 1
def get_leaf(self, v):
'''Sampling the data
'''
parent_idx = 0
while True:
cl_idx = 2 * parent_idx + 1
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree):
leaf_idx = parent_idx
break
else:
if v <= self.tree[cl_idx] :
parent_idx = cl_idx
else:
v -= self.tree[cl_idx]
parent_idx = cr_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
def total(self):
return int(self.tree[0])
class ReplayTree:
'''ReplayTree for the per(Prioritized Experience Replay) DQN.
'''
def __init__(self, capacity):
self.capacity = capacity # the capacity for memory replay
self.tree = SumTree(capacity)
self.abs_err_upper = 1.
## hyper parameter for calculating the importance sampling weight
self.beta_increment_per