解决Indexerror: dimension out of range (expected to be in range of [-1, 0], but got 1)

当batch_size设置为1时,代码在softmax层遇到维度错误。原因是输出矩阵的维度变为[1,maxlen,1],而softmax要求二维输入。解决方案是在调用squeeze后使用unsqueeze增加维度,确保softmax函数正确应用。通过在att_score上使用torch.unsqueeze(dim=0),然后进行softmax操作,可以避免此错误。
摘要由CSDN通过智能技术生成

问题描述

在复现代码时,把batch_size调整为1,结果softmax报以下错误:

# Super Sketch Network links a RNN and CNN together with an attention layer in the last layer.
class SSN(nn.Module):
    
    def __init__(self, cnn_model_name,rnn_model_name, d_frozen = True,num_classes=40):
        pass
                
        
    def forward(self, images,strokes):
        cnn_output,cnn_f = self.cnn(images)
        rnn_output,rnn_f = self.rnn(strokes,None)
        
        #Attention Layer linking RNN and CNN together.
        output = torch.stack([cnn_output,rnn_output],dim = 1)
        
        #Get the center feature
        ssn_feat = torch.cat((cnn_f,rnn_f),dim = 1)
        att_score = torch.matmul(output, self.attention).squeeze()
        att_score = F.softmax(att_score,dim = 1).view(output.size(0), output.size(1), 1)
        score = output * att_score

        score = torch.sum(score, dim=1)
        
        return score,ssn_feat

Indexerror: dimension out of range (expected to be in range of [-1, 0], but got 1)

解决方案

在使用softmax函数时需要保证矩阵是二维的,但是当batch_size=1时,整个output矩阵的维度为 

[batch_size =1,maxlen,1],

如果不指定output.squeeze的维度就会得到[maxlen]的维度导致报错。

解决方案是手动给张量升维。在PyTorch中,可以使用torch.unsqueeze函数增加张量的维度。该函数会在指定位置(默认为最后一个维度)增加一个维度,使得张量的维度增加1。例如,对于一个形状为(3, 4)的张量,使用torch.unsqueeze(input, dim=0)可以得到一个形状为(1, 3, 4)的张量。具体用法如下:

import torch

# 假设输入的张量为tensor,shape为(n,)
tensor = torch.tensor([1, 2, 3])

# 增加一个维度,变成(1, n)
tensor = torch.unsqueeze(tensor, dim=0)

应用在本段代码就是:

# Super Sketch Network links a RNN and CNN together with an attention layer in the last layer.
class SSN(nn.Module):
    
    def __init__(self, cnn_model_name,rnn_model_name, d_frozen = True,num_classes=40):
        pass
                
        
    def forward(self, images,strokes):
        cnn_output,cnn_f = self.cnn(images)
        rnn_output,rnn_f = self.rnn(strokes,None)
        
        #Attention Layer linking RNN and CNN together.
        output = torch.stack([cnn_output,rnn_output],dim = 1)
        
        #Get the center feature
        ssn_feat = torch.cat((cnn_f,rnn_f),dim = 1)
        att_score = torch.matmul(output, self.attention).squeeze()
        att_score = torch.unsqueeze(att_score,dim=0)
        att_score = F.softmax(att_score,dim = 1).view(output.size(0), output.size(1), 1)
        score = output * att_score

        score = torch.sum(score, dim=1)
        
        return score,ssn_feat

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

北京地铁1号线

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值