2021SC@SDUSC
文章目录
一、Train.py概述
在Extract.py中提取出sliver label、Model.py中搭建好模型Encoder/Decoder构件、my_dataset.py中处理并加载好用于训练模型的数据集后,在Train.py中训练seq2seq模型。
主要分为三步,构建模型及对加载入模型的数据处理、训练模型和评估模型。
首先来看构建模型及数据处理的相关实现。
二、源码分析
关键包导入
import torch
import pickle
from torch.utils.data import DataLoader
from my_dataloader import *
from create_vocabulary import *
from Model import Encoder, Decoder, Seq2Seq
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
-
从torch.utils.data包中导入DataLoader类,为了加载Dataset数据。
为了获取训练数据集导入my_dataloader。
为了获取模型构件导入Model。 -
torch.nn和torch.optim是pytorch中最常用的两个包。在之前的博客里学习分析过torch.nn包的常见用法,在此不再赘述。重点看torch.optim包的用法。
-
torch.optim包主要包含了用来更新参数的优化算法,比如SGD、AdaGrad、RMSProp、 Adam等。在用torch.nn包定义完网络后用torch.optim定义损失函数和优化方法。示例如下。
import torch.optim as optim#导入上面的torch.nn包之后还需导入torch.optim包
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
定义好了网络、损失函数、优化器之后,常规训练网络的做法如下。
#将输入输进网络得到输出y_pred
y_pred = model(x)
#运用上面定义的loss函数计算网络输出与标签之间的距离
loss = loss_fn(y_pred, y)
#在反向传播之前需要将优化器中的梯度值清零,因为在默认情况下反向传播的梯度值会进行累加
optimizer.zero_grad()