模型预测时使用单例模式避免重复加载
def singleton(cls):
# 单下划线的作用是这个变量只能在当前模块里访问,仅仅是一种提示作用
# 创建一个字典用来保存类的实例对象
_instance = {}
def _singleton(*args, **kwargs):
# 先判断这个类有没有对象
if cls not in _instance:
_instance[cls] = cls(*args, **kwargs) # 创建一个对象,并保存到字典当中
# 将实例对象返回
return _instance[cls]
return _singleton
@singleton
class LoadModel:
def __new__(cls, *args, **kwargs):
config = ConfigParser()
config.read('../config/config.ini')
config_path = config['DEFAULT']
model = BertForSequenceClassification.from_pretrained(
'../bert_pretrain/', # 使用 12-layer 的 BERT 模型.
config=os.path.join(config_path['bert_pretrain_path'], 'bert_config.json'),
num_labels=5, # 多分类任务的输出标签为 5个.
output_attentions=False, # 不返回 attentions weights.
output_hidden_states=False, # 不返回 all hidden-states.
)
model.load_state_dict(torch.load(os.path.join(config_path['classifier_model_path'], 'pytorch_model.bin')))
model.cuda()
model.eval()
return config_path, model