RuntimeError: size mismatch, got 3, 3x2,9

在使用PyTorch时遇到了RuntimeError,原因是输入数据与模型设计的维度不匹配。模型期望的是3x2的输入,但实际输入数据的维度是9。示例中定义了一个MyLinearModel,其输入维度应为2,输出维度为3,而尝试用大小为9的一维张量x_train进行运算,导致错误。
摘要由CSDN通过智能技术生成

RuntimeError: size mismatch, got 3, 3x2,9

使用pytorch时,数据没输入对、或模型没设计对,导致参与运算的张量形状不匹配,无法运算,所以报错。

这里说明下各个数字指的是什么:
got 3, 3x2
其中的3表示:对应层原定义的输出维度是3
其中的2表示:对应层原定义的输入维度是2

class MyLinearModel(nn.Module):#定义模型类
    def __init__(self,input_dim,output_dim):
        super().__init__()#继承父类所有初始化函数
        self.linear=nn.Linear(input_dim,output_dim)#声明用到的层
    
    def forward(self,x):#定义前向传播
        y=self.linear(x)
        return y

my_model=MyLinearModel(2,3)
y=my_model(x_train)

9
9表示:实际输入数据的维度是9

x_train=np.arange(1,10)
x_train=x_train.astype(np.float32)
x_train=torch.from_numpy(x_train)
print(x_train.shape)#torch.Size([9])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值