<link rel="stylesheet" href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/editerView/kdoc_html_views-1a98987dfd.css">
<link rel="stylesheet" href="https://csdnimg.cn/release/blogv2/dist/mdeditor/css/editerView/ck_htmledit_views-dc4a025e85.css">
<div id="content_views" class="markdown_views prism-atom-one-dark">
<svg xmlns="http://www.w3.org/2000/svg" style="display: none;">
<path stroke-linecap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);"></path>
</svg>
<blockquote>
💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢迎在文章下方留下你的评论和反馈。我期待着与你分享知识、互相学习和建立一个积极的社区。谢谢你的光临,让我们一起踏上这个知识之旅!
文章目录
🥦引言
本文使用IMDB数据集,结合pytorch进行情感分析
🥦完整代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torch import utils
import torchtext
from tqdm import tqdm
from torchtext.datasets import IMDB
from torchtext.datasets.imdb import NUM_LINES
from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import to_map_style_dataset
import os
import sys
import logging
import logging
logging.basicConfig(
level=logging.WARN, stream=sys.stdout, format = “%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s”)
VOCAB_SIZE = 15000
# step1 编写GCNN模型代码,门(Gate)卷积网络
class GCNN(nn.Module):
def init(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2):
super(GCNN, self).init()
self<span class="token punctuation">.</span>embedding_table <span class="token operator">=</span> nn<span class="token punctuation">.</span>Embedding<span class="token punctuation">(</span>vocab_size<span class="token punctuation">,</span> embedding_dim<span class="token punctuation">)</span>
nn<span class="token punctuation">.</span>init<span class="token punctuation">.</span>xavier_uniform_<span class="token punctuation">(</span>self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">.</span>weight<span class="token punctuation">)</span>
<span class="token comment"># 都是1维卷积</span>
self<span class="token punctuation">.</span>conv_A_1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span>embedding_dim<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv_B_1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span>embedding_dim<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv_A_2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>conv_B_2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>output_linear1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>output_linear2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> num_class<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> word_index<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token triple-quoted-string string">"""
定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出
"""</span>
<span class="token comment"># 1. 通过word_index得到word_embedding</span>
<span class="token comment"># word_index shape: [bs, max_seq_len]</span>
word_embedding <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">(</span>word_index<span class="token punctuation">)</span> <span class="token comment"># [bs, max_seq_len, embedding_dim]</span>
<span class="token comment"># 2. 编写第一层1D门卷积模块,通道数在第2维</span>
word_embedding <span class="token operator">=</span> word_embedding<span class="token punctuation">.</span>transpose<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span> <span class="token comment"># [bs, embedding_dim, max_seq_len]</span>
A <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_A_1<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span>
B <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_B_1<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span>
H <span class="token operator">=</span> A <span class="token operator">*</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>B<span class="token punctuation">)</span> <span class="token comment"># [bs, 64, max_seq_len]</span>
A <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_A_2<span class="token punctuation">(</span>H<span class="token punctuation">)</span>
B <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_B_2<span class="token punctuation">(</span>H<span class="token punctuation">)</span>
H <span class="token operator">=</span> A <span class="token operator">*</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>B<span class="token punctuation">)</span> <span class="token comment"># [bs, 64, max_seq_len]</span>
<span class="token comment"># 3. 池化并经过全连接层</span>
pool_output <span class="token operator">=</span> torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span>H<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment"># 平均池化,得到[bs, 4096]</span>
linear1_output <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear1<span class="token punctuation">(</span>pool_output<span class="token punctuation">)</span>
<span class="token comment"># 最后一层需要设置为隐含层数目</span>
logits <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear2<span class="token punctuation">(</span>linear1_output<span class="token punctuation">)</span> <span class="token comment"># [bs, 2]</span>
<span class="token keyword">return</span> logits
# PyTorch官网的简单模型
class TextClassificationModel(nn.Module):
“”"
简单版embedding.DNN模型
“”"
<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> vocab_size<span class="token operator">=</span>VOCAB_SIZE<span class="token punctuation">,</span> embed_dim<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> num_class<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token builtin">super</span><span class="token punctuation">(</span>TextClassificationModel<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>embedding <span class="token operator">=</span> nn<span class="token punctuation">.</span>EmbeddingBag<span class="token punctuation">(</span>vocab_size<span class="token punctuation">,</span> embed_dim<span class="token punctuation">,</span> sparse<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span>
self<span class="token punctuation">.</span>fc <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>embed_dim<span class="token punctuation">,</span> num_class<span class="token punctuation">)</span>
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> token_index<span class="token punctuation">)</span><span class="token punctuation">:</span>
<span class="token comment"># 词袋</span>
embedded <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span> <span class="token comment"># shape: [bs, embedding_dim]</span>
<span class="token keyword">return</span> self<span class="token punctuation">.</span>fc<span class="token punctuation">(</span>embedded<span class="token punctuation">)</span>
# step2 构建IMDB Dataloader
BATCH_SIZE = 64
def yeild_tokens(train_data_iter, tokenizer):
for i, sample in enumerate(train_data_iter):
label, comment = sample
yield tokenizer(comment) # 字符串转换为token索引的列表
train_data_iter = IMDB(root=“./data”, split=“train”) # Dataset类型的对象
tokenizer = get_tokenizer(“basic_english”)
# 只使用出现次数大约20的token
vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=[“<unk>”])
vocab.set_default_index(0) # 特殊索引设置为0
print(f’单词表大小: len(vocab)')
# 校对函数, batch是dataset返回值,主要是处理batch一组数据
def collate_fn(batch):
“”"
对DataLoader所生成的mini-batch进行后处理
“”"
target = []
token_index = []
max_length = 0 # 最大的token长度
for i, (label, comment) in enumerate(batch):
tokens = tokenizer(comment)
token_index.append(vocab(tokens)) # 字符列表转换为索引列表
<span class="token comment"># 确定最大的句子长度</span>
<span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>tokens<span class="token punctuation">)</span> <span class="token operator">></span> max_length<span class="token punctuation">:</span>
max_length <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>tokens<span class="token punctuation">)</span>
<span class="token keyword">if</span> label <span class="token operator">==</span> <span class="token string">"pos"</span><span class="token punctuation">:</span>
target<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span>
<span class="token keyword">else</span><span class="token punctuation">:</span>
target<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span>
token_index <span class="token operator">=</span> <span class="token punctuation">[</span>index <span class="token operator">+</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> <span class="token punctuation">(</span>max_length <span class="token operator">-</span> <span class="token builtin">len</span><span class="token punctuation">(</span>index<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">for</span> index <span class="token keyword">in</span> token_index<span class="token punctuation">]</span>
<span class="token comment"># one-hot接收长整形的数据,所以要转换为int64</span>
<span class="token keyword">return</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>target<span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>int64<span class="token punctuation">)</span><span class="token punctuation">,</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>int32<span class="token punctuation">)</span><span class="token punctuation">)</span>
# step3 编写训练代码
def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval, eval_step_interval, save_path, resume=“”):
“”"
此处data_loader是map-style dataset
“”“
start_epoch = 0
start_step = 0
if resume != ”“:
# 加载之前训练过的模型的参数文件
logging.warning(f"loading from resume”)
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint[‘model_state_dict’])
optimizer.load_state_dict(checkpoint[‘optimizer_state_dict’])
start_epoch = checkpoint[‘epoch’]
start_step = checkpoint[‘step’]
<span class="token keyword">for</span> epoch_index <span class="token keyword">in</span> tqdm<span class="token punctuation">(</span><span class="token builtin">range</span><span class="token punctuation">(</span>start_epoch<span class="token punctuation">,</span> num_epoch<span class="token punctuation">)</span><span class="token punctuation">,</span> desc<span class="token operator">=</span><span class="token string">"epoch"</span><span class="token punctuation">)</span><span class="token punctuation">:</span>
ema_loss <span class="token operator">=</span> <span class="token number">0</span>
total_acc_account <span class="token operator">=</span> <span class="token number">0</span>
total_account <span class="token operator">=</span> <span class="token number">0</span>
true_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
predicted_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
num_batches <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>train_data_loader<span class="token punctuation">)</span>
<span class="token keyword">for</span> batch_index<span class="token punctuation">,</span> <span class="token punctuation">(</span>target<span class="token punctuation">,</span> token_index<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>train_data_loader<span class="token punctuation">)</span><span class="token punctuation">:</span>
optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span>
step <span class="token operator">=</span> num_batches <span class="token operator">*</span> <span class="token punctuation">(</span>epoch_index<span class="token punctuation">)</span> <span class="token operator">+</span> batch_index <span class="token operator">+</span> <span class="token number">1</span>
logits <span class="token operator">=</span> model<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span>
<span class="token comment"># one-hot需要转换float32才可以训练</span>
bce_loss <span class="token operator">=</span> F<span class="token punctuation">.</span>binary_cross_entropy<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>logits<span class="token punctuation">)</span><span class="token punctuation">,</span> F<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>target<span class="token punctuation">,</span> num_classes<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>float32<span class="token punctuation">)</span><span class="token punctuation">)</span>
ema_loss <span class="token operator">=</span> <span class="token number">0.9</span> <span class="token operator">*</span> ema_loss <span class="token operator">+</span> <span class="token number">0.1</span> <span class="token operator">*</span> bce_loss <span class="token comment"># 指数平均loss</span>
bce_loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span>
nn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>clip_grad_norm_<span class="token punctuation">(</span>model<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">)</span> <span class="token comment"># 梯度的正则进行截断,保证训练稳定</span>
optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 更新参数</span>
true_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>target<span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
predicted_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
<span class="token keyword">if</span> step <span class="token operator">%</span> log_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"epoch_index: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>epoch_index<span class="token punctuation">}</span></span><span class="token string">, batch_index: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>batch_index<span class="token punctuation">}</span></span><span class="token string">, ema_loss: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>ema_loss<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
<span class="token keyword">if</span> step <span class="token operator">%</span> save_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
os<span class="token punctuation">.</span>makedirs<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> exist_ok<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
save_file <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f"step_</span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>step<span class="token punctuation">}</span></span><span class="token string">.pt"</span></span><span class="token punctuation">)</span>
torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span>
<span class="token string">"epoch"</span><span class="token punctuation">:</span> epoch_index<span class="token punctuation">,</span>
<span class="token string">"step"</span><span class="token punctuation">:</span> step<span class="token punctuation">,</span>
<span class="token string">"model_state_dict"</span><span class="token punctuation">:</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token string">'optimizer_state_dict'</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token string">'loss'</span><span class="token punctuation">:</span> bce_loss
<span class="token punctuation">}</span><span class="token punctuation">,</span> save_file<span class="token punctuation">)</span>
logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"checkpoint has been saved in </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>save_file<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
<span class="token keyword">if</span> step <span class="token operator">%</span> save_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
os<span class="token punctuation">.</span>makedirs<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> exist_ok<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span>
save_file <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f"step_</span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>step<span class="token punctuation">}</span></span><span class="token string">.pt"</span></span><span class="token punctuation">)</span>
torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span>
<span class="token string">"epoch"</span><span class="token punctuation">:</span> epoch_index<span class="token punctuation">,</span>
<span class="token string">"step"</span><span class="token punctuation">:</span> step<span class="token punctuation">,</span>
<span class="token string">"model_state_dict"</span><span class="token punctuation">:</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token string">'optimizer_state_dict'</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span>
<span class="token string">'loss'</span><span class="token punctuation">:</span> bce_loss<span class="token punctuation">,</span>
<span class="token string">'accuracy'</span><span class="token punctuation">:</span> accuracy<span class="token punctuation">,</span>
<span class="token string">'precision'</span><span class="token punctuation">:</span> precision<span class="token punctuation">,</span>
<span class="token string">'recall'</span><span class="token punctuation">:</span> recall<span class="token punctuation">,</span>
<span class="token string">'f1'</span><span class="token punctuation">:</span> f1
<span class="token punctuation">}</span><span class="token punctuation">,</span> save_file<span class="token punctuation">)</span>
logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"checkpoint has been saved in </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>save_file<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
<span class="token keyword">if</span> step <span class="token operator">%</span> eval_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span>
logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string">"start to do evaluation..."</span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span>
ema_eval_loss <span class="token operator">=</span> <span class="token number">0</span>
total_acc_account <span class="token operator">=</span> <span class="token number">0</span>
total_account <span class="token operator">=</span> <span class="token number">0</span>
true_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
predicted_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span>
<span class="token keyword">for</span> eval_batch_index<span class="token punctuation">,</span> <span class="token punctuation">(</span>eval_target<span class="token punctuation">,</span> eval_token_index<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>eval_data_loader<span class="token punctuation">)</span><span class="token punctuation">:</span>
total_account <span class="token operator">+=</span> eval_target<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span>
eval_logits <span class="token operator">=</span> model<span class="token punctuation">(</span>eval_token_index<span class="token punctuation">)</span>
total_acc_account <span class="token operator">+=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>eval_logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">==</span> eval_target<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span>
eval_bce_loss <span class="token operator">=</span> F<span class="token punctuation">.</span>binary_cross_entropy<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>eval_logits<span class="token punctuation">)</span><span class="token punctuation">,</span>
F<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>eval_target<span class="token punctuation">,</span> num_classes<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>float32<span class="token punctuation">)</span><span class="token punctuation">)</span>
ema_eval_loss <span class="token operator">=</span> <span class="token number">0.9</span> <span class="token operator">*</span> ema_eval_loss <span class="token operator">+</span> <span class="token number">0.1</span> <span class="token operator">*</span> eval_bce_loss
true_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>eval_target<span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
predicted_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>eval_logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span>
accuracy <span class="token operator">=</span> accuracy_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>
precision <span class="token operator">=</span> precision_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>
recall <span class="token operator">=</span> recall_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>
f1 <span class="token operator">=</span> f1_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span>
logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"ema_eval_loss: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>ema_eval_loss<span class="token punctuation">}</span></span><span class="token string">, eval_acc: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>total_acc_account <span class="token operator">/</span> total_account<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"Precision: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>precision<span class="token punctuation">}</span></span><span class="token string">, Recall: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>recall<span class="token punctuation">}</span></span><span class="token string">, F1: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>f1<span class="token punctuation">}</span></span><span class="token string">, Accuracy: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>accuracy<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span>
model<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span>
model = GCNN()
# model = TextClassificationModel()
print(“模型总参数:”, sum(p.numel() for p in model.parameters()))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train_data_iter = IMDB(root=“data”, split=“train”) # Dataset类型的对象
train_data_loader = torch.utils.data.DataLoader(
to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
eval_data_iter = IMDB(root=“data”, split=“test”) # Dataset类型的对象
# collate校对
eval_data_loader = utils.data.DataLoader(
to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)
# resume = “./data/step_500.pt”
resume = “”
train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, save_step_interval = 500, eval_step_interval = 300, save_path = “./log_imdb_text_classification2”, resume = resume)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
- 98
- 99
- 100
- 101
- 102
- 103
- 104
- 105
- 106
- 107
- 108
- 109
- 110
- 111
- 112
- 113
- 114
- 115
- 116
- 117
- 118
- 119
- 120
- 121
- 122
- 123
- 124
- 125
- 126
- 127
- 128
- 129
- 130
- 131
- 132
- 133
- 134
- 135
- 136
- 137
- 138
- 139
- 140
- 141
- 142
- 143
- 144
- 145
- 146
- 147
- 148
- 149
- 150
- 151
- 152
- 153
- 154
- 155
- 156
- 157
- 158
- 159
- 160
- 161
- 162
- 163
- 164
- 165
- 166
- 167
- 168
- 169
- 170
- 171
- 172
- 173
- 174
- 175
- 176
- 177
- 178
- 179
- 180
- 181
- 182
- 183
- 184
- 185
- 186
- 187
- 188
- 189
- 190
- 191
- 192
- 193
- 194
- 195
- 196
- 197
- 198
- 199
- 200
- 201
- 202
- 203
- 204
- 205
- 206
- 207
- 208
- 209
- 210
- 211
- 212
- 213
- 214
- 215
- 216
- 217
- 218
- 219
- 220
- 221
- 222
- 223
- 224
- 225
- 226
- 227
- 228
- 229
- 230
- 231
- 232
- 233
- 234
- 235
- 236
- 237
- 238
- 239
- 240
- 241
- 242
- 243
- 244
- 245
- 246
- 247
- 248
- 249
- 250
- 251
- 252
🥦代码分析
🥦导库
首先导入需要的库
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torch import utils
import torchtext
from tqdm import tqdm
from torchtext.datasets import IMDB
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
-
torch (PyTorch):
PyTorch 是一个用于机器学习和深度学习的开源深度学习框架。它提供了张量计算、自动微分、神经网络层和优化器等功能,使用户能够构建和训练深度学习模型。 -
torch.nn:
torch.nn 模块包含了PyTorch中用于构建神经网络模型的类和函数。它包括各种神经网络层、损失函数和优化器等。 -
torch.nn.functional:
torch.nn.functional 模块提供了一组函数,用于构建神经网络的非参数化操作,如激活函数、池化和卷积等。这些函数通常与torch.nn一起使用。 -
sklearn.metrics (scikit-learn):
scikit-learn是一个用于机器学习的Python库,其中包含了一系列用于评估模型性能的度量工具。导入的precision_score、recall_score、f1_score 和 accuracy_score 用于计算分类模型的精确度、召回率、F1分数和准确性。 -
torch.utils:
torch.utils 包含了一些实用工具和数据加载相关的函数。在这段代码中,它用于构建数据加载器。 -
torchtext:
torchtext 是一个PyTorch的自然语言处理库,用于文本数据的处理和加载。它提供了用于文本数据预处理和构建数据集的功能。 -
tqdm:
tqdm 是一个Python库,用于创建进度条,可用于监视循环迭代的进度。在代码中,它用于显示训练和评估的进度。 -
torchtext.datasets.IMDB:
torchtext.datasets.IMDB 是TorchText库中的一个数据集,包含了IMDb电影评论的数据。这些评论用于情感分析任务,其中评论被标记为积极或消极。
🥦设置日志
logging.basicConfig(
level=logging.WARN, stream=sys.stdout, format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
)
- 1
- 2
- 3
在代码中设置日志的作用是记录程序的运行状态、调试信息和重要事件,以便在开发和生产环境中更轻松地诊断问题和了解程序的行为。设置日志有以下作用:
-
问题诊断:当程序出现错误或异常时,日志记录可以提供有关错误发生的位置、原因和上下文的信息。这有助于开发人员快速定位和修复问题。
-
性能分析:通过记录程序的运行时间和关键操作的时间戳,日志可以用于性能分析,帮助开发人员识别潜在的性能瓶颈。
-
跟踪进度:在长时间运行的任务中,例如训练深度学习模型,日志记录可以帮助跟踪任务的进度,以便了解训练状态、完成的步骤和剩余时间。
-
监控和警报:日志可以与监控系统集成,以便在发生关键事件或异常情况时触发警报。这对于及时响应问题非常重要。
-
审计和合规:在某些应用中,日志记录是合规性的一部分,用于追踪系统的操作和用户的活动。日志可以用于审计和调查。
在上述代码中,设置日志的目的是跟踪训练进度、记录训练损失以及保存检查点。它允许开发人员监视模型训练的进展并在需要时查看详细信息,例如损失值和评估指标。此外,日志还可以用于调试和查看模型性能。
🥦模型定义
代码定义了两个模型:
GCNN:用于文本分类的门控卷积神经网络。
TextClassificationModel:使用嵌入和线性层的简单文本分类模型。
- 1
- 2
🥦GCNN
class GCNN(nn.Module): def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, num_class=2): super(GCNN, self).__init__()
self<span class="token punctuation">.</span>embedding_table <span class="token operator">=</span> nn<span class="token punctuation">.</span>Embedding<span class="token punctuation">(</span>vocab_size<span class="token punctuation">,</span> embedding_dim<span class="token punctuation">)</span> nn<span class="token punctuation">.</span>init<span class="token punctuation">.</span>xavier_uniform_<span class="token punctuation">(</span>self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">.</span>weight<span class="token punctuation">)</span> <span class="token comment"># 都是1维卷积</span> self<span class="token punctuation">.</span>conv_A_1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span>embedding_dim<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv_B_1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span>embedding_dim<span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv_A_2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>conv_B_2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Conv1d<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">15</span><span class="token punctuation">,</span> stride<span class="token operator">=</span><span class="token number">7</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>output_linear1 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">64</span><span class="token punctuation">,</span> <span class="token number">128</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>output_linear2 <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span><span class="token number">128</span><span class="token punctuation">,</span> num_class<span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> word_index<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token triple-quoted-string string">""" 定义GCN网络的算子操作流程,基于句子单词ID输入得到分类logits输出 """</span> <span class="token comment"># 1. 通过word_index得到word_embedding</span> <span class="token comment"># word_index shape: [bs, max_seq_len]</span> word_embedding <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">(</span>word_index<span class="token punctuation">)</span> <span class="token comment"># [bs, max_seq_len, embedding_dim]</span> <span class="token comment"># 2. 编写第一层1D门卷积模块,通道数在第2维</span> word_embedding <span class="token operator">=</span> word_embedding<span class="token punctuation">.</span>transpose<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token number">2</span><span class="token punctuation">)</span> <span class="token comment"># [bs, embedding_dim, max_seq_len]</span> A <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_A_1<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span> B <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_B_1<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span> H <span class="token operator">=</span> A <span class="token operator">*</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>B<span class="token punctuation">)</span> <span class="token comment"># [bs, 64, max_seq_len]</span> A <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_A_2<span class="token punctuation">(</span>H<span class="token punctuation">)</span> B <span class="token operator">=</span> self<span class="token punctuation">.</span>conv_B_2<span class="token punctuation">(</span>H<span class="token punctuation">)</span> H <span class="token operator">=</span> A <span class="token operator">*</span> torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>B<span class="token punctuation">)</span> <span class="token comment"># [bs, 64, max_seq_len]</span> <span class="token comment"># 3. 池化并经过全连接层</span> pool_output <span class="token operator">=</span> torch<span class="token punctuation">.</span>mean<span class="token punctuation">(</span>H<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token comment"># 平均池化,得到[bs, 4096]</span> linear1_output <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear1<span class="token punctuation">(</span>pool_output<span class="token punctuation">)</span> <span class="token comment"># 最后一层需要设置为隐含层数目</span> logits <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear2<span class="token punctuation">(</span>linear1_output<span class="token punctuation">)</span> <span class="token comment"># [bs, 2]</span> <span class="token keyword">return</span> logits
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
🥦TextClassificationModel
class TextClassificationModel(nn.Module): """ 简单版embedding.DNN模型 """
<span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> vocab_size<span class="token operator">=</span>VOCAB_SIZE<span class="token punctuation">,</span> embed_dim<span class="token operator">=</span><span class="token number">64</span><span class="token punctuation">,</span> num_class<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token builtin">super</span><span class="token punctuation">(</span>TextClassificationModel<span class="token punctuation">,</span> self<span class="token punctuation">)</span><span class="token punctuation">.</span>__init__<span class="token punctuation">(</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>embedding <span class="token operator">=</span> nn<span class="token punctuation">.</span>EmbeddingBag<span class="token punctuation">(</span>vocab_size<span class="token punctuation">,</span> embed_dim<span class="token punctuation">,</span> sparse<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">)</span> self<span class="token punctuation">.</span>fc <span class="token operator">=</span> nn<span class="token punctuation">.</span>Linear<span class="token punctuation">(</span>embed_dim<span class="token punctuation">,</span> num_class<span class="token punctuation">)</span> <span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> token_index<span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># 词袋</span> embedded <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span> <span class="token comment"># shape: [bs, embedding_dim]</span> <span class="token keyword">return</span> self<span class="token punctuation">.</span>fc<span class="token punctuation">(</span>embedded<span class="token punctuation">)</span>
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
🥦准备IMDb数据集
这行代码使用TorchText的IMDB数据集对象,导入IMDb数据集的训练集部分。
# 数据集导入
train_data_iter = IMDB(root="./data", split="train")
- 1
- 2
这行代码创建了一个用于将文本分词为单词的分词器。
# 数据预处理
tokenizer = get_tokenizer("basic_english")
- 1
- 2
这里,build_vocab_from_iterator 函数根据文本数据创建了一个词汇表,只包括出现频率大于等于20次的单词。特殊标记用于处理未知单词。然后,set_default_index将特殊标记的索引设置为0。
# 构建词汇表
vocab = build_vocab_from_iterator(yeild_tokens(train_data_iter, tokenizer), min_freq=20, specials=["<unk>"])
vocab.set_default_index(0)
- 1
- 2
- 3
这是一个自定义的校对函数,用于处理DataLoader返回的批次数据,将文本转换为可以输入模型的张量形式。
def collate_fn(batch): """ 对DataLoader所生成的mini-batch进行后处理 """ target = [] token_index = [] max_length = 0 # 最大的token长度 for i, (label, comment) in enumerate(batch): tokens = tokenizer(comment) token_index.append(vocab(tokens)) # 字符列表转换为索引列表
<span class="token comment"># 确定最大的句子长度</span> <span class="token keyword">if</span> <span class="token builtin">len</span><span class="token punctuation">(</span>tokens<span class="token punctuation">)</span> <span class="token operator">></span> max_length<span class="token punctuation">:</span> max_length <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>tokens<span class="token punctuation">)</span> <span class="token keyword">if</span> label <span class="token operator">==</span> <span class="token string">"pos"</span><span class="token punctuation">:</span> target<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token number">0</span><span class="token punctuation">)</span> <span class="token keyword">else</span><span class="token punctuation">:</span> target<span class="token punctuation">.</span>append<span class="token punctuation">(</span><span class="token number">1</span><span class="token punctuation">)</span> token_index <span class="token operator">=</span> <span class="token punctuation">[</span>index <span class="token operator">+</span> <span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> <span class="token operator">*</span> <span class="token punctuation">(</span>max_length <span class="token operator">-</span> <span class="token builtin">len</span><span class="token punctuation">(</span>index<span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">for</span> index <span class="token keyword">in</span> token_index<span class="token punctuation">]</span> <span class="token comment"># one-hot接收长整形的数据,所以要转换为int64</span> <span class="token keyword">return</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>target<span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>int64<span class="token punctuation">)</span><span class="token punctuation">,</span> torch<span class="token punctuation">.</span>tensor<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>int32<span class="token punctuation">)</span><span class="token punctuation">)</span>
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
这行代码将IMDb训练数据集加载到DataLoader对象中,以便进行模型训练。collate_fn函数用于处理数据的批处理。
train_data_loader = torch.utils.data.DataLoader(
to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
- 1
- 2
上述代码块执行了IMDb数据集的准备工作,包括导入数据、分词、构建词汇表和设置数据加载器。这些步骤是为了使数据集可用于训练文本分类模型。
🥦整理函数
这个 collate_fn 函数用于对 DataLoader 批次中的数据进行处理,确保每个批次中的文本序列具有相同的长度,并将标签转换为适用于模型输入的张量形式。它的工作包括以下几个方面:
提取标签和评论文本。
使用分词器将评论文本分词为单词。
确定批次中最长评论的长度。
根据最长评论的长度,将所有评论的单词索引序列填充到相同的长度。
将标签转换为适当的张量形式(这里是将标签转换为长整数型)。
返回处理后的批次数据,其中包括标签和填充后的单词索引序列。
- 1
- 2
- 3
- 4
- 5
- 6
这个整理函数确保了模型在训练期间能够处理不同长度的文本序列,并将它们转换为模型可接受的张量输入。
🥦训练函数
def train(train_data_loader, eval_data_loader, model, optimizer, num_epoch, log_step_interval, save_step_interval, eval_step_interval, save_path, resume=""): """ 此处data_loader是map-style dataset """ start_epoch = 0 start_step = 0 if resume != "": # 加载之前训练过的模型的参数文件 logging.warning(f"loading from resume") checkpoint = torch.load(resume) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] start_step = checkpoint['step']
<span class="token keyword">for</span> epoch_index <span class="token keyword">in</span> tqdm<span class="token punctuation">(</span><span class="token builtin">range</span><span class="token punctuation">(</span>start_epoch<span class="token punctuation">,</span> num_epoch<span class="token punctuation">)</span><span class="token punctuation">,</span> desc<span class="token operator">=</span><span class="token string">"epoch"</span><span class="token punctuation">)</span><span class="token punctuation">:</span> ema_loss <span class="token operator">=</span> <span class="token number">0</span> total_acc_account <span class="token operator">=</span> <span class="token number">0</span> total_account <span class="token operator">=</span> <span class="token number">0</span> true_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> predicted_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> num_batches <span class="token operator">=</span> <span class="token builtin">len</span><span class="token punctuation">(</span>train_data_loader<span class="token punctuation">)</span> <span class="token keyword">for</span> batch_index<span class="token punctuation">,</span> <span class="token punctuation">(</span>target<span class="token punctuation">,</span> token_index<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>train_data_loader<span class="token punctuation">)</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>zero_grad<span class="token punctuation">(</span><span class="token punctuation">)</span> step <span class="token operator">=</span> num_batches <span class="token operator">*</span> <span class="token punctuation">(</span>epoch_index<span class="token punctuation">)</span> <span class="token operator">+</span> batch_index <span class="token operator">+</span> <span class="token number">1</span> logits <span class="token operator">=</span> model<span class="token punctuation">(</span>token_index<span class="token punctuation">)</span> <span class="token comment"># one-hot需要转换float32才可以训练</span> bce_loss <span class="token operator">=</span> F<span class="token punctuation">.</span>binary_cross_entropy<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>logits<span class="token punctuation">)</span><span class="token punctuation">,</span> F<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>target<span class="token punctuation">,</span> num_classes<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>float32<span class="token punctuation">)</span><span class="token punctuation">)</span> ema_loss <span class="token operator">=</span> <span class="token number">0.9</span> <span class="token operator">*</span> ema_loss <span class="token operator">+</span> <span class="token number">0.1</span> <span class="token operator">*</span> bce_loss <span class="token comment"># 指数平均loss</span> bce_loss<span class="token punctuation">.</span>backward<span class="token punctuation">(</span><span class="token punctuation">)</span> nn<span class="token punctuation">.</span>utils<span class="token punctuation">.</span>clip_grad_norm_<span class="token punctuation">(</span>model<span class="token punctuation">.</span>parameters<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token number">0.1</span><span class="token punctuation">)</span> <span class="token comment"># 梯度的正则进行截断,保证训练稳定</span> optimizer<span class="token punctuation">.</span>step<span class="token punctuation">(</span><span class="token punctuation">)</span> <span class="token comment"># 更新参数</span> true_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>target<span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> predicted_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">if</span> step <span class="token operator">%</span> log_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span> logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"epoch_index: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>epoch_index<span class="token punctuation">}</span></span><span class="token string">, batch_index: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>batch_index<span class="token punctuation">}</span></span><span class="token string">, ema_loss: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>ema_loss<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span> <span class="token keyword">if</span> step <span class="token operator">%</span> save_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span> os<span class="token punctuation">.</span>makedirs<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> exist_ok<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> save_file <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f"step_</span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>step<span class="token punctuation">}</span></span><span class="token string">.pt"</span></span><span class="token punctuation">)</span> torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span> <span class="token string">"epoch"</span><span class="token punctuation">:</span> epoch_index<span class="token punctuation">,</span> <span class="token string">"step"</span><span class="token punctuation">:</span> step<span class="token punctuation">,</span> <span class="token string">"model_state_dict"</span><span class="token punctuation">:</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'optimizer_state_dict'</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'loss'</span><span class="token punctuation">:</span> bce_loss <span class="token punctuation">}</span><span class="token punctuation">,</span> save_file<span class="token punctuation">)</span> logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"checkpoint has been saved in </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>save_file<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span> <span class="token keyword">if</span> step <span class="token operator">%</span> save_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span> os<span class="token punctuation">.</span>makedirs<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> exist_ok<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">)</span> save_file <span class="token operator">=</span> os<span class="token punctuation">.</span>path<span class="token punctuation">.</span>join<span class="token punctuation">(</span>save_path<span class="token punctuation">,</span> <span class="token string-interpolation"><span class="token string">f"step_</span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>step<span class="token punctuation">}</span></span><span class="token string">.pt"</span></span><span class="token punctuation">)</span> torch<span class="token punctuation">.</span>save<span class="token punctuation">(</span><span class="token punctuation">{<!-- --></span> <span class="token string">"epoch"</span><span class="token punctuation">:</span> epoch_index<span class="token punctuation">,</span> <span class="token string">"step"</span><span class="token punctuation">:</span> step<span class="token punctuation">,</span> <span class="token string">"model_state_dict"</span><span class="token punctuation">:</span> model<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'optimizer_state_dict'</span><span class="token punctuation">:</span> optimizer<span class="token punctuation">.</span>state_dict<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span> <span class="token string">'loss'</span><span class="token punctuation">:</span> bce_loss<span class="token punctuation">,</span> <span class="token string">'accuracy'</span><span class="token punctuation">:</span> accuracy<span class="token punctuation">,</span> <span class="token string">'precision'</span><span class="token punctuation">:</span> precision<span class="token punctuation">,</span> <span class="token string">'recall'</span><span class="token punctuation">:</span> recall<span class="token punctuation">,</span> <span class="token string">'f1'</span><span class="token punctuation">:</span> f1 <span class="token punctuation">}</span><span class="token punctuation">,</span> save_file<span class="token punctuation">)</span> logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"checkpoint has been saved in </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>save_file<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span> <span class="token keyword">if</span> step <span class="token operator">%</span> eval_step_interval <span class="token operator">==</span> <span class="token number">0</span><span class="token punctuation">:</span> logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string">"start to do evaluation..."</span><span class="token punctuation">)</span> model<span class="token punctuation">.</span><span class="token builtin">eval</span><span class="token punctuation">(</span><span class="token punctuation">)</span> ema_eval_loss <span class="token operator">=</span> <span class="token number">0</span> total_acc_account <span class="token operator">=</span> <span class="token number">0</span> total_account <span class="token operator">=</span> <span class="token number">0</span> true_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> predicted_labels <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token punctuation">]</span> <span class="token keyword">for</span> eval_batch_index<span class="token punctuation">,</span> <span class="token punctuation">(</span>eval_target<span class="token punctuation">,</span> eval_token_index<span class="token punctuation">)</span> <span class="token keyword">in</span> <span class="token builtin">enumerate</span><span class="token punctuation">(</span>eval_data_loader<span class="token punctuation">)</span><span class="token punctuation">:</span> total_account <span class="token operator">+=</span> eval_target<span class="token punctuation">.</span>shape<span class="token punctuation">[</span><span class="token number">0</span><span class="token punctuation">]</span> eval_logits <span class="token operator">=</span> model<span class="token punctuation">(</span>eval_token_index<span class="token punctuation">)</span> total_acc_account <span class="token operator">+=</span> <span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>eval_logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span> <span class="token operator">==</span> eval_target<span class="token punctuation">)</span><span class="token punctuation">.</span><span class="token builtin">sum</span><span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">.</span>item<span class="token punctuation">(</span><span class="token punctuation">)</span> eval_bce_loss <span class="token operator">=</span> F<span class="token punctuation">.</span>binary_cross_entropy<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>sigmoid<span class="token punctuation">(</span>eval_logits<span class="token punctuation">)</span><span class="token punctuation">,</span> F<span class="token punctuation">.</span>one_hot<span class="token punctuation">(</span>eval_target<span class="token punctuation">,</span> num_classes<span class="token operator">=</span><span class="token number">2</span><span class="token punctuation">)</span><span class="token punctuation">.</span>to<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>float32<span class="token punctuation">)</span><span class="token punctuation">)</span> ema_eval_loss <span class="token operator">=</span> <span class="token number">0.9</span> <span class="token operator">*</span> ema_eval_loss <span class="token operator">+</span> <span class="token number">0.1</span> <span class="token operator">*</span> eval_bce_loss true_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>eval_target<span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> predicted_labels<span class="token punctuation">.</span>extend<span class="token punctuation">(</span>torch<span class="token punctuation">.</span>argmax<span class="token punctuation">(</span>eval_logits<span class="token punctuation">,</span> dim<span class="token operator">=</span><span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">)</span><span class="token punctuation">.</span>tolist<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> accuracy <span class="token operator">=</span> accuracy_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span> precision <span class="token operator">=</span> precision_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span> recall <span class="token operator">=</span> recall_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span> f1 <span class="token operator">=</span> f1_score<span class="token punctuation">(</span>true_labels<span class="token punctuation">,</span> predicted_labels<span class="token punctuation">)</span> logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"ema_eval_loss: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>ema_eval_loss<span class="token punctuation">}</span></span><span class="token string">, eval_acc: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>total_acc_account <span class="token operator">/</span> total_account<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span> logging<span class="token punctuation">.</span>warning<span class="token punctuation">(</span><span class="token string-interpolation"><span class="token string">f"Precision: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>precision<span class="token punctuation">}</span></span><span class="token string">, Recall: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>recall<span class="token punctuation">}</span></span><span class="token string">, F1: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>f1<span class="token punctuation">}</span></span><span class="token string">, Accuracy: </span><span class="token interpolation"><span class="token punctuation">{<!-- --></span>accuracy<span class="token punctuation">}</span></span><span class="token string">"</span></span><span class="token punctuation">)</span> model<span class="token punctuation">.</span>train<span class="token punctuation">(</span><span class="token punctuation">)</span>
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
- 95
- 96
- 97
这段代码定义了一个名为 train 的函数,用于执行训练过程。下面是该函数的详细说明:
train 函数接受以下参数:
train_data_loader: 训练数据的 DataLoader,用于迭代训练数据。
eval_data_loader: 用于评估的 DataLoader,用于评估模型性能。
model: 要训练的神经网络模型。
optimizer: 用于更新模型参数的优化器。
num_epoch: 训练的总周期数。
log_step_interval: 记录日志的间隔步数。
save_step_interval: 保存模型检查点的间隔步数。
eval_step_interval: 执行评估的间隔步数。
save_path: 保存模型检查点的目录。
resume: 可选的,用于恢复训练的检查点文件路径。
训练函数的主要工作如下:
它首先检查是否有恢复训练的检查点文件。如果有,它会加载之前训练的模型参数和优化器状态,以便继续训练。
然后,它开始进行一系列的训练周期(epochs),每个周期内包含多个训练步(batches)。
在每个训练步中,它执行以下操作:
零化梯度,以准备更新模型参数。
计算模型的预测输出(logits)。
计算二进制交叉熵损失(binary cross-entropy loss)。
使用反向传播(backpropagation)计算梯度并更新模型参数。
记录损失、真实标签和预测标签。
如果步数达到了 log_step_interval,则记录损失。
如果步数达到了 save_step_interval,则保存模型检查点。
如果步数达到了 eval_step_interval,则执行评估:
将模型切换到评估模式(model.eval())。
对评估数据集中的每个批次执行以下操作:
计算模型的预测输出。
计算二进制交叉熵损失。
计算准确性、精确度、召回率和F1分数。
记录评估损失和评估指标。
将模型切换回训练模式(model.train())。
最后,训练函数返回经过训练的模型。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
这个训练函数执行了完整的训练过程,包括了模型的前向传播、损失计算、梯度更新、日志记录、模型检查点的保存和评估。通过调用这个函数,你可以训练模型并监视其性能。
🥦模型初始化和优化器
model = GCNN()
# model = TextClassificationModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
- 1
- 2
- 3
🥦加载用于训练和评估的数据
在提供的代码中,加载用于训练和评估的数据的部分如下:
train_data_iter = IMDB(root="data", split="train")
- 1
这一行代码使用 TorchText 的 IMDB 数据集对象,导入 IMDB 数据集的训练集部分。这部分数据将用于模型的训练。
eval_data_iter = IMDB(root="data", split="test")
- 1
这一行代码使用 TorchText 的 IMDB 数据集对象,导入 IMDB 数据集的测试集部分。这部分数据将用于评估模型的性能。
之后,这些数据集通过以下代码转化为 DataLoader 对象,以便用于模型训练和评估:
# 训练数据 DataLoader
train_data_loader = torch.utils.data.DataLoader(
to_map_style_dataset(train_data_iter), batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)
- 1
- 2
- 3
# 评估数据 DataLoader
eval_data_loader = utils.data.DataLoader(
to_map_style_dataset(eval_data_iter), batch_size=8, collate_fn=collate_fn)
- 1
- 2
- 3
这些 DataLoader 对象将数据加载到内存中,以便训练和评估使用。collate_fn 函数用于处理数据的批次,确保它们具有适当的格式,以便输入到模型中。
这些部分负责加载和准备用于训练和评估的数据,是机器学习模型训练和评估的重要准备步骤。训练数据用于训练模型,而评估数据用于评估模型的性能。
🥦恢复训练
start_epoch = 0
start_step = 0
if resume != "":
# 加载之前训练过的模型的参数文件
logging.warning(f"loading from resume")
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
start_step = checkpoint['step']
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
上述代码段位于训练函数中的开头部分,主要用于检查是否有已经训练过的模型的检查点文件,以便继续训练。具体解释如下:
如果 resume 变量不为空(即存在要恢复的检查点文件路径),则执行以下操作:
通过 torch.load 加载之前训练过的模型的检查点文件。
使用 load_state_dict 方法将已保存的模型参数加载到当前的模型中,以便继续训练。
同样,使用 load_state_dict 方法将已保存的优化器状态加载到当前的优化器中,以确保继续从之前的状态开始训练。
获取之前训练的轮数和步数,以便从恢复的状态继续训练。
这部分代码的目的是允许从之前保存的模型检查点继续训练,而不是从头开始。这对于长时间运行的训练任务非常有用,可以在中途中断训练并在之后恢复,而不会丢失之前的训练进度。
🥦调用训练
train(train_data_loader, eval_data_loader, model, optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_text_classification2", resume=resume)
- 1
🥦保存文件的读取
import torch
# 指定已存在的 .pt 文件路径
file_path = “./log_imdb_text_classification/step_3500.pt” # 替换为实际的文件路径
# 使用 torch.load() 加载文件
checkpoint = torch.load(file_path)
# 查看准确率、精确率、召回率和F1分数
accuracy = checkpoint[“accuracy”]
precision = checkpoint[“precision”]
recall = checkpoint[“recall”]
f1 = checkpoint[“f1”]
print(“Accuracy:”, accuracy)
print(“Precision:”, precision)
print(“Recall:”, recall)
print(“F1 Score:”, f1)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
🥦扩展 LSTM、GRU
本文原作者使用的是卷积神经网络,但是卷积神经网络的优化模型GCNN,但是这个模型对于图更好,由此我接下来引入两个循环神经网络LSTM和GRU
class LSTMModel(nn.Module): def __init__(self, vocab_size=VOCAB_SIZE, embedding_dim=64, hidden_dim=64, num_class=2): super(LSTMModel, self).__init__() self.embedding_table = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=1, batch_first=True) self.output_linear = nn.Linear(hidden_dim, num_class)
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> word_index<span class="token punctuation">)</span><span class="token punctuation">:</span> word_embedding <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">(</span>word_index<span class="token punctuation">)</span> lstm_out<span class="token punctuation">,</span> _ <span class="token operator">=</span> self<span class="token punctuation">.</span>lstm<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span> lstm_out <span class="token operator">=</span> lstm_out<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span> <span class="token comment"># 取最后一个时间步的输出</span> logits <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear<span class="token punctuation">(</span>lstm_out<span class="token punctuation">)</span> <span class="token keyword">return</span> logits
class GRUModel(nn.Module):
def init(self, vocab_size=VOCAB_SIZE, embedding_dim=64, hidden_dim=64, num_class=2):
super(GRUModel, self).init()
self.embedding_table = nn.Embedding(vocab_size, embedding_dim)
self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=1, batch_first=True)
self.output_linear = nn.Linear(hidden_dim, num_class)
<span class="token keyword">def</span> <span class="token function">forward</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> word_index<span class="token punctuation">)</span><span class="token punctuation">:</span>
word_embedding <span class="token operator">=</span> self<span class="token punctuation">.</span>embedding_table<span class="token punctuation">(</span>word_index<span class="token punctuation">)</span>
gru_out<span class="token punctuation">,</span> _ <span class="token operator">=</span> self<span class="token punctuation">.</span>gru<span class="token punctuation">(</span>word_embedding<span class="token punctuation">)</span>
gru_out <span class="token operator">=</span> gru_out<span class="token punctuation">[</span><span class="token punctuation">:</span><span class="token punctuation">,</span> <span class="token operator">-</span><span class="token number">1</span><span class="token punctuation">,</span> <span class="token punctuation">:</span><span class="token punctuation">]</span> <span class="token comment"># 取最后一个时间步的输出</span>
logits <span class="token operator">=</span> self<span class="token punctuation">.</span>output_linear<span class="token punctuation">(</span>gru_out<span class="token punctuation">)</span>
<span class="token keyword">return</span> logits
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
# 创建LSTM模型
lstm_model = LSTMModel()
print("模型总参数:", sum(p.numel() for p in lstm_model.parameters()))
lstm_optimizer = torch.optim.Adam(lstm_model.parameters(), lr=0.001)
# 创建GRU模型
# gru_model = GRUModel()
# print(“模型总参数:”, sum(p.numel() for p in gru_model.parameters()))
# gru_optimizer = torch.optim.Adam(gru_model.parameters(), lr=0.001)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
# 训练LSTM模型
train(train_data_loader, eval_data_loader, lstm_model, lstm_optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path="./log_imdb_lstm", resume="")
# 训练GRU模型
# train(train_data_loader, eval_data_loader, gru_model, gru_optimizer, num_epoch=10, log_step_interval=20, save_step_interval=500, eval_step_interval=300, save_path=“./log_imdb_gru”, resume=“”)
- 1
- 2
- 3
- 4
- 5
感兴趣的小伙伴可以试试,对比一下
🥦总结
本文代码来自网络仅供学习,原文地址
挑战与创造都是很痛苦的,但是很充实。