import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
def main():
transform = transforms.Compose(
[transforms.ToTensor(), #将图片转换为tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#标准化
#torchvision.datasets. 下载数据集
# 50000张训练图片
# 第一次使用时要将download设置为True才会自动去下载数据集
# root表示将数据集下载到什么地方 train = True表示导入训练数据集
# transform = transform 对数据进行预处理
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
download=False, transform=transform)
#导入训练集
train_loader = torch.utils.data.DataLoader(train_set, batch_size=36, #导入训练集 shuffle = True 表示打乱数据集
shuffle=True, num_workers=0)
#导入测试集
# 10000张验证图片
# 第一次使用时要将download设置为True才会自动去下载数据集
# val_set = torchvision.datasets.CIFAR10(root='./data', train=True,
# download=False, transform=transform)
# val_loader = torch.utils.data.DataLoader(val_set, batch_size=10000,
# shuffle=False, num_workers=0)
# val_data_iter = iter(val_loader)
# val_image, val_label = val_data_iter.next()
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
shuffle=False, num_workers=0)#num_workers=0线程个数,windows下只能为0
test_data_iter = iter(testloader)
test_image, test_label = test_data_iter.next() #通过.next()获得图片和标签值
#类别,元组类型 plane->0
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
#测试
# def imshow(img):
# img = img / 2 + 0.5 # unnormalize 对图像进行反标准化处理
# npimg = img.numpy()
# plt.imshow(np.transpose(npimg, (1, 2, 0))) # h= w= channel=0
# plt.show()
#
# # print labels
# print(' '.join(f'{classes[test_label[j]]:5s}' for j in range(4)))
# # show images
# imshow(torchvision.utils.make_grid(test_image))
net = LeNet()
loss_function = nn.CrossEntropyLoss() #定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001)#使用Adam优化器 导入参数量,lr是学习率
#训练过程
for epoch in range(5): # loop over the dataset multiple times 训练迭代多少轮
running_loss = 0.0 #记录累积的训练损失
for step, data in enumerate(train_loader, start=0): #遍历训练集样本
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients 清除历史梯度
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs) #将输入图片传入到网络中
loss = loss_function(outputs, labels) #计算损失 outputs是网络预测的值,labels是输入图片对应的标签
loss.backward() #将loss进行反向传播
optimizer.step() #参数更新
# print statistics
running_loss += loss.item() #累加损失
if step % 500 == 499: # print every 500 mini-batches 每隔500步打印信息
with torch.no_grad(): #with是一个上下文管理器
outputs = net(test_image) # [batch, 10] 进行正向传播
predict_y = torch.max(outputs, dim=1)[1] #得到预测最大的值
accuracy = (predict_y==test_label).sum().item() /test_label.size(0) #将预测标签与真实标签比较 ,前面得到的是tensor数据,需要使用.item()进行数据转换
#除以测试样本的数量
print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %
(epoch + 1, step + 1, running_loss / 500, accuracy))
running_loss = 0.0
print('Finished Training')
#保存模型
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)
if __name__ == '__main__':
main()
训练相关代码
最新推荐文章于 2023-06-17 11:50:59 发布