import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
# def main():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) #Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])``
# ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
#
# # 50000张训练图片
# 第一次使用时要将download设置为True才会自动去下载数据集
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=False,transform=transform)#这里的transform=transform,是对每一个图片进行transform处理
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
shuffle=True, num_workers=0)# 将train_set传入进来,batch_size=36表示一批拿36张图片,shuffle=True表示将数据集进行打乱打乱,num_workers=0表示载入数据的线程数,在liunx系统或者 Ubuntu系统下是可以自己定义,但在windows系统下只能num_wworks=0。
#torch.utils.data.DataLoader:是指批训练,把数据变成一小批一小批数据进行训练。DataLoader就是用来包装所使用的数据,每次抛出一批数据
#
# # 10000张验证图片
# # 第一次使用时要将download设置为True才会自动去下载数据集
val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=10000,
shuffle=False, num_workers=0)
#torch.utils.data.DataLoader参数:
#dataset (Dataset) – 加载数据的数据集。
# batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
# shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
# sampler (Sampler, optional) – 定义从数据集中提取样本的策略,即生成index的方式,可以顺序也可以乱序
# num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
# collate_fn (callable, optional) –将一个batch的数据和标签进行合并操作。
# pin_memory (bool, optional) –设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
# drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
# timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。
val_data_iter = iter(val_loader)
val_image, val_label = next(val_data_iter)
#
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# def imshow(img):
# img=img/2+0.5#对图像进行反标准化处理 unnormalize 标准化处理(normalize)``output[channel] = (input[channel] - mean[channel]) / std[channel]``
# #input=output/2+0.5 #output=(input-0.5)*2=2input-1
# nping=img.numpy()#将图像转换为numpy结构,因为再转换为tensor的过程中,Pytorch的通道排序为:[batch,channel,height,width],一般batch不填,所以channel一般在第一位,然后是高度与宽度
# plt.imshow(np.transpose(nping,(1,2,0)))#通过np.transpose()将图像转换为原始的shape格式,原始的shape格式为[height,weight,channel]
# plt.show()
#
# #print labels
# print(' '.join('%5s' % classes[val_label[j]] for j in range(4)))
# #show image
# imshow(torchvision.utils.make_grid(val_image))
net = LeNet()#实例化LeNet()网络
loss_function = nn.CrossEntropyLoss()#定义损失函数,CrossEntropyLoss()函数:包含了nn.LogSoftmax函数和nn.NLLLoss函数,所以不需要在网络的输出加上Softmax函数
optimizer = optim.Adam(net.parameters(), lr=0.001)#定义优化器,使用Adam优化器,net.parameters()表示将所有LeNet()可以训练的参数都进行训练,lr表示leraning rate即学习率
for epoch in range(5): # loop over the dataset multiple times # 表示将训练集迭代5次
running_loss = 0.0#一个用来累加在训练过程中损失的变量
for step, data in enumerate(train_loader, start=0):#该循环用来遍历训练集样本,enumerate用来返回每一批数据data,也可返回这一批数据data所对应的步数index
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data#将得到数据分离成输入的图像和图像所对应的标签
# zero the parameter gradients
optimizer.zero_grad()#将历史损失进行清零
# forward + backward + optimize
outputs = net(inputs)#将得到的图像输入LetNet网络当中,进行正向传播,得到输出
loss = loss_function(outputs, labels)#计算损失,outputs就是网络预测的值,labels就是输入图片对应的真实标签
loss.backward()#将损失进行反向传播
optimizer.step()#进行参数更新
# print statistics
running_loss += loss.item()#将每次计算完的loss进行累加
if step % 500 == 499: # print every 500 mini-batches,每隔500步打印一次数据的信息
with torch.no_grad():#with是一个上下文管理器,在这里的用处是
outputs = net(val_image) # [batch, 10]
predict_y = torch.max(outputs, dim=1)[1]#寻找输出的最大index在什么位置,既可以理解为网络预测最可能归为那个类别,需要在维度1上寻找最大值,因为dim=0是batch,即需要在输出的10个节点中寻找最大的值。[1]表示只需要最大值的index值即索引,最后得到预测的最大值所对应的标签类别。
accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)#将预测的标签类别和真实的标签类别进行比较,比较之后,如果相等就返回1即ture,不相等就返回0即false,再通过求和函数求得在本次测试过程中预测对了多少个样本, 由于整个过程都是在tensor变量中进行计算的,计算的和是一个tensor,并不是一个数值,通过item()这个方法拿到这个数值,再除以测试样本的数目,就得到测试的准确率
print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, step + 1, running_loss / 500, accuracy))
#分别打印:epoch+1是指训练在第几轮,step + 1是指在某一轮的第几步, running_loss 是指在训练过程中累加的误差,running_loss / 500是指每训练500步的平均误差,accuracy是指测试样本的准确率
running_loss = 0.0#进行清零,然后进行下一次的迭代过程
print('Finished Training')#打印完成训练
save_path = './Lenet.pth'#为保存路径
torch.save(net.state_dict(), save_path)#将网络的所有参数进行保存,及保存在相应的路径,torch.nn.Module模块中的state_dict变量存放训练过程中需要学习的权重和偏执系数
# if __name__ == '__main__':
# main()
if __name__=="__main__"的作用:在A模块中使用if __name__=="__main__",而B模块调用A模块即import A,则__name__变为A。在B模块调用A模块的时候会自动忽略掉A模块中的if __name__=="__main__"。而在A模块中添加if __name__=="__main__",是为了能更好的测试或者让该模块更好的执行,而不影响调用它的模块。
训练数据集处理:
1.预处理函数:transforms.Compose()
2.transforms.ToTensor():转化为张量,将一个PIL图像或者numpy的数据(H,W,C)转换为一个tensor(C,H,W),即将[0,255]区间转换为[0,1]区间
3.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]):标准化处理,防止突出数值较高的指标在综合分析中的作用。
测试数据集预处理同上
获取数据集:(获取训练集和测试集,训练集、测试集获取方式相同)
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=False,transform=transform)
1.root='./data':把下载后的文件放在data文件夹下
2.train=True:导入数据集的训练集部分,若train=False就不会导入
3. download=False:不会下载训练集,若download=True就会下载
4.transform=transform:对训练集的图像进行预处理
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,
shuffle=True, num_workers=0)
1.torch.utils.data.DataLoader参数:
dataset (Dataset) – 加载数据的数据集。
batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
sampler (Sampler, optional) – 定义从数据集中提取样本的策略,即生成index的方式,可以顺序也可以乱序
num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
collate_fn (callable, optional) –将一个batch的数据和标签进行合并操作。
pin_memory (bool, optional) –设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)
timeout,是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。