pytorch实战---IMDB情感分析

    <link rel="stylesheet" href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/editerView/kdoc_html_views-1a98987dfd.css">
    <link rel="stylesheet" href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/editerView/ck_htmledit_views-dc4a025e85.css">
            <div id="content_views" class="markdown_views prism-atom-one-dark">
                <svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
                    <path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path>
                </svg>
                <blockquote> 

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
请添加图片描述

🥦引言

本文使用IMDB数据集,结合pytorch进行情感分析

🥦完整代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

from torch import utils

import torchtext
from tqdm import tqdm
from torchtext.datasets import IMDB

from torchtext.datasets.imdb import NUM_LINES
from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import to_map_style_dataset

import os
import sys
import logging
import logging

logging.basicConfig(
level=logging.WARN, stream=sys.stdout, format = “%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s”)

VOCAB_SIZE = 15000

# step1 编写GCNN模型代码,门(Gate)卷积网络
class GCNN(nn.Module):
def init(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2):
super(GCNN, self).init()

    self<span class="token punctuation">.</span>embedding_table <span class="token operator">=</span> nn<span class="token punctuation">.</span>Embedding<span class="token punctuation">(</span>vocab_size<span class="token punctuation">,</span> embedding_dim<span class="token punctuation">)</span>
    nn<span class="token punctuation">.</span>init<span class="token punctuation">.</span>xavier_uniform_<span class="token punctuation">(</span>self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">.</span>weight<span class="token punctuation">)</span>

    <span class="token comment"># 都是1维卷积</span>
    self<span class="token punctuation">.</span>conv_A_1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span>embedding_dim<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>conv_B_1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span>embedding_dim<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>

    self<span class="token punctuation">.</span>conv_A_2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>conv_B_2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>

    self<span class="token punctuation">.</span>output_linear1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>output_linear2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> num_class<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> word_index<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token triple-quoted-string string">"""
    定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出
    """</span>
    <span class="token comment"># 1. 通过word_index得到word_embedding</span>
    <span class="token comment"># word_index shape: [bs, max_seq_len]</span>
    word_embedding <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">(</span>word_index<span class="token punctuation">)</span>  <span class="token comment"># [bs, max_seq_len, embedding_dim]</span>

    <span class="token comment"># 2. 编写第一层1D门卷积模块,通道数在第2维</span>
    word_embedding <span class="token operator">=</span> word_embedding<span class="token punctuation">.</span>transpose<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span>  <span class="token comment"># [bs, embedding_dim, max_seq_len]</span>
    A <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_A_1<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span>
    B <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_B_1<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span>
    H <span class="token operator">=</span> A <span class="token operator">*</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>B<span class="token punctuation">)</span>  <span class="token comment"># [bs, 64, max_seq_len]</span>

    A <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_A_2<span class="token punctuation">(</span>H<span class="token punctuation">)</span>
    B <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_B_2<span class="token punctuation">(</span>H<span class="token punctuation">)</span>
    H <span class="token operator">=</span> A <span class="token operator">*</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>B<span class="token punctuation">)</span>  <span class="token comment"># [bs, 64, max_seq_len]</span>

    <span class="token comment"># 3. 池化并经过全连接层</span>
    pool_output <span class="token operator">=</span> torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span>H<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>  <span class="token comment"># 平均池化,得到[bs, 4096]</span>
    linear1_output <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear1<span class="token punctuation">(</span>pool_output<span class="token punctuation">)</span>

    <span class="token comment"># 最后一层需要设置为隐含层数目</span>
    logits <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear2<span class="token punctuation">(</span>linear1_output<span class="token punctuation">)</span>  <span class="token comment"># [bs, 2]</span>

    <span class="token keyword">return</span> logits

# PyTorch官网的简单模型
class TextClassificationModel(nn.Module):
“”"
简单版embedding.DNN模型
“”"

<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> vocab_size<span class="token operator">=</span>VOCAB_SIZE<span class="token punctuation">,</span> embed_dim<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> num_class<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token builtin">super</span><span class="token punctuation">(</span>TextClassificationModel<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>embedding <span class="token operator">=</span> nn<span class="token punctuation">.</span>EmbeddingBag<span class="token punctuation">(</span>vocab_size<span class="token punctuation">,</span> embed_dim<span class="token punctuation">,</span> sparse<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>fc <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>embed_dim<span class="token punctuation">,</span> num_class<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> token_index<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment"># 词袋</span>
    embedded <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span>  <span class="token comment"># shape: [bs, embedding_dim]</span>
    <span class="token keyword">return</span> self<span class="token punctuation">.</span>fc<span class="token punctuation">(</span>embedded<span class="token punctuation">)</span>

# step2 构建IMDB Dataloader
BATCH_SIZE = 64

def yeild_tokens(train_data_iter, tokenizer):
for i, sample in enumerate(train_data_iter):
label, comment = sample
yield tokenizer(comment) # 字符串转换为token索引的列表

train_data_iter = IMDB(root=“./data”, split=“train”) # Dataset类型的对象
tokenizer = get_tokenizer(“basic_english”)
# 只使用出现次数大约20的token
vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=[“<unk>”])
vocab.set_default_index(0) # 特殊索引设置为0
print(f’单词表大小: len(vocab)')

# 校对函数, batch是dataset返回值,主要是处理batch一组数据
def collate_fn(batch):
“”"
对DataLoader所生成的mini-batch进行后处理
“”"

target = []
token_index = []
max_length = 0 # 最大的token长度
for i, (label, comment) in enumerate(batch):
tokens = tokenizer(comment)
token_index.append(vocab(tokens)) # 字符列表转换为索引列表

    <span class="token comment"># 确定最大的句子长度</span>
    <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>tokens<span class="token punctuation">)</span> <span class="token operator">&gt;</span> max_length<span class="token punctuation">:</span>
        max_length <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>tokens<span class="token punctuation">)</span>

    <span class="token keyword">if</span> label <span class="token operator">==</span> <span class="token string">"pos"</span><span class="token punctuation">:</span>
        target<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span>
        target<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span>

token_index <span class="token operator">=</span> <span class="token punctuation">[</span>index <span class="token operator">+</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> <span class="token punctuation">(</span>max_length <span class="token operator">-</span> <span class="token builtin">len</span><span class="token punctuation">(</span>index<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">for</span> index <span class="token keyword">in</span> token_index<span class="token punctuation">]</span>
<span class="token comment"># one-hot接收长整形的数据,所以要转换为int64</span>
<span class="token keyword">return</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>target<span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>int64<span class="token punctuation">)</span><span class="token punctuation">,</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>int32<span class="token punctuation">)</span><span class="token punctuation">)</span>

# step3 编写训练代码
def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval, eval_step_interval, save_path, resume=“”):
“”"
此处data_loader是map-style dataset
“”“

start_epoch = 0
start_step = 0
if resume != ”“:
# 加载之前训练过的模型的参数文件
logging.warning(f"loading from resume”)
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint[‘model_state_dict’])
optimizer.load_state_dict(checkpoint[‘optimizer_state_dict’])
start_epoch = checkpoint[‘epoch’]
start_step = checkpoint[‘step’]

<span class="token keyword">for</span> epoch_index <span class="token keyword">in</span> tqdm<span class="token punctuation">(</span><span class="token builtin">range</span><span class="token punctuation">(</span>start_epoch<span class="token punctuation">,</span> num_epoch<span class="token punctuation">)</span><span class="token punctuation">,</span> desc<span class="token operator">=</span><span class="token string">"epoch"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    ema_loss <span class="token operator">=</span> <span class="token number">0</span>
    total_acc_account <span class="token operator">=</span> <span class="token number">0</span>
    total_account <span class="token operator">=</span> <span class="token number">0</span>
    true_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
    predicted_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
    num_batches <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>train_data_loader<span class="token punctuation">)</span>
    <span class="token keyword">for</span> batch_index<span class="token punctuation">,</span> <span class="token punctuation">(</span>target<span class="token punctuation">,</span> token_index<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>train_data_loader<span class="token punctuation">)</span><span class="token punctuation">:</span>
        optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
        step <span class="token operator">=</span> num_batches <span class="token operator">*</span> <span class="token punctuation">(</span>epoch_index<span class="token punctuation">)</span> <span class="token operator">+</span> batch_index <span class="token operator">+</span> <span class="token number">1</span>
        logits <span class="token operator">=</span> model<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span>
        <span class="token comment"># one-hot需要转换float32才可以训练</span>
        bce_loss <span class="token operator">=</span> F<span class="token punctuation">.</span>binary_cross_entropy<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>logits<span class="token punctuation">)</span><span class="token punctuation">,</span> F<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>target<span class="token punctuation">,</span> num_classes<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>float32<span class="token punctuation">)</span><span class="token punctuation">)</span>
        ema_loss <span class="token operator">=</span> <span class="token number">0.9</span> <span class="token operator">*</span> ema_loss <span class="token operator">+</span> <span class="token number">0.1</span> <span class="token operator">*</span> bce_loss  <span class="token comment"># 指数平均loss</span>
        bce_loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
        nn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>clip_grad_norm_<span class="token punctuation">(</span>model<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">)</span>  <span class="token comment"># 梯度的正则进行截断,保证训练稳定</span>
        optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span>  <span class="token comment"># 更新参数</span>

        true_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>target<span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        predicted_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

        <span class="token keyword">if</span> step <span class="token operator">%</span> log_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"epoch_index: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>epoch_index<span class="token punctuation">}</span></span><span class="token string">, batch_index: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>batch_index<span class="token punctuation">}</span></span><span class="token string">, ema_loss: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>ema_loss<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>

        <span class="token keyword">if</span> step <span class="token operator">%</span> save_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
            os<span class="token punctuation">.</span>makedirs<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> exist_ok<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
            save_file <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f"step_</span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>step<span class="token punctuation">}</span></span><span class="token string">.pt"</span></span><span class="token punctuation">)</span>
            torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span>
                <span class="token string">"epoch"</span><span class="token punctuation">:</span> epoch_index<span class="token punctuation">,</span>
                <span class="token string">"step"</span><span class="token punctuation">:</span> step<span class="token punctuation">,</span>
                <span class="token string">"model_state_dict"</span><span class="token punctuation">:</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                <span class="token string">'optimizer_state_dict'</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                <span class="token string">'loss'</span><span class="token punctuation">:</span> bce_loss
            <span class="token punctuation">}</span><span class="token punctuation">,</span> save_file<span class="token punctuation">)</span>


            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"checkpoint has been saved in </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>save_file<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
        <span class="token keyword">if</span> step <span class="token operator">%</span> save_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
            os<span class="token punctuation">.</span>makedirs<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> exist_ok<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
            save_file <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f"step_</span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>step<span class="token punctuation">}</span></span><span class="token string">.pt"</span></span><span class="token punctuation">)</span>
            torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span>
                <span class="token string">"epoch"</span><span class="token punctuation">:</span> epoch_index<span class="token punctuation">,</span>
                <span class="token string">"step"</span><span class="token punctuation">:</span> step<span class="token punctuation">,</span>
                <span class="token string">"model_state_dict"</span><span class="token punctuation">:</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                <span class="token string">'optimizer_state_dict'</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                <span class="token string">'loss'</span><span class="token punctuation">:</span> bce_loss<span class="token punctuation">,</span>
                <span class="token string">'accuracy'</span><span class="token punctuation">:</span> accuracy<span class="token punctuation">,</span>
                <span class="token string">'precision'</span><span class="token punctuation">:</span> precision<span class="token punctuation">,</span>
                <span class="token string">'recall'</span><span class="token punctuation">:</span> recall<span class="token punctuation">,</span>
                <span class="token string">'f1'</span><span class="token punctuation">:</span> f1
            <span class="token punctuation">}</span><span class="token punctuation">,</span> save_file<span class="token punctuation">)</span>

            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"checkpoint has been saved in </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>save_file<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>

        <span class="token keyword">if</span> step <span class="token operator">%</span> eval_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string">"start to do evaluation..."</span><span class="token punctuation">)</span>
            model<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
            ema_eval_loss <span class="token operator">=</span> <span class="token number">0</span>
            total_acc_account <span class="token operator">=</span> <span class="token number">0</span>
            total_account <span class="token operator">=</span> <span class="token number">0</span>
            true_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
            predicted_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>

            <span class="token keyword">for</span> eval_batch_index<span class="token punctuation">,</span> <span class="token punctuation">(</span>eval_target<span class="token punctuation">,</span> eval_token_index<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>eval_data_loader<span class="token punctuation">)</span><span class="token punctuation">:</span>
                total_account <span class="token operator">+=</span> eval_target<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
                eval_logits <span class="token operator">=</span> model<span class="token punctuation">(</span>eval_token_index<span class="token punctuation">)</span>
                total_acc_account <span class="token operator">+=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>eval_logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">==</span> eval_target<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
                eval_bce_loss <span class="token operator">=</span> F<span class="token punctuation">.</span>binary_cross_entropy<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>eval_logits<span class="token punctuation">)</span><span class="token punctuation">,</span>
                                                       F<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>eval_target<span class="token punctuation">,</span> num_classes<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>float32<span class="token punctuation">)</span><span class="token punctuation">)</span>
                ema_eval_loss <span class="token operator">=</span> <span class="token number">0.9</span> <span class="token operator">*</span> ema_eval_loss <span class="token operator">+</span> <span class="token number">0.1</span> <span class="token operator">*</span> eval_bce_loss

                true_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>eval_target<span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
                predicted_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>eval_logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

            accuracy <span class="token operator">=</span> accuracy_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>
            precision <span class="token operator">=</span> precision_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>
            recall <span class="token operator">=</span> recall_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>
            f1 <span class="token operator">=</span> f1_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>

            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"ema_eval_loss: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>ema_eval_loss<span class="token punctuation">}</span></span><span class="token string">, eval_acc: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>total_acc_account <span class="token operator">/</span> total_account<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"Precision: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>precision<span class="token punctuation">}</span></span><span class="token string">, Recall: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>recall<span class="token punctuation">}</span></span><span class="token string">, F1: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>f1<span class="token punctuation">}</span></span><span class="token string">, Accuracy: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>accuracy<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
            model<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span>

model = GCNN()
# model = TextClassificationModel()
print(“模型总参数:”, sum(p.numel() for p in model.parameters()))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_data_iter = IMDB(root=“data”, split=“train”) # Dataset类型的对象
train_data_loader = torch.utils.data.DataLoader(
to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

eval_data_iter = IMDB(root=“data”, split=“test”) # Dataset类型的对象
# collate校对
eval_data_loader = utils.data.DataLoader(
to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)

# resume = “./data/step_500.pt”
resume = “”

train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, save_step_interval = 500, eval_step_interval = 300, save_path = “./log_imdb_text_classification2”, resume = resume)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252

🥦代码分析

🥦导库

首先导入需要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torch import utils
import torchtext
from tqdm import tqdm
from torchtext.datasets import IMDB

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • torch (PyTorch):
    PyTorch 是一个用于机器学习和深度学习的开源深度学习框架。它提供了张量计算、自动微分、神经网络层和优化器等功能,使用户能够构建和训练深度学习模型。

  • torch.nn:
    torch.nn 模块包含了PyTorch中用于构建神经网络模型的类和函数。它包括各种神经网络层、损失函数和优化器等。

  • torch.nn.functional:
    torch.nn.functional 模块提供了一组函数,用于构建神经网络的非参数化操作,如激活函数、池化和卷积等。这些函数通常与torch.nn一起使用。

  • sklearn.metrics (scikit-learn):
    scikit-learn是一个用于机器学习的Python库,其中包含了一系列用于评估模型性能的度量工具。导入的precision_score、recall_score、f1_score 和 accuracy_score 用于计算分类模型的精确度、召回率、F1分数和准确性。

  • torch.utils:
    torch.utils 包含了一些实用工具和数据加载相关的函数。在这段代码中,它用于构建数据加载器。

  • torchtext:
    torchtext 是一个PyTorch的自然语言处理库,用于文本数据的处理和加载。它提供了用于文本数据预处理和构建数据集的功能。

  • tqdm:
    tqdm 是一个Python库,用于创建进度条,可用于监视循环迭代的进度。在代码中,它用于显示训练和评估的进度。

  • torchtext.datasets.IMDB:
    torchtext.datasets.IMDB 是TorchText库中的一个数据集,包含了IMDb电影评论的数据。这些评论用于情感分析任务,其中评论被标记为积极或消极。

🥦设置日志

logging.basicConfig(
    level=logging.WARN, stream=sys.stdout, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)

 
 
  • 1
  • 2
  • 3

在代码中设置日志的作用是记录程序的运行状态、调试信息和重要事件,以便在开发和生产环境中更轻松地诊断问题和了解程序的行为。设置日志有以下作用:

  • 问题诊断:当程序出现错误或异常时,日志记录可以提供有关错误发生的位置、原因和上下文的信息。这有助于开发人员快速定位和修复问题。

  • 性能分析:通过记录程序的运行时间和关键操作的时间戳,日志可以用于性能分析,帮助开发人员识别潜在的性能瓶颈。

  • 跟踪进度:在长时间运行的任务中,例如训练深度学习模型,日志记录可以帮助跟踪任务的进度,以便了解训练状态、完成的步骤和剩余时间。

  • 监控和警报:日志可以与监控系统集成,以便在发生关键事件或异常情况时触发警报。这对于及时响应问题非常重要。

  • 审计和合规:在某些应用中,日志记录是合规性的一部分,用于追踪系统的操作和用户的活动。日志可以用于审计和调查。

在上述代码中,设置日志的目的是跟踪训练进度、记录训练损失以及保存检查点。它允许开发人员监视模型训练的进展并在需要时查看详细信息,例如损失值和评估指标。此外,日志还可以用于调试和查看模型性能。

🥦模型定义

代码定义了两个模型:

GCNN:用于文本分类的门控卷积神经网络。
TextClassificationModel:使用嵌入和线性层的简单文本分类模型。

 
 
  • 1
  • 2
🥦GCNN
class GCNN(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2):
        super(GCNN, self).__init__()
    self<span class="token punctuation">.</span>embedding_table <span class="token operator">=</span> nn<span class="token punctuation">.</span>Embedding<span class="token punctuation">(</span>vocab_size<span class="token punctuation">,</span> embedding_dim<span class="token punctuation">)</span>
    nn<span class="token punctuation">.</span>init<span class="token punctuation">.</span>xavier_uniform_<span class="token punctuation">(</span>self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">.</span>weight<span class="token punctuation">)</span>

    <span class="token comment"># 都是1维卷积</span>
    self<span class="token punctuation">.</span>conv_A_1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span>embedding_dim<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>conv_B_1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span>embedding_dim<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>

    self<span class="token punctuation">.</span>conv_A_2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>conv_B_2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>

    self<span class="token punctuation">.</span>output_linear1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>output_linear2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> num_class<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> word_index<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token triple-quoted-string string">"""
    定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出
    """</span>
    <span class="token comment"># 1. 通过word_index得到word_embedding</span>
    <span class="token comment"># word_index shape: [bs, max_seq_len]</span>
    word_embedding <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">(</span>word_index<span class="token punctuation">)</span>  <span class="token comment"># [bs, max_seq_len, embedding_dim]</span>

    <span class="token comment"># 2. 编写第一层1D门卷积模块,通道数在第2维</span>
    word_embedding <span class="token operator">=</span> word_embedding<span class="token punctuation">.</span>transpose<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span>  <span class="token comment"># [bs, embedding_dim, max_seq_len]</span>
    A <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_A_1<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span>
    B <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_B_1<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span>
    H <span class="token operator">=</span> A <span class="token operator">*</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>B<span class="token punctuation">)</span>  <span class="token comment"># [bs, 64, max_seq_len]</span>

    A <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_A_2<span class="token punctuation">(</span>H<span class="token punctuation">)</span>
    B <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_B_2<span class="token punctuation">(</span>H<span class="token punctuation">)</span>
    H <span class="token operator">=</span> A <span class="token operator">*</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>B<span class="token punctuation">)</span>  <span class="token comment"># [bs, 64, max_seq_len]</span>

    <span class="token comment"># 3. 池化并经过全连接层</span>
    pool_output <span class="token operator">=</span> torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span>H<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span>  <span class="token comment"># 平均池化,得到[bs, 4096]</span>
    linear1_output <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear1<span class="token punctuation">(</span>pool_output<span class="token punctuation">)</span>

    <span class="token comment"># 最后一层需要设置为隐含层数目</span>
    logits <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear2<span class="token punctuation">(</span>linear1_output<span class="token punctuation">)</span>  <span class="token comment"># [bs, 2]</span>

    <span class="token keyword">return</span> logits
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
🥦TextClassificationModel
class TextClassificationModel(nn.Module):
    """
    简单版embedding.DNN模型
    """
<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> vocab_size<span class="token operator">=</span>VOCAB_SIZE<span class="token punctuation">,</span> embed_dim<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> num_class<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token builtin">super</span><span class="token punctuation">(</span>TextClassificationModel<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>embedding <span class="token operator">=</span> nn<span class="token punctuation">.</span>EmbeddingBag<span class="token punctuation">(</span>vocab_size<span class="token punctuation">,</span> embed_dim<span class="token punctuation">,</span> sparse<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
    self<span class="token punctuation">.</span>fc <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>embed_dim<span class="token punctuation">,</span> num_class<span class="token punctuation">)</span>

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> token_index<span class="token punctuation">)</span><span class="token punctuation">:</span>
    <span class="token comment"># 词袋</span>
    embedded <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span>  <span class="token comment"># shape: [bs, embedding_dim]</span>
    <span class="token keyword">return</span> self<span class="token punctuation">.</span>fc<span class="token punctuation">(</span>embedded<span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

🥦准备IMDb数据集

这行代码使用TorchText的IMDB数据集对象,导入IMDb数据集的训练集部分。

# 数据集导入
train_data_iter = IMDB(root="./data", split="train")

 
 
  • 1
  • 2

这行代码创建了一个用于将文本分词为单词的分词器。

# 数据预处理
tokenizer = get_tokenizer("basic_english")

 
 
  • 1
  • 2

这里,build_vocab_from_iterator 函数根据文本数据创建了一个词汇表,只包括出现频率大于等于20次的单词。特殊标记用于处理未知单词。然后,set_default_index将特殊标记的索引设置为0。

# 构建词汇表
vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=["<unk>"])
vocab.set_default_index(0)

 
 
  • 1
  • 2
  • 3

这是一个自定义的校对函数,用于处理DataLoader返回的批次数据,将文本转换为可以输入模型的张量形式。

def collate_fn(batch):
    """
    对DataLoader所生成的mini-batch进行后处理
    """
    target = []
    token_index = []
    max_length = 0  # 最大的token长度
    for i, (label, comment) in enumerate(batch):
        tokens = tokenizer(comment)
        token_index.append(vocab(tokens))  # 字符列表转换为索引列表
    <span class="token comment"># 确定最大的句子长度</span>
    <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>tokens<span class="token punctuation">)</span> <span class="token operator">&gt;</span> max_length<span class="token punctuation">:</span>
        max_length <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>tokens<span class="token punctuation">)</span>

    <span class="token keyword">if</span> label <span class="token operator">==</span> <span class="token string">"pos"</span><span class="token punctuation">:</span>
        target<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span>
    <span class="token keyword">else</span><span class="token punctuation">:</span>
        target<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span>

token_index <span class="token operator">=</span> <span class="token punctuation">[</span>index <span class="token operator">+</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> <span class="token punctuation">(</span>max_length <span class="token operator">-</span> <span class="token builtin">len</span><span class="token punctuation">(</span>index<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">for</span> index <span class="token keyword">in</span> token_index<span class="token punctuation">]</span>
<span class="token comment"># one-hot接收长整形的数据,所以要转换为int64</span>
<span class="token keyword">return</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>target<span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>int64<span class="token punctuation">)</span><span class="token punctuation">,</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>int32<span class="token punctuation">)</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

这行代码将IMDb训练数据集加载到DataLoader对象中,以便进行模型训练。collate_fn函数用于处理数据的批处理。

train_data_loader = torch.utils.data.DataLoader(
    to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

 
 
  • 1
  • 2

上述代码块执行了IMDb数据集的准备工作,包括导入数据、分词、构建词汇表和设置数据加载器。这些步骤是为了使数据集可用于训练文本分类模型。

🥦整理函数

这个 collate_fn 函数用于对 DataLoader 批次中的数据进行处理,确保每个批次中的文本序列具有相同的长度,并将标签转换为适用于模型输入的张量形式。它的工作包括以下几个方面:

提取标签和评论文本。
使用分词器将评论文本分词为单词。
确定批次中最长评论的长度。
根据最长评论的长度,将所有评论的单词索引序列填充到相同的长度。
将标签转换为适当的张量形式(这里是将标签转换为长整数型)。
返回处理后的批次数据,其中包括标签和填充后的单词索引序列。

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

这个整理函数确保了模型在训练期间能够处理不同长度的文本序列,并将它们转换为模型可接受的张量输入。

🥦训练函数

def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval,  eval_step_interval, save_path, resume=""):
    """
    此处data_loader是map-style dataset
    """
    start_epoch = 0
    start_step = 0
    if resume != "":
        # 加载之前训练过的模型的参数文件
        logging.warning(f"loading from resume")
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        start_step = checkpoint['step']
<span class="token keyword">for</span> epoch_index <span class="token keyword">in</span> tqdm<span class="token punctuation">(</span><span class="token builtin">range</span><span class="token punctuation">(</span>start_epoch<span class="token punctuation">,</span> num_epoch<span class="token punctuation">)</span><span class="token punctuation">,</span> desc<span class="token operator">=</span><span class="token string">"epoch"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
    ema_loss <span class="token operator">=</span> <span class="token number">0</span>
    total_acc_account <span class="token operator">=</span> <span class="token number">0</span>
    total_account <span class="token operator">=</span> <span class="token number">0</span>
    true_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
    predicted_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
    num_batches <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>train_data_loader<span class="token punctuation">)</span>
    <span class="token keyword">for</span> batch_index<span class="token punctuation">,</span> <span class="token punctuation">(</span>target<span class="token punctuation">,</span> token_index<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>train_data_loader<span class="token punctuation">)</span><span class="token punctuation">:</span>
        optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
        step <span class="token operator">=</span> num_batches <span class="token operator">*</span> <span class="token punctuation">(</span>epoch_index<span class="token punctuation">)</span> <span class="token operator">+</span> batch_index <span class="token operator">+</span> <span class="token number">1</span>
        logits <span class="token operator">=</span> model<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span>
        <span class="token comment"># one-hot需要转换float32才可以训练</span>
        bce_loss <span class="token operator">=</span> F<span class="token punctuation">.</span>binary_cross_entropy<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>logits<span class="token punctuation">)</span><span class="token punctuation">,</span> F<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>target<span class="token punctuation">,</span> num_classes<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>float32<span class="token punctuation">)</span><span class="token punctuation">)</span>
        ema_loss <span class="token operator">=</span> <span class="token number">0.9</span> <span class="token operator">*</span> ema_loss <span class="token operator">+</span> <span class="token number">0.1</span> <span class="token operator">*</span> bce_loss  <span class="token comment"># 指数平均loss</span>
        bce_loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
        nn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>clip_grad_norm_<span class="token punctuation">(</span>model<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">)</span>  <span class="token comment"># 梯度的正则进行截断,保证训练稳定</span>
        optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span>  <span class="token comment"># 更新参数</span>

        true_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>target<span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
        predicted_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

        <span class="token keyword">if</span> step <span class="token operator">%</span> log_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"epoch_index: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>epoch_index<span class="token punctuation">}</span></span><span class="token string">, batch_index: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>batch_index<span class="token punctuation">}</span></span><span class="token string">, ema_loss: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>ema_loss<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>

        <span class="token keyword">if</span> step <span class="token operator">%</span> save_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
            os<span class="token punctuation">.</span>makedirs<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> exist_ok<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
            save_file <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f"step_</span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>step<span class="token punctuation">}</span></span><span class="token string">.pt"</span></span><span class="token punctuation">)</span>
            torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span>
                <span class="token string">"epoch"</span><span class="token punctuation">:</span> epoch_index<span class="token punctuation">,</span>
                <span class="token string">"step"</span><span class="token punctuation">:</span> step<span class="token punctuation">,</span>
                <span class="token string">"model_state_dict"</span><span class="token punctuation">:</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                <span class="token string">'optimizer_state_dict'</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                <span class="token string">'loss'</span><span class="token punctuation">:</span> bce_loss
            <span class="token punctuation">}</span><span class="token punctuation">,</span> save_file<span class="token punctuation">)</span>


            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"checkpoint has been saved in </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>save_file<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
        <span class="token keyword">if</span> step <span class="token operator">%</span> save_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
            os<span class="token punctuation">.</span>makedirs<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> exist_ok<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
            save_file <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f"step_</span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>step<span class="token punctuation">}</span></span><span class="token string">.pt"</span></span><span class="token punctuation">)</span>
            torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span>
                <span class="token string">"epoch"</span><span class="token punctuation">:</span> epoch_index<span class="token punctuation">,</span>
                <span class="token string">"step"</span><span class="token punctuation">:</span> step<span class="token punctuation">,</span>
                <span class="token string">"model_state_dict"</span><span class="token punctuation">:</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                <span class="token string">'optimizer_state_dict'</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
                <span class="token string">'loss'</span><span class="token punctuation">:</span> bce_loss<span class="token punctuation">,</span>
                <span class="token string">'accuracy'</span><span class="token punctuation">:</span> accuracy<span class="token punctuation">,</span>
                <span class="token string">'precision'</span><span class="token punctuation">:</span> precision<span class="token punctuation">,</span>
                <span class="token string">'recall'</span><span class="token punctuation">:</span> recall<span class="token punctuation">,</span>
                <span class="token string">'f1'</span><span class="token punctuation">:</span> f1
            <span class="token punctuation">}</span><span class="token punctuation">,</span> save_file<span class="token punctuation">)</span>

            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"checkpoint has been saved in </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>save_file<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>

        <span class="token keyword">if</span> step <span class="token operator">%</span> eval_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string">"start to do evaluation..."</span><span class="token punctuation">)</span>
            model<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
            ema_eval_loss <span class="token operator">=</span> <span class="token number">0</span>
            total_acc_account <span class="token operator">=</span> <span class="token number">0</span>
            total_account <span class="token operator">=</span> <span class="token number">0</span>
            true_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
            predicted_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>

            <span class="token keyword">for</span> eval_batch_index<span class="token punctuation">,</span> <span class="token punctuation">(</span>eval_target<span class="token punctuation">,</span> eval_token_index<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>eval_data_loader<span class="token punctuation">)</span><span class="token punctuation">:</span>
                total_account <span class="token operator">+=</span> eval_target<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
                eval_logits <span class="token operator">=</span> model<span class="token punctuation">(</span>eval_token_index<span class="token punctuation">)</span>
                total_acc_account <span class="token operator">+=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>eval_logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">==</span> eval_target<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
                eval_bce_loss <span class="token operator">=</span> F<span class="token punctuation">.</span>binary_cross_entropy<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>eval_logits<span class="token punctuation">)</span><span class="token punctuation">,</span>
                                                       F<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>eval_target<span class="token punctuation">,</span> num_classes<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>float32<span class="token punctuation">)</span><span class="token punctuation">)</span>
                ema_eval_loss <span class="token operator">=</span> <span class="token number">0.9</span> <span class="token operator">*</span> ema_eval_loss <span class="token operator">+</span> <span class="token number">0.1</span> <span class="token operator">*</span> eval_bce_loss

                true_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>eval_target<span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
                predicted_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>eval_logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>

            accuracy <span class="token operator">=</span> accuracy_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>
            precision <span class="token operator">=</span> precision_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>
            recall <span class="token operator">=</span> recall_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>
            f1 <span class="token operator">=</span> f1_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>

            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"ema_eval_loss: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>ema_eval_loss<span class="token punctuation">}</span></span><span class="token string">, eval_acc: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>total_acc_account <span class="token operator">/</span> total_account<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
            logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"Precision: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>precision<span class="token punctuation">}</span></span><span class="token string">, Recall: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>recall<span class="token punctuation">}</span></span><span class="token string">, F1: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>f1<span class="token punctuation">}</span></span><span class="token string">, Accuracy: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>accuracy<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
            model<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span>
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97

这段代码定义了一个名为 train 的函数,用于执行训练过程。下面是该函数的详细说明:

train 函数接受以下参数:
    train_data_loader: 训练数据的 DataLoader,用于迭代训练数据。
    eval_data_loader: 用于评估的 DataLoader,用于评估模型性能。
    model: 要训练的神经网络模型。
    optimizer: 用于更新模型参数的优化器。
    num_epoch: 训练的总周期数。
    log_step_interval: 记录日志的间隔步数。
    save_step_interval: 保存模型检查点的间隔步数。
    eval_step_interval: 执行评估的间隔步数。
    save_path: 保存模型检查点的目录。
    resume: 可选的,用于恢复训练的检查点文件路径。

训练函数的主要工作如下:
它首先检查是否有恢复训练的检查点文件。如果有,它会加载之前训练的模型参数和优化器状态,以便继续训练。
然后,它开始进行一系列的训练周期(epochs),每个周期内包含多个训练步(batches)。
在每个训练步中,它执行以下操作:
零化梯度,以准备更新模型参数。
计算模型的预测输出(logits)。
计算二进制交叉熵损失(binary cross-entropy loss)。
使用反向传播(backpropagation)计算梯度并更新模型参数。
记录损失、真实标签和预测标签。
如果步数达到了 log_step_interval,则记录损失。
如果步数达到了 save_step_interval,则保存模型检查点。
如果步数达到了 eval_step_interval,则执行评估:
将模型切换到评估模式(model.eval())。
对评估数据集中的每个批次执行以下操作:
计算模型的预测输出。
计算二进制交叉熵损失。
计算准确性、精确度、召回率和F1分数。
记录评估损失和评估指标。
将模型切换回训练模式(model.train())。

最后,训练函数返回经过训练的模型。

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

这个训练函数执行了完整的训练过程,包括了模型的前向传播、损失计算、梯度更新、日志记录、模型检查点的保存和评估。通过调用这个函数,你可以训练模型并监视其性能。

🥦模型初始化和优化器

model = GCNN()
# model = TextClassificationModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

 
 
  • 1
  • 2
  • 3

🥦加载用于训练和评估的数据

在提供的代码中,加载用于训练和评估的数据的部分如下:

train_data_iter = IMDB(root="data", split="train")

 
 
  • 1

这一行代码使用 TorchText 的 IMDB 数据集对象,导入 IMDB 数据集的训练集部分。这部分数据将用于模型的训练。

eval_data_iter = IMDB(root="data", split="test")

 
 
  • 1

这一行代码使用 TorchText 的 IMDB 数据集对象,导入 IMDB 数据集的测试集部分。这部分数据将用于评估模型的性能。


之后,这些数据集通过以下代码转化为 DataLoader 对象,以便用于模型训练和评估:

# 训练数据 DataLoader
train_data_loader = torch.utils.data.DataLoader(
    to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

 
 
  • 1
  • 2
  • 3
# 评估数据 DataLoader
eval_data_loader = utils.data.DataLoader(
     to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)

 
 
  • 1
  • 2
  • 3

这些 DataLoader 对象将数据加载到内存中,以便训练和评估使用。collate_fn 函数用于处理数据的批次,确保它们具有适当的格式,以便输入到模型中。

这些部分负责加载和准备用于训练和评估的数据,是机器学习模型训练和评估的重要准备步骤。训练数据用于训练模型,而评估数据用于评估模型的性能。

🥦恢复训练

start_epoch = 0
start_step = 0
if resume != "":
    # 加载之前训练过的模型的参数文件
    logging.warning(f"loading from resume")
    checkpoint = torch.load(resume)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch']
    start_step = checkpoint['step']

 
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

上述代码段位于训练函数中的开头部分,主要用于检查是否有已经训练过的模型的检查点文件,以便继续训练。具体解释如下:

如果 resume 变量不为空(即存在要恢复的检查点文件路径),则执行以下操作:
通过 torch.load 加载之前训练过的模型的检查点文件。
使用 load_state_dict 方法将已保存的模型参数加载到当前的模型中,以便继续训练。
同样,使用 load_state_dict 方法将已保存的优化器状态加载到当前的优化器中,以确保继续从之前的状态开始训练。
获取之前训练的轮数和步数,以便从恢复的状态继续训练。

这部分代码的目的是允许从之前保存的模型检查点继续训练,而不是从头开始。这对于长时间运行的训练任务非常有用,可以在中途中断训练并在之后恢复,而不会丢失之前的训练进度。

🥦调用训练

train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_text_classification2", resume=resume)

 
 
  • 1

🥦保存文件的读取

import torch

# 指定已存在的 .pt 文件路径
file_path = “./log_imdb_text_classification/step_3500.pt” # 替换为实际的文件路径

# 使用 torch.load() 加载文件
checkpoint = torch.load(file_path)

# 查看准确率、精确率、召回率和F1分数
accuracy = checkpoint[“accuracy”]
precision = checkpoint[“precision”]
recall = checkpoint[“recall”]
f1 = checkpoint[“f1”]

print(“Accuracy:”, accuracy)
print(“Precision:”, precision)
print(“Recall:”, recall)
print(“F1 Score:”, f1)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

在这里插入图片描述

🥦扩展 LSTM、GRU

本文原作者使用的是卷积神经网络,但是卷积神经网络的优化模型GCNN,但是这个模型对于图更好,由此我接下来引入两个循环神经网络LSTM和GRU

class LSTMModel(nn.Module):
    def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, hidden_dim=64, num_class=2):
        super(LSTMModel, self).__init__()
        self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True)
        self.output_linear = nn.Linear(hidden_dim, num_class)
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> word_index<span class="token punctuation">)</span><span class="token punctuation">:</span>
    word_embedding <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">(</span>word_index<span class="token punctuation">)</span>
    lstm_out<span class="token punctuation">,</span> _ <span class="token operator">=</span> self<span class="token punctuation">.</span>lstm<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span>
    lstm_out <span class="token operator">=</span> lstm_out<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span>  <span class="token comment"># 取最后一个时间步的输出</span>
    logits <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear<span class="token punctuation">(</span>lstm_out<span class="token punctuation">)</span>
    <span class="token keyword">return</span> logits

class GRUModel(nn.Module):
def init(self, vocab_size=VOCAB_SIZE, embedding_dim=64, hidden_dim=64, num_class=2):
super(GRUModel, self).init()
self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=1, batch_first=True)
self.output_linear = nn.Linear(hidden_dim, num_class)

<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> word_index<span class="token punctuation">)</span><span class="token punctuation">:</span>
    word_embedding <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">(</span>word_index<span class="token punctuation">)</span>
    gru_out<span class="token punctuation">,</span> _ <span class="token operator">=</span> self<span class="token punctuation">.</span>gru<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span>
    gru_out <span class="token operator">=</span> gru_out<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span>  <span class="token comment"># 取最后一个时间步的输出</span>
    logits <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear<span class="token punctuation">(</span>gru_out<span class="token punctuation">)</span>
    <span class="token keyword">return</span> logits
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
# 创建LSTM模型
lstm_model = LSTMModel()
print("模型总参数:", sum(p.numel() for p in lstm_model.parameters()))
lstm_optimizer = torch.optim.Adam(lstm_model.parameters(), lr=0.001)

# 创建GRU模型
# gru_model = GRUModel()
# print(“模型总参数:”, sum(p.numel() for p in gru_model.parameters()))
# gru_optimizer = torch.optim.Adam(gru_model.parameters(), lr=0.001)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
# 训练LSTM模型
train(train_data_loader, eval_data_loader, lstm_model, lstm_optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_lstm", resume="")

# 训练GRU模型
# train(train_data_loader, eval_data_loader, gru_model, gru_optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path=“./log_imdb_gru”, resume=“”)

  • 1
  • 2
  • 3
  • 4
  • 5

感兴趣的小伙伴可以试试,对比一下

🥦总结

本文代码来自网络仅供学习,原文地址

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值