便于理解DQN Prioritized Experience Replay的相关补充

由于最近在做HEV的DQN和DDPG,看了不少关于DQN的视频(以b站up主莫烦为主)和博客,发现对Prioritized Experience Replay的介绍很多,但都不详细,在看代码的过程中也遇到了某些疑惑。
本文通过对 SumTree 和 Memory 的代码进行注释来进行一些补充。

关于 SumTree 的原理,在莫烦本人的网站中介绍的已经非常清楚,此处不赘述。

# self.tree 用来 存放 p值  >>> tree[0] 为 第零层 tree[1] tree[2] 为第一层...以此类推
# self.data 用来 存放 transitions >>> transition 的结构为 [s,a,r,s_]
Class SumTree():

# self.data的索引,例如data_pointer = 0时,那么就把这条 transition 存放在 self.data[0]位置
#存放完一条data_pointer += 1,当 data_pointer > data可存放最大容量时候重新初始化为0
data_pointer = 0  

def __init__(self, capacity):
    self.capacity = capacity  # self.data 的容量 // 也是 self.tree 最底层的节点数
    self.tree = np.zeros(2 * capacity - 1)  
    # 底层节点数 为 capacity 那么 self.tree 一共有 2 * capacity - 1 个节点
    #self.tree[0:size: capacity - 1] 为父层节点,self.tree[-capacity:] 为最下一层叶子节点
    # [--------------Parent nodes-------------][-------leaves to recode priority-------]
    #             size: capacity - 1                       size: capacity
    self.data = np.zeros(capacity, dtype=object)  # for all transitions
    # 用来存放 transitions
    # [--------------data frame-------------]
    #             size: capacity

def add(self, p, data):  # 输入 p值 和 transition
    tree_idx = self.data_pointer + self.capacity - 1 
    # transition 和 p 的索引是相对应的,可根据 data_pointer 求出 p值的索引
    self.data[self.data_pointer] = data  # 存放transition
    self.update(tree_idx, p)  # 更新 tree 中 底层叶子节点 和 上层节点的 p值

    self.data_pointer += 1
    if self.data_pointer >= self.capacity:  # 如上所述,大于容量便初始化为0
        self.data_pointer = 0

def update(self, tree_idx, p):
    change = p - self.tree[tree_idx] 
    # 由于data中存入了新的数据,则tree中底层相对应索引处的p值也要被修改,需要自下而上更新各节点处的p值
    # 至于为什么要更新p值,https://mofanpy.com/tutorials/machine-learning/reinforcement-learning/prioritized-replay/
    self.tree[tree_idx] = p # 更新该索引处的p值
    # then propagate the change through tree
    while tree_idx != 0:    # this method is faster than the recursive loop in the reference code
        tree_idx = (tree_idx - 1) // 2 # 自下而上更新至第0层,下一层 索引和上一层索引的关系为 (tree_idx - 1) // 2 
        self.tree[tree_idx] += change

def get_leaf(self, v):
# 根据随机数v 来 >>> 抽取叶子节点 >>> 获得叶子节点的索引 >>> 进而抽取到data中相对应的数据
# 至于是如何根据v来抽取的,看代码很容易看懂,便不赘述,莫烦网站中举的例子也很清晰https://mofanpy.com/tutorials/machine-learning/reinforcement-learning/prioritized-replay/
    """
    Tree structure and array storage:

    Tree index:
         0         -> storing priority sum
        / \
      1     2
     / \   / \
    3   4 5   6    -> storing priority for transitions

    Array type for storing:
    [0,1,2,3,4,5,6]
    """
    parent_idx = 0
    while True:     # the while loop is faster than the method in the reference code
        cl_idx = 2 * parent_idx + 1         # this leaf's left and right kids
        cr_idx = cl_idx + 1
        if cl_idx >= len(self.tree):        # reach bottom, end search
            leaf_idx = parent_idx
            break
        else:       # downward search, always search for a higher priority node
            if v <= self.tree[cl_idx]:
                parent_idx = cl_idx
            else:
                v -= self.tree[cl_idx]
                parent_idx = cr_idx

    data_idx = leaf_idx - self.capacity + 1
    return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

@property
def total_p(self):
    return self.tree[0]  # the root

至于Memory的代码中有一点需要说明
在莫烦视频中的代码,transitions在store前就会进行TD_errors的计算。然而在其github中的代码中并没有,而是将这些transitions的p值设为最大值令我很困惑,通过看论文、看视频找到了如下①②③中的原因,这条视频讲的相对比较清楚

p存储更新的代码用文字表达的话可以简化如下:
①self.data 没有存满前,那么self.tree 最低层中的p值都设为 abs_error_upper = 1
②当self.data存满了以后,假设这条transition要存在self.data[3],那么self.tree[3+self.capacity-1 ]就等于self.tree最底层的最大值(由于还没有被抽样进行训练,所以p值并不是真实的abs_error)
③当开始 learn,先抽取,再训练神经网络,再计算抽取的data的真实p值,再更新self.tree中相对应节点的p值。(当然更新后的p值并不一定是其真实p值,因为对p值设置了上限)
没有被抽样训练过的 transitions 的 p 值为 最大值,只有经过 训练以后才对这些 transitions 更新其真实的 p值

搞明白了这 ①②③ 再看代码就会容易很多

class Memory():
epsilon = 0.01  # small amount to avoid zero priority
alpha = 0.6  # [0~1] convert the importance of TD error to priority
beta = 0.4  # importance-sampling, from initial value increasing to 1
beta_increment_per_sampling = 0.001
abs_err_upper = 1.  # clipped abs error

def __init__(self, capacity):
    self.tree = SumTree(capacity)

def store(self, transition):
    max_p = np.max(self.tree.tree[-self.tree.capacity:]) 
    # store的时候这些数据还没进行抽取训练,因此transitions的p值为最大值
    if max_p == 0:
        max_p = self.abs_err_upper  # 一开始 子层 p值 都为 0
    self.tree.add(max_p, transition)   # set the max p for new p

def sample(self, n):
    b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1))  # 一维 二维 二维
    pri_seg = self.tree.total_p / n       # priority segment
    self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])  # max = 1

    min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p     # for later calculate ISweight
    for i in range(n):
        a, b = pri_seg * i, pri_seg * (i + 1)
        v = np.random.uniform(a, b)
        idx, p, data = self.tree.get_leaf(v)
        prob = p / self.tree.total_p
        ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
        b_idx[i], b_memory[i, :] = idx, data
    return b_idx, b_memory, ISWeights

def batch_update(self, tree_idx, abs_errors): 
# 训练完后,如上③所述要对batch的p值进行更新,这时更新的是真实的p值,但有如下代码的限制
    abs_errors += self.epsilon  # convert to abs and avoid 0
    clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
    ps = np.power(clipped_errors, self.alpha)
    for ti, p in zip(tree_idx, ps):
        self.tree.update(ti, p)
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值