Pytorch训练cifar10数据集

下面完整代码在github仓库:传送门


一、利用卷积神经网络训练cifar10数据集

from torch.utils.data import DataLoader
from full_conv import Net
import matplotlib.pyplot as plt
import torch
from torch import nn
from torchvision import transforms, datasets
from PIL import Image
import numpy as np

if __name__ == '__main__':
    batch_size = 100
    save_params = "./net_params.pth"
    save_net = "./net.pth"

    train_data = datasets.CIFAR100("./cifar100", True, transforms.ToTensor(), download=True)
    test_data = datasets.CIFAR100("./cifar100", False, transforms.ToTensor(), download=True)

    train_loader = DataLoader(train_data, batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size, shuffle=True)
    # print(train_data.data.shape)  # (50000, 32, 32, 3)
    # print(np.shape(train_data.targets))  # (50000,)
    # print(train_data.classes)  # ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    # print(train_data.train)  # True
    # print()
    # print(test_data.data.shape)  # (10000, 32, 32, 3)
    # print(np.shape(test_data.targets))  # (10000,)
    # print(test_data.classes)
    # print(test_data.train)  # False

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    net = Net().to(device)
    net.load_state_dict(torch.load(save_params))
    # net = torch.load(save_net).to(device)

    loss_fn = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(net.parameters())

    plt.ion()
    net.train()
    a = []
    b = []
    for epoch in range(10):
        for i, (x, y) in enumerate(train_loader):

            # print(x[0])
            # xs = x[0].data.numpy()  # (3, 32, 32)
            # xs = xs.transpose(1, 2, 0)  # (32, 32, 3)
            #
            # xs = (xs*0.5+0.5)*255
            # img = Image.fromarray(np.uint8(xs))
            # plt.imshow(img)
            # plt.pause(1)

            # y = torch.zeros(y.size(0), 10).scatter_(1, y.reshape(-1, 1), 1)
            x = x.to(device)
            y = y.to(device)

            output = net(x).reshape(y.size(0), -1)
            loss = loss_fn(output, y)

            optim.zero_grad()
            loss.backward()
            optim.step()

            if i % 10 == 0:
                plt.clf()
                a.append(i + epoch*(len(train_data) / batch_size))
                b.append(loss.item())
                plt.plot(a, b)
                plt.pause(1)

                print("epoch:{}, loss:{}".format(epoch, loss.item()))

        torch.save(net.state_dict(), save_params)  # 每一轮保存网络参数
        # torch.save(net, save_net)

    net.eval()
    eval_loss = 0
    eval_acc = 0
    for i, (x, y) in enumerate(test_loader):
        # y = torch.zeros(y.size(0), 10).scatter_(1, y.reshape(-1, 1), 1)
        x = x.to(device)
        y = y.to(device)

        out = net(x).reshape(y.size(0), -1)
        loss = loss_fn(out, y)

        eval_loss += loss.item() * batch_size

        # max_y = torch.argmax(y, 1)
        max_out = torch.argmax(out, 1)
        eval_acc += (max_out == y).sum().item()

    mean_loss = eval_loss / len(test_data)
    mean_acc = eval_acc / len(test_data)
    print("mean loss:{}, mean acc:{}".format(mean_loss, mean_acc))

二、查看1D卷积、2D卷积、3D卷积

import torch
import torch.nn as nn
import numpy as np

a = torch.randn(100, 1, 20)  # N, C, L
conv_1d = nn.Conv1d(1, 16, 3, 2, 1)

b = conv_1d(a)
# print(b.shape)

rng = np.random.RandomState(0)
X = 5*rng.randn(100, 1, 10)

print(X)
import torch
import torch.nn as nn

a = torch.randn(100, 1, 20, 20)  # N, C, H, W
conv_2d = nn.Conv2d(1, 16, (3, 3), 2, 1)  # 卷积核大小为3*3

b = conv_2d(a)
print(b.shape)
import torch
import torch.nn as nn

a = torch.randn(100, 3, 9, 224, 224)  # N, C, D, H, W
conv_3d = nn.Conv3d(3, 16, (3, 3, 3), 2, 1)  # 卷积核大小为3*3,2*2

b = conv_3d(a)
print(b.shape)
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值