SelfAttention教程

一个序列整3个向量出来 query key value
先随机算出来一个注意力权重和value,这里面的逻辑就是先论文的高手整出来的,不用理解
然后计算出output,然后和target对比 迭代更新,更新参数
这样能够对输入序列的不同位置关注度不同

note PyTorch 中,nn.Linear 层的权重和偏置默认是可训练的。
所以以下是可以被训练的参数(linear层就是简单的y = x * W^T + b)
self.W_q 中的权重和偏置。
self.W_k 中的权重和偏置。
self.W_v 中的权重和偏置。

本质就是计算一个token和其他所有token的关系

class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.W_q = nn.Linear(input_dim, input_dim)
        self.W_k = nn.Linear(input_dim, input_dim)
        self.W_v = nn.Linear(input_dim, input_dim)

    def forward(self, x):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        #自注意力机制可以捕获位置信息的原因是因为在计算注意力分数时,它考虑了查询和键之间的相对位置关系
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.input_dim ** 0.5)

        attention_weights = F.softmax(attention_scores, dim=-1)
        output = torch.matmul(attention_weights, V)
        return output


# 输入序列的维度为 4
input_dim = 4
# 创建一个输入序列,假设有三个位置(Token)
input_sequence = torch.rand((3, input_dim), requires_grad=True)

# 创建 Self-Attention 模块
self_attention = SelfAttention(input_dim)

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.SGD(self_attention.parameters(), lr=0.01)

# 模型训练
for epoch in range(100):
    # 前向传播
    output_sequence = self_attention(input_sequence)

    # 构造一个目标张量,这里简化为全零
    target = torch.zeros_like(output_sequence)

    # 计算损失
    loss = criterion(output_sequence, target)

    # 反向传播
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 输出每 10 次迭代的损失
    if (epoch + 1) % 10 == 0:
        print(f'Epoch {epoch + 1}/{100}, Loss: {loss.item()}')

# 训练完成后,查看最终的输出序列
print("\nFinal Output Sequence after Self-Attention:")
print(output_sequence.detach().numpy())

  • 6
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值