RuntimeError: size mismatch, got 3, 3x2,9
使用pytorch时,数据没输入对、或模型没设计对,导致参与运算的张量形状不匹配,无法运算,所以报错。
这里说明下各个数字指的是什么:
got 3, 3x2
其中的3表示:对应层原定义的输出维度是3
其中的2表示:对应层原定义的输入维度是2
class MyLinearModel(nn.Module):#定义模型类
def __init__(self,input_dim,output_dim):
super().__init__()#继承父类所有初始化函数
self.linear=nn.Linear(input_dim,output_dim)#声明用到的层
def forward(self,x):#定义前向传播
y=self.linear(x)
return y
my_model=MyLinearModel(2,3)
y=my_model(x_train)
9
9表示:实际输入数据的维度是9
x_train=np.arange(1,10)
x_train=x_train.astype(np.float32)
x_train=torch.from_numpy(x_train)
print(x_train.shape)#torch.Size([9])