Pytorch训练<识别手写数字>

import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt

train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
    print('CUDA is not available. Training on CPU...')
else:
    print('CUDA is available! Training on GPU...')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 512

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,))])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist data', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,))])),
    batch_size=batch_size, shuffle=False)

model = models.resnet18()

num_classes = 10
for param in model.parameters():
    param.requires_grad = False

model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Sequential(
    nn.Dropout(),
    nn.Linear(model.fc.in_features, num_classes),
    nn.LogSoftmax(dim=1)
)
model.to(device)
# print(model)

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

filename = "recognize_handwritten_digits.pt"


def save_checkpoint(epoch, model, optimizer, filename):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, filename)


num_epochs = 50
train_loss = []
for epoch in range(num_epochs):
    running_loss = 0
    correct = 0
    total = 0
    for i, (inputs, labels) in enumerate(train_loader):
        # 将数据放到设备上
        inputs, labels = inputs.to(device), labels.to(device)

        # 前向计算
        outputs = model(inputs)

        # 计算损失和梯度
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()

        # 更新模型参数
        optimizer.step()

        # 记录损失和准确率
        running_loss += loss.item()
        train_loss.append(loss.item())
        _, predicted = torch.max(outputs.data, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
    accuracy_train = 100 * correct / total
    # 在测试集上计算准确率
    with torch.no_grad():
        running_loss_test = 0
        correct_test = 0
        total_test = 0
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss_test += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            correct_test += (predicted == labels).sum().item()
            total_test += labels.size(0)
        accuracy_test = 100 * correct_test / total_test

    # 输出每个 epoch 的损失和准确率
    print("Epoch [{}/{}], Loss: {:.4f}, Train Accuracy: {:.2f}%,Loss: {:.4f}, Test Accuracy: {:.2f}%"
          .format(epoch + 1, num_epochs, running_loss / len(train_loader),
                  accuracy_train, running_loss_test / len(test_loader), accuracy_test))
    save_checkpoint(epoch, model, optimizer, filename)

plt.plot(train_loss, label='Train Loss')
# 添加图例和标签
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss')

# 显示图形
plt.show()

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值