图像分类:LeNet代码(pytorch)之train.py解读

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt

# def main():
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])``
# ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
#
#     # 50000张训练图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=False,transform=transform)#这里的transform=transform,是对每一个图片进行transform处理
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
                                               shuffle=True, num_workers=0)# 将train_set传入进来,batch_size=36表示一批拿36张图片,shuffle=True表示将数据集进行打乱打乱,num_workers=0表示载入数据的线程数,在liunx系统或者 Ubuntu系统下是可以自己定义,但在windows系统下只能num_wworks=0。
#torch.utils.data.DataLoader:是指批训练,把数据变成一小批一小批数据进行训练。DataLoader就是用来包装所使用的数据,每次抛出一批数据

#
#     # 10000张验证图片
#     # 第一次使用时要将download设置为True才会自动去下载数据集
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=10000,
                                             shuffle=False, num_workers=0)
#torch.utils.data.DataLoader参数:
#dataset (Dataset) – 加载数据的数据集。
# batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
# shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
# sampler (Sampler, optional) – 定义从数据集中提取样本的策略,即生成index的方式,可以顺序也可以乱序
# num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
# collate_fn (callable, optional) –将一个batch的数据和标签进行合并操作。
# pin_memory (bool, optional) –设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
# drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
# timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。

val_data_iter = iter(val_loader)
val_image, val_label = next(val_data_iter)
#
classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# def imshow(img):
#     img=img/2+0.5#对图像进行反标准化处理 unnormalize   标准化处理(normalize)``output[channel] = (input[channel] - mean[channel]) / std[channel]``
#     #input=output/2+0.5                                                          #output=(input-0.5)*2=2input-1
#     nping=img.numpy()#将图像转换为numpy结构,因为再转换为tensor的过程中,Pytorch的通道排序为:[batch,channel,height,width],一般batch不填,所以channel一般在第一位,然后是高度与宽度
#     plt.imshow(np.transpose(nping,(1,2,0)))#通过np.transpose()将图像转换为原始的shape格式,原始的shape格式为[height,weight,channel]
#     plt.show()
#
# #print labels
# print(' '.join('%5s' % classes[val_label[j]] for j in range(4)))
# #show image
# imshow(torchvision.utils.make_grid(val_image))
net = LeNet()#实例化LeNet()网络
loss_function = nn.CrossEntropyLoss()#定义损失函数,CrossEntropyLoss()函数:包含了nn.LogSoftmax函数和nn.NLLLoss函数,所以不需要在网络的输出加上Softmax函数
optimizer = optim.Adam(net.parameters(), lr=0.001)#定义优化器,使用Adam优化器,net.parameters()表示将所有LeNet()可以训练的参数都进行训练,lr表示leraning rate即学习率

for epoch in range(5):  # loop over the dataset multiple times    # 表示将训练集迭代5次

    running_loss = 0.0#一个用来累加在训练过程中损失的变量
    for step, data in enumerate(train_loader, start=0):#该循环用来遍历训练集样本,enumerate用来返回每一批数据data,也可返回这一批数据data所对应的步数index
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data#将得到数据分离成输入的图像和图像所对应的标签

        # zero the parameter gradients
        optimizer.zero_grad()#将历史损失进行清零
        # forward + backward + optimize
        outputs = net(inputs)#将得到的图像输入LetNet网络当中,进行正向传播,得到输出
        loss = loss_function(outputs, labels)#计算损失,outputs就是网络预测的值,labels就是输入图片对应的真实标签
        loss.backward()#将损失进行反向传播
        optimizer.step()#进行参数更新

        # print statistics
        running_loss += loss.item()#将每次计算完的loss进行累加
        if step % 500 == 499:    # print every 500 mini-batches,每隔500步打印一次数据的信息
            with torch.no_grad():#with是一个上下文管理器,在这里的用处是
                outputs = net(val_image)  # [batch, 10]
                predict_y = torch.max(outputs, dim=1)[1]#寻找输出的最大index在什么位置,既可以理解为网络预测最可能归为那个类别,需要在维度1上寻找最大值,因为dim=0是batch,即需要在输出的10个节点中寻找最大的值。[1]表示只需要最大值的index值即索引,最后得到预测的最大值所对应的标签类别。
                accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)#将预测的标签类别和真实的标签类别进行比较,比较之后,如果相等就返回1即ture,不相等就返回0即false,再通过求和函数求得在本次测试过程中预测对了多少个样本, 由于整个过程都是在tensor变量中进行计算的,计算的和是一个tensor,并不是一个数值,通过item()这个方法拿到这个数值,再除以测试样本的数目,就得到测试的准确率

                print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %
                          (epoch + 1, step + 1, running_loss / 500, accuracy))
                #分别打印:epoch+1是指训练在第几轮,step + 1是指在某一轮的第几步, running_loss 是指在训练过程中累加的误差,running_loss / 500是指每训练500步的平均误差,accuracy是指测试样本的准确率
                running_loss = 0.0#进行清零,然后进行下一次的迭代过程

print('Finished Training')#打印完成训练

save_path = './Lenet.pth'#为保存路径
torch.save(net.state_dict(), save_path)#将网络的所有参数进行保存,及保存在相应的路径,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数


# if __name__ == '__main__':
#     main()

if __name__=="__main__"的作用:在A模块中使用if __name__=="__main__",而B模块调用A模块即import A,则__name__变为A。在B模块调用A模块的时候会自动忽略掉A模块中的if __name__=="__main__"。而在A模块中添加if __name__=="__main__",是为了能更好的测试或者让该模块更好的执行,而不影响调用它的模块。

训练数据集处理:

1.预处理函数:transforms.Compose()

2.transforms.ToTensor():转化为张量,将一个PIL图像或者numpy的数据(H,W,C)转换为一个tensor(C,H,W),即将[0,255]区间转换为[0,1]区间

3.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]):标准化处理,防止突出数值较高的指标在综合分析中的作用。

测试数据集预处理同上

获取数据集:(获取训练集和测试集,训练集、测试集获取方式相同)

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=False,transform=transform)

1.root='./data':把下载后的文件放在data文件夹下

2.train=True:导入数据集的训练集部分,若train=False就不会导入

3. download=False:不会下载训练集,若download=True就会下载

4.transform=transform:对训练集的图像进行预处理

train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
                                               shuffle=True, num_workers=0)

1.torch.utils.data.DataLoader参数:

dataset (Dataset) – 加载数据的数据集。

 batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。

 shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).

sampler (Sampler, optional) – 定义从数据集中提取样本的策略,即生成index的方式,可以顺序也可以乱序

num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)

 collate_fn (callable, optional) –将一个batch的数据和标签进行合并操作。

pin_memory (bool, optional) –设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。

drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。

2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值