Bottleneck Transformers(单头,多头,关系矩阵理解)以及讲解视频,torch代码

Bottleneck transformers

对transformers不了解的可以看下以下视频,本文通过对数据流维度的标注,可以更容易理解Bottleneck transformers。

唐宇迪-transformer视频讲解
Bottleneck transformers论文视频讲解

Bottleneck transformers就是将restnet50的c5层的三个残差块中的3x3卷积操作进行替换
在这里插入图片描述

在这里插入图片描述

Bottleneck Transformers与自然语言处理中的自注意力对比

单头
在这里插入图片描述
数据流描述
对于右下角的操作,像non-local的操作,获取像素间的关系。
下图演示输入为b,c,h,w-(1,2048,16,8)的non-local关系矩阵,与nonlocal的区别是加了一个位置标签,以及可以扩展成多头
在这里插入图片描述

多头
这里展示head=2

  1. 多头就是将上述的c=10变成2x5,(64,10,196)->(64,2,5,196),将原来的数据进行拆分,为了后面数据的相乘产生多头的效果。
  2. 原本content-content单头数据(64,196,196)数据,通过各种矩阵相乘操作变成q(64,2,196,5) * k(64,2,5,196)=(64,2,196,196) 2头数据
  3. 两模型的模型大小一样
    在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.nn.functional as F

#单头
class S_MHSA(nn.Module):
    def __init__(self, n_dims, width=14, height=14):
        super(S_MHSA, self).__init__()

        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)

        self.rel_h = nn.Parameter(torch.randn([1, n_dims, 1, height]), requires_grad=True)
        self.rel_w = nn.Parameter(torch.randn([1, n_dims, width, 1]), requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        n_batch, C, width, height = x.size()
        q = self.query(x).view(n_batch, C, -1)#self.query(x)=torch.Size([64, 10, 14, 14]) #torch.Size([64, 10, 196])
        k = self.key(x).view(n_batch, C, -1)
        v = self.value(x).view(n_batch, C, -1)

        content_content = torch.bmm(q.permute(0, 2, 1), k) #q.permute(0, 2, 1)=(64,196,10) k=(64,10,196) (torch.Size([64, 196, 196]))
        # print(self.rel_h.shape)#(1,10,1,14)
        # print(self.rel_w.shape)  # (1,10,14,1)
        # print((self.rel_h + self.rel_w).shape)#(1,10,14,14)
        content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)#torch.Size([1, 196, 10])
        content_position = torch.matmul(content_position, q)#torch.Size([64, 196, 196])

        energy = content_content + content_position
        attention = self.softmax(energy)#torch.Size([64, 196, 196])
        print(attention.shape)
        print(v.shape)
        out = torch.bmm(v, attention.permute(0, 2, 1))#torch.Size([64, 10, 196])
        out = out.view(n_batch, C, width, height)#torch.Size([64, 10, 14, 14])
        return out

#多头
class MHSA(nn.Module):
    def __init__(self, n_dims, width=14, height=14, heads=2):
        super(MHSA, self).__init__()
        self.heads = heads

        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)

        self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)#(1,2,5,1,14)
        self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)#(1,2,5,14,1)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        n_batch, C, width, height = x.size()
        q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)#self.query(x)=torch.Size([64, 10, 14, 14]) #torch.Size([64, 2, 5,196])
        k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
        v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)

        content_content = torch.matmul(q.permute(0, 1, 3, 2), k)#q.permute(0, 1, 3, 2)=(64,2,196,5) k=(64,2,5,196) content_content(64,2,196,196)
        # print(content_content.shape)
        # print(self.rel_h.shape)#(1,2,5,1,14)
        # print(self.rel_w.shape)  # (1,2,5,14,1)
        # print((self.rel_h + self.rel_w).shape)#(1,2,5,14,14)
        content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)#(1,2,5,196)->(1,2,196,5)
        content_position = torch.matmul(content_position, q)#torch.Size([64, 2, 196, 196])
        energy = content_content + content_position#torch.Size([64, 2, 196, 196])
        attention = self.softmax(energy)

        out = torch.matmul(v, attention.permute(0, 1, 3, 2))#torch.Size([64, 2, 5, 196])
        out = out.view(n_batch, C, width, height)#(64,10,14,14)
        return out

if __name__ == '__main__':
    # x=torch.Tensor(64,10,14,14)#模拟resnet c5的第一个bottleneck,c的变化:第一个conv后的特征1024->512,第二个512->512,第三个512->2048
    # n_dims=10
    # model2=S_MHSA(n_dims=n_dims)
    # print("Model size: {:.5f}M".format(sum(p.numel() for p in model2.parameters()) / 1000000.0))#Model size: 0.00061M
    # out=model2(x)


    x1 = torch.Tensor(64, 10, 14, 14)  # 模拟resnet c5的第一个bottleneck,c的变化:第一个conv后的特征1024->512,第二个512->512,第三个512->2048
    n_dims = 10
    model1 = MHSA(n_dims=n_dims)
    print("Model size: {:.5f}M".format(sum(p.numel() for p in model1.parameters()) / 1000000.0))#Model size: 0.00061M
    out = model1(x1)
    # print(out.shape)

在这里插入图片描述

  • 11
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 11
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值