Deep Reinforcement Learning for Sepsis Treatment
1 Introduction
本文所提出的是一种基于深度强化学习的脓毒症治疗方法。
使用强化学习而非有监督学习有以下考虑:
- 在临床上,对于一个治疗效果的“好坏”的界定不一定那么清晰
- 强化学习可以从并非最优的行为中学出最优的策略
本文关注的是如何使用continuous-state-space modeling来用一个向量表示患者在某一个时刻的生理状态,并且使用Deep-Q学习的方法来学习合适的行为。
2 Background and related work
在Q-learning中,最优的action value function是由Bellman方程迭代得到的:
Q
∗
(
s
,
a
)
=
E
s
′
T
(
s
′
∣
s
,
a
)
[
r
+
γ
m
a
x
a
′
Q
∗
(
s
′
,
a
′
)
∣
s
t
=
s
,
a
t
=
a
]
Q^*(s,a)=\mathbb{E}_{s'~T(s'|s,a)}[r+\gamma max _{a'}Q^*(s',a')|s_t=s,a_t=a]
Q∗(s,a)=Es′ T(s′∣s,a)[r+γmaxa′Q∗(s′,a′)∣st=s,at=a]
其中
T
(
s
′
∣
s
,
a
)
T(s'|s,a)
T(s′∣s,a)代表的是状态转移的分布。Bellman方程的含义是从当前的状态
s
s
s出发,采取行为
a
a
a后所能得到的value的期望,其中迭代的部分取了上一个阶段最优的行为。
本文是基于下面的文献开展的
M. Komorowski, A. Gordon, L. A. Celi, and A. Faisal. A Markov Decision Process to suggest optimal treatment of severe infections in intensive care. In Neural Information Processing Systems Workshop on Machine Learning for Health, December 2016.
区别应该是在于现在这篇文献将以前的离散状态用连续状态来代替。
3 Methods
3.1 Data and preprocessing
本文使用的是MIMIC-III数据集,筛选出了符合Sepsis-3标准的患者群体。
对于每个患者,都有包括demographic, lab values, vital signs以及intake/output events在内的physiological parameters。数据被以4h为间隔进行切分,其中的指标通过取均值或加和进行聚合。最后得到一个
48
×
1
48\times1
48×1的特征向量,该向量也就是MDP中的
s
t
s_t
st。
选了的特征如下:
3.2 Action and rewards
本研究采取的action是离散化的药物计量选择,是将intravenous(IV)与vasopressor(VP)的剂量分别划分为五个选择然后进行组合。划分的依据是按照所有使用药物剂量的分位数。药物的剂量也要按照每四个小时的区间段进行聚合,IV是取的加和,VP取的最大值。也就是说,通过元组 ( t o t a l I V i n , m a x V P i n ) (\rm{total\ IV\ in,max\ VP\ in}) (total IV in,max VP in)可以表示一个行为。然后每个时间段内所对应的元组也就可以对应一个干预行为。
本文所使用的reward function是根据SOFA评分和lactate level(一种反映细胞缺氧程度的指标,脓毒症的患者这个指标会高一些)设计的,对高的SOFA评分、SOFA评分的升高以及lactate的升高会有一个惩罚,反之则有一个positive reward。在一条患者轨迹的最后,还会根据患者的生存与否给与一个reward值。
下面是本研究所设计的reward function:
这里这些参数选择的依据主要是要将每一步的reward限制在一个有限的幅度内,使得患者轨迹中所积累的reward不会在最后超过最后一步的奖赏。后面所使用的
t
a
n
h
(
⋅
)
\rm{tanh(\cdot)}
tanh(⋅)也是因为最大lactate level的变化值远大于了其平均值,通过这个函数实现放缩。
3.3 Model Structure
本文使用神经网络来拟合
Q
∗
(
s
,
a
)
Q^*(s,a)
Q∗(s,a),也就是最优的action value function。
所设计的Deep Q Network根据给定
<
s
,
a
,
r
,
s
′
>
<s,a,r,s'>
<s,a,r,s′>下网络的输出
Q
(
s
,
a
;
θ
)
Q(s,a;\theta)
Q(s,a;θ)与期望标签
Q
t
a
r
g
e
t
=
r
+
γ
m
a
x
a
′
Q
(
s
′
,
a
′
;
θ
)
Q_{target}=r+\gamma max_{a'}Q(s',a';\theta)
Qtarget=r+γmaxa′Q(s′,a′;θ)之间的squared error进行参数优化的,通过stochastic batch gradient就可以进行优化。
但是target values的不稳定会导致训练的不稳定,所以这里决定使用一个单独的网络来确定target Q values,也就是
Q
(
s
′
,
a
′
)
Q(s',a')
Q(s′,a′)。
本文采取的是基于简单的Q-Networks的改进强化学习神经网络。
- 为了克服Q-value的overestimation的问题,采用Double-Deep Q Network的网络结构。在这个网络结构中,target Q values是由一个前向传播的main network找到的行为进行计算的,而非直接由target网络得到。
- 为了使得患者状态的好坏和每一个时间步的行为是否采取得正确对于Q-values有独自的影响,本文使用Dueling Q Network。在该结构中,给定 ( s , a ) (s,a) (s,a)对,所对应的 Q ( s , a ) Q(s,a) Q(s,a)会被分别分为 v a l u e value value和 a d v a n t a g e advantage advantage,其中 v a l u e value value度量的是当前状态的好坏, a d v a n t a g e advantage advantage度量的是所选择的行为的好坏。
- 本文还采取了Prioritize Experience Replay以加速学习过程。
4 Results
这里的结果是对患者按照SOFA高低进行分群后展示的。
从图中可以看到,对于具有较低和中等SOFA的患者,医生 基本不太实用VP。而对于高SOFA的患者,本文显示的模型就不太有效果了。
Appendix learning material
Deep Reinforcement Learning with Double Q-learning文献阅读
Introduction
Q-learnin存在的问题
- 由于Q-learning中存在基于所估计的action values进行最大化,然后选择action的过程,所以很多时候会过高地估计action values
本文提出了Double Q-learning以克服上述问题,并且可以提升模型性能。
关于overestimation
- 在过去的研究中,overestimation大多被归因于insufficient flexible function approximation以及nosie
- 这篇文章不区分approximation error的来源,而直接将overestimation归因于inaccurate action values
Background
为了解决序贯决策问题,我们可以通过学习optimal value这种方式来实现。optimal value 的含义为在遵从一个最优的策略时在某个状态下采取一个行为所获得的未来奖励之和。
用公式表达就是
Q
π
(
s
,
a
)
≜
E
[
R
1
+
γ
R
2
+
.
.
.
∣
S
0
=
s
,
A
0
=
a
,
π
]
Q_{\pi}(s,a)\triangleq \mathbb{E}[R_1+\gamma R_2+...|S_0=s,A_0=a,\pi]
Qπ(s,a)≜E[R1+γR2+...∣S0=s,A0=a,π]
其中
γ
\gamma
γ为
[
0
,
1
]
[0,1]
[0,1]内的数,叫折扣因子,用于调整近期收益与长期收益的权衡。
而optimal values就是
Q
∗
(
s
,
a
)
=
m
a
x
π
Q
π
(
s
,
a
)
Q_*(s,a)=max_{\pi}Q_{\pi}(s,a)
Q∗(s,a)=maxπQπ(s,a),也就是选出一个
π
\pi
π使得值最大化,这里的理解就是说在特定的
s
s
s下选择一个
a
a
a的
R
R
R是和整个
π
\pi
π相关的,因为在当前状态下选择了一个行为之后的发展是受到策略影响的。而有了这个optimal values之后,再在每个状态下选择可以使值最大的action,就可以得到最大的收益。
而估计optimal function可以使用Q-learning的方法,这是TD-learning的off-policy版本。然而在大多数情况下状态-动作空间是很复杂的,所以类似于表格形式的将所有的state-action pair的value都计算出来是一件比较费劲的事情,所以可以使用参数化的函数
Q
(
s
,
a
;
θ
t
)
Q(s,a;\bf{\theta}_t)
Q(s,a;θt)来近似价值函数。