python中关于@的用法

这篇博客介绍了Python中装饰器的一种不常见应用,即用于矩阵的乘法运算。在Transformer的实现中,作者展示了如何通过@符号进行矩阵乘法,并解释了相关代码的工作原理。此外,还提供了示例代码来演示@符号与torch.mm和torch.matmul等函数在矩阵乘法上的等效性。
摘要由CSDN通过智能技术生成

第一种,表示的是修饰符,这篇Python中的装饰器及@用法详解博文里面有详细的讲解。
在这里记录一下另外种不常见的用法,就是用作矩阵的乘法。
在transformer里面有这样的一段代码,虽然我猜测到是矩阵的乘法计算,为了严谨另外也算是学习一下新的知识,查阅了一下资料。

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, drop=0., attn_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(drop)

    def forward(self, x):
        # print(x.shape)  # 2 26 128
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)
        #   # 2 4 26 32 32是qkv的dim
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        # print(attn.shape)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        # print(x.shape)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
import torch
a = torch.tensor([[1,1],[2,2]])
b = torch.tensor([[2,2],[3,3]])
c = a@b
c
# //***********//
tensor([[ 5,  5],
        [10, 10]])

可以看到执行的就是矩阵的乘法,与torch.mm和torch.matmul起的作用是一样的。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值