好文推荐 A transformer-based representation-learning model with unified processing of multimodal input

本文介绍了一种新型的医疗决策支持系统IRENE,它利用Transformer处理胸部医学中的多模态数据,包括图像、非结构化主诉和结构化信息。通过双向跨模态注意力机制,IRENE能捕捉不同模态间的关系。实验对比展示了其在临床诊断中的优势。
摘要由CSDN通过智能技术生成

论文地址:https://www.nature.com/articles/s41551-023-01045-x

代码地址:https://github.com/RL4M/IRENE


基于Transformer的表示学习模型,作为临床诊断辅助工具,以统一的方式处理多模态输入。将图像与文字转化为visual tokens和text tokens,通过一个双向的跨模态注意力机制块共同学习不同信息间的整体特征和其关联性来做出决策。

第一个以统一方式使用人工智能处理多模态信息,在临床上辅助医生进行决策诊断。为后续医学领域人工智能处理多模态信息提供一种新的思路。

Data

胸腔医学中,除了胸部X射线,医生还需要考虑患者的人口统计学信息(如年龄和性别)、主诉(如现病史和既往病史)以及实验室检查报告,以便做出准确的诊断决策。实际上,医生会首先将异常的放射学图像模式与主诉中提到的症状或实验室检查报告中的异常结果相关联。然后,医生依靠他们丰富的领域知识和多年的培训,通过共同解释这些多模态数据来做出最佳诊断。

在医学临床领域,常见的数据类型有三种:
图像(Radiograph)
主诉(Chief complaint):非结构化信息,现病史和既往病史
人口统计信息和实验室检查结果(Demographicsand lab test results):结构化信息,性别年龄等

    当前的临床辅助决策系统,常采用非同一的方式。首先将非结构化的主诉转化成结构化的数据,然后将不同模态的数据输入到不同的机器学习模块中,产生特定模态的特征。最后使用融合模块对这些特征进行融合。
但是这样做的一个问题是,特定模态模型的训练与融合过程相分离,导致不能获取不同模态之间的联系与关联。

    本文提出的 IRENE 共同学习图像、非结构化主诉和结构化临床信息的整体表示来进行决策。

图1.非同一方式与统一方式

Network structure

IRENE由嵌入层、两个多模态注意力块、10个自注意力块和一个输出层组成。

  • free-form embedding 将非结构化与结构化的文字转化成text tokens, image embedding将图像转化成image tokens
  • bidirectional multimodal blocks 不仅计算同一模态内部的注意力,还计算不同模态之间的注意力
    bidirectional 指的是text tokens要与image token做注意力,同时image token也要与text token做注意力
  • Self-attention blocks 经过两个双向的多模态注意力块后,连接text tokens与image tokens,然后进行自注意力计算

在这里1233插入图片描述

图2.network structure

  1. 图像:图片经过一个卷积层
    在这里插入图片描述
  2. 文字
      主诉(ChiComp):经过bert获得token_id, max_len设为40
      实验室检查结果(LabTest):经过bert获得token_id max_len设为92
      人口统计信息(Sex、Age):经过一个Liner获得 长度为1
    最终得到一个长度为40+92+1+1的一维向量,再送入到嵌入层

在这里插入图片描述
3. 双向多模态注意力块
自注意力:输入序列经过一个线性映射,得到K,Q,V  (n, d)
     ①Q与K相乘,计算相似度,得到权重分布  (n, d) * (d, n) = (n,n)
     ②权重分布经过softmax进行归一化
     ③权重分布与V相乘,加权求和  (n, n) * (n,d) = (n, d)
经过自注意力机制,可以捕捉到输入序列中不同位置之间的关系和依赖

现在有两个K,Q,V:
 text tokens的 KT, QT, VT
 image tokens的KI, QI, VI
所以:
QI与KI, VI计算注意力捕捉图像之间的依赖关系, QI与KT, VT计算注意力捕捉图像与文本之间的依赖关系
QT与KT, VT计算注意力捕捉文本之间的依赖关系, QT与KI, VI计算注意力捕捉文本与图像之间的依赖关系

在这里插入图片描述

在本文中,取λ为1。在第一层双向多模态注意力层中Xi 和Xt分别取平均送入第二层双向多模态注意力块,
在第二层双向多模态注意力层中Xi 和Xt分别取平均后,进行拼接,送入自注意力块中

在这里插入图片描述

实验

对比实验:

  • 只使用图像
  • 非统一方式早期融合
  • 非统一方式晚期融合
  • 多模态模型GIT: 在大量的图像-文本对中训练   问题:临床医学数据难获取
  • 多模态模型Perceiver: 将不同模态的数据进行拼接作为输入。 问题:某个模态数据较少时,关注度低
    在这里插入图片描述
    消融实验:
  • 不加入双向多模态注意力
  • 加入单向多模态注意力
  • 加入6层双向多模态注意力
  • 不加入主诉
  • 不加入实验室检查结果
  • 不加入图像数据在这里插入图片描述

代码解读

整体架构

class IRENE(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(IRENE, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier
        self.transformer = Transformer(config, img_size, vis)
        self.head = Linear(config.hidden_size, num_classes)
	# 经过一个自定义的Transformer,再经过一个Linear
    def forward(self, x, cc=None, lab=None, sex=None, age=None, labels=None):
        x, attn_weights = self.transformer(x, cc, lab, sex, age)
        logits = self.head(torch.mean(x, dim=1))

        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_classes), labels.float())
            return loss
        else:
            return logits, attn_weights, torch.mean(x, dim=1)

自定义的Transformes:Embeddings和Encoder组成

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids, cc=None, lab=None, sex=None, age=None):
        embedding_output, cc, lab, sex, age = self.embeddings(input_ids, cc, lab, sex, age)
        text = torch.cat((cc, lab, sex, age), 1)
        encoded, attn_weights = self.encoder(embedding_output, text)
        return encoded, attn_weights

Encoder

class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for i in range(config.transformer["num_layers"]):
            if i < 2: # 两个双向多模态注意力
                layer = Block(config, vis, mm=True)
            else: # 自注意力
                layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states, text=None):
        attn_weights = []
        for (i, layer_block) in enumerate(self.layer):
            if i == 2: #在第二个双向多模态注意力块后,拼接img与text,送入自注意力块
                hidden_states = torch.cat((hidden_states, text), 1)  
                hidden_states, weights = layer_block(hidden_states)
            elif i < 2: # hidden_states:img
                hidden_states, text, weights = layer_block(hidden_states, text)
            else:
                hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights

双向多模态注意力

# img-img
'''
	需要计算四个注意力:
	text-text   text-img   img-img   img-text
'''
attention_scores_img = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores_img = attention_scores_img / math.sqrt(self.attention_head_size)
attention_probs_img = self.softmax(attention_scores_img)
weights = attention_probs_img if self.vis else None
attention_probs_img = self.attn_dropout(attention_probs_img)
context_layer_img = torch.matmul(attention_probs_img, value_layer_img)
context_layer_img = context_layer_img.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer_img.size()[:-2] + (self.all_head_size,)
context_layer_img = context_layer_img.view(*new_context_layer_shape)

# text-text
attention_scores_text = torch.matmul(query_layer_text, key_layer_text.transpose(-1, -2))
attention_scores_text = attention_scores_text / math.sqrt(self.attention_head_size)
attention_probs_text = self.softmax(attention_scores_text)
attention_probs_text = self.attn_dropout_text(attention_probs_text)
context_layer_text = torch.matmul(attention_probs_text, value_layer_text)
context_layer_text = context_layer_text.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer_text.size()[:-2] + (self.all_head_size,)
context_layer_text = context_layer_text.view(*new_context_layer_shape)

# img-text
attention_scores_it = torch.matmul(query_layer_img, key_layer_text.transpose(-1, -2))
attention_scores_it = attention_scores_it / math.sqrt(self.attention_head_size)
attention_probs_it = self.softmax(attention_scores_it)
attention_probs_it = self.attn_dropout_it(attention_probs_it)
context_layer_it = torch.matmul(attention_probs_it, value_layer_text)
context_layer_it = context_layer_it.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer_it.size()[:-2] + (self.all_head_size,)
context_layer_it = context_layer_it.view(*new_context_layer_shape)

# text-img
attention_scores_ti = torch.matmul(query_layer_text, key_layer_img.transpose(-1, -2))
attention_scores_ti = attention_scores_ti / math.sqrt(self.attention_head_size)
attention_probs_ti = self.softmax(attention_scores_ti)
attention_probs_ti = self.attn_dropout_ti(attention_probs_ti)
context_layer_ti = torch.matmul(attention_probs_ti, value_layer_img)
context_layer_ti = context_layer_ti.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer_ti.size()[:-2] + (self.all_head_size,)
context_layer_ti = context_layer_ti.view(*new_context_layer_shape)

# img-img 与 img-text取平均
attention_output_img = self.out((context_layer_img + context_layer_it)/2)
# text-text 与 text-img取平均
attention_output_text = self.out((context_layer_text + context_layer_ti)/2)

attention_output_img = self.proj_dropout(attention_output_img)
attention_output_text = self.proj_dropout_text(attention_output_text)
return attention_output_img, attention_output_text, weights
  • 13
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值