网络的输入:
batch=64
self.linear_1 = nn.Linear(28*28, 120)
self.linear_2 = nn.Linear(120, 84)
self.linear_3 = nn.Linear(84, 10)
输入形状为(64,784),输出形状为(64,10)
形状说明
y_pred = model(x)
loss = loss_fn(y_pred, y_label)
训练中,一次batch=64的数据中:
- 网络的输出y_pred.shape为(64,10)
- minist的label是一个数值(也相当于0-9索引),并不是独热编码,因此,学习目标y_label.shape为(64)
两者需要用交叉熵计算算loss,CrossEntropyLoss会将索引转化为one-hot编码。
注意loss计算时&#