本篇文章主要讲解Dueling DQN的结构。
解决问题
对比之前的DQN,Dueling DQN主要对结构进行了优化。Dueling DQN考虑将
Q
Q
Q网络分成两部分,第一部分是仅仅与状态S有关,与具体要采用的动作
A
A
A无关,这部分我们叫做价值函数部分,记做
V
(
S
,
w
,
α
)
V(S,w,α)
V(S,w,α),第二部分同时与状态状态
S
S
S和动作
A
A
A有关,这部分叫做优势函数(Advantage Function)部分,记为
A
(
S
,
A
,
w
,
β
)
A(S,A,w,β)
A(S,A,w,β),那么最终我们的价值函数可以重新表示为:
Q
(
S
,
A
,
w
,
α
,
β
)
=
V
(
S
,
w
,
α
)
+
A
(
S
,
A
,
w
,
β
)
Q(S,A,w,α,β)=V(S,w,α)+A(S,A,w,β)
Q(S,A,w,α,β)=V(S,w,α)+A(S,A,w,β)
其中,w是公共部分的网络参数,而α是价值函数独有部分的网络参数,而β是优势函数独有部分的网络参数。图结构如下:
左边为之前的结构,右边修改之后的结构。
我们可以直接使用上一节的价值函数的组合公式得到我们的动作价值,但是这个式子无法辨识最终输出里面
V
(
S
,
w
,
α
)
V(S,w,α)
V(S,w,α)和
A
(
S
,
A
,
w
,
β
)
A(S,A,w,β)
A(S,A,w,β)各自的作用,为了可以体现这种可辨识性(identifiability),实际使用的组合公式如下:
Q
(
S
,
A
,
w
,
α
,
β
)
=
V
(
S
,
w
,
α
)
+
(
A
(
S
,
A
,
w
,
β
)
−
1
A
∑
a
′
∈
A
A
(
S
,
a
′
,
w
,
β
)
)
Q(S,A, w, \alpha, \beta) = V(S,w,\alpha) + (A(S,A,w,\beta) - \frac{1}{\mathcal{A}}\sum\limits_{a' \in \mathcal{A}}A(S,a', w,\beta))
Q(S,A,w,α,β)=V(S,w,α)+(A(S,A,w,β)−A1a′∈A∑A(S,a′,w,β))
可辨识性:可辨识性是指能否通过输入输出数据确定模型的性质的性质,我的理解是可解释性吧。
代码
然后接下来,就是参考tensorflow的代码,写的pytorch。我是在Prioritized Replay DQN的基础上写的。tensorflow的代码是在Nature DQN的基础上写的。
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 17 10:59:02 2019
@author: asus
"""
import gym
import torch
import torch.nn.functional as F
import numpy as np
import random
GAMMA = 0.9
INITIAL_EPSILON = 0.5
FINAL_EPSILON = 0.01
REPLAY_SIZE = 10000
BATCH_SIZE = 32
ENV_NAME = 'CartPole-v0'
EPISODE = 3000 # Episode limitation
STEP = 300 # Step limitation in an episode
TEST = 10 # The number of experiment test every 100 episode
class SumTree(object):
"""
This SumTree code is a modified version and the original code is from:
https://github.com/jaara/AI-blog/blob/master/SumTree.py
Story data with its priority in the tree.
"""
data_pointer = 0
def __init__(self, capacity):
self.capacity = capacity # for all priority values
self.tree = np.zeros(2 * capacity - 1)
# [--------------Parent nodes-------------][-------leaves to recode priority-------]
# size: capacity - 1 size: capacity
self.data = np.zeros(capacity, dtype=object) # for all transitions
# [--------------data frame-------------]
# size: capacity
def add(self, p, data):
tree_idx = self.data_pointer + self.capacity - 1
self.data[self.data_pointer] = data # update data_frame
self.update(tree_idx, p) # update tree_frame
self.data_pointer += 1
if self.data_pointer >= self.capacity: # replace when exceed the capacity
self.data_pointer = 0
def update(self, tree_idx, p):
change = p - self.tree[tree_idx]
self.tree[tree_idx] = p
# then propagate the change through tree
while tree_idx != 0: # this method is faster than the recursive loop in the reference code
tree_idx = (tree_idx - 1) // 2
self.tree[tree_idx] += change
def get_leaf(self, v):
"""
Tree structure and array storage:
Tree index:
0 -> storing priority sum
/ \
1 2
/ \ / \
3 4 5 6 -> storing priority for transitions
Array type for storing:
[0,1,2,3,4,5,6]
"""
parent_idx = 0
while True: # the while loop is faster than the method in the reference code
cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids
cr_idx = cl_idx + 1
if cl_idx >= len(self.tree): # reach bottom, end search
leaf_idx = parent_idx
break
else: # downward search, always search for a higher priority node
if v <= self.tree[cl_idx]:
parent_idx = cl_idx
else:
v -= self.tree[cl_idx]
parent_idx = cr_idx
data_idx = leaf_idx - self.capacity + 1
return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
@property
def total_p(self):
return self.tree[0] # the root
class Memory(object): # stored as ( s, a, r, s_ ) in SumTree
"""
This Memory class is modified based on the original code from:
https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
"""
epsilon = 0.01 # small amount to avoid zero priority
alpha = 0.6 # [0~1] convert the importance of TD error to priority
beta = 0.4 # importance-sampling, from initial value increasing to 1
beta_increment_per_sampling = 0.001
abs_err_upper = 1. # clipped abs error
def __init__(self, capacity):
self.tree = SumTree(capacity)
def store(self, transition):
max_p = np.max(self.tree.tree[-self.tree.capacity:])
if max_p == 0:
max_p = self.abs_err_upper
self.tree.add(max_p, transition) # set the max p for new p
def sample(self, n):
b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1))
pri_seg = self.tree.total_p / n # priority segment 均匀区间
#belta不断变大
self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) # max = 1
#最小概率
min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p # for later calculate ISweight
if min_prob == 0:
min_prob = 0.00001
for i in range(n):
a, b = pri_seg * i, pri_seg * (i + 1)
#取均匀分布
v = np.random.uniform(a, b)
#p为误差
idx, p, data = self.tree.get_leaf(v)
prob = p / self.tree.total_p
ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
b_idx[i], b_memory[i, :] = idx, data
return b_idx, b_memory, ISWeights
def batch_update(self, tree_idx, abs_errors):
abs_errors += self.epsilon # convert to abs and avoid 0
clipped_errors = np.minimum(abs_errors.detach().numpy(), self.abs_err_upper)
ps = np.power(clipped_errors, self.alpha)
for ti, p in zip(tree_idx, ps):
self.tree.update(ti, p)
class MODEL(torch.nn.Module):
def __init__(self, env):
super(MODEL, self).__init__()
self.state_dim = env.observation_space.shape[0]
self.action_dim = env.action_space.n
self.fc1 = torch.nn.Linear(self.state_dim, 20)
self.fc1.weight.data.normal_(0, 0.6)
self.advantage_fc = torch.nn.Linear(20, self.action_dim)
self.advantage_fc.weight.data.normal_(0, 0.2)
self.value_fc = torch.nn.Linear(20, 1)
self.value_fc.weight.data.normal_(0, 0.2)
def create_Q_network(self, x):
x = F.relu(self.fc1(x))
self.a_value = self.advantage_fc(x)
self.v_value = self.value_fc(x)
Q_value = self.v_value + (self.a_value - torch.mean(self.a_value, axis=-1, keepdim=True))
return Q_value
def forward(self, x, action_input):
Q_value = self.create_Q_network(x)
Q_action = torch.mul(Q_value, action_input).sum(dim=1)
return Q_action
class DQN():
def __init__(self, env):
self.replay_total = 0
self.target_Q_net = MODEL(env)
self.current_Q_net = MODEL(env)
self.memory = Memory(capacity=REPLAY_SIZE)
self.time_step = 0
self.epsilon = INITIAL_EPSILON
self.optimizer = torch.optim.Adam(params=self.current_Q_net.parameters(), lr=0.0001)
# self.loss = torch.nn.MSELoss()
def store_transition(self, s, a, r, s_, done):
transition = np.hstack((s, a, r, s_, done))
self.memory.store(transition)
def perceive(self,state,action,reward,next_state,done):
one_hot_action = np.zeros(self.current_Q_net.action_dim)
one_hot_action[action] = 1
self.store_transition(state,one_hot_action,reward,next_state,done)
self.replay_total += 1
if self.replay_total > BATCH_SIZE:
self.train_Q_network()
def train_Q_network(self):
self.time_step += 1
# Step 1: obtain random minibatch from replay memory
tree_idx, minibatch, ISWeights = self.memory.sample(BATCH_SIZE)
state_batch = torch.tensor(minibatch[:,0:4], dtype=torch.float32)
action_batch = torch.tensor(minibatch[:,4:6], dtype=torch.float32)
reward_batch = [data[6] for data in minibatch]
next_state_batch = torch.tensor(minibatch[:,7:11], dtype=torch.float32)
# Step 2: calculate y
y_batch = []
current_a = self.current_Q_net.create_Q_network(next_state_batch)
max_current_action_batch = torch.argmax(current_a, axis=1)
Q_value_batch = self.target_Q_net.create_Q_network(next_state_batch)
for i in range(0,BATCH_SIZE):
done = minibatch[i][11]
if done:
y_batch.append(reward_batch[i])
else:
max_current_action = max_current_action_batch[i]
y_batch.append(reward_batch[i] + GAMMA * Q_value_batch[i,max_current_action])
y = self.current_Q_net(torch.FloatTensor(state_batch), torch.FloatTensor(action_batch))
y_batch = torch.FloatTensor(y_batch)
cost = self.loss(y, y_batch, torch.tensor(ISWeights))
self.optimizer.zero_grad()
cost.backward()
self.optimizer.step()
y = self.current_Q_net(torch.FloatTensor(state_batch), torch.FloatTensor(action_batch))
abs_errors = torch.abs(y_batch - y)
self.memory.batch_update(tree_idx, abs_errors)
def egreedy_action(self,state):
Q_value = self.current_Q_net.create_Q_network(torch.FloatTensor(state))
if random.random() <= self.epsilon:
self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / 10000
return random.randint(0, self.current_Q_net.action_dim - 1)
else:
self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / 10000
return torch.argmax(Q_value).item()
def action(self,state):
return torch.argmax(self.target_Q_net.create_Q_network(torch.FloatTensor(state))).item()
def update_target_params(self):
torch.save(self.current_Q_net.state_dict(), 'net_params.pkl')
self.target_Q_net.load_state_dict(torch.load('net_params.pkl'))
def loss(self, y_output, y_true, ISWeights):
value = y_output - y_true
return torch.mean(value*value*ISWeights)
def main():
# initialize OpenAI Gym env and dqn agent
env = gym.make(ENV_NAME)
agent = DQN(env)
for episode in range(EPISODE):
# initialize task
state = env.reset()
# Train
for step in range(STEP):
action = agent.egreedy_action(state) # e-greedy action for train
next_state,reward,done,_ = env.step(action)
# Define reward for agent
reward = -1 if done else 0.1
agent.perceive(state,action,reward,next_state,done)
state = next_state
if done:
break
# Test every 100 episodes
if episode % 100== 0:
total_reward = 0
for i in range(TEST):
state = env.reset()
for j in range(STEP):
# env.render()
action = agent.action(state) # direct action for test
state,reward,done,_ = env.step(action)
total_reward += reward
if done:
break
ave_reward = total_reward/TEST
print ('episode: ',episode,'Evaluation Average Reward:',ave_reward)
agent.update_target_params()
if __name__ == '__main__':
main()