关于Pytorch 0.3 nn.Module的子类,前向传播过程的问题

22 篇文章 0 订阅
18 篇文章 1 订阅

先上代码:

class ft_net(nn.Module):

    def __init__(self, class_num ):
        super(ft_net, self).__init__()
        model_ft = models.resnet50(pretrained=True)
        # avg pooling to global pooling
        model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.model = model_ft
        self.classifier = ClassBlock(4096, 512)

    def forward(self, x,y):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        x = torch.squeeze(x)
        print(numpy.shape(x))

当dataloader的batch_size大于1时,执行了打印语句,但是当batch_size等于1时,却没有执行打印语句,再对张量进行观察,当batch_size大于1时,x是[n,1024],但是当batch_size等于1时,却是[1024],
注意!!这里不是[1,1024],这是张量维数的变化,当你代码中使用.cat() .view()等函数时,这将会报错。
之所以出现这个原因,是加入了torch.squeeze(x),squeeze将输入张量形状中的1 去除并返回,所以 一旦batch_size为1,就把4D张量第一维给抹去了

  • 5
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值