基于TextCNN的中文文本分类(三)
1. 模型的训练和评估
1.1 模型训练过程
- 开启训练模式、设置优化器optimizer、初始化超参数
- 遍历训练数据,进行批量训练,设置每隔100轮查看训练集和验证集的效果
- 保存模型,若当前验证集的loss小于之前训练最好的loss,则保存本次训练的模型
1.2 模型的评估
- 模型评估时梯度不用更新,遍历验证集前需要使用with torch.grad()
1.3 代码
- 步骤一:TextCNN模型训练
train_eval.py
import numpy as np
import torch
import torch.nn.functional as F
from sklearn import metrics
def train(config, model, train_iter, dev_iter):
print("begin")
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
total_batch = 0
dev_best_loss = float('inf')
last_improve = 0
flag = False
for epoch in range(config.num_epochs):
print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
for i, (trains, labels) in enumerate(train_iter):
outputs = model(trains)
model.zero_grad()
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
if total_batch % 100 == 0:
true = labels.data.cpu()
predict = torch.max(outputs.data, 1)[1].cpu()
train_acc = metrics.accuracy_score(true, predict)
dev_acc, dev_loss = evaluate(config, model, dev_iter)
if dev_loss < dev_best_loss:
dev_best_loss = dev_loss
torch.save(model.state_dict(), config.save_path)
improve = '*'
last_improve = total_batch
else:
improve = ''
msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, ' \
' Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}'
print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, improve))
model.train()
total_batch += 1
if total_batch - last_improve > config.require_improvement:
print("No optimization for a long time, auto-stopping...")
flag = True
break
if flag:
break
def evaluate(config, model, data_iter, test=False):
model.eval()
loss_total = 0
predict_all = np.array([], dtype=int)
labels_all = np.array([], dtype=int)
with torch.no_grad():
for texts, labels in data_iter:
outputs = model(texts)
loss = F.cross_entropy(outputs, labels)
loss_total += loss
labels = labels.data.cpu().numpy()
predict = torch.max(outputs.data, 1)[1].cpu().numpy()
labels_all = np.append(labels_all, labels)
predict_all = np.append(predict_all, predict)
acc = metrics.accuracy_score(labels_all, predict_all)
if test:
report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
confusion = metrics.confusion_matrix(labels_all, predict_all)
return acc, loss_total / len(data_iter), report, confusion
return acc, loss_total / len(data_iter)
from TextCNN import Config
from TextCNN import Model
from load_data import build_dataset
from load_data_iter import build_iterator
from train_eval import train
if __name__ == "__main__":
config = Config()
print("Loading data...")
vocab, train_data, dev_data, test_data = build_dataset(config, False)
train_iter = build_iterator(train_data, config, False)
dev_iter = build_iterator(dev_data,config,False)
config.n_vocab = len(vocab)
model = Model(config).to(config.device)
print(model.parameters)
print(model.parameters)
train(config, model, train_iter, dev_iter)
1.5 运行结果
运行结果:
D:\Users\tarena\PycharmProjects\nlp\venv\Scripts\python.exe D:/Users/tarena/PycharmProjects/nlp/unit27/run.py
Loading data...
Vocab size: 4762
180000it [00:02, 75269.58it/s]
10000it [00:00, 51721.76it/s]
10000it [00:00, 65092.25it/s]
<bound method Module.parameters of Model(
(embedding): Embedding(4762, 300, padding_idx=4761)
(convs): ModuleList(
(0): Conv2d(1, 256, kernel_size=(2, 300), stride=(1, 1))
(1): Conv2d(1, 256, kernel_size=(3, 300), stride=(1, 1))
(2): Conv2d(1, 256, kernel_size=(4, 300), stride=(1, 1))
)
(dropout): Dropout(p=0.5, inplace=False)
(fc): Linear(in_features=768, out_features=10, bias=True)
)>
<bound method Module.parameters of Model(
(embedding): Embedding(4762, 300, padding_idx=4761)
(convs): ModuleList(
(0): Conv2d(1, 256, kernel_size=(2, 300), stride=(1, 1))
(1): Conv2d(1, 256, kernel_size=(3, 300), stride=(1, 1))
(2): Conv2d(1, 256, kernel_size=(4, 300), stride=(1, 1))
)
(dropout): Dropout(p=0.5, inplace=False)
(fc): Linear(in_features=768, out_features=10, bias=True)
)>
begin
Epoch [1/5]
Iter: 0, Train Loss: 2.5, Train Acc: 12.50%, Val Loss: 2.4, Val Acc: 13.30%