pytorch rnn 变长输入序列问题

pytorch实现变长输入的rnn分类

输入数据是长度不固定的序列数据,主要讲解两个部分

  1. Data.DataLoader的collate_fn用法,以及按batch进行padding数据
  2. pack_padded_sequence和pad_packed_sequence来处理变长序列

collate_fn

Dataloader的collate_fn参数,定义数据处理和合并成batch的方式。
由于pack_padded_sequence用到的tensor必须按照长度从大到小排过序的,所以在Collate_fn中,需要完成两件事,一是把当前batch的样本按照当前batch最大长度进行padding,二是将padding后的数据从大到小进行排序。

def pad_tensor(vec, pad):
    """
    args:
        vec - tensor to pad
        pad - the size to pad to

    return:
        a new tensor padded to 'pad'
    """
    return torch.cat([vec, torch.zeros(pad - len(vec), dtype=torch.float)], dim=0).data.numpy()

class Collate:
    """
    a variant of callate_fn that pads according to the longest sequence in
    a batch of sequences
    """

    def __init__(self):
        pass

    def _collate(self, batch):
        """
        args:
            batch - list of (tensor, label)

        reutrn:
            xs - a tensor of all examples in 'batch' before padding like:
                '''
                [tensor([1,2,3,4]),
                 tensor([1,2]),
                 tensor([1,2,3,4,5])]
                '''
            ys - a LongTensor of all labels in batch like:
                '''
                [1,0,1]
                '''
        """
        xs = [torch.FloatTensor(v[0]) for v in batch]
        ys = torch.LongTensor([v[1] for v in batch])
        # 获得每个样本的序列长度
        seq_lengths = torch.LongTensor([v for v in map(len, xs)])
        max_len = max([len(v) for v in xs])
        # 每个样本都padding到当前batch的最大长度
        xs = torch.FloatTensor([pad_tensor(v, max_len) for v in xs])
        # 把xs和ys按照序列长度从大到小排序
        seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
        xs = xs[perm_idx]
        ys = ys[perm_idx]
        return xs, seq_lengths, ys

    def __call__(self, batch):
        return self._collate(batch)

定义完collate类以后,在DataLoader中直接使用

train_data = Data.DataLoader(dataset=train_dataset, batch_size=32, num_workers=0, collate_fn=Collate())

torch.nn.utils.rnn.pack_padded_sequence()

pack_padded_sequence将一个填充过的变长序列压紧。输入参数包括

  • input(Variable)- 被填充过后的变长序列组成的batch data
  • lengths (list[int]) - 变长序列的原始序列长度
  • batch_first (bool,optional) - 如果是True,input的形状应该是(batch_size,seq_len,input_size)
    返回值:一个PackedSequence对象,可以直接作为rnn,lstm,gru的传入数据。
    用法:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
# x是填充过后的batch数据,seq_lengths是每个样本的序列长度
packed_input = pack_padded_sequence(x, seq_lengths, batch_first=True)

RNN模型

定义了一个单向的LSTM模型,因为处理的是变长序列,forward函数传入的值是一个PackedSequence对象,返回值也是一个PackedSequence对象

class Model(nn.Module):
    def __init__(self, in_size, hid_size, n_layer, drop=0.1, bi=False):
        super(Model, self).__init__()
        self.lstm = nn.LSTM(input_size=in_size,
                            hidden_size=hid_size,
                            num_layers=n_layer,
                            batch_first=True,
                            dropout=drop,
                            bidirectional=bi)
        # 分类类别数目为2
        self.fc = nn.Linear(in_features=hid_size, out_features=2)

    def forward(self, x):
        '''
        :param x: 变长序列时,x是一个PackedSequence对象
        :return: PackedSequence对象
        '''
        # lstm_out: tensor of shape (batch, seq_len, num_directions * hidden_size)
        lstm_out, _ = self.lstm(x)  
        
        return lstm_out

model = Model()
lstm_out = model(packed_input)

torch.nn.utils.rnn.pad_packed_sequence()

这个操作和pack_padded_sequence()是相反的,把压紧的序列再填充回来。因为前面提到的LSTM模型传入和返回的都是PackedSequence对象,所以我们如果想要把返回的PackedSequence对象转换回Tensor,就需要用到pad_packed_sequence函数。
参数说明:

  • sequence (PackedSequence) – 将要被填充的 batch
  • batch_first (bool, optional) – 如果为True,返回的数据的形状为(batch_size,seq_len,input_size)

返回值: 一个tuple,包含被填充后的序列,和batch中序列的长度列表。
用法:

# 此处lstm_out是一个PackedSequence对象
output, _ = pad_packed_sequence(lstm_out)

返回的output是一个形状为(batch_size,seq_len,input_size)的tensor。

总结

  1. pytorch在自定义dataset时,可以在DataLoader的collate_fn参数中定义对数据的变换,操作以及合成batch的方式。
  2. 处理变长rnn问题时,通过pack_padded_sequence()将填充的batch数据转换成PackedSequence对象,直接传入rnn模型中。通过pad_packed_sequence()来将rnn模型输出的PackedSequence对象转换回相应的Tensor。

Reference

https://blog.csdn.net/qq_27505047/article/details/78764888
https://zhuanlan.zhihu.com/p/60129684

  • 7
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值