REFUEL: Exploring Sparse Features in Deep Reinforcement Learning for Fast Disease Diagnosis

REFUEL: Exploring Sparse Features in Deep Reinforcement Learning for Fast Disease Diagnosis

文章来自 HTC Research & Healthcare,提出了一个新算法 ---- REFUEL (REward shaping and FeatUrE rebuiLding),该算法能够有效提升在线疾病诊断的速度和精度。

Contribution:
1)提出了一个新算法 — REFUEL,由 Reward shaping and Feature rebuilding 这两个技术与 RL 相结合而成。
2)将 REFUEL 算法用在疾病诊断上,有效提升了在线疾病诊断的速度和精度。

医学疾病诊断过程

其实疾病诊断过程还是比较容易理解的,基本上就是一个人生病了要去看医生,然后医生问你有没有头痛,有没有发烧等等相关症状,然后病人就回答有或者没有,最后医生根据病人的回答得到相关的症状,并判断出疾病类别,至此,诊断结束。

总结起来就是两个过程,其一是医生需要询问病人有没有某个症状,其二是医生问完后作出诊断。

  • 动作空间 A \mathcal A A

当将 RL 算法应用于疾病诊断中时,Agent 即充当医生的角色,故需要作出两个不同类别的动作,询问症状和下诊断。 所以,文章的动作空间 A \mathcal A A 的表述是这样的:

当动作 a a a 的取值在 { 1 , . . . , m } \{1,...,m\} {1,...,m} 范围时,表示 Agent 在问病人的症状 (e.g. a = 2 a=2 a=2,表示提问病人有没有2号症状,这里设共有 m 个症状,且 m ≥ 2 m \geq 2 m2);
当动作 a a a 的取值在 { m + 1 , . . . , m + n } \{m+1,...,m+n\} {m+1,...,m+n} 范围时,表示 Agent 在下诊断 (e.g. a = m + 3 a=m + 3 a=m+3,表示下诊断为第 3 种疾病,这里设共有 n 个疾病,且 n ≥ 3 n \geq 3 n3)。

  • 状态空间 S \mathcal S S

同样,这里的 m 代表的是症状个数, − 1 , 0 , 1 -1,0,1 1,0,1 分别代表未知症状,阴性症状,阳性症状。设 m=2,则该状态空间为 S = { − 1 , 0 , 1 } 2 = { − 1 , 0 , 1 } × { − 1 , 0 , 1 } = { ( − 1 , − 1 ) , ( − 1 , 0 ) , ( − 1 , 1 ) , ( 0 , − 1 ) , ( 0 , 0 ) , ( 0 , 1 ) , ( 1 , − 1 ) , ( 1 , 0 ) , ( 1 , 1 ) } \mathcal S = \{-1,0,1\}^2 = \{-1,0,1\} \times \{-1,0,1\} = \{(-1,-1),(-1,0),(-1,1),(0,-1),(0,0),(0,1),(1,-1),(1,0),(1,1)\} S={1,0,1}2={1,0,1}×{1,0,1}={(1,1),(1,0),(1,1),(0,1),(0,0),(0,1),(1,1),(1,0),(1,1)},一共 ∣ S ∣ = 3 2 = 9 |\mathcal S| = 3^2 = 9 S=32=9 种情况

  • 状态转移函数

根据上述,医生是有了,那么其互动对象病人呢?文章给了一个数据集合 D \mathcal D D 来充当病人:

其中, X = { 0 , 1 } m , Y = { 1 , . . . , n } \mathcal X = \{0,1\}^m, \mathcal Y = \{1,...,n\} X={0,1}m,Y={1,...,n}

集合 D \mathcal D D 共有 k k k 个病人,每个病人可以描述为 ( x ( j ) , y ( j ) ) (x^{(j)},y^{(j)}) (x(j),y(j)),其中 x ( j ) x^{(j)} x(j) 表示第 j 个病人的所有症状 (m 维), y ( j ) y^{(j)} y(j) 表示第 j 个病人的疾病代号 (1维)。
e.g. 设 m = 5 , n ≥ 6 m = 5,n\geq 6 m=5,n6。若 j=2,则指第 2 号病人 ( x ( 2 ) , y ( 2 ) ) (x^{(2)},y^{(2)}) (x(2),y(2)); 若病人的描述具体为 x ( 2 ) = ( 1 , 1 , 0 , 0 , 1 ) ,   y ( 2 ) = 4 x^{(2)} = (1,1,0,0,1), \ y^{(2)} = 4 x(2)=(1,1,0,0,1), y(2)=4,则表示第 2 号病人得的是第 4 种病,其症状为(阳性,阳性,阴性,阴性,阳性),即该病的症状 1、症状2 和症状 5 均为阳性。

好,有了病人之后,Agent 在对他进行病症咨询的时候,如果病人的症状向量 x ( j ) x^{(j)} x(j) 对应的元素为 1 则回答有该症状,为 0 则答无。这样就 Agent 与 Env 就能互动了。

在病人回答了之后,Agent 得到下一个状态 S t + 1 S_{t+1} St+1

怎么理解这个式子?
其实,该式子表示的仅仅是状态向量 S t + 1 S_{_{t+1}} St+1 的一个元素的更新方法,而 S t + 1 = ( S t + 1 , 1 , S t + 1 , 2 , S t + 1 , 3 , S t + 1 , 4 , S t + 1 , 5 ) S_{_{t+1}} = (S_{_{t+1,1}},S_{_{t+1,2}},S_{_{t+1,3}},S_{_{t+1,4}},S_{_{t+1,5}}) St+1=(St+1,1,St+1,2,St+1,3,St+1,4,St+1,5)。如果 agent 执行动作 A t A_t At,即问病人有没有第 A t A_t At 号症状,那么状态向量 S t + 1 S_{_{t+1}} St+1 对应的第 A t A_t At 位元素就会更新,其他位置的元素不变。符号 x j x_j xj 表示的是病人症状向量的第 j 个元素的值,只有阴性和阳性即 0 和 1 两个取值。若 x j = 0 x_j = 0 xj=0,则 S t + 1 , j = − 1 S_{_{t+1,j}} = -1 St+1,j=1; 若 x j = 1 x_j = 1 xj=1,则 S t + 1 , j = 1 S_{_{t+1,j}} = 1 St+1,j=1。值得注意的是,在状态向量 S t + 1 S_{_{t+1}} St+1 中,用 1 和 -1 表示阳性和阴性。

比如,设在 t 时刻,状态为 S t = ( 0 , 1 , 0 , 0 , 0 ) S_{t} = (0,1,0,0,0) St=(0,1,0,0,0),Agent 提问第 2 号病人有没有第 5 号症状?根据第 2 号病人的症状向量 x ( 2 ) = ( 1 , 1 , 0 , 0 , 1 ) x^{(2)} = (1,1,0,0,1) x(2)=(1,1,0,0,1),该病人回答有,那么下一个状态则更新为 S t + 1 = ( 0 , 1 , 0 , 0 , 1 ) S_{t+1} = (0,1,0,0,1) St+1=(0,1,0,0,1),这样 agent 就收集到了该病人的另一个症状。通常来说,收集到越充足的病症,对于疾病诊断越准确。

Reward shaping

有了上面的介绍铺垫,可以直接开始重点了。
直接上奖励重塑公式:

只是在原来奖励 r r r 的基础上,添加了一项辅助奖励 f ( s , s ′ ) f(s,s') f(s,s)
去看下 f ( s , s ′ ) f(s,s') f(s,s) 是什么:

其中, γ \gamma γ 是折扣因子,再深追一层,什么是 φ ( s ) \varphi(s) φ(s)

这个 φ ( s ) \varphi(s) φ(s) 叫做 bounded potential function;
解释:
首先, λ \lambda λ 是个超参数。
再看条件, s ∈ S \ { S ⊥ } s\in S\backslash \{S_\perp\} sS\{S} 指的是除了结束状态集合 { S ⊥ } \{S_\perp\} {S} 之外的状态。
s j s_j sj 指的是状态 s s s 的第 j 个元素;而 { j : s j = 1 } \{j:s_j=1\} {j:sj=1} 指的是在状态向量 s s s 中,元素值为 1 的元素下标的集合。
∣ { j : s j = 1 } ∣ |\{j:s_j=1\}| {j:sj=1} 指的是这个集合的元素个数。
所以, φ ( s ) \varphi(s) φ(s) 的大小跟状态向量 s s s 中,值为 1 的元素个数有关。而在状态向量中,元素值为 1 则表示为阳性,且是靠咨询得到的,所以这样设计 potential function 是为了鼓励 agent 去问出一个病人更多的阳性症状,以便更好地下诊断。

Feature rebuilding

现在来看论文提出的第二个技术 — 特征重建,其结构如下图所示:

其中,上面输出的是策略 π \pi π,下面输出的是重建的特征 z z z,而前三层网络是两者所共享的。
特征重建用的 loss 函数是二分类交叉熵的损失函数:

其中,症状向量 x x x 来自于数据集 D \mathcal D D (全局),而 z z z 来源于状态向量 s s s 的拟合,而 s s s 是不全面的,是局部的 (经过 agent 的多次询问, s s s 才有可能等于 x x x)。

然而,从这个损失函数来看,特征重建网络是想让局部特征 s s s 能逼近或预测出全局特征 x x x。当这个网络预测得越来越准的时候,说明这个网络的参数拥有了一定的全局观,又因为前三层是共享的网络,这样,上面的分支所输出的策略 π \pi π 就会带着一定的全局观来输出动作。

伪代码

所以,agent 的目标是最大化以下函数:

共两部分,一个是标准的 VPG 强化学习算法的目标,另一个是特征重建的目标。

因为 VPG 算法的目标函数求导为:

所以,更新网络参数的更新公式为 (熵正则化项 H H H 是作者额外添加的):

最后,贴上伪代码:

其中,line 5 是先从 D \mathcal D D 中抽取一个病人,然后再从病人身上抽出一个症状;相当于是病人因为有了这个症状才来看病的;
line10-13 是医生问症状过程,line14-16 是医生下诊断过程。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值