生存分析是一种典型的医疗时间事件(time-event)分析场景,其主要分析序列研究中事件(如复发、死亡、治愈等)随着时间变化的统计规律,从而发现其中的敏感/危险因子。
其经典的统计学手段主要有:
(1)参数估计法:即知道其分布函数,根据数据估计其分布参数;
(2)非参数估计法:如KM估计(Kaplan-Meier);
(3)半参数估计法:如Cox比例风险模型。
在生存分析中,无论是其计算(如KM估计、Cox比例风险模型),还是其评估(如C-index)均需要注意右删失(right-censored)数据的处理。比如在KM估计中,累乘的每个时间段的分母总数均需剔除右删失数据,而在Cox比例风险模型和C-index计算中,均仅考虑比基准案例时间更长的右删失案例作为pair进行计算。
Cox比例风险模型是一种广义线性回归模型,其通过将风险函数分解为独立的时间项(即基准风险函数)和敏感因子项,从而可以忽略时间因素的影响:
h
(
X
,
t
)
=
h
0
(
t
)
∗
exp
β
X
h(X,t)=h_0(t)*\exp^{\beta X}
h(X,t)=h0(t)∗expβX
其目标函数为局部似然(partial likelihood)函数:
L
=
∏
i
:
E
i
=
1
exp
h
^
θ
(
x
i
)
∑
j
∈
R
(
T
i
)
exp
h
^
θ
(
x
j
)
L=\prod\limits_{i:E_i=1}\frac{\exp^{\hat h_\theta(x_i)}}{\sum\limits_{j\in R(T_i)}\exp^{\hat h_\theta(x_j)}}
L=i:Ei=1∏j∈R(Ti)∑exph^θ(xj)exph^θ(xi)
其中 i , j i,j i,j均为样本编号, E = 1 E=1 E=1表示终止事件(死亡、复发)发生; R ( T i ) R(T_i) R(Ti)表示满足 T j > T i T_j>T_i Tj>Ti条件,对应的损失函数(定义为neg log partial likelihood loss): l = − log L = − ∑ i : E i = 1 ( h ^ θ ( x i ) − log ∑ j ∈ R ( T i ) exp h ^ θ ( x j ) ) l=-\log L=-\sum\limits_{i:E_i=1}(\hat h_\theta(x_i)-\log{\sum\limits_{j\in R(T_i)}\exp^{\hat h_\theta(x_j)}}) l=−logL=−i:Ei=1∑(h^θ(xi)−logj∈R(Ti)∑exph^θ(xj))
Cox比例风险模型将传统的敏感因子风险函数转换为线性方程的形式:
h
(
x
;
β
)
=
β
X
h(x;\beta)=\beta X
h(x;β)=βX其假设过强,仅能捕捉到线性关系。
而DeepSuv模型将深度学习理念应用于Cox比例风险模型,将敏感因子风险函数通过多层感知机(额外会引入非线性激活函数、drop out、BN等)的形式进行表达,更够更好的捕捉变量间的关系。
局部似然函数从本质上看其实就是list-wise ranking,区别于常规IR问题中list-wise ranking之处在于doc list是每个sample直接给出的;而在生存分析中是从全部样本中,对每个sample筛选出满足
R
(
T
i
)
R(T_i)
R(Ti)条件的作为list。因此,为了更方便的计算每个mini-batch
应当选取全量的样本数(配合下面给出的NegativeLogLikelihood
使用)。
在使用中,其每条训练数据包括如下三大部分:
(1)risk_pred: 预测的生存期/风险函数,即cox回归指数项上的结果,注意该数据与实际生存期间的正负关系(比如风险函数与生存期为反相关关系)
(2)y: 真实事件终止事件(可能为右删失数据,也有可能为真实事件终止)
(3)e: event indicator, 1-事件终止; 0-右删失
下面给出了其损失函数的定义,其中需要特别注意mask
矩阵的定义和应用:
import torch
import torch.nn as nn
# 写法一:以coloum为基准
class NegativeLogLikelihood(nn.Module):
def __init__(self):
super(NegativeLogLikelihood, self).__init__()
def forward(self, risk_pred, y, e):
"""
@params: risk_pred: 预测的生存期/风险函数,即cox回归指数项上的结果,注意该数据与实际生存期间的正负关系(比如风险函数与生存期为法相关系) shape: (N,1)
@params: y: 真实事件终止事件(可能为右删失数据,也有可能为真实事件终止) shape:(N,1)
@params: e: event indicator, 1-事件终止; 0-右删失 shape:(N,1)
"""
mask = torch.ones(y.shape[0], y.shape[0]) # mask矩阵, mask(i,j)中j表示基准事件,i为其它对比事件
mask[(y.T-y) > 0] = 0 # 基准事件真实存活期大于其它对比事件的,无需考虑
exp_loss = torch.exp(risk_pred) * mask # mask非必要项,(N, N)
log_loss = torch.log((exp_loss.sum(dim=0))/(mask.sum(dim=0))) # 这里取平均以消除pair中样本长度的影响, (N, 1)
log_loss = log_loss.reshape(-1, 1)
neg_log_loss = -torch.sum((risk_pred - log_loss) * e) / torch.sum(e) # 不需要计入右删失值
return neg_log_loss
# 写法二:以row为基准
class NegativeLogLikelihood2(nn.Module):
def __init__(self):
super(NegativeLogLikelihood2, self).__init__()
self.epsilon = torch.tensor(1e-5)
def forward(self, risk_pred, y, e):
"""
@params: risk_pred: 预测的生存期/风险函数,即cox回归指数项上的结果,注意该数据与实际生存期间的正负关系(比如风险函数与生存期为法相关系) shape: (N,1)
@params: y: 真实事件终止事件(可能为右删失数据,也有可能为真实事件终止) shape:(N,1)
@params: e: event indicator, 1-事件终止; 0-右删失 shape:(N,1)
"""
mask = torch.ones(y.shape[0], y.shape[0]) # mask矩阵, mask(i,j)中i表示基准事件,j为其它对比事件
mask[(y-y.T) > 0] = 0 # 基准事件真实存活期大于其它对比事件的,无需考虑
exp_loss = torch.exp(risk_pred.T) * mask # mask非必要项,(N, N)
log_loss = torch.log((exp_loss.sum(dim=1))/(mask.sum(dim=1))) # 取平均,注意防止数值下溢, (N,)
e = e.reshape(-1)
neg_log_loss = -torch.sum((risk_pred.T - log_loss) * e) / torch.sum(e) # 不需要计入右删失值
return neg_log_loss
DeepSuv模型除了能够给出风险函数,还可以用于治疗方案的对比。具体来说,可以在训练、测试数据中加入治疗方案的因子,通过对比在不同方案下(其他因素均保持一致)的风险函数大小给出治疗方案的排序。
【Reference】