使用chatGPT编写的self-attention模块

import torch

# 计算两个向量的注意力分数
def attention_score(query, key):
  return torch.matmul(query, key.transpose(-2, -1))

# 计算注意力权重
def attention_weights(query, key, values):
  score = attention_score(query, key)
  weights = torch.softmax(score, dim=-1)
  return torch.matmul(weights, values)

# 构建自注意力层
class SelfAttention(torch.nn.Module):
  def __init__(self, input_size, hidden_size):
    super(SelfAttention, self).__init__()
    self.query = torch.nn.Linear(input_size, hidden_size)
    self.key = torch.nn.Linear(input_size, hidden_size)
    self.value = torch.nn.Linear(input_size, hidden_size)
    self.output = torch.nn.Linear(hidden_size, input_size)
  
  def forward(self, inputs):
    query = self.query(inputs)
    key = self.key(inputs)
    value = self.value(inputs)
    attention = attention_weights(query, key, value)
    return self.output(attention)

测试用例

# 定义测试用的输入
inputs = torch.randn(4, 5, 8)

# 创建一个自注意力层
attention = SelfAttention(8, 8)

# 计算自注意力层的输出
outputs = attention(inputs)

# 打印输出的形状
print(outputs.shape)  # 输出: (4, 5, 8)

在上面的示例代码中,输入张量的形状为 (4, 5, 8),其中 4、5、8 分别表示的含义如下:

4:批大小,即一次计算时处理的样本数量。在这个例子中,我们一次处理了 4 个样本。
5:序列长度,即每个样本中包含的元素数量。在这个例子中,每个样本由 5 个元素组成。
8:元素的维度,即每个元素的特征数量。在这个例子中,每个元素都由 8 个特征构成。
在实际应用中,这些维度的含义可能会有所不同。例如,在自然语言处理中,序列长度通常表示句子中的单词数量,元素维度表示单词的维度(例如词嵌入)。在图像处理中,序列长度可能表示图像中像素的数量,元素维度表示像素的通道数。总之,各个维度的含义取决于模型的具体应用场景。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值