二十七、基于TextCNN中文文本分类(三)

基于TextCNN的中文文本分类(三)

1. 模型的训练和评估

1.1 模型训练过程

  • 开启训练模式、设置优化器optimizer、初始化超参数
  • 遍历训练数据,进行批量训练,设置每隔100轮查看训练集和验证集的效果
  • 保存模型,若当前验证集的loss小于之前训练最好的loss,则保存本次训练的模型

1.2 模型的评估

  • 模型评估时梯度不用更新,遍历验证集前需要使用with torch.grad()

1.3 代码

  • 步骤一:TextCNN模型训练train_eval.py
# coding: UTF-8
import numpy as np
import torch
import torch.nn.functional as F
from sklearn import metrics

# 训练
def train(config, model, train_iter, dev_iter):
    print("begin")
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float('inf')
    last_improve = 0  # 记录上次验证集loss下降的batch数
    flag = False  # 记录是否很久没有效果提升
    for epoch in range(config.num_epochs):
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
        # 批量训练
        for i, (trains, labels) in enumerate(train_iter):
            outputs = model(trains)
            model.zero_grad()
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            optimizer.step()
            if total_batch % 100 == 0:
                # 每多少轮输出在训练集和验证集上的效果
                true = labels.data.cpu()
                predict = torch.max(outputs.data, 1)[1].cpu()
                train_acc = metrics.accuracy_score(true, predict)
                dev_acc, dev_loss = evaluate(config, model, dev_iter)
                if dev_loss < dev_best_loss:
                    dev_best_loss = dev_loss
                    torch.save(model.state_dict(), config.save_path)
                    improve = '*'
                    last_improve = total_batch
                else:
                    improve = ''
                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%}, ' \
                      ' Val Loss: {3:>5.2},  Val Acc: {4:>6.2%}'
                print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, improve))
                model.train()
            total_batch += 1
            if total_batch - last_improve > config.require_improvement:
                # 验证集loss超过1000batch没下降,结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:
            break
  • 步骤三:主函数run.py
# 评价
def evaluate(config, model, data_iter, test=False):
    model.eval()
    loss_total = 0
    predict_all = np.array([], dtype=int)
    labels_all = np.array([], dtype=int)
    with torch.no_grad():
        for texts, labels in data_iter:
            outputs = model(texts)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predict = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predict)

    acc = metrics.accuracy_score(labels_all, predict_all)
    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
        confusion = metrics.confusion_matrix(labels_all, predict_all)
        return acc, loss_total / len(data_iter), report, confusion
    return acc, loss_total / len(data_iter)
  • 步骤四:主函数load_data_iter.py
# coding:utf-8

from TextCNN import Config
from TextCNN import Model
from load_data import build_dataset
from load_data_iter import build_iterator
from train_eval import train

if __name__ == "__main__":
    config = Config()
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, False)
    # 1. 批量加载数据
    train_iter = build_iterator(train_data, config, False)
    dev_iter = build_iterator(dev_data,config,False)

    config.n_vocab = len(vocab)
    # 2. 构建模型
    model = Model(config).to(config.device)
    print(model.parameters)

    # init_network(model)
    print(model.parameters)
    train(config, model, train_iter, dev_iter)

1.5 运行结果

运行结果:

D:\Users\tarena\PycharmProjects\nlp\venv\Scripts\python.exe D:/Users/tarena/PycharmProjects/nlp/unit27/run.py
Loading data...
Vocab size: 4762
180000it [00:02, 75269.58it/s]
10000it [00:00, 51721.76it/s]
10000it [00:00, 65092.25it/s]
<bound method Module.parameters of Model(
  (embedding): Embedding(4762, 300, padding_idx=4761)
  (convs): ModuleList(
    (0): Conv2d(1, 256, kernel_size=(2, 300), stride=(1, 1))
    (1): Conv2d(1, 256, kernel_size=(3, 300), stride=(1, 1))
    (2): Conv2d(1, 256, kernel_size=(4, 300), stride=(1, 1))
  )
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=768, out_features=10, bias=True)
)>
<bound method Module.parameters of Model(
  (embedding): Embedding(4762, 300, padding_idx=4761)
  (convs): ModuleList(
    (0): Conv2d(1, 256, kernel_size=(2, 300), stride=(1, 1))
    (1): Conv2d(1, 256, kernel_size=(3, 300), stride=(1, 1))
    (2): Conv2d(1, 256, kernel_size=(4, 300), stride=(1, 1))
  )
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=768, out_features=10, bias=True)
)>
begin
Epoch [1/5]
Iter:      0,  Train Loss:   2.5,  Train Acc: 12.50%,  Val Loss:   2.4,  Val Acc: 13.30%
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值