这个实现并不完整,只是bert的基础组件,代码中只是我自己的理解和实现,可能有不正确的地方(还有些地方还没理解就没写),具体nsp,mlm等下游任务模型构建代码请参考huggingface的transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import inspect
# BertForMaskedLM - BertModel,BertOnlyMLMHead
# BertModel - BertEmbedding, BertEncoder(no BertPooler)
# BertOnlyMLMHead - BertLMPredictionHead
# BertEncoder - BertLayer
# BertLayer - BertAttention, BertIntermediate, BertOutput
class Config(object):
def __init__(self):
self.vocab_size = 1000
self.hidden_size = 768
self.max_position_embeddings = 32 # 输入句子的最大长度
self.hidden_dropout_prob = 0.5
self.layer_norm_eps = 1e-6 # layer_norm中防止分母为0
self.type_vocab_size = 1 # 输入句子的类别,例如问和答
self.pad_token_id = 0
self.num_attention_heads = 12
self.attention_probs_dropout_prob = 0.5
self.position_embedding_type = "absolute"
self.intermediate_size = 512
self.hidden_act = "relu"
self.chunk_size_feed_forward = 0
self.num_hidden_layers = 12
class BertEmbedding(nn.Module):
def __init__(self, config):
super(BertEmbedding, self).__init__()
# 可学习的三个embeddings
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # 一个段落中最长的position
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) # token_type分别句子的类别
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.position_embeddings_type = getattr(config, "position_embeddings_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False)
# 注册到内存缓冲区中,参与模型训练,但参数不更新,不保存和恢复,只在每次创建模型实例时动态生成
self.register_buffer("token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0):
# 在BERT的某些变种中(例如,用于序列生成的BERT),模型可能需要用到之前的键值对来计算当前的输出。past_key_values_length参数告诉模型要从前一个时间步开始使用多少键值对,还没遇到
if input_ids is None: # 如果input_ids没有输入,输入在input_embeds
input_shape = inputs_embeds.size()[:-1]
else:
input_shape = input_ids.size()
seq_length = input_shape[1]
if position_ids is None: # 输入没有position_ids从缓冲提取
position_ids = self.position_ids[:, past_key_values_length:seq_length+past_key_values_length] # [1, seq_len]
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffer_token_type_ids = self.token_type_ids[:, :seq_length] # [1,seq_len]
token_type_ids_expanded = buffer_token_type_ids.expand(input_shape[0], seq_length) # [bs, seq_len]
token_type_ids = token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) # 重新申请的参数需要把它放到device上
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
if self.position_embeddings_type == "absolute": # 是否使用绝对位置嵌入
position_embeddings = self.position_embeddings(position_ids) # [1, hidden_size]
embeddings = embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertSelfAttention(nn.Module): # bert的自注意力机制
def __init__(self, config, position_embedding_type=None):
super(BertSelfAttention, self).__init__()
assert config.hidden_size % config.num_attention_heads == 0
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = config.hidden_size // config.num_attention_heads
self.all_head_size = config.hidden_size
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
def transpose_for_scores(self, x):
# [bs, seq-len, head, head_size
new_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_shape)
return x.permute(0, 2, 1, 3) # [bs, head, seq_len, head_size]
def forward(self, hidden_states, attention_mask=None, head_mask=None):
# hugggingface的源码中还有encoder_inputs等参数,现在还不知道有什么用,等用到了再说
# head_mask是用于将某些注意力头的计算无效化,研究不同的注意力头对模型性能的影响
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(self.query(hidden_states))
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.all_head_size) # [bs, head, seq_len, seq_len]
if attention_mask is not None:
attention_scores += attention_mask
# 加上attention_mask,其中的元素如果被mask就为-inf
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = self.dropout(attention_probs)
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
view_shape = context_layer.size()[: -2] + (self.all_head_size, )
outputs = context_layer.view(view_shape)
return (outputs, ) # 输入和输出shape一样,[bs, seq_len, hidden_size]
class BertSelfOutput(nn.Module): # 进行add和norm,注意力机制的output
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int=0):
# 对线性层的结构进行裁剪
# layer.weight.size() = (output_feature, input_feature)
# 减少某些输出结点和输入结点的连接
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
# dim=0,保留某些行,dim=1,保留某些列(即认为输出结点都保留,bias也都保留)
if layer.bias is not None: # bias [output_features]
if dim == 1: # 在输入特征维度上进行选择,偏置项不需要进行修剪,保留了所有输出特征对应的偏置项
b = layer.bias.clone().detach()
else: # 如果为0,那么对于bias的1*output_feature的矩阵来说
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
new_layer.wieght.requires_grad = False # 避免在复制时产生梯度计算
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
if layer.bias is not None:
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
class BertAttention(nn.Module):
# 这里面其实还有q,k,v中linear的裁剪,写到这里时还没看见用法,所以就没写
def __init__(self, config):
super().__init__()
self.attention = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None):
self_outputs = self.attention(hidden_states, attention_mask, head_mask) # 自注意力机制的输出应该是元组,outputs[0]代表的是自注意力机制的输出(没注意)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # 这里返回的也是元组
return outputs
ACT2FN = {} # 这是激活函数的字典
class BertIntermediate(nn.Module): # 增加模型的非线性,改变隐藏状态的维度(激活函数)
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class BertOutput(nn.Module): # intermeidate和bertoutput一起使用通过激活函数和正则化,残差连接
def __init__(self, config):
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) # 维度又回到了hidden_size
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
def apply_chunking_to_forward(forward_fn, chunk_size, chunk_dim, *input_tensors):
# 将输入张量按照指定的维度和块大小进行切分,然后独立地对每个块应用一个前向传播函数(forward_fn),以节省内存使用
# 参数检查,张量切分,分块处理,结果组合
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
if num_args_in_forward_chunk_fn != len(input_tensors): # 函数的输入参数量是否和input_tensor匹配
raise ValueError(
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
"tensors are given"
)
if chunk_size > 0:
tensor_shape = input_tensors[0].shape[chunk_dim]
for input_tensor in input_tensors: # 每一个元素在该维度上的维数是否一样
if input_tensor.shape[chunk_dim] != tensor_shape:
raise ValueError(
f"All input tenors have to be of the same shape: {tensor_shape}, "
f"found shape {input_tensor.shape[chunk_dim]}"
)
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
raise ValueError(
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
f"size {chunk_size}"
)
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
return torch.cat(output_chunks, dim=chunk_dim)
return forward_fn(*input_tensors)
class BertLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = BertAttention(config)
# 没使用交叉注意力机制
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
def forward(self, hidden_states, attention_mask, head_mask):
self_attention_output = self.attention(hidden_states, attention_mask, head_mask)
attention_output = self_attention_output[0]
outputs = self_attention_output[1:]
layer_output = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output)
outputs = (layer_output,) +outputs
return outputs
class BertEncoder(nn.Module):
def __init__(self, config):
self.config = config
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, output_hidden_states=False):
# output_hidden_states是否要输出每一个隐藏层的状态
all_hidden_states = () if output_hidden_states else None
for layer in self.layers:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states = layer(hidden_states, attention_mask)
if output_hidden_states: # 最后一层的hidden_states
all_hidden_states += (hidden_states,)
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) # 返回可能包含两个元素的元组
class BertPooler(nn.Module): # 一个分类器
def __init__(self, config):
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states):
first_token_tensor = hidden_states[:0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module): # transformer的预测头
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module): # 大模型的预测头,将数据放缩为词表大小
def __init__(self, config):
super().__init__()
self.transform = BertPredictionHeadTransform(config)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOnlyNSPHead(nn.Module): # nsp任务的头
def __init__(self, config):
super().__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2) # 上一句是否跟下一句有联系
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainedModel(PreTrainedModel):
config_class = BertConfig
base_model_prefix = "bert"
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() # nn.Embedding层的内部实现其实是查表 # seq_len中是对应元素的下标,也就是embedding权重矩阵中对应的行
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class BertModel(BertPreTrainedModel):
def __init__(self, config, add_pooling_layer=True):
super().__init__()
self.config = config
self.embeddings = BertEmbedding(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
self.init_weights()
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, output_hidden_states=None):
if attention_mask = None:
attention_mask = torch.ones_like(input_ids)
attention_mask[input_ids==self.config.pad_token_id] = 0
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# 利用广播机制适应[bs, head, seq_len, seq_len]
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask.float()
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
output_hidden_states=output_hidden_states
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return (sequence_output, pooled_output) + encoder_outputs[1:]
下游任务的模型构建代码请参考huggingface的transformers