从BertForSequenceClassification分类出pooled_output作为final feature
BertForSequenceClassification函数在设计时并未返回pooled_output
参数:
class BertForSequenceClassification(BertPreTrainedModel):
#....
#....
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
#....
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
但有时我们又想用pooled_output的做后续的操作或分析其性质。一种方法是直接在huggingface的底层文件中更改BertForSequenceClassification返回的参数,但是比较麻烦,尤其不适用于多机器平台,不易于统一。
相对简便耦合性高的方法是先用BertForSequenceClassification对model进行训练,随后将model save起来,再用BertModel 将存储好的model 加载,传递参数给该model再返回pooled_output就可以。
class BertModel(BertPreTrainedModel):
#...
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
样例pseudocode:
model = BertForSequenceClassification.from_pretrained(checkpoint, num_labels=num_class,
model.save_pretrained('saved_path',state_dict=True)
model_eval = BertModel.from_pretrained(state_dict = 'saved_path').to(device)
model_eval.eval()
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
pooled_out = model_eval(input_ids=batch['input_ids'],
token_type_ids=batch['token_type_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels'])['pooler_out']
参考资料