import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
# 超参数定义
EPOCH = 1000
BATCH_SIZE = 64
LR = 0.001
# 数据加载
train_data = datasets.CIFAR10(root='/root/cifar10', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.CIFAR10(root='/root/cifar10', train=False, transform=transforms.ToTensor(), download=True)
# 输出图像
"""temp = train_data[1][0].numpy()
print(temp.shape)
temp = temp.transpose(1, 2, 0)
print(temp.shape)
plt.imshow(temp)
plt.show()"""
# 使用DataLoader进行分批
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True)
# 使用ResNet Model
model = torchvision.models.resnet18(pretrained=False)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=LR)
# device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
import time
# 训练
for epoch in range(EPOCH):
start_time = time.time()
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
outputs = model(inputs)
# 计算损失函数
loss = criterion(outputs, labels)
# 清空上一轮梯度
optimizer.zero_grad()
# 反向传播
loss.backward()
# 参数更新
optimizer.step()
print('epoch{:} loss:{:.4f}'.format(epoch+1, loss.item(), time.time() - start_time))
# 保存训练模型
file_name = 'cifar10_resnet.pt'
torch.save(model, file_name)
print('model saved')
# 测试
model = torch.load(file_name)
model.eval()
correct, total = 0, 0
for data in test_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
# 前向传播
out = model(images)
_, predicted = torch.max(out.data, 1)
total = total + labels.size(0)
correct += (predicted == labels).sum().item()
print('10000张测试图像 准确率{:.4}%'.format(100.0 * correct / total))
pytorch实现cifar10分类
最新推荐文章于 2024-06-25 23:15:17 发布