本方法基于hugginface的transformers项目改造
过程其实很简单,基于TFBertMainLayer下的call函数做了改造,如果直接用self.bert的输出其实就是CLS token 的结果。
使用时同样可以使用from_pretrained来加载下载好的BERT模型参数,
然后把输入的数据整理为input_ids, attention_mask,token_type_ids格式即可。
使用tf.datasets也可。
from transformers.modeling_tf_bert import TFBertModel
# 自定义一个Bert模型,其中self.index就是想要的层的索引
class Bert_layer(TFBertModel):
def __init__(self, config):
super(Bert_layer, self).__init__(config)
self.num_hidden_layers = config.num_hidden_layers
self.index = -1
def get_encoder_layer(self,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
position_ids = inputs[3] if len(inputs) > 3 else position_ids
head_mask = inputs[4] if len(inputs) > 4 else head_mask
inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
return_dict = inputs[8] if len(inputs) > 8 else return_dict
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, (dict)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
token_type_ids = inputs.get("token_type_ids", token_type_ids)
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
output_hidden_states = True
extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]
extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output = self.bert.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.bert.encoder(
embedding_output,
extended_attention_mask,
head_mask,
output_attentions,
output_hidden_states,
return_dict,
)
return encoder_outputs[1][self.index] # 其实就是抄了以下transformers里的call函数,主要就修改了return的内容