关键词短语生成的无监督方法11——Train.py

本文详细介绍了Train.py在Seq2Seq模型训练中的作用,包括模型构建、数据处理、优化器使用和学习率调度。讲解了torch.optim包的优化算法,如StepLR的更新策略,并展示了如何根据训练epoch动态调整学习率。
摘要由CSDN通过智能技术生成

2021SC@SDUSC

一、Train.py概述

在Extract.py中提取出sliver label、Model.py中搭建好模型Encoder/Decoder构件、my_dataset.py中处理并加载好用于训练模型的数据集后,在Train.py中训练seq2seq模型。
主要分为三步,构建模型及对加载入模型的数据处理、训练模型和评估模型。
首先来看构建模型及数据处理的相关实现。

二、源码分析

关键包导入

import torch
import pickle
from torch.utils.data import DataLoader
from my_dataloader import *
from create_vocabulary import *
from Model import Encoder, Decoder, Seq2Seq
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
  • 从torch.utils.data包中导入DataLoader类,为了加载Dataset数据。
    为了获取训练数据集导入my_dataloader。
    为了获取模型构件导入Model。

  • torch.nn和torch.optim是pytorch中最常用的两个包。在之前的博客里学习分析过torch.nn包的常见用法,在此不再赘述。重点看torch.optim包的用法。

  • torch.optim包主要包含了用来更新参数的优化算法,比如SGD、AdaGrad、RMSProp、 Adam等。在用torch.nn包定义完网络后用torch.optim定义损失函数和优化方法。示例如下。

import torch.optim as optim#导入上面的torch.nn包之后还需导入torch.optim包
 
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

定义好了网络、损失函数、优化器之后,常规训练网络的做法如下。

#将输入输进网络得到输出y_pred
y_pred = model(x)
#运用上面定义的loss函数计算网络输出与标签之间的距离
loss = loss_fn(y_pred, y)
#在反向传播之前需要将优化器中的梯度值清零,因为在默认情况下反向传播的梯度值会进行累加
optimizer.zero_grad()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值