# -*- coding: utf-8 -*-
from importlib import import_module
import torch
from utils import build_iterator
PAD, CLS = '[PAD]', '[CLS]'
pad_size=32
class TestNews():
def __init__(self):
dataset = 'D:\pythonWorkSpace\\bert\example\Bert-Chinese-Text-Classification-Pytorch-master\THUCNews' # 数据集
model_name = "bert" # bert
x = import_module('models.' + model_name)
self.config = x.Config(dataset)
self.model = x.Model(self.config).to(self.config.device)
# self.model = torch.load(self.config.save_path, map_location=torch.device('cpu'))
self.model.load_state_dict(torch.load(self.config.save_path, map_location=torch.device('cpu')), False) #此处确保模型训练时保存的是状态矩阵
self.model.eval()
def pre_process_text(self):
lin = test_data.strip()
content, label = lin.split('\t')
token = self.config.tokenizer.tokenize(content)
token = [CLS] + token
seq_len = len(token)
mask = []
token_ids = self.config.tokenizer.convert_tokens_to_ids(token)
if pad_size:
if len(token) < pad_size:
mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
token_ids += ([0] * (pad_size - len(token)))
else:
mask = [1] * pad_size
token_ids = token_ids[:pad_size]
seq_len = pad_size
data001 = (token_ids, int(label), seq_len, mask)
datalist = []
datalist.append(data001)
test_iterator = build_iterator(datalist, self.config)
return test_iterator
test_news_object = TestNews()
test_data = "观众挑刺新《鹿鼎记》 张纪中回应不碍事(图) 9"
model= test_news_object.model
test_data_iterator = test_news_object.pre_process_text()
for test_item, labels in test_data_iterator:
outputs = model(test_item)#模型预测
predic = torch.max(outputs.data, 1)[1].cpu().numpy()
print(predic)