pytorch如何去掉某个不为1的维度

概述

据博主了解,pytorch中目前没有能够直接去掉一个不为1的维度的函数。因此可以考虑用多个函数组合来进行这一操作,比如先把需要去掉的维度使用narrow()减小到1,然后再用squeeze()去掉这一维度。下文是具体的实验中遇到的实例和问题,以及narrow()squeeze()配合用法的详解。

检查模型输出数据的形状

最近使用CRNN模型代码,跑自己的对比实验,在这篇博客:文本识别CRNN模型介绍以及pytorch代码实现中找到了现成的源码,但是模型放上去之后报错了:

在这里插入图片描述
ValueError: Expected input batch_size (55) to match target batch_size (8).

原因很明显是实际输入的数据和期待输入的数据形状不匹配。这时我注意到,原博客热心的博主还在代码里注释了模型输出output的形状:

return output # shape: (seq_len, batch, num_class)

output的形状是(seq_len, batch, num_class),说明这个output具有三个维度,分别是seq_len, batch, 和num_class
而检查一下我自己的代码,在经过模型之后得到推理值logits,和标签一起计算loss时需要的输入数据的形状是(batch, num_class),并不需要第一个维度seq_len。
Debug发现,我的batch_size设置为8,而报错提到55正是seq_len的值,因此让代码跑通的最简单的方式,就是想办法把output的第一个维度seq_len直接去掉。但是如何去掉呢?

squeeze()和narrow()的用法

squeeze()的用法

想要减小维度,首先想到squeeze()方法,这篇博客:pytorch中的squeeze和unsqueeze的用法小结给了squeeze()unsqueeze()的详细用法介绍,但问题是,squeeze()只能去掉维数为1的维度,即指定压缩第n维,如果它的维数为1,则压缩,反之不对该维度操作。而我要压缩的那个维度的维数是55,不为1,因此直接使用根本不起作用。
这可怎么办?我去全网搜索“pytorch如何去掉不为1的维度”,没有找到任何有效的解决方案,能找到的除了squeeze(),还有reshape()resize_(),前者不能改变总的维数,即要保证所有维度的所有维数相乘总数不变的情况下才能改变形状;而后者虽然可以改变维数,但不支持下一步的梯度计算,使用会报错:

这里是引用
RuntimeError: cannot resize variables that require grad

因此只能换其他方案,比如是否能先把seq_len维度减小到1,然后再用squeeze()去掉那个维度呢?果真有减小维数的方法!那就是narrow()方法。

narrow()的用法

参考官网:torch.narrow()

用法:torch.narrow(input, dim, start, length) → Tensor
input.narrow(dim, start, length) → Tensor
类似切片操作,返回输入张量的切片操作结果。 输入tensor和返回的tensor共享内存。

参数说明:

  • input (Tensor) – 需切片的张量Tensor
  • dim (int) – 需要被切片的维度
  • start (int) – 开始的索引,从那个元素开始切片
  • length (int) – 切片的长度

实例:

        >>> x = torch.tensor([[1, 2, 3], 
        					   [4, 5, 6],
        					   [7, 8, 9]])
        >>> x.narrow(0, 0, 2)
        tensor([[ 1,  2,  3],
                [ 4,  5,  6]])
        >>> x.narrow(1, 1, 2)
        tensor([[ 2,  3],
                [ 5,  6],
                [ 8,  9]])

因此,要把 (seq_len, batch, num_class) = (55, 8, 5)变为 (batch, num_class) = (8,5),首先可以对seq_len维度切片,将其维数变为1,即(55, 8, 5) → (1, 8, 5),然后再使用squeeze去掉第0维度,即(1, 8, 5) → (8, 5)。代码如下:

output = output.narrow(0, 1, 1)
# 需要被切片的维度是0,从第1个元素开始切片,切片长度是1。
# 其实,我们在第0维度切片,只要保留任意一个元素即可,
# 因此第二个参数可以是0,1,2,...,54中的任何一个数,因为第0维长度是55。
output = torch.squeeze(output, dim=0)

我们找到CRNN模型的forward(),在return之前加上上面两行代码:

 # CNN+LSTM前向计算
    def forward(self, images):
        # shape of images: (batch, channel, height, width)
        conv = self.cnn(images)
        batch, channel, height, width = conv.size()
        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(2, 0, 1)  # (width, batch, feature)
        # 卷积接全连接。全连接输入形状为(width, batch, channel*height),
        # 输出形状为(width, batch, hidden_layer),分别对应时序长度,batch,特征数,符合LSTM输入要求
        seq = self.map_to_seq(conv)
        recurrent, _ = self.rnn1(seq)
        recurrent, _ = self.rnn2(recurrent)
        output = self.dense(recurrent)
        output = output.narrow(0, 1, 1)
        output = torch.squeeze(output, dim=0)
        return output  # shape: (seq_len, batch, num_class)

然后再次试运行,跑通了!
在这里插入图片描述

总结

据博主所致,目前Python中没有直接封装好的去掉某个非1维度的函数,因此可以将多个函数组合使用。本文提倡先把需要去掉的维度使用narrow()减小到1,然后再用squeeze()去掉这一维度。

  • 10
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Dijkstra's Monk-ey

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值