以vgg19为例
import os import torch from torch.utils.data import DataLoader from torchvision import datasets from torchvision import transforms from torch import nn,optim #from lenet5 import Lenet5 from vgg import VGG19 from vgg import VGG34 def main(): batchsz = 32 cifar_train = datasets.CIFAR10('dataset/', train=True, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]), download=True) cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True) cifar_test = datasets.CIFAR10('dataset/', train=False, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]), download=True) cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True) x, label = iter(cifar_train).next() print('x:', x.shape, 'label:', label.shape) device = torch.device('cuda') model = VGG19().to(device) criton = nn.CrossEntropyLoss().to(device) #包含了softmax optimizer = optim.Adam(model.parameters(),lr=1e-3) print(model) if os.path.exists('model.pkl'): model.load_state_dict(torch.load('model.pkl')) print('model loaded from model.pkl') classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') N_CLASSES = 10 class_correct = list(0. for i in range(N_CLASSES)) class_total = list(0. for i in range(N_CLASSES)) model.eval() #test total_correct = 0 total_num = 0 for x, label in cifar_test: x,label = x.to(device) ,label.to(device) logits = model(x) pred = logits.argmax(dim=1) total_correct += torch.eq(pred,label).float().sum().item() total_num += x.size(0) #即batch_size c = (pred == label).squeeze() for i in range(len(label)): _label = label[i] class_correct[_label] += c[i].item() class_total[_label] += 1 acc = total_correct / total_num print('acc: ',acc) for i in range(N_CLASSES): print('Accuracy of %5s : %2d %%' % ( classes[i], 100 * class_correct[i] / class_total[i])) if __name__ == '__main__': main()
而训练了vgg34则会是这样
参考:
https://blog.csdn.net/Arctic_Beacon/article/details/85068188