训练相关代码

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np


def main():
    transform = transforms.Compose(
        [transforms.ToTensor(), #将图片转换为tensor
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#标准化
    #torchvision.datasets. 下载数据集

    # 50000张训练图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    # root表示将数据集下载到什么地方 train = True表示导入训练数据集
    # transform = transform 对数据进行预处理
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=False, transform=transform)
    #导入训练集
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,          #导入训练集  shuffle = True 表示打乱数据集
                                               shuffle=True, num_workers=0)
    #导入测试集
    # 10000张验证图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    # val_set = torchvision.datasets.CIFAR10(root='./data', train=True,
    #                                        download=False, transform=transform)
    # val_loader = torch.utils.data.DataLoader(val_set, batch_size=10000,
    #                                          shuffle=False, num_workers=0)
    # val_data_iter = iter(val_loader)
    # val_image, val_label = val_data_iter.next()

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=False, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
                                             shuffle=False, num_workers=0)#num_workers=0线程个数,windows下只能为0
    test_data_iter = iter(testloader)
    test_image, test_label = test_data_iter.next() #通过.next()获得图片和标签值
    #类别,元组类型 plane->0
    classes = ('plane', 'car', 'bird', 'cat',
                'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    #测试
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize 对图像进行反标准化处理
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0))) # h= w= channel=0
    #     plt.show()
    #
    # # print labels
    # print(' '.join(f'{classes[test_label[j]]:5s}' for j in range(4)))
    # # show images
    # imshow(torchvision.utils.make_grid(test_image))


    net = LeNet()
    loss_function = nn.CrossEntropyLoss() #定义损失函数
    optimizer = optim.Adam(net.parameters(), lr=0.001)#使用Adam优化器 导入参数量,lr是学习率

    #训练过程
    for epoch in range(5):  # loop over the dataset multiple times 训练迭代多少轮

        running_loss = 0.0 #记录累积的训练损失
        for step, data in enumerate(train_loader, start=0): #遍历训练集样本
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients 清除历史梯度
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = net(inputs) #将输入图片传入到网络中
            loss = loss_function(outputs, labels) #计算损失 outputs是网络预测的值,labels是输入图片对应的标签
            loss.backward() #将loss进行反向传播
            optimizer.step() #参数更新

            # print statistics
            running_loss += loss.item() #累加损失
            if step % 500 == 499:    # print every 500 mini-batches 每隔500步打印信息
                with torch.no_grad(): #with是一个上下文管理器
                    outputs = net(test_image)  # [batch, 10] 进行正向传播
                    predict_y = torch.max(outputs, dim=1)[1] #得到预测最大的值
                    accuracy = (predict_y==test_label).sum().item() /test_label.size(0) #将预测标签与真实标签比较 ,前面得到的是tensor数据,需要使用.item()进行数据转换
                                                                                                 #除以测试样本的数量
                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, accuracy))
                    running_loss = 0.0

    print('Finished Training')
    #保存模型
    save_path = './Lenet.pth'
    torch.save(net.state_dict(), save_path)


if __name__ == '__main__':
    main()

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值