pytorch实战分类网络(一)-----AlexNet

分类网络实战

一.数据集

数据集下载地址

https://www.kaggle.com/datasets/ayushv322/animal-classification

数据集处理代码
import os
from shutil import copy, rmtree
import random


def mk_file(file_path: str):
    if os.path.exists(file_path):

        rmtree(file_path)
    os.makedirs(file_path)


def main():

    random.seed(0)


    split_rate = 0.1


    cwd = r'C:\DL_dataset'
    data_root = os.path.join(cwd, "Kaggle_animals_classification")
    origin_flower_path = os.path.join(data_root, "animals_classification")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num*split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()
数据集目录结构

Kaggle_animals_classification
…animals_classification

二、搭建网络

网络结构图

在这里插入图片描述可以看到AlexNet的整个网络结构是由5个卷积层和3个全连接层组成的,深度总共八层。

搭建网络

model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

三、开始训练

train.py

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms,datasets,utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet
def main():
    device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    data_tramsform={
        'train':transforms.Compose(
            [transforms.RandomResizedCrop(224),
             transforms.RandomHorizontalFlip(),
             transforms.ToTensor(),
             transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
             ]
        ),
        'val':transforms.Compose(
            [transforms.Resize((224,224)),
             transforms.ToTensor(),
             transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
            ]
        )
    }

    image_path=r"C:\DL_dataset\Kaggle_animals_classification"

    train_dataset=datasets.ImageFolder(root=os.path.join(image_path,'train'),transform=data_tramsform['train'])
    train_num=len(train_dataset)

    # 生成json格式的文件
    flower_list=train_dataset.class_to_idx
    cla_dict=dict((val,key) for key,val in flower_list.items())
    json_str=json.dumps(cla_dict,indent=4)

    with open('class_indices.json','w') as json_file:
        json_file.write(json_str)

    batch_size=32
    # 其实这一句是为了在linux系统上运行方便,在windows上直接可以让num_workers=0
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    # 加载数据集
    train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=nw)
    validate_dataset=datasets.ImageFolder(root=os.path.join(image_path,'val'),transform=data_tramsform['val'])
    val_num=len(validate_dataset)
    validate_loader=torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size,shuffle=True,num_workers=nw)
    print("using {} images for training, {} images for validation.".format(train_num,val_num))
    # 开始为训练做准备,注意num_classes数据集是几类这个地方是几
    net=AlexNet(num_classes=4,init_weights=True)
    net.to(device)
    loss_function=nn.CrossEntropyLoss()
    optimizer=optim.Adam(net.parameters(),lr=0.0002)
    epochs=10
    # 保存训练权重
    save_path='./AlexNet.pth'
    best_acc=0.0
    train_steps=len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss=0.0
        train_bar=tqdm(train_loader,file=sys.stdout)
        for step,data in enumerate(train_bar):
            images,labels=data
            optimizer.zero_grad()
            outputs=net(images.to(device))
            loss=loss_function(outputs,labels.to(device))
            loss.backward()
            optimizer.step()

            running_loss=running_loss+loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)

        # validate
        net.eval()
        acc=0.0
        with torch.no_grad():
            val_bar=tqdm(validate_loader,file=sys.stdout)
            for val_data in val_bar:
                val_images,val_labels=val_data
                outputs=net(val_images.to(device))
                predict_y=torch.max(outputs,dim=1)[1]
                acc=acc+torch.eq(predict_y,val_labels.to(device)).sum().item()
            val_acc=acc/val_num
            print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_acc))
            if val_acc>best_acc:
                best_acc=val_acc
                torch.save(net.state_dict(),save_path)
    print('Finish training')
if __name__=='__main__':
    main()

四、预测

predict.py

import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import AlexNet

def main():
    device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    data_transform=transforms.Compose(
        [
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        ]
    )
    # 加载图片
    img_path='./检测图片.jpg'
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)

    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)
    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    json_file = open(json_path, "r")
    class_indict = json.load(json_file)
    # print(class_indict['3'],9)

    # create model
    model = AlexNet(num_classes=4).to(device)

    # load model weights
    weights_path = "./AlexNet.pth"
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path))

    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)

        # print(predict)
        predict_cla = torch.argmax(predict).numpy()
        # print(predict_cla,8)
    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    # print(type(predict[predict_cla].numpy()))
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()

if __name__ == '__main__':
    main()

五、总结

AlexNet的性质:

  1. 首次利用GPU进行加速运算
  2. 激活函数使用ReLu
  3. 在前两层的全连接中使用Dropout方法,随机失活一部分神经元,减少过拟合
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

leon.shadow

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值