卷积神经网络效果

构建卷积神经网络

卷积神经网络的输入层与传统神经网络有些区别,需要重新设计,训练模块基本一致

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

首先读取数据

分别构建训练集和测试集(验证集) DataLoader来迭代数据

#定义超参数
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3   #训练的总循环周期
batch_size = 64  #一个批次的大小,64张照片

#训练集
train_dataset = datasets.MNIST(root='./data',
                              train=True,
                              transform=transforms.ToTensor(),
                              download=True)

#测试集
test_dataset = datasets.MNIST(root='./data',
                              train=False,
                              transform=transforms.ToTensor(),)

#构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         shuffle=True)

卷积网络模块构建

一般卷积层,relu层,池化层可以写成一个套餐 (conv+relu)+pool

注意卷积最后结果还是一个特征图,需要把图转换成向量才能作分类或者回归任务

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(       #输入大小(1,28,28)
            nn.Conv2d(
                in_channels=1,           #灰度图
                out_channels=16,         #要得到多少个特征图
                kernel_size=5,           #卷积核大小
                stride=1,                #步长
                padding=2,               #如果希望卷积后大小和原来一样,需要设置padding=(kernel_size-1)/2 if stride=1
            ),                           #输出的特征图为(16,28,28)
            nn.ReLU(),                   #relu层
            nn.MaxPool2d(kernel_size=2),#进行池化操作(2*2区域),输出结果为:(16,14,14)
        )
        self.conv2 = nn.Sequential(      #下一个套餐的输入 (16,14,14)
        nn.Conv2d(16, 32, 5, 1, 2),       #输出(32,14,14)
        nn.ReLU(),                       #relu层
        nn.MaxPool2d(2),                 #输出(32,7,7)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)  #全连接层得到的结果
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)         #faltten操作,结果为:(batch_size, 32*7*7)
        output = self.out(x)
        return output

准确率作为评估标准

def accuracy(predictions,lables):
    pred = torch.max(predictions.data, 1)[1]
    rights = pred.eq(lables.data.view_as(pred)).sum()
    return rights, len(lables)

训练模型

#实例化
net = CNN()
#损失函数
criterion = nn.CrossEntropyLoss()
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001)             #定义优化器,普通的随机梯度下降算法

#开始训练循环
for epoch in range(num_epochs):
    #当前epoch的结果保存下来
    train_rights = []
    
    for batch_idx, (data, target) in enumerate(train_loader):   #针对容器中的每一个批进行循环  
        net.train()
        output = net(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        right = accuracy(output, target)
        train_rights.append(right)
        
        if batch_idx % 100 == 0:
            
            net.eval()
            val_rights = []
            
            for (data, target) in test_loader:
                output = net(data)
                right = accuracy(output, target)
                val_rights.append(right)
                
        #准确率计算
        train_r = (sum([tup[0] for tup in train_rights]),sum([tup[1] for tup in train_rights]))
        val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))
        
        print('当前epoch:{} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}'.format(
             epoch, batch_idx * batch_size, len(train_loader.dataset),
             100. * batch_idx / len(train_loader),
             loss.data,
             100. * train_r[0].numpy() / train_r[1],
             100. * val_r[0].numpy() / val_r[1]))
当前epoch:0 [0/60000 (0%)]	损失: 2.313051	训练集准确率: 7.81%	测试集正确率: 17.08
当前epoch:0 [64/60000 (0%)]	损失: 2.292140	训练集准确率: 10.94%	测试集正确率: 17.08
当前epoch:0 [128/60000 (0%)]	损失: 2.293331	训练集准确率: 10.94%	测试集正确率: 17.08
当前epoch:0 [192/60000 (0%)]	损失: 2.244610	训练集准确率: 16.80%	测试集正确率: 17.08
当前epoch:0 [256/60000 (0%)]	损失: 2.218900	训练集准确率: 18.75%	测试集正确率: 17.08
当前epoch:0 [320/60000 (1%)]	损失: 2.195498	训练集准确率: 21.88%	测试集正确率: 17.08
当前epoch:0 [384/60000 (1%)]	损失: 2.206018	训练集准确率: 21.65%	测试集正确率: 17.08
当前epoch:0 [448/60000 (1%)]	损失: 2.180433	训练集准确率: 22.46%	测试集正确率: 17.08
当前epoch:0 [512/60000 (1%)]	损失: 2.141994	训练集准确率: 23.26%	测试集正确率: 17.08
当前epoch:0 [576/60000 (1%)]	损失: 2.042990	训练集准确率: 25.47%	测试集正确率: 17.08
当前epoch:0 [640/60000 (1%)]	损失: 2.048994	训练集准确率: 25.71%	测试集正确率: 17.08
当前epoch:0 [704/60000 (1%)]	损失: 1.977731	训练集准确率: 25.26%	测试集正确率: 17.08
当前epoch:0 [768/60000 (1%)]	损失: 2.034307	训练集准确率: 24.64%	测试集正确率: 17.08
当前epoch:0 [832/60000 (1%)]	损失: 1.966112	训练集准确率: 24.11%	测试集正确率: 17.08
当前epoch:0 [896/60000 (1%)]	损失: 1.881626	训练集准确率: 24.90%	测试集正确率: 17.08
当前epoch:0 [960/60000 (2%)]	损失: 1.844498	训练集准确率: 26.17%	测试集正确率: 17.08
当前epoch:0 [1024/60000 (2%)]	损失: 1.773689	训练集准确率: 28.95%	测试集正确率: 17.08
当前epoch:0 [1088/60000 (2%)]	损失: 1.686953	训练集准确率: 31.34%	测试集正确率: 17.08
当前epoch:0 [1152/60000 (2%)]	损失: 1.692229	训练集准确率: 32.73%	测试集正确率: 17.08
当前epoch:0 [1216/60000 (2%)]	损失: 1.557774	训练集准确率: 34.06%	测试集正确率: 17.08
当前epoch:0 [1280/60000 (2%)]	损失: 1.487893	训练集准确率: 35.57%	测试集正确率: 17.08
当前epoch:0 [1344/60000 (2%)]	损失: 1.601223	训练集准确率: 36.29%	测试集正确率: 17.08
当前epoch:0 [1408/60000 (2%)]	损失: 1.326457	训练集准确率: 37.50%	测试集正确率: 17.08
当前epoch:0 [1472/60000 (2%)]	损失: 1.236989	训练集准确率: 38.80%	测试集正确率: 17.08
当前epoch:0 [1536/60000 (3%)]	损失: 1.168241	训练集准确率: 40.12%	测试集正确率: 17.08
当前epoch:0 [1600/60000 (3%)]	损失: 1.159773	训练集准确率: 41.23%	测试集正确率: 17.08
当前epoch:0 [1664/60000 (3%)]	损失: 1.126725	训练集准确率: 42.48%	测试集正确率: 17.08
当前epoch:0 [1728/60000 (3%)]	损失: 0.897731	训练集准确率: 43.97%	测试集正确率: 17.08
当前epoch:0 [1792/60000 (3%)]	损失: 0.936758	训练集准确率: 44.94%	测试集正确率: 17.08
当前epoch:0 [1856/60000 (3%)]	损失: 0.991007	训练集准确率: 45.83%	测试集正确率: 17.08
当前epoch:0 [1920/60000 (3%)]	损失: 0.866241	训练集准确率: 46.93%	测试集正确率: 17.08
当前epoch:0 [1984/60000 (3%)]	损失: 0.715318	训练集准确率: 47.90%	测试集正确率: 17.08
当前epoch:0 [2048/60000 (3%)]	损失: 0.925184	训练集准确率: 48.58%	测试集正确率: 17.08
当前epoch:0 [2112/60000 (4%)]	损失: 0.614515	训练集准确率: 49.68%	测试集正确率: 17.08
当前epoch:0 [2176/60000 (4%)]	损失: 0.793619	训练集准确率: 50.45%	测试集正确率: 17.08
当前epoch:0 [2240/60000 (4%)]	损失: 0.657147	训练集准确率: 51.22%	测试集正确率: 17.08
当前epoch:0 [2304/60000 (4%)]	损失: 0.796385	训练集准确率: 51.94%	测试集正确率: 17.08
当前epoch:0 [2368/60000 (4%)]	损失: 0.791919	训练集准确率: 52.59%	测试集正确率: 17.08
当前epoch:0 [2432/60000 (4%)]	损失: 0.791459	训练集准确率: 53.25%	测试集正确率: 17.08
当前epoch:0 [2496/60000 (4%)]	损失: 0.616070	训练集准确率: 53.87%	测试集正确率: 17.08
当前epoch:0 [2560/60000 (4%)]	损失: 0.483364	训练集准确率: 54.54%	测试集正确率: 17.08
当前epoch:0 [2624/60000 (4%)]	损失: 0.889030	训练集准确率: 54.87%	测试集正确率: 17.08
当前epoch:0 [2688/60000 (4%)]	损失: 0.370191	训练集准确率: 55.70%	测试集正确率: 17.08
当前epoch:0 [2752/60000 (5%)]	损失: 0.715300	训练集准确率: 56.11%	测试集正确率: 17.08
当前epoch:0 [2816/60000 (5%)]	损失: 0.772094	训练集准确率: 56.74%	测试集正确率: 17.08
当前epoch:0 [2880/60000 (5%)]	损失: 0.573602	训练集准确率: 57.27%	测试集正确率: 17.08
当前epoch:0 [2944/60000 (5%)]	损失: 0.737286	训练集准确率: 57.65%	测试集正确率: 17.08
当前epoch:0 [3008/60000 (5%)]	损失: 0.361655	训练集准确率: 58.30%	测试集正确率: 17.08
当前epoch:0 [3072/60000 (5%)]	损失: 0.711395	训练集准确率: 58.55%	测试集正确率: 17.08
当前epoch:0 [3136/60000 (5%)]	损失: 0.549022	训练集准确率: 59.06%	测试集正确率: 17.08
当前epoch:0 [3200/60000 (5%)]	损失: 0.543209	训练集准确率: 59.65%	测试集正确率: 17.08
当前epoch:0 [3264/60000 (5%)]	损失: 0.576940	训练集准确率: 60.01%	测试集正确率: 17.08
当前epoch:0 [3328/60000 (6%)]	损失: 0.737475	训练集准确率: 60.29%	测试集正确率: 17.08
当前epoch:0 [3392/60000 (6%)]	损失: 0.613304	训练集准确率: 60.62%	测试集正确率: 17.08
当前epoch:0 [3456/60000 (6%)]	损失: 0.511794	训练集准确率: 61.05%	测试集正确率: 17.08
当前epoch:0 [3520/60000 (6%)]	损失: 0.616450	训练集准确率: 61.41%	测试集正确率: 17.08
当前epoch:0 [3584/60000 (6%)]	损失: 0.516962	训练集准确率: 61.81%	测试集正确率: 17.08
当前epoch:0 [3648/60000 (6%)]	损失: 0.346120	训练集准确率: 62.34%	测试集正确率: 17.08
当前epoch:0 [3712/60000 (6%)]	损失: 0.381096	训练集准确率: 62.76%	测试集正确率: 17.08
当前epoch:0 [3776/60000 (6%)]	损失: 0.366492	训练集准确率: 63.18%	测试集正确率: 17.08
当前epoch:0 [3840/60000 (6%)]	损失: 0.536602	训练集准确率: 63.52%	测试集正确率: 17.08
当前epoch:0 [3904/60000 (7%)]	损失: 0.686056	训练集准确率: 63.81%	测试集正确率: 17.08
当前epoch:0 [3968/60000 (7%)]	损失: 0.359961	训练集准确率: 64.29%	测试集正确率: 17.08
当前epoch:0 [4032/60000 (7%)]	损失: 0.414718	训练集准确率: 64.70%	测试集正确率: 17.08
当前epoch:0 [4096/60000 (7%)]	损失: 0.436713	训练集准确率: 65.00%	测试集正确率: 17.08
当前epoch:0 [4160/60000 (7%)]	损失: 0.764519	训练集准确率: 65.18%	测试集正确率: 17.08
当前epoch:0 [4224/60000 (7%)]	损失: 0.518178	训练集准确率: 65.51%	测试集正确率: 17.08
当前epoch:0 [4288/60000 (7%)]	损失: 0.511101	训练集准确率: 65.85%	测试集正确率: 17.08
当前epoch:0 [4352/60000 (7%)]	损失: 0.287424	训练集准确率: 66.24%	测试集正确率: 17.08
当前epoch:0 [4416/60000 (7%)]	损失: 0.493250	训练集准确率: 66.50%	测试集正确率: 17.08
当前epoch:0 [4480/60000 (7%)]	损失: 0.287984	训练集准确率: 66.81%	测试集正确率: 17.08
当前epoch:0 [4544/60000 (8%)]	损失: 0.551467	训练集准确率: 67.06%	测试集正确率: 17.08
当前epoch:0 [4608/60000 (8%)]	损失: 0.464156	训练集准确率: 67.29%	测试集正确率: 17.08
当前epoch:0 [4672/60000 (8%)]	损失: 0.320018	训练集准确率: 67.61%	测试集正确率: 17.08
当前epoch:0 [4736/60000 (8%)]	损失: 0.449212	训练集准确率: 67.92%	测试集正确率: 17.08
当前epoch:0 [4800/60000 (8%)]	损失: 0.398034	训练集准确率: 68.19%	测试集正确率: 17.08
当前epoch:0 [4864/60000 (8%)]	损失: 0.280365	训练集准确率: 68.45%	测试集正确率: 17.08
当前epoch:0 [4928/60000 (8%)]	损失: 0.272768	训练集准确率: 68.79%	测试集正确率: 17.08
当前epoch:0 [4992/60000 (8%)]	损失: 0.357351	训练集准确率: 69.07%	测试集正确率: 17.08
当前epoch:0 [5056/60000 (8%)]	损失: 0.204072	训练集准确率: 69.36%	测试集正确率: 17.08
当前epoch:0 [5120/60000 (9%)]	损失: 0.321930	训练集准确率: 69.64%	测试集正确率: 17.08
当前epoch:0 [5184/60000 (9%)]	损失: 0.449989	训练集准确率: 69.82%	测试集正确率: 17.08
当前epoch:0 [5248/60000 (9%)]	损失: 0.284762	训练集准确率: 70.12%	测试集正确率: 17.08
当前epoch:0 [5312/60000 (9%)]	损失: 0.321816	训练集准确率: 70.37%	测试集正确率: 17.08
当前epoch:0 [5376/60000 (9%)]	损失: 0.400219	训练集准确率: 70.57%	测试集正确率: 17.08
当前epoch:0 [5440/60000 (9%)]	损失: 0.384535	训练集准确率: 70.75%	测试集正确率: 17.08
当前epoch:0 [5504/60000 (9%)]	损失: 0.302196	训练集准确率: 70.96%	测试集正确率: 17.08
当前epoch:0 [5568/60000 (9%)]	损失: 0.340139	训练集准确率: 71.22%	测试集正确率: 17.08
当前epoch:0 [5632/60000 (9%)]	损失: 0.206504	训练集准确率: 71.45%	测试集正确率: 17.08
当前epoch:0 [5696/60000 (9%)]	损失: 0.229951	训练集准确率: 71.70%	测试集正确率: 17.08
当前epoch:0 [5760/60000 (10%)]	损失: 0.457961	训练集准确率: 71.86%	测试集正确率: 17.08
当前epoch:0 [5824/60000 (10%)]	损失: 0.463543	训练集准确率: 72.06%	测试集正确率: 17.08
当前epoch:0 [5888/60000 (10%)]	损失: 0.600046	训练集准确率: 72.21%	测试集正确率: 17.08
当前epoch:0 [5952/60000 (10%)]	损失: 0.264341	训练集准确率: 72.41%	测试集正确率: 17.08
当前epoch:0 [6016/60000 (10%)]	损失: 0.421244	训练集准确率: 72.60%	测试集正确率: 17.08
当前epoch:0 [6080/60000 (10%)]	损失: 0.233128	训练集准确率: 72.79%	测试集正确率: 17.08
当前epoch:0 [6144/60000 (10%)]	损失: 0.542871	训练集准确率: 72.94%	测试集正确率: 17.08
当前epoch:0 [6208/60000 (10%)]	损失: 0.357706	训练集准确率: 73.12%	测试集正确率: 17.08
当前epoch:0 [6272/60000 (10%)]	损失: 0.237382	训练集准确率: 73.30%	测试集正确率: 17.08
当前epoch:0 [6336/60000 (11%)]	损失: 0.307530	训练集准确率: 73.50%	测试集正确率: 17.08
当前epoch:0 [6400/60000 (11%)]	损失: 0.217164	训练集准确率: 73.72%	测试集正确率: 91.22
当前epoch:0 [6464/60000 (11%)]	损失: 0.178109	训练集准确率: 73.91%	测试集正确率: 91.22
当前epoch:0 [6528/60000 (11%)]	损失: 0.293645	训练集准确率: 74.09%	测试集正确率: 91.22
当前epoch:0 [6592/60000 (11%)]	损失: 0.234155	训练集准确率: 74.26%	测试集正确率: 91.22
当前epoch:0 [6656/60000 (11%)]	损失: 0.300352	训练集准确率: 74.42%	测试集正确率: 91.22
当前epoch:0 [6720/60000 (11%)]	损失: 0.200646	训练集准确率: 74.60%	测试集正确率: 91.22
当前epoch:0 [6784/60000 (11%)]	损失: 0.323289	训练集准确率: 74.74%	测试集正确率: 91.22
当前epoch:0 [6848/60000 (11%)]	损失: 0.174315	训练集准确率: 74.93%	测试集正确率: 91.22
当前epoch:0 [6912/60000 (12%)]	损失: 0.322517	训练集准确率: 75.09%	测试集正确率: 91.22
当前epoch:0 [6976/60000 (12%)]	损失: 0.308520	训练集准确率: 75.26%	测试集正确率: 91.22
当前epoch:0 [7040/60000 (12%)]	损失: 0.116981	训练集准确率: 75.46%	测试集正确率: 91.22
当前epoch:0 [7104/60000 (12%)]	损失: 0.260045	训练集准确率: 75.60%	测试集正确率: 91.22
当前epoch:0 [7168/60000 (12%)]	损失: 0.237228	训练集准确率: 75.75%	测试集正确率: 91.22
当前epoch:0 [7232/60000 (12%)]	损失: 0.311560	训练集准确率: 75.88%	测试集正确率: 91.22
当前epoch:0 [7296/60000 (12%)]	损失: 0.273539	训练集准确率: 76.03%	测试集正确率: 91.22
当前epoch:0 [7360/60000 (12%)]	损失: 0.200470	训练集准确率: 76.19%	测试集正确率: 91.22
..........
当前epoch:2 [59456/60000 (99%)]	损失: 0.075698	训练集准确率: 98.66%	测试集正确率: 98.90
当前epoch:2 [59520/60000 (99%)]	损失: 0.018424	训练集准确率: 98.66%	测试集正确率: 98.90
当前epoch:2 [59584/60000 (99%)]	损失: 0.034391	训练集准确率: 98.66%	测试集正确率: 98.90
当前epoch:2 [59648/60000 (99%)]	损失: 0.023720	训练集准确率: 98.66%	测试集正确率: 98.90
当前epoch:2 [59712/60000 (99%)]	损失: 0.089731	训练集准确率: 98.66%	测试集正确率: 98.90
当前epoch:2 [59776/60000 (100%)]	损失: 0.034569	训练集准确率: 98.66%	测试集正确率: 98.90
当前epoch:2 [59840/60000 (100%)]	损失: 0.008449	训练集准确率: 98.66%	测试集正确率: 98.90
当前epoch:2 [59904/60000 (100%)]	损失: 0.078446	训练集准确率: 98.66%	测试集正确率: 98.90
当前epoch:2 [59968/60000 (100%)]	损失: 0.016576	训练集准确率: 98.66%	测试集正确率: 98.90

 

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

荔枝味啊~

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值