pytorch 文本分类模型验证

# -*- coding: utf-8 -*-
from importlib import import_module

import torch

from utils import build_iterator

PAD, CLS = '[PAD]', '[CLS]'
pad_size=32

class TestNews():
    def __init__(self):
        dataset = 'D:\pythonWorkSpace\\bert\example\Bert-Chinese-Text-Classification-Pytorch-master\THUCNews'  # 数据集
        model_name = "bert"  # bert
        x = import_module('models.' + model_name)
        self.config = x.Config(dataset)
        self.model = x.Model(self.config).to(self.config.device)
        # self.model = torch.load(self.config.save_path, map_location=torch.device('cpu'))
        self.model.load_state_dict(torch.load(self.config.save_path, map_location=torch.device('cpu')), False) #此处确保模型训练时保存的是状态矩阵
        self.model.eval()

    def pre_process_text(self):
        lin = test_data.strip()
        content, label = lin.split('\t')
        token = self.config.tokenizer.tokenize(content)
        token = [CLS] + token
        seq_len = len(token)
        mask = []
        token_ids = self.config.tokenizer.convert_tokens_to_ids(token)
        if pad_size:
            if len(token) < pad_size:
                mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
                token_ids += ([0] * (pad_size - len(token)))
            else:
                mask = [1] * pad_size
                token_ids = token_ids[:pad_size]
                seq_len = pad_size
        data001 = (token_ids, int(label), seq_len, mask)
        datalist = []
        datalist.append(data001)
        test_iterator = build_iterator(datalist, self.config)
        return test_iterator

test_news_object = TestNews()
test_data = "观众挑刺新《鹿鼎记》 张纪中回应不碍事(图)    9"
model= test_news_object.model
test_data_iterator = test_news_object.pre_process_text()
for test_item, labels in test_data_iterator:
    outputs = model(test_item)#模型预测
    predic = torch.max(outputs.data, 1)[1].cpu().numpy()
    print(predic)

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值