pytorch版和tensorflow版全连接注意力和多头注意力【机器学习】模型搭建

全连接注意力和多头注意力都是注意力机制的一种,用于在神经网络中引入注意力机制来实现信息交互和权重计算。它通过对输入的每个位置进行加权计算,来获得每个位置的表示。在全连接注意力中,通过计算查询(query)和键(key)的内积得到注意力权重,再将权重与值(value)相乘得到最终的输出下面我将分别解释全连接注意力和多头注意力,并提供用PyTorch实现的示例代码。
…全连接注意力示例…

class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super(SelfAttention, self).__init__()
        self.linear = nn.Linear(d_model, d_model)

    def forward(self, x):
        query = self.linear(x)
        key = self.linear(x)
        value = self.linear(x)

        attention_weight = torch.matmul(query, key.transpose(-2, -1))
        attention_weight /= query.size(-1) ** 0.5
        attention_weight = nn.functional.softmax(attention_weight, dim=-1)

        output = torch.matmul(attention_weight, value)

        return output


# 使用全连接注意力示例
d_model = 4
attn = SelfAttention(d_model)
x = torch.randn(2, 3, d_model)  # batchsize=2, sequence_length=3, hidden_size=d_model
output = attn(x)
print(output)

多头注意力是在全连接注意力的基础上引入了多个头(即多个并行的注意力机制)进行处理,以增加模型的表示能力和捕捉不同方面的信息。多头注意力将输入通过多个不同的查询、键和值映射得到多组不同的注意力表示,然后将这些表示进行拼接和线性变换得到最终输出。以下是用PyTorch实现的多头注意力示例代码:
…多头注意力示例…

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_model):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.depth = d_model // num_heads

        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)

        self.output_linear = nn.Linear(d_model, d_model)

    def forward(self, query, key, value):
        batch_size = query.size(0)

        query = self.query_linear(query).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
        key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
        value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)

        attention_weight = torch.matmul(query, key.transpose(-2, -1)) / (self.depth ** 0.5)
        attention_weight = nn.functional.softmax(attention_weight, dim=-1)

        attention_output = torch.matmul(attention_weight, value)
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        return self.output_linear(attention_output)

…下面是tensorflow版全连接注意力和多头注意力…

import tensorflow as tf


def fully_connected_attention(query, key, value):
    # 计算注意力得分
    score = tf.matmul(query, key, transpose_b=True)

    # 归一化得分
    attention_weight = tf.nn.softmax(score, axis=-1)

    # 加权求和得到注意力表示
    attention_output = tf.matmul(attention_weight, value)

    return attention_output


# 使用全连接注意力示例
query = tf.constant([[1.0, 2.0, 3.0]])
key = tf.constant([[4.0, 5.0, 6.0]])
value = tf.constant([[7.0, 8.0, 9.0]])
output = fully_connected_attention(query, key, value)
print(output.numpy())

…多头注意力示例代码…

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, num_heads, d_model):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.depth = d_model // num_heads

        self.query_dense = tf.keras.layers.Dense(d_model)
        self.key_dense = tf.keras.layers.Dense(d_model)
        self.value_dense = tf.keras.layers.Dense(d_model)

        self.output_dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, query, key, value):
        batch_size = tf.shape(query)[0]

        query = self.query_dense(query)
        key = self.key_dense(key)
        value = self.value_dense(value)

        query = self.split_heads(query, batch_size)
        key = self.split_heads(key, batch_size)
        value = self.split_heads(value, batch_size)

        attention_weight = tf.matmul(query, key, transpose_b=True)
        attention_weight /= tf.math.sqrt(tf.cast(self.depth, tf.float32))
        attention_weight = tf.nn.softmax(attention_weight, axis=-1)

        attention_output = tf.matmul(attention_weight, value)
        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
        attention_output = tf.reshape(attention_output, (batch_size, -1, self.d_model))

        return self.output_dense(attention_output)


# 使用多头注意力示例
num_heads = 2
d_model = 6
mha = MultiHeadAttention(num_heads, d_model)
query = tf.constant([[1.0, 2.0, 3.0]])
key = tf.constant([[4.0, 5.0, 6.0]])
value = tf.constant([[7.0, 8.0, 9.0]])
output = mha(query, key, value)
print(output.numpy())

以上代码为示例代码,展示了如何实现全连接注意力和多头注意力。不同的注意力机制适用于不同的情况,可以根据具体需求选择合适的注意力机制来应用。

  • 30
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值