对话诊断(X)Generative Adversarial Regularized Mutual Information Policy Gradient Framework for Automatic

Generative Adversarial Regularized Mutual Information Policy Gradient Framework for Automatic Diagnosis

先分享一篇发表在2020AAAI的使用GAN及互信息进行对话诊断的论文,重点关注模型的架构,解决的问题及所提方法。关于对话系统及对话诊断的发展及研究还有之前的论文以后再补。

禁二传二改

论文链接:http://zhoujingbo.github.io/paper/xia2020generative_aaai.pdf
Abstract:

使用强化学习进行自动诊断得到的精度不高,因此提出使用生成对抗式正则化互信息策略结构( Generative Adversarial regularized Mutual in formation Policy gradient framework ——GAMP)实现更精准更快速的自动诊断。在框架中,将GAN的生成器作为一个策略网络,并使用GAN的鉴别器作为奖励功能的一部分。这种生成式的对抗性正则化策略梯度框架可以尽量避免产生偏离常见诊断范式的症状查询的随机试验。此外,还添加了互信息来增强奖励功能,以鼓励模型选择最具鉴别性的症状来做出诊断。

Introduction:

自动诊断通过用户的自我叙述以及多轮交互的信息来预测患者最可能患的疾病。自动诊断系统在简化诊断程序、降低收集患者信息的成本、帮助做出更好更有效的决策方面具有巨大的潜力。近年来,研究人员对使用强化学习(RL)进行自动诊断问题建模越来越感兴趣。自动诊断被认为是医生的询问和病人的回答,符合RL使用反馈来处理顺序决策问题的特性。因此,RL被广泛认为是开发强大的自动诊断解决方案的合适候选者。

然而,RL在解决自动诊断方面仍存在一些挑战。

  • 首先,由于诊断数据的规模有限,RL倾向于产生随机试验,而不考虑常见诊断范式中的症状和疾病之间的相关性。而在现实生活中,医生总是依据医学诊断逻辑仔细选择询问的问题。并且RL需要大量的数据来学习这些潜在的知识(医学知识),而由于收集数据的成本和患者的隐私问题,现有可利用的诊断数据规模要小得多。
  • 其次,需要一种复杂的方法来设置奖励函数。现有的研究已经提到奖励对于RL中的策略学习至关重要,但目前仍然没有很好的解决方案来设置奖励功能。
模型架构:

在这里插入图片描述

整体对话流程如上,对话状态跟踪器(DST),用以跟踪用户和代理的状态。NLU从输入文本中提取医疗实体和关键问题,NLG可以向患者生成对话问题,与之前的对话诊断系统各部分定义相同。

在这里插入图片描述

先前方法存在的问题:①使用模拟数据,不能反映现实情况;②使用特定的奖励函数,函数的设计原理不清楚,大多采用DQN算法不能询问最有利的症状。本文的贡献:

  • GAMP框架的新颖之处在于生成对抗网络(GAN)与RL模型的集成。以GAN的生成器作为RL的策略网络。GAN的鉴别器可以用来估计症状序列是医生询问的“真实“序列的可能性,并基于此设计了一个奖励函数来指导策略网络的优化。这种新的策略学习策略称为生成式对抗性正则化策略梯度。之所以提出该模型,是因为医生通常会根据已有医学知识根据患者自述症状提出相关的问题。例如,在问了一个“你会头痛吗”的问题和有一个“是”的答案后,很少有医生会问“你有脚痛吗”,因为这些症状几乎不可能存在于同一种疾病中而强化学习并没有相关的医学基础,所以会询问一些不相关的症状,并且基于有限的训练数据,强化学习很难捕获潜在的、复杂的医学知识
  • 其次,本文还使用互信息增强奖励函数以优化模型。之所以使用互信息是因为在诊断过程中,医生会询问最具鉴别力的症状,从而排除不可能的疾病(鉴别诊断)。在模型框架中,使用一个推断机制计算当前状态疾病概率分布和紧邻的下一个疾病概率分布的互信息,而后将计算的互信息融入奖励函数指导策略学习。
提出的方法(重点介绍对话智能体)

主要包括生成器、判别器、推理引擎三部分组成,生成器用于查询可能出现症状的患者;鉴别器用于评估查询的序列是真实还是虚假;推理引擎用于推断可能的疾病。NLU使用简单的Bi-LSTM模型,NLG基于模板(不是重点)。

基本流程与以上对话系统相同:整个对话代理系统接收患者的自我报告,学习问询的症状与患者互动,并在对话结束时做出诊断。当达到最大回合 T T T,或推理引擎达到阈值 τ \tau τ,或生成器选择终止的节点(疾病节点)时,对话会话将终止。

动作空间

Agent:如果有 m m m种疾病 n n n个症状,对话系统的动作空间大小为 m + n + l m+n+l m+n+l,其中 l l l为附加的动作,如感谢、结束等。

用户的动作空间:为对询问症状的回复,即拒绝、确认、不确定。

在每个交互步,对话状态追踪(DST)会记录对话的状态,包括之前用户和系统Agent的动作。

用户模拟器:

每个用户模拟器都会包含一个目标,一个用户目标通常由疾病标签、显性症状、隐性症状、请求槽值四部分组成。当Agent询问某一症状时,用户有三种回复: True (for the positive symptom), False (for the negative symptom), and Not sure (for the unmentioned symptom)。在会话结束时,用户模拟器会判断Agent诊断的正确与否。如果诊断错误或到达最大会话步,对话失败。

策略学习:

主要包括三部分:生成策略(系统动作)的生成器 G θ G_{\theta} Gθ、评估序列的真实与否的判别器 D ψ D_\psi Dψ、用于推断可能疾病的推理引擎 D ϕ D_\phi Dϕ

生成器 G θ G_{\theta} Gθ:症状序列生成器,使用LSTM预测下一个症状,利用患者的自述症状和对话进行训练,训练过程通常是语言模型的最大似然估计(MLE)方法。在训练过程中,生成器根据输入的序列 Y 1 : t − 1 Y_{1:t-1} Y1:t1输出预测的下一个症状 y t y_t yt。在训练完 G θ G_\theta Gθ后,利用其生成一系列虚假的症状序列,并将虚假的症状序列存入虚假数据集。

判别器 D ψ D_\psi Dψ:用于评估序列的真实与否,使用MLP对输入的序列进行鉴别。利用生成器生成的虚假序列以及真实的序列进行训练。

推理引擎 D ϕ D_\phi Dϕ:用于推断可能的疾病,使用MLP,最后一层使用softmax激活函数计算各个疾病的概率。使用真实数据,利用有监督方法进行训练,输入患者的自述症状及对话数据,输出疾病概率分布,标签为医生的诊断结果,采用交叉熵进行Loss计算。需要注意的是,一旦 D ϕ D_\phi Dϕ训练完成,在框架当中就不会再涉及它的更新,即可以看作离线训练一个分类器。

(我之前没想清楚为什么要使用生成对抗网络进行对话诊断,后来想到是不是想要模型通过产生与真实医生问诊尽可能相似的序列来实现贴近实际的诊断。但是这样真的合理吗?但互信息确实有一定的道理)

在这里插入图片描述

在强化学习框架下,生成器(策略) G θ ( a t ∣ s t ) G_\theta(a_t|s_t) Gθ(atst)的目标是从初始状态产生一个序列以最大化期望回报,其中 R T R_T RT是症状序列的奖励值, τ \tau τ是交互生成的轨迹, Q D ψ G θ ( s , a ) Q_{D_\psi}^{G_\theta}(s,a) QDψGθ(s,a)是状态-动作价值函数,表示在某一个状态下采取某个动作的估计价值,通过判别器 D ψ D_\psi Dψ计算得到。
J ( θ ) = E [ R T ∣ s , θ ] = ∑ t a u G θ ( a ∣ s ) ⋅ Q D ψ G θ ( s , a ) ( 1 ) J(\theta)=E[R_T|s,\theta]=\sum_{tau}G_\theta(a|s) \cdot Q_{D_\psi}^{G_\theta}(s,a)\qquad(1) J(θ)=E[RTs,θ]=tauGθ(as)QDψGθ(s,a)(1)

目标函数计算从初始状态(用户自述症状)开始的对话序列的奖励值,生成器的目标是生成一个判别器无法分辨真假的序列(像真实序列的假序列)。本文使用基于策略梯度的REINFORCE算法,并将 D ψ D_\psi Dψ的评估值作为每步的奖励。判别器从部分观测的症状序列和全部症状序列中计算奖励值 R D R_D RD
R D = Q D ψ G θ ( s = Y 1 : t − 1 , a = y t ) = D ψ ( Y 1 : t − 1 ) ( 2 ) R_D=Q_{D_\psi}^{G_\theta}(s=Y_{1:t-1},a=y_t)=D_\psi(Y_{1:t-1})\qquad(2) RD=QDψGθ(s=Y1:t1,a=yt)=Dψ(Y1:t1)(2)
而后,一旦产生了更加真实的序列,就利用公式(3)重新训练判别器。
m i n ψ − E Y − p d a t a [ l o g D ψ ( Y ) ] − E Y − G θ [ l o g ( 1 − D ψ ( Y ) ) ] ( 3 ) min_\psi-E_{Y-p_{data}}[logD_\psi(Y)]-E_{Y-G_\theta}[log(1-D_\psi(Y))]\qquad(3) minψEYpdata[logDψ(Y)]EYGθ[log(1Dψ(Y))](3)

互信息正则化策略梯度

在真实世界中,医生会询问一些有鉴别性的症状来做鉴别诊断。因此本文想要通过不断更新生成器使其获得问询关键症状的能力以此更好的区分难分辨的疾病。在此基础上,提出使用互信息来提升模型的性能。

在信息论中熵可以衡量事件的不确定性。而疾病诊断需要降低这种不确定性(找出与患者症状最相关的疾病),为了一步步的减少不确定性,生成器需要考虑去除不确定性的症状,这也意味着系统询问的症状可以降低疾病概率分布的熵这是很重要的。现实问诊流程也是一步一步的确定患者得了某种病,而不是某些病)。在信息论中, X X X Y Y Y之间的互信息 I ( X ; Y ) I(X;Y) I(X;Y)可以看成是一个随机变量中包含的关于另一个随机变量的信息量,或者说是一个随机变量由于已知另一个随机变量而减少的不肯定性。可以通过以下两个项的差值计算得到:
I ( X ; Y ) = H ( X ) − H ( X ∣ Y ) ( 4 ) I(X;Y)=H(X)-H(X|Y)\qquad(4) I(X;Y)=H(X)H(XY)(4)
在本文中,计算当前状态下疾病的概率分布 O t − 1 O_{t-1} Ot1与相邻的下一个状态的疾病概率分布 O t O_t Ot之间的互信息(想要通过问询症状提升患者患病的确定性):
I ( O t − 1 ; O t ∣ D ϕ ) = H ( O t − 1 ∣ D ϕ ) − H ( O t ∣ D ϕ ) ( 5 ) I(O_{t-1};O_t|D_\phi)=H(O_{t-1}|D_\phi)-H(O_t|D_\phi)\qquad(5) I(Ot1;OtDϕ)=H(Ot1Dϕ)H(OtDϕ)(5)

O t − 1 = D ϕ ( Y 1 : t − 1 ′ ) ( 6 ) O_{t-1}=D_\phi(Y_{1:t-1}^{'})\qquad(6) Ot1=Dϕ(Y1:t1)(6)

其中 H ( ⋅ ) H(\cdot) H()表示熵计算, Y 1 : t − 1 ′ Y_{1:t-1}^{'} Y1:t1表示患者确定存在的症状(图二中圆圈表示), Y 1 : t − 1 Y_{1:t-1} Y1:t1表示Agent询问的症状(图二中的菱形)。 y t y_t yt表示下一个要询问的症状。

在医学诊断过程中,医生询问可能的症状。它通常有两个意图。首先,医生想根据病人的回答来确认他的初步诊断。第二,医生可以根据病人的回答来排除可能出现的疾病。医学诊断的过程是逐步消除候选疾病的必要方法。如图3所示,每多询问一个症状,疾病的分布应该更加确定,即疾病的概率分布应该呈现山峰-山谷的形状而不是平坦的形状。

在这里插入图片描述

因此,给定推理引擎 D ϕ D_\phi Dϕ(在训练生成器过程中参数 ϕ \phi ϕ固定),通过生成器产生的候选询问症状应该逐渐提升互信息的奖励值 R M R_M RM
R M = Q D ϕ G θ ( s = Y 1 : t − 1 ′ , a = y t ) = I ( O t − 1 ; O t ∣ D ϕ ) ( 7 ) R_M=Q_{D_\phi}^{G_\theta}(s=Y_{1:t-1}^{'},a=y_t)=I(O_{t-1};O_t|D_\phi)\qquad(7) RM=QDϕGθ(s=Y1:t1,a=yt)=I(Ot1;OtDϕ)(7)
为了使生成器生成的症状序列更加自然真实且具有区分疾病的能力,将判别器的评估奖励与推理引擎得到的互信息奖励融合起来共同训练生成器的参数。
R F = ( 1 − λ ) R M + λ ( R D − ϵ ) ( 8 ) R_F=(1-\lambda)R_M+\lambda(R_D-\epsilon)\qquad(8) RF=(1λ)RM+λ(RDϵ)(8)
其中 λ \lambda λ代表奖励的权重, ϵ \epsilon ϵ控制判别器的效果。如果判别器认为生成器生成的序列为真实的,则后半部分的奖励为正值。一般设置 ϵ \epsilon ϵ为0.5,即概率大于0.5时判别器认为生成的序列是真实的。

当判别器更新后,开始更新生成器。基于强化学习REINFORCE算法(基于策略梯度更新)的更新公式如下, J ( θ ) J(\theta) J(θ)为目标函数。
∇ θ J ( θ ) = ∑ t = 1 T E y t − G θ [ ∇ θ l o g G θ ( y t ∣ Y 1 : t − 1 ) ⋅ R F ( t ) ] ( 9 ) \nabla_{\theta}J(\theta)=\sum_{t=1}^TE_{y_t-G_\theta}[\nabla_{\theta}logG_\theta(y_t|Y_{1:t-1})\cdot R_F^{(t)}]\qquad(9) θJ(θ)=t=1TEytGθ[θlogGθ(ytY1:t1)RF(t)](9)

θ : = θ + α ∇ θ J ( θ ) ( 10 ) \theta:=\theta+\alpha\nabla_{\theta}J(\theta)\qquad(10) θ:=θ+αθJ(θ)(10)

生成器

在本文中,模仿医生现实的问诊流程设计了一个序列查询生成器 G θ G_\theta Gθ,通过多轮交互的方式从患者自述症状开始对患者进行多轮问话,患者依据生成器查询的症状进行否定或确认。使用RNN构建生成器(LSTM),RNN通过递归使用更新函数 δ \delta δ将查询序列 x 1 , . . . , x T x_1,...,x_T x1,...,xT映射到隐藏状态 h 1 , . . . , h T h_1,...,h_T h1,...,hT。同时应用一个softmax输出层将隐藏状态映射到输出标记分布中,其中权重矩阵 W W W和向量 b b b是参数。
h t = δ ( h t − 1 , x t ) ( 11 ) h_t=\delta(h_{t-1},x_t)\qquad(11) ht=δ(ht1,xt)(11)

p ( x t ∣ x 1 , . . . , x t − 1 ) = s o f t m a x ( W h t + b ) ( 12 ) p(x_t|x_1,...,x_{t-1})=softmax(Wh_t+b)\qquad(12) p(xtx1,...,xt1)=softmax(Wht+b)(12)

判别器和推理引擎

结构见上文

实验
数据集

使用2018年Liu等(该数据集包含710个用户目标和66种症状,包括四种标记疾病,包括上呼吸道感染、儿童功能性消化不良、婴儿腹泻和儿童支气管炎)及2019年Xu等公开的数据集(该数据集总共包含527个会话数据和41个症状。共含五种疾病类型,包括变应性鼻炎、上呼吸道感染、肺炎、儿童手足口病和儿童腹泻。)。

实验设置

实验评估:评估标准包括以下三个

  • 诊断的准确率
  • 平均对话轮次
  • 正确匹配率,即Agent正确询问用户隐性症状的概率。

训练设置:不赘述

实验结果

包括对比实验和消融实验。
在这里插入图片描述
在这里插入图片描述

结论

本文提出了一个用于自动诊断的生成对抗性正则化互信息策略梯度框架(GAMP),旨在使一个更好的医疗对话系统具有更高的诊断精度和更少的交互对话回合。首先,使用生成式对抗性正则化策略梯度优化诊断系统,试图避免询问偏离医生的常见诊断范式的不合理的症状。其次,设计了一种新的奖励机制,添加互信息作为奖励功能的一部分。在两个公共数据集上的实验评估验证了所提方法的有效性。该框架不仅可以提高诊断的准确性,而且可以使用较少的查询来做出诊断决策。

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

wavehaha

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值