第27周:Transformer实战:文本分类

目录

前言

一、前期准备

1.1 环境安装

1.2 加载数据

二、数据预处理

2.1 构建词典

2.2 生成数据批次和迭代器

2.3 构建数据集

三、模型构建

3.1 定义位置编码器

3.2 定义Transformer模型

3.3 初始化模型

3.4 定义训练函数

3.5 定义评估函数

四、训练模型

4.1 模型训练

4.2 模型评估

五、模型调优

总结


前言

说在前面

1)本周任务

  • 理解文中代码逻辑并成功运行
  • 根据自己的理解对代码进行调优,使准确率达到70%

2)运行环境:Python3.8、Pycharm2020、torch1.12.1+cu113


一、前期准备

1.1 环境安装

本文是基于Pytorch框架实现的文本分类

代码如下:

#一、准备工作
#1.1 环境安装
import torch,torchvision
print(torch.__version__)
print(torchvision.__version__)
import torch.nn as nn
from torchvision import transforms, datasets
import os, PIL,pathlib,warnings
import pandas as pd

warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

打印输出:

2.0.0+cu118
0.15.1+cu118
cuda

1.2 加载数据

代码如下:


#1.2 加载数据
#加载自定义中文数据
train_data = pd.read_csv('train.csv', sep='\t', header=None)
print(train_data.head())
#构造数据集迭代器
def custom_data_iter(texts, labels):
    for x, y in zip(texts, labels):
        yield x, y

train_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])

打印输出:

  0              1
0      还有双鸭山到淮阴的汽车票吗13号的   Travel-Query
1                从这里怎么回家   Travel-Query
2       随便播放一首专辑阁楼里的佛里的歌     Music-Play
3              给看一下墓王之王嘛  FilmTele-Play
4  我想看挑战两把s686打突变团竞的游戏视频     Video-Play
 

二、数据预处理

2.1 构建词典

需要安装jieba分词库,安装语句pip install jieba

代码如下(示例):

#二、数据预处理
#2.1 构建词典
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
import jieba

#中文分词方法
tokenizer = jieba.lcut
def yield_tokens(data_iter):
    for text,_ in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

vocab(['我', '想', '看', '和平', '精英', '上', '战神', '必备', '技巧', '的', '游戏', '视频'])
label_name = list(set(train_data[1].values[:]))
print('label name:', label_name)


text_pipeline = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)

print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

打印输出:

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\XiaoMa\AppData\Local\Temp\jieba.cache
Loading model cost 0.320 seconds.
Prefix dict has been built successfully.
label name: ['Radio-Listen', 'Other', 'Alarm-Update', 'Travel-Query', 'FilmTele-Play', 'Weather-Query', 'Audio-Play', 'HomeAppliance-Control', 'Music-Play', 'Calendar-Query', 'TVProgram-Play', 'Video-Play']
[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
11

2.2 生成数据批次和迭代器

代码如下:

#2.2 生成数据批次和迭代器
from torch.utils.data import DataLoader
def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_text, _label) in batch:
        # 标签列表
        label_list.append(label_pipeline(_label))
        # 文本列表
        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
        text_list.append(processed_text)
        # 偏移量
        offsets.append(processed_text.size(0))

    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)  # 返回维度dim中输入元素的累计和
    return text_list.to(device), label_list.to(device), offsets.to(device)

2.3 构建数据集

代码如下:

#2.3 构建数据集
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset

BATCH_SIZE = 4
train_iter = custom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)

split_train_, split_valid_ = random_split(train_dataset,
                                         [int(len(train_dataset)*0.8), int(len(train_dataset)*0.2)])
train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)

to_map_style_dataset()函数:作用是将一个迭代式的数据集(Iterable-style dataset)转换为映射式的数据集(Map-style dataset)。这个转换使得我们可以通过索引(例如:整数)更方便地访问数据集中的元素。在 PyTorch 中,数据集可以分为两种类型:Iterable-style 和 Map-style。
●Iterable-style 数据集实现了 __ iter__() 方法,可以迭代访问数据集中的元素,但不支持通过索引访问。
●Map-style 数据集实现了 __ getitem__() 和 __ len__() 方法,可以直接通过索引访问特定元素,并能获取数据集的大小。

三、模型构建

3.1 定义位置编码器

代码如下:

#三、模型构建
#3.1 定义位置编码函数
import math
#位置编码
class PositionalEncoding(nn.Module):
    "实现位置编码"
    def __init__(self, embed_dim, max_len=500):
        super(PositionalEncoding, self).__init__()
        # 初始化Shape为(max_len,embed_dim)的PE (positional encoding)
        pe = torch.zeros(max_len, embed_dim)
        # 初始化一个tensor [max_len, 1]
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        # 这里就是sin和cos括号中的内容,通过e和ln进行了变换
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(100.0) / embed_dim))

        pe[:, 0::2] = torch.sin(position * div_term)  # 计算PE(pos, 2i)
        pe[:, 1::2] = torch.cos(position * div_term)  # 计算PE(pos, 2i+1)

        pe = pe.unsqueeze(0).transpose(0, 1)  # 为了方便计算,在最外面在unsqueeze出一个batch

        # 如果一个参数不参与梯度下降,但又希望保存model的时候将其保存下来
        # 这个时候就可以用register_buffer
        # 这里将位置编码张量注册为模型的缓冲区,参数不参与梯度下降
        self.register_buffer('pe', pe)

    def forward(self, x):
        # 将x和positional encoding相加。
        #print(x.shape)
        #x = x.unsqueeze(1)
        #print(x.shape)
        #print(self.pe[:x.size(0)].shape)
        #x = x.unsqueeze(1).expand(-1, 1, -1)  # 将 x 的形状调整为 [4, 1, 64]
        x = x + self.pe[:x.size(0)]
        return x

3.2 定义Transformer模型

代码如下:

#3.2 定义Transformer模型
from tempfile import TemporaryDirectory
from typing import Tuple
from torch import nn, tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_class, nhead=8,d_hid=256,nlayers=12,dropout=0.1):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size,embed_dim,sparse=False)
        self.pos_encoder = PositionalEncoding(embed_dim)

        #定义编码器层
        encoder_layers = TransformerEncoderLayer(embed_dim, nhead, d_hid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers,nlayers)
        self.embed_dim = embed_dim
        self.linear = nn.Linear(embed_dim*4, num_class)

    def forward(self, src, offsets, src_mask=None):
        src = self.embedding(src, offsets)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)

        output = output.view(4, embed_dim * 4)
        output = self.linear(output)

        return output

3.3 初始化模型

代码如下:

#3.3 初始化模型
vocab_size = len(vocab)
embed_dim = 64
num_class = len(label_name)

model = TransformerModel(vocab_size, embed_dim, num_class).to(device)

3.4 定义训练函数

代码如下:


#3.4 定义训练函数
import time
def train(dataloader):
    model.train()
    total_acc, train_loss, total_count = 0, 0, 0
    log_interval = 300
    start_time = time.time()

    for idx, (text, label, offsets) in enumerate(dataloader):
        predicted_label = model(text, offsets)
        optimizer.zero_grad()

        loss = criterion(predicted_label, label)
        loss.backward()

        optimizer.step()

        #记录loss与acc
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        train_loss += loss.item()
        total_count += label.size(0)

        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch{:1d} | {:4d}/{:4d} batches'
                  'train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),
                                                                total_acc / total_count, train_loss / total_count))
            total_acc, train_loss, total_count = 0, 0, 0
            start_time = time.time()

3.5 定义评估函数

代码如下:

#3.5 定义评估函数
def evaluate(dataloader):
    model.eval()
    total_acc, train_loss, total_count = 0, 0, 0
    with torch.no_grad():
        for idx, (text, label, offsets) in enumerate(dataloader):
            predicted_label = model(text, offsets)

            loss = criterion(predicted_label, label)
            # 记录acc和loss
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            train_loss += loss.item()
            total_count += label.size(0)
    return total_acc/total_count, train_loss/total_count

四、训练模型

4.1 模型训练

代码如下:

#四、训练模型
#4.1 模型训练
epochs = 50
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

for epoch in range(1, epochs+1):
    epoch_start_time = time.time()
    train(train_dataloader)
    val_acc, val_loss = evaluate(valid_dataloader)

    #获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']

    print('-' * 69)
    print('| epoch {:d} | time:{:4.2f}s |'
          ' valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch, time.time() - epoch_start_time,
                                                                      val_acc, val_loss, lr))
    print('-' * 69)

代码输出:

| epoch1 |  300/2420 batchestrain_acc 0.105 train_loss 0.63530
| epoch1 |  600/2420 batchestrain_acc 0.101 train_loss 0.61687
| epoch1 |  900/2420 batchestrain_acc 0.133 train_loss 0.61168
| epoch1 | 1200/2420 batchestrain_acc 0.137 train_loss 0.60003
| epoch1 | 1500/2420 batchestrain_acc 0.135 train_loss 0.60039
| epoch1 | 1800/2420 batchestrain_acc 0.159 train_loss 0.58828
| epoch1 | 2100/2420 batchestrain_acc 0.142 train_loss 0.58723
| epoch1 | 2400/2420 batchestrain_acc 0.147 train_loss 0.57945
---------------------------------------------------------------------
| epoch 1 | time:32.54s | valid_acc 0.179 valid_loss 0.570 | lr 0.010000
---------------------------------------------------------------------
| epoch2 |  300/2420 batchestrain_acc 0.163 train_loss 0.57650
| epoch2 |  600/2420 batchestrain_acc 0.163 train_loss 0.56766
| epoch2 |  900/2420 batchestrain_acc 0.166 train_loss 0.57234
| epoch2 | 1200/2420 batchestrain_acc 0.168 train_loss 0.57338
| epoch2 | 1500/2420 batchestrain_acc 0.183 train_loss 0.56817
| epoch2 | 1800/2420 batchestrain_acc 0.200 train_loss 0.56484
| epoch2 | 2100/2420 batchestrain_acc 0.205 train_loss 0.56261
| epoch2 | 2400/2420 batchestrain_acc 0.207 train_loss 0.55943
---------------------------------------------------------------------
| epoch 2 | time:32.18s | valid_acc 0.204 valid_loss 0.557 | lr 0.010000
---------------------------------------------------------------------
| epoch3 |  300/2420 batchestrain_acc 0.190 train_loss 0.56242
| epoch3 |  600/2420 batchestrain_acc 0.191 train_loss 0.56532
| epoch3 |  900/2420 batchestrain_acc 0.206 train_loss 0.55682
| epoch3 | 1200/2420 batchestrain_acc 0.206 train_loss 0.56180
| epoch3 | 1500/2420 batchestrain_acc 0.226 train_loss 0.54646
| epoch3 | 1800/2420 batchestrain_acc 0.209 train_loss 0.55417
| epoch3 | 2100/2420 batchestrain_acc 0.202 train_loss 0.55837
| epoch3 | 2400/2420 batchestrain_acc 0.210 train_loss 0.54955
---------------------------------------------------------------------
| epoch 3 | time:31.72s | valid_acc 0.227 valid_loss 0.544 | lr 0.010000
---------------------------------------------------------------------
| epoch4 |  300/2420 batchestrain_acc 0.217 train_loss 0.54728
| epoch4 |  600/2420 batchestrain_acc 0.218 train_loss 0.54847
| epoch4 |  900/2420 batchestrain_acc 0.212 train_loss 0.55259
| epoch4 | 1200/2420 batchestrain_acc 0.220 train_loss 0.54698
| epoch4 | 1500/2420 batchestrain_acc 0.214 train_loss 0.55084
| epoch4 | 1800/2420 batchestrain_acc 0.235 train_loss 0.55156
| epoch4 | 2100/2420 batchestrain_acc 0.233 train_loss 0.54610
| epoch4 | 2400/2420 batchestrain_acc 0.223 train_loss 0.54245
---------------------------------------------------------------------
| epoch 4 | time:31.55s | valid_acc 0.232 valid_loss 0.544 | lr 0.010000
---------------------------------------------------------------------
| epoch5 |  300/2420 batchestrain_acc 0.208 train_loss 0.54385
| epoch5 |  600/2420 batchestrain_acc 0.234 train_loss 0.53694
| epoch5 |  900/2420 batchestrain_acc 0.217 train_loss 0.54622
| epoch5 | 1200/2420 batchestrain_acc 0.219 train_loss 0.54792
| epoch5 | 1500/2420 batchestrain_acc 0.253 train_loss 0.53146
| epoch5 | 1800/2420 batchestrain_acc 0.248 train_loss 0.53930
| epoch5 | 2100/2420 batchestrain_acc 0.239 train_loss 0.54326
| epoch5 | 2400/2420 batchestrain_acc 0.217 train_loss 0.54475
---------------------------------------------------------------------
| epoch 5 | time:31.55s | valid_acc 0.238 valid_loss 0.535 | lr 0.010000
---------------------------------------------------------------------
| epoch6 |  300/2420 batchestrain_acc 0.245 train_loss 0.53657
| epoch6 |  600/2420 batchestrain_acc 0.253 train_loss 0.53779
| epoch6 |  900/2420 batchestrain_acc 0.251 train_loss 0.53184
| epoch6 | 1200/2420 batchestrain_acc 0.258 train_loss 0.52866
| epoch6 | 1500/2420 batchestrain_acc 0.262 train_loss 0.53595
| epoch6 | 1800/2420 batchestrain_acc 0.250 train_loss 0.53333
| epoch6 | 2100/2420 batchestrain_acc 0.249 train_loss 0.52478
| epoch6 | 2400/2420 batchestrain_acc 0.269 train_loss 0.53164
---------------------------------------------------------------------
| epoch 6 | time:31.45s | valid_acc 0.283 valid_loss 0.519 | lr 0.010000
---------------------------------------------------------------------
| epoch7 |  300/2420 batchestrain_acc 0.273 train_loss 0.52782
| epoch7 |  600/2420 batchestrain_acc 0.289 train_loss 0.50863
| epoch7 |  900/2420 batchestrain_acc 0.296 train_loss 0.51765
| epoch7 | 1200/2420 batchestrain_acc 0.289 train_loss 0.51848
| epoch7 | 1500/2420 batchestrain_acc 0.320 train_loss 0.50162
| epoch7 | 1800/2420 batchestrain_acc 0.289 train_loss 0.50815
| epoch7 | 2100/2420 batchestrain_acc 0.316 train_loss 0.50151
| epoch7 | 2400/2420 batchestrain_acc 0.304 train_loss 0.51635
---------------------------------------------------------------------
| epoch 7 | time:31.54s | valid_acc 0.318 valid_loss 0.497 | lr 0.010000
---------------------------------------------------------------------
| epoch8 |  300/2420 batchestrain_acc 0.315 train_loss 0.49451
| epoch8 |  600/2420 batchestrain_acc 0.341 train_loss 0.49457
| epoch8 |  900/2420 batchestrain_acc 0.332 train_loss 0.48540
| epoch8 | 1200/2420 batchestrain_acc 0.328 train_loss 0.48078
| epoch8 | 1500/2420 batchestrain_acc 0.356 train_loss 0.47262
| epoch8 | 1800/2420 batchestrain_acc 0.373 train_loss 0.46420
| epoch8 | 2100/2420 batchestrain_acc 0.356 train_loss 0.47481
| epoch8 | 2400/2420 batchestrain_acc 0.395 train_loss 0.46700
---------------------------------------------------------------------
| epoch 8 | time:32.24s | valid_acc 0.359 valid_loss 0.471 | lr 0.010000
---------------------------------------------------------------------
| epoch9 |  300/2420 batchestrain_acc 0.395 train_loss 0.46218
| epoch9 |  600/2420 batchestrain_acc 0.384 train_loss 0.45515
| epoch9 |  900/2420 batchestrain_acc 0.399 train_loss 0.45004
| epoch9 | 1200/2420 batchestrain_acc 0.428 train_loss 0.44382
| epoch9 | 1500/2420 batchestrain_acc 0.396 train_loss 0.45083
| epoch9 | 1800/2420 batchestrain_acc 0.422 train_loss 0.43863
| epoch9 | 2100/2420 batchestrain_acc 0.409 train_loss 0.44288
| epoch9 | 2400/2420 batchestrain_acc 0.399 train_loss 0.44968
---------------------------------------------------------------------
| epoch 9 | time:32.30s | valid_acc 0.447 valid_loss 0.424 | lr 0.010000
---------------------------------------------------------------------
| epoch10 |  300/2420 batchestrain_acc 0.428 train_loss 0.43655
| epoch10 |  600/2420 batchestrain_acc 0.449 train_loss 0.42489
| epoch10 |  900/2420 batchestrain_acc 0.452 train_loss 0.41688
| epoch10 | 1200/2420 batchestrain_acc 0.438 train_loss 0.43090
| epoch10 | 1500/2420 batchestrain_acc 0.432 train_loss 0.43513
| epoch10 | 1800/2420 batchestrain_acc 0.477 train_loss 0.40354
| epoch10 | 2100/2420 batchestrain_acc 0.456 train_loss 0.41597
| epoch10 | 2400/2420 batchestrain_acc 0.478 train_loss 0.41560
---------------------------------------------------------------------
| epoch 10 | time:32.10s | valid_acc 0.457 valid_loss 0.433 | lr 0.010000
---------------------------------------------------------------------

4.2 模型评估

代码如下:

#4.2 模型评估
test_acc, test_loss = evaluate(valid_dataloader)
print('模型准确率为:{:5.4f}'.format(test_acc))

打印输出:

模型准确率为:0.4479

五、模型调优

5.1 尝试修改了优化器为Adam,结果变得更差了,相比之下,本文任务更适合使用SGD优化器

5.2 增加了epoch数,从原来的10增加到了50

模型准确率达到了76.2%

2.0.0+cu118
0.15.1+cu118
cuda
                       0              1
0      还有双鸭山到淮阴的汽车票吗13号的   Travel-Query
1                从这里怎么回家   Travel-Query
2       随便播放一首专辑阁楼里的佛里的歌     Music-Play
3              给看一下墓王之王嘛  FilmTele-Play
4  我想看挑战两把s686打突变团竞的游戏视频     Video-Play
Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\XiaoMa\AppData\Local\Temp\jieba.cache
Loading model cost 0.400 seconds.
Prefix dict has been built successfully.
label name: ['HomeAppliance-Control', 'Music-Play', 'Video-Play', 'Weather-Query', 'Other', 'Audio-Play', 'Calendar-Query', 'FilmTele-Play', 'Travel-Query', 'Radio-Listen', 'TVProgram-Play', 'Alarm-Update']
[2, 10, 13, 973, 1079, 146, 7724, 7574, 7793, 1, 186, 28]
2
| epoch1 |  300/2420 batchestrain_acc 0.114 train_loss 0.63097
| epoch1 |  600/2420 batchestrain_acc 0.122 train_loss 0.61782
| epoch1 |  900/2420 batchestrain_acc 0.113 train_loss 0.61274
| epoch1 | 1200/2420 batchestrain_acc 0.159 train_loss 0.59832
| epoch1 | 1500/2420 batchestrain_acc 0.142 train_loss 0.59311
| epoch1 | 1800/2420 batchestrain_acc 0.154 train_loss 0.58369
| epoch1 | 2100/2420 batchestrain_acc 0.158 train_loss 0.58276
| epoch1 | 2400/2420 batchestrain_acc 0.152 train_loss 0.57642
---------------------------------------------------------------------
| epoch 1 | time:51.73s | valid_acc 0.146 valid_loss 0.577 | lr 0.010000
---------------------------------------------------------------------
| epoch2 |  300/2420 batchestrain_acc 0.154 train_loss 0.57415
| epoch2 |  600/2420 batchestrain_acc 0.182 train_loss 0.56760
| epoch2 |  900/2420 batchestrain_acc 0.187 train_loss 0.57234
| epoch2 | 1200/2420 batchestrain_acc 0.171 train_loss 0.57135
| epoch2 | 1500/2420 batchestrain_acc 0.201 train_loss 0.56401
| epoch2 | 1800/2420 batchestrain_acc 0.200 train_loss 0.55670
| epoch2 | 2100/2420 batchestrain_acc 0.182 train_loss 0.56958
| epoch2 | 2400/2420 batchestrain_acc 0.186 train_loss 0.56489
---------------------------------------------------------------------
| epoch 2 | time:51.24s | valid_acc 0.205 valid_loss 0.560 | lr 0.010000
---------------------------------------------------------------------
| epoch3 |  300/2420 batchestrain_acc 0.199 train_loss 0.55668
| epoch3 |  600/2420 batchestrain_acc 0.189 train_loss 0.56699
| epoch3 |  900/2420 batchestrain_acc 0.203 train_loss 0.56305
| epoch3 | 1200/2420 batchestrain_acc 0.188 train_loss 0.56545
| epoch3 | 1500/2420 batchestrain_acc 0.213 train_loss 0.55726
| epoch3 | 1800/2420 batchestrain_acc 0.211 train_loss 0.55215
| epoch3 | 2100/2420 batchestrain_acc 0.233 train_loss 0.55473
| epoch3 | 2400/2420 batchestrain_acc 0.187 train_loss 0.55733
---------------------------------------------------------------------
| epoch 3 | time:53.07s | valid_acc 0.221 valid_loss 0.550 | lr 0.010000
---------------------------------------------------------------------
| epoch4 |  300/2420 batchestrain_acc 0.229 train_loss 0.54875
| epoch4 |  600/2420 batchestrain_acc 0.211 train_loss 0.55788
| epoch4 |  900/2420 batchestrain_acc 0.207 train_loss 0.55474
| epoch4 | 1200/2420 batchestrain_acc 0.220 train_loss 0.54492
| epoch4 | 1500/2420 batchestrain_acc 0.241 train_loss 0.55204
| epoch4 | 1800/2420 batchestrain_acc 0.237 train_loss 0.54883
| epoch4 | 2100/2420 batchestrain_acc 0.225 train_loss 0.54803
| epoch4 | 2400/2420 batchestrain_acc 0.236 train_loss 0.54462
---------------------------------------------------------------------
| epoch 4 | time:52.62s | valid_acc 0.210 valid_loss 0.561 | lr 0.010000
---------------------------------------------------------------------
| epoch5 |  300/2420 batchestrain_acc 0.229 train_loss 0.54333
| epoch5 |  600/2420 batchestrain_acc 0.235 train_loss 0.55842
| epoch5 |  900/2420 batchestrain_acc 0.247 train_loss 0.53037
| epoch5 | 1200/2420 batchestrain_acc 0.235 train_loss 0.53962
| epoch5 | 1500/2420 batchestrain_acc 0.224 train_loss 0.54456
| epoch5 | 1800/2420 batchestrain_acc 0.238 train_loss 0.54416
| epoch5 | 2100/2420 batchestrain_acc 0.237 train_loss 0.53918
| epoch5 | 2400/2420 batchestrain_acc 0.228 train_loss 0.53598
---------------------------------------------------------------------
| epoch 5 | time:51.53s | valid_acc 0.250 valid_loss 0.531 | lr 0.010000
---------------------------------------------------------------------
| epoch6 |  300/2420 batchestrain_acc 0.238 train_loss 0.53026
| epoch6 |  600/2420 batchestrain_acc 0.251 train_loss 0.54157
| epoch6 |  900/2420 batchestrain_acc 0.247 train_loss 0.53594
| epoch6 | 1200/2420 batchestrain_acc 0.240 train_loss 0.53341
| epoch6 | 1500/2420 batchestrain_acc 0.255 train_loss 0.52823
| epoch6 | 1800/2420 batchestrain_acc 0.256 train_loss 0.53342
| epoch6 | 2100/2420 batchestrain_acc 0.248 train_loss 0.53735
| epoch6 | 2400/2420 batchestrain_acc 0.251 train_loss 0.52731
---------------------------------------------------------------------
| epoch 6 | time:51.63s | valid_acc 0.275 valid_loss 0.522 | lr 0.010000
---------------------------------------------------------------------
| epoch7 |  300/2420 batchestrain_acc 0.281 train_loss 0.51962
| epoch7 |  600/2420 batchestrain_acc 0.268 train_loss 0.52640
| epoch7 |  900/2420 batchestrain_acc 0.263 train_loss 0.52401
| epoch7 | 1200/2420 batchestrain_acc 0.274 train_loss 0.51380
| epoch7 | 1500/2420 batchestrain_acc 0.297 train_loss 0.52157
| epoch7 | 1800/2420 batchestrain_acc 0.300 train_loss 0.50638
| epoch7 | 2100/2420 batchestrain_acc 0.284 train_loss 0.51751
| epoch7 | 2400/2420 batchestrain_acc 0.298 train_loss 0.50865
---------------------------------------------------------------------
| epoch 7 | time:51.01s | valid_acc 0.321 valid_loss 0.498 | lr 0.010000
---------------------------------------------------------------------
| epoch8 |  300/2420 batchestrain_acc 0.313 train_loss 0.49857
| epoch8 |  600/2420 batchestrain_acc 0.320 train_loss 0.49511
| epoch8 |  900/2420 batchestrain_acc 0.338 train_loss 0.48080
| epoch8 | 1200/2420 batchestrain_acc 0.347 train_loss 0.47775
| epoch8 | 1500/2420 batchestrain_acc 0.318 train_loss 0.49126
| epoch8 | 1800/2420 batchestrain_acc 0.350 train_loss 0.48422
| epoch8 | 2100/2420 batchestrain_acc 0.353 train_loss 0.48038
| epoch8 | 2400/2420 batchestrain_acc 0.385 train_loss 0.46344
---------------------------------------------------------------------
| epoch 8 | time:52.93s | valid_acc 0.363 valid_loss 0.479 | lr 0.010000
---------------------------------------------------------------------
| epoch9 |  300/2420 batchestrain_acc 0.367 train_loss 0.46217
| epoch9 |  600/2420 batchestrain_acc 0.375 train_loss 0.46193
| epoch9 |  900/2420 batchestrain_acc 0.391 train_loss 0.45186
| epoch9 | 1200/2420 batchestrain_acc 0.391 train_loss 0.44914
| epoch9 | 1500/2420 batchestrain_acc 0.411 train_loss 0.44841
| epoch9 | 1800/2420 batchestrain_acc 0.379 train_loss 0.45993
| epoch9 | 2100/2420 batchestrain_acc 0.364 train_loss 0.46334
| epoch9 | 2400/2420 batchestrain_acc 0.432 train_loss 0.44234
---------------------------------------------------------------------
| epoch 9 | time:45.39s | valid_acc 0.430 valid_loss 0.435 | lr 0.010000
---------------------------------------------------------------------
| epoch10 |  300/2420 batchestrain_acc 0.429 train_loss 0.42428
| epoch10 |  600/2420 batchestrain_acc 0.439 train_loss 0.44019
| epoch10 |  900/2420 batchestrain_acc 0.444 train_loss 0.42484
| epoch10 | 1200/2420 batchestrain_acc 0.421 train_loss 0.42512
| epoch10 | 1500/2420 batchestrain_acc 0.444 train_loss 0.42230
| epoch10 | 1800/2420 batchestrain_acc 0.450 train_loss 0.43166
| epoch10 | 2100/2420 batchestrain_acc 0.436 train_loss 0.41611
| epoch10 | 2400/2420 batchestrain_acc 0.486 train_loss 0.40082
---------------------------------------------------------------------
| epoch 10 | time:45.42s | valid_acc 0.483 valid_loss 0.403 | lr 0.010000
---------------------------------------------------------------------
| epoch11 |  300/2420 batchestrain_acc 0.500 train_loss 0.40324
| epoch11 |  600/2420 batchestrain_acc 0.457 train_loss 0.40968
| epoch11 |  900/2420 batchestrain_acc 0.475 train_loss 0.39906
| epoch11 | 1200/2420 batchestrain_acc 0.486 train_loss 0.40340
| epoch11 | 1500/2420 batchestrain_acc 0.501 train_loss 0.38556
| epoch11 | 1800/2420 batchestrain_acc 0.480 train_loss 0.39941
| epoch11 | 2100/2420 batchestrain_acc 0.499 train_loss 0.39277
| epoch11 | 2400/2420 batchestrain_acc 0.531 train_loss 0.37051
---------------------------------------------------------------------
| epoch 11 | time:51.40s | valid_acc 0.488 valid_loss 0.391 | lr 0.010000
---------------------------------------------------------------------
| epoch12 |  300/2420 batchestrain_acc 0.523 train_loss 0.37236
| epoch12 |  600/2420 batchestrain_acc 0.541 train_loss 0.36260
| epoch12 |  900/2420 batchestrain_acc 0.552 train_loss 0.35538
| epoch12 | 1200/2420 batchestrain_acc 0.548 train_loss 0.35951
| epoch12 | 1500/2420 batchestrain_acc 0.531 train_loss 0.36266
| epoch12 | 1800/2420 batchestrain_acc 0.543 train_loss 0.36087
| epoch12 | 2100/2420 batchestrain_acc 0.545 train_loss 0.36039
| epoch12 | 2400/2420 batchestrain_acc 0.546 train_loss 0.37050
---------------------------------------------------------------------
| epoch 12 | time:51.89s | valid_acc 0.533 valid_loss 0.357 | lr 0.010000
---------------------------------------------------------------------
| epoch13 |  300/2420 batchestrain_acc 0.586 train_loss 0.33823
| epoch13 |  600/2420 batchestrain_acc 0.593 train_loss 0.34383
| epoch13 |  900/2420 batchestrain_acc 0.557 train_loss 0.36033
| epoch13 | 1200/2420 batchestrain_acc 0.567 train_loss 0.33469
| epoch13 | 1500/2420 batchestrain_acc 0.599 train_loss 0.33413
| epoch13 | 1800/2420 batchestrain_acc 0.602 train_loss 0.31916
| epoch13 | 2100/2420 batchestrain_acc 0.562 train_loss 0.35244
| epoch13 | 2400/2420 batchestrain_acc 0.593 train_loss 0.32772
---------------------------------------------------------------------
| epoch 13 | time:51.13s | valid_acc 0.584 valid_loss 0.340 | lr 0.010000
---------------------------------------------------------------------
| epoch14 |  300/2420 batchestrain_acc 0.640 train_loss 0.30234
| epoch14 |  600/2420 batchestrain_acc 0.628 train_loss 0.31170
| epoch14 |  900/2420 batchestrain_acc 0.581 train_loss 0.32656
| epoch14 | 1200/2420 batchestrain_acc 0.616 train_loss 0.31603
| epoch14 | 1500/2420 batchestrain_acc 0.590 train_loss 0.32491
| epoch14 | 1800/2420 batchestrain_acc 0.604 train_loss 0.31986
| epoch14 | 2100/2420 batchestrain_acc 0.602 train_loss 0.32168
| epoch14 | 2400/2420 batchestrain_acc 0.587 train_loss 0.32910
---------------------------------------------------------------------
| epoch 14 | time:46.20s | valid_acc 0.619 valid_loss 0.314 | lr 0.010000
---------------------------------------------------------------------
| epoch15 |  300/2420 batchestrain_acc 0.601 train_loss 0.30912
| epoch15 |  600/2420 batchestrain_acc 0.637 train_loss 0.30154
| epoch15 |  900/2420 batchestrain_acc 0.636 train_loss 0.29399
| epoch15 | 1200/2420 batchestrain_acc 0.623 train_loss 0.30342
| epoch15 | 1500/2420 batchestrain_acc 0.630 train_loss 0.30157
| epoch15 | 1800/2420 batchestrain_acc 0.639 train_loss 0.29125
| epoch15 | 2100/2420 batchestrain_acc 0.616 train_loss 0.31402
| epoch15 | 2400/2420 batchestrain_acc 0.613 train_loss 0.31384
---------------------------------------------------------------------
| epoch 15 | time:52.92s | valid_acc 0.614 valid_loss 0.317 | lr 0.010000
---------------------------------------------------------------------
| epoch16 |  300/2420 batchestrain_acc 0.641 train_loss 0.29299
| epoch16 |  600/2420 batchestrain_acc 0.630 train_loss 0.30191
| epoch16 |  900/2420 batchestrain_acc 0.656 train_loss 0.28631
| epoch16 | 1200/2420 batchestrain_acc 0.670 train_loss 0.27210
| epoch16 | 1500/2420 batchestrain_acc 0.636 train_loss 0.30204
| epoch16 | 1800/2420 batchestrain_acc 0.657 train_loss 0.28607
| epoch16 | 2100/2420 batchestrain_acc 0.682 train_loss 0.26433
| epoch16 | 2400/2420 batchestrain_acc 0.647 train_loss 0.28930
---------------------------------------------------------------------
| epoch 16 | time:55.40s | valid_acc 0.650 valid_loss 0.280 | lr 0.010000
---------------------------------------------------------------------
| epoch17 |  300/2420 batchestrain_acc 0.675 train_loss 0.26864
| epoch17 |  600/2420 batchestrain_acc 0.675 train_loss 0.26734
| epoch17 |  900/2420 batchestrain_acc 0.657 train_loss 0.27574
| epoch17 | 1200/2420 batchestrain_acc 0.639 train_loss 0.28399
| epoch17 | 1500/2420 batchestrain_acc 0.670 train_loss 0.27347
| epoch17 | 1800/2420 batchestrain_acc 0.679 train_loss 0.26036
| epoch17 | 2100/2420 batchestrain_acc 0.671 train_loss 0.27521
| epoch17 | 2400/2420 batchestrain_acc 0.677 train_loss 0.26035
---------------------------------------------------------------------
| epoch 17 | time:47.33s | valid_acc 0.650 valid_loss 0.312 | lr 0.010000
---------------------------------------------------------------------
| epoch18 |  300/2420 batchestrain_acc 0.685 train_loss 0.26377
| epoch18 |  600/2420 batchestrain_acc 0.688 train_loss 0.26276
| epoch18 |  900/2420 batchestrain_acc 0.665 train_loss 0.27293
| epoch18 | 1200/2420 batchestrain_acc 0.690 train_loss 0.26074
| epoch18 | 1500/2420 batchestrain_acc 0.703 train_loss 0.24556
| epoch18 | 1800/2420 batchestrain_acc 0.674 train_loss 0.27331
| epoch18 | 2100/2420 batchestrain_acc 0.694 train_loss 0.25756
| epoch18 | 2400/2420 batchestrain_acc 0.693 train_loss 0.26167
---------------------------------------------------------------------
| epoch 18 | time:52.69s | valid_acc 0.660 valid_loss 0.285 | lr 0.010000
---------------------------------------------------------------------
| epoch19 |  300/2420 batchestrain_acc 0.715 train_loss 0.23965
| epoch19 |  600/2420 batchestrain_acc 0.708 train_loss 0.24142
| epoch19 |  900/2420 batchestrain_acc 0.716 train_loss 0.24902
| epoch19 | 1200/2420 batchestrain_acc 0.677 train_loss 0.27742
| epoch19 | 1500/2420 batchestrain_acc 0.689 train_loss 0.25587
| epoch19 | 1800/2420 batchestrain_acc 0.711 train_loss 0.24643
| epoch19 | 2100/2420 batchestrain_acc 0.687 train_loss 0.25684
| epoch19 | 2400/2420 batchestrain_acc 0.688 train_loss 0.25094
---------------------------------------------------------------------
| epoch 19 | time:51.46s | valid_acc 0.679 valid_loss 0.270 | lr 0.010000
---------------------------------------------------------------------
| epoch20 |  300/2420 batchestrain_acc 0.706 train_loss 0.23483
| epoch20 |  600/2420 batchestrain_acc 0.708 train_loss 0.24640
| epoch20 |  900/2420 batchestrain_acc 0.722 train_loss 0.22958
| epoch20 | 1200/2420 batchestrain_acc 0.714 train_loss 0.23760
| epoch20 | 1500/2420 batchestrain_acc 0.717 train_loss 0.23799
| epoch20 | 1800/2420 batchestrain_acc 0.705 train_loss 0.24246
| epoch20 | 2100/2420 batchestrain_acc 0.723 train_loss 0.23437
| epoch20 | 2400/2420 batchestrain_acc 0.712 train_loss 0.24185
---------------------------------------------------------------------
| epoch 20 | time:48.92s | valid_acc 0.675 valid_loss 0.287 | lr 0.010000
---------------------------------------------------------------------
| epoch21 |  300/2420 batchestrain_acc 0.731 train_loss 0.22429
| epoch21 |  600/2420 batchestrain_acc 0.719 train_loss 0.22944
| epoch21 |  900/2420 batchestrain_acc 0.733 train_loss 0.22685
| epoch21 | 1200/2420 batchestrain_acc 0.743 train_loss 0.22088
| epoch21 | 1500/2420 batchestrain_acc 0.718 train_loss 0.23528
| epoch21 | 1800/2420 batchestrain_acc 0.739 train_loss 0.22832
| epoch21 | 2100/2420 batchestrain_acc 0.718 train_loss 0.23742
| epoch21 | 2400/2420 batchestrain_acc 0.727 train_loss 0.23075
---------------------------------------------------------------------
| epoch 21 | time:51.53s | valid_acc 0.690 valid_loss 0.280 | lr 0.010000
---------------------------------------------------------------------
| epoch22 |  300/2420 batchestrain_acc 0.744 train_loss 0.21029
| epoch22 |  600/2420 batchestrain_acc 0.736 train_loss 0.22299
| epoch22 |  900/2420 batchestrain_acc 0.738 train_loss 0.22259
| epoch22 | 1200/2420 batchestrain_acc 0.731 train_loss 0.22499
| epoch22 | 1500/2420 batchestrain_acc 0.713 train_loss 0.23047
| epoch22 | 1800/2420 batchestrain_acc 0.733 train_loss 0.23300
| epoch22 | 2100/2420 batchestrain_acc 0.729 train_loss 0.22323
| epoch22 | 2400/2420 batchestrain_acc 0.732 train_loss 0.22367
---------------------------------------------------------------------
| epoch 22 | time:49.76s | valid_acc 0.677 valid_loss 0.279 | lr 0.010000
---------------------------------------------------------------------
| epoch23 |  300/2420 batchestrain_acc 0.747 train_loss 0.21396
| epoch23 |  600/2420 batchestrain_acc 0.730 train_loss 0.22773
| epoch23 |  900/2420 batchestrain_acc 0.766 train_loss 0.19834
| epoch23 | 1200/2420 batchestrain_acc 0.748 train_loss 0.21389
| epoch23 | 1500/2420 batchestrain_acc 0.740 train_loss 0.22346
| epoch23 | 1800/2420 batchestrain_acc 0.730 train_loss 0.22744
| epoch23 | 2100/2420 batchestrain_acc 0.738 train_loss 0.22318
| epoch23 | 2400/2420 batchestrain_acc 0.743 train_loss 0.21736
---------------------------------------------------------------------
| epoch 23 | time:47.72s | valid_acc 0.698 valid_loss 0.269 | lr 0.010000
---------------------------------------------------------------------
| epoch24 |  300/2420 batchestrain_acc 0.749 train_loss 0.20925
| epoch24 |  600/2420 batchestrain_acc 0.757 train_loss 0.20227
| epoch24 |  900/2420 batchestrain_acc 0.761 train_loss 0.20799
| epoch24 | 1200/2420 batchestrain_acc 0.777 train_loss 0.18896
| epoch24 | 1500/2420 batchestrain_acc 0.766 train_loss 0.19932
| epoch24 | 1800/2420 batchestrain_acc 0.752 train_loss 0.20789
| epoch24 | 2100/2420 batchestrain_acc 0.773 train_loss 0.19881
| epoch24 | 2400/2420 batchestrain_acc 0.752 train_loss 0.21283
---------------------------------------------------------------------
| epoch 24 | time:47.79s | valid_acc 0.655 valid_loss 0.306 | lr 0.010000
---------------------------------------------------------------------
| epoch25 |  300/2420 batchestrain_acc 0.778 train_loss 0.18798
| epoch25 |  600/2420 batchestrain_acc 0.754 train_loss 0.20311
| epoch25 |  900/2420 batchestrain_acc 0.769 train_loss 0.19548
| epoch25 | 1200/2420 batchestrain_acc 0.773 train_loss 0.19514
| epoch25 | 1500/2420 batchestrain_acc 0.769 train_loss 0.20732
| epoch25 | 1800/2420 batchestrain_acc 0.763 train_loss 0.19751
| epoch25 | 2100/2420 batchestrain_acc 0.755 train_loss 0.20002
| epoch25 | 2400/2420 batchestrain_acc 0.767 train_loss 0.19577
---------------------------------------------------------------------
| epoch 25 | time:47.72s | valid_acc 0.700 valid_loss 0.267 | lr 0.010000
---------------------------------------------------------------------
| epoch26 |  300/2420 batchestrain_acc 0.775 train_loss 0.18995
| epoch26 |  600/2420 batchestrain_acc 0.782 train_loss 0.18044
| epoch26 |  900/2420 batchestrain_acc 0.776 train_loss 0.19263
| epoch26 | 1200/2420 batchestrain_acc 0.772 train_loss 0.18196
| epoch26 | 1500/2420 batchestrain_acc 0.782 train_loss 0.19367
| epoch26 | 1800/2420 batchestrain_acc 0.764 train_loss 0.19917
| epoch26 | 2100/2420 batchestrain_acc 0.757 train_loss 0.20754
| epoch26 | 2400/2420 batchestrain_acc 0.763 train_loss 0.20010
---------------------------------------------------------------------
| epoch 26 | time:47.80s | valid_acc 0.714 valid_loss 0.259 | lr 0.010000
---------------------------------------------------------------------
| epoch27 |  300/2420 batchestrain_acc 0.783 train_loss 0.18380
| epoch27 |  600/2420 batchestrain_acc 0.787 train_loss 0.17971
| epoch27 |  900/2420 batchestrain_acc 0.784 train_loss 0.18646
| epoch27 | 1200/2420 batchestrain_acc 0.796 train_loss 0.17358
| epoch27 | 1500/2420 batchestrain_acc 0.789 train_loss 0.19063
| epoch27 | 1800/2420 batchestrain_acc 0.802 train_loss 0.17082
| epoch27 | 2100/2420 batchestrain_acc 0.772 train_loss 0.19022
| epoch27 | 2400/2420 batchestrain_acc 0.768 train_loss 0.19923
---------------------------------------------------------------------
| epoch 27 | time:49.38s | valid_acc 0.713 valid_loss 0.264 | lr 0.010000
---------------------------------------------------------------------
| epoch28 |  300/2420 batchestrain_acc 0.806 train_loss 0.16722
| epoch28 |  600/2420 batchestrain_acc 0.802 train_loss 0.16295
| epoch28 |  900/2420 batchestrain_acc 0.773 train_loss 0.19564
| epoch28 | 1200/2420 batchestrain_acc 0.788 train_loss 0.18725
| epoch28 | 1500/2420 batchestrain_acc 0.801 train_loss 0.16580
| epoch28 | 1800/2420 batchestrain_acc 0.797 train_loss 0.17730
| epoch28 | 2100/2420 batchestrain_acc 0.787 train_loss 0.18217
| epoch28 | 2400/2420 batchestrain_acc 0.773 train_loss 0.20249
---------------------------------------------------------------------
| epoch 28 | time:47.99s | valid_acc 0.724 valid_loss 0.255 | lr 0.010000
---------------------------------------------------------------------
| epoch29 |  300/2420 batchestrain_acc 0.777 train_loss 0.17826
| epoch29 |  600/2420 batchestrain_acc 0.797 train_loss 0.16099
| epoch29 |  900/2420 batchestrain_acc 0.821 train_loss 0.15533
| epoch29 | 1200/2420 batchestrain_acc 0.784 train_loss 0.18379
| epoch29 | 1500/2420 batchestrain_acc 0.780 train_loss 0.17715
| epoch29 | 1800/2420 batchestrain_acc 0.795 train_loss 0.17738
| epoch29 | 2100/2420 batchestrain_acc 0.790 train_loss 0.17705
| epoch29 | 2400/2420 batchestrain_acc 0.782 train_loss 0.18708
---------------------------------------------------------------------
| epoch 29 | time:48.01s | valid_acc 0.719 valid_loss 0.252 | lr 0.010000
---------------------------------------------------------------------
| epoch30 |  300/2420 batchestrain_acc 0.821 train_loss 0.16117
| epoch30 |  600/2420 batchestrain_acc 0.805 train_loss 0.15564
| epoch30 |  900/2420 batchestrain_acc 0.803 train_loss 0.16268
| epoch30 | 1200/2420 batchestrain_acc 0.817 train_loss 0.16171
| epoch30 | 1500/2420 batchestrain_acc 0.810 train_loss 0.16449
| epoch30 | 1800/2420 batchestrain_acc 0.795 train_loss 0.17510
| epoch30 | 2100/2420 batchestrain_acc 0.779 train_loss 0.18525
| epoch30 | 2400/2420 batchestrain_acc 0.809 train_loss 0.16960
---------------------------------------------------------------------
| epoch 30 | time:48.12s | valid_acc 0.715 valid_loss 0.264 | lr 0.010000
---------------------------------------------------------------------
| epoch31 |  300/2420 batchestrain_acc 0.811 train_loss 0.15723
| epoch31 |  600/2420 batchestrain_acc 0.790 train_loss 0.16986
| epoch31 |  900/2420 batchestrain_acc 0.796 train_loss 0.17329
| epoch31 | 1200/2420 batchestrain_acc 0.808 train_loss 0.16572
| epoch31 | 1500/2420 batchestrain_acc 0.797 train_loss 0.16919
| epoch31 | 1800/2420 batchestrain_acc 0.802 train_loss 0.16382
| epoch31 | 2100/2420 batchestrain_acc 0.807 train_loss 0.15687
| epoch31 | 2400/2420 batchestrain_acc 0.775 train_loss 0.18029
---------------------------------------------------------------------
| epoch 31 | time:48.27s | valid_acc 0.721 valid_loss 0.264 | lr 0.010000
---------------------------------------------------------------------
| epoch32 |  300/2420 batchestrain_acc 0.807 train_loss 0.17263
| epoch32 |  600/2420 batchestrain_acc 0.817 train_loss 0.15635
| epoch32 |  900/2420 batchestrain_acc 0.792 train_loss 0.17293
| epoch32 | 1200/2420 batchestrain_acc 0.814 train_loss 0.15874
| epoch32 | 1500/2420 batchestrain_acc 0.799 train_loss 0.17187
| epoch32 | 1800/2420 batchestrain_acc 0.815 train_loss 0.16326
| epoch32 | 2100/2420 batchestrain_acc 0.799 train_loss 0.17070
| epoch32 | 2400/2420 batchestrain_acc 0.816 train_loss 0.15760
---------------------------------------------------------------------
| epoch 32 | time:50.59s | valid_acc 0.752 valid_loss 0.242 | lr 0.010000
---------------------------------------------------------------------
| epoch33 |  300/2420 batchestrain_acc 0.826 train_loss 0.14967
| epoch33 |  600/2420 batchestrain_acc 0.834 train_loss 0.13926
| epoch33 |  900/2420 batchestrain_acc 0.817 train_loss 0.16265
| epoch33 | 1200/2420 batchestrain_acc 0.820 train_loss 0.16022
| epoch33 | 1500/2420 batchestrain_acc 0.812 train_loss 0.15494
| epoch33 | 1800/2420 batchestrain_acc 0.816 train_loss 0.15696
| epoch33 | 2100/2420 batchestrain_acc 0.831 train_loss 0.15118
| epoch33 | 2400/2420 batchestrain_acc 0.824 train_loss 0.15765
---------------------------------------------------------------------
| epoch 33 | time:53.45s | valid_acc 0.711 valid_loss 0.284 | lr 0.010000
---------------------------------------------------------------------
| epoch34 |  300/2420 batchestrain_acc 0.838 train_loss 0.13297
| epoch34 |  600/2420 batchestrain_acc 0.836 train_loss 0.14296
| epoch34 |  900/2420 batchestrain_acc 0.809 train_loss 0.15700
| epoch34 | 1200/2420 batchestrain_acc 0.814 train_loss 0.16003
| epoch34 | 1500/2420 batchestrain_acc 0.810 train_loss 0.16390
| epoch34 | 1800/2420 batchestrain_acc 0.831 train_loss 0.13682
| epoch34 | 2100/2420 batchestrain_acc 0.828 train_loss 0.15188
| epoch34 | 2400/2420 batchestrain_acc 0.807 train_loss 0.16033
---------------------------------------------------------------------
| epoch 34 | time:53.36s | valid_acc 0.746 valid_loss 0.243 | lr 0.010000
---------------------------------------------------------------------
| epoch35 |  300/2420 batchestrain_acc 0.838 train_loss 0.13975
| epoch35 |  600/2420 batchestrain_acc 0.823 train_loss 0.14537
| epoch35 |  900/2420 batchestrain_acc 0.852 train_loss 0.12856
| epoch35 | 1200/2420 batchestrain_acc 0.840 train_loss 0.13151
| epoch35 | 1500/2420 batchestrain_acc 0.854 train_loss 0.12872
| epoch35 | 1800/2420 batchestrain_acc 0.828 train_loss 0.14587
| epoch35 | 2100/2420 batchestrain_acc 0.823 train_loss 0.15114
| epoch35 | 2400/2420 batchestrain_acc 0.823 train_loss 0.15159
---------------------------------------------------------------------
| epoch 35 | time:49.54s | valid_acc 0.743 valid_loss 0.249 | lr 0.010000
---------------------------------------------------------------------
| epoch36 |  300/2420 batchestrain_acc 0.837 train_loss 0.13681
| epoch36 |  600/2420 batchestrain_acc 0.847 train_loss 0.13260
| epoch36 |  900/2420 batchestrain_acc 0.825 train_loss 0.15109
| epoch36 | 1200/2420 batchestrain_acc 0.845 train_loss 0.13915
| epoch36 | 1500/2420 batchestrain_acc 0.857 train_loss 0.12454
| epoch36 | 1800/2420 batchestrain_acc 0.816 train_loss 0.15596
| epoch36 | 2100/2420 batchestrain_acc 0.853 train_loss 0.12449
| epoch36 | 2400/2420 batchestrain_acc 0.820 train_loss 0.15226
---------------------------------------------------------------------
| epoch 36 | time:46.81s | valid_acc 0.758 valid_loss 0.245 | lr 0.010000
---------------------------------------------------------------------
| epoch37 |  300/2420 batchestrain_acc 0.861 train_loss 0.13067
| epoch37 |  600/2420 batchestrain_acc 0.843 train_loss 0.14057
| epoch37 |  900/2420 batchestrain_acc 0.826 train_loss 0.14100
| epoch37 | 1200/2420 batchestrain_acc 0.847 train_loss 0.12774
| epoch37 | 1500/2420 batchestrain_acc 0.849 train_loss 0.12791
| epoch37 | 1800/2420 batchestrain_acc 0.821 train_loss 0.14428
| epoch37 | 2100/2420 batchestrain_acc 0.835 train_loss 0.14542
| epoch37 | 2400/2420 batchestrain_acc 0.850 train_loss 0.13312
---------------------------------------------------------------------
| epoch 37 | time:46.83s | valid_acc 0.745 valid_loss 0.257 | lr 0.010000
---------------------------------------------------------------------
| epoch38 |  300/2420 batchestrain_acc 0.863 train_loss 0.11401
| epoch38 |  600/2420 batchestrain_acc 0.862 train_loss 0.11872
| epoch38 |  900/2420 batchestrain_acc 0.859 train_loss 0.11716
| epoch38 | 1200/2420 batchestrain_acc 0.852 train_loss 0.12466
| epoch38 | 1500/2420 batchestrain_acc 0.873 train_loss 0.11293
| epoch38 | 1800/2420 batchestrain_acc 0.828 train_loss 0.14327
| epoch38 | 2100/2420 batchestrain_acc 0.830 train_loss 0.13937
| epoch38 | 2400/2420 batchestrain_acc 0.838 train_loss 0.14168
---------------------------------------------------------------------
| epoch 38 | time:47.06s | valid_acc 0.712 valid_loss 0.292 | lr 0.010000
---------------------------------------------------------------------
| epoch39 |  300/2420 batchestrain_acc 0.845 train_loss 0.12823
| epoch39 |  600/2420 batchestrain_acc 0.867 train_loss 0.11793
| epoch39 |  900/2420 batchestrain_acc 0.853 train_loss 0.12088
| epoch39 | 1200/2420 batchestrain_acc 0.843 train_loss 0.12443
| epoch39 | 1500/2420 batchestrain_acc 0.858 train_loss 0.13443
| epoch39 | 1800/2420 batchestrain_acc 0.846 train_loss 0.12992
| epoch39 | 2100/2420 batchestrain_acc 0.875 train_loss 0.10862
| epoch39 | 2400/2420 batchestrain_acc 0.848 train_loss 0.12784
---------------------------------------------------------------------
| epoch 39 | time:48.13s | valid_acc 0.740 valid_loss 0.267 | lr 0.010000
---------------------------------------------------------------------
| epoch40 |  300/2420 batchestrain_acc 0.860 train_loss 0.11646
| epoch40 |  600/2420 batchestrain_acc 0.878 train_loss 0.10498
| epoch40 |  900/2420 batchestrain_acc 0.841 train_loss 0.14007
| epoch40 | 1200/2420 batchestrain_acc 0.834 train_loss 0.13898
| epoch40 | 1500/2420 batchestrain_acc 0.858 train_loss 0.12757
| epoch40 | 1800/2420 batchestrain_acc 0.858 train_loss 0.12010
| epoch40 | 2100/2420 batchestrain_acc 0.844 train_loss 0.11948
| epoch40 | 2400/2420 batchestrain_acc 0.852 train_loss 0.13280
---------------------------------------------------------------------
| epoch 40 | time:47.16s | valid_acc 0.729 valid_loss 0.277 | lr 0.010000
---------------------------------------------------------------------
| epoch41 |  300/2420 batchestrain_acc 0.874 train_loss 0.10229
| epoch41 |  600/2420 batchestrain_acc 0.850 train_loss 0.12475
| epoch41 |  900/2420 batchestrain_acc 0.873 train_loss 0.11382
| epoch41 | 1200/2420 batchestrain_acc 0.844 train_loss 0.12744
| epoch41 | 1500/2420 batchestrain_acc 0.851 train_loss 0.12652
| epoch41 | 1800/2420 batchestrain_acc 0.855 train_loss 0.11761
| epoch41 | 2100/2420 batchestrain_acc 0.857 train_loss 0.11470
| epoch41 | 2400/2420 batchestrain_acc 0.831 train_loss 0.14415
---------------------------------------------------------------------
| epoch 41 | time:46.99s | valid_acc 0.735 valid_loss 0.252 | lr 0.010000
---------------------------------------------------------------------
| epoch42 |  300/2420 batchestrain_acc 0.883 train_loss 0.10204
| epoch42 |  600/2420 batchestrain_acc 0.868 train_loss 0.10268
| epoch42 |  900/2420 batchestrain_acc 0.840 train_loss 0.14107
| epoch42 | 1200/2420 batchestrain_acc 0.863 train_loss 0.12050
| epoch42 | 1500/2420 batchestrain_acc 0.850 train_loss 0.12474
| epoch42 | 1800/2420 batchestrain_acc 0.858 train_loss 0.11658
| epoch42 | 2100/2420 batchestrain_acc 0.866 train_loss 0.11895
| epoch42 | 2400/2420 batchestrain_acc 0.864 train_loss 0.11654
---------------------------------------------------------------------
| epoch 42 | time:46.76s | valid_acc 0.745 valid_loss 0.257 | lr 0.010000
---------------------------------------------------------------------
| epoch43 |  300/2420 batchestrain_acc 0.845 train_loss 0.12530
| epoch43 |  600/2420 batchestrain_acc 0.863 train_loss 0.11793
| epoch43 |  900/2420 batchestrain_acc 0.881 train_loss 0.10258
| epoch43 | 1200/2420 batchestrain_acc 0.884 train_loss 0.10542
| epoch43 | 1500/2420 batchestrain_acc 0.857 train_loss 0.11869
| epoch43 | 1800/2420 batchestrain_acc 0.863 train_loss 0.11984
| epoch43 | 2100/2420 batchestrain_acc 0.862 train_loss 0.12085
| epoch43 | 2400/2420 batchestrain_acc 0.864 train_loss 0.11249
---------------------------------------------------------------------
| epoch 43 | time:47.54s | valid_acc 0.728 valid_loss 0.270 | lr 0.010000
---------------------------------------------------------------------
| epoch44 |  300/2420 batchestrain_acc 0.871 train_loss 0.11203
| epoch44 |  600/2420 batchestrain_acc 0.876 train_loss 0.10731
| epoch44 |  900/2420 batchestrain_acc 0.876 train_loss 0.09843
| epoch44 | 1200/2420 batchestrain_acc 0.877 train_loss 0.10445
| epoch44 | 1500/2420 batchestrain_acc 0.844 train_loss 0.13204
| epoch44 | 1800/2420 batchestrain_acc 0.856 train_loss 0.12643
| epoch44 | 2100/2420 batchestrain_acc 0.858 train_loss 0.12898
| epoch44 | 2400/2420 batchestrain_acc 0.853 train_loss 0.13501
---------------------------------------------------------------------
| epoch 44 | time:44.21s | valid_acc 0.754 valid_loss 0.248 | lr 0.010000
---------------------------------------------------------------------
| epoch45 |  300/2420 batchestrain_acc 0.870 train_loss 0.11244
| epoch45 |  600/2420 batchestrain_acc 0.868 train_loss 0.11606
| epoch45 |  900/2420 batchestrain_acc 0.888 train_loss 0.09326
| epoch45 | 1200/2420 batchestrain_acc 0.876 train_loss 0.10719
| epoch45 | 1500/2420 batchestrain_acc 0.870 train_loss 0.10922
| epoch45 | 1800/2420 batchestrain_acc 0.882 train_loss 0.10089
| epoch45 | 2100/2420 batchestrain_acc 0.881 train_loss 0.10368
| epoch45 | 2400/2420 batchestrain_acc 0.870 train_loss 0.11590
---------------------------------------------------------------------
| epoch 45 | time:50.51s | valid_acc 0.762 valid_loss 0.248 | lr 0.010000
---------------------------------------------------------------------
| epoch46 |  300/2420 batchestrain_acc 0.886 train_loss 0.09462
| epoch46 |  600/2420 batchestrain_acc 0.873 train_loss 0.10219
| epoch46 |  900/2420 batchestrain_acc 0.875 train_loss 0.10946
| epoch46 | 1200/2420 batchestrain_acc 0.873 train_loss 0.10926
| epoch46 | 1500/2420 batchestrain_acc 0.873 train_loss 0.10560
| epoch46 | 1800/2420 batchestrain_acc 0.882 train_loss 0.10014
| epoch46 | 2100/2420 batchestrain_acc 0.884 train_loss 0.10224
| epoch46 | 2400/2420 batchestrain_acc 0.892 train_loss 0.10056
---------------------------------------------------------------------
| epoch 46 | time:47.69s | valid_acc 0.755 valid_loss 0.248 | lr 0.010000
---------------------------------------------------------------------
| epoch47 |  300/2420 batchestrain_acc 0.873 train_loss 0.10631
| epoch47 |  600/2420 batchestrain_acc 0.885 train_loss 0.09509
| epoch47 |  900/2420 batchestrain_acc 0.870 train_loss 0.11216
| epoch47 | 1200/2420 batchestrain_acc 0.882 train_loss 0.09893
| epoch47 | 1500/2420 batchestrain_acc 0.885 train_loss 0.10097
| epoch47 | 1800/2420 batchestrain_acc 0.880 train_loss 0.10543
| epoch47 | 2100/2420 batchestrain_acc 0.879 train_loss 0.10007
| epoch47 | 2400/2420 batchestrain_acc 0.877 train_loss 0.10747
---------------------------------------------------------------------
| epoch 47 | time:49.61s | valid_acc 0.740 valid_loss 0.273 | lr 0.010000
---------------------------------------------------------------------
| epoch48 |  300/2420 batchestrain_acc 0.880 train_loss 0.10058
| epoch48 |  600/2420 batchestrain_acc 0.861 train_loss 0.11509
| epoch48 |  900/2420 batchestrain_acc 0.887 train_loss 0.09902
| epoch48 | 1200/2420 batchestrain_acc 0.867 train_loss 0.11345
| epoch48 | 1500/2420 batchestrain_acc 0.888 train_loss 0.09145
| epoch48 | 1800/2420 batchestrain_acc 0.889 train_loss 0.10624
| epoch48 | 2100/2420 batchestrain_acc 0.886 train_loss 0.09808
| epoch48 | 2400/2420 batchestrain_acc 0.900 train_loss 0.09340
---------------------------------------------------------------------
| epoch 48 | time:47.14s | valid_acc 0.740 valid_loss 0.288 | lr 0.010000
---------------------------------------------------------------------
| epoch49 |  300/2420 batchestrain_acc 0.901 train_loss 0.09128
| epoch49 |  600/2420 batchestrain_acc 0.870 train_loss 0.11369
| epoch49 |  900/2420 batchestrain_acc 0.887 train_loss 0.09796
| epoch49 | 1200/2420 batchestrain_acc 0.886 train_loss 0.10062
| epoch49 | 1500/2420 batchestrain_acc 0.882 train_loss 0.10779
| epoch49 | 1800/2420 batchestrain_acc 0.888 train_loss 0.10213
| epoch49 | 2100/2420 batchestrain_acc 0.868 train_loss 0.11080
| epoch49 | 2400/2420 batchestrain_acc 0.868 train_loss 0.11578
---------------------------------------------------------------------
| epoch 49 | time:47.14s | valid_acc 0.734 valid_loss 0.303 | lr 0.010000
---------------------------------------------------------------------
| epoch50 |  300/2420 batchestrain_acc 0.900 train_loss 0.08726
| epoch50 |  600/2420 batchestrain_acc 0.887 train_loss 0.09685
| epoch50 |  900/2420 batchestrain_acc 0.895 train_loss 0.08955
| epoch50 | 1200/2420 batchestrain_acc 0.877 train_loss 0.10595
| epoch50 | 1500/2420 batchestrain_acc 0.897 train_loss 0.09040
| epoch50 | 1800/2420 batchestrain_acc 0.890 train_loss 0.09288
| epoch50 | 2100/2420 batchestrain_acc 0.905 train_loss 0.08954
| epoch50 | 2400/2420 batchestrain_acc 0.904 train_loss 0.08663
---------------------------------------------------------------------
| epoch 50 | time:47.19s | valid_acc 0.762 valid_loss 0.258 | lr 0.010000
---------------------------------------------------------------------
模型准确率为:0.7620

进程已结束,退出代码 0

5.3 还可以尝试修改学习率、batch_size等


总结

实现了Transformer在文本分类任务上的应用,并达到了70%以上的准确率,但存在一个问题,运行时间很长

Transformer 是一种用于处理序列数据的深度学习模型,它在自然语言处理任务中取得了很大的成功。文本分类是其中一种常见的任务,它的目标是将输入的文本分类到预定义的类别中。 在使用 Transformer 进行文本分类时,可以采用以下步骤: 1. 数据预处理:将文本数据转化为模型可以处理的形式,通常是将文本转化为数字表示,如词嵌入或字符嵌入。 2. 构建 Transformer 模型:使用 Transformer 模型作为文本分类的基础模型。Transformer 模型由多个编码器层和解码器层组成,其中编码器用于提取输入文本的特征。 3. 特征提取:通过将输入文本输入到 Transformer 模型中,获取文本的特征表示。可以使用编码器最终输出的隐藏状态作为文本的特征表示。 4. 分类层:将提取到的文本特征输入到分类层中,进行具体的文本分类任务。分类层通常是一个全连接层或者 softmax 层,输出预测的类别概率。 5. 损失函数和优化:使用适当的损失函数(如交叉熵损失函数)来计算模型预测结果与真实标签之间的差距,并通过反向传播算法更新模型参数。 6. 模型训练和评估:使用训练数据对模型进行训练,并使用验证数据对模型进行评估和调优。最后,使用测试数据对模型进行评估。 需要注意的是,具体的实现细节可能会因框架和任务的不同而有所差异,可以根据具体情况选择适合的工具和库来实现 Transformer 文本分类
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值