LeNet--卷积神经网络开山之作

LeNet的网络结构示意图如下所示:

输入 → 卷积 → 池化 → 卷积 → 池化 → 全连接 → 全连接 → 全连接(输出)

model.py:

import torch.nn as nn
import torch.nn.functional as F

# input(3, 32, 32),pytorch默认stride=1, padding=0
class LeNet(nn.Module):
    def __init__(self, num_classes=10, init_weights=True):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=0) # 16 * 28 * 28
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # 16 * 14 * 14
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5) # 32 * 10 * 10
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # 32 * 5 * 5
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28)
        x = self.pool1(x)            # output(16, 14, 14)
        x = F.relu(self.conv2(x))    # output(32, 10, 10)
        x = self.pool2(x)            # output(32, 5, 5)
        x = x.view(-1, 32*5*5)       # output(32*5*5) 注意这里的尺寸一定要和训练时Resize的图片尺寸一致
        x = F.relu(self.fc1(x))      # output(120)
        x = F.relu(self.fc2(x))      # output(84)
        x = self.fc3(x)              # output(10)
        return x


 train.py:

import json
import sys

import torch
import torchvision
import torch.nn as nn
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
from tqdm import tqdm

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False


ROOT_TRAIN = r'E:/cnn/AlexNet/data/train'
ROOT_TEST = r'E:/cnn/AlexNet/data/val'


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(32),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((32, 32)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}  # 数据预处理

    train_dataset = ImageFolder(ROOT_TRAIN, transform=data_transform["train"]) # 加载训练集
    train_num = len(train_dataset) # 打印训练集有多少张图片
    animal_list = train_dataset.class_to_idx # 获取类别名称以及对应的索引
    cla_dict = dict((val, key) for key, val in animal_list.items()) # 将上面的键值对位置对调一下
    json_str = json.dumps(cla_dict, indent=4)  # 把类别和对应的索引写入根目录下class_indices.json文件中
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=0)

    validate_dataset = ImageFolder(ROOT_TEST, transform=data_transform["val"]) # 载入测试集
    val_num = len(validate_dataset) # 打印测试集有多少张图片
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=16, shuffle=False,
                                                  num_workers=0)

    print("using {} images for training, {} images for validation.".format(train_num, val_num))  # 用于打印总的训练集数量和验证集数量

    # 用于查看数据集,注意改一下上面validate_loader的batch_size,batch_size等几就是一次查看几张图片,shuffle=True顺序打乱一下
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))


    net = LeNet(num_classes=10, init_weights=True) # 实例化网络,num_classes代表有几个类别

    net.to(device) # 将网络指认到GPU或CPU上
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = './LeNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar: # 遍历验证集
                val_images, val_labels = val_data # 数据分为图片和标签
                outputs = net(val_images.to(device)) # 将图片指认到设备上传入网络进行正向传播并得到输出
                predict_y = torch.max(outputs, dim=1)[1] # 求得输出预测中最有可得的类别(概率最大值)
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item() # 将预测标签与真实标签进行比对,求得总的预测正确数量

        val_accurate = acc / val_num # 预测正确数量/测试集总数量
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

predict:

import torch
import torchvision.transforms as transforms
from PIL import Image

from model import LeNet

def main():
    transform = transforms.Compose(
        [transforms.Resize((32, 32)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    classes = ('Cat', 'Dog') # 这里是二分类问题

    net = LeNet()
    net.load_state_dict(torch.load('Lenet.pth'))

    im = Image.open('1.jpg')
    im = transform(im)  # [C, H, W]
    im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]

    with torch.no_grad():
        outputs = net(im)
        predict = torch.max(outputs, dim=1)[1].numpy()
    print(classes[int(predict)])


if __name__ == '__main__':
    main()

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值