六、新闻主题分类任务

以一段新闻报道中的文本描述内容为输入,使用模型帮助我们判断它最有可能属于哪一种类型的新闻,这是典型的文本分类问题。我们这里假定每种类型是互斥的,即文本描述有且只有一种类型,例如一篇新闻不能即是娱乐类又是财经类,只能是一种类别。

一、数据下载与介绍

我们使用的是AG_NEWS数据集,已经被集成在了torchtext中,下面是下载数据集的代码:

注意:

如果没有torchtext时,使用pip安装时会有一个大坑。

torchtext安装时会检查pytorch的版本,如果版本不兼容,它会卸载你的torch,然后安装一个GPU版本的兼容的torch,这个过程是自动的,没有什么提示,或者大部分人不会具体去看提示,这里会非常坑。

我在刚开始安装torchtext后,怎么也无法使用GPU,我还是以为是显卡有问题了,搞了好久最后才发现是torch被变成了CPU版本,刚开始不知道,就卸载torch,然后重装CUDA版本的torch,但是没用,最后装上的还是CPU版本的torch(torchtext真是霸道!),往复了几次都不行,怎么装都是CUP版本的torch,巨坑!!!

怎么寻找正确的torchtext版本?

一个简单的规律是,torchtext的版本号比torch高一个子版本,然后主版本为0, 阶段版本号最好也是对应的。例如:

torch1.13.1 对应的 torchtext 应该torchtext 0.14.1
那么应该使用下面命令安装
pip install torchtext==0.14.1

上面的规律是对应torch主版本为1的,torch主版本为2的可以参考类似的规律。

感谢博客《更新 torchtext 造成的torch版本不匹配的问题》带来的解答。

# 导入有关torch的工具包
import torch as tc
import torchtext
# 导入torchtext.datasets中的文本分类任务
from torchtext.datasets import AG_NEWS
import os

# 定义数据下载路径,当前路径的data文件夹
load_data_path = './Datasets/'
# 如果不存在该路径,则创建这个路径
if not os.path.exists(load_data_path):
    os.makedirs(load_data_path)

# 选取torchtext中的文本分类数据集'AG_NEWS'即新闻主题分类数据,保存在指定目录下
# 将数值映射后的训练和验证数据加载到内存中
train_data, test_data = AG_NEWS(
    root=load_data_path, split=('train', 'test'))

# AG_NEWS返回的数据是一个迭代器,每个元素都是一个元组,包含文本和标签
for (label, text) in train_data:
    print(f"Label: {label}, Text: {text}")
for (label, text) in test_data:
    print(f"Label: {label}, Text: {text}")

下载完成后,会有两个以.csv结尾的文件,

数据集中的内容如下:

"3","Fears for T N pension after talks","Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul."
"4","The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com)","SPACE.com - TORONTO, Canada -- A second\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\privately funded suborbital space flight, has officially announced the first\launch date for its manned rocket."

  1. 训练集有12000个样本,测试集有7600个样本。
  2. 一共有四种标签{1,2,3,4}对应{World,Sports,Business,SCI/Tech}分别指世界性新闻、体育新闻、商业新闻和技术类新闻。
  3. 每条样本有三列,第一列是标签,说明该新闻属于哪一类;第二列是新闻标题;第三列是新闻简述。
  4. test.csv和train.csv中的格式相同

二、构建Dataset类,读取数据

我们使用上面的代码将数据集进行保存后,新建一个Python文件,开始构建读取数据的Dataset类,代码如下:

#!------------------------第一步:数据读取,构建Dataset类--------------------------------
class AG_NEWS_Data(Dataset):
    def __init__(self, train=True) -> None:
        super().__init__()
        data_path = os.path.join(BASE_PATH, 'train.csv') if train else os.path.join(
            BASE_PATH, 'test.csv')  # 设置数据路径,本实验中只使用了训练集
        self.data = pd.read_csv(data_path, sep=',', header=None)  # 读取数据
        # print(self.data.head())

        sen_len = []  # 每条样本中文本句子长度
        self.contents = ''  # 所有样本分词后的内容
        token_number = 0  # 所有文本中有多少个不同的分词
        label_count = []  # 所有样本的label标签

        # * 计算每条样本的长度,取出每条样本的标签label,拼接所有样本内容到contents中
        for i in range(self.__len__()):
            content, label = self.__getitem__(i)
        # for content, label in data:
            sen_len.append(len(content.split(' ')))  # 每条样本的长度
            label_count.append(label)  # 取出每条样本的标签label
            self.contents += ' '+content  # 拼接样本内容到contents中

        vocab_dict = {v: idx for idx, v in enumerate(
            set(self.contents.split(' ')))}  # 获取所有分词集合
        token_number = len(vocab_dict)
        sen_len_distribution = {str(i): sen_len.count(i) for i in sorted(
            set(sen_len))}  # 句子长度分布的字典,如{'80':192,'81':689,...},即长度为80的句子有192个...
        label_n_distribution = {str(i): label_count.count(i) for i in set(
            label_count)}  # 标签数量分布的字典,如{'1':20000,'2':20000,...},每个标签对应的样本个数

        self.vocab_dict, self.token_number, self.sen_len_distribution, self.label_n_distribution = vocab_dict, token_number, sen_len_distribution, label_n_distribution

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        label = int(self.data.iloc[index, 0])  # 提取标签,并转换为int类型
        content = self.data.iloc[index, 1]+' ' + \
            self.data.iloc[index, 2]  # 拼接样本中的题目和内容文本
        content = content.lower()  # 将所有单词转换为小写类型
        # 使用正则表达式,只保留文本中的数字和单词,将其余信息替换为空格
        content = re.sub(r'[^\w\s]', ' ', content)
        content = re.sub(r'\s+', ' ', content)  # 将多个空格的位置替换为1个空格

        return content, label

三、构建网络模型

对网络中的每一层都要设置初始化权重值,权重值的初始换范围一般是一个小于1的数,可以接近零,但不能是0,是0的话,模型会变得特别难训练(大量的经验总结到的)。

只设置三层简单的线性层。

#! -------------------第二步:构建网络模型,构建带有Embedding层的文本分类模型-----
class TextSentiment(nn.Module):
    """文本分类模型"""

    def __init__(self, vocab_size, embed_dim, num_class):
        """description:类的初始化函数

        Args:
            vocab_size (int): 整个语料包含的不同词汇总数
            embed_dim (int): 指定词嵌入的维度
            num_class (int): 文本分类的类别总数
        """
        super().__init__()
        # 实例化Embedding层,sparse=True代表每次对该层求解梯度,只更新部分权重
        self.embedding = nn.Embedding(
            vocab_size, embedding_dim=embed_dim, sparse=True)
        # 实例化线性层,参数分别是embed_dim和num_class
        self.fc1 = nn.Linear(in_features=LEN_STA*EMBED_DIM, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=num_class)
        # 为各层初始化权重
        self.init_weight()

    def init_weight(self):
        """初始化权重函数
        """
        # 指定初始权重的取值范围数
        init_range = 0.5
        # 各层的权重参数都是初始化为均匀分布
        self.embedding.weight.data.uniform_(-init_range, init_range)
        for fc in [self.fc1, self.fc2, self.fc3]:
            fc.weight.data.uniform_(-init_range, init_range)
            # 偏置初始化为0
            fc.bias.data.zero_()

    def forward(self, text):
        """正向计算过程

        Args:
            text (list): 文本数值映射后的结果

        Returns:
            tensor: 与类别数尺寸相同的张量,用以判断文本类别
        """
        # 获得embedding的结果embedded
        # 此时embedded的尺寸为(m,32)其中m是BACTH_SIZE大小的数据中的词汇总数,32为指定词嵌入的维度EMBED_DIM
        # print(text.shape)
        embedded = self.embedding(text)
        # embedded = F.avg_pool1d(embedded, kernel_size=3)
        x = embedded.view(embedded.size(0), -1)
        # print(embedded.shape)
        # print(len_sta*EMBED_DIM)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        return x

四、序列化和长度标准化

规范输入句子的长度,并进行序列化,即将文本转换为tensor类型的整数,才可以进行Embedding操作。 可以使用one-hot编码进行序列化,这里为了方便直接使用了[0,1,2,3,4,...,]这种单纯的数字。

def get_length_standard(rate=0.9):
    """计算文本内容标准化长度的函数,根据样本文本长度的分布情况(从小到大),取前rate的分割点处的长度作为标准长度.

    Args:
        rate (float, optional): Defaults to 0.9.

    Returns:
        int: 样本文本长度分布中前rate的分割点处的长度
    """
    value_sum = 0  # 统计当前符合条件的样本总数
    sample_len = len(AG_NEWS)  # 数据集总长度

    # 取出每个长度对应的样本数量key=句子长度,value=该长度下的样本数量
    for key, value in AG_NEWS.sen_len_distribution.items():
        value_sum += int(value)
        if (value_sum/sample_len >= rate):
            return int(key)


def get_sen_ser(sentence, len_sta):
    """对样本内容进行标准化和序列化的函数,多删少补(补0)

    Args:
        sentence (str): [description]
        len_sta ([type]): [description]

    Returns:
        [type]: [description]
    """
    # 对句子进行序列化
    vocab_list = [AG_NEWS.vocab_dict[v] for v in sentence.split(' ')]

    if (len(vocab_list) >= len_sta):
        return vocab_list[:len_sta]
    else:
        vocab_list.extend([0]*(len_sta-len(vocab_list)))
        return vocab_list

五、自定义生成Batch的函数 

#! --------------------------第四步:自定义生成batch的函数----------------------

def generate_batch(batch):
    """生成batch数据的函数

    Args:
        batch (list): 由样本张量和对应标签的元组组成的batch_size大小的列表,形如:[(sample1,label1),(sample2,label2),...]
    Returns:
        tensor: 样本张量和标签各自的列表形式(张量),形如:text=tensor([sample1,sample2,....]),label=tensor([label1,label2,....])
    """
    label = []  # 存储样本标签
    text = []  # 存储样本的文本
    for t, l in batch:
        # 从batch中获得标签张量
        text.append(get_sen_ser(t, len_sta=LEN_STA))  # 对文本进行标准化和序列化处理
        # 从batch中获得样本张量
        label.append(int(l)-1)  # 序列化标签
    # text = tc.cat(text)
    # text = torch.tensor(np.array(text), device=device)
    text = torch.tensor(text, device=device)
    return text, torch.tensor(label, device=device)

六、构建训练函数

#!---------------------------第五步:构建训练函数----------------------------
def train(train_data):
    """模型训练函数"""
    # 初始化训练损失和准确率为0
    train_loss = 0
    train_acc = 0

    # 使用数据加载器生成BATCH_SIZE大小的数据进行批次训练
    # data就是N多个generate_batch函数处理后的BATCH_SIZE大小的数据生成器
    data = DataLoader(train_data, batch_size=BATCH_SIZE,
                      shuffle=True, collate_fn=generate_batch)  # 使用自定义的generate_batch函数

    # 对data进行循环遍历,使用每个batch的数据进行参数更新
    for text, label in data:
        # 1、设置优化器初始梯度为0
        optimizer.zero_grad()
        # 2、模型输入一个批次数据,获得输出
        label_pre = model(text)
        # 3、根据真实标签与模型输出计算损失
        loss = loss_F(label_pre, label)
        # 4、误差反向传播
        loss.backward()
        # 5、更新参数
        optimizer.step()

        # 将该批次的损失加到总损失中
        train_loss += loss.item()
        # 将该批次的准确率加到总准确率中
        train_acc += (label_pre.argmax(1) == label).sum().item()

    # 使用学习率调节器自动调整学习率
    scheduler.step()

    # 返回本轮训练的平均损失和平均准确率
    return train_loss/len(train_data), train_acc/len(train_data)

七、构建验证函数

#!-----------------------------第六步:构建验证函数------------------------
def val(val_data):
    model.eval()

    # 初始化训练损失和准确率为0
    val_loss = 0
    val_acc = 0

    # 使用数据加载器生成BATCH_SIZE大小的数据进行批次训练
    # data就是N多个generate_batch函数处理后的BATCH_SIZE大小的数据生成器
    data = DataLoader(val_data, batch_size=BATCH_SIZE,
                      shuffle=True, collate_fn=generate_batch)  # 使用自定义的generate_batch函数
    with torch.no_grad():
        for text, label in data:
            label_pre = model(text)
            # 根据真实标签与模型输出计算损失
            loss = loss_F(label_pre, label)

            # 将该损失加入到总损失中
            val_loss += loss

            # 将该次的准确个数加入到总个数中
            val_acc += (label_pre.argmax(1) == label).sum().item()
    # 返回本轮训练的平均损失和平均准确率
    return val_loss/len(val_data), val_acc/len(val_data)

八、模型训练和验证

if __name__ == '__main__':

    # 设置数据的存储路径
    BASE_PATH = r'H:\Pytorch学习\Datasets\datasets\AG_NEWS'
    # 检查显卡是否可用
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # 加载训练数据
    AG_NEWS = AG_NEWS_Data(train=True)  # 加载数据集
    generator = torch.Generator().manual_seed(2024)  # 设置随机数生成器和随机种子
    AG_NEWS_train, AG_NEWS_val = random_split(  # 划分训练集和验证集
        AG_NEWS, [0.7, 0.3], generator=generator)

    VOCAB_SIZE = len(AG_NEWS.vocab_dict)  # 获取train_data语料中包含的不同词汇总数
    BATCH_SIZE = 1000  # 指定BATCH_SIZE的大小
    EMBED_DIM = 32  # 指定词嵌入的维度
    NUN_CLASS = 4  # 类别总数
    LEARN_RATE = 0.005  # 学习率
    LEN_STA = get_length_standard(0.9)  # 每句话的规范长度,统一长度,多删少补
    EPOCH = 100  # 设置数据集迭代次数

    model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)  # 实例化模型
    loss_F = nn.CrossEntropyLoss().to(device)  # 设置损失函数
    optimizer = optim.SGD(model.parameters(), lr=LEARN_RATE)  # 设置优化函数
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=1, gamma=0.9)  # 设置学习率调整器

    # 进行模型训练和验证
    for epoch in range(EPOCH):
        train_loss, train_acc = train(AG_NEWS_train)
        print(
            f'epoch {epoch}:\ttrain_loss:{train_loss:.6f}\ttrain_acc:{train_acc:.6f}', end='\t')
        val_loss, val_acc = val(AG_NEWS_val)
        print(
            f'val_loss:{val_loss:.6f}\tval_acc:{val_acc:.6f}')

九、完整代码与输出结果

(一)完整代码

import re
import os
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.data import DataLoader
import torch.optim as optim
import torch


#!------------------------第一步:数据读取,构建Dataset类-----------------------------
class AG_NEWS_Data(Dataset):
    def __init__(self, train=True) -> None:
        super().__init__()
        data_path = os.path.join(BASE_PATH, 'train.csv') if train else os.path.join(
            BASE_PATH, 'test.csv')  # 设置数据路径,本实验中只使用了训练集
        self.data = pd.read_csv(data_path, sep=',', header=None)  # 读取数据
        # print(self.data.head())

        sen_len = []  # 每条样本中文本句子长度
        self.contents = ''  # 所有样本分词后的内容
        token_number = 0  # 所有文本中有多少个不同的分词
        label_count = []  # 所有样本的label标签

        # * 计算每条样本的长度,取出每条样本的标签label,拼接所有样本内容到contents中
        for i in range(self.__len__()):
            content, label = self.__getitem__(i)
        # for content, label in data:
            sen_len.append(len(content.split(' ')))  # 每条样本的长度
            label_count.append(label)  # 取出每条样本的标签label
            self.contents += ' '+content  # 拼接样本内容到contents中

        vocab_dict = {v: idx for idx, v in enumerate(
            set(self.contents.split(' ')))}  # 获取所有分词集合
        token_number = len(vocab_dict)
        sen_len_distribution = {str(i): sen_len.count(i) for i in sorted(
            set(sen_len))}  # 句子长度分布的字典,如{'80':192,'81':689,...},即长度为80的句子有192个...
        label_n_distribution = {str(i): label_count.count(i) for i in set(
            label_count)}  # 标签数量分布的字典,如{'1':20000,'2':20000,...},每个标签对应的样本个数

        self.vocab_dict, self.token_number, self.sen_len_distribution, self.label_n_distribution = vocab_dict, token_number, sen_len_distribution, label_n_distribution

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        label = int(self.data.iloc[index, 0])  # 提取标签,并转换为int类型
        content = self.data.iloc[index, 1]+' ' + \
            self.data.iloc[index, 2]  # 拼接样本中的题目和内容文本
        content = content.lower()  # 将所有单词转换为小写类型
        # 使用正则表达式,只保留文本中的数字和单词,将其余信息替换为空格
        content = re.sub(r'[^\w\s]', ' ', content)
        content = re.sub(r'\s+', ' ', content)  # 将多个空格的位置替换为1个空格

        return content, label

#! --------------------第二步:构建网络模型,构建带有Embedding层的文本分类模型------
class TextSentiment(nn.Module):
    """文本分类模型"""

    def __init__(self, vocab_size, embed_dim, num_class):
        """description:类的初始化函数

        Args:
            vocab_size (int): 整个语料包含的不同词汇总数
            embed_dim (int): 指定词嵌入的维度
            num_class (int): 文本分类的类别总数
        """
        super().__init__()
        # 实例化Embedding层,sparse=True代表每次对该层求解梯度,只更新部分权重
        self.embedding = nn.Embedding(
            vocab_size, embedding_dim=embed_dim, sparse=True)
        # 实例化线性层,参数分别是embed_dim和num_class
        self.fc1 = nn.Linear(in_features=LEN_STA*EMBED_DIM, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=num_class)
        # 为各层初始化权重
        self.init_weight()

    def init_weight(self):
        """初始化权重函数
        """
        # 指定初始权重的取值范围数
        init_range = 0.5
        # 各层的权重参数都是初始化为均匀分布
        self.embedding.weight.data.uniform_(-init_range, init_range)
        for fc in [self.fc1, self.fc2, self.fc3]:
            fc.weight.data.uniform_(-init_range, init_range)
            # 偏置初始化为0
            fc.bias.data.zero_()

    def forward(self, text):
        """正向计算过程

        Args:
            text (list): 文本数值映射后的结果

        Returns:
            tensor: 与类别数尺寸相同的张量,用以判断文本类别
        """
        # 获得embedding的结果embedded
        # 此时embedded的尺寸为(m,32)其中m是BACTH_SIZE大小的数据中的词汇总数,32为指定词嵌入的维度EMBED_DIM
        # print(text.shape)
        embedded = self.embedding(text)
        # embedded = F.avg_pool1d(embedded, kernel_size=3)
        x = embedded.view(embedded.size(0), -1)
        # print(embedded.shape)
        # print(len_sta*EMBED_DIM)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)

        return x


#! ----------------------第三步:将每个样本中的句子进行长度标准化和序列化-----------------
def get_length_standard(rate=0.9):
    """计算文本内容标准化长度的函数,根据样本文本长度的分布情况(从小到大),取前rate的分割点处的长度作为标准长度.

    Args:
        rate (float, optional): Defaults to 0.9.

    Returns:
        int: 样本文本长度分布中前rate的分割点处的长度
    """
    value_sum = 0  # 统计当前符合条件的样本总数
    sample_len = len(AG_NEWS)  # 数据集总长度

    # 取出每个长度对应的样本数量key=句子长度,value=该长度下的样本数量
    for key, value in AG_NEWS.sen_len_distribution.items():
        value_sum += int(value)
        if (value_sum/sample_len >= rate):
            return int(key)


def get_sen_ser(sentence, len_sta):
    """对样本内容进行标准化和序列化的函数,多删少补(补0)

    Args:
        sentence (str): [description]
        len_sta ([type]): [description]

    Returns:
        [type]: [description]
    """
    # 对句子进行序列化
    vocab_list = [AG_NEWS.vocab_dict[v] for v in sentence.split(' ')]

    if (len(vocab_list) >= len_sta):
        return vocab_list[:len_sta]
    else:
        vocab_list.extend([0]*(len_sta-len(vocab_list)))
        return vocab_list


#! --------------------------第四步:自定义生成batch的函数-------------------------
def generate_batch(batch):
    """生成batch数据的函数

    Args:
        batch (list): 由样本张量和对应标签的元组组成的batch_size大小的列表,形如:[(sample1,label1),(sample2,label2),...]
    Returns:
        tensor: 样本张量和标签各自的列表形式(张量),形如:text=tensor([sample1,sample2,....]),label=tensor([label1,label2,....])
    """
    label = []  # 存储样本标签
    text = []  # 存储样本的文本
    for t, l in batch:
        # 从batch中获得标签张量
        text.append(get_sen_ser(t, len_sta=LEN_STA))  # 对文本进行标准化和序列化处理
        # 从batch中获得样本张量
        label.append(int(l)-1)  # 序列化标签
    # text = tc.cat(text)
    # text = torch.tensor(np.array(text), device=device)
    text = torch.tensor(text, device=device)
    return text, torch.tensor(label, device=device)


#!-----------------------------------第五步:构建训练函数-------------------------
def train(train_data):
    """模型训练函数"""
    # 初始化训练损失和准确率为0
    train_loss = 0
    train_acc = 0

    # 使用数据加载器生成BATCH_SIZE大小的数据进行批次训练
    # data就是N多个generate_batch函数处理后的BATCH_SIZE大小的数据生成器
    data = DataLoader(train_data, batch_size=BATCH_SIZE,
                      shuffle=True, collate_fn=generate_batch)  # 使用自定义的generate_batch函数

    # 对data进行循环遍历,使用每个batch的数据进行参数更新
    for text, label in data:
        # 1、设置优化器初始梯度为0
        optimizer.zero_grad()
        # 2、模型输入一个批次数据,获得输出
        label_pre = model(text)
        # 3、根据真实标签与模型输出计算损失
        loss = loss_F(label_pre, label)
        # 4、误差反向传播
        loss.backward()
        # 5、更新参数
        optimizer.step()

        # 将该批次的损失加到总损失中
        train_loss += loss.item()
        # 将该批次的准确率加到总准确率中
        train_acc += (label_pre.argmax(1) == label).sum().item()

    # 使用学习率调节器自动调整学习率
    scheduler.step()

    # 返回本轮训练的平均损失和平均准确率
    return train_loss/len(train_data), train_acc/len(train_data)

#!-----------------------------第六步:构建验证函数--------------------------
def val(val_data):
    model.eval()

    # 初始化训练损失和准确率为0
    val_loss = 0
    val_acc = 0

    # 使用数据加载器生成BATCH_SIZE大小的数据进行批次训练
    # data就是N多个generate_batch函数处理后的BATCH_SIZE大小的数据生成器
    data = DataLoader(val_data, batch_size=BATCH_SIZE,
                      shuffle=True, collate_fn=generate_batch)  # 使用自定义的generate_batch函数
    with torch.no_grad():
        for text, label in data:
            label_pre = model(text)
            # 根据真实标签与模型输出计算损失
            loss = loss_F(label_pre, label)

            # 将该损失加入到总损失中
            val_loss += loss

            # 将该次的准确个数加入到总个数中
            val_acc += (label_pre.argmax(1) == label).sum().item()
    # 返回本轮训练的平均损失和平均准确率
    return val_loss/len(val_data), val_acc/len(val_data)


#! --------------------------第七步:进行模型训练和验证---------------------------------
if __name__ == '__main__':

    # 设置数据的存储路径
    BASE_PATH = r'H:\Pytorch学习\Datasets\datasets\AG_NEWS'
    # 检查显卡是否可用
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # 加载训练数据
    AG_NEWS = AG_NEWS_Data(train=True)  # 加载数据集
    generator = torch.Generator().manual_seed(2024)  # 设置随机数生成器和随机种子
    AG_NEWS_train, AG_NEWS_val = random_split(  # 划分训练集和验证集
        AG_NEWS, [0.7, 0.3], generator=generator)

    VOCAB_SIZE = len(AG_NEWS.vocab_dict)  # 获取train_data语料中包含的不同词汇总数
    BATCH_SIZE = 1000  # 指定BATCH_SIZE的大小
    EMBED_DIM = 32  # 指定词嵌入的维度
    NUN_CLASS = 4  # 类别总数
    LEARN_RATE = 0.005  # 学习率
    LEN_STA = get_length_standard(0.9)  # 每句话的规范长度,统一长度,多删少补
    EPOCH = 100  # 设置数据集迭代次数

    model = TextSentiment(VOCAB_SIZE, EMBED_DIM, NUN_CLASS).to(device)  # 实例化模型
    loss_F = nn.CrossEntropyLoss().to(device)  # 设置损失函数
    optimizer = optim.SGD(model.parameters(), lr=LEARN_RATE)  # 设置优化函数
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=1, gamma=0.9)  # 设置学习率调整器

    # 进行模型训练和验证
    for epoch in range(EPOCH):
        train_loss, train_acc = train(AG_NEWS_train)
        print(
            f'epoch {epoch}:\ttrain_loss:{train_loss:.6f}\ttrain_acc:{train_acc:.6f}', end='\t')
        val_loss, val_acc = val(AG_NEWS_val)
        print(
            f'val_loss:{val_loss:.6f}\tval_acc:{val_acc:.6f}')

(二)输出结果

 EPOCH = 100 ,设置数据集迭代了100次,结果如下,可以看出,模型能力有限,有预测能力,但只有一点点。

epoch 0:        train_loss:0.017730     train_acc:0.267131      val_loss:0.003573       val_acc:0.266333
epoch 1:        train_loss:0.002014     train_acc:0.274238      val_loss:0.001452       val_acc:0.284139
epoch 2:        train_loss:0.001417     train_acc:0.289357      val_loss:0.001417       val_acc:0.289583
epoch 3:        train_loss:0.001396     train_acc:0.292762      val_loss:0.001392       val_acc:0.293667
epoch 4:        train_loss:0.001389     train_acc:0.294369      val_loss:0.001386       val_acc:0.298972
epoch 5:        train_loss:0.001383     train_acc:0.298071      val_loss:0.001384       val_acc:0.297611
epoch 6:        train_loss:0.001381     train_acc:0.300738      val_loss:0.001382       val_acc:0.303028
epoch 7:        train_loss:0.001379     train_acc:0.303667      val_loss:0.001379       val_acc:0.302861
epoch 8:        train_loss:0.001376     train_acc:0.304119      val_loss:0.001375       val_acc:0.303528
epoch 9:        train_loss:0.001375     train_acc:0.304893      val_loss:0.001376       val_acc:0.300528
epoch 10:       train_loss:0.001374     train_acc:0.307119      val_loss:0.001372       val_acc:0.308639
epoch 11:       train_loss:0.001372     train_acc:0.308905      val_loss:0.001374       val_acc:0.303667
epoch 12:       train_loss:0.001371     train_acc:0.310357      val_loss:0.001372       val_acc:0.309667
epoch 13:       train_loss:0.001370     train_acc:0.311393      val_loss:0.001372       val_acc:0.309917
epoch 14:       train_loss:0.001369     train_acc:0.311607      val_loss:0.001370       val_acc:0.308667
epoch 15:       train_loss:0.001369     train_acc:0.311929      val_loss:0.001370       val_acc:0.312222
epoch 16:       train_loss:0.001368     train_acc:0.312952      val_loss:0.001369       val_acc:0.309778
epoch 17:       train_loss:0.001368     train_acc:0.313524      val_loss:0.001367       val_acc:0.314528
epoch 18:       train_loss:0.001367     train_acc:0.313905      val_loss:0.001368       val_acc:0.315444
epoch 19:       train_loss:0.001367     train_acc:0.314810      val_loss:0.001367       val_acc:0.315694
epoch 20:       train_loss:0.001366     train_acc:0.315952      val_loss:0.001368       val_acc:0.313333
epoch 21:       train_loss:0.001366     train_acc:0.317262      val_loss:0.001367       val_acc:0.314750
epoch 22:       train_loss:0.001366     train_acc:0.315976      val_loss:0.001366       val_acc:0.316222
epoch 23:       train_loss:0.001365     train_acc:0.317345      val_loss:0.001366       val_acc:0.316139
epoch 24:       train_loss:0.001365     train_acc:0.315976      val_loss:0.001366       val_acc:0.316444
epoch 25:       train_loss:0.001365     train_acc:0.316786      val_loss:0.001366       val_acc:0.314111
epoch 26:       train_loss:0.001365     train_acc:0.316905      val_loss:0.001365       val_acc:0.318611
epoch 27:       train_loss:0.001364     train_acc:0.318774      val_loss:0.001365       val_acc:0.316944
epoch 28:       train_loss:0.001364     train_acc:0.319036      val_loss:0.001366       val_acc:0.314944
epoch 29:       train_loss:0.001364     train_acc:0.318393      val_loss:0.001365       val_acc:0.316111
epoch 30:       train_loss:0.001364     train_acc:0.319250      val_loss:0.001365       val_acc:0.316833
epoch 31:       train_loss:0.001364     train_acc:0.318440      val_loss:0.001365       val_acc:0.317444
epoch 32:       train_loss:0.001364     train_acc:0.319500      val_loss:0.001365       val_acc:0.316444
epoch 33:       train_loss:0.001364     train_acc:0.319333      val_loss:0.001365       val_acc:0.315972
epoch 34:       train_loss:0.001363     train_acc:0.319786      val_loss:0.001365       val_acc:0.315389
epoch 35:       train_loss:0.001363     train_acc:0.319560      val_loss:0.001365       val_acc:0.316583
epoch 36:       train_loss:0.001363     train_acc:0.320024      val_loss:0.001365       val_acc:0.316556
epoch 37:       train_loss:0.001363     train_acc:0.320774      val_loss:0.001365       val_acc:0.316639
epoch 38:       train_loss:0.001363     train_acc:0.320179      val_loss:0.001365       val_acc:0.315889
epoch 39:       train_loss:0.001363     train_acc:0.320393      val_loss:0.001365       val_acc:0.315139
epoch 40:       train_loss:0.001363     train_acc:0.320774      val_loss:0.001365       val_acc:0.316278
epoch 41:       train_loss:0.001363     train_acc:0.320821      val_loss:0.001365       val_acc:0.315167
epoch 42:       train_loss:0.001363     train_acc:0.321167      val_loss:0.001365       val_acc:0.315667
epoch 43:       train_loss:0.001363     train_acc:0.320619      val_loss:0.001365       val_acc:0.316167
epoch 44:       train_loss:0.001363     train_acc:0.320571      val_loss:0.001365       val_acc:0.316778
epoch 45:       train_loss:0.001363     train_acc:0.321714      val_loss:0.001365       val_acc:0.316611
epoch 46:       train_loss:0.001363     train_acc:0.321143      val_loss:0.001365       val_acc:0.316000
epoch 47:       train_loss:0.001363     train_acc:0.321262      val_loss:0.001365       val_acc:0.316056
epoch 48:       train_loss:0.001363     train_acc:0.321429      val_loss:0.001365       val_acc:0.315722
epoch 49:       train_loss:0.001363     train_acc:0.321036      val_loss:0.001365       val_acc:0.315917
epoch 50:       train_loss:0.001363     train_acc:0.321417      val_loss:0.001365       val_acc:0.315639
epoch 51:       train_loss:0.001362     train_acc:0.321560      val_loss:0.001365       val_acc:0.315889
epoch 52:       train_loss:0.001362     train_acc:0.321524      val_loss:0.001365       val_acc:0.316056
epoch 53:       train_loss:0.001362     train_acc:0.321690      val_loss:0.001365       val_acc:0.315889
epoch 54:       train_loss:0.001362     train_acc:0.321429      val_loss:0.001365       val_acc:0.316028
epoch 55:       train_loss:0.001362     train_acc:0.321536      val_loss:0.001365       val_acc:0.316083
epoch 56:       train_loss:0.001362     train_acc:0.321417      val_loss:0.001365       val_acc:0.315639
epoch 57:       train_loss:0.001362     train_acc:0.321476      val_loss:0.001365       val_acc:0.315750
epoch 58:       train_loss:0.001362     train_acc:0.321512      val_loss:0.001365       val_acc:0.315806
epoch 59:       train_loss:0.001362     train_acc:0.321452      val_loss:0.001365       val_acc:0.315861
epoch 60:       train_loss:0.001362     train_acc:0.321750      val_loss:0.001365       val_acc:0.316000
epoch 61:       train_loss:0.001362     train_acc:0.321298      val_loss:0.001365       val_acc:0.315889
epoch 62:       train_loss:0.001362     train_acc:0.321405      val_loss:0.001365       val_acc:0.316000
epoch 63:       train_loss:0.001362     train_acc:0.321607      val_loss:0.001365       val_acc:0.315972
epoch 64:       train_loss:0.001362     train_acc:0.321583      val_loss:0.001365       val_acc:0.316111
epoch 65:       train_loss:0.001362     train_acc:0.321452      val_loss:0.001365       val_acc:0.316056
epoch 66:       train_loss:0.001362     train_acc:0.321452      val_loss:0.001365       val_acc:0.316111
epoch 67:       train_loss:0.001362     train_acc:0.321583      val_loss:0.001365       val_acc:0.316083
epoch 68:       train_loss:0.001362     train_acc:0.321464      val_loss:0.001365       val_acc:0.316111
epoch 69:       train_loss:0.001362     train_acc:0.321679      val_loss:0.001365       val_acc:0.316139
epoch 70:       train_loss:0.001362     train_acc:0.321476      val_loss:0.001365       val_acc:0.316139
epoch 71:       train_loss:0.001362     train_acc:0.321714      val_loss:0.001365       val_acc:0.316111
epoch 72:       train_loss:0.001362     train_acc:0.321679      val_loss:0.001365       val_acc:0.316056
epoch 73:       train_loss:0.001362     train_acc:0.321560      val_loss:0.001365       val_acc:0.316056
epoch 74:       train_loss:0.001362     train_acc:0.321583      val_loss:0.001365       val_acc:0.316028
epoch 75:       train_loss:0.001362     train_acc:0.321548      val_loss:0.001365       val_acc:0.316028
epoch 76:       train_loss:0.001362     train_acc:0.321500      val_loss:0.001365       val_acc:0.316083
epoch 77:       train_loss:0.001362     train_acc:0.321548      val_loss:0.001365       val_acc:0.316056
epoch 78:       train_loss:0.001362     train_acc:0.321571      val_loss:0.001365       val_acc:0.316056
epoch 79:       train_loss:0.001362     train_acc:0.321548      val_loss:0.001365       val_acc:0.316028
epoch 80:       train_loss:0.001362     train_acc:0.321631      val_loss:0.001365       val_acc:0.316028
epoch 81:       train_loss:0.001362     train_acc:0.321512      val_loss:0.001365       val_acc:0.316028
epoch 82:       train_loss:0.001362     train_acc:0.321536      val_loss:0.001365       val_acc:0.316056
epoch 83:       train_loss:0.001362     train_acc:0.321583      val_loss:0.001365       val_acc:0.316056
epoch 84:       train_loss:0.001362     train_acc:0.321512      val_loss:0.001365       val_acc:0.316056
epoch 85:       train_loss:0.001362     train_acc:0.321560      val_loss:0.001365       val_acc:0.316056
epoch 86:       train_loss:0.001362     train_acc:0.321583      val_loss:0.001365       val_acc:0.316056
epoch 87:       train_loss:0.001362     train_acc:0.321536      val_loss:0.001365       val_acc:0.316056
epoch 88:       train_loss:0.001362     train_acc:0.321548      val_loss:0.001365       val_acc:0.316056
epoch 89:       train_loss:0.001362     train_acc:0.321571      val_loss:0.001365       val_acc:0.316056
epoch 90:       train_loss:0.001362     train_acc:0.321536      val_loss:0.001365       val_acc:0.316056
epoch 91:       train_loss:0.001362     train_acc:0.321560      val_loss:0.001365       val_acc:0.316056
epoch 92:       train_loss:0.001362     train_acc:0.321560      val_loss:0.001365       val_acc:0.316083
epoch 93:       train_loss:0.001362     train_acc:0.321583      val_loss:0.001365       val_acc:0.316083
epoch 94:       train_loss:0.001362     train_acc:0.321571      val_loss:0.001365       val_acc:0.316056
epoch 95:       train_loss:0.001362     train_acc:0.321583      val_loss:0.001365       val_acc:0.316056
epoch 96:       train_loss:0.001362     train_acc:0.321548      val_loss:0.001365       val_acc:0.316056
epoch 97:       train_loss:0.001362     train_acc:0.321536      val_loss:0.001365       val_acc:0.316056
epoch 98:       train_loss:0.001362     train_acc:0.321548      val_loss:0.001365       val_acc:0.316056
epoch 99:       train_loss:0.001362     train_acc:0.321560      val_loss:0.001365       val_acc:0.316056

AG's News Topic Classification Dataset Version 3, Updated 09/09/2015 ORIGIN AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic comunity for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html . The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015). DESCRIPTION The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600. The file classes.txt contains a list of classes corresponding to each label. The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes ("), and any internal double quote is escaped by 2 double quotes (""). New lines are escaped by a backslash followed with an "n" character, that is "\n".
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

并不傻的狍子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值