pytorch中的pack和pad操作

背景

  进行训练和测试时,有时难以保证输入文本长度的一致性,因此常常需要截断操作(即将超过预设长度的文本截断)和pad操作(即对不足预设长度的文本进行补0)。
  在Pytorch中的torch.nn.utils.rnn,提供了pad和pack,pack_padded_sequence和pad_packed_sequence四种方法实现这一操作。

pad和pack

  举一个简单的例子:

from torch.nn.utils.rnn import pack_sequence, pad_sequence,pad_packed_sequence, pack_padded_sequence, 

text1 = torch.tensor([1,2,3,4])    # 可视为有4个文字的样本
text2 = torch.tensor([5,6,7])  # 可视为有3个文字的样本
text3 = torch.tensor([8,9])    # 可视为有2个文字的样本
sequences = [text1, text2, text3]  # 三个文本序列

  pack操作将原来的二维数据(batch*sequence)进行了压缩,但其排列是按照列(即sequence的顺序)进行排列,每个时间步一次性输出batch上的所有样本,即:
在这里插入图片描述
  pack后的返回值包括两数据。一类为data,即压缩后的数据;而另一类batch_sizes表示每个时间步,batch中包含的样本量。

[Input]  pack_sequence(sequences)
[Output] PackedSequence(data=tensor([1, 5, 8, 2, 6, 9, 3, 7, 4]), batch_sizes=tensor([3, 3, 2, 1]))

  pad操作即是将不同长度的文本序列进行补0。需要注意的是,这个没有dim参数,替代的是batch_first参数,即第一个维度是否是batch,在默认情况下,参数batch_first=False,这是rnn网络的推荐用法。其返回值的第一个维度将变成sequence,而第二个维度才为batch。如果是仅仅使用这种方法对数据进行补齐或截断,可以通过设置batch_first=True,使得返回值的第一个维度为batch,从而保持与输入值的一致性。
在这里插入图片描述

[Input]  pad_sequence(sequences)
[Output] tensor([[1, 5, 8],
        		[2, 6, 9],
        		[3, 7, 0],
        		[4, 0, 0]])

[Input]  pad_sequence(sequences, batch_first=True)
[Output] tensor([[1, 2, 3, 4],
        		[5, 6, 7, 0],
        		[8, 9, 0, 0]])

pack_padded_sequence和pad_packed_sequence

  因为pytorch中的RNN网络可以接受的是PackedSequence类型数据(通过pack操作实现),而pad操作又可以实现不等长文本的填充对齐,所以自然会想到将两个操作联合起来,这就是pytorch提供的pack_padded_sequence和pad_packed_sequence功能。
  pack_padded_sequence就是将经pad后的文本序列在做pack,从而实现对文本缺失位置的填0和维度压缩:
在这里插入图片描述

[Input]  pack_padded_sequence(pad_sequence(sequences,batch_first=True),lengths=[4,3,3], batch_first=True)
[Output] PackedSequence(data=tensor([1, 5, 8, 2, 6, 9, 3, 7, 0, 4, 0]), batch_sizes=tensor([3, 3, 3, 2]))

  pack_padded_sequence函数接收一个padded_sequence数据;根据batch_first参数明确该数据的布局(默认为batch_first=False);根据lengths参数明确batch内各样本的时间步长,选择数据;将上述数据按照时间维度进行压缩,得到目标的PackedSequence类型数据。
  pad_packed_sequence函数即为pack_padded_sequence的逆操作,其在参数设定时也通过batch_first控制返回值的维度顺序,同时可通过设置total_lengths来控制pad后的总步长(该值必须不小于输入PackedSequence数据的步长数):

[Input] pad_packed_sequence(pack_sequence(sequences),total_length=5,batch_first=True)
[Output] (tensor([[1, 2, 3, 4, 0],
         [5, 6, 7, 0, 0],
         [8, 9, 0, 0, 0]]), tensor([4, 3, 2]))

  • 29
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
CTC (Connectionist Temporal Classification) 是一种用于无需对齐标签序列的序列学习方法,常被用于语音识别、光学字符识别等任务PyTorch 是一个流行的深度学习框架,提供了丰富的工具和接口来实现各种深度学习任务,包括使用 CTC 的序列学习。 要在 PyTorch 使用 CTC,可以使用 `torch.nn.CTCLoss` 模块计算 CTC 损失,该模块需要输入预测序列、标签序列和预测序列长度等参数。可以使用 `torch.nn.utils.rnn.pack_padded_sequence` 和 `torch.nn.utils.rnn.pad_packed_sequence` 模块来处理变长序列输入。 以下是一个使用 PyTorch 和 CTC 实现的简单语音识别示例: ``` import torch import torch.nn as nn # 定义模型 class SpeechRecognitionModel(nn.Module): def __init__(self, input_size, hidden_size, num_layers, num_classes): super(SpeechRecognitionModel, self).__init__() self.rnn = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True) self.fc = nn.Linear(hidden_size * 2, num_classes) def forward(self, x): x, lengths = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) x, _ = self.rnn(x) x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) x = self.fc(x) return x # 计算 CTC 损失 loss_fn = nn.CTCLoss() # 定义数据和标签 data = torch.randn(10,20, 40) # (batch_size, seq_len, input_size) label = [torch.randint(1, 30, (5,), dtype=torch.long) for i in range(10)] # 变长标签序列 # 计算预测序列长度 input_lengths = torch.full((10,), 20, dtype=torch.long) # 计算标签序列长度 label_lengths = torch.tensor([len(l) for l in label], dtype=torch.long) # 初始化模型和优化器 model = SpeechRecognitionModel(40, 256, 3, 30) optimizer = torch.optim.Adam(model.parameters()) # 训练模型 for i in range(100): optimizer.zero_grad() output = model(data) output_lengths = torch.full((10,), output.shape[1], dtype=torch.long) loss = loss_fn(output, label, output_lengths, label_lengths) loss.backward() optimizer.step() print("Iteration {}: Loss = {}".format(i+1, loss.item())) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

一本糊涂张~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值