Bert-Chinese-Text-Classification-Pytorch项目测试函数编写

 问题:

在这个git项目给了训练函数但是未提供测试函数,然后我根据自己的需求改了一下输入的模式

解决方案:

excle表格的形式输入大批量数据(可以根据自己的需要进行更改输入方式)

文本内容是由标题字段和正文字段组成的

输出是直接print,如果想要存入excle表格可以转成dataframe格式然后自己存到需要的位置

如果有需要可以自行更改

# coding: UTF-8
import torch
import pandas as pd
import tqdm
import time
from utils import build_iterator, get_time_dif
from models.bert import Model
from pytorch_pretrained import BertTokenizer



# 配置类
class Config(object):

    """配置参数"""
    def __init__(self, dataset, all_class):
        self.model_name = 'top30bert'
        self.data_path = dataset + '/sample_data.xlsx'                                   # 预测集
        self.class_list = all_class                                                      # 类别名单
        self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt'            # 模型训练结果
        self.device = torch.device('cuda')                                               # if torch.cuda.is_available() else 'cpu'
        self.require_improvement = 1000                                 # 若超过1000batch效果还没提升,则提前结束训练
        self.num_classes = len(self.class_list)                         # 类别数
        self.num_epochs = 3                                             # epoch数
        self.batch_size = 50                                           # mini-batch大小 原始128
        self.pad_size = 32                                              # 每句话处理成的长度(短填长切)
        self.learning_rate = 5e-5                                       # 学习率
        self.bert_path = './bert_pretrain'
        self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
        self.hidden_size = 768

# 加载数据函数
PAD, CLS = '[PAD]', '[CLS]'  # padding符号, bert中综合信息符号


def load_dataset(path, pad_size=32):
    contents = []
    print('data_path:', path)
    data = pd.read_excel(path)
    for i, row in tqdm.tqdm(data.iterrows()):
        # label = class2num[row['FOUR_TYPE_NAME']]
        content = str(row['CONTENT_TEXT']) + str(row['TITLE'])

        # content, label = lin.split('\t')  # 从tab分开出内容和标签
        token = config.tokenizer.tokenize(content)
        token = [CLS] + token
        seq_len = len(token)
        mask = []
        token_ids = config.tokenizer.convert_tokens_to_ids(token)

        if pad_size:
            if len(token) < pad_size:
                mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
                token_ids += ([0] * (pad_size - len(token)))
            else:
                mask = [1] * pad_size
                token_ids = token_ids[:pad_size]
                seq_len = pad_size
        contents.append((token_ids, 0, seq_len, mask))  # contents.append((token_ids, int(label), seq_len, mask))
    return contents

# 加载配置
classfile = pd.read_excel('data/class.xlsx')
all_class = [value for i, value in enumerate(classfile.iloc[:,0].tolist())]  # 所有类别名称列表
dataset = 'data'
config = Config(dataset, all_class)


# 加载数据,预处理
print("Loading data...")
start_time = time.time()
data = load_dataset(config.data_path, config.pad_size)  # list
# print("data",data)
print("data_len", len(data))
data_iter = build_iterator(data, config)  # utils.DatasetIterater
# print("data_iter",type(data_iter))
time_dif = get_time_dif(start_time)
print("Loading data Time usage:", time_dif)

# 创建模型
model = Model(config).to(config.device)
model.load_state_dict(torch.load(config.save_path))
model.eval()

# 开始预测
predictions = []
with torch.no_grad():
    for texts, _ in data_iter:
        # print('texts:', texts)  # text为tensor
        outputs = model(texts)
        # print(outputs.size())
        predic = torch.max(outputs, 1)[1].cpu().numpy()
        predicted_classes = [all_class[idx] for idx in predic]
        predictions.extend(predicted_classes)  # 将预测结果添加到predictions列表中

print(len(predictions))
for i in predictions:
    print(i)

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值