AlexNet网络模型搭建

该博客介绍了如何使用PyTorch搭建AlexNet网络,并应用在图像分类任务上。作者首先详细展示了AlexNet的网络结构,包括8层卷积神经网络,然后提供了训练和验证数据的预处理方式,最后通过训练脚本展示了一个完整的训练过程,包括损失计算、反向传播和优化。此外,还提供了一个预测脚本来加载训练好的模型并进行预测。
摘要由CSDN通过智能技术生成

AlexNet网络结构如下所示,使用了8层卷积神经网络,前5层是卷积层,后三层3层是全连接层

上图作者使用两块GPU进行计算,因此分为了上下两部分。我们以单块GPU为例进行搭建,因此其结构图参考下面部分即可

net.py:

import torch
from torch import nn
import torch.nn.functional as F

# 3 * 224 * 224
class MyAlexNet(nn.Module):
    def __init__(self, num_classes, init_weights=False):
        super(MyAlexNet, self).__init__()
        self.c1 = nn.Conv2d(in_channels=3, out_channels=48, kernel_size=11, stride=4, padding=2) # 48 * 55 * 55
        self.ReLU = nn.ReLU()
        self.s2 = nn.MaxPool2d(2) # 48 * 27 * 27
        self.c2 = nn.Conv2d(in_channels=48, out_channels=128, kernel_size=5, stride=1, padding=2) # 128 * 27 * 27
        self.ReLU = nn.ReLU()
        self.s3 = nn.MaxPool2d(2) # 128 * 13 * 13
        self.c3 = nn.Conv2d(in_channels=128, out_channels=192, kernel_size=3, stride=1, padding=1) # 192 * 13 * 13
        self.ReLU = nn.ReLU()
        self.c4 = nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=1, padding=1) # 192 * 13 * 13
        self.ReLU = nn.ReLU()
        self.c5 = nn.Conv2d(in_channels=192, out_channels=128, kernel_size=3, stride=1, padding=1) # 128 * 13 * 13
        self.ReLU = nn.ReLU()
        self.s5 = nn.MaxPool2d(kernel_size=3, stride=2) # 128 * 6 * 6
        self.flatten = nn.Flatten()
        self.f6 = nn.Linear(4608, 2048) # 128 * 6 * 6 = 4608
        self.f7 = nn.Linear(2048, 2048)
        self.f8 = nn.Linear(2048, 1000)
        self.f9 = nn.Linear(1000, num_classes) # num_classes为类别数

    def forward(self, x):
        x = self.ReLU(self.c1(x))
        x = self.s2(x)
        x = self.ReLU(self.c2(x))
        x = self.s3(x)
        x = self.ReLU(self.c3(x))
        x = self.ReLU(self.c4(x))
        x = self.ReLU(self.c5(x))
        x = self.s5(x)
        x = self.flatten(x)
        x = self.f6(x)
        x = F.dropout(x, p=0.5)
        x = self.f7(x)
        x = F.dropout(x, p=0.5)
        x = self.f8(x)
        x = F.dropout(x, p=0.5)

        x = self.f9(x)
        return x

train.py:

import os
import sys
import json
import time

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from torchvision.datasets import ImageFolder
from tqdm import tqdm

from net import MyAlexNet

ROOT_TRAIN = r'E:/cnn/AlexNet/data/train'
ROOT_TEST = r'E:/cnn/AlexNet/data/val'

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}  # 数据预处理

    train_dataset = ImageFolder(ROOT_TRAIN, transform=data_transform["train"]) # 加载训练集
    train_num = len(train_dataset) # 打印训练集有多少张图片
    animal_list = train_dataset.class_to_idx # 获取类别名称以及对应的索引
    cla_dict = dict((val, key) for key, val in animal_list.items()) # 将上面的键值对位置对调一下

    json_str = json.dumps(cla_dict, indent=4) # 把类别和对应的索引写入根目录下class_indices.json文件中
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=0)

    validate_dataset = ImageFolder(ROOT_TEST, transform=data_transform["val"]) # 载入测试集
    val_num = len(validate_dataset) # 打印测试集有多少张图片
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=16, shuffle=False,
                                                  num_workers=0)

    # 用于查看数据集,注意改一下上面validate_loader的batch_size,batch_size等几就是一次查看几张图片,shuffle=True顺序打乱一下
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))


    net = MyAlexNet(num_classes=2) # 实例化网络,num_classes代表有几个类别

    net.to(device) # 将网络指认到GPU或CPU上
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 1
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    t1 = time.perf_counter()
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        #t1 = time.perf_counter()
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)
        #print(time.perf_counter()-t1)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar: # 遍历验证集
                val_images, val_labels = val_data # 数据分为图片和标签
                outputs = net(val_images.to(device)) # 将图片指认到设备上传入网络进行正向传播并得到输出
                predict_y = torch.max(outputs, dim=1)[1] # 求得输出预测中最有可得的类别(概率最大值)
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item() # 将预测标签与真实标签进行比对,求得总的预测正确数量

        val_accurate = acc / val_num # 预测正确数量/测试集总数量
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)
    print((time.perf_counter() - t1)/3600)
    print('Finished Training')


if __name__ == '__main__':
    main()

 predict.py:

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from net import MyAlexNet

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

def main():
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize((224, 224)),
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    classes = ("Cat", "Dog")

    net = MyAlexNet(num_classes=2)
    net.load_state_dict(torch.load('D:/cnn/All Classfication/AlexNet/save_model/last_model.pth'))  # 载入权重文件

    im = Image.open('2.jpg')
    plt.imshow(im)
    im = data_transform(im)  # [C, H, W] 调整shape H x W x C -> C x H x W
    im = torch.unsqueeze(im, dim=0)  # [N, C, H, W],增加一个新的维度N


    with torch.no_grad():
        outputs = net(im)
        predict = torch.max(outputs, dim=1)[1].numpy()
        # val0代表有多大的概率认为属于这个类别
        val0 = torch.max(outputs, dim=1)[0].numpy()
    print(classes[int(predict)])
    print(val0.item())

    plt.show()



if __name__ == '__main__':
    main()

reference:

3.2 使用pytorch搭建AlexNet并训练花分类数据集_哔哩哔哩_bilibili

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值