[RL]优先经验回放(PER-DQN)原理及代码实现

之前说过DQN、DuelingDQN、DDQN原理及代码实现,今天来说另外一个DQN的变体:优先经验回放Prioritized Experience Replay(PER-DQN)

一、优先经验回放(PER-DQN)原理

论文Prioritized Experience Replay

一、概述

强化学习由于其算法特性,并没有现成的数据集,而仅靠单步获得的数据对未知的复杂环境信息进行感知决策并不高效可靠。DQN算法结合神经网络的同时,结合了经验回放机制,针对Q-learning的局限性,打消了采样数据相关性,使得数据分布变得更稳定。

但随着DQN算法的应用,研究人员发现,基于经验回放机制的DQN算法,仅采用均匀采样批次更新,导致部分数量少但价值高的经验没有被高效的利用。针对上述情况,Deep Mind团队提出了Prioritized Experience Replay(优先经验回放)机制

二、基本原理

1 经验回放(Experience Replay):

  • 创建一个经验池(Experience Replay Buffer),每一次Agent选择一个动作与环境交互就会储存一组数据 e t = ( s t , a t , r t , s t + 1 ) e_t = (s_t, a_t, r_t, s_{t+1}) et=(st,at,rt,st+1)到经验池中。
  • 维护这个经验池(队列),当储存的数据组数到达一定的阈值,数据到就会从队列中被提取出来。
  • 采用均匀采样的方式进行数据提取。

上述方法解决了经验数据的相关性(Correlated data)和非平稳分布(Non-stationary distribution)问题。它的做法是从以往的状态转移(经验)中均匀采样进行训练。优点是数据利用率高,一个样本被多次使用,且连续样本的相关性会使参数更新的方差(Variance)比较大,以此减少这种相关性。

然而,采用均匀采样方式存在的问题,作者举了例子如图所示:
左图表示一个(稀疏奖励)环境有初始状态为1,有n个状态,两个可选动作,仅当选择绿色线条动作的时候可以得到 reward=1 的奖励;右图为实验结果,黑色曲线代表均匀采样的结果,蓝色曲线为研究人员提出的一个名为“oracle”的最优次序,即每次采样的transition均采用“最好”的结果,实验结果可看出每次采用最优次序的方法在稀疏奖励(Reward sparse)环境能够明显优于均匀采样。

那么如何在实际应用当中找到这个“最优”次序,即如何在采样前提前设计好一个次序,使得每次采样的数据都尽可能让agent高效学习呢?

2 优先经验回放(Prioritized Experience Replay,PER)

针对经验回放机制存在的问题,DeepMind团队提出了两方面的思考:要存储哪些经验(which experiences to store),以及要重放哪些经验(which experiences to replay,and how to do so)。论文中仅针对后者,即怎么样选取要采样的数据以及实验的方法做了详尽的说明和研究。

PER机制将TD-error(时序误差)作为衡量标准评估被采样数据的优先级。TD-error指在时序差分中当前Q值和它目标Q值的差值,误差越大即表示该数据对网络参数的更新越有帮助。贪婪(选取最大值)的采样TD-error大的数据训练,理论上会加速收敛,但随之而来也会面临以下问题:

  • TD-error可看做对参数更新的信息增益,信息增益较大仅表示对于当前的价值网络参数而言增益较大,但却不能代表对于后续的价值网络没有较大的增益。若只贪婪的考虑信息增益来采样,当前TD-error较小的数据优先级会越来越低,后面会越来越难采样到该组数据。
  • 贪婪的选择使得神经网络总是更新某一部分样本,即“经验的一个子集”,很可能导致陷入局部最优,亦或是过估计的发生。

针对上述PER存在的问题,作者在文中提出了一种随机抽样的方法,该方法介于纯贪婪和均匀随机之间,确保transition基于优先级的被采样概率是单调的,同时即使对于最低优先级的transition也保证非零的概率。

三、相关改进工作

1 随机优先级(Stochastic Prioritization)

论文将采样transition i i i的概率定义为:

P ( i ) = p i α ∑ k p k α P(i)=\frac{p_i^\alpha}{\sum_k p_k^\alpha} P(i)=kpkαpiα

其中 p i > 0 p_i>0 pi>0表示transition i i i的优先级。指数 α α α表示决定使用多少优先级,可看做一个trade-off因子,用来权衡uniform和greedy的程度,当 α = 0 α=0 α=0时表示均匀采样, α = 1 α=1 α=1是表示贪婪(选取最大值)采样。

在DQN中: δ = y − Q ( s , a ) \delta=y -Q(s, a) δ=yQ(s,a), δ \delta δ表示TD-error,即每一步当前Q值与目标值 y y y之间的差值,在更新过程中也是为了让 δ 2 \delta^2 δ2的期望尽可能的小。

文中将随机优先经验回放划分为以下两个类型:

a)直接的,基于比例的:Proportional Prioritization

b)间接的,基于排名的:Rank-based Prioritization

  • a)Proportional Prioritization中,根据 ∣ δ ∣ |\delta| δ决定采样概率:

p i = ∣ δ i ∣ + ϵ p_i=\left|\delta_i\right|+\epsilon pi=δi+ϵ

其中 δ \delta δ表示TD-eroor, ϵ \epsilon ϵ是一个大于0的常数,为了保证无论TD-error取值如何,采样概率 p i p_i pi仍大于0,即仍有概率会被采样到。

  • b)Rank-based Prioritization中,根据 ∣ δ ∣ |\delta| δ排名(Rank) 来决定采样概率:

p i = 1 rank ⁡ ( i ) p_i=\frac{1}{\operatorname{rank}(i)} pi=rank(i)1

作者在文中对两种方法进行了比较:

a)从理论层次分析:
proportional prioritization优势在于可以直接获得 ∣ δ ∣ |\delta| δ 的信息,也就是它的信息增益多一些;而rank-based prioritization则没有 ∣ δ ∣ |\delta| δ 的信息,但其对异常点不敏感,因为异常点的TD-error过大或过小对rank值没有太大影响,也正因为此,rank-based prioritization具有更好的鲁棒性。

b)从实验层次分析:
结果如下图所示,可以看出这两种方法的表现大致相同。在这里插入图片描述

2 SumTree

Proportional Prioritization的实现较为复杂,可借助SumTree数据结构完成。SumTree是一种树形结构,每片树叶存储每个样本的优先级P,每个树枝节点只有两个分叉,节点的值是两个分叉的和,所以SumTree的顶端就是所有p的和。结构如下图所示, 顶层的节点是全部p的和。在这里插入图片描述
抽样时, 我们会将 p 的总和除以 batch size, 分成 batch size 多个区间,即 n = s u m ( p ) / b a t c h s i z e n={ sum(p) }/{batchsize} n=sum(p)/batchsize.。如果将所有节点的优先级加起来是42, 我们如果抽6个样本, 这时的区间拥有的 priority 可能是这样:[0-7], [7-14], [14-21], [21-28], [28-35], [35-42]

然后在每个区间里随机选取一个数. 比如在第区间 [21-28] 里选到了24, 就按照这个 24 从最顶上的42开始向下搜索. 首先看到最顶上 42 下面有两个 child nodes, 拿着手中的24对比左边的 child 29, 如果左边的 child 比自己手中的值大, 那我们就走左边这条路, 接着再对比 29 下面的左边那个点 13, 这时, 手中的 24 比 13 大, 那我们就走右边的路, 并且将手中的值根据 13 修改一下, 变成 24-13 = 11. 接着拿着 11 和 13 左下角的 12 比, 结果 12 比 11 大, 那我们就选 12 当做这次选到的 priority, 并且也选择 12 对应的数据。

以上面的树结构为例,根节点是42,如果要采样一个样本,我们可以在[0,42]之间做均匀采样,采样到哪个区间,就是哪个样本。比如我们采样到了26, 在(25-29)这个区间,那么就是第四个叶子节点被采样到。而注意到第三个叶子节点优先级最高,是12,它的区间13-25也是最长的,所以它会比其他节点更容易被采样到。

如果要采样两个样本,我们可以在[0,21],[21,42]两个区间做均匀采样,方法和上面采样一个样本类似。

3 消除偏差(Annealing the Bias)

使用优先经验回放还存在一个问题是改变了状态的分布,DQN中引入经验池是为了解决数据相关性,使数据(尽量)独立同分布。但是使用优先经验回放又改变了状态的分布,这样势必会引入偏差bias,对此,文中使用重要性采样结合退火因子来消除引入的偏差。

在DQN中,梯度的计算如下所示:

∇ θ i L ( θ i ) = E s , a , r , s ′ [ ( r + γ max ⁡ a ′ Q ( s ′ , a ′ ; θ i ) − Q ( s , a ; θ i ) ) ∇ θ i Q ( s , a ; θ i ) ] ( 1 ) \nabla_{\theta_i} L\left(\theta_i\right)=\mathbb{E}_{s, a, r, s^{\prime}}\left[\left(r+\gamma \max _{a^{\prime}} Q\left(s^{\prime}, a^{\prime} ; \theta_i\right)-Q\left(s, a ; \theta_i\right)\right) \nabla_{\theta_i} Q\left(s, a ; \theta_i\right)\right](1) θiL(θi)=Es,a,r,s[(r+γamaxQ(s,a;θi)Q(s,a;θi))θiQ(s,a;θi)]1

在随机梯度下降(SGD)中可表示为:

∇ θ L ( θ ) = δ ∇ θ Q ( s , a ) \nabla_\theta L(\theta)=\delta \nabla_\theta Q(s, a) θL(θ)=δθQ(s,a)

而重要性采样,就是给这个梯度加上一个权重 w w w

∇ θ L ( θ ) = w δ ∇ θ Q ( s , a ) \nabla_\theta L(\theta)=w \delta \nabla_\theta Q(s, a) θL(θ)=wδθQ(s,a)

重要性采样权重 w i w_i wi在文中定义为:

w i = ( 1 N ⋅ 1 P ( i ) ) β w_i=\left(\frac{1}{N} \cdot \frac{1}{P(i)}\right)^\beta wi=(N1P(i)1)β

N N N 表示Buffer里的样本数,而 β \beta β 是一个超参数,用来决定多大程度想抵消 Prioritized Experience Replay对收敛结果的影响。如果 β = 0 \beta=0 β=0,表示完全不使用重要性采样 ; β = 1 \beta=1 β=1时表示完全抵消掉影响,由于 ( s , a ) (s, a) (s,a) 不再是均匀分布随机选出来的了,而是以 P ( i ) P(i) P(i) 的概率选出来,因此,如果 β = 1 \beta=1 β=1 , 那么 w w w P ( i ) P(i) P(i) 就正好抵消了,于是Prioritized Experience Replay的作用也就被抵消了,即 β = 1 \beta=1 β=1等同于DQN中的 Experience Replay。

为了稳定性,我们需要对权重 w w w 归一化,但是不用真正意义上的归一化,只要除上 max ⁡ i w i \max _i w_i maxiwi 即可,即:

w j = ( N ∗ P ( j ) ) − β / max ⁡ i ( w i ) w_j=(N * P(j))^{-\beta} / \max _i\left(w_i\right) wj=(NP(j))β/imax(wi)

归一化后的 w i w_i wi 在编写代码时可推导转化为:

w j = ( N ∗ P ( j ) ) − β max ⁡ i ( w i ) = ( N ∗ P ( j ) ) − β max ⁡ i ( ( N ∗ P ( i ) ) − β ) = ( P ( j ) ) − β max ⁡ i ( ( P ( i ) ) − β ) = ( P j min ⁡ i P ( i ) ) − β w_j=\frac{(N * P(j))^{-\beta}}{\max _i\left(w_i\right)}=\frac{(N * P(j))^{-\beta}}{\max _i\left((N * P(i))^{-\beta}\right)}=\frac{(P(j))^{-\beta}}{\max _i\left((P(i))^{-\beta}\right)}=\left(\frac{P_j}{\min _i P(i)}\right)^{-\beta} wj=maxi(wi)(NP(j))β=maxi((NP(i))β)(NP(j))β=maxi((P(i))β)(P(j))β=(miniP(i)Pj)β

五、PER代码

1 Prioritized Replay DQN 算法流程

算法输入:迭代轮数 T T T ,状态特征维度 n n n ,动作集 A A A ,步长 α \alpha α ,采样权重系数 β \beta β ,衰减因子 γ \gamma γ ,探索率 ϵ \epsilon ϵ ,当前 Q Q Q 网络 Q Q Q ,目标 Q \mathrm{Q} Q 网络 Q ′ Q^{\prime} Q ,批量梯度下降的样本数 m m m ,目标 Q \mathrm{Q} Q 网络参数更新频率 C C C ,SumTree的叶子节点数 S S S

输出: Q网络参数。

  1. 随机初始化所有的状态和动作对应的价值 Q Q Q. 随机初始化当前 Q \mathrm{Q} Q 网络的所有参数 w w w, 初始化目标Q网络 Q ′ Q^{\prime} Q 的参数 w ′ = w w^{\prime}=w w=w 。初始化经验回放SumTree 的默认数据结构,所有SumTree的 S S S 个叶子节点的优先级 p j p_j pj 为 1 。

  2. for i from 1 to T T T ,进行迭代。

    a) 初始化S为当前状态序列的第一个状态,得到其特征向量 ϕ ( S ) \phi(S) ϕ(S)

    b) 在 Q \mathrm{Q} Q 网络中使用 ϕ ( S ) \phi(S) ϕ(S) 作为输入,得到 Q \mathrm{Q} Q 网络的所有动作对应的 Q \mathrm{Q} Q 值输出。用 ϵ \epsilon ϵ 一贪婪法在当前 Q \mathrm{Q} Q 值输出中选择对应的动作 A A A

    c) 在状态 S S S 执行当前动作 A A A, 得到新状态 S ′ S^{\prime} S 对应的特征向量 ϕ ( S ′ ) \phi\left(S^{\prime}\right) ϕ(S) 和奖励 R R R, 是否终止状态is_end

    d) 将 { ϕ ( S ) , A , R , ϕ ( S ′ ) , i s − e n d } \left\{\phi(S), A, R, \phi\left(S^{\prime}\right), i s_{-} e n d\right\} {ϕ(S),A,R,ϕ(S),isend} 这个五元组存入SumTree

    e) S = S ′ S=S^{\prime} S=S

    f) 从SumTree中采样 m m m 个样本 { ϕ ( S j ) , A j , R j , ϕ ( S j ′ ) , i s − e n d j } , j = 1 , 2. , , , m \left\{\phi\left(S_j\right), A_j, R_j, \phi\left(S_j^{\prime}\right), i s_{-} e n d_j\right\}, j=1,2 .,,, m {ϕ(Sj),Aj,Rj,ϕ(Sj),isendj},j=1,2.,,,m,每个样本被采样的概率基于 P ( j ) = p j ∑ i ( p i ) P(j)=\frac{p_j}{\sum_i\left(p_i\right)} P(j)=i(pi)pj ,损失函数权重 w j = ( N ∗ P ( j ) ) − β / max ⁡ i ( w i ) w_j=(N * P(j))^{-\beta} / \max _i\left(w_i\right) wj=(NP(j))β/maxi(wi) ,计算当前目标Q值 y j y_j yj :

    y j = { R j  is end  j  is true  R j + γ Q ′ ( ϕ ( S j ′ ) , arg ⁡ max ⁡ a ′ Q ( ϕ ( S j ′ ) , a , w ) , w ′ )  is end  j  is false  y_j= \begin{cases}R_j & \text { is end }_j \text { is true } \\ R_j+\gamma Q^{\prime}\left(\phi\left(S_j^{\prime}\right), \arg \max _{a^{\prime}} Q\left(\phi\left(S_j^{\prime}\right), a, w\right), w^{\prime}\right) & \text { is end }_j \text { is false }\end{cases} yj={RjRj+γQ(ϕ(Sj),argmaxaQ(ϕ(Sj),a,w),w) is end j is true  is end j is false 

    g)使用均方差损失函数 1 m ∑ j = 1 m w j ( y j − Q ( ϕ ( S j ) , A j , w ) ) 2 \frac{1}{m} \sum_{j=1}^m w_j\left(y_j-Q\left(\phi\left(S_j\right), A_j, w\right)\right)^2 m1j=1mwj(yjQ(ϕ(Sj),Aj,w))2 ,通过神经网络的梯度反向传播来更新Q网络的所有参数 w w w

    h) 重新计算所有样本的TD误差 δ j = y j − Q ( ϕ ( S j ) , A j , w ) \delta_j=y_j-Q\left(\phi\left(S_j\right), A_j, w\right) δj=yjQ(ϕ(Sj),Aj,w),更新SumTree中所有节点的优先级 p j = ∣ δ j ∣ p_j=\left|\delta_j\right| pj=δj

    i) 如果 i % C = 1 \mathrm{i} \% \mathrm{C}=1 i%C=1, 则更新目标 Q \mathrm{Q} Q 网络参数 w ′ = w w^{\prime}=w w=w

    j) 如果 S ′ S^{\prime} S 是终止状态,当前轮迭代完毕,否则转到步骤b)

四、代码实现

原DQN代码在DQN、DuelingDQN、DDQN原理及代码实现,只在原DQN添加sum_tree、ReplayTree、store_transition以及learn部分等内容
完整代码实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gym
import wandb
import random
BATCH_SIZE=32
LR=0.01
EPSILON_START = 0.1  # epsilon 的初始值
EPSILON_END = 0.1  # epsilon 的最终值
EPSILON_DECAY = 1000 # epsilon 的衰减步数
GAMMA=0.9
TARGET_REPLACE_ITER=100
MEMORY_CAPACITY=2000
env=gym.make("CartPole-v1",render_mode="human") 
N_ACTIONS=env.action_space.n
N_STATES=env.observation_space.shape[0]

定义优先度采样要用的树结构

class SumTree:
    def __init__(self, capacity: int):
        # 初始化SumTree,设定容量
        self.capacity = capacity
        # 数据指针,指示下一个要存储数据的位置
        self.data_pointer = 0
        # 数据条目数
        self.n_entries = 0
        # 构建SumTree数组,长度为(2 * capacity - 1),用于存储树结构
        self.tree = np.zeros(2 * capacity - 1)
        # 数据数组,用于存储实际数据
        self.data = np.zeros(capacity, dtype=object)

    def update(self, tree_idx, p):#更新采样权重
        # 计算权重变化
        change = p - self.tree[tree_idx]
        # 更新树中对应索引的权重
        self.tree[tree_idx] = p

        # 从更新的节点开始向上更新,直到根节点
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    def add(self, p, data):#向SumTree中添加新数据
        # 计算数据存储在树中的索引
        tree_idx = self.data_pointer + self.capacity - 1
        # 存储数据到数据数组中
        self.data[self.data_pointer] = data
        # 更新对应索引的树节点权重
        self.update(tree_idx, p)

        # 移动数据指针,循环使用存储空间
        self.data_pointer += 1
        if self.data_pointer >= self.capacity:
            self.data_pointer = 0

        # 维护数据条目数
        if self.n_entries < self.capacity:
            self.n_entries += 1

    def get_leaf(self, v):#采样数据
        # 从根节点开始向下搜索,直到找到叶子节点
        parent_idx = 0
        while True:
            cl_idx = 2 * parent_idx + 1
            cr_idx = cl_idx + 1
            # 如果左子节点超出范围,则当前节点为叶子节点
            if cl_idx >= len(self.tree):
                leaf_idx = parent_idx
                break
            else:
                # 根据采样值确定向左还是向右子节点移动
                if v <= self.tree[cl_idx]:
                    parent_idx = cl_idx
                else:
                    v -= self.tree[cl_idx]
                    parent_idx = cr_idx

        # 计算叶子节点在数据数组中的索引
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    def total(self):
        return int(self.tree[0])

定义PER的ReplayTree

class ReplayTree:#ReplayTree for the per(Prioritized Experience Replay) DQN. 
    def __init__(self, capacity):
        self.capacity = capacity # 记忆回放的容量
        self.tree = SumTree(capacity)  # 创建一个SumTree实例
        self.abs_err_upper = 1.  # 绝对误差上限
        self.epsilon = 0.01
        ## 用于计算重要性采样权重的超参数
        self.beta_increment_per_sampling = 0.001
        self.alpha = 0.6
        self.beta = 0.4 
        self.abs_err_upper = 1.

    def __len__(self):# 返回存储的样本数量
        return self.tree.total()

    def push(self, error, sample):#Push the sample into the replay according to the importance sampling weight
        p = (np.abs(error.detach().numpy()) +self.epsilon) ** self.alpha
        self.tree.add(p, sample)         


    def sample(self, batch_size):
        pri_segment = self.tree.total() / batch_size

        priorities = []
        batch = []
        idxs = []

        is_weights = []

        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])
        min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total() 

        for i in range(batch_size):
            a = pri_segment * i
            b = pri_segment * (i+1)

            s = random.uniform(a, b)
            idx, p, data = self.tree.get_leaf(s)

            priorities.append(p)
            batch.append(data)
            idxs.append(idx)
            prob = p / self.tree.total()

        sampling_probabilities = np.array(priorities) / self.tree.total()
        is_weights = np.power(self.tree.n_entries * sampling_probabilities, -self.beta)
        is_weights /= is_weights.max()

        return zip(*batch), idxs, is_weights
    
    def batch_update(self, tree_idx, abs_errors):#Update the importance sampling weight
        abs_errors += self.epsilon

        clipped_errors = np.minimum(abs_errors, self.abs_err_upper)
        ps = np.power(clipped_errors, self.alpha)

        for ti, p in zip(tree_idx, ps):
            self.tree.update(ti, p)

拟合Q值的网络

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.fc1=nn.Linear(N_STATES,50)
        self.fc1.weight.data.normal_(0,0.1)
        self.out=nn.Linear(50,N_ACTIONS)
        self.out.weight.data.normal_(0,0.1)

    def forward(self,x):
        x=F.relu(self.fc1(x))
        actions_value=self.out(x)
        return actions_value

定义DQN算法

class DQN(object):
    def __init__(self):
        self.eval_net,self.target_net=Net(),Net()
        self.learn_step_counter=0
        self.memory_counter=0      
        self.epsilon = EPSILON_START
        self.memory=ReplayTree(capacity=MEMORY_CAPACITY)
        self.optimizer=torch.optim.Adam(self.eval_net.parameters(),lr=LR)
        self.loss_func=nn.MSELoss()

    def choose_action(self,x):
        x=torch.unsqueeze(torch.FloatTensor(x),0)
        if np.random.uniform()<1-self.epsilon:#greedy
            actions_value=self.eval_net.forward(x)
            action=torch.max(actions_value,1)[1].data.numpy()
            action=action[0]

        else:#随机
            action=np.random.randint(0,N_ACTIONS)
        return action

    def store_transition(self,s,a,r,s_,done):
        policy_val =self.eval_net(torch.tensor(s))[a]
        target_val =self.target_net(torch.tensor(s_))
        transition = (s, a, r, s_)

        if done:
            error = abs(policy_val-r)
        else:
            error = abs(policy_val-r-GAMMA*torch.max(target_val))
        self.memory.push(error, transition)  # 添加经验和初始优先级
        self.memory_counter += 1

    def learn(self):
        if self.learn_step_counter%TARGET_REPLACE_ITER==0:
            self.target_net.load_state_dict(self.eval_net.state_dict())
        self.learn_step_counter+=1

        batch, tree_idx,is_weights = self.memory.sample(BATCH_SIZE)
        b_s,b_a,b_r,b_s_=batch
        b_s = torch.FloatTensor(b_s)
        b_a = torch.unsqueeze(torch.LongTensor(b_a), 1)
        b_r = torch.unsqueeze(torch.FloatTensor(b_r),1)
        b_s_ = torch.FloatTensor(b_s_)
        q_eval = self.eval_net(b_s).gather(1, b_a)
        q_next = self.target_net(b_s_).detach()
        q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)
        loss = (torch.FloatTensor(is_weights) * self.loss_func(q_eval, q_target)).mean()

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        abs_errors = torch.abs(q_eval - q_target).detach().numpy().squeeze()
        self.memory.batch_update(tree_idx, abs_errors)  # 更新经验的优先级
        # 更新 epsilon
        self.epsilon = EPSILON_END + (EPSILON_START - EPSILON_END) * \
            np.exp(-1.0 * self.learn_step_counter / EPSILON_DECAY) if EPSILON_END + (EPSILON_START - EPSILON_END) * \
            np.exp(-1.0 * self.learn_step_counter / EPSILON_DECAY) > EPSILON_END else EPSILON_END
        self.learn_step_counter += 1

训练

dqn=DQN()
for i in range(100):
    print('<<<<<<<Episode:%s'%i)
    s=env.reset()
    episode_reward_sum=0

    while True:
        a=dqn.choose_action(s)
        s_,r,done,info=env.step(a)
        x,x_dot,theta,theta_dot=s_
        r1=(env.x_threshold-abs(x))/env.x_threshold-0.8
        r2=(env.theta_threshold_radians-abs(theta))/env.theta_threshold_radians-0.5
        new_r=r1+r2

        dqn.store_transition(s,a,new_r,s_,done)
        episode_reward_sum+=new_r

        s=s_

        if dqn.memory_counter>BATCH_SIZE:
            dqn.learn()

        if done:
            print('episode%s---reward_sum:%s'%(i,round(episode_reward_sum,2)))
            break

训练结果,学习速度明显优于其他几个DQN改进算法
在这里插入图片描述

期待宝贵意见!原理部分来自easy-rl-master,感谢!

  • 31
    点赞
  • 66
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值