文章目录
1.导库
import time
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import argparse
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2.获取数据集
CIFAR10数据集简单介绍
CIFAR-10 是一个包含60000张图片的数据集。其中每张照片为32*32的彩色照片,每个像素点包括RGB三个数值,数值范围 0 ~ 255。
所有照片分属10个不同的类别,分别是 ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’
其中五万张图片被划分为训练集,剩下的一万张图片属于测试集。
加载数据集
batch_size = 64
train_dataset = datasets.CIFAR10(root='data',
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.CIFAR10(root='data',
train=False,
transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
3.创建VGG19模型
import time
import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
class VGG19(torch.nn.Module):
def __init__(self,num_classes):
super(VGG19, self).__init__()
self.block1 = nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=64,kernel_size=(3,3),stride=(1,1),padding=1),
nn.BatchNorm2d(num_features=64),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))
)
self.block2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=128),
nn.ReLU(),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
)
self.block3 = nn.Sequential(
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=256),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=256),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=256),
nn.ReLU(),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=256),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
)
self.block4 = nn.Sequential(
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=512),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=512),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=512),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=512),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
)
self.block5 = nn.Sequential(
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=512),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=512),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=512),
nn.ReLU(),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(3, 3), stride=(1, 1), padding=1),
nn.BatchNorm2d(num_features=512),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
)
self.classifier = nn.Sequential(
nn.Linear(512*1*1,4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096,4096),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(4096,num_classes)
)
def forward(self,x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
logits = self.classifier(x.view(-1,512*1*1))
probas = F.softmax(logits,dim = 1)
return logits,probas
简单查看一下网络结构
net = VGG19(10)
print(net)
print(net(torch.randn([1,3,32,32])))
VGG19(
(block1): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
(block2): Sequential(
(0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
(block3): Sequential(
(0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU()
(12): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
(block4): Sequential(
(0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(10): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU()
(12): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
(block5): Sequential(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(10): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU()
(12): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
)
(classifier): Sequential(
(0): Linear(in_features=512, out_features=4096, bias=True)
(1): ReLU()
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU()
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=10, bias=True)
)
)
(tensor([[ 0.2961, 0.0183, 0.1557, -0.2880, -0.2378, -0.1892, -0.1059, -0.2017,
0.2670, 0.4931]], grad_fn=<AddmmBackward>), tensor([[0.1274, 0.0965, 0.1107, 0.0710, 0.0747, 0.0784, 0.0852, 0.0774, 0.1237,
0.1551]], grad_fn=<SoftmaxBackward>))
NUM_EPOCHS = 15
model = VGG19(num_classes=10)
model = model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
4.开启训练
valid_loader = test_loader
def compute_accuracy_and_loss(model, data_loader, device):
correct_pred, num_examples = 0, 0
cross_entropy = 0.
for i, (features, targets) in enumerate(data_loader):
features = features.to(device)
targets = targets.to(device)
logits, probas = model(features)
cross_entropy += F.cross_entropy(logits, targets).item()
_, predicted_labels = torch.max(probas, 1)
num_examples += targets.size(0)
correct_pred += (predicted_labels == targets).sum()
return correct_pred.float()/num_examples * 100, cross_entropy/num_examples
start_time = time.time()
train_acc_lst, valid_acc_lst = [], []
train_loss_lst, valid_loss_lst = [], []
for epoch in range(NUM_EPOCHS):
model.train()
for batch_idx, (features, targets) in enumerate(train_loader):
### PREPARE MINIBATCH
features = features.to(DEVICE)
targets = targets.to(DEVICE)
### FORWARD AND BACK PROP
logits, probas = model(features)
cost = F.cross_entropy(logits, targets)
optimizer.zero_grad()
cost.backward()
### UPDATE MODEL PARAMETERS
optimizer.step()
### LOGGING
if not batch_idx % 300:
print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '
f'Batch {batch_idx:03d}/{len(train_loader):03d} |'
f' Cost: {cost:.4f}')
# no need to build the computation graph for backprop when computing accuracy
model.eval()
with torch.set_grad_enabled(False):
train_acc, train_loss = compute_accuracy_and_loss(model, train_loader, device=DEVICE)
valid_acc, valid_loss = compute_accuracy_and_loss(model, valid_loader, device=DEVICE)
train_acc_lst.append(train_acc)
valid_acc_lst.append(valid_acc)
train_loss_lst.append(train_loss)
valid_loss_lst.append(valid_loss)
print(f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} Train Acc.: {train_acc:.2f}%'
f' | Validation Acc.: {valid_acc:.2f}%')
elapsed = (time.time() - start_time)/60
print(f'Time elapsed: {elapsed:.2f} min')
elapsed = (time.time() - start_time)/60
print(f'Total Training Time: {elapsed:.2f} min')
训练结果
Epoch: 001/015 | Batch 000/782 | Cost: 2.2996
Epoch: 001/015 | Batch 300/782 | Cost: 1.9033
Epoch: 001/015 | Batch 600/782 | Cost: 1.8213
Epoch: 001/015 Train Acc.: 27.86% | Validation Acc.: 27.72%
Time elapsed: 0.79 min
Epoch: 002/015 | Batch 000/782 | Cost: 1.8215
Epoch: 002/015 | Batch 300/782 | Cost: 1.7313
Epoch: 002/015 | Batch 600/782 | Cost: 1.5908
Epoch: 002/015 Train Acc.: 31.12% | Validation Acc.: 30.99%
Time elapsed: 1.58 min
Epoch: 003/015 | Batch 000/782 | Cost: 1.4785
Epoch: 003/015 | Batch 300/782 | Cost: 1.6117
Epoch: 003/015 | Batch 600/782 | Cost: 1.4389
Epoch: 003/015 Train Acc.: 38.75% | Validation Acc.: 38.88%
Time elapsed: 2.37 min
Epoch: 004/015 | Batch 000/782 | Cost: 1.7106
Epoch: 004/015 | Batch 300/782 | Cost: 1.5534
Epoch: 004/015 | Batch 600/782 | Cost: 1.3664
Epoch: 004/015 Train Acc.: 48.24% | Validation Acc.: 48.28%
Time elapsed: 3.17 min
Epoch: 005/015 | Batch 000/782 | Cost: 1.3960
Epoch: 005/015 | Batch 300/782 | Cost: 1.2528
Epoch: 005/015 | Batch 600/782 | Cost: 1.0285
Epoch: 005/015 Train Acc.: 57.69% | Validation Acc.: 56.56%
Time elapsed: 3.96 min
Epoch: 006/015 | Batch 000/782 | Cost: 1.2201
Epoch: 006/015 | Batch 300/782 | Cost: 1.3973
Epoch: 006/015 | Batch 600/782 | Cost: 0.9932
Epoch: 006/015 Train Acc.: 63.45% | Validation Acc.: 62.00%
Time elapsed: 4.75 min
Epoch: 007/015 | Batch 000/782 | Cost: 1.0917
Epoch: 007/015 | Batch 300/782 | Cost: 0.9945
Epoch: 007/015 | Batch 600/782 | Cost: 0.9642
Epoch: 007/015 Train Acc.: 70.49% | Validation Acc.: 68.02%
Time elapsed: 5.55 min
Epoch: 008/015 | Batch 000/782 | Cost: 0.9970
Epoch: 008/015 | Batch 300/782 | Cost: 1.0228
Epoch: 008/015 | Batch 600/782 | Cost: 0.7637
Epoch: 008/015 Train Acc.: 71.08% | Validation Acc.: 68.62%
Time elapsed: 6.34 min
Epoch: 009/015 | Batch 000/782 | Cost: 0.8868
Epoch: 009/015 | Batch 300/782 | Cost: 0.6673
Epoch: 009/015 | Batch 600/782 | Cost: 0.6563
Epoch: 009/015 Train Acc.: 76.39% | Validation Acc.: 73.18%
Time elapsed: 7.13 min
Epoch: 010/015 | Batch 000/782 | Cost: 0.5915
Epoch: 010/015 | Batch 300/782 | Cost: 0.6803
Epoch: 010/015 | Batch 600/782 | Cost: 0.7547
Epoch: 010/015 Train Acc.: 80.88% | Validation Acc.: 76.85%
Time elapsed: 7.92 min
Epoch: 011/015 | Batch 000/782 | Cost: 0.7416
Epoch: 011/015 | Batch 300/782 | Cost: 0.5792
Epoch: 011/015 | Batch 600/782 | Cost: 0.6585
Epoch: 011/015 Train Acc.: 78.51% | Validation Acc.: 74.52%
Time elapsed: 8.71 min
Epoch: 012/015 | Batch 000/782 | Cost: 0.5723
Epoch: 012/015 | Batch 300/782 | Cost: 0.4954
Epoch: 012/015 | Batch 600/782 | Cost: 0.6602
Epoch: 012/015 Train Acc.: 86.85% | Validation Acc.: 81.15%
Time elapsed: 9.50 min
Epoch: 013/015 | Batch 000/782 | Cost: 0.4229
Epoch: 013/015 | Batch 300/782 | Cost: 0.3894
Epoch: 013/015 | Batch 600/782 | Cost: 0.6071
Epoch: 013/015 Train Acc.: 81.52% | Validation Acc.: 76.20%
Time elapsed: 10.30 min
Epoch: 014/015 | Batch 000/782 | Cost: 0.3842
Epoch: 014/015 | Batch 300/782 | Cost: 0.3581
Epoch: 014/015 | Batch 600/782 | Cost: 0.4606
Epoch: 014/015 Train Acc.: 89.17% | Validation Acc.: 82.13%
Time elapsed: 11.09 min
Epoch: 015/015 | Batch 000/782 | Cost: 0.3284
Epoch: 015/015 | Batch 300/782 | Cost: 0.3410
Epoch: 015/015 | Batch 600/782 | Cost: 0.3267
Epoch: 015/015 Train Acc.: 89.35% | Validation Acc.: 81.93%
Time elapsed: 11.88 min
Total Training Time: 11.88 min
训练损失和测试损失关系图
plt.plot(range(1, NUM_EPOCHS+1), train_loss_lst, label='Training loss')
plt.plot(range(1, NUM_EPOCHS+1), valid_loss_lst, label='Validation loss')
plt.legend(loc='upper right')
plt.ylabel('Cross entropy')
plt.xlabel('Epoch')
plt.show()
训练精度和测试精度关系图
plt.plot(range(1, NUM_EPOCHS+1), train_acc_lst, label='Training accuracy')
plt.plot(range(1, NUM_EPOCHS+1), valid_acc_lst, label='Validation accuracy')
plt.legend(loc='upper left')
plt.ylabel('Cross entropy')
plt.xlabel('Epoch')
plt.show()
5.测试阶段
model.eval()
with torch.set_grad_enabled(False): # save memory during inference
test_acc, test_loss = compute_accuracy_and_loss(model, test_loader, DEVICE)
print(f'Test accuracy: {test_acc:.2f}%')
Test accuracy: 81.93%
6.查看效果图
from PIL import Image
import matplotlib.pyplot as plt
for features, targets in train_loader:
break
#预测环节
_, predictions = model.forward(features[:8].to(DEVICE))
predictions = torch.argmax(predictions, dim=1)
print(predictions)
features = features[:7]
fig = plt.figure()
tname = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# print(features[i].size())
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
tmp = features[i]
plt.imshow(np.transpose(tmp, (1, 2, 0)))
plt.title("Actual value: {}".format(tname[targets[i]])+'\n'+"Prediction value: {}".format(tname[predictions[i]]),size = 10)
# plt.title("Prediction value: {}".format(tname[targets[i]]))
plt.show()