利用深度学习进行生存分析——DeepSuv模型小结

生存分析是一种典型的医疗时间事件(time-event)分析场景,其主要分析序列研究中事件(如复发、死亡、治愈等)随着时间变化的统计规律,从而发现其中的敏感/危险因子。

其经典的统计学手段主要有:
(1)参数估计法:即知道其分布函数,根据数据估计其分布参数;
(2)非参数估计法:如KM估计(Kaplan-Meier);
(3)半参数估计法:如Cox比例风险模型。

在生存分析中,无论是其计算(如KM估计、Cox比例风险模型),还是其评估(如C-index)均需要注意右删失(right-censored)数据的处理。比如在KM估计中,累乘的每个时间段的分母总数均需剔除右删失数据,而在Cox比例风险模型和C-index计算中,均仅考虑比基准案例时间更长的右删失案例作为pair进行计算。

Cox比例风险模型是一种广义线性回归模型,其通过将风险函数分解为独立的时间项(即基准风险函数)和敏感因子项,从而可以忽略时间因素的影响: h ( X , t ) = h 0 ( t ) ∗ exp ⁡ β X h(X,t)=h_0(t)*\exp^{\beta X} h(X,t)=h0(t)expβX
其目标函数为局部似然(partial likelihood)函数: L = ∏ i : E i = 1 exp ⁡ h ^ θ ( x i ) ∑ j ∈ R ( T i ) exp ⁡ h ^ θ ( x j ) L=\prod\limits_{i:E_i=1}\frac{\exp^{\hat h_\theta(x_i)}}{\sum\limits_{j\in R(T_i)}\exp^{\hat h_\theta(x_j)}} L=i:Ei=1jR(Ti)exph^θ(xj)exph^θ(xi)

其中 i , j i,j i,j均为样本编号, E = 1 E=1 E=1表示终止事件(死亡、复发)发生; R ( T i ) R(T_i) R(Ti)表示满足 T j > T i T_j>T_i Tj>Ti条件,对应的损失函数(定义为neg log partial likelihood loss): l = − log ⁡ L = − ∑ i : E i = 1 ( h ^ θ ( x i ) − log ⁡ ∑ j ∈ R ( T i ) exp ⁡ h ^ θ ( x j ) ) l=-\log L=-\sum\limits_{i:E_i=1}(\hat h_\theta(x_i)-\log{\sum\limits_{j\in R(T_i)}\exp^{\hat h_\theta(x_j)}}) l=logL=i:Ei=1(h^θ(xi)logjR(Ti)exph^θ(xj))

Cox比例风险模型将传统的敏感因子风险函数转换为线性方程的形式:
h ( x ; β ) = β X h(x;\beta)=\beta X h(x;β)=βX其假设过强,仅能捕捉到线性关系。

而DeepSuv模型将深度学习理念应用于Cox比例风险模型,将敏感因子风险函数通过多层感知机(额外会引入非线性激活函数、drop out、BN等)的形式进行表达,更够更好的捕捉变量间的关系。

局部似然函数从本质上看其实就是list-wise ranking,区别于常规IR问题中list-wise ranking之处在于doc list是每个sample直接给出的;而在生存分析中是从全部样本中,对每个sample筛选出满足 R ( T i ) R(T_i) R(Ti)条件的作为list。因此,为了更方便的计算每个mini-batch应当选取全量的样本数(配合下面给出的NegativeLogLikelihood使用)。

在使用中,其每条训练数据包括如下三大部分:
(1)risk_pred: 预测的生存期/风险函数,即cox回归指数项上的结果,注意该数据与实际生存期间的正负关系(比如风险函数与生存期为反相关关系)
(2)y: 真实事件终止事件(可能为右删失数据,也有可能为真实事件终止)
(3)e: event indicator, 1-事件终止; 0-右删失

下面给出了其损失函数的定义,其中需要特别注意mask矩阵的定义和应用:

import torch
import torch.nn as nn

# 写法一:以coloum为基准
class NegativeLogLikelihood(nn.Module):
    def __init__(self):
        super(NegativeLogLikelihood, self).__init__()

    def forward(self, risk_pred, y, e):
        """
        @params: risk_pred: 预测的生存期/风险函数,即cox回归指数项上的结果,注意该数据与实际生存期间的正负关系(比如风险函数与生存期为法相关系)   shape: (N,1)
        @params: y: 真实事件终止事件(可能为右删失数据,也有可能为真实事件终止)    shape:(N,1)
        @params: e: event indicator, 1-事件终止; 0-右删失     shape:(N,1)
        """
        mask = torch.ones(y.shape[0], y.shape[0])     # mask矩阵, mask(i,j)中j表示基准事件,i为其它对比事件
        mask[(y.T-y) > 0] = 0             # 基准事件真实存活期大于其它对比事件的,无需考虑
        exp_loss = torch.exp(risk_pred) * mask       # mask非必要项,(N, N)
        log_loss = torch.log((exp_loss.sum(dim=0))/(mask.sum(dim=0)))      # 这里取平均以消除pair中样本长度的影响, (N, 1)
        log_loss = log_loss.reshape(-1, 1)
        neg_log_loss = -torch.sum((risk_pred - log_loss) * e) / torch.sum(e)       # 不需要计入右删失值
        return neg_log_loss

# 写法二:以row为基准
class NegativeLogLikelihood2(nn.Module):
    def __init__(self):
        super(NegativeLogLikelihood2, self).__init__()
        self.epsilon = torch.tensor(1e-5)

    def forward(self, risk_pred, y, e):
        """
        @params: risk_pred: 预测的生存期/风险函数,即cox回归指数项上的结果,注意该数据与实际生存期间的正负关系(比如风险函数与生存期为法相关系)   shape: (N,1)
        @params: y: 真实事件终止事件(可能为右删失数据,也有可能为真实事件终止)    shape:(N,1)
        @params: e: event indicator, 1-事件终止; 0-右删失     shape:(N,1)
        """
        mask = torch.ones(y.shape[0], y.shape[0])     # mask矩阵, mask(i,j)中i表示基准事件,j为其它对比事件
        mask[(y-y.T) > 0] = 0             # 基准事件真实存活期大于其它对比事件的,无需考虑
        exp_loss = torch.exp(risk_pred.T) * mask       # mask非必要项,(N, N)
        log_loss = torch.log((exp_loss.sum(dim=1))/(mask.sum(dim=1)))      # 取平均,注意防止数值下溢, (N,)
        e = e.reshape(-1)
        neg_log_loss = -torch.sum((risk_pred.T - log_loss) * e) / torch.sum(e)       # 不需要计入右删失值
        return neg_log_loss

DeepSuv模型除了能够给出风险函数,还可以用于治疗方案的对比。具体来说,可以在训练、测试数据中加入治疗方案的因子,通过对比在不同方案下(其他因素均保持一致)的风险函数大小给出治疗方案的排序。

【Reference】

  1. 生存分析基本概念介绍
  2. DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network
  • 7
    点赞
  • 55
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
深度学习生存期预测模型是一种利用深度学习算法来预测患者生存时间的模型。与传统的生存分析方法不同,深度学习生存期预测模型可以考虑更多的因素,并且能够进行个性化的生存率估计。 首先,深度学习生存期预测模型使用复杂的神经网络结构来处理生存数据。这些神经网络可以学习到输入特征与生存时间之间的复杂非线性关系。通过训练模型并使用大量的数据,深度学习可以提取出隐藏在数据中的重要特征,并用于预测患者的生存时间。 其次,深度学习生存期预测模型可以考虑患者的特殊情况。这意味着模型可以根据患者的个体特征和临床指标来进行个性化的生存率估计。例如,模型可以考虑患者的年龄、性别、疾病分期、基因表达等因素,从而更准确地预测患者的生存时间。 最后,深度学习生存期预测模型可以通过比较不同的患者群体来评估预后。模型可以根据不同患者群体的特征和临床指标,建立不同的生存函数,并将它们绘制在同一张图上进行比较。这样,我们可以了解不同患者群体之间生存概率的差异。 综上所述,深度学习生存期预测模型是一种能够考虑更多因素、进行个性化估计并比较不同患者群体的生存预测模型。它为医生和研究人员提供了更准确的生存时间预测,有助于制定更精准的治疗方案和预测患者的预后。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* *3* [深度学习用于医学预后-第二课第三周14-15节-评估方法比较以及Kaplan-Meier估计](https://blog.csdn.net/u014264373/article/details/130690941)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 100%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值