构建卷积神经网络
卷积神经网络的输入层与传统神经网络有些区别,需要重新设计,训练模块基本一致
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
首先读取数据
分别构建训练集和测试集(验证集) DataLoader来迭代数据
#定义超参数
input_size = 28 #图像的总尺寸28*28
num_classes = 10 #标签的种类数
num_epochs = 3 #训练的总循环周期
batch_size = 64 #一个批次的大小,64张照片
#训练集
train_dataset = datasets.MNIST(root='./data',
train=True,
transform=transforms.ToTensor(),
download=True)
#测试集
test_dataset = datasets.MNIST(root='./data',
train=False,
transform=transforms.ToTensor(),)
#构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=True)
卷积网络模块构建
一般卷积层,relu层,池化层可以写成一个套餐 (conv+relu)+pool
注意卷积最后结果还是一个特征图,需要把图转换成向量才能作分类或者回归任务
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( #输入大小(1,28,28)
nn.Conv2d(
in_channels=1, #灰度图
out_channels=16, #要得到多少个特征图
kernel_size=5, #卷积核大小
stride=1, #步长
padding=2, #如果希望卷积后大小和原来一样,需要设置padding=(kernel_size-1)/2 if stride=1
), #输出的特征图为(16,28,28)
nn.ReLU(), #relu层
nn.MaxPool2d(kernel_size=2),#进行池化操作(2*2区域),输出结果为:(16,14,14)
)
self.conv2 = nn.Sequential( #下一个套餐的输入 (16,14,14)
nn.Conv2d(16, 32, 5, 1, 2), #输出(32,14,14)
nn.ReLU(), #relu层
nn.MaxPool2d(2), #输出(32,7,7)
)
self.out = nn.Linear(32 * 7 * 7, 10) #全连接层得到的结果
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) #faltten操作,结果为:(batch_size, 32*7*7)
output = self.out(x)
return output
准确率作为评估标准
def accuracy(predictions,lables):
pred = torch.max(predictions.data, 1)[1]
rights = pred.eq(lables.data.view_as(pred)).sum()
return rights, len(lables)
训练模型
#实例化
net = CNN()
#损失函数
criterion = nn.CrossEntropyLoss()
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法
#开始训练循环
for epoch in range(num_epochs):
#当前epoch的结果保存下来
train_rights = []
for batch_idx, (data, target) in enumerate(train_loader): #针对容器中的每一个批进行循环
net.train()
output = net(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
right = accuracy(output, target)
train_rights.append(right)
if batch_idx % 100 == 0:
net.eval()
val_rights = []
for (data, target) in test_loader:
output = net(data)
right = accuracy(output, target)
val_rights.append(right)
#准确率计算
train_r = (sum([tup[0] for tup in train_rights]),sum([tup[1] for tup in train_rights]))
val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))
print('当前epoch:{} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}'.format(
epoch, batch_idx * batch_size, len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.data,
100. * train_r[0].numpy() / train_r[1],
100. * val_r[0].numpy() / val_r[1]))
当前epoch:0 [0/60000 (0%)] 损失: 2.313051 训练集准确率: 7.81% 测试集正确率: 17.08 当前epoch:0 [64/60000 (0%)] 损失: 2.292140 训练集准确率: 10.94% 测试集正确率: 17.08 当前epoch:0 [128/60000 (0%)] 损失: 2.293331 训练集准确率: 10.94% 测试集正确率: 17.08 当前epoch:0 [192/60000 (0%)] 损失: 2.244610 训练集准确率: 16.80% 测试集正确率: 17.08 当前epoch:0 [256/60000 (0%)] 损失: 2.218900 训练集准确率: 18.75% 测试集正确率: 17.08 当前epoch:0 [320/60000 (1%)] 损失: 2.195498 训练集准确率: 21.88% 测试集正确率: 17.08 当前epoch:0 [384/60000 (1%)] 损失: 2.206018 训练集准确率: 21.65% 测试集正确率: 17.08 当前epoch:0 [448/60000 (1%)] 损失: 2.180433 训练集准确率: 22.46% 测试集正确率: 17.08 当前epoch:0 [512/60000 (1%)] 损失: 2.141994 训练集准确率: 23.26% 测试集正确率: 17.08 当前epoch:0 [576/60000 (1%)] 损失: 2.042990 训练集准确率: 25.47% 测试集正确率: 17.08 当前epoch:0 [640/60000 (1%)] 损失: 2.048994 训练集准确率: 25.71% 测试集正确率: 17.08 当前epoch:0 [704/60000 (1%)] 损失: 1.977731 训练集准确率: 25.26% 测试集正确率: 17.08 当前epoch:0 [768/60000 (1%)] 损失: 2.034307 训练集准确率: 24.64% 测试集正确率: 17.08 当前epoch:0 [832/60000 (1%)] 损失: 1.966112 训练集准确率: 24.11% 测试集正确率: 17.08 当前epoch:0 [896/60000 (1%)] 损失: 1.881626 训练集准确率: 24.90% 测试集正确率: 17.08 当前epoch:0 [960/60000 (2%)] 损失: 1.844498 训练集准确率: 26.17% 测试集正确率: 17.08 当前epoch:0 [1024/60000 (2%)] 损失: 1.773689 训练集准确率: 28.95% 测试集正确率: 17.08 当前epoch:0 [1088/60000 (2%)] 损失: 1.686953 训练集准确率: 31.34% 测试集正确率: 17.08 当前epoch:0 [1152/60000 (2%)] 损失: 1.692229 训练集准确率: 32.73% 测试集正确率: 17.08 当前epoch:0 [1216/60000 (2%)] 损失: 1.557774 训练集准确率: 34.06% 测试集正确率: 17.08 当前epoch:0 [1280/60000 (2%)] 损失: 1.487893 训练集准确率: 35.57% 测试集正确率: 17.08 当前epoch:0 [1344/60000 (2%)] 损失: 1.601223 训练集准确率: 36.29% 测试集正确率: 17.08 当前epoch:0 [1408/60000 (2%)] 损失: 1.326457 训练集准确率: 37.50% 测试集正确率: 17.08 当前epoch:0 [1472/60000 (2%)] 损失: 1.236989 训练集准确率: 38.80% 测试集正确率: 17.08 当前epoch:0 [1536/60000 (3%)] 损失: 1.168241 训练集准确率: 40.12% 测试集正确率: 17.08 当前epoch:0 [1600/60000 (3%)] 损失: 1.159773 训练集准确率: 41.23% 测试集正确率: 17.08 当前epoch:0 [1664/60000 (3%)] 损失: 1.126725 训练集准确率: 42.48% 测试集正确率: 17.08 当前epoch:0 [1728/60000 (3%)] 损失: 0.897731 训练集准确率: 43.97% 测试集正确率: 17.08 当前epoch:0 [1792/60000 (3%)] 损失: 0.936758 训练集准确率: 44.94% 测试集正确率: 17.08 当前epoch:0 [1856/60000 (3%)] 损失: 0.991007 训练集准确率: 45.83% 测试集正确率: 17.08 当前epoch:0 [1920/60000 (3%)] 损失: 0.866241 训练集准确率: 46.93% 测试集正确率: 17.08 当前epoch:0 [1984/60000 (3%)] 损失: 0.715318 训练集准确率: 47.90% 测试集正确率: 17.08 当前epoch:0 [2048/60000 (3%)] 损失: 0.925184 训练集准确率: 48.58% 测试集正确率: 17.08 当前epoch:0 [2112/60000 (4%)] 损失: 0.614515 训练集准确率: 49.68% 测试集正确率: 17.08 当前epoch:0 [2176/60000 (4%)] 损失: 0.793619 训练集准确率: 50.45% 测试集正确率: 17.08 当前epoch:0 [2240/60000 (4%)] 损失: 0.657147 训练集准确率: 51.22% 测试集正确率: 17.08 当前epoch:0 [2304/60000 (4%)] 损失: 0.796385 训练集准确率: 51.94% 测试集正确率: 17.08 当前epoch:0 [2368/60000 (4%)] 损失: 0.791919 训练集准确率: 52.59% 测试集正确率: 17.08 当前epoch:0 [2432/60000 (4%)] 损失: 0.791459 训练集准确率: 53.25% 测试集正确率: 17.08 当前epoch:0 [2496/60000 (4%)] 损失: 0.616070 训练集准确率: 53.87% 测试集正确率: 17.08 当前epoch:0 [2560/60000 (4%)] 损失: 0.483364 训练集准确率: 54.54% 测试集正确率: 17.08 当前epoch:0 [2624/60000 (4%)] 损失: 0.889030 训练集准确率: 54.87% 测试集正确率: 17.08 当前epoch:0 [2688/60000 (4%)] 损失: 0.370191 训练集准确率: 55.70% 测试集正确率: 17.08 当前epoch:0 [2752/60000 (5%)] 损失: 0.715300 训练集准确率: 56.11% 测试集正确率: 17.08 当前epoch:0 [2816/60000 (5%)] 损失: 0.772094 训练集准确率: 56.74% 测试集正确率: 17.08 当前epoch:0 [2880/60000 (5%)] 损失: 0.573602 训练集准确率: 57.27% 测试集正确率: 17.08 当前epoch:0 [2944/60000 (5%)] 损失: 0.737286 训练集准确率: 57.65% 测试集正确率: 17.08 当前epoch:0 [3008/60000 (5%)] 损失: 0.361655 训练集准确率: 58.30% 测试集正确率: 17.08 当前epoch:0 [3072/60000 (5%)] 损失: 0.711395 训练集准确率: 58.55% 测试集正确率: 17.08 当前epoch:0 [3136/60000 (5%)] 损失: 0.549022 训练集准确率: 59.06% 测试集正确率: 17.08 当前epoch:0 [3200/60000 (5%)] 损失: 0.543209 训练集准确率: 59.65% 测试集正确率: 17.08 当前epoch:0 [3264/60000 (5%)] 损失: 0.576940 训练集准确率: 60.01% 测试集正确率: 17.08 当前epoch:0 [3328/60000 (6%)] 损失: 0.737475 训练集准确率: 60.29% 测试集正确率: 17.08 当前epoch:0 [3392/60000 (6%)] 损失: 0.613304 训练集准确率: 60.62% 测试集正确率: 17.08 当前epoch:0 [3456/60000 (6%)] 损失: 0.511794 训练集准确率: 61.05% 测试集正确率: 17.08 当前epoch:0 [3520/60000 (6%)] 损失: 0.616450 训练集准确率: 61.41% 测试集正确率: 17.08 当前epoch:0 [3584/60000 (6%)] 损失: 0.516962 训练集准确率: 61.81% 测试集正确率: 17.08 当前epoch:0 [3648/60000 (6%)] 损失: 0.346120 训练集准确率: 62.34% 测试集正确率: 17.08 当前epoch:0 [3712/60000 (6%)] 损失: 0.381096 训练集准确率: 62.76% 测试集正确率: 17.08 当前epoch:0 [3776/60000 (6%)] 损失: 0.366492 训练集准确率: 63.18% 测试集正确率: 17.08 当前epoch:0 [3840/60000 (6%)] 损失: 0.536602 训练集准确率: 63.52% 测试集正确率: 17.08 当前epoch:0 [3904/60000 (7%)] 损失: 0.686056 训练集准确率: 63.81% 测试集正确率: 17.08 当前epoch:0 [3968/60000 (7%)] 损失: 0.359961 训练集准确率: 64.29% 测试集正确率: 17.08 当前epoch:0 [4032/60000 (7%)] 损失: 0.414718 训练集准确率: 64.70% 测试集正确率: 17.08 当前epoch:0 [4096/60000 (7%)] 损失: 0.436713 训练集准确率: 65.00% 测试集正确率: 17.08 当前epoch:0 [4160/60000 (7%)] 损失: 0.764519 训练集准确率: 65.18% 测试集正确率: 17.08 当前epoch:0 [4224/60000 (7%)] 损失: 0.518178 训练集准确率: 65.51% 测试集正确率: 17.08 当前epoch:0 [4288/60000 (7%)] 损失: 0.511101 训练集准确率: 65.85% 测试集正确率: 17.08 当前epoch:0 [4352/60000 (7%)] 损失: 0.287424 训练集准确率: 66.24% 测试集正确率: 17.08 当前epoch:0 [4416/60000 (7%)] 损失: 0.493250 训练集准确率: 66.50% 测试集正确率: 17.08 当前epoch:0 [4480/60000 (7%)] 损失: 0.287984 训练集准确率: 66.81% 测试集正确率: 17.08 当前epoch:0 [4544/60000 (8%)] 损失: 0.551467 训练集准确率: 67.06% 测试集正确率: 17.08 当前epoch:0 [4608/60000 (8%)] 损失: 0.464156 训练集准确率: 67.29% 测试集正确率: 17.08 当前epoch:0 [4672/60000 (8%)] 损失: 0.320018 训练集准确率: 67.61% 测试集正确率: 17.08 当前epoch:0 [4736/60000 (8%)] 损失: 0.449212 训练集准确率: 67.92% 测试集正确率: 17.08 当前epoch:0 [4800/60000 (8%)] 损失: 0.398034 训练集准确率: 68.19% 测试集正确率: 17.08 当前epoch:0 [4864/60000 (8%)] 损失: 0.280365 训练集准确率: 68.45% 测试集正确率: 17.08 当前epoch:0 [4928/60000 (8%)] 损失: 0.272768 训练集准确率: 68.79% 测试集正确率: 17.08 当前epoch:0 [4992/60000 (8%)] 损失: 0.357351 训练集准确率: 69.07% 测试集正确率: 17.08 当前epoch:0 [5056/60000 (8%)] 损失: 0.204072 训练集准确率: 69.36% 测试集正确率: 17.08 当前epoch:0 [5120/60000 (9%)] 损失: 0.321930 训练集准确率: 69.64% 测试集正确率: 17.08 当前epoch:0 [5184/60000 (9%)] 损失: 0.449989 训练集准确率: 69.82% 测试集正确率: 17.08 当前epoch:0 [5248/60000 (9%)] 损失: 0.284762 训练集准确率: 70.12% 测试集正确率: 17.08 当前epoch:0 [5312/60000 (9%)] 损失: 0.321816 训练集准确率: 70.37% 测试集正确率: 17.08 当前epoch:0 [5376/60000 (9%)] 损失: 0.400219 训练集准确率: 70.57% 测试集正确率: 17.08 当前epoch:0 [5440/60000 (9%)] 损失: 0.384535 训练集准确率: 70.75% 测试集正确率: 17.08 当前epoch:0 [5504/60000 (9%)] 损失: 0.302196 训练集准确率: 70.96% 测试集正确率: 17.08 当前epoch:0 [5568/60000 (9%)] 损失: 0.340139 训练集准确率: 71.22% 测试集正确率: 17.08 当前epoch:0 [5632/60000 (9%)] 损失: 0.206504 训练集准确率: 71.45% 测试集正确率: 17.08 当前epoch:0 [5696/60000 (9%)] 损失: 0.229951 训练集准确率: 71.70% 测试集正确率: 17.08 当前epoch:0 [5760/60000 (10%)] 损失: 0.457961 训练集准确率: 71.86% 测试集正确率: 17.08 当前epoch:0 [5824/60000 (10%)] 损失: 0.463543 训练集准确率: 72.06% 测试集正确率: 17.08 当前epoch:0 [5888/60000 (10%)] 损失: 0.600046 训练集准确率: 72.21% 测试集正确率: 17.08 当前epoch:0 [5952/60000 (10%)] 损失: 0.264341 训练集准确率: 72.41% 测试集正确率: 17.08 当前epoch:0 [6016/60000 (10%)] 损失: 0.421244 训练集准确率: 72.60% 测试集正确率: 17.08 当前epoch:0 [6080/60000 (10%)] 损失: 0.233128 训练集准确率: 72.79% 测试集正确率: 17.08 当前epoch:0 [6144/60000 (10%)] 损失: 0.542871 训练集准确率: 72.94% 测试集正确率: 17.08 当前epoch:0 [6208/60000 (10%)] 损失: 0.357706 训练集准确率: 73.12% 测试集正确率: 17.08 当前epoch:0 [6272/60000 (10%)] 损失: 0.237382 训练集准确率: 73.30% 测试集正确率: 17.08 当前epoch:0 [6336/60000 (11%)] 损失: 0.307530 训练集准确率: 73.50% 测试集正确率: 17.08 当前epoch:0 [6400/60000 (11%)] 损失: 0.217164 训练集准确率: 73.72% 测试集正确率: 91.22 当前epoch:0 [6464/60000 (11%)] 损失: 0.178109 训练集准确率: 73.91% 测试集正确率: 91.22 当前epoch:0 [6528/60000 (11%)] 损失: 0.293645 训练集准确率: 74.09% 测试集正确率: 91.22 当前epoch:0 [6592/60000 (11%)] 损失: 0.234155 训练集准确率: 74.26% 测试集正确率: 91.22 当前epoch:0 [6656/60000 (11%)] 损失: 0.300352 训练集准确率: 74.42% 测试集正确率: 91.22 当前epoch:0 [6720/60000 (11%)] 损失: 0.200646 训练集准确率: 74.60% 测试集正确率: 91.22 当前epoch:0 [6784/60000 (11%)] 损失: 0.323289 训练集准确率: 74.74% 测试集正确率: 91.22 当前epoch:0 [6848/60000 (11%)] 损失: 0.174315 训练集准确率: 74.93% 测试集正确率: 91.22 当前epoch:0 [6912/60000 (12%)] 损失: 0.322517 训练集准确率: 75.09% 测试集正确率: 91.22 当前epoch:0 [6976/60000 (12%)] 损失: 0.308520 训练集准确率: 75.26% 测试集正确率: 91.22 当前epoch:0 [7040/60000 (12%)] 损失: 0.116981 训练集准确率: 75.46% 测试集正确率: 91.22 当前epoch:0 [7104/60000 (12%)] 损失: 0.260045 训练集准确率: 75.60% 测试集正确率: 91.22 当前epoch:0 [7168/60000 (12%)] 损失: 0.237228 训练集准确率: 75.75% 测试集正确率: 91.22 当前epoch:0 [7232/60000 (12%)] 损失: 0.311560 训练集准确率: 75.88% 测试集正确率: 91.22 当前epoch:0 [7296/60000 (12%)] 损失: 0.273539 训练集准确率: 76.03% 测试集正确率: 91.22 当前epoch:0 [7360/60000 (12%)] 损失: 0.200470 训练集准确率: 76.19% 测试集正确率: 91.22 ..........当前epoch:2 [59456/60000 (99%)] 损失: 0.075698 训练集准确率: 98.66% 测试集正确率: 98.90 当前epoch:2 [59520/60000 (99%)] 损失: 0.018424 训练集准确率: 98.66% 测试集正确率: 98.90 当前epoch:2 [59584/60000 (99%)] 损失: 0.034391 训练集准确率: 98.66% 测试集正确率: 98.90当前epoch:2 [59648/60000 (99%)] 损失: 0.023720 训练集准确率: 98.66% 测试集正确率: 98.90 当前epoch:2 [59712/60000 (99%)] 损失: 0.089731 训练集准确率: 98.66% 测试集正确率: 98.90 当前epoch:2 [59776/60000 (100%)] 损失: 0.034569 训练集准确率: 98.66% 测试集正确率: 98.90 当前epoch:2 [59840/60000 (100%)] 损失: 0.008449 训练集准确率: 98.66% 测试集正确率: 98.90 当前epoch:2 [59904/60000 (100%)] 损失: 0.078446 训练集准确率: 98.66% 测试集正确率: 98.90 当前epoch:2 [59968/60000 (100%)] 损失: 0.016576 训练集准确率: 98.66% 测试集正确率: 98.90