自注意力机制(Self-Attention)

自注意力机制代码(pytorch版):

import torch
from torch import nn


class SelfAttention(nn.Module):
    """ self attention module"""

    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.chanel_in = in_dim

        self.query = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.key = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.value = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)

        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)

    def forward_sing(self, x):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X (HxW) X (HxW)
        """
        m_batchsize, C, height, width = x.size()
        proj_query = self.query(x).reshape(
            m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key(x).reshape(m_batchsize, -1, width*height)
        energy = proj_query.bmm(proj_key)
        attention = self.softmax(energy)
        proj_value = self.value(x).reshape(m_batchsize, -1, width*height)

        out = proj_value.bmm(attention.permute(0, 2, 1))
        out = out.reshape(m_batchsize, C, height, width)

        out = self.gamma * out + x
        return out

    def forward(self, x):
        if x.ndim == 5:
            B, T = x.shape[:2]
            x = self.forward_sing(x.flatten(0, 1)).unflatten(0, (B, T))
            return x
        else:
            return self.forward_sing(x)

``
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值