# 需要导入模块: from torch import nn [as 别名]
# 或者: from torch.nn import TransformerEncoderLayer [as 别名]
def __init__(self, bert_config):
"""
:param bert_config: configuration for bert model
"""
super(BertABSATagger, self).__init__(bert_config)
self.num_labels = bert_config.num_labels
self.tagger_config = TaggerConfig()
self.tagger_config.absa_type = bert_config.absa_type.lower()
if bert_config.tfm_mode == 'finetune':
# initialized with pre-trained BERT and perform finetuning
# print("Fine-tuning the pre-trained BERT...")
self.bert = BertModel(bert_config)
else:
raise Exception("Invalid transformer mode %s!!!" % bert_config.tfm_mode)
self.bert_dropout = nn.Dropout(bert_config.hidden_dropout_prob)
# fix the parameters in BERT and rega