B站学习小土堆视频进行完整的模型训练
模型训练代码
import torch
import torchvision.datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import *
#准备数据集
train_data = torchvision.datasets.CIFAR10("./data",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10("./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)
#length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
#利用DataLoader加载数据
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)
#创建网络模型
tudui = Tudui()
#损失函数
loss_fn = nn.CrossEntropyLoss()
#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)
#设置训练网络的参数
#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epoch = 10
#添加tensorboard
writer = SummaryWriter("logs_train")
for i in range(epoch):
print("------第 {} 轮训练开始------".format(i+1))
#训练步骤开始
tudui.train()#
for data in train_dataloader:
imgs,targets = data
output = tudui(imgs)
loss = loss_fn(output,targets)
#优化器优化模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_step = total_train_step + 1
if total_train_step % 100 == 0:
print("训练次数:{}, Loss:{} ".format(total_train_step,loss.item()))
writer.add_scalar("train_loss",loss.item(),total_train_step)
#测试步骤开始
tudui.eval()#
total_test_loss = 0
total_test_accuracy = 0
#测试阶段不需要梯度
with torch.no_grad():
for data in test_dataloader:
imgs, targets = data
output = tudui(imgs)
loss = loss_fn(output,targets)
total_test_loss = total_test_loss + loss.item()
accuarcy = (output.argmax(1) == targets).sum()
total_test_accuracy = total_test_accuracy + accuarcy
print("整体测试集上Loss:{}".format(total_test_loss))
print("整体测试集上的正确率:{} ".format(total_test_accuracy/test_data_size))
writer.add_scalar("test_loss",total_test_loss,total_test_step)
writer.add_scalar("test_accuarcy",total_test_accuracy/test_data_size,total_test_step)
total_test_step = total_test_step + 1
torch.save(tudui,"tudui_{}.pth".format(i))
print("模型已保存")
writer.close()