import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets # 放置了许多常用数据集,包括手写数字识别import torch.nn.functional as F
deftest():
correct =0
total =0with torch.no_grad():for data in test_loader:
images, labels =data
outputs = model(images)
_, predicted = torch.max(outputs.data, dim =1)# 返回两个值,第一个是最大值,第二个是最大值的索引。dim=1表示在列维度求以上结果,dim = 0表示在行维度求以上结果。
total += labels.size(0)# 每一个batch_size 中labels是一个(N,1)的元组,size(0)=N
correct +=(predicted == labels).sum().item()# 对的总个数print("Accuracy on the test set %d %%"%(100*correct/total))
网络启动
if __name__=="__main__":for epoch inrange(10):
train(epoch)if epoch %2==0:
test()
[1, 300] loss: 2.190
[1, 600] loss: 0.789
[1, 900] loss: 0.414
Accuracy on the test set 89 %
[2, 300] loss: 0.316
[2, 600] loss: 0.265
[2, 900] loss: 0.224
[3, 300] loss: 0.190
[3, 600] loss: 0.162
[3, 900] loss: 0.159
Accuracy on the test set 96 %
[4, 300] loss: 0.131
[4, 600] loss: 0.121
[4, 900] loss: 0.114
[5, 300] loss: 0.097
[5, 600] loss: 0.100
[5, 900] loss: 0.090
Accuracy on the test set 96 %
[6, 300] loss: 0.074
[6, 600] loss: 0.077
[6, 900] loss: 0.075
[7, 300] loss: 0.061
[7, 600] loss: 0.059
[7, 900] loss: 0.061
Accuracy on the test set 97 %
[8, 300] loss: 0.049
[8, 600] loss: 0.049
[8, 900] loss: 0.050
[9, 300] loss: 0.037
[9, 600] loss: 0.040
[9, 900] loss: 0.042
Accuracy on the test set 97 %
[10, 300] loss: 0.033
[10, 600] loss: 0.028
[10, 900] loss: 0.035