AlexNet网络代码详解(pytorch)

此代码是关于pytorch版本的AlexNet网络代码的详解,注释内容清晰,几乎每行都有解释,帮助很好的读懂相关代码。

1. model.py

import torch.nn as nn
import torch

class AlexNet(nn.Module):	# 创建类AlexNet,继承于父类nn.module
    def __init__(self, num_classes=1000, init_weights=False):	# 通过初始化函数,定义网络在正向传播中所需使用到的一些层结构
        super(AlexNet, self).__init__()	# super是将父类与子类关联起来
        self.features = nn.Sequential(	# features是专门用于提取图像特征,nn.Sequential此模块将一系列层结构进行打包组合成新结构
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224],output[48, 55, 55]
            nn.ReLU(inplace=True),	# inplace=True通过此方法可增加计算量,但可降低内存使用容量
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(	# classifier结构包含最后三层连接层,亦为分类器;nn.Sequential将全连接层打包成新模块
            nn.Dropout(p=0.5),	# p=0.5表示随机失活的比例为0.5
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),	# num_classes输出为数据集类别个数
        )
        if init_weights:	# 初始化权重
            self._initialize_weights()

    def forward(self, x):	# 定义正向传播过程
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)	# flatten展平处理,start_dim=1从索引1开始
        x = self.classifier(x)	# 展平后输入到分类结构中,即全连接层
        return x

    def _initialize_weights(self):	# 定义初始化权重
        for m in self.modules():	# 遍历self.modules模块
            if isinstance(m, nn.Conv2d):	# 如果此层为卷积层
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')	# 则用凯明初始化对权重w进行初始化
                if m.bias is not None:	# 如果偏置不为空
                    nn.init.constant_(m.bias, 0)	# 则0为它的初始化
            elif isinstance(m, nn.Linear):	# 否则如果传进的实例为全连接层
                nn.init.normal_(m.weight, 0, 0.01)	# 则通过正态分布给权重赋值,0表示正态分布的均值,0.01表示方差
                nn.init.constant_(m.bias, 0)	# 将偏置初始化为0

2.train.py

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")	# torch.device指定训练过程中所用设备
    print("using {} device.".format(device))	# 打印所使用的设备,0表示使用gpu,cpu表示使用cpu

    data_transform = {	# data_transform数据预处理函数
        "train": transforms.Compose([transforms.RandomResizedCrop(224),	# 随机裁剪图片大小为224×224
                                     transforms.RandomHorizontalFlip(),	# 随机翻转
                                     transforms.ToTensor(),	# 转换成tensor,即将灰度范围从0-255变换到0-1之间
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),	# 标准化处理,将数据从0-1转换成-1到1之间,变成了均值为0,方差为1的标准正态分布
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # os.path.join此函数可将输入的两个路径连在一起,os.getcwd此函数获取当前所在目录,../..前一个..表示返回上上层目录
    image_path = os.path.join(data_root, "data_set", "Plant_leave_diseases_dataset")  # 将数据集路径传入
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)	# 判断路径是否存在
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),	
                                         transform=data_transform["train"])	# datasets.ImageFolder此函数用于加载数据集
    train_num = len(train_dataset)	#此函数可打印训练集有多少张图片

    # 生成的字典文件{'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    # 生成的json文件{"0":"daisy", "1":"dandelion", "2":"roses", "3":"sunflower", "4":"tulips"}
    flower_list = train_dataset.class_to_idx	# class_to_idx用于获取分类名称所对应的索引
    cla_dict = dict((val, key) for key, val in flower_list.items())	#遍历刚刚获得的字典flower_list,将key与val反过来,即'daisy'和'0'反过来,这样预测所给的索引能直接通过此字典得到它所对应的类别
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)	# 保存在json_file文件中,方便在预测时读取信息

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
	
	# 将train_dataset加载进来
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)
                                               
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])	# 传入验证集的所对应的预处理函数结果
    val_num = len(validate_dataset)	# 统计验证集文件个数
    # 将validate_dataset加载进来
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    net = AlexNet(num_classes=5, init_weights=True)	# 导入自定义的AlexNet模型,并进行模型实例化

    net.to(device)	# 将网络指定到刚刚规定的设备上
    loss_function = nn.CrossEntropyLoss()	# 定义损失函数,交叉熵
    # pata = list(net.parameters())	# 查看模型参数
    # print(pata)
    optimizer = optim.Adam(net.parameters(), lr=0.0002)	# 定义Adam优化器,优化网络中所有参数

    epochs = 10	# 迭代10次
    save_path = './AlexNet.pth'	# 给定保存权重路径
    best_acc = 0.0	# 定义最佳准确率,为了保存准确率最高的训练模型
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # 训练集训练过程
        net.train()	# 启用Dropout方法
        running_loss = 0.0	# 统计训练过程中的平均损失
        train_bar = tqdm(train_loader, file=sys.stdout)	# 显示训练过程进度条
        for step, data in enumerate(train_bar):	# 遍历数据集
            images, labels = data	# 将数据分为图像和对应的标签
            optimizer.zero_grad()	# 清空之前梯度信息
            outputs = net(images.to(device))	# 将训练图像也指定到设备中,再正向传播得到输出
            loss = loss_function(outputs, labels.to(device))	# 计算预测值与真实值的损失
            loss.backward()	# 将损失反向传播到每个节点中
            optimizer.step()	# 再通过optimizer更新每个节点的参数
            running_loss += loss.item()	再将loss值累加到running_loss中

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)	# 打印训练进度

        # 验证集验证过程
        net.eval()	# 停止Dropout方法
        acc = 0.0  
        with torch.no_grad():	# 用此函数来禁止pytorch对参数跟踪,即在验证过程中不需计算损失梯度
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:	# 遍历验证集
                val_images, val_labels = val_data	#将数据分为图片和对应的标签
                outputs = net(val_images.to(device))	# 正向传播得到输出
                predict_y = torch.max(outputs, dim=1)[1]	# 求得输出最大值作为预测最有可能的类别
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()	# 将预测与真实标签对比,将预测对的个数进行求和

        val_accurate = acc / val_num	# 判断当前准确率是否大于历史最优准确率
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)	# 保存当前权重

    print('Finished Training')

if __name__ == '__main__':
    main()
  • 1
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值