原因在于torch.max()改为torch.argmax()
out = model(img)
loss = criterion(out, label)
eval_loss += loss.data.item()*label.size(0)
pred = torch.argmax(out, 1)
num_correct = (pred == label).sum()
原因在于torch.max()改为torch.argmax()
out = model(img)
loss = criterion(out, label)
eval_loss += loss.data.item()*label.size(0)
pred = torch.argmax(out, 1)
num_correct = (pred == label).sum()