第四章-基于神经网络方法求解RL-DQN概述
主要内容为题主在学习机器学习时记录的内容,主要为DQN的一些概念、流程。
一、函数逼近
1.1 RL与Deep RL
上次学的Sarsa 与Q-learning 适合简单的场景,如走迷宫等。
因为它的状态很少,由一个Q表格就可以完全放下。
担当它操作到DRL时,其产生的状态特别多,如国际象棋有 10^ 47 种状态,而围棋有 10^ 170种状态,甚至模拟人类具体动作行为时有不可数的状态,那么用一张表格来存储这些状态,表格将变得非常非常巨大。
所以引入了值函数近似的概念。
1.2 值函数近似
1.2.1 定义
Q表格的算法中,我们最后会生成一张表来记录Q值:
一般是一张动作和状态的二维表。
那表格太大时怎么办?我们可以将表格的值近似为一个函数。
近似式为:
即在使用s,a,w 三个参数得到的Q值,与原表格的Q值近似。其中参数w,可以理解为神经网络的权值参数。
1.2.2 优缺点
表格方法的缺点:
- 表格会占据大量内存
- 表格太大时,查找效率会变低
值函数近似优点:
- 仅需要存储参数
- 状态泛化,相似状态可以输出一样。
二、神经网络
2.1 简介
其实神经网络就像一个黑盒子,你输入x 值会输出对应的 y 值。如果输入输出为向量时,则会把多个输入转换为多个输出。
就像输入 手写3 可以识别为 3。
输入猫猫照片,可以识别为猫。
输入房屋位置、面积可以计算出房价。
2.2 内部结构
2.2.1 两层&三层
神经网络这个词来源于系统架构设计背后的灵感——模仿生物大脑自身神经网络的基本结构。神经网络由多个神经元组成。
一般的神经网络主要有两层,即输入层和输出层。
也可以加一层隐藏层,构成三层的神经网络。
三层的作用分别为:
- 输入层——接受输入数据。
- 隐藏层——对输入数据进行数学计算。
- 输出层——为程序产生给定的输出。
2.2.2 深层神经网络+CNN
比三层神经网络更强的是引入了卷积(CNN)。
几个卷积过滤层可以自动地提取图片的特征,再加上几层全连接层。
引入CNN后可以实现强化学习观察图片来进行学习。
三、DQN基础
3.1 基本介绍
3.1.1 概念
DQN 最早由 DeepMind 发表在 NIPS 2013,改进后发表在 Nature 2015.(Mnih, Volodymyr, et al. “Human-level control through deep reinforcement learning.” Nature 518.7540 (2015): 529-533.)
它最初通过实现49个Atari游戏的强化学习,并有30个超越了人类水平,来证明DQN可以实现强化学习且效果非常优秀。
因为涉及到像素图像的输入,所以他使用了神经网络来代替Q表格,但他的本质还是Q-learning。
3.1.2 回顾一下Q-learning
Q-learning 是从环境当中获取到状态 s , 然后通过查表找到对应所有动作的Q值,然后输出最优的动作 a ,输出给环境,然后得到下一步的 s 。
同时也会得到奖励 r ,将 r 代入学习公式,使得存入表格的Q逼近真实的Q。
当然并不是每一次都去取最优值,这里有个sample函数来实现探索。
3.1.3 DQN的改进
DQN将Q-learning中的表格换为了神经网络,而在学习公式和策略方面并没有改变。
这样做的好处就是实现了函数近似而不是用占内存极大的表格来存储。
3.2 如何训练
3.2.1 监督学习
这里需要与监督学习进行比较。
监督学习输入 x ,输出 y,并于真实值 y 进行比较。比较是为了逼近真实值。
我们计算他们的均方差,送进优化函数里面,就可以对网络做自动的更新和优化。
3.2.2 DQN怎么做
DQN与监督学习的训练过程类似,他输入一堆状态s,输出对应的q值,输出的q是一个向量(假设动作是多维度的)。
接着我们需要让得到的q逼近Target q。
Target q的计算需要用到Q-learning的学习公式:
更新网络时需要的均方差公式为:
四、DQN的创新点
引入了神经网络,就会引入非线性函数。而用非线性函数来近似Q表格,并不能证明最后一定可以收敛。而DQN就使用了两大创新来保证算法的稳定性和收敛。
4.1 Experience replay-经验回放
4.1.1 经验回放概念
即经验回放。主要解决样本关联性和样本利用效率的问题。
上图为监督学习的网络,它的输入为独立的x,x之间并没有联系。
而强化学习是一个序列的学习问题,前后的步骤是有关联的,但神经网络作为有监督学习模型,需要数据满足独立分布。这样的矛盾被 Experience replay 用存储-采样的方法打破。
它的方法是存储一些经验数据,然后打乱顺序从中选取数据来更新网络。这样不仅可以减少相关性,还可以增大样本的利用率。
4.1.2 回顾Off-policy
在 Off-policy中有两种策略 Behaviour policy 与 Target policy。他们的关系就像士兵和军师,而q表格就像战术表一样。士兵通过战术表使用战术来攻打堡垒,拿到战术经验带给军师,军师则将经验分析并写入战术表中。
而经验回放机制提供了一个经验池,经验池是一个固定长度的队列, 士兵每次与环境交互会得到一条经验,从而存到经验池中。如果经验超过了经验池的容量,会将旧的一条经验弹出。而军师在优化战术时,会随机的从经验池中抽取一条经验来更新Q。
其实经验池就相当于一个缓冲区,(个人感觉这种缓冲区的思想在计算机学中很常被使用到,cache、快表等等,以后遇到问题时可以考虑使用),它可以打乱样本关联性(随机在经验池中挑选),并提高样本的利用率(存起来多用几次)。
4.1.3 ReplayMemory类
python中有写好的ReplayMemory类,可以引入collections库调用队列的函数来实现一个队列。
下面是ReplayMemory类的代码。其实在研究强化学习时,很多代码、算法都可以重复使用,所以如果能封装的话就会很方便,仅需要简单修改然后调一下环境就行。
import random
import collections
import numpy as np
class ReplayMemory(object):
def __init__(self, max_size):
self.buffer = collections.deque(maxlen=max_size)
# 增加一条经验到经验池中
def append(self, exp):
self.buffer.append(exp)
# 从经验池中选取N条经验出来
def sample(self, batch_size):
mini_batch = random.sample(self.buffer, batch_size)
obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = [], [], [], [], []
for experience in mini_batch:
s, a, r, s_p, done = experience
obs_batch.append(s)
action_batch.append(a)
reward_batch.append(r)
next_obs_batch.append(s_p)
done_batch.append(done)
return np.array(obs_batch).astype('float32'), \
np.array(action_batch).astype('float32'), np.array(reward_batch).astype('float32'),\
np.array(next_obs_batch).astype('float32'), np.array(done_batch).astype('float32')
def __len__(self):
return len(self.buffer)
4.2 Fixed Q target-固定Q目标
因为监督学习是足够平稳和稳定的,所以可以进行对比。监督学习要让输出的y逼近真实的y,这个真实的y是固定的、稳定的。
但是DQN,输出Q的预测值,要逼近的目标值其实也是在不断地变化的。就像你练习射箭,但你的目标是一只不断跳动的兔子一样。为了有好的效果我们需要将兔子固定一下,变成固定的靶子,或者在一段时间内将兔子固定为靶子。即在一段时间内让目标的Q保持稳定,这样输出的Q也是稳定的,这就是DQN的第二大创新-Fixed Q target。
五、DQN代码
5.1 伪代码
经过上面的介绍,想必大家已经看出来,DQN其实不过是Q-learning引入了神经网络,使用了经验回放和固定Q目标两大创新。
它的伪代码如图:
5.2 DQN 流程
5.2.1 DQN流程图
- sample 函数,负责采样,保障所有的动作都被探索到。
- learn 函数,根据环境中拿到的数据来更新Q表格。
- DQN 使用model来代替Q表格。
- 同时使用一个经验池来实现经验回放,用append函数存经验,然后使用sample从经验池中挑选一个batch数据送到 learn 函数中进行学习训练。
- 还需要另一个和Q网络一样的target网络,它定期的从Q网络复制参数过来。Q网络来产生Q预测值,而target网络来产生一个相对稳定的Q目标值。两者的差值就是Q网络要优化的目标。
- 计算后的插值优化后可以更新 Q表格,这就是learn函数要实现的内容。
- 这是DQN最核心的一部分。
5.2.2 DQN三层结构
所以经过优化后,DQN可以分为三部分,model、algorithm、agent三部分(这也是parl的魅力所在)。
model 用来定义网络结构。
algorithm 定义具体的算法来更新Q网络。
agent 负责与环境做交互,交互过程中生成的数据交付给algorithm ,进而交付给model 。