最近学习了pytorch的使用方法,并用pytorch改进了cifar10的分类准确率,属实舒适。
pytorch的使用方法和numpy近似,并且可以使用gpu帮助运行,对训练速度会有不小的提升。
下面贴上代码,最终的准确率能达到75%
import torch
import pickle
import os
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
# 准备数据集
def unpickle(file):
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
# 初始化
x_trainset = []
y_trainset = []
file_location = '/content/drive/My Drive/Colab Notebooks/cifar-10-batches-py'
file_name_list = os.listdir(file_location)
for file_name in file_name_list:
if file_name[0:10] == 'data_batch':
data_batch = unpickle('/content/drive/My Drive/Colab Notebooks/cifar-10-batches-py/'+file_name)[b'data']
label_batch = unpickle('/content/drive/My Drive/Colab Notebooks/cifar-10-batches-py/'+file_name)[b'labels']
x_trainset.append(data_batch)
y_trainset.append(label_batch)
x_train = torch.Tensor(x_trainset).reshape(50000,3,32,32)
y_train = torch.LongTensor(y_trainset).reshape(1,50000).squeeze(0)
x_trainset = x_train[:20000]
y_trainset = y_train[:20000]
x_testset = x_train[40000:45000]
y_testset = y_train[40000:45000]
def img_show(img):
img = img.numpy()
img = np.transpose(img,(1,2,0))
img /= 255
plt.imshow(img)
plt.show()
# img_show(x_trainset[5])
class AlexNet(nn.Module):
def __init__(self, num_classes=10):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 192, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(192, 384, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return F.log_softmax(x, dim=1)
def train(model,device,train_loader,optimizer,epoch,whether_to_print = True,print_interval=100):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if whether_to_print and batch_idx % print_interval == 0:
print("Train epoch: {} Loss: {:.6f}".format(
epoch, loss.item()
))
def test(model, device, test_loader,Test_batch_size = 200):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for (data, target) in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
pred = pred.squeeze(1)
correct += (pred == target).sum()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100.*correct / len(test_loader.dataset)))
if __name__ == '__main__':
torch.manual_seed(53113)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_batch_size = test_batch_size = 200
train_loader = torch.utils.data.DataLoader(list(zip(x_trainset,y_trainset)),
batch_size=train_batch_size, shuffle=True )
test_loader = torch.utils.data.DataLoader(list(zip(x_testset,y_testset)),
batch_size=test_batch_size, shuffle=True )
lr = 0.01
momentum = 0.5
model = AlexNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epochs = 100
for epoch in range(1, epochs + 1):
train(model, device,train_loader, optimizer,epoch)
if epoch%10 == 0:
test(model, device, test_loader)
save_model = False
if (save_model):
torch.save(model.state_dict(), "cifar10_cnn.pt")