import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets # 放置了许多常用数据集,包括手写数字识别import torch.nn.functional as F
deftest():
correct =0
total =0with torch.no_grad():for data in test_loader:
images, labels =data
outputs = model(images)
_, predicted = torch.max(outputs.data, dim =1)# 返回两个值,第一个是最大值,第二个是最大值的索引。dim=1表示在列维度求以上结果,dim = 0表示在行维度求以上结果。
total += labels.size(0)# 每一个batch_size 中labels是一个(N,1)的元组,size(0)=N
correct +=(predicted == labels).sum().item()# 对的总个数print("Accuracy on the test set %d %%"%(100*correct/total))
网络启动
if __name__=="__main__":for epoch inrange(10):
train(epoch)if epoch %2==0:
test()
[1, 300] loss: 0.559
[1, 600] loss: 0.169
[1, 900] loss: 0.134
Accuracy on the test set 96 %
[2, 300] loss: 0.109
[2, 600] loss: 0.091
[2, 900] loss: 0.084
[3, 300] loss: 0.072
[3, 600] loss: 0.072
[3, 900] loss: 0.071
Accuracy on the test set 98 %
[4, 300] loss: 0.067
[4, 600] loss: 0.055
[4, 900] loss: 0.058
[5, 300] loss: 0.053
[5, 600] loss: 0.053
[5, 900] loss: 0.055
Accuracy on the test set 98 %
[6, 300] loss: 0.048
[6, 600] loss: 0.048
[6, 900] loss: 0.047
[7, 300] loss: 0.042
[7, 600] loss: 0.043
[7, 900] loss: 0.045
Accuracy on the test set 98 %
[8, 300] loss: 0.043
[8, 600] loss: 0.036
[8, 900] loss: 0.038
[9, 300] loss: 0.038
[9, 600] loss: 0.032
[9, 900] loss: 0.040
Accuracy on the test set 98 %
[10, 300] loss: 0.038
[10, 600] loss: 0.029
[10, 900] loss: 0.036