全连接注意力和多头注意力都是注意力机制的一种,用于在神经网络中引入注意力机制来实现信息交互和权重计算。它通过对输入的每个位置进行加权计算,来获得每个位置的表示。在全连接注意力中,通过计算查询(query)和键(key)的内积得到注意力权重,再将权重与值(value)相乘得到最终的输出下面我将分别解释全连接注意力和多头注意力,并提供用PyTorch实现的示例代码。
…全连接注意力示例…
class SelfAttention(nn.Module):
def __init__(self, d_model):
super(SelfAttention, self).__init__()
self.linear = nn.Linear(d_model, d_model)
def forward(self, x):
query = self.linear(x)
key = self.linear(x)
value = self.linear(x)
attention_weight = torch.matmul(query, key.transpose(-2, -1))
attention_weight /= query.size(-1) ** 0.5
attention_weight = nn.functional.softmax(attention_weight, dim=-1)
output = torch.matmul(attention_weight, value)
return output
# 使用全连接注意力示例
d_model = 4
attn = SelfAttention(d_model)
x = torch.randn(2, 3, d_model) # batchsize=2, sequence_length=3, hidden_size=d_model
output = attn(x)
print(output)
多头注意力是在全连接注意力的基础上引入了多个头(即多个并行的注意力机制)进行处理,以增加模型的表示能力和捕捉不同方面的信息。多头注意力将输入通过多个不同的查询、键和值映射得到多组不同的注意力表示,然后将这些表示进行拼接和线性变换得到最终输出。以下是用PyTorch实现的多头注意力示例代码:
…多头注意力示例…
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, d_model):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
self.depth = d_model // num_heads
self.query_linear = nn.Linear(d_model, d_model)
self.key_linear = nn.Linear(d_model, d_model)
self.value_linear = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)
def forward(self, query, key, value):
batch_size = query.size(0)
query = self.query_linear(query).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)
attention_weight = torch.matmul(query, key.transpose(-2, -1)) / (self.depth ** 0.5)
attention_weight = nn.functional.softmax(attention_weight, dim=-1)
attention_output = torch.matmul(attention_weight, value)
attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
return self.output_linear(attention_output)
…下面是tensorflow版全连接注意力和多头注意力…
import tensorflow as tf
def fully_connected_attention(query, key, value):
# 计算注意力得分
score = tf.matmul(query, key, transpose_b=True)
# 归一化得分
attention_weight = tf.nn.softmax(score, axis=-1)
# 加权求和得到注意力表示
attention_output = tf.matmul(attention_weight, value)
return attention_output
# 使用全连接注意力示例
query = tf.constant([[1.0, 2.0, 3.0]])
key = tf.constant([[4.0, 5.0, 6.0]])
value = tf.constant([[7.0, 8.0, 9.0]])
output = fully_connected_attention(query, key, value)
print(output.numpy())
…多头注意力示例代码…
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, num_heads, d_model):
super(MultiHeadAttention, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
self.depth = d_model // num_heads
self.query_dense = tf.keras.layers.Dense(d_model)
self.key_dense = tf.keras.layers.Dense(d_model)
self.value_dense = tf.keras.layers.Dense(d_model)
self.output_dense = tf.keras.layers.Dense(d_model)
def split_heads(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, query, key, value):
batch_size = tf.shape(query)[0]
query = self.query_dense(query)
key = self.key_dense(key)
value = self.value_dense(value)
query = self.split_heads(query, batch_size)
key = self.split_heads(key, batch_size)
value = self.split_heads(value, batch_size)
attention_weight = tf.matmul(query, key, transpose_b=True)
attention_weight /= tf.math.sqrt(tf.cast(self.depth, tf.float32))
attention_weight = tf.nn.softmax(attention_weight, axis=-1)
attention_output = tf.matmul(attention_weight, value)
attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
attention_output = tf.reshape(attention_output, (batch_size, -1, self.d_model))
return self.output_dense(attention_output)
# 使用多头注意力示例
num_heads = 2
d_model = 6
mha = MultiHeadAttention(num_heads, d_model)
query = tf.constant([[1.0, 2.0, 3.0]])
key = tf.constant([[4.0, 5.0, 6.0]])
value = tf.constant([[7.0, 8.0, 9.0]])
output = mha(query, key, value)
print(output.numpy())
以上代码为示例代码,展示了如何实现全连接注意力和多头注意力。不同的注意力机制适用于不同的情况,可以根据具体需求选择合适的注意力机制来应用。