代码非原创,具体我暂时不懂,都是通过到处copy的代码最终实现想要的效果
cifar10训练20轮,精确度达到73%,如果想要达到更好的精确度,建议换个模型跑。
cifar10训练
1.对图片归一化处理
transform = transforms.Compose([ transforms.Resize(32), transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
2.加载数据集
train_dataset = torchvision.datasets.CIFAR10(root=PATH, train=True, transform=transform, download=True) train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True) test_dataset = torchvision.datasets.CIFAR10(root=PATH, train=False, transform=transform, download=True) test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False) classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
3.搭建网络模型
class CNN(nn.Module):#简单网络 def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5)#卷积 self.pool = nn.MaxPool2d(2, 2)#池化 self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84)#全连接层 self.fc3 = nn.Linear(84, 10) def forward(self,x):#构建模型 x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
4.选择设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(device)
5.定义网络模型
cnn = CNN().to(device)
6.学习率和动量
crit = torch.nn.CrossEntropyLoss() opti = torch.optim.SGD(cnn.parameters(), lr=0.002, momentum=0.9)
7.训练模型
def train(): for epoch in range(1): running_loss = 0.0 for i, data in enumerate(train_loader): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = cnn(inputs) opti.zero_grad() loss = crit(outputs, labels) loss.backward() opti.step() running_loss += loss.item() if i % 200 == 199: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 print('Finish training')
8.测试模型和计算代码运行时间
def test(): correct = 0 total = 0 for data in test_loader: images, labels = data outputs = cnn(Variable(images)) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum() print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total)) class_correct = list(0. for i in range(10)) class_total = list(0. for i in range(10)) with torch.no_grad(): for data in test_loader: images, labels = data images, labels = images.to(device), labels.to(