使用LSTM算法处理的序列经常是变长的,这里介绍一下PyTorch框架下使用LSTM模型处理变长序列的方法。需要使用到PyTorch中torch.nn.utils包中的pack_padded_sequence()和pad_packed_sequence()两个函数。pack:压缩;pad:填充。
pack_padded_sequence()函数
def pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True):
# type: (Tensor, Tensor, bool, bool) -> PackedSequence
r"""Packs a Tensor containing padded sequences of variable length.
:attr:`input` can be of size ``T x B x *`` where `T` is the length of the
longest sequence (equal to ``lengths[0]``), ``B`` is the batch size, and
``*`` is any number of dimensions (including 0). If ``batch_first`` is
``True``, ``B x T x *`` :attr:`input` is expected.
For unsorted sequences, use `enforce_sorted = False`. If :attr:`enforce_sorted` is
``True``, the sequences should be sorted by length in a decreasing order, i.e.
``input[:,0]`` should be the longest sequence, and ``input[:,B-1]`` the shortest
one. `enforce_sorted = True` is only necessary for ONNX export.
Note:
This function accepts any input that has at least two dimensions. You
can apply it to pack the labels, and use the output of the RNN with
them to compute the loss directly. A Tensor can be retrieved from
a :class:`PackedSequence` object by accessing its ``.data`` attribute.
Arguments:
input (Tensor): padded batch of variable length sequences.
lengths (Tensor): list of sequences lengths of each batch element.
batch_first (bool, optional): if ``True``, the input is expected in ``B x T x *``
format.
enforce_sorted (bool, optional): if ``True``, the input is expected to
contain sequences sorted by length in a decreasing order. If
``False``, the input will get sorted unconditionally. Default: ``True``.
Returns:
a :class:`PackedSequence` object
"""
pad_packed_sequence()函数
def pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):
# type: (PackedSequence, bool, float, Optional[int]) -> Tuple[Tensor, Tensor]
r"""Pads a packed batch of variable length sequences.
It is an inverse operation to :func:`pack_padded_sequence`.
The returned Tensor's data will be of size ``T x B x *``, where `T` is the length
of the longest sequence and `B` is the batch size. If ``batch_first`` is True,
the data will be transposed into ``B x T x *`` format.
Example:
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1,2,0], [3,0,0], [4,5,6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
>>> seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)
>>> seq_unpacked
tensor([[1, 2, 0],
[3, 0, 0],
[4, 5, 6]])
>>> lens_unpacked
tensor([2, 1, 3])
.. note::
:attr:`total_length` is useful to implement the
``pack sequence -> recurrent network -> unpack sequence`` pattern in a
:class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
See :ref:`this FAQ section <pack-rnn-unpack-with-data-parallelism>` for
details.
Arguments:
sequence (PackedSequence): batch to pad
batch_first (bool, optional): if ``True``, the output will be in ``B x T x *``
format.
padding_value (float, optional): values for padded elements.
total_length (int, optional): if not ``None``, the output will be padded to
have length :attr:`total_length`. This method will throw :class:`ValueError`
if :attr:`total_length` is less than the max sequence length in
:attr:`sequence`.
Returns:
Tuple of Tensor containing the padded sequence, and a Tensor
containing the list of lengths of each sequence in the batch.
Batch elements will be re-ordered as they were ordered originally when
the batch was passed to ``pack_padded_sequence`` or ``pack_sequence``.
"""
流程为先根据序列长度压缩输入,然后经过LSTM后再填充。
代码为:
import torch as t
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
a = t.tensor(torch.tensor([[1, 2, 3],
[1, 2, 0],
[3, 0, 0],
[2, 1, 0]]))
lengths = t.tensor([3, 2, 1, 2])
# 排序
a_lengths, idx = lengths.sort(0, descending=True)
print(a_lengths) # tensor([3, 3, 2, 1])
print(idx) # tensor([0, 3, 2, 1])
_, un_idx = t.sort(idx, dim=0)
print(un_idx) # tensor([0, 3, 2, 1])
a = a[idx]
print(a)
# 定义层
emb = t.nn.Embedding(4, 2, padding_idx=0)
lstm = t.nn.LSTM(input_size=2, hidden_size=6, batch_first=True)
a_input = emb(a)
print(a_input)
print(a_input.shape)
a_packed_input = pack_padded_sequence(input=a_input, lengths=a_lengths, batch_first=True)
print(a_packed_input)
packed_out, _ = lstm(a_packed_input)
print(packed_out)
out, _ = pad_packed_sequence(packed_out, batch_first=True, total_length=3)
print(out)
# 根据un_index将输出转回原输入顺序
out = t.index_select(out, 0, un_idx)
print(out)
print(out.shape) # torch.Size([4, 3, 4])
或者指定pack_padded_sequence()函数中的enforce_sorted参数为False,这种方式简单一点。
import torch as t
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
a = t.tensor(torch.tensor([[1, 2, 3],
[1, 2, 0],
[3, 0, 0],
[2, 1, 0]]))
lengths = t.tensor([3, 2, 1, 2])
# 定义层
emb = t.nn.Embedding(4, 2, padding_idx=0)
lstm = t.nn.LSTM(input_size=2, hidden_size=6, batch_first=True)
a_input = emb(a)
print(a_input)
print(a_input.shape)
a_packed_input = pack_padded_sequence(input=a_input, lengths=lengths, batch_first=True, enforce_sorted=False)
print('--------------->a_packed_input:', a_packed_input)
packed_out, _ = lstm(a_packed_input)
print('--------------->packed_out:', packed_out)
out, seq_len = pad_packed_sequence(packed_out, batch_first=True, total_length=3)
print('--------------->out:', out)
print('--------------->seq_len:', seq_len)
效果为:
tensor([[[-0.1088, 0.1528],
[-0.7593, -0.1880],
[-0.0589, -1.0816]],
[[-0.1088, 0.1528],
[-0.7593, -0.1880],
[ 0.0000, 0.0000]],
[[-0.0589, -1.0816],
[ 0.0000, 0.0000],
[ 0.0000, 0.0000]],
[[-0.7593, -0.1880],
[-0.1088, 0.1528],
[ 0.0000, 0.0000]]], grad_fn=<EmbeddingBackward>)
torch.Size([4, 3, 2])
--------------->a_packed_input: PackedSequence(data=tensor([[-0.1088, 0.1528],
[-0.1088, 0.1528],
[-0.7593, -0.1880],
[-0.0589, -1.0816],
[-0.7593, -0.1880],
[-0.7593, -0.1880],
[-0.1088, 0.1528],
[-0.0589, -1.0816]], grad_fn=<PackPaddedSequenceBackward>), batch_sizes=tensor([4, 3, 1]), sorted_indices=tensor([0, 1, 3, 2]), unsorted_indices=tensor([0, 1, 3, 2]))
--------------->packed_out: PackedSequence(data=tensor([[-0.1006, -0.0041, 0.0238, 0.0290, 0.0865, -0.0759],
[-0.1006, -0.0041, 0.0238, 0.0290, 0.0865, -0.0759],
[-0.1205, 0.0228, -0.0018, 0.0025, 0.1195, -0.0669],
[-0.1151, -0.0232, -0.0214, -0.0042, 0.1847, -0.0348],
[-0.1604, 0.0171, 0.0086, 0.0087, 0.1646, -0.0883],
[-0.1604, 0.0171, 0.0086, 0.0087, 0.1646, -0.0883],
[-0.1696, 0.0064, 0.0206, 0.0213, 0.1473, -0.0896],
[-0.2000, -0.0163, -0.0205, -0.0086, 0.2650, -0.0440]],
grad_fn=<CatBackward>), batch_sizes=tensor([4, 3, 1]), sorted_indices=tensor([0, 1, 3, 2]), unsorted_indices=tensor([0, 1, 3, 2]))
--------------->out: tensor([[[-0.1006, -0.0041, 0.0238, 0.0290, 0.0865, -0.0759],
[-0.1604, 0.0171, 0.0086, 0.0087, 0.1646, -0.0883],
[-0.2000, -0.0163, -0.0205, -0.0086, 0.2650, -0.0440]],
[[-0.1006, -0.0041, 0.0238, 0.0290, 0.0865, -0.0759],
[-0.1604, 0.0171, 0.0086, 0.0087, 0.1646, -0.0883],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[-0.1151, -0.0232, -0.0214, -0.0042, 0.1847, -0.0348],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[-0.1205, 0.0228, -0.0018, 0.0025, 0.1195, -0.0669],
[-0.1696, 0.0064, 0.0206, 0.0213, 0.1473, -0.0896],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]],
grad_fn=<IndexSelectBackward>)
--------------->seq_len: tensor([3, 2, 1, 2])