在Pytorch中使用mnist数据集做分类案例

案例:

import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.utils.data as data_utils
import cv2

# data
train_data = dataset.MNIST(root='./mnist',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)

test_data = dataset.MNIST(root='./mnist',
                           train=False,
                           transform=transforms.ToTensor(),
                           download=False)

# batchsize
train_loader = data_utils.DataLoader(dataset=train_data,
                                     batch_size=64,
                                     shuffle=True)

test_loader = data_utils.DataLoader(dataset=test_data,
                                     batch_size=64,
                                     shuffle=True)


# net
class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=5, padding=2),
            torch.nn.BatchNorm2d(num_features=32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2)
        )
        self.fc = torch.nn.Linear(in_features=14 * 14 * 32, out_features=10)

    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.size()[0], -1)  # 输入数据整成全连接需要的数据格式
        out = self.fc(out)
        return out


cnn = CNN()

# loss
loss_func = torch.nn.CrossEntropyLoss()
# optimizer
optimizer = torch.optim.Adam(cnn.parameters(), lr=0.001)

# train
for epoch in range(10):
    for i, (images, lables) in enumerate(train_loader):
        # images = images.cuda()
        # lables = lables.cuda()
        # images为黑边图,通道数为1
        outputs = cnn(images)
        loss = loss_func(outputs, lables)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print('epoch is {}, ite is {}/{}, loss is {}'.format(
            epoch + 1, i, len(train_data)//64, loss.item()))


    # eval
    loss_test = 0
    acc = 0
    for i, (images, lables) in enumerate(test_loader):
        # images = images.cuda()
        # lables = lables.cuda()
        outputs = cnn(images)
        loss_test += loss_func(outputs, lables)
        _, pred = outputs.max(1) # 第一维度的最大值
        acc += (pred == lables).sum().item()

    acc = acc / len(test_data)
    loss_test = loss_test / (len(test_data) // 64)

    print('epoch:{}, acc:{}, loss:{}'.format(epoch + 1,acc, loss_test.item()))

# 保存模型
torch.save(cnn, './model/class.pkl')

# 加载模型做测试
# test/eval
cnn = torch.load('./model/class.pkl')
loss_test = 0
acc = 0
for i, (images, lables) in enumerate(test_loader):
    # images = images.cuda()
    # lables = lables.cuda()
    outputs = cnn(images)
    loss_test += loss_func(outputs, lables)
    _, pred = outputs.max(1)  # 第一维度的最大值
    acc += (pred == lables).sum().item()

    # 画图
    for idx in range(images.shape[0]):
        im_data = images[idx]
        im_label = lables[idx]
        im_data = im_data.transpose(1,2,0)
        cv2.imshow('imdata', im_data)
        cv2.waitKey(0)

acc = acc / len(test_data)
print('epoch:{}, acc:{}'.format(epoch + 1, acc))

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

浅蓝的风

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值