torch之训练过程train()
1.cpu版本
import torchvision
from torch.utils.tensorboard import SummaryWriter
from model import *
from torch.utils.data import DataLoader
train_data = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
firstmodel = FirstModel()
loss_fn = nn.CrossEntropyLoss()
learning_rate = 0.01
optimzer = torch.optim.SGD(params=firstmodel.parameters(), lr=learning_rate)
total_train_step = 0
total_test_step = 0
epoch = 10
writer = SummaryWriter("./logs")
for i in range(epoch):
print("—————————————第{}轮训练开始————————————".format(i + 1))