【写在前面】该论文发表于KDD2020,属于数据挖掘的顶级会议。HiTANet在局部和全局两个阶段建模时间信息,局部评估阶段有一个时间感知Transformer,它将时间信息嵌入到visit-level embedding中,并且为每个visit生成一个局部attention权重;全局综合阶段进一步采用一个时间感知的 key-query attention机制,给不同时间步分配全局权重。最后,将两种类型的注意力权重进行动态组合,去生成用于风险预测的病人的表示。
论文地址:https://dl.acm.org/doi/pdf/10.1145/3394486.3403107
论文源码:https://github.com/HiTANet2020/HiTANet
病人EHR数据举例
每个病人在就诊时都会在不同的时间间隔内有多次入院记录,即visit;每次visit都会得到一些诊断结果,以ICD码的形式给出,如490。作者就是根据这些以往的诊断结果和时序信息来推断当前的诊断结果。
Model
该模型分为三个部分:
1. Visit Analysis
在第t次visit,模型会学习一个向量表示 v t v_t vt,融合了诊断代码 x t x_t xt和时间间隔 δ t \delta_t δt。基于 [ v 1 , v 2 , . . . , v T , v ∗ ] [v_1,v_2,...,v_T,v_*] [v1,v2,...,vT,v∗]使用Transformer学习一个隐状态 h t \boldsymbol{h_t} ht,然后使用隐状态生成一个局部attention分数 α t \alpha_t αt
2. Comprehensive Analysis
为了建模疾病发展的过程,首先将 h t \boldsymbol{h_t} ht嵌入成一个“query”向量 q q q,并且将每个时间间隔 δ t \delta_t δt嵌入成一个“key"向量,根据一个key-query attention来为每个visit x t x_t xt得到一个全局attention分数 β t \beta_t βt
3. Time-aware Dynamic Attention Fusion
该模型使用 α t \alpha_t αt和 β t \beta_t βt以及隐向量 h ∗ h_* h∗,为了为每个visit获得一个overall attention 分数 γ t ′ \gamma_t^{'} γt′,然后得到最终的表示 h ′ h^{'} h′用来进行预测
Local Level: Visit Analysis
使用一个线性函数将稀疏的visit向量 x t x_t xt映射到一个低维稠密的空间:
经过线性函数之后每个病人的表示为:
E
=
[
e
1
,
e
2
,
.
.
.
,
e
T
,
e
∗
]
E=[e_1,e_2,...,e_T,e_*]
E=[e1,e2,...,eT,e∗]。
使用如下函数,将时间间隔转换成一个和病人表示向量维度相同的向量,然后在进行累加 v t = e t + r t \boldsymbol{v_t}=\boldsymbol{e_t} + \boldsymbol{r_t} vt=et+rt.
在得到
V
=
[
v
2
,
v
2
,
.
.
.
,
v
T
,
v
∗
]
V=[v_2,v_2,...,v_T,v_*]
V=[v2,v2,...,vT,v∗]之后,使用给一个一层的Transformer(F)来学习每个visit的长期依赖。
当医生诊断出来时,他们将不只关注当前的就诊,而是回顾历史医疗记录,并搜索与目标疾病高度相关的记录。
为了模仿以上过程,做做使用了局部attention机制,为每个visit都得到一个attention值:
最后将
η
=
[
η
1
,
η
2
,
.
.
.
,
η
T
]
\boldsymbol\eta=[\eta_1,\eta_2,...,\eta_T]
η=[η1,η2,...,ηT]经过一个Softmax函数来生成局部attention值:
采用时间感知的Transformer,可以得到每个visit的attention权重。
Global Level: Comprehensive Analysis
事实上,医生不仅关注个人就诊,而且还通过分析整体诊断(即x∗)来做出疾病的最终判断。
为了模仿以上过程,作者提出了一个时间感知key-query attention机制。
对于整个诊断的隐状态表示 h ∗ \boldsymbol{h}_* h∗,首先利用它得到一个query向量 q ∈ R S q \in R^S q∈RS:
这是使用了ReLU 来保证数值为正值,相比于负值,正值的结果可能对总结整体诊断的特征更有价值。
在分析整体诊断时,医生还想知道哪些时间点对疾病至关重要。
采用与公式(2)相似的方式将时间信息编码到一个隐含空间中:
- Eq.(2)专注于捕获出现与时间信息相关的诊断代码的重要性。Eq.(6)试图描述在疾病进展过程中时间信息本身的重要性,而不考虑任何诊断代码。
为了了解风险预测过程中每个时间间隔的意义,需要得到每个时间间隔向量对整体隐含向量的注意力权重:
得到之后最终还是要经过一个Softmax层来进行归一化:
Time-aware Dynamic Attention Fusion
Local注意机制可以看作是采用“正向”操作来模仿医生的诊断程序,而Global注意机制类似于“逆向”操作,这是回顾性分析时间信息的重要性。
经过上面两个过程可以得到两个注意力权重 α \boldsymbol{\alpha} α和 β \boldsymbol{\beta} β,在这里使用一个动态attention融合机制,去捕捉不同cases下的访问表示和时间表示的偏好。
首先,将overall 表示 h ∗ \boldsymbol{h_*} h∗嵌入到一个新的空间中,并且使用Softmax进行归一化:
然后,我们根据注意力权重和嵌入的整体表示z为每次访问生成相应的整体关注权重,如下:
最后对整体权重进行归一化:
Prediction
使用最终生成的权重,来得到最终病人表示:
在预测的过程使用了一个简单的线性层和softmax层来进行线性预测:
在目标函数中使用了一个交叉熵来计算损失:
Model Algorithm
Experiment
我们将风险预测任务表述为一个二元分类问题来预测患者是否有一个特定的发病。
Experimental Setup
Datasets
Obstructive Pulmonary Disease(COPD), Heart Failure and Kidney Disease/慢性阻塞性肺病(COPD)、心力衰竭和肾病。(自己组建的数据集,没有公开)数据统计如下:
Baseline
Traditional methods : SVM, LR, RF
Plain RNNs: LSTM, GRU
Attention-based Models: Dipole, Retain, SAnD
Time-based Models: RetainEx, T-LSTM , TimeLine
Mertics
Accuracy (Acc), Precision (Pre), Recall, F1, and Area Under Curve (Auc) scores
Evaluation Strategy
training data, validation data, and testing data, in a ratio 0.75:0.10:0.15.
作者在验证集上修复了最佳模型,并报告了在测试集中的性能,随后执行了五次随机运行,并报告了测试性能的均值和标准偏差
Implementation Details
实验硬件配置 :Ubuntu 16.04 内存64G 和Tesla V100 GPU