问题:
在这个git项目给了训练函数但是未提供测试函数,然后我根据自己的需求改了一下输入的模式
解决方案:
以excle表格的形式输入大批量数据(可以根据自己的需要进行更改输入方式)
文本内容是由标题字段和正文字段组成的
输出是直接print,如果想要存入excle表格可以转成dataframe格式然后自己存到需要的位置
如果有需要可以自行更改
# coding: UTF-8
import torch
import pandas as pd
import tqdm
import time
from utils import build_iterator, get_time_dif
from models.bert import Model
from pytorch_pretrained import BertTokenizer
# 配置类
class Config(object):
"""配置参数"""
def __init__(self, dataset, all_class):
self.model_name = 'top30bert'
self.data_path = dataset + '/sample_data.xlsx' # 预测集
self.class_list = all_class # 类别名单
self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果
self.device = torch.device('cuda') # if torch.cuda.is_available() else 'cpu'
self.require_improvement = 1000 # 若超过1000batch效果还没提升,则提前结束训练
self.num_classes = len(self.class_list) # 类别数
self.num_epochs = 3 # epoch数
self.batch_size = 50 # mini-batch大小 原始128
self.pad_size = 32 # 每句话处理成的长度(短填长切)
self.learning_rate = 5e-5 # 学习率
self.bert_path = './bert_pretrain'
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
self.hidden_size = 768
# 加载数据函数
PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号
def load_dataset(path, pad_size=32):
contents = []
print('data_path:', path)
data = pd.read_excel(path)
for i, row in tqdm.tqdm(data.iterrows()):
# label = class2num[row['FOUR_TYPE_NAME']]
content = str(row['CONTENT_TEXT']) + str(row['TITLE'])
# content, label = lin.split('\t') # 从tab分开出内容和标签
token = config.tokenizer.tokenize(content)
token = [CLS] + token
seq_len = len(token)
mask = []
token_ids = config.tokenizer.convert_tokens_to_ids(token)
if pad_size:
if len(token) < pad_size:
mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
token_ids += ([0] * (pad_size - len(token)))
else:
mask = [1] * pad_size
token_ids = token_ids[:pad_size]
seq_len = pad_size
contents.append((token_ids, 0, seq_len, mask)) # contents.append((token_ids, int(label), seq_len, mask))
return contents
# 加载配置
classfile = pd.read_excel('data/class.xlsx')
all_class = [value for i, value in enumerate(classfile.iloc[:,0].tolist())] # 所有类别名称列表
dataset = 'data'
config = Config(dataset, all_class)
# 加载数据,预处理
print("Loading data...")
start_time = time.time()
data = load_dataset(config.data_path, config.pad_size) # list
# print("data",data)
print("data_len", len(data))
data_iter = build_iterator(data, config) # utils.DatasetIterater
# print("data_iter",type(data_iter))
time_dif = get_time_dif(start_time)
print("Loading data Time usage:", time_dif)
# 创建模型
model = Model(config).to(config.device)
model.load_state_dict(torch.load(config.save_path))
model.eval()
# 开始预测
predictions = []
with torch.no_grad():
for texts, _ in data_iter:
# print('texts:', texts) # text为tensor
outputs = model(texts)
# print(outputs.size())
predic = torch.max(outputs, 1)[1].cpu().numpy()
predicted_classes = [all_class[idx] for idx in predic]
predictions.extend(predicted_classes) # 将预测结果添加到predictions列表中
print(len(predictions))
for i in predictions:
print(i)