自注意力机制代码(pytorch版):
import torch
from torch import nn
class SelfAttention(nn.Module):
""" self attention module"""
def __init__(self, in_dim):
super(SelfAttention, self).__init__()
self.chanel_in = in_dim
self.query = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.key = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.value = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward_sing(self, x):
"""
inputs :
x : input feature maps( B X C X H X W)
returns :
out : attention value + input feature
attention: B X (HxW) X (HxW)
"""
m_batchsize, C, height, width = x.size()
proj_query = self.query(x).reshape(
m_batchsize, -1, width*height).permute(0, 2, 1)
proj_key = self.key(x).reshape(m_batchsize, -1, width*height)
energy = proj_query.bmm(proj_key)
attention = self.softmax(energy)
proj_value = self.value(x).reshape(m_batchsize, -1, width*height)
out = proj_value.bmm(attention.permute(0, 2, 1))
out = out.reshape(m_batchsize, C, height, width)
out = self.gamma * out + x
return out
def forward(self, x):
if x.ndim == 5:
B, T = x.shape[:2]
x = self.forward_sing(x.flatten(0, 1)).unflatten(0, (B, T))
return x
else:
return self.forward_sing(x)
``