Mnist数据集的图片尺寸一般是28*28 = 784,这里模拟一下搭建一个神经元的网络来进行预测和loss的计算,代码如下。
import torch
from torch import nn
import math
import torch.nn.functional as F
class Mnist_net(nn.Module):
def __init__(self):
super().__init__()
self.weights = nn.Parameter(torch.randn(784,10)/math.sqrt(784))
self.bias = nn.Parameter(torch.zeros(10))
def forward(self,xb):
xb = xb@self.weights +self.bias
return xb
model = Mnist_net()
loss_func = F.cross_entropy
xb = torch.randn(2,784)
yb = torch.tensor([1,1])
print ("model(xb):",model(xb))
print ("yb:",yb)
loss = loss_func(model(xb),yb)
print ("loss:",loss)
这里要注意的是yb是实际输出的标注,model(xb)是预测的输出,这里yb和model(xb)是不同的尺寸,如下。
model(xb): tensor([[ 0.3825, -0.0071, 0.4110, -0.0534, -0.3756, 0.7349, -3.1753, 0.7359,
-0.1801, 0.1951],
[-1.4873,