DQN变体:Dueling DQN

本篇文章主要讲解Dueling DQN的结构。

解决问题

对比之前的DQN,Dueling DQN主要对结构进行了优化。Dueling DQN考虑将 Q Q Q网络分成两部分,第一部分是仅仅与状态S有关,与具体要采用的动作 A A A无关,这部分我们叫做价值函数部分,记做 V ( S , w , α ) V(S,w,α) V(S,w,α),第二部分同时与状态状态 S S S和动作 A A A有关,这部分叫做优势函数(Advantage Function)部分,记为 A ( S , A , w , β ) A(S,A,w,β) A(S,A,w,β),那么最终我们的价值函数可以重新表示为:
Q ( S , A , w , α , β ) = V ( S , w , α ) + A ( S , A , w , β ) Q(S,A,w,α,β)=V(S,w,α)+A(S,A,w,β) Q(S,A,w,α,β)=V(S,w,α)+A(S,A,w,β)
其中,w是公共部分的网络参数,而α是价值函数独有部分的网络参数,而β是优势函数独有部分的网络参数。图结构如下:
在这里插入图片描述
左边为之前的结构,右边修改之后的结构。
我们可以直接使用上一节的价值函数的组合公式得到我们的动作价值,但是这个式子无法辨识最终输出里面 V ( S , w , α ) V(S,w,α) V(S,w,α) A ( S , A , w , β ) A(S,A,w,β) A(S,A,w,β)各自的作用,为了可以体现这种可辨识性(identifiability),实际使用的组合公式如下:
Q ( S , A , w , α , β ) = V ( S , w , α ) + ( A ( S , A , w , β ) − 1 A ∑ a ′ ∈ A A ( S , a ′ , w , β ) ) Q(S,A, w, \alpha, \beta) = V(S,w,\alpha) + (A(S,A,w,\beta) - \frac{1}{\mathcal{A}}\sum\limits_{a' \in \mathcal{A}}A(S,a', w,\beta)) Q(S,A,w,α,β)=V(S,w,α)+(A(S,A,w,β)A1aAA(S,a,w,β))
可辨识性:可辨识性是指能否通过输入输出数据确定模型的性质的性质,我的理解是可解释性吧。

代码

然后接下来,就是参考tensorflow的代码,写的pytorch。我是在Prioritized Replay DQN的基础上写的。tensorflow的代码是在Nature DQN的基础上写的。

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 17 10:59:02 2019

@author: asus
"""

import gym
import torch
import torch.nn.functional as F
import numpy as np
import random

GAMMA = 0.9
INITIAL_EPSILON = 0.5
FINAL_EPSILON = 0.01
REPLAY_SIZE = 10000
BATCH_SIZE = 32
ENV_NAME = 'CartPole-v0'
EPISODE = 3000 # Episode limitation
STEP = 300 # Step limitation in an episode
TEST = 10 # The number of experiment test every 100 episode

class SumTree(object):
    """
    This SumTree code is a modified version and the original code is from:
    https://github.com/jaara/AI-blog/blob/master/SumTree.py
    Story data with its priority in the tree.
    """
    data_pointer = 0

    def __init__(self, capacity):
        self.capacity = capacity  # for all priority values
        self.tree = np.zeros(2 * capacity - 1)
        # [--------------Parent nodes-------------][-------leaves to recode priority-------]
        #             size: capacity - 1                       size: capacity
        self.data = np.zeros(capacity, dtype=object)  # for all transitions
        # [--------------data frame-------------]
        #             size: capacity

    def add(self, p, data):
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data  # update data_frame
        self.update(tree_idx, p)  # update tree_frame

        self.data_pointer += 1
        if self.data_pointer >= self.capacity:  # replace when exceed the capacity
            self.data_pointer = 0

    def update(self, tree_idx, p):
        change = p - self.tree[tree_idx]
        self.tree[tree_idx] = p
        # then propagate the change through tree
        while tree_idx != 0:    # this method is faster than the recursive loop in the reference code
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    def get_leaf(self, v):
        """
        Tree structure and array storage:
        Tree index:
             0         -> storing priority sum
            / \
          1     2
         / \   / \
        3   4 5   6    -> storing priority for transitions
        Array type for storing:
        [0,1,2,3,4,5,6]
        """
        parent_idx = 0
        while True:     # the while loop is faster than the method in the reference code
            cl_idx = 2 * parent_idx + 1         # this leaf's left and right kids
            cr_idx = cl_idx + 1
            if cl_idx >= len(self.tree):        # reach bottom, end search
                leaf_idx = parent_idx
                break
            else:       # downward search, always search for a higher priority node
                if v <= self.tree[cl_idx]:
                    parent_idx = cl_idx
                else:
                    v -= self.tree[cl_idx]
                    parent_idx = cr_idx

        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]

    @property
    def total_p(self):
        return self.tree[0]  # the root


class Memory(object):  # stored as ( s, a, r, s_ ) in SumTree
    """
    This Memory class is modified based on the original code from:
    https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py
    """
    epsilon = 0.01  # small amount to avoid zero priority
    alpha = 0.6  # [0~1] convert the importance of TD error to priority
    beta = 0.4  # importance-sampling, from initial value increasing to 1
    beta_increment_per_sampling = 0.001
    abs_err_upper = 1.  # clipped abs error

    def __init__(self, capacity):
        self.tree = SumTree(capacity)

    def store(self, transition):
        max_p = np.max(self.tree.tree[-self.tree.capacity:])
        if max_p == 0:
            max_p = self.abs_err_upper
        self.tree.add(max_p, transition)   # set the max p for new p

    def sample(self, n):
        b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), np.empty((n, self.tree.data[0].size)), np.empty((n, 1))
        pri_seg = self.tree.total_p / n       # priority segment 均匀区间
        #belta不断变大
        self.beta = np.min([1., self.beta + self.beta_increment_per_sampling])  # max = 1
        #最小概率
        min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p     # for later calculate ISweight
        if min_prob == 0:
            min_prob = 0.00001
        for i in range(n):
            a, b = pri_seg * i, pri_seg * (i + 1)
            #取均匀分布
            v = np.random.uniform(a, b)
            #p为误差
            idx, p, data = self.tree.get_leaf(v)
            prob = p / self.tree.total_p
            ISWeights[i, 0] = np.power(prob/min_prob, -self.beta)
            b_idx[i], b_memory[i, :] = idx, data
        return b_idx, b_memory, ISWeights

    def batch_update(self, tree_idx, abs_errors):
        abs_errors += self.epsilon  # convert to abs and avoid 0

        clipped_errors = np.minimum(abs_errors.detach().numpy(), self.abs_err_upper)
        ps = np.power(clipped_errors, self.alpha)
        for ti, p in zip(tree_idx, ps):
            self.tree.update(ti, p)

class MODEL(torch.nn.Module):
    def __init__(self, env):
        super(MODEL, self).__init__()
        self.state_dim = env.observation_space.shape[0]
        self.action_dim = env.action_space.n
        self.fc1 = torch.nn.Linear(self.state_dim, 20)
        self.fc1.weight.data.normal_(0, 0.6)
        self.advantage_fc = torch.nn.Linear(20, self.action_dim)
        self.advantage_fc.weight.data.normal_(0, 0.2)
        self.value_fc = torch.nn.Linear(20, 1)
        self.value_fc.weight.data.normal_(0, 0.2)
        
    def create_Q_network(self, x):
        x = F.relu(self.fc1(x))
        self.a_value = self.advantage_fc(x)
        self.v_value = self.value_fc(x)
        Q_value = self.v_value + (self.a_value - torch.mean(self.a_value, axis=-1, keepdim=True))
        return Q_value
    
    def forward(self, x, action_input):
        Q_value = self.create_Q_network(x)
        Q_action = torch.mul(Q_value, action_input).sum(dim=1)
        return Q_action
    
class DQN():
    def __init__(self, env):
        self.replay_total = 0
        self.target_Q_net = MODEL(env)
        self.current_Q_net = MODEL(env)
        self.memory = Memory(capacity=REPLAY_SIZE)
        self.time_step = 0
        self.epsilon = INITIAL_EPSILON
        self.optimizer = torch.optim.Adam(params=self.current_Q_net.parameters(), lr=0.0001)
#        self.loss = torch.nn.MSELoss()
    
    def store_transition(self, s, a, r, s_, done):
        transition = np.hstack((s, a, r, s_, done))
        self.memory.store(transition)
    
    def perceive(self,state,action,reward,next_state,done):
        one_hot_action = np.zeros(self.current_Q_net.action_dim)
        one_hot_action[action] = 1
        self.store_transition(state,one_hot_action,reward,next_state,done)
        self.replay_total += 1
        if self.replay_total > BATCH_SIZE:
            self.train_Q_network()
    
    def train_Q_network(self):
        self.time_step += 1
        # Step 1: obtain random minibatch from replay memory
        tree_idx, minibatch, ISWeights = self.memory.sample(BATCH_SIZE)
        
        state_batch = torch.tensor(minibatch[:,0:4], dtype=torch.float32)
        action_batch =  torch.tensor(minibatch[:,4:6], dtype=torch.float32)
        reward_batch = [data[6] for data in minibatch]
        next_state_batch = torch.tensor(minibatch[:,7:11], dtype=torch.float32)

        # Step 2: calculate y
        y_batch = []
        
        current_a = self.current_Q_net.create_Q_network(next_state_batch)
        max_current_action_batch = torch.argmax(current_a, axis=1)

        Q_value_batch = self.target_Q_net.create_Q_network(next_state_batch)
        
        for i in range(0,BATCH_SIZE):
            done = minibatch[i][11]
            if done:
                y_batch.append(reward_batch[i])
            else:
                max_current_action = max_current_action_batch[i]
                y_batch.append(reward_batch[i] + GAMMA * Q_value_batch[i,max_current_action])
                
        y = self.current_Q_net(torch.FloatTensor(state_batch), torch.FloatTensor(action_batch))
        y_batch = torch.FloatTensor(y_batch)
        cost = self.loss(y, y_batch, torch.tensor(ISWeights))
        self.optimizer.zero_grad()
        cost.backward()
        self.optimizer.step()
        y = self.current_Q_net(torch.FloatTensor(state_batch), torch.FloatTensor(action_batch))
        abs_errors = torch.abs(y_batch - y)
        self.memory.batch_update(tree_idx, abs_errors)
        
    def egreedy_action(self,state):
        Q_value = self.current_Q_net.create_Q_network(torch.FloatTensor(state))
        if random.random() <= self.epsilon:
            self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / 10000
            return random.randint(0, self.current_Q_net.action_dim - 1)
        else:
            self.epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / 10000
            return torch.argmax(Q_value).item()     
                
    def action(self,state):
        return torch.argmax(self.target_Q_net.create_Q_network(torch.FloatTensor(state))).item()
               
    def update_target_params(self):
        torch.save(self.current_Q_net.state_dict(), 'net_params.pkl')
        self.target_Q_net.load_state_dict(torch.load('net_params.pkl'))
    
    def loss(self, y_output, y_true, ISWeights):
        value = y_output - y_true
        return torch.mean(value*value*ISWeights)
    
def main():
  # initialize OpenAI Gym env and dqn agent
  env = gym.make(ENV_NAME)
  agent = DQN(env)

  for episode in range(EPISODE):
    # initialize task
    state = env.reset()
    # Train
    for step in range(STEP):
    
      action = agent.egreedy_action(state) # e-greedy action for train
      next_state,reward,done,_ = env.step(action)
      # Define reward for agent
      reward = -1 if done else 0.1
      agent.perceive(state,action,reward,next_state,done)
      state = next_state
      if done:
        break
    # Test every 100 episodes
    if episode % 100== 0:
      total_reward = 0
      for i in range(TEST):
        state = env.reset()
        for j in range(STEP):
#          env.render()
          action = agent.action(state) # direct action for test
          state,reward,done,_ = env.step(action)
          total_reward += reward
          if done:
            break
      ave_reward = total_reward/TEST
      print ('episode: ',episode,'Evaluation Average Reward:',ave_reward)
    agent.update_target_params()
if __name__ == '__main__':
  main()

参考文献:https://www.cnblogs.com/pinard/p/9923859.html

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值