【Transformer】single self-attention的Pytorch实现

在这里插入图片描述

import torch.nn as nn
import torch
import matplotlib.pyplot as plt


class Self_Attention(nn.Module):
    def __init__(self, dim, dk, dv):
        super(Self_Attention, self).__init__()
        self.scale = dk ** -0.5
        self.q = nn.Linear(dim, dk)
        self.k = nn.Linear(dim, dk)
        self.v = nn.Linear(dim, dv)


    def forward(self, x):
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = attn @ v
        return x


att = Self_Attention(dim=2, dk=2, dv=3)
x = torch.rand((1, 4, 2))
output = att(x)



代码中首先创建了一个self-attention的类,然后在随机出输入x,(1,4,2) 1是batch_size 4是4个token,2是每个token的长度

然后把x传入对象,在self-attention的初始化函数中定义了 1 d k \frac{1}{\sqrt{d_k}} dk 1,并且从输入中提出q,k,v,其中q,k的维度是一定要保持一致的,

这里面 X的输入是(batchsize,num,dim_in)num是一维序列中token的个数,这里a1到a4就4个,dim_in是每个token的特征维数,这里每一个a都是1*2的向量,特征维度为2,dim_in就为2

对于Q、K、V的维度,W1 W2 W3分别是(dim_in,dq) (dim_in,dk) (dim_in,dv) 只不过dq肯定等于dk
这样X(batchsize,num,dim_in)才能分别与W1 、W2、W3相乘得到Q、K、V
Q(batchsize,num,dq) K(batchsize,num,dk) V(batchsize,num,dv)
这样 Q K T QK^T QKT (batchsize,num,dq)*(batchsize,dk,num) = (batchsize,num,num)

K T K^T KT是依靠k.transpose(-2,-1)实现的

self.scale = dk ** -0.5这一步是计算 1 d k \frac{1}{\sqrt{d_k}} dk 1

self.q = nn.Linear(dim, dk)
self.k = nn.Linear(dim, dk)
self.v = nn.Linear(dim, dv)

torch.nn.Linear的用法可以参考官方的文档

torch.nn.Linear

在这里插入图片描述
这里的三个操作实际上就构建了W1,W2,W3矩阵,这三个矩阵分别与X相乘就得到了Q K V

def forward(self, x):
    q = self.q(x)
    k = self.k(x)
    v = self.v(x)

这里就完成了Q K V的计算

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)

    x = attn @ v

这三步就分别完成了相似度分数的计算、相似度分数的归一化和最终计算

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

量子-Alex

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值