PyTorch中使用LSTM处理变长序列

使用LSTM算法处理的序列经常是变长的,这里介绍一下PyTorch框架下使用LSTM模型处理变长序列的方法。需要使用到PyTorch中torch.nn.utils包中的pack_padded_sequence()pad_packed_sequence()两个函数。pack:压缩;pad:填充。

 

pack_padded_sequence()函数

def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):
    # type: (Tensor, Tensor, bool, bool) -> PackedSequence
    r"""Packs a Tensor containing padded sequences of variable length.

    :attr:`input` can be of size ``T x B x *`` where `T` is the length of the
    longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and
    ``*`` is any number of dimensions (including 0). If ``batch_first`` is
    ``True``, ``B x T x *`` :attr:`input` is expected.

    For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
    ``True``, the sequences should be sorted by length in a decreasing order, i.e.
    ``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
    one. `enforce_sorted = True` is only necessary for ONNX export.

    Note:
        This function accepts any input that has at least two dimensions. You
        can apply it to pack the labels, and use the output of the RNN with
        them to compute the loss directly. A Tensor can be retrieved from
        a :class:`PackedSequence` object by accessing its ``.data`` attribute.

    Arguments:
        input (Tensor): padded batch of variable length sequences.
        lengths (Tensor): list of sequences lengths of each batch element.
        batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``
            format.
        enforce_sorted (bool, optional): if ``True``, the input is expected to
            contain sequences sorted by length in a decreasing order. If
            ``False``, the input will get sorted unconditionally. Default: ``True``.

    Returns:
        a :class:`PackedSequence` object
    """

pad_packed_sequence()函数

def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
    # type: (PackedSequence, bool, float, Optional[int]) -> Tuple[Tensor, Tensor]
    r"""Pads a packed batch of variable length sequences.

    It is an inverse operation to :func:`pack_padded_sequence`.

    The returned Tensor's data will be of size ``T x B x *``, where `T` is the length
    of the longest sequence and `B` is the batch size. If ``batch_first`` is True,
    the data will be transposed into ``B x T x *`` format.

    Example:
        >>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
        >>> seq = torch.tensor([[1,2,0], [3,0,0], [4,5,6]])
        >>> lens = [2, 1, 3]
        >>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
        >>> packed
        PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
                       sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
        >>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
        >>> seq_unpacked
        tensor([[1, 2, 0],
                [3, 0, 0],
                [4, 5, 6]])
        >>> lens_unpacked
        tensor([2, 1, 3])

    .. note::
        :attr:`total_length` is useful to implement the
        ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
        :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
        See :ref:`this FAQ section <pack-rnn-unpack-with-data-parallelism>` for
        details.

    Arguments:
        sequence (PackedSequence): batch to pad
        batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
            format.
        padding_value (float, optional): values for padded elements.
        total_length (int, optional): if not ``None``, the output will be padded to
            have length :attr:`total_length`. This method will throw :class:`ValueError`
            if :attr:`total_length` is less than the max sequence length in
            :attr:`sequence`.

    Returns:
        Tuple of Tensor containing the padded sequence, and a Tensor
        containing the list of lengths of each sequence in the batch.
        Batch elements will be re-ordered as they were ordered originally when
        the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``.

    """

流程为先根据序列长度压缩输入,然后经过LSTM后再填充。

代码为:

import torch as t
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


a = t.tensor(torch.tensor([[1, 2, 3],
                  [1, 2, 0],
                  [3, 0, 0],
                  [2, 1, 0]]))
lengths = t.tensor([3, 2, 1, 2])

# 排序
a_lengths, idx = lengths.sort(0, descending=True)
print(a_lengths)        # tensor([3, 3, 2, 1])
print(idx)      # tensor([0, 3, 2, 1])

_, un_idx = t.sort(idx, dim=0)
print(un_idx)       # tensor([0, 3, 2, 1])
a = a[idx]
print(a)

# 定义层
emb = t.nn.Embedding(4, 2, padding_idx=0)
lstm = t.nn.LSTM(input_size=2, hidden_size=6, batch_first=True)

a_input = emb(a)
print(a_input)
print(a_input.shape)
a_packed_input = pack_padded_sequence(input=a_input, lengths=a_lengths, batch_first=True)
print(a_packed_input)
packed_out, _ = lstm(a_packed_input)
print(packed_out)
out, _ = pad_packed_sequence(packed_out, batch_first=True, total_length=3)
print(out)
# 根据un_index将输出转回原输入顺序
out = t.index_select(out, 0, un_idx)
print(out)
print(out.shape)        # torch.Size([4, 3, 4])

或者指定pack_padded_sequence()函数中的enforce_sorted参数为False,这种方式简单一点。

import torch as t
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


a = t.tensor(torch.tensor([[1, 2, 3],
                  [1, 2, 0],
                  [3, 0, 0],
                  [2, 1, 0]]))
lengths = t.tensor([3, 2, 1, 2])

# 定义层
emb = t.nn.Embedding(4, 2, padding_idx=0)
lstm = t.nn.LSTM(input_size=2, hidden_size=6, batch_first=True)

a_input = emb(a)
print(a_input)
print(a_input.shape)
a_packed_input = pack_padded_sequence(input=a_input, lengths=lengths, batch_first=True, enforce_sorted=False)
print('--------------->a_packed_input:', a_packed_input)
packed_out, _ = lstm(a_packed_input)
print('--------------->packed_out:', packed_out)
out, seq_len = pad_packed_sequence(packed_out, batch_first=True, total_length=3)
print('--------------->out:', out)
print('--------------->seq_len:', seq_len)

效果为:

tensor([[[-0.1088,  0.1528],
         [-0.7593, -0.1880],
         [-0.0589, -1.0816]],

        [[-0.1088,  0.1528],
         [-0.7593, -0.1880],
         [ 0.0000,  0.0000]],

        [[-0.0589, -1.0816],
         [ 0.0000,  0.0000],
         [ 0.0000,  0.0000]],

        [[-0.7593, -0.1880],
         [-0.1088,  0.1528],
         [ 0.0000,  0.0000]]], grad_fn=<EmbeddingBackward>)
torch.Size([4, 3, 2])
--------------->a_packed_input: PackedSequence(data=tensor([[-0.1088,  0.1528],
        [-0.1088,  0.1528],
        [-0.7593, -0.1880],
        [-0.0589, -1.0816],
        [-0.7593, -0.1880],
        [-0.7593, -0.1880],
        [-0.1088,  0.1528],
        [-0.0589, -1.0816]], grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([4, 3, 1]), sorted_indices=tensor([0, 1, 3, 2]), unsorted_indices=tensor([0, 1, 3, 2]))
--------------->packed_out: PackedSequence(data=tensor([[-0.1006, -0.0041,  0.0238,  0.0290,  0.0865, -0.0759],
        [-0.1006, -0.0041,  0.0238,  0.0290,  0.0865, -0.0759],
        [-0.1205,  0.0228, -0.0018,  0.0025,  0.1195, -0.0669],
        [-0.1151, -0.0232, -0.0214, -0.0042,  0.1847, -0.0348],
        [-0.1604,  0.0171,  0.0086,  0.0087,  0.1646, -0.0883],
        [-0.1604,  0.0171,  0.0086,  0.0087,  0.1646, -0.0883],
        [-0.1696,  0.0064,  0.0206,  0.0213,  0.1473, -0.0896],
        [-0.2000, -0.0163, -0.0205, -0.0086,  0.2650, -0.0440]],
       grad_fn=<CatBackward>), batch_sizes=tensor([4, 3, 1]), sorted_indices=tensor([0, 1, 3, 2]), unsorted_indices=tensor([0, 1, 3, 2]))
--------------->out: tensor([[[-0.1006, -0.0041,  0.0238,  0.0290,  0.0865, -0.0759],
         [-0.1604,  0.0171,  0.0086,  0.0087,  0.1646, -0.0883],
         [-0.2000, -0.0163, -0.0205, -0.0086,  0.2650, -0.0440]],

        [[-0.1006, -0.0041,  0.0238,  0.0290,  0.0865, -0.0759],
         [-0.1604,  0.0171,  0.0086,  0.0087,  0.1646, -0.0883],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1151, -0.0232, -0.0214, -0.0042,  0.1847, -0.0348],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1205,  0.0228, -0.0018,  0.0025,  0.1195, -0.0669],
         [-0.1696,  0.0064,  0.0206,  0.0213,  0.1473, -0.0896],
         [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<IndexSelectBackward>)
--------------->seq_len: tensor([3, 2, 1, 2])

 

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值