Co Attention注意力机制实现

Hierarchical Question-Image Co-Attention for Visual Question Answering”中的图像和文本间的Co Attention协同注意力实现

参考:

https://github.com/SkyOL5/VQA-CoAttention/blob/master/coatt/coattention_net.py

https://github.com/Zhangtd/Models-reproducing/blob/master/NIPS2016/selfDef.py

Co Attention示意图如下:

有两种实现方式,分别为Parallel co-attention mechanism和Alternating co-attention mechanism

基于PyTorch实现Parallel co-attention mechanism,代码如下:

from typing import Dict, Optional

import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import Tensor


def create_src_lengths_mask(
        batch_size: int, src_lengths: Tensor, max_src_len: Optional[int] = None
):
    """
    Generate boolean mask to prevent attention beyond the end of source
    Inputs:
      batch_size : int
      src_lengths : [batch_size] of sentence lengths
      max_src_len: Optionally override max_src_len for the mask
    Outputs:
      [batch_size, max_src_len]
    """
    if max_src_len is None:
        max_src_len = int(src_lengths.max())
    src_indices = torch.arange(0, max_src_len).unsqueeze(0).type_as(src_lengths)
    src_indices = src_indices.expand(batch_size, max_src_len)
    src_lengths = src_lengths.unsqueeze(dim=1).expand(batch_size, max_src_len)

    # returns [batch_size, max_seq_len]
    return (src_indices < src_lengths).int().detach()


def masked_softmax(scores, src_lengths, src_length_masking=True):
    """Apply source length masking then softmax.
    Input and output have shape bsz x src_len"""
    if src_length_masking:
        bsz, max_src_len = scores.size()
        # print('bsz:', bsz)
        # compute masks
        src_mask = create_src_lengths_mask(bsz, src_lengths)
        # Fill pad positions with -inf
        scores = scores.masked_fill(src_mask == 0, -np.inf)

    # Cast to float and then back again to prevent loss explosion under fp16.
    return F.softmax(scores.float(), dim=-1).type_as(scores)


class ParallelCoAttentionNetwork(nn.Module):

    def __init__(self, hidden_dim, co_attention_dim, src_length_masking=True):
        super(ParallelCoAttentionNetwork, self).__init__()

        self.hidden_dim = hidden_dim
        self.co_attention_dim = co_attention_dim
        self.src_length_masking = src_length_masking

        self.W_b = nn.Parameter(torch.randn(self.hidden_dim, self.hidden_dim))
        self.W_v = nn.Parameter(torch.randn(self.co_attention_dim, self.hidden_dim))
        self.W_q = nn.Parameter(torch.randn(self.co_attention_dim, self.hidden_dim))
        self.w_hv = nn.Parameter(torch.randn(self.co_attention_dim, 1))
        self.w_hq = nn.Parameter(torch.randn(self.co_attention_dim, 1))

    def forward(self, V, Q, Q_lengths):
        """
        :param V: batch_size * hidden_dim * region_num, eg B x 512 x 196
        :param Q: batch_size * seq_len * hidden_dim, eg B x L x 512
        :param Q_lengths: batch_size
        :return:batch_size * 1 * region_num, batch_size * 1 * seq_len,
        batch_size * hidden_dim, batch_size * hidden_dim
        """
        # (batch_size, seq_len, region_num)
        C = torch.matmul(Q, torch.matmul(self.W_b, V))
        # (batch_size, co_attention_dim, region_num)
        H_v = nn.Tanh()(torch.matmul(self.W_v, V) + torch.matmul(torch.matmul(self.W_q, Q.permute(0, 2, 1)), C))
        # (batch_size, co_attention_dim, seq_len)
        H_q = nn.Tanh()(
            torch.matmul(self.W_q, Q.permute(0, 2, 1)) + torch.matmul(torch.matmul(self.W_v, V), C.permute(0, 2, 1)))

        # (batch_size, 1, region_num)
        a_v = F.softmax(torch.matmul(torch.t(self.w_hv), H_v), dim=2)
        # (batch_size, 1, seq_len)
        a_q = F.softmax(torch.matmul(torch.t(self.w_hq), H_q), dim=2)
        # # (batch_size, 1, seq_len)

        masked_a_q = masked_softmax(
            a_q.squeeze(1), Q_lengths, self.src_length_masking
        ).unsqueeze(1)

        # (batch_size, hidden_dim)
        v = torch.squeeze(torch.matmul(a_v, V.permute(0, 2, 1)))
        # (batch_size, hidden_dim)
        q = torch.squeeze(torch.matmul(masked_a_q, Q))

        return a_v, masked_a_q, v, q

测试代码如下:

pcan = ParallelCoAttentionNetwork(6, 5)
v = torch.randn((5, 6, 10))
q = torch.randn(5, 8, 6)
q_lens = torch.LongTensor([3, 4, 5, 8, 2])
a_v, a_q, v, q = pcan(v, q, q_lens)
print(a_v)
print(a_v.shape)
print(a_q)
print(a_q.shape)
print(v)
print(v.shape)
print(q)
print(q.shape)

效果如下:

tensor([[[9.2527e-04, 1.1542e-03, 1.1542e-03, 1.1542e-03, 2.0009e-02,
          9.2527e-04, 4.0845e-02, 8.8328e-01, 1.1958e-03, 4.9358e-02]],

        [[4.5501e-01, 8.6522e-02, 8.6522e-02, 1.7235e-05, 3.8831e-03,
          2.5070e-04, 9.0637e-05, 4.0010e-03, 2.0196e-03, 3.6169e-01]],

        [[8.8455e-03, 7.2149e-04, 1.7595e-04, 2.1307e-04, 7.0610e-01,
          1.3427e-01, 4.3360e-04, 4.0731e-02, 4.0731e-02, 6.7774e-02]],

        [[4.0013e-01, 2.3081e-02, 3.8406e-02, 4.3583e-03, 9.9425e-05,
          3.8398e-02, 9.9425e-05, 9.4912e-02, 4.0013e-01, 3.9162e-04]],

        [[3.1121e-02, 8.0567e-05, 4.0445e-01, 1.4391e-03, 8.0567e-05,
          4.0445e-01, 7.6909e-02, 2.4837e-04, 4.3044e-03, 7.6909e-02]]],
       grad_fn=<SoftmaxBackward>)
torch.Size([5, 1, 10])
tensor([[[0.3466, 0.3267, 0.3267, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.2256, 0.3276, 0.2237, 0.2232, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.1761, 0.2254, 0.2254, 0.1823, 0.1908, 0.0000, 0.0000, 0.0000]],

        [[0.1292, 0.1411, 0.1411, 0.1100, 0.1292, 0.1100, 0.1101, 0.1292]],

        [[0.5284, 0.4716, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
       grad_fn=<UnsqueezeBackward0>)
torch.Size([5, 1, 8])
tensor([[-0.7862,  1.0180,  0.1585,  0.4961, -1.5916, -0.3553],
        [ 0.3624, -0.2036,  0.2993, -0.4440,  0.2494,  1.4896],
        [ 0.1695, -0.2286,  0.4431,  0.6027, -1.6116,  0.0566],
        [ 0.2004,  0.8219, -0.2115, -0.6428,  0.3486,  1.3802],
        [ 1.4024, -0.1860,  0.1685,  0.2352, -0.4956,  1.0010]],
       grad_fn=<SqueezeBackward0>)
torch.Size([5, 6])
tensor([[ 0.3757,  0.1662,  0.2181,  0.0787,  0.0110, -0.5938],
        [-0.6106,  0.4000,  0.6068, -0.4054,  0.0193, -0.1147],
        [ 0.3877, -0.1800,  1.2430, -0.4881, -0.3598, -0.3592],
        [-0.3799, -0.3262,  0.0745, -0.2856,  0.0221, -0.1749],
        [ 0.1159, -0.4949, -0.5576, -0.6870, -1.2895,  0.0421]],
       grad_fn=<SqueezeBackward0>)
torch.Size([5, 6])

  • 6
    点赞
  • 66
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
co-attention协同注意力机制是一种在多模态任务中应用的机制。它通过同时关注两个不同的输入序列,以便更好地理解它们之间的关系。这种机制有两种实现方式,分别为Parallel co-attention mechanism和Alternating co-attention mechanism。 其中,Parallel co-attention mechanism是将注意力机制应用在两个输入序列之间的每一次互动上。具体来说,它为每个单词在区域上创建一个注意图,并为每个区域在单词上创建一个注意图。这种机制可以循环叠加使用,以进一步增强关注的效果。 Alternating co-attention mechanism则是通过交替地在两个输入序列之间进行注意力计算来实现。它首先计算第一个序列对第二个序列的注意力分布,然后再计算第二个序列对第一个序列的注意力分布。通过交替计算,可以更好地捕捉到两个序列之间的相关性。 总之,co-attention协同注意力机制是一种在多模态任务中应用的机制,它可以帮助我们更好地理解和建模不同输入序列之间的关系。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [Co Attention注意力机制实现](https://blog.csdn.net/tszupup/article/details/117292683)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 33.333333333333336%"] - *2* [【论文解析】《 Dense Symmetric Co-Attention for VQA》改进视觉和语言表示的密集对称协同注意力机制的...](https://blog.csdn.net/weixin_44794449/article/details/101753183)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 33.333333333333336%"] - *3* [《Deep Modular Co-Attention Networks for Visual Question Answering》论文笔记](https://download.csdn.net/download/weixin_38621897/14035239)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 33.333333333333336%"] [ .reference_list ]
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值