- 发表时间:2018
- 论文链接:https://www.aclweb.org/anthology/P18-2033
- 代码:https://github.com/LiuQL2/MedicalChatbot
- 代码语言:python
摘要
本文构建了一个用于自动诊断的对话系统。首先,从线上医学论坛上病人的自述以及病人医生间的交谈中提取症状,从而构建数据集;然后,本文提出了用于自动诊断的任务型对话系统框架,该系统能够通过与病人交谈,获取除病人自述外的其他症状。实验表明,从交谈中获取的额外症状能够极大地提升疾病诊断精度,本文的对话系统能够自动地收集这些症状,而且诊断准确度更高。
数据
数据是从中文医学网站上的儿科中收集的,包括四种疾病类型:上呼吸道感染、儿童功能性消化不良、腹泻和支气管炎。标记数据包括两个过程:症状提取、症状归一化。
症状提取
症状归一化
不同人表述症状是不一样的,比如有人说拉肚子,有人说腹泻,因此要将这些症状表述为专业术语,采用的是SNOMED CT(一种临床医学语标准)标准,如图2所示。通过病人所提供的症状可以分为两类:显性症状和隐性症状。显性症状是指病人在咨询时提供的症状,如病人说:“医生,我流鼻涕打喷嚏,这是怎么回事啊”,其中鼻流涕和打喷嚏就是两个显性症状。隐性症状是指医生通过咨询获知的症状,如医生接着问:“那你拉肚子不?”, 此处的腹泻就是隐性症状。
本文框架
本文对话系统框架包括三大模块:NLU(自然语言理解):检测用户意图、提取槽位值;DM(对话管理):追踪对话状态、给出系统行动;NLG(自然语言生成):根据系统行动生成自然语言。其中NLU和NLG模块都是采用基于模板的方法,重点研究DM模块。
DM模块包含两个子模块:对话状态追踪(DST)和策略学习。
对话策略学习
用户模拟器
训练是采用DQN,网络的输入是当前状态 S t S_t St,输出为agent的action。那么训练数据从何而来呢,这时候就需要用户模拟器,和agent模拟任务驱动对话过程。我们称这个过程为warm_start。此处假设有4种疾病,warm_start过程描述如下:
- 对话管理系统初始化
- 用户模拟器初始化action a u , 0 a_{u,0} au,0:随机从所有数据中选取一条,将数据中的所有显性症状作为inform_slots(例如上述例子中的鼻流涕和打喷嚏),action为:request, request_slots为disease。
- 利用初始的用户action更新状态
- 初始化agent。
- 模拟对话系统
- agent基于规则根据状态选择action:此处分析一下状态中用户提供了哪些显性症状,用户的需求是询问哪种疾病,例如:agent在数据集中发现只有鼻流涕和打喷嚏是不能确定任何一种疾病的,否则对话就结束了。那么agent就在数据集中查询4种疾病中哪种出现这两种症状的频率比较高,假设发现4种疾病中上呼吸道感染最容易出现这两种症状,但是还有三种症状需要询问病人去确认一下,那就从这三种症状中随机选一种吧,此时,agent采取action 为request,request_slots就是刚才随机选择的症状;
- agent采取行动后,DM就更新状态(此状态包含的信息比较多:对话轮数、agent_action、user_action、current_slots(提供当前需要的信息)、agent和用户的历史action信息等)
- user再根据状态和agent的action进行回复,如确认一下有没有agent询问的特征;
- DM再次更新状态;
上述过程不断进行,直到agent确诊了疾病是什么,或者达到我们设定的最大对话轮数。
从上述过程中,我们将(state,agent_action,reward,next_state,对话是否结束)这些信息记录下来,后续用作训练数据。
其中括号中的大多都是采用热编码的形式。(具体的过程请参看代码)
DQN训练
DQN是一种off-policy的深度强化学习算法,off-policy就表明了agent的动作选择网络和目标网络肯定不是同一个,因此有两个一样的网络,网络的输入是state,输出为agent_action。
- 网络结构:网络结构非常简单,一层全连接层;
- 输入(batch_size, 200):200是个假设的state表示维度,在实际使用中决定其大小的因素比较多;
- 输出:假设agent_action有300种,则输出为(batch_size,300)
动作选择网络每个iteration的参数都会利用loss回传进行更新,那么loss函数是怎么算的呢?
其标签就是将 next_state记为
s
′
s^{'}
s′送给目标网络,
y
i
=
r
+
γ
m
a
x
a
′
Q
(
s
′
,
a
′
)
y_{i} = r+\gamma max_{a^{'}}Q(s^{'},a^{'})
yi=r+γmaxa′Q(s′,a′),其中
r
r
r表示及时回报。
loss函数就是目标网络和动作选择网络输出
y
i
y_{i}
yi间的均方误差。
动作选择网络利用上述过程实时更新,目标网络在每个epoch结束时,直接copy动作选择网络的参数进行更新。