论文阅读--Multi-Attention Multiple Instance Learning

该文介绍了注意力机制如何增强多实例学习(MIL)的性能,尤其是在图像分析任务中。通过引入注意力机制,模型能够更好地关注图像中的关键区域。文章详细阐述了模型的构建,包括特征提取、邻域注意力、模板学习和全局注意力。实验在MNIST数据集以及其它表格数据集上展示了模型的有效性,表明注意机制可以提高分类准确性和模型的解释性。
摘要由CSDN通过智能技术生成

注意力机制背后的最初想法是取代简单的平均输出
简单的平均输出: y ∗ = n − 1 ∑ i = 1 n y i y^*=n^{-1}\sum\limits_{i=1}^{n}y_i y=n1i=1nyi
引入回归模型后 x x x的标签的预测值为:
y ∗ = ∑ i = 1 n α ( x , x i ) y i y^*=\sum\limits_{i=1}^{n}\alpha(x,x_i)y_i y=i=1nα(x,xi)yi
其中的加权为: α ( x , x i ) \alpha(x,x_i) α(x,xi)表示 x i x_i xi x x x的相关性, x i x_i xi属于训练集的向量, y i y_i yi表示训练向量的对应标签。

权重的原始形式之一是由核K定义的,它可以被视为一个得分函数,估计向量 x i x_i xi x x x的相关性。权重公式如下:
α ( x , x i ) = K ( x , x i ) ∑ j = 1 n K ( x , x j ) \alpha(x,x_i)=\frac{K(x,x_i)}{\sum_{j=1}^{n} K(x,x_j)} α(x,xi)=j=1nK(x,xj)K(x,xi)

符号含义
X X X
I I I实例
N N N包个数
m k m_k mk k k k个包中实例的个数
( x i , y i ) (x_i,y_i) (xi,yi) i i i个实例在包中的坐标
N i ⊂ M \mathcal{N}_i\subset\mathcal{M} NiM i i i个小块的邻居的索引集
P = P 1 , … , P C \mathcal{P}={P_1,\dots,P_C} P=P1,,PC模板的集合

开局一张图

在这里插入图片描述

  • 最左边的图片表示一个包,其中的每一个小块表示包中的每一个实例,假设其中有 m m m个实例,使用两个不同颜色的高亮色块表示第 i i i个实例 I i I_i Ii和第 j j j个实例 I j I_j Ij

  • 所有小块(实例)都被输入经过训练的神经网络,使用一维卷积计算其嵌入 F i F_i Fi,目的是降低实例 I i I_i Ii的维度:
    F i = C o n v ( I i ) , i = 1 , … , m . F_i=Conv(I_i),i=1,\dots,m. Fi=Conv(Ii),i=1,,m.

  • 我们使用 ( x i , y i ) (x_i,y_i) (xi,yi)表示在整个图像中第 i i i个小块的坐标。
    使用 M = { 1 , … , m } \mathcal{M}=\{1,\dots,m\} M={1,,m}表示整幅图像中所有小块的索引集。
    i i i个小块的邻居的索引集 N i ⊂ M \mathcal{N}_i\subset\mathcal{M} NiM为:
    N = { j ∈ M ∣ 0 < max ⁡ ( ∣ x i − x j ∣ , ∣ y i − y j ∣ ) ≤ d } \mathcal{N}=\{j\in \mathcal{M}|0<\max(|x_i-x_j|,|y_i-y_j|)\leq d\} N={jM∣0<max(xixj,yiyj)d}
    d d d可以是任意值,在本实验中 d d d的取值为1。并且值得注意的是,邻居是在包内找的。

  • 邻居的使用可以改善分类结果,这里我们引入注意机制,计算 I i I_i Ii邻居联合嵌入 B i B_i Bi为:
    B i = ∑ j ∈ N i α j ( i ) F j B_i = \sum\limits_{j\in\mathcal{N}_i}\alpha^{(i)}_{j}F_j Bi=jNiαj(i)Fj

  • 根据第实例 I i I_i Ii的第 j j j个邻居的权重 α j ( i ) \alpha_j^{(i)} αj(i)
    α j ( i ) = s o f t m a x ( s ( i ) ) = e x p ( s j ( i ) ) ∑ t e x p ( s t ( i ) ) \alpha_j^{(i)}=\mathbf{softmax}(s^{(i)})=\frac{exp(s_j^{(i)})}{\sum_texp(s_t^{(i)})} αj(i)=softmax(s(i))=texp(st(i))exp(sj(i))
    其中 s j ( i ) s_j^{(i)} sj(i)是将两个向量 F i F_i Fi F j F_j Fj映射到标量的注意力评分函数:
    s j ( i ) = s c o r e n b ( F i , F j ) = F T t a n h ( V n b F j ) s_j^{(i)}=\mathbf{score}_{nb}(F_i,F_j)=F^Ttanh(V_{nb}F_j) sj(i)=scorenb(Fi,Fj)=FTtanh(VnbFj)
    其中 V V V是可学习参数的矩阵, t a n h tanh tanh是双曲正切函数
    在这里插入图片描述

  • 总之,我们为每个 F i F_i Fi得到一个嵌入 B i B_i Bi B i B_i Bi中包含了关于邻居的信息。将第 i i i个实例的嵌入与邻域 B i B_i Bi的联合嵌入串联起来,得到嵌入 T i T_i Ti:
    T i = ( F i , B j ) T_i=(F_i,B_j) Ti=(Fi,Bj)
    串联操作在图中用符号⊕表示. 现在每个包都由一组嵌入向量 T i T_i Ti表示。

该算法提出的另一个重要概念是一组嵌入模板,这些模板被视为可学习向量。此外,模板数量 C C C,其值决定分类质量。模板是实现 T i T_i Ti集的多注意力的方式,由于神经网络的初始化规则不同,它们是不同的在作者提供的代码中,使用如下方式获得template:

self.templates = nn.Parameter(torch.randn(self.n_templates, self.L2,                                                   requires_grad=True))

这里的templates是一个n_templates*L2大小的tensor, ( 10 ∗ 256 ) (10*256) (10256),这些随机数字满足标准正态分布(0~1)。代码中n_templates的值为 10 10 10,L2的值为 128 ∗ 2 128*2 1282。即这里生成了10个 1 ∗ 256 1*256 1256大小的向量,因为 P k P_k Pk要和 T i T_i Ti求相似性,所以这两个的形状是一样的。

使用 P = P 1 , … , P C \mathcal{P}={P_1,\dots,P_C} P=P1,,PC表示模板的集合,然后,每个模板产生相应的聚合嵌入 E k E_k Ek
E k = ∑ i = 1 m β k ( k ) T i , k = 1 , … , C E_k=\sum\limits_{i=1}^{m}\beta_k^{(k)}T_i,k=1,\dots,C Ek=i=1mβk(k)Ti,k=1,,C
这里其实是对包中的所有实例的映射向量进行求和,使用 β k ( k ) \beta_k^{(k)} βk(k)实现注意力机制的引入。

其中:这个公式不懂没有说score是怎么算的,猜测是和上述 F i F_i Fi F j F_j Fj的打分方法一样。
β k ( k ) = s o f t m a x ( ⟨ s c o r e ( P k , T i ) ⟩ i ) \beta_k^{(k)}=\mathbf{softmax}(\langle score(P_k,T_i)\rangle_i) βk(k)=softmax(⟨score(Pk,Ti)i)

这里引入softmax进行归一化操作。
k k k个聚合嵌入 E k E_k Ek是图像中具有邻域的所有小块(实例)嵌入的加权平均值,其权重由第 k k k个模板 P k P_k Pk定义。

  • 与所有 C C C个模板相对应的聚合嵌入 E k E_k Ek被分组到与整个图像或袋子和所有模板相对应的总向量 Z Z Z中, Z Z Z是包最后被映射为的向量。
    Z = ∑ k = 1 C γ k E k Z=\sum\limits_{k=1}^{C}\gamma_kE_k Z=k=1CγkEk
    其中:
    γ k = s o f t m a x ( ⟨ s c o r e ( G , E k ) ⟩ k ) \gamma_k=\mathbf{softmax}(\langle score(G,E_k)\rangle_k) γk=softmax(⟨score(G,Ek)k)
    G G G是训练向量的全局模板。模板 G G G和相应的最终注意 final attention决定了哪些聚合嵌入 E k E_k Ek和模板 P k P_k Pk是重要的,在分类之前使用注意机制可以使得模型具有可解释性。

整个神经网络使用随机梯度下降和Adam优化器进行端到端的训练。在每个步骤中,从训练集中选择一个图像(一个包),然后将其划分为多个小图像块。这些小图像块被用作神经网络的输入。神经网络的结果是对一类概率的估计,例如,这意味着图像是否包含恶性细胞。根据得到的估计和整个图像标签确定损失函数。为了更新神经网络权重,使用自动微分算法为每个训练参数确定损失函数的偏导数值。然后以标准方式更新训练参数值。

损失函数:

L = 1 N ∑ k = 1 N B C E ( Y k , f ( X k ) ) + 2 C ( C − 1 ) ∑ i < j ( P i T ⋅ P j ) 2 L=\frac{1}{ N}\sum\limits_{k=1}^{N}BCE(Y_k,f(X_k))+\frac{2}{C(C-1)}\sum\limits_{i<j}(P_{i}^{T}\cdot P_j)^2 L=N1k=1NBCE(Yk,f(Xk))+C(C1)2i<j(PiTPj)2

这里的第一项是标准二分类交叉熵(BCE)损失函数 f ( X k ) f(X_k) f(Xk)是整个神经网络的输出。 Y k Y_k Yk是真实标签

BCE损失函数: l o s s ( x , y ) = − [ y log ⁡ x + ( 1 − y ) log ⁡ ( 1 − x ) ] loss(x,y)=-[y\log x+(1-y)\log(1-x)] loss(x,y)=[ylogx+(1y)log(1x)] 其中 x x x是模型输出, y y y是真实标签。

在第二项中,当 P i P_i Pi P j P_j Pj的差距越大时,损失就越小,因此损失函数中的第二项旨在使得 P i P_i Pi P j P_j Pj尽可能不同。

计算包中实例的重要性

根据上述的步骤,一个包最终被映射成了一个向量的表示: Z = ∑ k = 1 C γ k E k Z=\sum\limits_{k=1}^{C}\gamma_kE_k Z=k=1CγkEk
可以通过如下公式推导出:
在这里插入图片描述
在这里可以把 w i = ∑ k = 1 C γ k β i ( k ) w_i=\sum_{k=1}^C\gamma_k\beta_i^{(k)} wi=k=1Cγkβi(k)看作是实例 I i I_i Ii在它的邻居中的重要性。并且 ∑ i = 1 m w i = 1 \sum_{i=1}^{m}w_i=1 i=1mwi=1,包中所有的实例的重要性之和为 1 1 1

注意!前面已经提到过,第实例 I i I_i Ii的第 j j j个邻居的重要性为 α j ( i ) \alpha_j^{(i)} αj(i)。因此我们可以考虑将实例邻居的重要性和实例本身的重要性联系起来。
如果我们假设向量 T i T_i Ti的所有元素都是等价的,那么它的两部分的重要性度量是相等的。但是,如果指数为 i i i j j j的图像小块是相邻的,则 B i B_i Bi的值可能会与 T j T_j Tj有关联。因此我们定义第 i i i个图像小块的最终重要性 v i v_i vi为:
v i = w i + ∑ j ∈ N i α j ( i ) w j v_i=w_i+\sum\limits_{j\in N_i}\alpha_{j}^{(i)}w_j vi=wi+jNiαj(i)wj

数据集

MNIST数据集

MNIST数据集是一个常用的28x28像素手写数字图像的大型数据集。它有一个60000个实例的训练集和10000个实例的测试集。数字的大小已规范化,并在固定大小的图像中居中。数据集位于http://yann.lecun.com/exdb/mnist/.
在这里插入图片描述

Datasets Musk1, Musk2, Fox, Tiger, Elephant

在我们的实验中,我们使用了与Attention MIL中相同的第一个特征提取器的架构、优化器和超参数。其他参数与MNIST的实验相同,模板的数量为10,所有嵌入Fi的维数都等于128,嵌入Ti、模板Pi、聚合嵌入Ek、向量G和Z的大小都相同128。
整个模型是在不使用邻居的情况下训练的,因为它们不相关,可能会恶化预测结果。
与图像数据集相比,所考虑的表格数据集没有任何邻接关系,即没有定义实例的邻居。MAMIL在所有实验的数据集上都优于另外两个对比算法。

代码分析

https://github.com/andruekonst/mamil
代码使用了Jupyter Notebook 和python代码

#设置随机数种子为123,注意每次重新运行程序时,相同的随机数种子返回的随机数相同
torch.manual_seed(args.seed)#在CPU上设置随机数种子
if args.cuda:
    torch.cuda.manual_seed(args.seed)#在GPU上设置随机数种子

#训练集
train_set = MnistBags(
    target_number=args.target_number,
    neighbour_number=args.neighbour_number,#设置邻居的数量
    mean_bag_length=args.mean_bag_length,#包长度的平均值?
    var_bag_length=args.var_bag_length,
    num_bag=args.num_bags_train,#训练包的个数,人为设置
    seed=args.seed,
    train=True
)

train_weights = get_class_balancing_weights(train_set)#每个包的权重
train_sampler = data_utils.WeightedRandomSampler(train_weights, len(train_weights))#根据训练包的权重进行随机采样,具体做法见https://blog.csdn.net/tyfwin/article/details/108435756
train_shuffle = False  #无需洗牌,因为采样器正在使用

MAMIL1D构造函数

    def __init__(self, n_templates: int = 10, bottleneck_width: int = 4):
        """Initializes 2D MAMIL model.

        Args:
          n_templates: Number of templates.
          bottleneck_width: Bottleneck spatial width (and height), depends on input patches size.
          bottleneck宽度(和高度)取决于输入面片(图像小块)的大小。
        """
        super().__init__()
        self.L = 128
        self.L2 = self.L * 2
        self.D = 128
        self.K = 1
        self.n_templates = n_templates
        self.bottleneck_width = bottleneck_width
        self.channels = 50
        self.embedding_dim = self.channels * self.bottleneck_width ** 2

        #特征提取第一部分
        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv2d(1, 20, kernel_size=5),#卷积层
            nn.ReLU(),#激活函数
            nn.MaxPool2d(2, stride=2),#池化层
            nn.Conv2d(20, self.channels, kernel_size=5),#卷积层
            nn.ReLU(),#激活函数
            nn.MaxPool2d(2, stride=2)#池化层
        )
        #特征提取第二部分
        self.feature_extractor_part2 = nn.Sequential(
            nn.Linear(self.embedding_dim, self.L),#线性层
            nn.ReLU(),#激活函数
        )

        self.neighbours_attention = nn.Sequential(
            nn.Linear(self.L, self.L),#线性层
            nn.Tanh()#激活函数
        )

        #用来生成随机数字的tensor,这些随机数字满足标准正态分布(0~1) 这里生成形状为 n_templates*L2 的随机数矩阵
        self.templates = nn.Parameter(torch.randn(self.n_templates, self.L2,
                                                   requires_grad=True))
        self.proto_attention = nn.Sequential(
            nn.Linear(self.L2, self.L2),
            nn.Tanh()
        )
        self.global_attention = nn.Sequential(
            nn.Linear(self.L2, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(
            nn.Linear(self.L2*self.K, 1),
            nn.Sigmoid()
        )

神经网络:

  def forward(self, x, return_proto_scores=False):
        """Calculates instance-level probabilities and an attention matrix.计算实例级概率和注意矩阵。

        Args:
          x: Input batch with exactly one bag of batches.
          return_proto_scores: Return templates scores or not.

        Returns:
          Tuple of (instance probabilities, instance labels, attention matrix),(实例概率、实例标签、注意矩阵)的元组,
          or (instance probabilities, instance labels, attention matrix, patch scores, template scores)实例概率、实例标签、注意矩阵、补丁得分、模板得分
          if `return_proto_scores` is True.
            
        """
        x = x.squeeze(0)

        H = self.feature_extractor_part1(x)
        H = H.view(-1, self.embedding_dim)
        H = self.feature_extractor_part2(H)  # NxL

        # Classical Attention-MIL:
        # A = self.attention(H)  # NxK
        # A = torch.transpose(A, 1, 0)  # KxN
        # A = F.softmax(A, dim=1)  # softmax over N
        # M = torch.mm(A, H)  # KxL

        # Neighbourhood attention:
        #   H shape: NxL
        #   Naive loop-based implementation:
        neighbourhood_embeddings = []
        for i in range(H.shape[0]):
            if i == 0:
                cur_neighbours = torch.cat((H[i + 1: i + 2], H[i + 1: i + 2]),
                                           axis=0)
            elif i == H.shape[0] - 1:
                cur_neighbours = torch.cat((H[i - 1: i], H[i - 1: i]),
                                           axis=0)
            else:
                cur_neighbours = torch.cat((H[i - 1: i], H[i + 1: i + 2]),
                                           axis=0)
            cur_instance_embedding = H[i: i + 1]
            # cur_neighbours shape: 2xL
            cur_alphas = torch.mm(
                self.neighbours_attention(cur_neighbours),
                cur_instance_embedding.T
            )  # 2x1
            cur_alphas = torch.transpose(cur_alphas, 1, 0)  # 1x2
            cur_alphas = F.softmax(cur_alphas, dim=1)  # 1x2
            cur_neighbourhood_emb = torch.mm(cur_alphas, cur_neighbours)  # 1xL
            neighbourhood_embeddings.append(cur_neighbourhood_emb)
        neighbourhood_embeddings = torch.cat(neighbourhood_embeddings, dim=0)
        H = torch.cat((H, neighbourhood_embeddings), dim=1)  # Nx2L

        # Multi-template:
        patch_scores = torch.mm(
            self.proto_attention(H), # H,
            self.templates.T
        )  # NxP, P = n_templates
        patch_scores = torch.transpose(patch_scores, 1, 0)  # PxN
        betas = F.softmax(patch_scores, dim=1)  # PxN
        embs = torch.mm(betas, H)  # PxL2

        template_scores = self.global_attention(embs)  # PxK
        template_scores = torch.transpose(template_scores, 1, 0)  # KxP
        gammas = F.softmax(template_scores, dim=1)  # KxP
        M = torch.mm(gammas, embs)  # KxL2
        A = torch.mm(gammas, betas)

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        if return_proto_scores:
            return Y_prob, Y_hat, A, patch_scores, template_scores
        return Y_prob, Y_hat, A

原文中提到了Template这一概念, 在代码中使用如下方式实现,

self.templates = nn.Parameter(torch.randn(self.n_templates, self.L2,
                                                   requires_grad=True))

这里的templates是一个n_templates*L2大小的tensor,这些随机数字满足标准正态分布(0~1)。
templates的提出有如下几个优点:

  • 模板可以用于实例类型的分类。这可以通过确定特定图像的模板γk的权重来实现
  • 另一个优点是可以简单地添加新模板或排除不必要的模板。由于在不改变网络权重的情况下,模板的数量可以不同,因此可以在训练后添加或排除模板。添加新模板对应于出现新类型单元的场景。添加了新的模板可训练向量,所有其他网络参数均未训练且保持不变。可以基于整个数据集上模板权重的统计信息来排除模板。在不降低预测质量的情况下,可以排除总权重最低的模板。
  • 模板可以可视化。对于每个模板Pk,我们可以计算其所有补丁嵌入Ti与所有数据集图像的邻域的相似性,并找到得分最大的补丁(Pk;Ti)。这个补丁可以看作是模板的近似可视化。可以训练从嵌入的面片图像重建面片图像的自动编码器,以便从模板重建的面片图像可以被视为模板的可视化。
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值