from transformers.models.bert.modeling_bert import BertOnlyMLMHead
from transformers import BertConfig, BertPreTrainedModel, BertModel
class BertLMHeadModel(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
# Initialize weights and apply final processing
self.init_weights()
def forward(self):
pass
model = BertLMHeadModel.from_pretrained("bert-base-cased")
具体修改方法可以参考transformers.models.bert.modeling_bert文件,版本不一样会有部分改动。