用pytorch构建Alexnet模型(train模块)(个人笔记)

import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm

from model import AlexNet


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # 通过torch.device来指定训练设备(cpu或gpu)
    print("using {} device.".format(device))

    data_transform = {     # 定义数据预处理函数
        "train": transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪(裁剪到224*224)
                                     transforms.RandomHorizontalFlip(),  # 随机翻转
                                     transforms.ToTensor(),  # 转化为一个tensor
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),  # 标准化处理
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path(获取目录)
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),  # 通过atasets.ImageFolder加载数据集
                                         transform=data_transform["train"])  # transform数据集预处理,
    # 通过定义的data_transform这个字典,传入"train"这个key,他就会返回训练集对应的数据预处理(字典还能这么用?)
    train_num = len(train_dataset)   # 通过len()函数打印训练集有多少张图片

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx   # 通过此获取分类类别的名称所对应的索引(如‘daisy’:0,)
    cla_dict = dict((val, key) for key, val in flower_list.items())  # 通过此遍历所获得的的字典(val:类别名称;key:类别索引)
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)  # 通过json包,将cla_dict这个字典编码成json格式
    with open('class_indices.json', 'w') as json_file:  # 打开class_indices.json文件(可以查看这个文件)
        json_file.write(json_str)

    batch_size = 32  # 定义batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,   # 通过DataLoader加载train_dataset
                                               batch_size=batch_size, shuffle=True,  # 通过给定的 batch_size与随机参数(shuffle),
                                               # 随机的从样本中回去一批一批的数据
                                               num_workers=nw)  # window系统不能设为非零值

    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),  # 通过ImageFolder载入测试集
                                            transform=data_transform["val"])  # 传入测试集所对应的预处理方式
    val_num = len(validate_dataset)   # 统计测试集的文件个数
    validate_loader = torch.utils.data.DataLoader(validate_dataset,   # 载入测试集
                                                  batch_size=4, shuffle=True,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)     #  如何查看数据集的代码
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()
    #
    # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    # imshow(utils.make_grid(test_image))

    net = AlexNet(num_classes=5, init_weights=True)  # 传入类别5 ; 初始化权重设为True

    net.to(device)  # 将网络指认到指定的刚刚设置的设备上(CPU或GPU)
    loss_function = nn.CrossEntropyLoss()  # 定义损失函数(使用是多类别交叉熵函数)
    # pata = list(net.parameters())  # 查看模型参数
    optimizer = optim.Adam(net.parameters(), lr=0.0002)  # 定义的优化器,优化对象是所有的可训练参数(net.parameters()),
    # 学习率为lr=0.0002(已经过调试,变大或变小都会影响准确率)

    epochs = 10
    save_path = './AlexNet.pth'  # 保存权重路径
    best_acc = 0.0   # 定义最佳准确率,保存最高准确率
    train_steps = len(train_loader)
    for epoch in range(epochs):     # 开始训练
        # train
        net.train()    # 通过 net.train() 与 net.eval() 方法来管理Dropout()方法,也可以管理BN层
        # 使用net.train()开启Dropout()方法,使用net.eval()关闭Dropout()方法
        running_loss = 0.0  # 统计训练过程中的实际损失
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):   # 遍历数据集
            images, labels = data   # 将数据分为images, labels
            optimizer.zero_grad()  # 清除历史梯度信息
            outputs = net(images.to(device))     # 将训练图像也指认到设备中
            loss = loss_function(outputs, labels.to(device))  # 通过loss_function计算预测值与真实值的损失,这里同样将labels指认到设备中
            loss.backward()   # 将损失反向传播到每个节点中
            optimizer.step()   # 通过optimizer更新每个节点的参数

            # print statistics
            running_loss += loss.item()  # 将得到的损失值累加到running_loss

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)     # 打印训练进度

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():   # 训练一轮之后开始验证,通过with torch.no_grad()禁止pytorch对我们的参数进行跟踪,在我们训练过程中不会计算损失梯度
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:  # 遍历训练集
                val_images, val_labels = val_data  # 将训练集划分为val_images, val_labels
                outputs = net(val_images.to(device))  # 通过正向传播得到输出
                predict_y = torch.max(outputs, dim=1)[1]   # 通过torch.max(),求得输出中预测最有可能的那个类别
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()  # 将预测与真实标签进行对比,正确为1,错误为0,
                                                                                # 然后加起来,通过item得到数值

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))   # 打印信息

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)  # 保存当前权重

    print('Finished Training')


if __name__ == '__main__':
    main()

(感谢“霹雳吧啦”)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值