class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
super().__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
# 利用断言函数判断此时隐藏状态的维度数n_state除以注意力头数config.n_head之后是否能整除.
assert n_state % config.n_head == 0
# 下方的self.register_buffer()函数的操作相当于创建了两个Attention类中的self属性, 即为self.bias属性
# 与self.masked_bias属性;
# 其中self.bias属性为一个下三角矩阵(对角线下元素全为1, 对角线上元素全为0), 其形状为(1, 1, n_ctx, n_ctx),
# 也即形状相当于(1, 1, 1024, 1024);
# 而self.masked_bias属性则为一个极大的负数-1e4;
self.register_buffer(
"bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.is_cross_attention = is_cross_attention
if self.is_cross_attention:
# self.c_attn = Conv1D(2 * n_state, nx)相当于全连接层, 其将输入张量的最后一个维度的维度数由nx(768)投影为
# 2 * n_state(2*768), 此时n_state = nx = num_head*head_features = 768.
self.c_attn = Conv1D(2 * n_state, nx)
# self.q_attn = Conv1D(n_state, nx)相当于全连接层, 其将输入张量的最后一个维度的维度数由nx(768)投影为
# n_state(768), 此时n_state = nx = num_head*head_features = 768.
self.q_attn = Conv1D(n_state, nx)
else:
# self.c_attn = Conv1D(3 * n_state, nx)相当于全连接层, 其将输入张量的最后一个维度的维度数由nx(768)投影为
# 2 * n_state(2*768), 此时n_state = nx = num_head*head_features = 768.
self.c_attn = Conv1D(3 * n_state, nx)
# 此处self.c_proj()为Conv1D(n_state, nx)函数(all_head_size=n_state=nx=768), 相当于一个全连接层的作用,
# 其将此时的多头注意力聚合操作结果张量a的最后一个维度all_head_size由n_state(768)的维度数投影为nx(768)的维度数.
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = set()
# prune_heads()可结合 https://github.com/huggingface/transformers/issues/850 理解.
def prune_heads(self, heads):
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
)
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)
def merge_heads(self, x):
# 此时x为: 利用计算得到的注意力分数张量对value张量进行注意力聚合后得到的注意力结果张量.
# x的形状为