数据集说明
CIFAR-10数据集由10个类的 60000 个 32x32 彩色图像组成,每个类有6000个图像。有50000个训练图像和10000个测试图像。数据集划分为5个训练批次和1个测试批次,每个批次有10000个图像,测试批次包含来自每个类别的恰好1000个随机选择的图像。训练批次以随机顺序包含剩余图像,但由于一些训练批次可能来源一个类别的图像比另一个多,因此总体来看,5个训练集之和包含来自每个类的正好5000张图像。
数据集下载
加载数据
这个可以采用PyTorch提供的数据集加载工具 torchvision,同时对数据进行预处理,可以预先把数据集下载好解压,并放在当前目录的 data下, 所以参数 download = False
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
# B,G,R 三个通道归一化 标准差为 0.5, 方差为0.5
[transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]
)
trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download=False, transform = transform)
trainLoader = torch.utils.data.DataLoader(trainset,batch_size = 4, shuffle = True, num_workers = 2)
testset = torchvision.datasets.CIFAR10(root = './data', train = False, download=False, transform = transform)
testLoader = torch.utils.data.DataLoader(testset, batch_size = 4, shuffle = False, num_workers = 2)
classes = ('plane', 'car','bird','cat','deer','dog','frog','horse','ship','truck')
构建网络
class CNNNet(nn.Module):
def __init__(self):
super(CNNNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels = 3, out_channels=16, kernel_size= 5, stride=1)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride = 2)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=36, kernel_size=3, stride=1)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride = 2)
self.fc1 = nn.Linear(1296, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self,x):
x = self.pool1(F.relu(self.conv1(x)))
x = self.pool2(F.relu(self.conv2(x)))
x = x.view(-1, 36*6*6)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return x
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net = CNNNet()
net = net.to(device)
查看一下网络结构
CNNNet(
(conv1): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1))
(pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(16, 36, kernel_size=(3, 3), stride=(1, 1))
(pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(in_features=1296, out_features=128, bias=