N2-中文文本分类

一、准备工作

1.1 环境

pytorch=2.0.0,torchtext

1.2 训练数据

train.csv

1.3 加载数据

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings
device = torch.device("cuda")
device

device(type=‘cuda’)

import pandas as pd
train_data = pd.read_csv('./data/train.csv',sep='\t',header=None)
train_data.head()
#构造数据集迭代器
def coustom_data_iter(text, labels):
    for x,y in zip(text,labels):
        yield x,y
train_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])

二、数据预处理

2.1 构建词典

需要安装jieba

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import jieba
tokenizer = jieba.lcut

def yield_tokens(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]

label_name = list(set(train_data[1].values[:]))
print(label_name)

[‘Calendar-Query’, ‘Video-Play’, ‘Alarm-Update’, ‘Other’, ‘Travel-Query’, ‘FilmTele-Play’, ‘Music-Play’, ‘Weather-Query’, ‘Radio-Listen’, ‘HomeAppliance-Control’, ‘Audio-Play’, ‘TVProgram-Play’]

text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
1

2.2 生成数据批次和迭代器

from torch.utils.data import DataLoader
def collect_batch(batch):
    label_list, text_list , offsets = [],[],[0]
    for (_text,_label) in batch:
        label_list.append(label_pipeline(_label))
        processed_text= torch.tensor(text_pipeline(_text),dtype=torch.int64)
        text_list.append(processed_text)

        offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list,dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    return text_list.to(device), label_list.to(device),offsets.to(device)
dataloader = DataLoader(train_iter,batch_size=64,shuffle=False,collate_fn=collect_batch)

三、构建模型

3.1 搭建模型

class TextClassificationModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

3.2 初始化模型

num_class =len(label_name)
vacab_size = len(vocab)
em_size  = 64
model = TextClassificationModel(vacab_size,em_size,num_class).to(device)
Epoch = 20
learning_rate = 5
batch_size = 256
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss().to(device)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)

3.3 训练函数和评估函数

from tqdm import tqdm
def train(dataloader):
    model.train()  # 切换为训练模式
    total_acc, train_loss, total_count = 0, 0, 0
    par = tqdm(dataloader)
    for idx, (text,label,offsets) in enumerate(par):

        predicted_label = model(text, offsets)

        optimizer.zero_grad()                    # grad属性归零
        loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值
        loss.backward()                          # 反向传播
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪
        optimizer.step()  # 每一步自动更新

        # 记录acc与loss
        total_acc   += (predicted_label.argmax(1) == label).sum().item()
        train_loss  += loss.item()
        total_count += label.size(0)
        par.set_description('loss: %.3f | acc: %.3f' % (train_loss / (idx + 1), total_acc / total_count))

def evaluate(dataloader):
    model.eval()  # 切换为测试模式
    total_acc, train_loss, total_count = 0, 0, 0

    with torch.no_grad():
        par = tqdm(dataloader)
        for idx, (text,label,offsets) in enumerate(par):
            predicted_label = model(text, offsets)

            loss = criterion(predicted_label, label)  # 计算loss值
            # 记录测试数据
            total_acc   += (predicted_label.argmax(1) == label).sum().item()
            train_loss  += loss.item()
            total_count += label.size(0)
            par.set_description('loss: %.3f | acc: %.3f' % (train_loss / (idx + 1), total_acc / total_count))
    return total_acc/total_count, train_loss/total_count

torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) 是一个PyTorch函数,用于在训练神经网络时限制梯度的大小。这种操作被称为梯度裁剪(gradient clipping),可以防止梯度爆炸问题,从而提高神经网络的稳定性和性能。
model.parameters() 表示模型的所有参数。对于一个神经网络,参数通常包括权重和偏置项。
0.1 是一个指定的阈值,表示梯度的最大范数(L2范数)。如果计算出的梯度范数超过这个阈值,梯度会被缩放,使其范数等于阈值。

四、训练模型

4.1拆分数据

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
train_iter = coustom_data_iter(train_data[0].values[:],train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)
train_dataset, val_dataset = random_split(train_dataset, [int(len(train_dataset)*0.8), len(train_dataset)-int(len(train_dataset)*0.8)])
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,collate_fn=collect_batch)
val_dataloader = DataLoader(val_dataset,batch_size=batch_size,shuffle=True,collate_fn=collect_batch)

4.2 训练模型

total_accu = None
for epoch in range(1,Epoch + 1):
    train(train_dataloader)
    val_acc,val_loss = evaluate(val_dataloader)
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    if total_accu is not None and total_accu > val_acc:
        scheduler.step()
    else:
        total_accu = val_acc
    print('epoch: %d | lr: %.4f | val loss: %.4f | val acc: %.4f' % (epoch, learning_rate, val_loss, val_acc))

loss: 1.892 | acc: 0.453: 100%|██████████| 38/38 [00:00<00:00, 73.77it/s]
loss: 1.390 | acc: 0.663: 100%|██████████| 10/10 [00:00<00:00, 86.15it/s]
epoch: 1 | lr: 5.0000 | val loss: 0.0057 | val acc: 0.6632
loss: 1.083 | acc: 0.739: 100%|██████████| 38/38 [00:00<00:00, 72.82it/s]
loss: 0.874 | acc: 0.767: 100%|██████████| 10/10 [00:00<00:00, 81.83it/s]
epoch: 2 | lr: 5.0000 | val loss: 0.0036 | val acc: 0.7669
loss: 0.734 | acc: 0.804: 100%|██████████| 38/38 [00:00<00:00, 73.65it/s]
loss: 0.668 | acc: 0.815: 100%|██████████| 10/10 [00:00<00:00, 87.53it/s]
epoch: 3 | lr: 5.0000 | val loss: 0.0028 | val acc: 0.8149
loss: 0.578 | acc: 0.839: 100%|██████████| 38/38 [00:00<00:00, 76.48it/s]
loss: 0.556 | acc: 0.840: 100%|██████████| 10/10 [00:00<00:00, 82.11it/s]
epoch: 4 | lr: 5.0000 | val loss: 0.0023 | val acc: 0.8397
loss: 0.483 | acc: 0.864: 100%|██████████| 38/38 [00:00<00:00, 79.19it/s]
loss: 0.509 | acc: 0.855: 100%|██████████| 10/10 [00:00<00:00, 88.49it/s]
epoch: 5 | lr: 5.0000 | val loss: 0.0021 | val acc: 0.8550
loss: 0.416 | acc: 0.886: 100%|██████████| 38/38 [00:00<00:00, 76.90it/s]
loss: 0.459 | acc: 0.862: 100%|██████████| 10/10 [00:00<00:00, 77.67it/s]
epoch: 6 | lr: 5.0000 | val loss: 0.0019 | val acc: 0.8624
loss: 0.365 | acc: 0.899: 100%|██████████| 38/38 [00:00<00:00, 74.76it/s]
loss: 0.433 | acc: 0.870: 100%|██████████| 10/10 [00:00<00:00, 84.41it/s]
epoch: 7 | lr: 5.0000 | val loss: 0.0018 | val acc: 0.8698
loss: 0.325 | acc: 0.906: 100%|██████████| 38/38 [00:00<00:00, 76.51it/s]
loss: 0.411 | acc: 0.876: 100%|██████████| 10/10 [00:00<00:00, 85.80it/s]
epoch: 8 | lr: 5.0000 | val loss: 0.0017 | val acc: 0.8756
loss: 0.292 | acc: 0.920: 100%|██████████| 38/38 [00:00<00:00, 75.54it/s]
loss: 0.393 | acc: 0.877: 100%|██████████| 10/10 [00:00<00:00, 79.01it/s]
epoch: 9 | lr: 5.0000 | val loss: 0.0016 | val acc: 0.8769
loss: 0.263 | acc: 0.931: 100%|██████████| 38/38 [00:00<00:00, 73.83it/s]
loss: 0.377 | acc: 0.883: 100%|██████████| 10/10 [00:00<00:00, 82.71it/s]
epoch: 10 | lr: 5.0000 | val loss: 0.0016 | val acc: 0.8835
loss: 0.238 | acc: 0.937: 100%|██████████| 38/38 [00:00<00:00, 73.25it/s]
loss: 0.372 | acc: 0.887: 100%|██████████| 10/10 [00:00<00:00, 77.38it/s]
epoch: 11 | lr: 5.0000 | val loss: 0.0015 | val acc: 0.8868
loss: 0.217 | acc: 0.943: 100%|██████████| 38/38 [00:00<00:00, 71.55it/s]
loss: 0.361 | acc: 0.888: 100%|██████████| 10/10 [00:00<00:00, 77.22it/s]
epoch: 12 | lr: 5.0000 | val loss: 0.0015 | val acc: 0.8876
loss: 0.198 | acc: 0.950: 100%|██████████| 38/38 [00:00<00:00, 68.75it/s]
loss: 0.362 | acc: 0.890: 100%|██████████| 10/10 [00:00<00:00, 75.83it/s]
epoch: 13 | lr: 5.0000 | val loss: 0.0015 | val acc: 0.8905
loss: 0.181 | acc: 0.956: 100%|██████████| 38/38 [00:00<00:00, 65.00it/s]
loss: 0.353 | acc: 0.893: 100%|██████████| 10/10 [00:00<00:00, 78.17it/s]
epoch: 14 | lr: 5.0000 | val loss: 0.0015 | val acc: 0.8930
loss: 0.165 | acc: 0.961: 100%|██████████| 38/38 [00:00<00:00, 66.87it/s]
loss: 0.345 | acc: 0.895: 100%|██████████| 10/10 [00:00<00:00, 79.08it/s]
epoch: 15 | lr: 5.0000 | val loss: 0.0014 | val acc: 0.8946
loss: 0.152 | acc: 0.965: 100%|██████████| 38/38 [00:00<00:00, 67.15it/s]
loss: 0.347 | acc: 0.894: 100%|██████████| 10/10 [00:00<00:00, 74.05it/s]
epoch: 16 | lr: 5.0000 | val loss: 0.0014 | val acc: 0.8942
loss: 0.138 | acc: 0.971: 100%|██████████| 38/38 [00:00<00:00, 67.80it/s]
loss: 0.344 | acc: 0.894: 100%|██████████| 10/10 [00:00<00:00, 75.15it/s]
epoch: 17 | lr: 5.0000 | val loss: 0.0014 | val acc: 0.8938
loss: 0.137 | acc: 0.972: 100%|██████████| 38/38 [00:00<00:00, 67.71it/s]
loss: 0.346 | acc: 0.894: 100%|██████████| 10/10 [00:00<00:00, 76.61it/s]
epoch: 18 | lr: 5.0000 | val loss: 0.0014 | val acc: 0.8938
loss: 0.137 | acc: 0.972: 100%|██████████| 38/38 [00:00<00:00, 68.83it/s]
loss: 0.340 | acc: 0.894: 100%|██████████| 10/10 [00:00<00:00, 77.02it/s]
epoch: 19 | lr: 5.0000 | val loss: 0.0014 | val acc: 0.8938
loss: 0.136 | acc: 0.972: 100%|██████████| 38/38 [00:00<00:00, 70.25it/s]
loss: 0.347 | acc: 0.894: 100%|██████████| 10/10 [00:00<00:00, 76.30it/s]
epoch: 20 | lr: 5.0000 | val loss: 0.0014 | val acc: 0.8938

4.3 测试数据

def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item()

ex_text_str = "随便播放一首专辑阁楼里的佛里的歌"
#ex_text_str = "还有双鸭山到淮阴的汽车票吗13号的"

model = model.to("cpu")
print("该文本的类别是:%s" %label_name[predict(ex_text_str, text_pipeline)])

该文本的类别是:Music-Play

五、总结

这周通过调取jieba库实现对中文文本的分类,后续将学习自己建立词典进行语言文本分类。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值