【代码小记】赏析 优先经验回放《PER:Prioritized Experience Replay》

本文介绍了SumTree数据结构在优先级强化学习中的应用,详细讲解了如何使用优先级来存储和更新经验,以及如何通过不均匀抽样进行训练。重点阐述了PER(Prioritized Experience Replay)的方法,包括添加新经验、采样策略和批更新的过程。
摘要由CSDN通过智能技术生成

参考作者:https://yulizi123.github.io/tutorials/machine-learning/reinforcement-learning/4-6-prioritized-replay/

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 25 00:43:00 2018

@author: wuyuankai
"""

import numpy as np


class SumTree(object):
####################################################################
    def __init__(self, capacity):
        #1)capacity:存储transition的容量
        self.capacity = capacity  # for all priority values #这里MEMORY_CAPACITY = 10000
        #2)叶子结点:存储各transition的优先级值
        self.tree = np.zeros(2 * capacity - 1)
        # [--------------Parent nodes-------------][-------leaves to recode priority-------]
        #             size: capacity - 1                       size: capacity
        #3)self.data:存储各transitions
        self.data =np.zeros(capacity, dtype=object) #for all transitions #可以使用类型为Object的变量指向任意类型的对象
        # [--------------data frame-------------]
        #             size: capacity


####################################################################
    data_pointer = 0
    def add(self, max_p, data):
        """
         Tree structure and array storage:
         Tree index:
              0         -> 根节点:storing priority sum
             / \
           1     2
          / \   / \
         3   4 5   6    -> 叶子结点:storing priority for transitions
         Array type for storing:
         [0,1,2,3,4,5,6]
         """
        #1)在self.data中增加/更新一条transition
        self.data[self.data_pointer] = data  #在self.data_pointer位置处,增加/更新一条transition
        #2.1)在self.tree叶结点中,增加/更新该transition对应的优先级值
        #2.2)使该transition对应的优先级值所带来的改变量,反应在相应self.tree上层节点值中
        tree_idx = self.data_pointer + self.capacity - 1  #从叶子结点的最左边开始(比如:3号结点)
        self.update(tree_idx, max_p)  # update tree_frame
        self.data_pointer += 1
        #3)更新->超容后-> self.data_pointer归位
        if self.data_pointer >= self.capacity:  # replace when exceed the capacity
            self.data_pointer = 0

            # with open('save_memory1.txt', 'w') as outfile:
            #     for slice_2d in self.data:
            #        np.savetxt(outfile, slice_2d, fmt='%f', delimiter=',')


    def update(self, tree_idx, p):#利用二叉树结构
        #2.1)在self.tree该叶节点的各上层节点上,相应增加该transition对应的优先级值改变量
        change = p - self.tree[tree_idx]
        #2.2)在self.tree叶结点中,更新该transition对应的优先级值
        self.tree[tree_idx] = p
        while tree_idx != 0:    # this method is faster than the recursive loop in the reference code
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

#############################################################################
    def get_leaf(self, v):
        """
            Tree structure and array storage:
            Tree index:
                 0         -> 根节点:storing priority sum
                / \
              1     2
             / \   / \
            3   4 5   6    -> 叶子结点:storing priority for transitions
            Array type for storing:
            [0,1,2,3,4,5,6]
            """
        #1)起始:根结点处
        parent_idx = 0
        while True:     # the while loop is faster than the method in the reference code
            cl_idx = 2 * parent_idx + 1         # this leaf's left and right kids
            cr_idx = cl_idx + 1
            #3)key结束:遍历直至‘根节点’parent_idx为叶子结点
            if cl_idx >= len(self.tree):# reach bottom, end search
                leaf_idx = parent_idx
                break
            #2)过程:通过v值,不断沿树向下遍历(不断更换‘根节点’parent_idx)
            else: # downward search, always search for a higher priority node
                if v <= self.tree[cl_idx]:
                    parent_idx = cl_idx
                else:
                    v -= self.tree[cl_idx]
                    parent_idx = cr_idx

        #key-key-key:树的叶结点索引转为在self.data中的索引(对应同一条transition)
        #data_idx=i,leaf_idx=(self.capacity - 1)+i
        data_idx = leaf_idx - (self.capacity - 1)
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    @property
    def total_p(self):
        #所有叶结点的优先级值累加
        return self.tree[0]  # the root

############################################################################
class Memory(object):  # stored as ( s, a, r, s_ ) in SumTree

    epsilon = 0.01  # small amount to avoid zero priority
    alpha = 0.6  # [0~1] convert the importance of TD error to priority
    beta = 0.4  # importance-sampling, from initial value increasing to 1
    beta_increment_per_sampling = 0.001
    abs_err_upper = 1.  # clipped abs error

    def __init__(self, capacity):
        self.tree = SumTree(capacity)

    #1)p规则:如何赋予新transition的p
    def store(self, transition):
        #情况一:新transition的p,设置为当前所有叶结点优先级值的最大值max_p(!=0)
        max_p = np.max(self.tree.tree[-self.tree.capacity:])
        #情况二:新transition的p,设置为1(若当前所有叶结点优先级值的最大值max_p=0)
        if max_p == 0:
            max_p = self.abs_err_upper
        self.tree.add(max_p, transition)   # set the max p for new p

    #2)抽样
    #目的:从经验池随机取出BATCH_SIZE条经验(通过BATCH_SIZE个v,在get_leaf中遍历实现)
    #注意:及时调整权重更新幅度、及时记录经验选择的结果
    def sample(self, batch_size):
        #2.1)初始化:树结构中索引b_idx:[batch_size,1] , b_memory:[batch_size,8] , ISWeights:[batch_size,1]
        b_idx, b_memory, ISWeights = np.empty((batch_size,), dtype=np.int32), np.empty((batch_size, self.tree.data[0].size)), np.empty((batch_size, 1))

        #2.2)priority segment:根据batch_size
        pri_seg = self.tree.total_p / batch_size
        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])  # max = 1
        min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p     # for later calculate ISweight

        #2.3)采样batch_size个transition
        for i in range(batch_size):
            #a)划分:batch_size个区间
            a, b = pri_seg * i, pri_seg * (i + 1)
            #b)v值:在各划分区间中随机取值
            v = np.random.uniform(a, b)
            #c)经验选择:对应树结构中索引idx, b_memory,ISWeights
            idx, p, data = self.tree.get_leaf(v)
            #d)权重更新幅度的调整:通过 ISWeights实现
            #问题:PER是一种不均匀抽样,这种抽样偏好会导致预测偏差
            #解决:当一条transition被抽样的次数增多,self.beta值会随之的不断增加,从而减少权重更新幅度(c_loss)
            prob = p / self.tree.total_p
            ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)#c_loss=td_error * self.ISWeights
            #e)记录:batch_size个transition(经验选择的结果)
            b_idx[i], b_memory[i, :] = idx, data
        return b_idx, b_memory, ISWeights

    #3)批更新
    #目的:通过树结构,更新已抽样batch个transition,经过网络学习后,对应的优先级值变化
    #步骤:1)td_error_up = abs(q_target - q) * self.ISWeights + self.epsilon
    #2)td_error_up->clipped_errors->b_p

    def batch_update(self, b_idx, b_td_error_up):
        #b_p:已抽样batch个transition的最终更新优先级值
        #维度:【batch_size,1】
        b_td_error_up += self.epsilon  # convert to abs and avoid 0
        clipped_errors = np.minimum(b_td_error_up, self.abs_err_upper)
        b_p = np.power(clipped_errors, self.alpha)
        print(b_p)
        #通过b_idx, b_p,利用树结构及时更新各transition的优先级值(经过网络学习后的)
        for idx, p in zip(b_idx, b_p):
            self.tree.update(idx, p)
  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值