第一种,表示的是修饰符,这篇Python中的装饰器及@用法详解博文里面有详细的讲解。
在这里记录一下另外种不常见的用法,就是用作矩阵的乘法。
在transformer里面有这样的一段代码,虽然我猜测到是矩阵的乘法计算,为了严谨另外也算是学习一下新的知识,查阅了一下资料。
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, drop=0., attn_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(drop)
def forward(self, x):
# print(x.shape) # 2 26 128
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
# # 2 4 26 32 32是qkv的dim
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# print(attn.shape)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# print(x.shape)
x = self.proj(x)
x = self.proj_drop(x)
return x
import torch
a = torch.tensor([[1,1],[2,2]])
b = torch.tensor([[2,2],[3,3]])
c = a@b
c
# //***********//
tensor([[ 5, 5],
[10, 10]])
可以看到执行的就是矩阵的乘法,与torch.mm和torch.matmul起的作用是一样的。