图像分类:AlexNet代码(pytorch)之train.py解读

附上代码:

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#表示如果当前有可使用的GPU设备,就使用设备上的GPU设备,没有就是用CPU设备
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),#随机裁剪,将图像裁剪至224X224的大小
                                     transforms.RandomHorizontalFlip(),#在水平放心随机翻转
                                     transforms.ToTensor(),#转换成tensor
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path  #返回绝对路径
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)#打印训练集有多少张图片

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx#利用class_to_idx去获取分类名称所对应的索引
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size, shuffle=True,
                                               num_workers=nw)

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                                  batch_size=4, shuffle=True,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.__next__()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    net = AlexNet(num_classes=5, init_weights=True)

    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = net(images.to(device))
            loss = loss_function(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

解释顺序就是代码阅读顺序

训练数据集处理:

1.

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
	device = torch.device("cuda")
else:
	device = torch.device("cpu")

表示如果有GPU使用GPU进行计算训练,否则使用CPU

2.

  print("using {} device.".format(device))

一种格式化字符串的函数str.format()Python format 格式化函数 | 菜鸟教程 (runoob.com)

3.transforms.Compose():预处理函数

4.transforms.RandomResizedCrop(224):随机裁剪,将图像裁剪至224X224的大小

5.transforms.RandomHorizontalFlip():随机翻转,数据增强一种方法,这里是水平翻转。

6.transforms.ToTensor():转换成张量
7.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):标准化处理,第一个(0.5, 0.5, 0.5)为均值,第二个(0.5, 0.5, 0.5)为方差

测试数据集处理:

1.transforms.Resize((224, 224):图像大小改为224

2.transforms.ToTensor():转换成张量

3.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)):标准化处理

获取数据集:

1.

data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) 

os.path.abspath():获取数据集所在根目录,即返回绝对路径

os.path.join():将传入两个路径连接在一起

os.getcwd():获取当前所在文件的目录

"../..":返回上上级目录

2.

image_path = os.path.join(data_root, "data_set", "flower_data")

从根目录开始向下进行完整目录的拼接

3.

assert os.path.exists(image_path), "{} path does not exist.".format(image_path)

os.path.exists():函数的功能是查看给定的文件/目录是否存在,存在返回True,不存在返回False。

assert:Python3 assert(断言) | 菜鸟教程 (runoob.com)      

             为assert断言语句添加异常参数:assert的异常参数,其实就是在断言表达式后添加字符串信息,用来解释断言并更好的知道是哪里出了问题。格式如下:

             assert expression [, arguments]

             assert 表达式 [, 参数]

4.

 train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])

datasets.ImageFolder():加载数据集

root=image_path + "/train":传入训练集数据路径

transform=data_transform["train"]:调用训练数据集预处理模块  即:

 "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

5.  train_num = len(train_dataset):打印训练集有多少张图片

分类处理:

1.flower_list = train_dataset.class_to_idx:利用class_to_idx去获取分类名称所对应的索引

2.cla_dict = dict((val, key) for key, val in flower_list.items()) :循环遍历数组索引该值并交换重新赋值给数组,这样模型预测出来的直接就是value类别值。Python 字典(Dictionary) items()方法 | 菜鸟教程 (runoob.com)

3.json_str = json.dumps(cla_dict, indent=4):把字典编码成json格式,indent参数决定添加几个空格

4. with open('class_indices.json', 'w') as json_file: json_file.write(json_str):把字典类别索引写入json文件

5.batch_size = 32:一次性载入32张图片

6.torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True,num_workers=0):加载数据集和其他参数

7.datasets.ImageFolder(root=image_path + "/val",transform=data_transform["val"]):加载测试集路径和测试集预处理模块。

查看数据集代码:

    test_data_iter = iter(validate_loader)
    test_image, test_label = test_data_iter.__next__()

    def imshow(img):
        img = img / 2 + 0.5  # unnormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.show()

    print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    imshow(utils.make_grid(test_image))

使用设计的模型:

1.AlexNet(num_classes=5, init_weights=True):传入参数调用类

2.num_classes=5:类别数为5,最终全连接生成5个值的向量。

3.init_weights=True:初始化模型训练权重参数,从头开始训练

4.net.to(device):设备(GPU或CPU)加载网络。

5.loss_function = nn.CrossEntropyLoss():设置损失函数。CrossEntropyLoss为交叉熵损失函数。

6.optimizer = optim.Adam(net.parameters(), lr=0.0002):设置Adam优化器。根据模型当前参数决定优化即调整参数的增减和幅度。Ir为学习率。过大过小都会影响准确率。

7.save_path = './AlexNet.pth':设置保存权重的路径

8.best_acc = 0.0:设置准确率变量

开始训练:

1.for epoch in range(10):遍历迭代10次

2.net.train():调用Dropout方法

3.running_loss = 0.0:设置训练损失值

4.t1 = time.perf_counter():设置记录训练开始时间以计算一个epoch所花费时间

5.for step, data in enumerate(train_loader, start=0):遍历数据集,返回数据data和步长step

6.images, labels = data:把data数组中的图像和标签分别赋值给变量images和label。

7.optimizer.zero_grad():清空之前的梯度信息。作用是将历史损失梯度进行清零,一般batch_size这个数值设置的越大,训练效果越好,但由于硬件设备受限,不可能用一个很大的batch_size进行训练,而通过Optimizer.zero_grad()可以变相实现一个很大batch数目的训练,即一次性计算多个小的batch。

8.outputs = net(images.to(device)):开始进行正向传播,并把图像计算与设备进行绑定。

9.loss = loss_function(outputs, labels.to(device)):训练得到预测输出之后与真实标签进行计算损失值。

10.loss.backward():将loss反向传播到各个节点。

11.optimizer.step():更新每个节点参数

12.running_loss += loss.item():进行一个loss的累加

13.rate = (step + 1) / len(train_loader):当前训练步数,比如72/900、73/900。 len(train_loader)表示训练一轮需要的步数

14.a = "*" * int(rate * 50),b = "." * int((1 - rate) * 50):使用*和.,打印进度百分比

15.print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end=""):打印训练进度信息

 验证过程:

1.net.eval():控制不使用Dropout。

2. with torch.no_grad():该函数在接下来计算过程中不要去计算每个节点的误差损失梯度,如果没有使用它,在运行时就会消耗更多的算力,占用更多的内存资源,甚至会内存崩掉。在这个函数的所有范围内的计算都不会去计算它的误差梯度

3.for val_data in validate_loader:遍历验证数据集

4.val_images, val_labels = val_data:把data数组中的图像和标签分别赋值给变量images和label

5.outputs = net(val_images.to(device)):开始进行正向传播,并把图像计算与设备进行绑定

6.predict_y = torch.max(outputs, dim=1)[1]:这代码的意思是获得这个batch中网络的预测标签,torch.max(outputs, dim=1)返回两个值,分别是最大值和其对应的索引,dim=1时按行返回最大索引,dim=0时按列返回最大索引。

7.acc += (predict_y == val_labels.to(device)).sum().item():将预测值与真实值进行比较,相等为1,不等为0,并求和,并通过item()获得数值。

8.val_accurate = acc / val_num:预测正确值之和除以总和计算准确率

9.        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path):判断当前准确率是否大于历史准确率,如果是保存当前模型权重

10.print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))

print('Finished Training'):打印训练轮数、损失值、步长、准确率,结束训练。
 

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值