这篇文章主要讲解Prioritized Replay DQN
主要解决问题
之前的DQN算法系列有个问题,每一次从经验回放集合里面抽取数据,每一个数据都是有相同的概率被抽取。这是有问题的,这也是导致收敛速度变慢的原因之一,我们应该着重关注TD误差绝对值比较大的数据,而不是那些效果已经比较好的数据。
算法基础讲解
还记得之前的损失函数是:
1
m
∑
j
=
1
m
(
y
j
−
Q
(
ϕ
(
S
j
)
,
A
j
,
w
)
)
2
\frac{1}{m}\sum\limits_{j=1}^m(y_j-Q(\phi(S_j),A_j,w))^2
m1j=1∑m(yj−Q(ϕ(Sj),Aj,w))2
TD误差也由此而来,公式如下:
δ
j
=
y
j
−
Q
(
ϕ
(
S
j
)
,
A
j
,
w
)
\delta_j = y_j- Q(\phi(S_j),A_j,w)
δj=yj−Q(ϕ(Sj),Aj,w),然后经验数据的优先级为
p
j
=
∣
δ
j
∣
p_j = |\delta_j|
pj=∣δj∣,
p
j
p_j
pj越大被抽中的概率也就越大。
由此我们储存数据的方式也不能像之前一样,用简单的集合。而是一个叫suntree的数据结构。
sumtree解释
因为我们需要对数据进行优先级的比较,
p
j
p_j
pj越大被抽中的概率也就越大。这里把优先级大小当成是一个区间的大小,在这个区间内进行采样。举个栗子:
有三个数据分别是data1,data2,data3,它们的优先集分别是1,4,8。这里就相当于1,4,8三个区间,总区间长度为13。在[1,13]内采样(随机选一个数字),这里假设采一个样本。data1在[1]区间内,data2在[2,5],data3在[6,13];抽到那个区间就是那个数据。好了让我们来看下面这幅图:
上面这副图就是sumtree的整体结构,内部节点不是数据,只有叶子节点才是数据。这里面一共有8个叶子节点(数据),圈内的数字代表它们的优先级大小。下面的(0-3)的东西代表的是它所占区间大小(优先级大小),明显的12的叶子节点最大。每一个节点的数值都是它的子节点之和。所以也可以很清楚的看到总的区间大小为42,也就是说我们需要从[0,42]中选一个数。
如上图,我们选一个数字24很明显在左边那个29的数字之内[0,29],所以去往左边(数字可以理解为区间大小,左叶子节点为左区间,右叶子节点为右区间),然后来到29数字,左边为[0,13]区间,右边为[13,29];跟左边比,29大于13不在左区间,去往右区间,同时应该注意,此时24应该减去13等于11也就是此时在右区间的相对位置(或者也可以表示相对区间大小)。然后接下来到了16,和上面一样比较11<12,在左区间,进入12叶子节点。就选择该数据。
然后以下是sumtree的实现,感受下:
class SumTree(object):
"""
This SumTree code is a modified version and the original code is from:
https://github.com/jaara/AI-blog/blob/master/SumTree.py
Story data with its priority in the tree.
"""
data_pointer = 0
def __init__(self, capacity):
self.capacity = capacity # for all priority values
self.tree = np.zeros(2 * capacity - 1)
# [--------------Parent nodes-------------][-------leaves to recode priority-------]
# size: capacity - 1 size: capacity
self.data = np.zeros(capacity, dtype=object) # for all transitions
# [--------------data frame-------------]
# size: capacity
def add(self, p, data):
tree_idx = self.data_pointer + self.capacity - 1
self.data[self.data_pointer] = data # update data_frame
self.update(tree_idx, p) # update tree_frame
self.data_pointer += 1
if self.data_pointer >= self.capacity: # replace when exceed the capacity
self.data_pointer = 0
def update(self, tree_idx, p):
change = p - self.tree[tree_idx]
self.tree[tree_idx] = p
# then propagate the change through tree
while tree_idx != 0: # this method is faster than the recursive loop in the reference code
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change
def get_leaf(self, v):
"""
Tree structure and array storage:
Tree index:
0 -> storing priority sum
/ \
1 2
/ \ / \
3 4 5 6 -> storing priority for transitions
Array type for storing:
[0,1,2,3,4,5,6]
"""
parent_idx = 0
while True: # the while loop is faster than the method in the reference code
cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree): # reach bottom, end search
leaf_idx = parent_idx
break
else: # downward search, always search for a higher priority node
if v <= self.tree[cl_idx]:
parent_idx = cl_idx
else:
v -= self.tree[cl_idx]
parent_idx = cr_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
@property
def total_p(self):
return self.tree[0] # the root
思路还是蛮简单的,讲完这个接下来,还要讲损失函数的优化,之前我们Q网络的损失函数是
1
m
∑
j
=
1
m
(
y
j
−
Q
(
ϕ
(
S
j
)
,
A
j
,
w
)
)
2
\frac{1}{m}\sum\limits_{j=1}^m(y_j-Q(\phi(S_j),A_j,w))^2
m1j=1∑m(yj−Q(ϕ(Sj),Aj,w))2,做一个小小的改变根据数据的优先级做一点变动,损失函数变为:
1
m
∑
j
=
1
m
w
j
(
y
j
−
Q
(
ϕ
(
S
j
)
,
A
j
,
w
)
)
2
\frac{1}{m}\sum\limits_{j=1}^mw_j(y_j-Q(\phi(S_j),A_j,w))^2
m1j=1∑mwj(yj−Q(ϕ(Sj),Aj,w))2
多了一个
w
j
w_j
wj由优先级
p
j
p_j
pj得到,公式如下:
P
(
j
)
=
p
j
∑
i
(
p
i
)
P(j) = \frac{p_j}{\sum\limits_i(p_i)}
P(j)=i∑(pi)pj
w
j
=
(
N
∗
P
(
j
)
)
−
β
/
max
i
(
w
i
)
w_j = (N*P(j))^{-\beta}/\max_i(w_i)
wj=(N∗P(j))−β/maxi(wi)
采样权重系数
β
β
β自己设置。
另外计算方式可以是这样:
w
j
=
(
N
∗
P
(
j
)
)
−
β
max
i
(
w
i
)
=
(
N
∗
P
(
j
)
)
−
β
max
i
(
(
N
∗
P
(
i
)
)
−
β
)
=
(
P
(
j
)
)
−
β
max
i
(
(
P
(
i
)
)
−
β
)
=
(
p
j
min
i
P
(
i
)
)
−
β
w_j = \frac{ (N*P(j))^{-\beta}}{\max_i(w_i)} = \frac{ (N*P(j))^{-\beta}}{\max_i((N*P(i))^{-\beta})} = \frac{ (P(j))^{-\beta}}{\max_i((P(i))^{-\beta})} =( \frac{p_j}{\min_iP(i)})^{-\beta}
wj=maxi(wi)(N∗P(j))−β=maxi((N∗P(i))−β)(N∗P(j))−β=maxi((P(i))−β)(P(j))−β=(miniP(i)pj)−β
算法流程
算法输入:迭代轮数 T T T,状态特征维度 n n n, 动作集 A A A, 步长 α α α,采样权重系数 β β β,衰减因子 γ γ γ, 探索率 ϵ ϵ ϵ, 当前 Q Q Q网络 Q Q Q,目标 Q Q Q网络 Q ′ Q^′ Q′, 批量梯度下降的样本数 m m m,目标 Q Q Q网络参数更新频率 C C C, SumTree的叶子节点数 S S S。
输出: Q Q Q网络参数。
1. 随机初始化所有的状态和动作对应的价值 Q Q Q. 随机初始化当前 Q Q Q网络的所有参数 w w w,初始化目标 Q Q Q网络 Q ′ Q^′ Q′的参数 w ′ = w w^′=w w′=w。初始化经验回放SumTree的默认数据结构,所有SumTree的 S S S个叶子节点的优先级 p j pj pj为1。
2. for i from 1 to T,进行迭代。
a) 初始化
S
S
S为当前状态序列的第一个状态, 拿到其特征向量
ϕ
(
S
)
ϕ(S)
ϕ(S)
b) 在
Q
Q
Q网络中使用
ϕ
(
S
)
ϕ(S)
ϕ(S)作为输入,得到
Q
Q
Q网络的所有动作对应的
Q
Q
Q值输出。用
ϵ
ϵ
ϵ−贪婪法在当前
Q
Q
Q值输出中选择对应的动作
A
A
A
c) 在状态
S
S
S执行当前动作
A
A
A,得到新状态
S
′
S^′
S′对应的特征向量
ϕ
(
S
′
)
ϕ(S^′)
ϕ(S′)和奖励
R
R
R,是否终止状态is_end
d) 将 ϕ ( S ) , A , R , ϕ ( S ′ ) , i s e n d {ϕ(S),A,R,ϕ(S^′),is_end} ϕ(S),A,R,ϕ(S′),isend这个五元组存入SumTree
e)
S
=
S
′
S=S^′
S=S′
f) 从SumTree中采样m个样本
ϕ
(
S
j
)
,
A
j
,
R
j
,
ϕ
(
S
j
′
)
,
i
s
e
n
d
j
,
j
=
1
,
2.
,
,
,
m
{ϕ(Sj),Aj,Rj,ϕ(S^′_j),is_endj},j=1,2.,,,m
ϕ(Sj),Aj,Rj,ϕ(Sj′),isendj,j=1,2.,,,m,每个样本被采样的概率基于
P
(
j
)
=
p
j
∑
i
(
p
i
)
P(j) = \frac{p_j}{\sum\limits_i(p_i)}
P(j)=i∑(pi)pj,损失函数权重
w
j
=
(
N
∗
P
(
j
)
)
−
β
/
max
i
(
w
i
)
w_j = (N*P(j))^{-\beta}/\max_i(w_i)
wj=(N∗P(j))−β/maxi(wi),计算当前目标
Q
Q
Q值
y
j
y_j
yj:
y
j
=
{
R
j
i
s
_
e
n
d
j
i
s
t
r
u
e
R
j
+
γ
Q
′
(
ϕ
(
S
j
′
)
,
arg
max
a
′
Q
(
ϕ
(
S
j
′
)
,
a
,
w
)
,
w
′
)
i
s
_
e
n
d
j
i
s
f
a
l
s
e
y_j= \begin{cases} R_j& {is\_end_j\; is \;true}\\ R_j + \gamma Q'(\phi(S'_j),\arg\max_{a'}Q(\phi(S'_j),a,w),w')& {is\_end_j\; is \;false} \end{cases}
yj={RjRj+γQ′(ϕ(Sj′),argmaxa′Q(ϕ(Sj′),a,w),w′)is_endjistrueis_endjisfalse
g) 使用均方差损失函数
1
m
∑
j
=
1
m
w
j
(
y
j
−
Q
(
ϕ
(
S
j
)
,
A
j
,
w
)
)
2
\frac{1}{m}\sum\limits_{j=1}^mw_j(y_j-Q(\phi(S_j),A_j,w))^2
m1j=1∑mwj(yj−Q(ϕ(Sj),Aj,w))2,通过神经网络的梯度反向传播来更新
Q
Q
Q网络的所有参数w
h) 重新计算所有样本的TD误差
δ
j
=
y
j
−
Q
(
ϕ
(
S
j
)
,
A
j
,
w
)
\delta_j = y_j- Q(\phi(S_j),A_j,w)
δj=yj−Q(ϕ(Sj),Aj,w),更新SumTree中所有节点的优先级
p
j
=
∣
δ
j
∣
pj=|δj|
pj=∣δj∣
i) 如果T%C=1,则更新目标Q网络参数
w
′
=
w
w^′=w
w′=w
j) 如果
S
′
S^′
S′是终止状态,当前轮迭代完毕,否则转到步骤b)
代码
接下来,参考Tensorflow代码写的pytorch:
# -*- coding: utf-8 -*-
"""
Created on Fri Dec 6 15:46:28 2019
@author: asus
"""
import gym
import torch
import torch.nn.functional as F
import numpy as np
import random
GAMMA = 0.9
INITIAL_EPSILON = 0.5
FINAL_EPSILON = 0.01
REPLAY_SIZE = 10000
BATCH_SIZE = 32
ENV_NAME = 'CartPole-v0'
EPISODE = 3000 # Episode limitation
STEP = 300 # Step limitation in an episode
TEST = 10 # The number of experiment test every 100 episode
class SumTree(object):
"""
This SumTree code is a modified version and the original code is from:
https://github.com/jaara/AI-blog/blob/master/SumTree.py
Story data with its priority in the tree.
"""
data_pointer = 0
def __init__(self, capacity):
self.capacity = capacity # for all priority values
self.tree = np.zeros(2 * capacity - 1)
# [--------------Parent nodes-------------][-------leaves to recode priority-------]
# size: capacity - 1 size: capacity
self.data = np.zeros(capacity, dtype=object) # for all transitions
# [--------------data frame-------------]
# size: capacity
def add(self, p, data):
tree_idx = self.data_pointer + self.capacity - 1
self.data[self.data_pointer] = data # update data_frame
self.update(tree_idx, p) # update tree_frame
self.data_pointer += 1
if self.data_pointer >= self.capacity: # replace when exceed the capacity
self.data_pointer = 0
def update(self, tree_idx, p):
change = p - self.tree[tree_idx]
self.tree[tree_idx] = p
# then propagate the change through tree
while tree_idx != 0: # this method is faster than the recursive loop in the reference code
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change
def get_leaf(self, v):
"""
Tree structure and array storage:
Tree index:
0 -> storing priority sum
/ \
1 2
/ \ / \
3 4 5 6 -> storing priority for transitions
Array type for storing:
[0,1,2,3,4,5,6]
"""
parent_idx = 0
while True: # the while loop is faster than the method in the reference code
cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree): # reach bottom, end search
leaf_idx = parent_idx
break
else: # downward search, always search for a higher priority node
if v <= self.tree[cl_idx]:
parent_idx = cl_idx
else:
v -= self.tree[cl_idx]
parent_idx = cr_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
@property
def total_p(self):
return self.tree[0] # the root
class Memory(object): # stored as ( s, a, r, s_ ) in SumTree
"""
This Memory class is modified based on the original code from:
https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
"""
epsilon = 0.01 # small amount to avoid zero priority
alpha = 0.6 # [0~1] convert the importance of TD error to priority
beta = 0.4 # importance-sampling, from initial value increasing to 1
beta_increment_per_sampling = 0.001
abs_err_upper = 1. # clipped abs error
def __init__(self, capacity):
self.tree = SumTree(capacity)
def store(self, transition):
max_p = np.max(self.tree.tree[-self.tree.capacity:])
if max_p == 0:
max_p = self.abs_err_upper
self.tree.add(max_p, transition) # set the max p for new p
def sample(self, n):
b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1))
pri_seg = self.tree.total_p / n # priority segment 均匀区间
#belta不断变大
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) # max = 1
#最小概率
min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p # for later calculate ISweight
if min_prob == 0:
min_prob = 0.00001
for i in range(n):
a, b = pri_seg * i, pri_seg * (i + 1)
#取均匀分布
v = np.random.uniform(a, b)
#p为误差
idx, p, data = self.tree.get_leaf(v)
prob = p / self.tree.total_p
ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
b_idx[i], b_memory[i, :] = idx, data
return b_idx, b_memory, ISWeights
def batch_update(self, tree_idx, abs_errors):
abs_errors += self.epsilon # convert to abs and avoid 0
clipped_errors = np.minimum(abs_errors.detach().numpy(), self.abs_err_upper)
ps = np.power(clipped_errors, self.alpha)
for ti, p in zip(tree_idx, ps):
self.tree.update(ti, p)
class MODEL(torch.nn.Module):
def __init__(self, env):
super(MODEL, self).__init__()
self.state_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.n
self.fc1 = torch.nn.Linear(self.state_dim, 20)
self.fc1.weight.data.normal_(0, 0.6)
self.fc2 = torch.nn.Linear(20, self.action_dim)
def create_Q_network(self, x):
x = F.relu(self.fc1(x))
Q_value = self.fc2(x)
return Q_value
def forward(self, x, action_input):
Q_value = self.create_Q_network(x)
Q_action = torch.mul(Q_value, action_input).sum(dim=1)
return Q_action
class DQN():
def __init__(self, env):
self.replay_total = 0
self.target_Q_net = MODEL(env)
self.current_Q_net = MODEL(env)
self.memory = Memory(capacity=REPLAY_SIZE)
self.time_step = 0
self.epsilon = INITIAL_EPSILON
self.optimizer = torch.optim.Adam(params=self.current_Q_net.parameters(), lr=0.0001)
# self.loss = torch.nn.MSELoss()
def store_transition(self, s, a, r, s_, done):
transition = np.hstack((s, a, r, s_, done))
self.memory.store(transition)
def perceive(self,state,action,reward,next_state,done):
one_hot_action = np.zeros(self.current_Q_net.action_dim)
one_hot_action[action] = 1
self.store_transition(state,one_hot_action,reward,next_state,done)
self.replay_total += 1
if self.replay_total > BATCH_SIZE:
self.train_Q_network()
def train_Q_network(self):
self.time_step += 1
# Step 1: obtain random minibatch from replay memory
tree_idx, minibatch, ISWeights = self.memory.sample(BATCH_SIZE)
state_batch = torch.tensor(minibatch[:,0:4], dtype=torch.float32)
action_batch = torch.tensor(minibatch[:,4:6], dtype=torch.float32)
reward_batch = [data[6] for data in minibatch]
next_state_batch = torch.tensor(minibatch[:,7:11], dtype=torch.float32)
# Step 2: calculate y
y_batch = []
current_a = self.current_Q_net.create_Q_network(next_state_batch)
max_current_action_batch = torch.argmax(current_a, axis=1)
Q_value_batch = self.target_Q_net.create_Q_network(next_state_batch)
for i in range(0,BATCH_SIZE):
done = minibatch[i][11]
if done:
y_batch.append(reward_batch[i])
else:
max_current_action = max_current_action_batch[i]
y_batch.append(reward_batch[i] + GAMMA * Q_value_batch[i,max_current_action])
y = self.current_Q_net(torch.FloatTensor(state_batch), torch.FloatTensor(action_batch))
y_batch = torch.FloatTensor(y_batch)
cost = self.loss(y, y_batch, torch.tensor(ISWeights))
self.optimizer.zero_grad()
cost.backward()
self.optimizer.step()
y = self.current_Q_net(torch.FloatTensor(state_batch), torch.FloatTensor(action_batch))
abs_errors = torch.abs(y_batch - y)
self.memory.batch_update(tree_idx, abs_errors)
def egreedy_action(self,state):
Q_value = self.current_Q_net.create_Q_network(torch.FloatTensor(state))
if random.random() <= self.epsilon:
self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / 10000
return random.randint(0, self.current_Q_net.action_dim - 1)
else:
self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / 10000
return torch.argmax(Q_value).item()
def action(self,state):
return torch.argmax(self.target_Q_net.create_Q_network(torch.FloatTensor(state))).item()
def update_target_params(self):
torch.save(self.current_Q_net.state_dict(), 'net_params.pkl')
self.target_Q_net.load_state_dict(torch.load('net_params.pkl'))
def loss(self, y_output, y_true, ISWeights):
value = y_output - y_true
return torch.mean(value*value*ISWeights)
def main():
# initialize OpenAI Gym env and dqn agent
env = gym.make(ENV_NAME)
agent = DQN(env)
for episode in range(EPISODE):
# initialize task
state = env.reset()
# Train
for step in range(STEP):
action = agent.egreedy_action(state) # e-greedy action for train
next_state,reward,done,_ = env.step(action)
# Define reward for agent
reward = -1 if done else 0.1
agent.perceive(state,action,reward,next_state,done)
state = next_state
if done:
break
# Test every 100 episodes
if episode % 100== 0:
total_reward = 0
for i in range(TEST):
state = env.reset()
for j in range(STEP):
# env.render()
action = agent.action(state) # direct action for test
state,reward,done,_ = env.step(action)
total_reward += reward
if done:
break
ave_reward = total_reward/TEST
print ('episode: ',episode,'Evaluation Average Reward:',ave_reward)
agent.update_target_params()
if __name__ == '__main__':
main()