REFUEL: Exploring Sparse Features in Deep Reinforcement Learning for Fast Disease Diagnosis
文章来自 HTC Research & Healthcare,提出了一个新算法 ---- REFUEL (REward shaping and FeatUrE rebuiLding),该算法能够有效提升在线疾病诊断的速度和精度。
Contribution:
1)提出了一个新算法 — REFUEL,由 Reward shaping and Feature rebuilding 这两个技术与 RL 相结合而成。
2)将 REFUEL 算法用在疾病诊断上,有效提升了在线疾病诊断的速度和精度。
医学疾病诊断过程
其实疾病诊断过程还是比较容易理解的,基本上就是一个人生病了要去看医生,然后医生问你有没有头痛,有没有发烧等等相关症状,然后病人就回答有或者没有,最后医生根据病人的回答得到相关的症状,并判断出疾病类别,至此,诊断结束。
总结起来就是两个过程,其一是医生需要询问病人有没有某个症状,其二是医生问完后作出诊断。
- 动作空间 A \mathcal A A
当将 RL 算法应用于疾病诊断中时,Agent 即充当医生的角色,故需要作出两个不同类别的动作,询问症状和下诊断。 所以,文章的动作空间 A \mathcal A A 的表述是这样的:
![](https://i-blog.csdnimg.cn/blog_migrate/faf51805c780ea91661aab74bf47f14c.png)
当动作
a
a
a 的取值在
{
1
,
.
.
.
,
m
}
\{1,...,m\}
{1,...,m} 范围时,表示 Agent 在问病人的症状 (e.g.
a
=
2
a=2
a=2,表示提问病人有没有2号症状,这里设共有 m 个症状,且
m
≥
2
m \geq 2
m≥2);
当动作
a
a
a 的取值在
{
m
+
1
,
.
.
.
,
m
+
n
}
\{m+1,...,m+n\}
{m+1,...,m+n} 范围时,表示 Agent 在下诊断 (e.g.
a
=
m
+
3
a=m + 3
a=m+3,表示下诊断为第 3 种疾病,这里设共有 n 个疾病,且
n
≥
3
n \geq 3
n≥3)。
- 状态空间 S \mathcal S S
![](https://i-blog.csdnimg.cn/blog_migrate/f76d6e8263bf4b5e95f1db2ea67e0273.png)
同样,这里的 m 代表的是症状个数, − 1 , 0 , 1 -1,0,1 −1,0,1 分别代表未知症状,阴性症状,阳性症状。设 m=2,则该状态空间为 S = { − 1 , 0 , 1 } 2 = { − 1 , 0 , 1 } × { − 1 , 0 , 1 } = { ( − 1 , − 1 ) , ( − 1 , 0 ) , ( − 1 , 1 ) , ( 0 , − 1 ) , ( 0 , 0 ) , ( 0 , 1 ) , ( 1 , − 1 ) , ( 1 , 0 ) , ( 1 , 1 ) } \mathcal S = \{-1,0,1\}^2 = \{-1,0,1\} \times \{-1,0,1\} = \{(-1,-1),(-1,0),(-1,1),(0,-1),(0,0),(0,1),(1,-1),(1,0),(1,1)\} S={−1,0,1}2={−1,0,1}×{−1,0,1}={(−1,−1),(−1,0),(−1,1),(0,−1),(0,0),(0,1),(1,−1),(1,0),(1,1)},一共 ∣ S ∣ = 3 2 = 9 |\mathcal S| = 3^2 = 9 ∣S∣=32=9 种情况
- 状态转移函数
根据上述,医生是有了,那么其互动对象病人呢?文章给了一个数据集合 D \mathcal D D 来充当病人:
![](https://i-blog.csdnimg.cn/blog_migrate/bd88c5813d04376a3dc8b453e2329a10.png)
其中, X = { 0 , 1 } m , Y = { 1 , . . . , n } \mathcal X = \{0,1\}^m, \mathcal Y = \{1,...,n\} X={0,1}m,Y={1,...,n}。
集合
D
\mathcal D
D 共有
k
k
k 个病人,每个病人可以描述为
(
x
(
j
)
,
y
(
j
)
)
(x^{(j)},y^{(j)})
(x(j),y(j)),其中
x
(
j
)
x^{(j)}
x(j) 表示第 j 个病人的所有症状 (m 维),
y
(
j
)
y^{(j)}
y(j) 表示第 j 个病人的疾病代号 (1维)。
e.g. 设
m
=
5
,
n
≥
6
m = 5,n\geq 6
m=5,n≥6。若 j=2,则指第 2 号病人
(
x
(
2
)
,
y
(
2
)
)
(x^{(2)},y^{(2)})
(x(2),y(2)); 若病人的描述具体为
x
(
2
)
=
(
1
,
1
,
0
,
0
,
1
)
,
y
(
2
)
=
4
x^{(2)} = (1,1,0,0,1), \ y^{(2)} = 4
x(2)=(1,1,0,0,1), y(2)=4,则表示第 2 号病人得的是第 4 种病,其症状为(阳性,阳性,阴性,阴性,阳性),即该病的症状 1、症状2 和症状 5 均为阳性。
好,有了病人之后,Agent 在对他进行病症咨询的时候,如果病人的症状向量 x ( j ) x^{(j)} x(j) 对应的元素为 1 则回答有该症状,为 0 则答无。这样就 Agent 与 Env 就能互动了。
在病人回答了之后,Agent 得到下一个状态 S t + 1 S_{t+1} St+1:
![](https://i-blog.csdnimg.cn/blog_migrate/4fa4d2946ab1b096c8a63274d068face.png)
怎么理解这个式子?
其实,该式子表示的仅仅是状态向量
S
t
+
1
S_{_{t+1}}
St+1 的一个元素的更新方法,而
S
t
+
1
=
(
S
t
+
1
,
1
,
S
t
+
1
,
2
,
S
t
+
1
,
3
,
S
t
+
1
,
4
,
S
t
+
1
,
5
)
S_{_{t+1}} = (S_{_{t+1,1}},S_{_{t+1,2}},S_{_{t+1,3}},S_{_{t+1,4}},S_{_{t+1,5}})
St+1=(St+1,1,St+1,2,St+1,3,St+1,4,St+1,5)。如果 agent 执行动作
A
t
A_t
At,即问病人有没有第
A
t
A_t
At 号症状,那么状态向量
S
t
+
1
S_{_{t+1}}
St+1 对应的第
A
t
A_t
At 位元素就会更新,其他位置的元素不变。符号
x
j
x_j
xj 表示的是病人症状向量的第 j 个元素的值,只有阴性和阳性即 0 和 1 两个取值。若
x
j
=
0
x_j = 0
xj=0,则
S
t
+
1
,
j
=
−
1
S_{_{t+1,j}} = -1
St+1,j=−1; 若
x
j
=
1
x_j = 1
xj=1,则
S
t
+
1
,
j
=
1
S_{_{t+1,j}} = 1
St+1,j=1。值得注意的是,在状态向量
S
t
+
1
S_{_{t+1}}
St+1 中,用 1 和 -1 表示阳性和阴性。
比如,设在 t 时刻,状态为 S t = ( 0 , 1 , 0 , 0 , 0 ) S_{t} = (0,1,0,0,0) St=(0,1,0,0,0),Agent 提问第 2 号病人有没有第 5 号症状?根据第 2 号病人的症状向量 x ( 2 ) = ( 1 , 1 , 0 , 0 , 1 ) x^{(2)} = (1,1,0,0,1) x(2)=(1,1,0,0,1),该病人回答有,那么下一个状态则更新为 S t + 1 = ( 0 , 1 , 0 , 0 , 1 ) S_{t+1} = (0,1,0,0,1) St+1=(0,1,0,0,1),这样 agent 就收集到了该病人的另一个症状。通常来说,收集到越充足的病症,对于疾病诊断越准确。
Reward shaping
有了上面的介绍铺垫,可以直接开始重点了。
直接上奖励重塑公式:
![](https://i-blog.csdnimg.cn/blog_migrate/54efa0d1c5fff2d462aa576ac48a594d.png)
只是在原来奖励
r
r
r 的基础上,添加了一项辅助奖励
f
(
s
,
s
′
)
f(s,s')
f(s,s′)。
去看下
f
(
s
,
s
′
)
f(s,s')
f(s,s′) 是什么:
![](https://i-blog.csdnimg.cn/blog_migrate/8027b2d13b4f07031a65fc6a24d8c135.png)
其中, γ \gamma γ 是折扣因子,再深追一层,什么是 φ ( s ) \varphi(s) φ(s)?
![](https://i-blog.csdnimg.cn/blog_migrate/08b349ac08579c2319e9f54c5346e0ea.png)
这个
φ
(
s
)
\varphi(s)
φ(s) 叫做 bounded potential function;
解释:
首先,
λ
\lambda
λ 是个超参数。
再看条件,
s
∈
S
\
{
S
⊥
}
s\in S\backslash \{S_\perp\}
s∈S\{S⊥} 指的是除了结束状态集合
{
S
⊥
}
\{S_\perp\}
{S⊥} 之外的状态。
s
j
s_j
sj 指的是状态
s
s
s 的第 j 个元素;而
{
j
:
s
j
=
1
}
\{j:s_j=1\}
{j:sj=1} 指的是在状态向量
s
s
s 中,元素值为 1 的元素下标的集合。
∣
{
j
:
s
j
=
1
}
∣
|\{j:s_j=1\}|
∣{j:sj=1}∣ 指的是这个集合的元素个数。
所以,
φ
(
s
)
\varphi(s)
φ(s) 的大小跟状态向量
s
s
s 中,值为 1 的元素个数有关。而在状态向量中,元素值为 1 则表示为阳性,且是靠咨询得到的,所以这样设计 potential function 是为了鼓励 agent 去问出一个病人更多的阳性症状,以便更好地下诊断。
Feature rebuilding
现在来看论文提出的第二个技术 — 特征重建,其结构如下图所示:
![](https://i-blog.csdnimg.cn/blog_migrate/f62148c54069fabec56cbddc8f11a087.png)
其中,上面输出的是策略
π
\pi
π,下面输出的是重建的特征
z
z
z,而前三层网络是两者所共享的。
特征重建用的 loss 函数是二分类交叉熵的损失函数:
![](https://i-blog.csdnimg.cn/blog_migrate/16d271a948c951637034be232b4ac11b.png)
其中,症状向量 x x x 来自于数据集 D \mathcal D D (全局),而 z z z 来源于状态向量 s s s 的拟合,而 s s s 是不全面的,是局部的 (经过 agent 的多次询问, s s s 才有可能等于 x x x)。
然而,从这个损失函数来看,特征重建网络是想让局部特征 s s s 能逼近或预测出全局特征 x x x。当这个网络预测得越来越准的时候,说明这个网络的参数拥有了一定的全局观,又因为前三层是共享的网络,这样,上面的分支所输出的策略 π \pi π 就会带着一定的全局观来输出动作。
伪代码
所以,agent 的目标是最大化以下函数:
![](https://i-blog.csdnimg.cn/blog_migrate/f12bcf73c8d106a4a7fdec2460e82459.png)
共两部分,一个是标准的 VPG 强化学习算法的目标,另一个是特征重建的目标。
因为 VPG 算法的目标函数求导为:
![](https://i-blog.csdnimg.cn/blog_migrate/7ca1eb06c2bcc97b15114d36a2d0ea8e.png)
所以,更新网络参数的更新公式为 (熵正则化项 H H H 是作者额外添加的):
![](https://i-blog.csdnimg.cn/blog_migrate/8459db8c2a9502f1fd7bf779a3f29315.png)
最后,贴上伪代码:
![](https://i-blog.csdnimg.cn/blog_migrate/0a4d387617f19ff38d27d19eaa85856d.png)
其中,line 5 是先从
D
\mathcal D
D 中抽取一个病人,然后再从病人身上抽出一个症状;相当于是病人因为有了这个症状才来看病的;
line10-13 是医生问症状过程,line14-16 是医生下诊断过程。