Bottleneck transformers
对transformers不了解的可以看下以下视频,本文通过对数据流维度的标注,可以更容易理解Bottleneck transformers。
唐宇迪-transformer视频讲解
Bottleneck transformers论文视频讲解
Bottleneck transformers就是将restnet50的c5层的三个残差块中的3x3卷积操作进行替换
Bottleneck Transformers与自然语言处理中的自注意力对比
单头
数据流描述
对于右下角的操作,像non-local的操作,获取像素间的关系。
下图演示输入为b,c,h,w-(1,2048,16,8)的non-local关系矩阵,与nonlocal的区别是加了一个位置标签,以及可以扩展成多头
多头
这里展示head=2
- 多头就是将上述的c=10变成2x5,(64,10,196)->(64,2,5,196),将原来的数据进行拆分,为了后面数据的相乘产生多头的效果。
- 原本content-content单头数据(64,196,196)数据,通过各种矩阵相乘操作变成q(64,2,196,5) * k(64,2,5,196)=(64,2,196,196) 2头数据
- 两模型的模型大小一样
代码
import torch
import torch.nn as nn
import torch.nn.functional as F
#单头
class S_MHSA(nn.Module):
def __init__(self, n_dims, width=14, height=14):
super(S_MHSA, self).__init__()
self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.rel_h = nn.Parameter(torch.randn([1, n_dims, 1, height]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, n_dims, width, 1]), requires_grad=True)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
n_batch, C, width, height = x.size()
q = self.query(x).view(n_batch, C, -1)#self.query(x)=torch.Size([64, 10, 14, 14]) #torch.Size([64, 10, 196])
k = self.key(x).view(n_batch, C, -1)
v = self.value(x).view(n_batch, C, -1)
content_content = torch.bmm(q.permute(0, 2, 1), k) #q.permute(0, 2, 1)=(64,196,10) k=(64,10,196) (torch.Size([64, 196, 196]))
# print(self.rel_h.shape)#(1,10,1,14)
# print(self.rel_w.shape) # (1,10,14,1)
# print((self.rel_h + self.rel_w).shape)#(1,10,14,14)
content_position = (self.rel_h + self.rel_w).view(1, C, -1).permute(0, 2, 1)#torch.Size([1, 196, 10])
content_position = torch.matmul(content_position, q)#torch.Size([64, 196, 196])
energy = content_content + content_position
attention = self.softmax(energy)#torch.Size([64, 196, 196])
print(attention.shape)
print(v.shape)
out = torch.bmm(v, attention.permute(0, 2, 1))#torch.Size([64, 10, 196])
out = out.view(n_batch, C, width, height)#torch.Size([64, 10, 14, 14])
return out
#多头
class MHSA(nn.Module):
def __init__(self, n_dims, width=14, height=14, heads=2):
super(MHSA, self).__init__()
self.heads = heads
self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)#(1,2,5,1,14)
self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)#(1,2,5,14,1)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
n_batch, C, width, height = x.size()
q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)#self.query(x)=torch.Size([64, 10, 14, 14]) #torch.Size([64, 2, 5,196])
k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
content_content = torch.matmul(q.permute(0, 1, 3, 2), k)#q.permute(0, 1, 3, 2)=(64,2,196,5) k=(64,2,5,196) content_content(64,2,196,196)
# print(content_content.shape)
# print(self.rel_h.shape)#(1,2,5,1,14)
# print(self.rel_w.shape) # (1,2,5,14,1)
# print((self.rel_h + self.rel_w).shape)#(1,2,5,14,14)
content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)#(1,2,5,196)->(1,2,196,5)
content_position = torch.matmul(content_position, q)#torch.Size([64, 2, 196, 196])
energy = content_content + content_position#torch.Size([64, 2, 196, 196])
attention = self.softmax(energy)
out = torch.matmul(v, attention.permute(0, 1, 3, 2))#torch.Size([64, 2, 5, 196])
out = out.view(n_batch, C, width, height)#(64,10,14,14)
return out
if __name__ == '__main__':
# x=torch.Tensor(64,10,14,14)#模拟resnet c5的第一个bottleneck,c的变化:第一个conv后的特征1024->512,第二个512->512,第三个512->2048
# n_dims=10
# model2=S_MHSA(n_dims=n_dims)
# print("Model size: {:.5f}M".format(sum(p.numel() for p in model2.parameters()) / 1000000.0))#Model size: 0.00061M
# out=model2(x)
x1 = torch.Tensor(64, 10, 14, 14) # 模拟resnet c5的第一个bottleneck,c的变化:第一个conv后的特征1024->512,第二个512->512,第三个512->2048
n_dims = 10
model1 = MHSA(n_dims=n_dims)
print("Model size: {:.5f}M".format(sum(p.numel() for p in model1.parameters()) / 1000000.0))#Model size: 0.00061M
out = model1(x1)
# print(out.shape)