概述
据博主了解,pytorch中目前没有能够直接去掉一个不为1的维度的函数。因此可以考虑用多个函数组合来进行这一操作,比如先把需要去掉的维度使用narrow()
减小到1,然后再用squeeze()
去掉这一维度。下文是具体的实验中遇到的实例和问题,以及narrow()
和squeeze()
配合用法的详解。
检查模型输出数据的形状
最近使用CRNN模型代码,跑自己的对比实验,在这篇博客:文本识别CRNN模型介绍以及pytorch代码实现中找到了现成的源码,但是模型放上去之后报错了:
ValueError: Expected input batch_size (55) to match target batch_size (8).
原因很明显是实际输入的数据和期待输入的数据形状不匹配。这时我注意到,原博客热心的博主还在代码里注释了模型输出output的形状:
return output # shape: (seq_len, batch, num_class)
output的形状是(seq_len, batch, num_class)
,说明这个output具有三个维度,分别是seq_len
, batch
, 和num_class
。
而检查一下我自己的代码,在经过模型之后得到推理值logits,和标签一起计算loss时需要的输入数据的形状是(batch, num_class)
,并不需要第一个维度seq_len。
Debug发现,我的batch_size
设置为8,而报错提到55正是seq_len
的值,因此让代码跑通的最简单的方式,就是想办法把output的第一个维度seq_len
直接去掉。但是如何去掉呢?
squeeze()和narrow()的用法
squeeze()的用法
想要减小维度,首先想到squeeze()
方法,这篇博客:pytorch中的squeeze和unsqueeze的用法小结给了squeeze()
和unsqueeze()
的详细用法介绍,但问题是,squeeze()
只能去掉维数为1的维度,即指定压缩第n维,如果它的维数为1,则压缩,反之不对该维度操作。而我要压缩的那个维度的维数是55,不为1,因此直接使用根本不起作用。
这可怎么办?我去全网搜索“pytorch如何去掉不为1的维度”,没有找到任何有效的解决方案,能找到的除了squeeze()
,还有reshape()
、resize_()
,前者不能改变总的维数,即要保证所有维度的所有维数相乘总数不变的情况下才能改变形状;而后者虽然可以改变维数,但不支持下一步的梯度计算,使用会报错:
RuntimeError: cannot resize variables that require grad
因此只能换其他方案,比如是否能先把seq_len
维度减小到1,然后再用squeeze()
去掉那个维度呢?果真有减小维数的方法!那就是narrow()
方法。
narrow()的用法
参考官网:torch.narrow()
用法:torch.narrow(input, dim, start, length) → Tensor
或input.narrow(dim, start, length) → Tensor
类似切片操作,返回输入张量的切片操作结果。 输入tensor和返回的tensor共享内存。
参数说明:
input (Tensor)
– 需切片的张量Tensordim (int)
– 需要被切片的维度start (int)
– 开始的索引,从那个元素开始切片length (int)
– 切片的长度
实例:
>>> x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
>>> x.narrow(0, 0, 2)
tensor([[ 1, 2, 3],
[ 4, 5, 6]])
>>> x.narrow(1, 1, 2)
tensor([[ 2, 3],
[ 5, 6],
[ 8, 9]])
因此,要把 (seq_len, batch, num_class) = (55, 8, 5)
变为 (batch, num_class) = (8,5)
,首先可以对seq_len
维度切片,将其维数变为1,即(55, 8, 5) → (1, 8, 5)
,然后再使用squeeze
去掉第0维度,即(1, 8, 5) → (8, 5)
。代码如下:
output = output.narrow(0, 1, 1)
# 需要被切片的维度是0,从第1个元素开始切片,切片长度是1。
# 其实,我们在第0维度切片,只要保留任意一个元素即可,
# 因此第二个参数可以是0,1,2,...,54中的任何一个数,因为第0维长度是55。
output = torch.squeeze(output, dim=0)
我们找到CRNN模型的forward(),在return之前加上上面两行代码:
# CNN+LSTM前向计算
def forward(self, images):
# shape of images: (batch, channel, height, width)
conv = self.cnn(images)
batch, channel, height, width = conv.size()
conv = conv.view(batch, channel * height, width)
conv = conv.permute(2, 0, 1) # (width, batch, feature)
# 卷积接全连接。全连接输入形状为(width, batch, channel*height),
# 输出形状为(width, batch, hidden_layer),分别对应时序长度,batch,特征数,符合LSTM输入要求
seq = self.map_to_seq(conv)
recurrent, _ = self.rnn1(seq)
recurrent, _ = self.rnn2(recurrent)
output = self.dense(recurrent)
output = output.narrow(0, 1, 1)
output = torch.squeeze(output, dim=0)
return output # shape: (seq_len, batch, num_class)
然后再次试运行,跑通了!
总结
据博主所致,目前Python中没有直接封装好的去掉某个非1维度的函数,因此可以将多个函数组合使用。本文提倡先把需要去掉的维度使用narrow()
减小到1,然后再用squeeze()
去掉这一维度。