Pytorch和CNN图像分类
PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序。它主要由Facebookd的人工智能小组开发,不仅能够实现强大的GPU加速,同时还支持动态神经网络,这一点是现在很多主流框架如TensorFlow都不支持的。 PyTorch提供了两个高级功能:
1.具有强大的GPU加速的张量计算(如Numpy)
2.包含自动求导系统的深度神经网络。除了Facebook之外,Twitter、GMU和Salesforce等机构都采用了PyTorch。
本文使用CIFAR-10数据集进行图像分类。该数据集中的图像是彩色小图像,其中被分为了十类。一些示例图像,如下图所示:
测试GPU是否可以使用
数据集中的图像大小为32x32x3 。在训练的过程中最好使用GPU来加速。
1import torch
2import numpy as np
3
4# 检查是否可以利用GPU
5train_on_gpu = torch.cuda.is_available()
6
7if not train_on_gpu:
8 print(‘CUDA is not available.’)
9else:
10 print(‘CUDA is available!’)
结果:
CUDA is available!
加载数据
数据下载可能会比较慢。请耐心等待。加载训练和测试数据,将训练数据分为训练集和验证集,然后为每个数据集创建DataLoader。
1from torchvision import datasets
2import torchvision.transforms as transforms
3from torch.utils.data.sampler import ubsetRandomSampler
4
5# number of subprocesses to use for data loading
6num_workers = 0
7# 每批加载16张图片
8batch_size = 16
9# percentage of training set to use as validation
10valid_size = 0.2
11
12# 将数据转换为torch.FloatTensor,并标准化。
13transform = transforms.Compose([
14 transforms.ToTensor(),
15 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
16 ])
17
18# 选择训练集与测试集的数据
19train_data = datasets.CIFAR10(‘data’, train=True,
20 download=True, transform=transform)
21test_data = datasets.CIFAR10(‘data’, train=False,
22 download=True, transform=transform)
23
24# obtain training indices that will be used for validation
25num_train = len(train_data)
26indices = list(range(num_train))
27np.random.shuffle(indices)
28split = int(np.floor(valid_size * num_train))
29train_idx, valid_idx = indices[split:], indices[:split]
30
31# define samplers for obtaining training and validation batches
32train_sampler = SubsetRandomSampler(train_idx)
33valid_sampler = SubsetRandomSampler(valid_idx)
34
35# prepare data loaders (combine dataset and sampler)
36train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
37 sampler=train_sampler, num_workers=num_workers)
38valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
39 sampler=valid_sampler, num_workers=num_workers)
40test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
41 num_workers=num_workers)
42
43# 图像分类中10类别
44classes = [‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,
45 ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’]
查看训练集中的一批样本
1import matplotlib.pyplot as plt
2%matplotlib inline
3
4# helper function to un-normalize and display an image
5def imshow(img):
6 img = img / 2 + 0.5 # unnormalize
7 plt.imshow(np.transpose(img, (1, 2, 0))) # convert from Tensor image
8
9# 获取一批样本
10dataiter = iter(train_loader)
11images, labels = dataiter.next()
12images = images.numpy() # convert images to numpy for display
13
14# 显示图像,标题为类名
15fig = plt.figure(figsize=(25, 4))
16# 显示16张图片
17for idx in np.arange(16):
18 ax = fig.add_subplot(2, 16/2, idx+1, xticks=[], yticks=[])
19 imshow(images[idx])
20 ax.set_title(classes[labels[idx]])
结果: