LeNet详解与数据集的训练

1. 下载并导入数据集,,需要将数据集导入至代码所在文件夹中,并命名为data。本次展示所用的数据集是cifar数据集

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

traindata = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=False, transform=trans)
trainloader = torch.utils.data.DataLoader(
    traindata, batch_size=2048, shuffle=True, num_workers=2)
testdata = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=False, transform=trans)
testloader = torch.utils.data.DataLoader(
    testdata, batch_size=2048, shuffle=True, num_workers=0)

2. 建立LeNet网络,网络有些调整。LeNet详细的网络参数可以参考其论文,我这边做了一些调整,是为了验证不同参数对结果的影响。

#网络:LeNet5  改版
net = torch.nn.Sequential(
    nn.Conv2d(3, 6, kernel_size=5), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),

    nn.Conv2d(6, 16, kernel_size=3), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),

    nn.Linear(16*5*5, 120),
    nn.ReLU(),
    nn.Linear(120, 84),
    nn.ReLU(),
    nn.Linear(84, 10),
)

3. 设置我们的损失函数为交叉熵损失,优化算法为SGDwithMomentum,学习率为1e-3

device = torch.device("cuda:0")
print(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
net = net.to(device)

4. 开始训练

for epoch in range(50):
    net.train()

    for i, (X, Y) in enumerate(trainloader, 0):
        optimizer.zero_grad()
        X, Y = X.to(device), Y.to(device)
        gc.collect()
        torch.cuda.empty_cache()
        outputs = net(X)
        loss = criterion(outputs, Y.long())
        loss.backward()
        optimizer.step()
    print(epoch, loss.item())

    net.eval()
    with torch.no_grad():
        # test
        total_correct = 0
        total_num = 0
        for text_X, test_Y in testloader:
            text_X, test_Y = text_X.to(device), test_Y.to(device)
            logits = net(text_X)
            pred = logits.argmax(dim=1)
            total_correct += torch.eq(pred, test_Y).float().sum().item()
            total_num += text_X.size(0)
        acc = total_correct / total_num
        print("epoch: ", epoch, "acc: ", acc)

5. 效果展示:目前迭代了13轮,效果很差哈哈哈哈。在后面的训练中,我们可以通过减小学习率,来不断的优化

0 2.302259922027588
epoch:  0 acc:  0.1289
1 2.301879405975342
epoch:  1 acc:  0.1347
2 2.305971384048462
epoch:  2 acc:  0.1357
3 2.2990949153900146
epoch:  3 acc:  0.1353
4 2.2980239391326904
epoch:  4 acc:  0.1359
5 2.2999300956726074
epoch:  5 acc:  0.1348
6 2.29781436920166
epoch:  6 acc:  0.1348
7 2.297546625137329
epoch:  7 acc:  0.1354
8 2.2953741550445557
epoch:  8 acc:  0.1356
9 2.293961763381958
epoch:  9 acc:  0.1364
10 2.2932660579681396
epoch:  10 acc:  0.1381
11 2.290314197540283
epoch:  11 acc:  0.1419
12 2.2890193462371826
epoch:  12 acc:  0.1478
13 2.287813186645508
epoch:  13 acc:  0.1549

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值