[Pytorch] Sequence-to-Sequence Decoder 代码学习

本文详细介绍了如何实现PyTorch中的Sequence-to-Sequence Decoder,包括import部分、初始化参数、forward one step过程及主forward进程。通过理解encoder是否双向、RNN结构、最大解码长度、注意力机制等关键点,解析了decoder如何根据上一步隐藏状态和encoder输出计算当前步的预测概率、隐藏状态和注意力分布。
摘要由CSDN通过智能技术生成

虽然对 encoder-decoder 框架的了解已经很多了,但是从未实现过,可谓是“最熟悉的陌生人了”。近期,由于研究的需要,故而参照 github 上某开源项目(pytorch-seq2seq),实现了一个句法分析系统。本文,来学习一下实现的 decoder 部分的代码。

首先是import 部分的代码

import random

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

from attention import Attention
from baseRNN import BaseRNN

if torch.cuda.is_available():
    import torch.cuda as device
else:
    import torch as device

在 import部分:首先导入 numpy以及 torch 中需要使用到的模块。除了公共包,此处导入了名叫 Attention以及 BaseRNN 的模块,其中BaseRNN 为对 torch.nn.rnn 模块的一个wrapper, Attention的机制也是在 seq2seq 中一个很重要的部分,用于获取解码时对于解码中某一时刻最为 care 的信息,很简短的代码,留待以后补充。


init部分:

def __init__(self, vocab_size, max_len, input_size, hidden_size,
                 sos_id, eos_id,
                 n_layers=1, rnn_cell='gru', bidirectional=False,
                 input_dropout_p=0, dropout_p=0, use_attention=False):
        super(DecoderRNN, self).__init__(vocab_size, max_len, input_size, hidden_size,input_dropout_p, dropout_p,n_layers, rnn_cell)

        self.bidirectional_encoder = bidirectional
        self.rnn = self.rnn_cell(input_size, hidden_size, n_layers, batch_first=True, dropout=dropout_p)

        self.output_size = vocab_size
        self.max_length = max_len
        self.use_attention = use_attention
        self.eos_id = eos_id
        self.sos_id = sos_id

        self.init_input = None

        self.embedding = nn.Embedding(self.output_size, self.input_size)
        if use_attention:
            self.attention = Attention(self.hidden_size)

        self.fflayer = nn.Linear(self.hidden_size, self.output_size)

以上是一系列解码过程中需要使用到的参数。

  • bidirectional: 指明 encoder 端的输入是否为 bidirectional,用于初始化 encoder hidden 
  • rnn: decoder 端为一个 rnn
  • output_size: decoder端 output 的“词表”大小
  • max_length: 最长解码长度
  • use_attention: 是否在解码端使用注意力机制构建 feature 表示
  • eos_id: 辅助用于判断解码终止
  • sos_id: 辅助用于解码端的第一个输入
  • init_input: 目前没什么用
  • embedding: 解码端 output的词表 embedding
  • fflayer: 在解码时提供计算 output 的

以下则进入我们解码时的每一步时执行的操作,即为 forward one step

def forward_step(self, input_var, hidden, encoder_outputs, function):
        """
        Args:
            input_var: input token ids
            hidden: last hidden state
            encoder_outputs: encoder-layer output
            function: probs function, default is F.log_softmax
        Return:
            the softmax output, the hidden state save, and the attention value
        """
        batch_size = input_var.size(0)
        output_size = input_var.size(1)
        embedded = self.embedding(input_var)
        embedded = self.input_dropout(embe
  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值