Pytorch 实现LeNet卷积神经网络

Pytorch 实现LeNet卷积神经网络

github:https://github.com/QzGithub0617/Pytorch-LeNet.git

1、请先执行train.py,执行前请将train.py文件内的第18行、第23行的download都设置位True,下载训练集数据。训练数据下载结束后,请将两处的download设置位false

2、train.py运行结束后,运行pre.py,测试图片可根据自己需求修改

LeNet.py:

import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1, 32*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

train.py:

import torch
import torchvision
import torch.utils
import torch.nn as nn
from LeNet import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib .pyplot as plt
import numpy as np

if __name__ == '__main__':

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=False, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,
                                              shuffle=True, num_workers=0)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=False, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
                                             shuffle=False, num_workers=0)
    test_data_iter = iter(testloader)
    test_image, test_label = test_data_iter.next()

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # def imshow(img):
    #     img = img / 2 + 0.5     # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # # print labels
    # print(' '.join(f'{classes[test_label[j]]:5s}' for j in range(4)))
    # # show images
    # imshow(torchvision.utils.make_grid(test_image))
    net = LeNet()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    for epoch in range(5):
        running_loss = 0.0 # 损失累加
        for step, data in enumerate(trainloader, start=0):
            inputs, labels = data
            optimizer.zero_grad()

            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if step % 500 == 499:
                with torch.no_grad():
                    outputs = net(test_image)
                    pre_y = torch.max(outputs, dim=1)[1] # 找最大的index,[1]表示只需要index,不需要知道值是多少
                    accuracy = (pre_y == test_label).sum().item() / test_label.size(0)
                    print('[%d, %5d], train_loss: %0.3f test_accuracy: %0.3f' %
                          (epoch + 1, step+1, running_loss / 500, accuracy))
                    running_loss = 0.0

    print('finished traning')
    save_path = './LeNet.pth'
    torch.save(net.state_dict(), save_path)

pre.py:

import torch
import torchvision.transforms as transforms
from PIL import Image
from LeNet import LeNet

transform = transforms.Compose(
    [transforms.Resize((32, 32)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = LeNet()
net.load_state_dict(torch.load('LeNet.pth'))
im = Image.open('1.jpg')
im = transform(im)
im = torch.unsqueeze(im, dim=0)

with torch.no_grad():
    outputs = net(im)
    pre = torch.max(outputs, dim=1)[1]
print(classes[int(pre)])

测试图片 1.png(来源于网络)

在这里插入图片描述

参考资料:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值